codetether_agent/cognition/
tool_router.rs1use crate::provider::{CompletionResponse, ContentPart, FinishReason, ToolDefinition};
14use anyhow::{Result, anyhow};
15use std::sync::{Arc, Mutex};
16use uuid::Uuid;
17
18use super::thinker::{CandleThinker, ThinkerBackend, ThinkerConfig};
19
20#[derive(Debug, Clone)]
24pub struct ToolRouterConfig {
25 pub enabled: bool,
27 pub model_path: Option<String>,
29 pub tokenizer_path: Option<String>,
31 pub arch: String,
33 pub device: super::thinker::CandleDevicePreference,
35 pub max_tokens: usize,
37 pub temperature: f32,
39}
40
41impl Default for ToolRouterConfig {
42 fn default() -> Self {
43 Self {
44 enabled: false,
45 model_path: None,
46 tokenizer_path: None,
47 arch: "gemma3".to_string(),
48 device: super::thinker::CandleDevicePreference::Auto,
49 max_tokens: 512,
50 temperature: 0.1,
51 }
52 }
53}
54
55impl ToolRouterConfig {
56 pub fn from_env() -> Self {
68 let enabled = std::env::var("CODETETHER_TOOL_ROUTER_ENABLED")
69 .map(|v| matches!(v.as_str(), "1" | "true" | "yes"))
70 .unwrap_or(false);
71
72 Self {
73 enabled,
74 model_path: std::env::var("CODETETHER_TOOL_ROUTER_MODEL_PATH").ok(),
75 tokenizer_path: std::env::var("CODETETHER_TOOL_ROUTER_TOKENIZER_PATH").ok(),
76 arch: std::env::var("CODETETHER_TOOL_ROUTER_ARCH")
77 .unwrap_or_else(|_| "gemma3".to_string()),
78 device: std::env::var("CODETETHER_TOOL_ROUTER_DEVICE")
79 .map(|v| super::thinker::CandleDevicePreference::from_env(&v))
80 .unwrap_or(super::thinker::CandleDevicePreference::Auto),
81 max_tokens: std::env::var("CODETETHER_TOOL_ROUTER_MAX_TOKENS")
82 .ok()
83 .and_then(|v| v.parse().ok())
84 .unwrap_or(512),
85 temperature: std::env::var("CODETETHER_TOOL_ROUTER_TEMPERATURE")
86 .ok()
87 .and_then(|v| v.parse().ok())
88 .unwrap_or(0.1),
89 }
90 }
91}
92
93fn build_functiongemma_prompt(assistant_text: &str, tools: &[ToolDefinition]) -> String {
100 let tool_defs: Vec<serde_json::Value> = tools
102 .iter()
103 .map(|t| {
104 serde_json::json!({
105 "name": t.name,
106 "description": t.description,
107 "parameters": t.parameters,
108 })
109 })
110 .collect();
111
112 let tools_json = serde_json::to_string_pretty(&tool_defs).unwrap_or_else(|_| "[]".to_string());
113
114 format!(
123 "<start_of_turn>system\n\
124 You are a function calling AI model. You are provided with function \
125 signatures within <tools></tools> XML tags. You may call one or more \
126 functions to assist with the user query. Don't make assumptions about \
127 what values to plug into functions.\n\n\
128 <tools>\n{tools_json}\n</tools>\n\n\
129 For each function call return a JSON object with function name and \
130 arguments within <tool_call></tool_call> XML tags as follows:\n\
131 <tool_call>\n{{\"name\": \"function_name\", \"arguments\": {{\"arg1\": \"value1\"}}}}\n</tool_call>\n\
132 <end_of_turn>\n\
133 <start_of_turn>user\n\
134 {assistant_text}\n\
135 <end_of_turn>\n\
136 <start_of_turn>model\n"
137 )
138}
139
140#[derive(Debug, Clone)]
144struct ParsedToolCall {
145 name: String,
146 arguments: String, }
148
149fn parse_functiongemma_response(text: &str) -> Vec<ParsedToolCall> {
160 let mut calls = Vec::new();
161
162 let mut remaining = text;
164 while let Some(start) = remaining.find("<tool_call>") {
165 remaining = &remaining[start + "<tool_call>".len()..];
166 if let Some(end) = remaining.find("</tool_call>") {
167 let block = remaining[..end].trim();
168 remaining = &remaining[end + "</tool_call>".len()..];
169
170 if let Ok(value) = serde_json::from_str::<serde_json::Value>(block) {
172 let name = value
173 .get("name")
174 .and_then(|n| n.as_str())
175 .unwrap_or("")
176 .to_string();
177 let arguments = value
178 .get("arguments")
179 .map(|a| serde_json::to_string(a).unwrap_or_else(|_| "{}".to_string()))
180 .unwrap_or_else(|| "{}".to_string());
181
182 if !name.is_empty() {
183 calls.push(ParsedToolCall { name, arguments });
184 }
185 } else {
186 tracing::warn!(
187 block = %block,
188 "FunctionGemma produced unparseable tool_call block"
189 );
190 }
191 } else {
192 break; }
194 }
195
196 calls
197}
198
199pub struct ToolCallRouter {
205 runtime: Arc<Mutex<CandleThinker>>,
206}
207
208impl std::fmt::Debug for ToolCallRouter {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 f.debug_struct("ToolCallRouter").finish()
211 }
212}
213
214impl ToolCallRouter {
215 pub fn from_config(config: &ToolRouterConfig) -> Result<Option<Self>> {
219 if !config.enabled {
220 tracing::debug!("FunctionGemma tool router is disabled");
221 return Ok(None);
222 }
223
224 let model_path = config.model_path.as_ref().ok_or_else(|| {
225 anyhow!("CODETETHER_TOOL_ROUTER_MODEL_PATH is required when the tool router is enabled")
226 })?;
227 let tokenizer_path = config.tokenizer_path.as_ref().ok_or_else(|| {
228 anyhow!(
229 "CODETETHER_TOOL_ROUTER_TOKENIZER_PATH is required when the tool router is enabled"
230 )
231 })?;
232
233 let thinker_config = ThinkerConfig {
235 enabled: true,
236 backend: ThinkerBackend::Candle,
237 candle_model_path: Some(model_path.clone()),
238 candle_tokenizer_path: Some(tokenizer_path.clone()),
239 candle_arch: Some(config.arch.clone()),
240 candle_device: config.device,
241 max_tokens: config.max_tokens,
242 temperature: config.temperature,
243 ..ThinkerConfig::default()
244 };
245
246 let runtime = CandleThinker::new(&thinker_config)?;
247 tracing::info!(
248 model_path = %model_path,
249 arch = %config.arch,
250 "FunctionGemma tool-call router initialised"
251 );
252
253 Ok(Some(Self {
254 runtime: Arc::new(Mutex::new(runtime)),
255 }))
256 }
257
258 pub async fn maybe_reformat(
267 &self,
268 response: CompletionResponse,
269 tools: &[ToolDefinition],
270 ) -> CompletionResponse {
271 let has_tool_calls = response
273 .message
274 .content
275 .iter()
276 .any(|p| matches!(p, ContentPart::ToolCall { .. }));
277
278 if has_tool_calls {
279 return response;
280 }
281
282 if tools.is_empty() {
284 return response;
285 }
286
287 let assistant_text: String = response
289 .message
290 .content
291 .iter()
292 .filter_map(|p| match p {
293 ContentPart::Text { text } => Some(text.as_str()),
294 _ => None,
295 })
296 .collect::<Vec<_>>()
297 .join("\n");
298
299 if assistant_text.trim().is_empty() {
300 return response;
301 }
302
303 match self.run_functiongemma(&assistant_text, tools).await {
305 Ok(parsed) if !parsed.is_empty() => {
306 tracing::info!(
307 num_calls = parsed.len(),
308 "FunctionGemma router produced tool calls from text-only response"
309 );
310 self.rewrite_response(response, parsed)
311 }
312 Ok(_) => {
313 response
315 }
316 Err(e) => {
317 tracing::warn!(
318 error = %e,
319 "FunctionGemma router failed; returning original response"
320 );
321 response
322 }
323 }
324 }
325
326 async fn run_functiongemma(
328 &self,
329 assistant_text: &str,
330 tools: &[ToolDefinition],
331 ) -> Result<Vec<ParsedToolCall>> {
332 let prompt = build_functiongemma_prompt(assistant_text, tools);
333 let runtime = Arc::clone(&self.runtime);
334
335 let output = tokio::task::spawn_blocking(move || {
336 let mut guard = runtime
337 .lock()
338 .map_err(|_| anyhow!("FunctionGemma mutex poisoned"))?;
339 guard.think("", &prompt)
344 })
345 .await
346 .map_err(|e| anyhow!("FunctionGemma task join failed: {e}"))??;
347
348 Ok(parse_functiongemma_response(&output.text))
349 }
350
351 fn rewrite_response(
353 &self,
354 mut response: CompletionResponse,
355 calls: Vec<ParsedToolCall>,
356 ) -> CompletionResponse {
357 for call in calls {
359 response.message.content.push(ContentPart::ToolCall {
360 id: format!("fc_{}", Uuid::new_v4()),
361 name: call.name,
362 arguments: call.arguments,
363 });
364 }
365
366 response.finish_reason = FinishReason::ToolCalls;
368 response
369 }
370}
371
372#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn parse_single_tool_call() {
380 let text = r#"<tool_call>
381{"name": "read_file", "arguments": {"path": "/tmp/foo.rs"}}
382</tool_call>"#;
383 let calls = parse_functiongemma_response(text);
384 assert_eq!(calls.len(), 1);
385 assert_eq!(calls[0].name, "read_file");
386 assert!(calls[0].arguments.contains("/tmp/foo.rs"));
387 }
388
389 #[test]
390 fn parse_multiple_tool_calls() {
391 let text = r#"I'll read both files.
392<tool_call>
393{"name": "read_file", "arguments": {"path": "a.rs"}}
394</tool_call>
395<tool_call>
396{"name": "read_file", "arguments": {"path": "b.rs"}}
397</tool_call>"#;
398 let calls = parse_functiongemma_response(text);
399 assert_eq!(calls.len(), 2);
400 assert_eq!(calls[0].name, "read_file");
401 assert_eq!(calls[1].name, "read_file");
402 }
403
404 #[test]
405 fn parse_no_tool_calls() {
406 let text = "I cannot help with that request.";
407 let calls = parse_functiongemma_response(text);
408 assert!(calls.is_empty());
409 }
410
411 #[test]
412 fn parse_malformed_json_skipped() {
413 let text = r#"<tool_call>
414not valid json
415</tool_call>
416<tool_call>
417{"name": "list_dir", "arguments": {"path": "."}}
418</tool_call>"#;
419 let calls = parse_functiongemma_response(text);
420 assert_eq!(calls.len(), 1);
421 assert_eq!(calls[0].name, "list_dir");
422 }
423
424 #[test]
425 fn parse_empty_name_skipped() {
426 let text = r#"<tool_call>
427{"name": "", "arguments": {}}
428</tool_call>"#;
429 let calls = parse_functiongemma_response(text);
430 assert!(calls.is_empty());
431 }
432
433 #[test]
434 fn prompt_contains_tool_definitions() {
435 let tools = vec![ToolDefinition {
436 name: "read_file".to_string(),
437 description: "Read a file".to_string(),
438 parameters: serde_json::json!({
439 "type": "object",
440 "properties": {
441 "path": { "type": "string" }
442 },
443 "required": ["path"]
444 }),
445 }];
446 let prompt = build_functiongemma_prompt("Please read foo.rs", &tools);
447 assert!(prompt.contains("<start_of_turn>system"));
448 assert!(prompt.contains("read_file"));
449 assert!(prompt.contains("<tools>"));
450 assert!(prompt.contains("Please read foo.rs"));
451 assert!(prompt.contains("<start_of_turn>model"));
452 }
453
454 #[test]
455 fn config_defaults_disabled() {
456 let config = ToolRouterConfig::default();
457 assert!(!config.enabled);
458 assert_eq!(config.arch, "gemma3");
459 assert_eq!(config.max_tokens, 512);
460 }
461}