1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
use anyhow::Result;
use uuid;
use super::provider::ModelProvider;
use super::types::*;
// =============================================================================
// QueryEngine — The central orchestrator that sits between the agent loop
// and whichever model provider is active.
// =============================================================================
pub struct QueryEngine {
provider: Box<dyn ModelProvider>,
system_prompt: String,
tools: Vec<ToolDefinition>,
max_tokens: u32,
}
impl QueryEngine {
/// Create a new QueryEngine with a provider, system prompt, and tool definitions.
pub fn new(
provider: Box<dyn ModelProvider>,
system_prompt: String,
tools: Vec<ToolDefinition>,
) -> Self {
Self {
provider,
system_prompt,
tools,
max_tokens: 8192,
}
}
/// Set the maximum number of tokens the model can generate per turn.
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
/// Send a query to the active model provider.
///
/// Takes the full conversation history and returns the model's response,
/// which may contain text and/or tool_use blocks.
pub async fn query(&self, messages: &[Message]) -> Result<CompletionResponse> {
let supports_tools = self.provider.supports_tools();
let mut system = self.system_prompt.clone();
if !supports_tools && !self.tools.is_empty() {
// Append tool definitions and XML instructions to the system prompt
let mut tools_desc = String::new();
tools_desc.push_str("\n\nYou have access to the following tools:\n");
for tool in &self.tools {
tools_desc.push_str(&format!(
"- {}: {} | Input Schema: {}\n",
tool.name, tool.description, tool.input_schema
));
}
tools_desc.push_str(
"\nTo use a tool, you MUST output a tool call using the following XML format:\n\
<tool_call>\n\
{\n\
\"name\": \"tool_name\",\n\
\"input\": {\n\
\"param1\": \"value1\"\n\
}\n\
}\n\
</tool_call>\n\n\
You can invoke multiple tools in a single turn if needed. The execution results will be returned to you."
);
system.push_str(&tools_desc);
}
let request = CompletionRequest {
messages: messages.to_vec(),
system,
tools: if supports_tools {
self.tools.clone()
} else {
vec![]
},
max_tokens: self.max_tokens,
};
let mut response = self.provider.complete(request).await?;
// If provider doesn't support tools, we parse tool calls from the text response
if !supports_tools {
let mut parsed_blocks = Vec::new();
let mut has_tool_use = false;
for block in &response.content {
if let ContentBlock::Text { text } = block {
let mut last_idx = 0;
let text_str = text.as_str();
while let Some(start_idx) = text_str[last_idx..].find("<tool_call>") {
let absolute_start = last_idx + start_idx;
let inner_start = absolute_start + "<tool_call>".len();
if let Some(end_idx) = text_str[inner_start..].find("</tool_call>") {
let absolute_end = inner_start + end_idx;
let json_str = &text_str[inner_start..absolute_end];
// Add the text before the tool call
let pre_text = &text_str[last_idx..absolute_start];
if !pre_text.trim().is_empty() {
parsed_blocks.push(ContentBlock::Text {
text: pre_text.to_string(),
});
}
// Parse the tool call JSON
if let Ok(val) = serde_json::from_str::<serde_json::Value>(json_str) {
if let Some(name) = val["name"].as_str() {
let input = val["input"].clone();
has_tool_use = true;
parsed_blocks.push(ContentBlock::ToolUse {
id: format!("toolu_{}", uuid::Uuid::new_v4()),
name: name.to_string(),
input,
});
}
}
last_idx = absolute_end + "</tool_call>".len();
} else {
break;
}
}
// Add remaining text
let post_text = &text_str[last_idx..];
if !post_text.trim().is_empty() || parsed_blocks.is_empty() {
parsed_blocks.push(ContentBlock::Text {
text: post_text.to_string(),
});
}
} else {
parsed_blocks.push(block.clone());
}
}
if has_tool_use {
response.content = parsed_blocks;
response.stop_reason = StopReason::ToolUse;
}
}
Ok(response)
}
/// Hot-swap the model provider at runtime (e.g., via `/model` command).
pub fn switch_provider(&mut self, provider: Box<dyn ModelProvider>) {
self.provider = provider;
}
/// Update the system prompt (e.g., when adding memory context).
pub fn set_system_prompt(&mut self, prompt: String) {
self.system_prompt = prompt;
}
/// Update the tool definitions.
pub fn set_tools(&mut self, tools: Vec<ToolDefinition>) {
self.tools = tools;
}
/// Get the name of the currently active provider.
pub fn provider_name(&self) -> &str {
self.provider.name()
}
/// Get the model ID of the currently active provider.
pub fn model_id(&self) -> &str {
self.provider.model_id()
}
/// Whether the current provider natively supports tool calls.
pub fn supports_tools(&self) -> bool {
self.provider.supports_tools()
}
}