use serde_json::{Value, json};
use std::time::Duration;
use uuid::Uuid;
use crate::base::error::{TestError, TestsResult, ToolCallError};
use crate::mcp::client::McpClient;
use crate::ollama::chat::ChatMessage;
use crate::ollama::client::OllamaClient;
#[derive(Debug)]
pub struct ToolCall {
pub name: String,
pub args: String,
}
#[derive(Debug)]
pub struct DialogResult {
pub tool: String,
pub args: String,
pub model_response: String,
pub tool_response: String,
}
impl DialogResult {
fn from(tool_call: ToolCall, model_response: String, tool_response: String) -> Self {
Self {
tool: tool_call.name,
args: tool_call.args,
model_response,
tool_response,
}
}
}
pub struct Dialog {
ollama: OllamaClient,
mcp: McpClient,
model: String,
tools: Vec<Value>,
timeout: Duration,
}
impl Dialog {
pub fn new(
ollama: OllamaClient,
mcp: McpClient,
model: String,
tools: Vec<Value>,
timeout: Duration,
) -> Self {
Self {
ollama,
mcp,
model,
tools,
timeout,
}
}
pub async fn run(&self, query: &str) -> TestsResult<DialogResult> {
let session_id = Uuid::new_v4().to_string();
let mut messages = vec![ChatMessage::user(query)];
let tool_call: Option<ToolCall>;
let tool_response: String;
tracing::debug!("Turn 1: sending query to model");
let response = self
.ollama
.chat_with_history(&session_id, &self.model, &messages, &self.tools)
.await?;
let content = response.message.content.trim();
tracing::debug!("Turn 1 response: {}", content);
let json_str = extract_json(content);
tracing::debug!("Extracted JSON: {}", json_str);
if let Some((name, args)) = parse_tool_call(&json_str) {
tracing::debug!("Tool called: {}", name);
tool_response = self
.mcp
.call_tool(name.clone(), args.clone(), self.timeout)
.await?;
tracing::debug!("Tool result: {}", &tool_response);
tool_call = Some(ToolCall {
name: name.clone(),
args: serde_json::to_string(&args).unwrap_or_default(),
});
messages.push(ChatMessage::assistant(content));
messages.push(ChatMessage::tool(&tool_response));
} else {
tracing::error!("Model did not call a tool: {}", content);
return Err(TestError::ToolCall(ToolCallError {
tool: None,
args: None,
code: -1,
}));
}
tracing::debug!("Turn 2: requesting final response");
let response = self
.ollama
.chat_with_history(&session_id, &self.model, &messages, &[])
.await?;
let model_response = response.message.content.trim().to_string();
tracing::debug!("Model response: {}", model_response);
Ok(DialogResult::from(
tool_call.unwrap(),
model_response,
tool_response,
))
}
}
fn extract_json(content: &str) -> String {
if content.starts_with("```json") {
content
.strip_prefix("```json")
.unwrap_or(content)
.strip_suffix("```")
.unwrap_or(content)
.trim()
.to_string()
} else if content.starts_with("```") {
content
.strip_prefix("```")
.unwrap_or(content)
.strip_suffix("```")
.unwrap_or(content)
.trim()
.to_string()
} else {
content.to_string()
}
}
fn parse_tool_call(json_str: &str) -> Option<(String, Value)> {
let parsed = serde_json::from_str::<Value>(json_str).ok()?;
let name = parsed
.get("name")
.or_else(|| parsed.get("function"))
.or_else(|| parsed.get("tool"))
.or_else(|| parsed.get("method"))
.or_else(|| parsed.get("call"))
.and_then(|n| n.as_str())?;
let args = parsed.get("arguments").cloned().unwrap_or(json!({}));
Some((name.to_string(), args))
}