use crate::types::*;
use crate::error::AgentError;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub struct QueryEngine {
config: QueryEngineConfig,
messages: Vec<crate::types::Message>,
turn_count: u32,
total_usage: TokenUsage,
total_cost: f64,
http_client: reqwest::Client,
tool_executors: Mutex<HashMap<String, Arc<ToolExecutor>>>,
}
type BoxFuture<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send>>;
type ToolExecutor = dyn Fn(serde_json::Value, &ToolContext) -> BoxFuture<Result<ToolResult, AgentError>> + Send + Sync;
pub struct QueryEngineConfig {
pub cwd: String,
pub model: String,
pub api_key: Option<String>,
pub base_url: Option<String>,
pub tools: Vec<ToolDefinition>,
pub system_prompt: Option<String>,
pub max_turns: u32,
pub max_budget_usd: Option<f64>,
pub max_tokens: u32,
pub can_use_tool: Option<fn(ToolDefinition, serde_json::Value) -> bool>,
}
impl QueryEngine {
pub fn new(config: QueryEngineConfig) -> Self {
Self {
config,
messages: vec![],
turn_count: 0,
total_usage: TokenUsage {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
},
total_cost: 0.0,
http_client: reqwest::Client::new(),
tool_executors: Mutex::new(HashMap::new()),
}
}
pub fn register_tool<F>(&mut self, name: String, executor: F)
where
F: Fn(serde_json::Value, &ToolContext) -> BoxFuture<Result<ToolResult, AgentError>>
+ Send
+ Sync
+ 'static,
{
self.tool_executors.lock().unwrap().insert(name, Arc::new(executor));
}
pub fn set_messages(&mut self, messages: Vec<crate::types::Message>) {
self.messages = messages;
}
pub async fn execute_tool(
&self,
name: &str,
input: serde_json::Value,
) -> Result<ToolResult, AgentError> {
let context = ToolContext {
cwd: self.config.cwd.clone(),
abort_signal: None,
};
let executor = {
let executors = self.tool_executors.lock().unwrap();
let found = executors.get(name).cloned();
found
};
if let Some(executor) = executor {
executor(input, &context).await
} else {
Err(AgentError::Tool(format!("Tool '{}' not found", name)))
}
}
pub fn get_turn_count(&self) -> u32 {
self.turn_count
}
pub fn get_messages(&self) -> Vec<crate::types::Message> {
self.messages.clone()
}
pub async fn submit_message(&mut self, prompt: &str) -> Result<String, AgentError> {
self.messages.push(crate::types::Message {
role: crate::types::MessageRole::User,
content: prompt.to_string(),
..Default::default()
});
self.turn_count += 1;
let mut max_tool_turns = 10;
while max_tool_turns > 0 {
max_tool_turns -= 1;
let api_messages = self.build_api_messages()?;
let api_key = self.config.api_key.as_ref()
.ok_or_else(|| AgentError::Api("API key not provided".to_string()))?;
let base_url = self.config.base_url.as_ref()
.map(|s| s.as_str())
.unwrap_or("https://api.anthropic.com");
let model = &self.config.model;
let mut request_body = serde_json::json!({
"model": model,
"max_tokens": self.config.max_tokens,
"messages": api_messages,
});
if !self.config.tools.is_empty() {
let tools: Vec<serde_json::Value> = self.config.tools.iter().map(|t| {
serde_json::json!({
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.input_schema
}
})
}).collect();
request_body["tools"] = serde_json::json!(tools);
}
let url = format!("{}/v1/chat/completions", base_url);
let response = match self.http_client
.post(&url)
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
{
Ok(r) => r,
Err(e) => {
return Err(AgentError::Api(format!("Send error: {}", e)));
}
};
let status_code = response.status();
if !status_code.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AgentError::Api(format!("API error {}: {}", status_code, error_text)));
}
let response_text = response.text().await.map_err(|e| {
AgentError::Api(format!("Failed to read response text: {}", e))
})?;
let response_json: serde_json::Value = serde_json::from_str(&response_text).map_err(|e| {
AgentError::Api(format!("Failed to parse response: {}", e))
})?;
let tool_calls = extract_tool_calls(&response_json);
if tool_calls.is_empty() {
let response_text = extract_response_text(&response_json);
let usage = extract_usage(&response_json);
self.total_usage.input_tokens += usage.input_tokens;
self.total_usage.output_tokens += usage.output_tokens;
self.messages.push(crate::types::Message {
role: crate::types::MessageRole::Assistant,
content: response_text.clone(),
..Default::default()
});
return Ok(response_text);
}
for tool_call in tool_calls {
let tool_name = tool_call.get("name").and_then(|n| n.as_str()).unwrap_or("");
let tool_args = tool_call.get("arguments").cloned().unwrap_or(serde_json::Value::Null);
let tool_call_id = tool_call.get("id").and_then(|id| id.as_str()).map(String::from);
let tool_args_for_msg = tool_args.clone();
let tool_result = self.execute_tool(tool_name, tool_args).await?;
self.messages.push(crate::types::Message {
role: crate::types::MessageRole::Assistant,
content: format!("Calling tool: {} with args: {:?}", tool_name, tool_args_for_msg),
tool_call_id: tool_call_id.clone(),
tool_calls: Some(vec![crate::types::ToolCall {
id: tool_call_id.clone().unwrap_or_default(),
name: tool_name.to_string(),
arguments: tool_args_for_msg,
}]),
..Default::default()
});
self.messages.push(crate::types::Message {
role: crate::types::MessageRole::Tool,
content: tool_result.content,
tool_call_id,
..Default::default()
});
}
}
let final_text = self.messages.iter()
.filter(|m| m.role == crate::types::MessageRole::Assistant)
.last()
.map(|m| m.content.clone())
.unwrap_or_else(|| "Max tool execution turns reached".to_string());
Ok(final_text)
}
fn build_api_messages(&self) -> Result<Vec<serde_json::Value>, AgentError> {
let mut api_messages: Vec<serde_json::Value> = Vec::new();
if let Some(system_prompt) = &self.config.system_prompt {
api_messages.push(serde_json::json!({
"role": "system",
"content": system_prompt
}));
}
for msg in &self.messages {
let role_str = match msg.role {
crate::types::MessageRole::User => "user",
crate::types::MessageRole::Assistant => "assistant",
crate::types::MessageRole::Tool => "tool",
};
let mut msg_json = serde_json::json!({
"role": role_str,
"content": msg.content
});
if let Some(tool_call_id) = &msg.tool_call_id {
msg_json["tool_call_id"] = serde_json::json!(tool_call_id);
}
if let Some(tool_calls) = &msg.tool_calls {
let tc_json: Vec<serde_json::Value> = tool_calls.iter().map(|tc| {
let args_str = tc.arguments.to_string();
serde_json::json!({
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": args_str
}
})
}).collect();
msg_json["tool_calls"] = serde_json::json!(tc_json);
}
api_messages.push(msg_json);
}
Ok(api_messages)
}
}
fn extract_tool_calls(response: &serde_json::Value) -> Vec<serde_json::Value> {
if let Some(choices) = response.get("choices").and_then(|c| c.as_array()) {
if let Some(first_choice) = choices.first() {
if let Some(message) = first_choice.get("message") {
if let Some(tool_calls) = message.get("tool_calls").and_then(|t| t.as_array()) {
if !tool_calls.is_empty() {
return tool_calls.iter().map(|tc| {
let func = tc.get("function");
let name = func.and_then(|f| f.get("name")).cloned().unwrap_or(serde_json::Value::Null);
let args = func.and_then(|f| f.get("arguments"));
let arguments = if let Some(args_val) = args {
if let Some(arg_str) = args_val.as_str() {
serde_json::from_str(arg_str).unwrap_or(args_val.clone())
} else {
args_val.clone()
}
} else {
serde_json::Value::Null
};
let id = tc.get("id").cloned();
let mut result = serde_json::json!({
"name": name,
"arguments": arguments,
});
if let Some(id_val) = id {
result["id"] = id_val;
}
result
}).collect();
}
}
}
}
}
vec![]
}
fn extract_response_text(response: &serde_json::Value) -> String {
if let Some(choices) = response.get("choices").and_then(|c| c.as_array()) {
if let Some(first_choice) = choices.first() {
if let Some(message) = first_choice.get("message") {
if let Some(content) = message.get("content").and_then(|c| c.as_str()) {
return content.to_string();
}
}
}
}
if let Some(content) = response.get("content").and_then(|c| c.as_array()) {
for block in content {
if let Some(block_type) = block.get("type").and_then(|t| t.as_str()) {
match block_type {
"text" => {
if let Some(t) = block.get("text").and_then(|t| t.as_str()) {
return t.to_string();
}
}
_ => {}
}
}
}
}
String::new()
}
fn extract_usage(response: &serde_json::Value) -> TokenUsage {
if let Some(usage) = response.get("usage") {
return TokenUsage {
input_tokens: usage.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0)
+ usage.get("completion_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
output_tokens: usage.get("completion_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
};
}
let usage = response.get("usage");
TokenUsage {
input_tokens: usage.and_then(|u| u.get("input_tokens")).and_then(|v| v.as_u64()).unwrap_or(0),
output_tokens: usage.and_then(|u| u.get("output_tokens")).and_then(|v| v.as_u64()).unwrap_or(0),
cache_creation_input_tokens: usage.and_then(|u| u.get("cache_creation_input_tokens")).and_then(|v| v.as_u64()),
cache_read_input_tokens: usage.and_then(|u| u.get("cache_read_input_tokens")).and_then(|v| v.as_u64()),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{MessageRole, TokenUsage};
#[tokio::test]
async fn test_engine_creation() {
let engine = QueryEngine::new(QueryEngineConfig {
cwd: "/tmp".to_string(),
model: "claude-sonnet-4-6".to_string(),
api_key: None,
base_url: None,
tools: vec![],
system_prompt: None,
max_turns: 10,
max_budget_usd: None,
max_tokens: 16384,
can_use_tool: None,
});
assert_eq!(engine.get_turn_count(), 0);
}
#[tokio::test]
async fn test_engine_submit_message() {
let mut engine = QueryEngine::new(QueryEngineConfig {
cwd: "/tmp".to_string(),
model: "claude-sonnet-4-6".to_string(),
api_key: None,
base_url: None,
tools: vec![],
system_prompt: None,
max_turns: 10,
max_budget_usd: None,
max_tokens: 16384,
can_use_tool: None,
});
let result = engine.submit_message("Hello").await;
assert!(result.is_err());
}
}