codetether_agent/cognition/
tool_router.rs1use crate::provider::{CompletionResponse, ContentPart, FinishReason, ToolDefinition};
14use anyhow::{Result, anyhow};
15use std::sync::{Arc, Mutex};
16use uuid::Uuid;
17
18use super::thinker::{CandleThinker, ThinkerConfig, ThinkerBackend};
19
20#[derive(Debug, Clone)]
24pub struct ToolRouterConfig {
25 pub enabled: bool,
27 pub model_path: Option<String>,
29 pub tokenizer_path: Option<String>,
31 pub arch: String,
33 pub device: super::thinker::CandleDevicePreference,
35 pub max_tokens: usize,
37 pub temperature: f32,
39}
40
41impl Default for ToolRouterConfig {
42 fn default() -> Self {
43 Self {
44 enabled: false,
45 model_path: None,
46 tokenizer_path: None,
47 arch: "gemma3".to_string(),
48 device: super::thinker::CandleDevicePreference::Auto,
49 max_tokens: 512,
50 temperature: 0.1,
51 }
52 }
53}
54
55impl ToolRouterConfig {
56 pub fn from_env() -> Self {
68 let enabled = std::env::var("CODETETHER_TOOL_ROUTER_ENABLED")
69 .map(|v| matches!(v.as_str(), "1" | "true" | "yes"))
70 .unwrap_or(false);
71
72 Self {
73 enabled,
74 model_path: std::env::var("CODETETHER_TOOL_ROUTER_MODEL_PATH").ok(),
75 tokenizer_path: std::env::var("CODETETHER_TOOL_ROUTER_TOKENIZER_PATH").ok(),
76 arch: std::env::var("CODETETHER_TOOL_ROUTER_ARCH")
77 .unwrap_or_else(|_| "gemma3".to_string()),
78 device: std::env::var("CODETETHER_TOOL_ROUTER_DEVICE")
79 .map(|v| super::thinker::CandleDevicePreference::from_env(&v))
80 .unwrap_or(super::thinker::CandleDevicePreference::Auto),
81 max_tokens: std::env::var("CODETETHER_TOOL_ROUTER_MAX_TOKENS")
82 .ok()
83 .and_then(|v| v.parse().ok())
84 .unwrap_or(512),
85 temperature: std::env::var("CODETETHER_TOOL_ROUTER_TEMPERATURE")
86 .ok()
87 .and_then(|v| v.parse().ok())
88 .unwrap_or(0.1),
89 }
90 }
91}
92
93fn build_functiongemma_prompt(
100 assistant_text: &str,
101 tools: &[ToolDefinition],
102) -> String {
103 let tool_defs: Vec<serde_json::Value> = tools
105 .iter()
106 .map(|t| {
107 serde_json::json!({
108 "name": t.name,
109 "description": t.description,
110 "parameters": t.parameters,
111 })
112 })
113 .collect();
114
115 let tools_json = serde_json::to_string_pretty(&tool_defs).unwrap_or_else(|_| "[]".to_string());
116
117 format!(
126 "<start_of_turn>system\n\
127 You are a function calling AI model. You are provided with function \
128 signatures within <tools></tools> XML tags. You may call one or more \
129 functions to assist with the user query. Don't make assumptions about \
130 what values to plug into functions.\n\n\
131 <tools>\n{tools_json}\n</tools>\n\n\
132 For each function call return a JSON object with function name and \
133 arguments within <tool_call></tool_call> XML tags as follows:\n\
134 <tool_call>\n{{\"name\": \"function_name\", \"arguments\": {{\"arg1\": \"value1\"}}}}\n</tool_call>\n\
135 <end_of_turn>\n\
136 <start_of_turn>user\n\
137 {assistant_text}\n\
138 <end_of_turn>\n\
139 <start_of_turn>model\n"
140 )
141}
142
143#[derive(Debug, Clone)]
147struct ParsedToolCall {
148 name: String,
149 arguments: String, }
151
152fn parse_functiongemma_response(text: &str) -> Vec<ParsedToolCall> {
163 let mut calls = Vec::new();
164
165 let mut remaining = text;
167 while let Some(start) = remaining.find("<tool_call>") {
168 remaining = &remaining[start + "<tool_call>".len()..];
169 if let Some(end) = remaining.find("</tool_call>") {
170 let block = remaining[..end].trim();
171 remaining = &remaining[end + "</tool_call>".len()..];
172
173 if let Ok(value) = serde_json::from_str::<serde_json::Value>(block) {
175 let name = value
176 .get("name")
177 .and_then(|n| n.as_str())
178 .unwrap_or("")
179 .to_string();
180 let arguments = value
181 .get("arguments")
182 .map(|a| serde_json::to_string(a).unwrap_or_else(|_| "{}".to_string()))
183 .unwrap_or_else(|| "{}".to_string());
184
185 if !name.is_empty() {
186 calls.push(ParsedToolCall { name, arguments });
187 }
188 } else {
189 tracing::warn!(
190 block = %block,
191 "FunctionGemma produced unparseable tool_call block"
192 );
193 }
194 } else {
195 break; }
197 }
198
199 calls
200}
201
202pub struct ToolCallRouter {
208 runtime: Arc<Mutex<CandleThinker>>,
209}
210
211impl std::fmt::Debug for ToolCallRouter {
212 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213 f.debug_struct("ToolCallRouter").finish()
214 }
215}
216
217impl ToolCallRouter {
218 pub fn from_config(config: &ToolRouterConfig) -> Result<Option<Self>> {
222 if !config.enabled {
223 tracing::debug!("FunctionGemma tool router is disabled");
224 return Ok(None);
225 }
226
227 let model_path = config.model_path.as_ref().ok_or_else(|| {
228 anyhow!(
229 "CODETETHER_TOOL_ROUTER_MODEL_PATH is required when the tool router is enabled"
230 )
231 })?;
232 let tokenizer_path = config.tokenizer_path.as_ref().ok_or_else(|| {
233 anyhow!(
234 "CODETETHER_TOOL_ROUTER_TOKENIZER_PATH is required when the tool router is enabled"
235 )
236 })?;
237
238 let thinker_config = ThinkerConfig {
240 enabled: true,
241 backend: ThinkerBackend::Candle,
242 candle_model_path: Some(model_path.clone()),
243 candle_tokenizer_path: Some(tokenizer_path.clone()),
244 candle_arch: Some(config.arch.clone()),
245 candle_device: config.device,
246 max_tokens: config.max_tokens,
247 temperature: config.temperature,
248 ..ThinkerConfig::default()
249 };
250
251 let runtime = CandleThinker::new(&thinker_config)?;
252 tracing::info!(
253 model_path = %model_path,
254 arch = %config.arch,
255 "FunctionGemma tool-call router initialised"
256 );
257
258 Ok(Some(Self {
259 runtime: Arc::new(Mutex::new(runtime)),
260 }))
261 }
262
263 pub async fn maybe_reformat(
272 &self,
273 response: CompletionResponse,
274 tools: &[ToolDefinition],
275 ) -> CompletionResponse {
276 let has_tool_calls = response
278 .message
279 .content
280 .iter()
281 .any(|p| matches!(p, ContentPart::ToolCall { .. }));
282
283 if has_tool_calls {
284 return response;
285 }
286
287 if tools.is_empty() {
289 return response;
290 }
291
292 let assistant_text: String = response
294 .message
295 .content
296 .iter()
297 .filter_map(|p| match p {
298 ContentPart::Text { text } => Some(text.as_str()),
299 _ => None,
300 })
301 .collect::<Vec<_>>()
302 .join("\n");
303
304 if assistant_text.trim().is_empty() {
305 return response;
306 }
307
308 match self.run_functiongemma(&assistant_text, tools).await {
310 Ok(parsed) if !parsed.is_empty() => {
311 tracing::info!(
312 num_calls = parsed.len(),
313 "FunctionGemma router produced tool calls from text-only response"
314 );
315 self.rewrite_response(response, parsed)
316 }
317 Ok(_) => {
318 response
320 }
321 Err(e) => {
322 tracing::warn!(
323 error = %e,
324 "FunctionGemma router failed; returning original response"
325 );
326 response
327 }
328 }
329 }
330
331 async fn run_functiongemma(
333 &self,
334 assistant_text: &str,
335 tools: &[ToolDefinition],
336 ) -> Result<Vec<ParsedToolCall>> {
337 let prompt = build_functiongemma_prompt(assistant_text, tools);
338 let runtime = Arc::clone(&self.runtime);
339
340 let output = tokio::task::spawn_blocking(move || {
341 let mut guard = runtime
342 .lock()
343 .map_err(|_| anyhow!("FunctionGemma mutex poisoned"))?;
344 guard.think("", &prompt)
349 })
350 .await
351 .map_err(|e| anyhow!("FunctionGemma task join failed: {e}"))??;
352
353 Ok(parse_functiongemma_response(&output.text))
354 }
355
356 fn rewrite_response(
358 &self,
359 mut response: CompletionResponse,
360 calls: Vec<ParsedToolCall>,
361 ) -> CompletionResponse {
362 for call in calls {
364 response.message.content.push(ContentPart::ToolCall {
365 id: format!("fc_{}", Uuid::new_v4()),
366 name: call.name,
367 arguments: call.arguments,
368 });
369 }
370
371 response.finish_reason = FinishReason::ToolCalls;
373 response
374 }
375}
376
377#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn parse_single_tool_call() {
385 let text = r#"<tool_call>
386{"name": "read_file", "arguments": {"path": "/tmp/foo.rs"}}
387</tool_call>"#;
388 let calls = parse_functiongemma_response(text);
389 assert_eq!(calls.len(), 1);
390 assert_eq!(calls[0].name, "read_file");
391 assert!(calls[0].arguments.contains("/tmp/foo.rs"));
392 }
393
394 #[test]
395 fn parse_multiple_tool_calls() {
396 let text = r#"I'll read both files.
397<tool_call>
398{"name": "read_file", "arguments": {"path": "a.rs"}}
399</tool_call>
400<tool_call>
401{"name": "read_file", "arguments": {"path": "b.rs"}}
402</tool_call>"#;
403 let calls = parse_functiongemma_response(text);
404 assert_eq!(calls.len(), 2);
405 assert_eq!(calls[0].name, "read_file");
406 assert_eq!(calls[1].name, "read_file");
407 }
408
409 #[test]
410 fn parse_no_tool_calls() {
411 let text = "I cannot help with that request.";
412 let calls = parse_functiongemma_response(text);
413 assert!(calls.is_empty());
414 }
415
416 #[test]
417 fn parse_malformed_json_skipped() {
418 let text = r#"<tool_call>
419not valid json
420</tool_call>
421<tool_call>
422{"name": "list_dir", "arguments": {"path": "."}}
423</tool_call>"#;
424 let calls = parse_functiongemma_response(text);
425 assert_eq!(calls.len(), 1);
426 assert_eq!(calls[0].name, "list_dir");
427 }
428
429 #[test]
430 fn parse_empty_name_skipped() {
431 let text = r#"<tool_call>
432{"name": "", "arguments": {}}
433</tool_call>"#;
434 let calls = parse_functiongemma_response(text);
435 assert!(calls.is_empty());
436 }
437
438 #[test]
439 fn prompt_contains_tool_definitions() {
440 let tools = vec![ToolDefinition {
441 name: "read_file".to_string(),
442 description: "Read a file".to_string(),
443 parameters: serde_json::json!({
444 "type": "object",
445 "properties": {
446 "path": { "type": "string" }
447 },
448 "required": ["path"]
449 }),
450 }];
451 let prompt = build_functiongemma_prompt("Please read foo.rs", &tools);
452 assert!(prompt.contains("<start_of_turn>system"));
453 assert!(prompt.contains("read_file"));
454 assert!(prompt.contains("<tools>"));
455 assert!(prompt.contains("Please read foo.rs"));
456 assert!(prompt.contains("<start_of_turn>model"));
457 }
458
459 #[test]
460 fn config_defaults_disabled() {
461 let config = ToolRouterConfig::default();
462 assert!(!config.enabled);
463 assert_eq!(config.arch, "gemma3");
464 assert_eq!(config.max_tokens, 512);
465 }
466}