Skip to main content

codetether_agent/cognition/
tool_router.rs

1//! FunctionGemma-powered hybrid tool-call router.
2//!
3//! Sits between the primary LLM response and the tool-extraction step in the
4//! session agentic loop.  When the primary LLM returns text-only output that
5//! *describes* tool calls without using structured `ContentPart::ToolCall`
6//! entries, the router passes the text + available tool definitions through a
7//! local FunctionGemma model (via Candle) and emits properly-formatted
8//! `ContentPart::ToolCall` entries.
9//!
10//! **Feature-gated**: this module only compiles when the `functiongemma` cargo
11//! feature is enabled.  The binary size is unaffected in default builds.
12
13use crate::provider::{CompletionResponse, ContentPart, FinishReason, ToolDefinition};
14use anyhow::{Result, anyhow};
15use std::sync::{Arc, Mutex};
16use uuid::Uuid;
17
18use super::thinker::{CandleThinker, ThinkerBackend, ThinkerConfig};
19
20// ── Configuration ────────────────────────────────────────────────────────────
21
22/// Environment-variable driven configuration for the tool-call router.
23#[derive(Debug, Clone)]
24pub struct ToolRouterConfig {
25    /// Whether the router is active.  Default: `false`.
26    pub enabled: bool,
27    /// Filesystem path to the FunctionGemma GGUF model.
28    pub model_path: Option<String>,
29    /// Filesystem path to the matching tokenizer.json.
30    pub tokenizer_path: Option<String>,
31    /// Architecture hint (default: `"gemma3"`).
32    pub arch: String,
33    /// Device preference (auto / cpu / cuda).
34    pub device: super::thinker::CandleDevicePreference,
35    /// Max tokens for the FunctionGemma response.
36    /// FunctionGemma only outputs `<tool_call>` JSON blocks — 128 is generous.
37    pub max_tokens: usize,
38    /// Temperature for FunctionGemma sampling.
39    pub temperature: f32,
40}
41
42impl Default for ToolRouterConfig {
43    fn default() -> Self {
44        Self {
45            enabled: false,
46            model_path: None,
47            tokenizer_path: None,
48            arch: "gemma3".to_string(),
49            device: super::thinker::CandleDevicePreference::Auto,
50            max_tokens: 128,
51            temperature: 0.1,
52        }
53    }
54}
55
56impl ToolRouterConfig {
57    /// Build from environment variables.
58    ///
59    /// | Variable | Description |
60    /// |----------|-------------|
61    /// | `CODETETHER_TOOL_ROUTER_ENABLED` | `true` / `1` to activate |
62    /// | `CODETETHER_TOOL_ROUTER_MODEL_PATH` | Path to `.gguf` model |
63    /// | `CODETETHER_TOOL_ROUTER_TOKENIZER_PATH` | Path to `tokenizer.json` |
64    /// | `CODETETHER_TOOL_ROUTER_ARCH` | Architecture hint (default: `gemma3`) |
65    /// | `CODETETHER_TOOL_ROUTER_DEVICE` | `auto` / `cpu` / `cuda` |
66    /// | `CODETETHER_TOOL_ROUTER_MAX_TOKENS` | Max decode tokens (default: 512) |
67    /// | `CODETETHER_TOOL_ROUTER_TEMPERATURE` | Sampling temp (default: 0.1) |
68    pub fn from_env() -> Self {
69        let enabled = std::env::var("CODETETHER_TOOL_ROUTER_ENABLED")
70            .map(|v| matches!(v.as_str(), "1" | "true" | "yes"))
71            .unwrap_or(false);
72
73        Self {
74            enabled,
75            model_path: std::env::var("CODETETHER_TOOL_ROUTER_MODEL_PATH").ok(),
76            tokenizer_path: std::env::var("CODETETHER_TOOL_ROUTER_TOKENIZER_PATH").ok(),
77            arch: std::env::var("CODETETHER_TOOL_ROUTER_ARCH")
78                .unwrap_or_else(|_| "gemma3".to_string()),
79            device: std::env::var("CODETETHER_TOOL_ROUTER_DEVICE")
80                .map(|v| super::thinker::CandleDevicePreference::from_env(&v))
81                .unwrap_or(super::thinker::CandleDevicePreference::Auto),
82            max_tokens: std::env::var("CODETETHER_TOOL_ROUTER_MAX_TOKENS")
83                .ok()
84                .and_then(|v| v.parse().ok())
85                .unwrap_or(128),
86            temperature: std::env::var("CODETETHER_TOOL_ROUTER_TEMPERATURE")
87                .ok()
88                .and_then(|v| v.parse().ok())
89                .unwrap_or(0.1),
90        }
91    }
92}
93
94// ── Prompt formatting ────────────────────────────────────────────────────────
95
96/// Serialize tool definitions into FunctionGemma's expected chat template.
97///
98/// FunctionGemma expects tools as a JSON list in the system turn, followed by
99/// the user's intent.  The model produces structured JSON function call output.
100fn build_functiongemma_prompt(assistant_text: &str, tools: &[ToolDefinition]) -> String {
101    // Build tool descriptions as a JSON array for the system section.
102    let tool_defs: Vec<serde_json::Value> = tools
103        .iter()
104        .map(|t| {
105            serde_json::json!({
106                "name": t.name,
107                "description": t.description,
108                "parameters": t.parameters,
109            })
110        })
111        .collect();
112
113    let tools_json = serde_json::to_string_pretty(&tool_defs).unwrap_or_else(|_| "[]".to_string());
114
115    // FunctionGemma chat template:
116    //   <start_of_turn>system
117    //   You are a function calling AI model. ...
118    //   <end_of_turn>
119    //   <start_of_turn>user
120    //   <user intent text>
121    //   <end_of_turn>
122    //   <start_of_turn>model
123    format!(
124        "<start_of_turn>system\n\
125         You are a function calling AI model. You are provided with function \
126         signatures within <tools></tools> XML tags. You may call one or more \
127         functions to assist with the user query. Don't make assumptions about \
128         what values to plug into functions.\n\n\
129         <tools>\n{tools_json}\n</tools>\n\n\
130         For each function call return a JSON object with function name and \
131         arguments within <tool_call></tool_call> XML tags as follows:\n\
132         <tool_call>\n{{\"name\": \"function_name\", \"arguments\": {{\"arg1\": \"value1\"}}}}\n</tool_call>\n\
133         <end_of_turn>\n\
134         <start_of_turn>user\n\
135         {assistant_text}\n\
136         <end_of_turn>\n\
137         <start_of_turn>model\n"
138    )
139}
140
141// ── Response parsing ─────────────────────────────────────────────────────────
142
143/// A single parsed tool call from FunctionGemma output.
144#[derive(Debug, Clone)]
145struct ParsedToolCall {
146    name: String,
147    arguments: String, // JSON string
148}
149
150/// Parse FunctionGemma output into zero or more structured tool calls.
151///
152/// Expected format:
153/// ```text
154/// <tool_call>
155/// {"name": "read_file", "arguments": {"path": "/tmp/foo.rs"}}
156/// </tool_call>
157/// ```
158///
159/// Handles multiple `<tool_call>` blocks in a single response.
160fn parse_functiongemma_response(text: &str) -> Vec<ParsedToolCall> {
161    let mut calls = Vec::new();
162
163    // Extract everything between <tool_call> and </tool_call>
164    let mut remaining = text;
165    while let Some(start) = remaining.find("<tool_call>") {
166        remaining = &remaining[start + "<tool_call>".len()..];
167        if let Some(end) = remaining.find("</tool_call>") {
168            let block = remaining[..end].trim();
169            remaining = &remaining[end + "</tool_call>".len()..];
170
171            // Try to parse the JSON block
172            if let Ok(value) = serde_json::from_str::<serde_json::Value>(block) {
173                let name = value
174                    .get("name")
175                    .and_then(|n| n.as_str())
176                    .unwrap_or("")
177                    .to_string();
178                let arguments = value
179                    .get("arguments")
180                    .map(|a| serde_json::to_string(a).unwrap_or_else(|_| "{}".to_string()))
181                    .unwrap_or_else(|| "{}".to_string());
182
183                if !name.is_empty() {
184                    calls.push(ParsedToolCall { name, arguments });
185                }
186            } else {
187                tracing::warn!(
188                    block = %block,
189                    "FunctionGemma produced unparseable tool_call block"
190                );
191            }
192        } else {
193            break; // Unclosed <tool_call> — stop
194        }
195    }
196
197    calls
198}
199
200// ── Router ───────────────────────────────────────────────────────────────────
201
202/// Hybrid tool-call router backed by a local FunctionGemma model.
203///
204/// Created once at session start; shared via `Arc` across prompt calls.
205pub struct ToolCallRouter {
206    runtime: Arc<Mutex<CandleThinker>>,
207}
208
209impl std::fmt::Debug for ToolCallRouter {
210    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211        f.debug_struct("ToolCallRouter").finish()
212    }
213}
214
215impl ToolCallRouter {
216    /// Construct from a [`ToolRouterConfig`].
217    ///
218    /// Returns `None` if the router is disabled or missing required paths.
219    pub fn from_config(config: &ToolRouterConfig) -> Result<Option<Self>> {
220        if !config.enabled {
221            tracing::debug!("FunctionGemma tool router is disabled");
222            return Ok(None);
223        }
224
225        let model_path = config.model_path.as_ref().ok_or_else(|| {
226            anyhow!("CODETETHER_TOOL_ROUTER_MODEL_PATH is required when the tool router is enabled")
227        })?;
228        let tokenizer_path = config.tokenizer_path.as_ref().ok_or_else(|| {
229            anyhow!(
230                "CODETETHER_TOOL_ROUTER_TOKENIZER_PATH is required when the tool router is enabled"
231            )
232        })?;
233
234        // Build a ThinkerConfig configured for the FunctionGemma model
235        let thinker_config = ThinkerConfig {
236            enabled: true,
237            backend: ThinkerBackend::Candle,
238            candle_model_path: Some(model_path.clone()),
239            candle_tokenizer_path: Some(tokenizer_path.clone()),
240            candle_arch: Some(config.arch.clone()),
241            candle_device: config.device,
242            max_tokens: config.max_tokens,
243            temperature: config.temperature,
244            ..ThinkerConfig::default()
245        };
246
247        let runtime = CandleThinker::new(&thinker_config)?;
248        tracing::info!(
249            model_path = %model_path,
250            arch = %config.arch,
251            "FunctionGemma tool-call router initialised"
252        );
253
254        Ok(Some(Self {
255            runtime: Arc::new(Mutex::new(runtime)),
256        }))
257    }
258
259    /// Conditionally reformat a `CompletionResponse`.
260    ///
261    /// - If the model natively supports tool calling, return **unchanged**
262    ///   (FunctionGemma is only useful for models that lack native tool support).
263    /// - If the response already contains `ContentPart::ToolCall` entries,
264    ///   return it **unchanged** (zero overhead path).
265    /// - If the assistant text doesn't look like it's describing tool usage,
266    ///   return **unchanged** (cheap heuristic avoids expensive inference).
267    /// - Otherwise, run FunctionGemma to convert the text into structured
268    ///   tool calls.
269    /// - On any internal error, return the **original** response unchanged
270    ///   (safe degradation — the router never breaks existing functionality).
271    pub async fn maybe_reformat(
272        &self,
273        response: CompletionResponse,
274        tools: &[ToolDefinition],
275        model_supports_tools: bool,
276    ) -> CompletionResponse {
277        // Fast path: model already handles tool calling natively.
278        // FunctionGemma is only needed for models that return text descriptions
279        // of tool calls instead of structured ContentPart::ToolCall entries.
280        if model_supports_tools {
281            tracing::trace!("Skipping FunctionGemma: model supports native tool calling");
282            return response;
283        }
284
285        // Fast path: if the response already has structured tool calls, pass through.
286        let has_tool_calls = response
287            .message
288            .content
289            .iter()
290            .any(|p| matches!(p, ContentPart::ToolCall { .. }));
291
292        if has_tool_calls {
293            return response;
294        }
295
296        // No tools were provided — nothing for FunctionGemma to match against.
297        if tools.is_empty() {
298            return response;
299        }
300
301        // Collect assistant text from the response.
302        let assistant_text: String = response
303            .message
304            .content
305            .iter()
306            .filter_map(|p| match p {
307                ContentPart::Text { text } => Some(text.as_str()),
308                _ => None,
309            })
310            .collect::<Vec<_>>()
311            .join("\n");
312
313        if assistant_text.trim().is_empty() {
314            return response;
315        }
316
317        // Cheap heuristic: skip FunctionGemma if the text doesn't mention any
318        // available tool name.  This avoids expensive CPU inference for pure
319        // conversational / final-answer responses.
320        let text_lower = assistant_text.to_lowercase();
321        let mentions_tool = tools
322            .iter()
323            .any(|t| text_lower.contains(&t.name.to_lowercase()));
324        if !mentions_tool {
325            tracing::trace!("Skipping FunctionGemma: assistant text mentions no tool names");
326            return response;
327        }
328
329        // Run FunctionGemma in a blocking thread (CPU-bound).
330        match self.run_functiongemma(&assistant_text, tools).await {
331            Ok(parsed) if !parsed.is_empty() => {
332                tracing::info!(
333                    num_calls = parsed.len(),
334                    "FunctionGemma router produced tool calls from text-only response"
335                );
336                self.rewrite_response(response, parsed)
337            }
338            Ok(_) => {
339                // FunctionGemma decided no tool calls are needed — pass through.
340                response
341            }
342            Err(e) => {
343                tracing::warn!(
344                    error = %e,
345                    "FunctionGemma router failed; returning original response"
346                );
347                response
348            }
349        }
350    }
351
352    /// Run the FunctionGemma model in a blocking thread.
353    async fn run_functiongemma(
354        &self,
355        assistant_text: &str,
356        tools: &[ToolDefinition],
357    ) -> Result<Vec<ParsedToolCall>> {
358        let prompt = build_functiongemma_prompt(assistant_text, tools);
359        let runtime = Arc::clone(&self.runtime);
360
361        let output = tokio::task::spawn_blocking(move || {
362            let mut guard = runtime
363                .lock()
364                .map_err(|_| anyhow!("FunctionGemma mutex poisoned"))?;
365            // Use the raw prompt — we've already formatted it with the Gemma chat template.
366            // The thinker's `think()` wraps in System/User/Assistant roles; we need direct
367            // access to the generation loop.  We pass the full prompt as the user message
368            // and an empty system prompt so the thinker doesn't re-wrap it.
369            guard.think("", &prompt)
370        })
371        .await
372        .map_err(|e| anyhow!("FunctionGemma task join failed: {e}"))??;
373
374        Ok(parse_functiongemma_response(&output.text))
375    }
376
377    /// Rewrite the `CompletionResponse` to replace text with structured tool calls.
378    ///
379    /// The original text is **removed** so the model sees a pure tool-call
380    /// assistant turn.  On the follow-up turn it will receive the tool results
381    /// and compose a proper answer – rather than ignoring them because it
382    /// already gave a complete text response.
383    fn rewrite_response(
384        &self,
385        mut response: CompletionResponse,
386        calls: Vec<ParsedToolCall>,
387    ) -> CompletionResponse {
388        // Strip all text parts — the model should see only tool calls so it
389        // properly processes the tool results on the next iteration.
390        response
391            .message
392            .content
393            .retain(|p| !matches!(p, ContentPart::Text { .. }));
394
395        for call in calls {
396            response.message.content.push(ContentPart::ToolCall {
397                id: format!("fc_{}", Uuid::new_v4()),
398                name: call.name,
399                arguments: call.arguments,
400            });
401        }
402
403        // Signal the session loop that tool calls are present.
404        response.finish_reason = FinishReason::ToolCalls;
405        response
406    }
407}
408
409// ── Tests ────────────────────────────────────────────────────────────────────
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn parse_single_tool_call() {
417        let text = r#"<tool_call>
418{"name": "read_file", "arguments": {"path": "/tmp/foo.rs"}}
419</tool_call>"#;
420        let calls = parse_functiongemma_response(text);
421        assert_eq!(calls.len(), 1);
422        assert_eq!(calls[0].name, "read_file");
423        assert!(calls[0].arguments.contains("/tmp/foo.rs"));
424    }
425
426    #[test]
427    fn parse_multiple_tool_calls() {
428        let text = r#"I'll read both files.
429<tool_call>
430{"name": "read_file", "arguments": {"path": "a.rs"}}
431</tool_call>
432<tool_call>
433{"name": "read_file", "arguments": {"path": "b.rs"}}
434</tool_call>"#;
435        let calls = parse_functiongemma_response(text);
436        assert_eq!(calls.len(), 2);
437        assert_eq!(calls[0].name, "read_file");
438        assert_eq!(calls[1].name, "read_file");
439    }
440
441    #[test]
442    fn parse_no_tool_calls() {
443        let text = "I cannot help with that request.";
444        let calls = parse_functiongemma_response(text);
445        assert!(calls.is_empty());
446    }
447
448    #[test]
449    fn parse_malformed_json_skipped() {
450        let text = r#"<tool_call>
451not valid json
452</tool_call>
453<tool_call>
454{"name": "list_dir", "arguments": {"path": "."}}
455</tool_call>"#;
456        let calls = parse_functiongemma_response(text);
457        assert_eq!(calls.len(), 1);
458        assert_eq!(calls[0].name, "list_dir");
459    }
460
461    #[test]
462    fn parse_empty_name_skipped() {
463        let text = r#"<tool_call>
464{"name": "", "arguments": {}}
465</tool_call>"#;
466        let calls = parse_functiongemma_response(text);
467        assert!(calls.is_empty());
468    }
469
470    #[test]
471    fn prompt_contains_tool_definitions() {
472        let tools = vec![ToolDefinition {
473            name: "read_file".to_string(),
474            description: "Read a file".to_string(),
475            parameters: serde_json::json!({
476                "type": "object",
477                "properties": {
478                    "path": { "type": "string" }
479                },
480                "required": ["path"]
481            }),
482        }];
483        let prompt = build_functiongemma_prompt("Please read foo.rs", &tools);
484        assert!(prompt.contains("<start_of_turn>system"));
485        assert!(prompt.contains("read_file"));
486        assert!(prompt.contains("<tools>"));
487        assert!(prompt.contains("Please read foo.rs"));
488        assert!(prompt.contains("<start_of_turn>model"));
489    }
490
491    #[test]
492    fn config_defaults_disabled() {
493        let config = ToolRouterConfig::default();
494        assert!(!config.enabled);
495        assert_eq!(config.arch, "gemma3");
496        assert_eq!(config.max_tokens, 128);
497    }
498}