Skip to main content

agentzero_tools/
model_routing_config.rs

1use agentzero_core::routing::ModelRouter;
2use agentzero_core::{Tool, ToolContext, ToolResult};
3use async_trait::async_trait;
4use serde::Deserialize;
5use serde_json::json;
6
7#[derive(Debug, Deserialize)]
8struct Input {
9    op: String,
10    #[serde(default)]
11    hint: Option<String>,
12    #[serde(default)]
13    query: Option<String>,
14}
15
16pub struct ModelRoutingConfigTool {
17    router: ModelRouter,
18}
19
20impl ModelRoutingConfigTool {
21    pub fn new(router: ModelRouter) -> Self {
22        Self { router }
23    }
24}
25
26#[async_trait]
27impl Tool for ModelRoutingConfigTool {
28    fn name(&self) -> &'static str {
29        "model_routing_config"
30    }
31
32    fn description(&self) -> &'static str {
33        "View or modify the model routing configuration at runtime."
34    }
35
36    fn input_schema(&self) -> Option<serde_json::Value> {
37        Some(serde_json::json!({
38            "type": "object",
39            "properties": {
40                "op": { "type": "string", "enum": ["list_routes", "list_embedding_routes", "resolve_hint", "classify_query", "route_query"], "description": "The routing operation to perform" },
41                "hint": { "type": "string", "description": "Route hint to resolve (for resolve_hint)" },
42                "query": { "type": "string", "description": "Query to classify or route (for classify_query/route_query)" }
43            },
44            "required": ["op"],
45            "additionalProperties": false
46        }))
47    }
48
49    async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
50        let parsed: Input =
51            serde_json::from_str(input).map_err(|e| anyhow::anyhow!("invalid input: {e}"))?;
52
53        let output = match parsed.op.as_str() {
54            "list_routes" => {
55                let hints: Vec<&str> = self
56                    .router
57                    .model_routes
58                    .iter()
59                    .map(|r| r.hint.as_str())
60                    .collect();
61                json!({ "routes": hints }).to_string()
62            }
63            "list_embedding_routes" => {
64                let hints: Vec<&str> = self
65                    .router
66                    .embedding_routes
67                    .iter()
68                    .map(|r| r.hint.as_str())
69                    .collect();
70                json!({ "embedding_routes": hints }).to_string()
71            }
72            "resolve_hint" => {
73                let hint = parsed
74                    .hint
75                    .as_deref()
76                    .ok_or_else(|| anyhow::anyhow!("resolve_hint requires a `hint` field"))?;
77                match self.router.resolve_hint(hint) {
78                    Some(route) => json!({
79                        "hint": route.matched_hint,
80                        "provider": route.provider,
81                        "model": route.model,
82                        "max_tokens": route.max_tokens,
83                    })
84                    .to_string(),
85                    None => json!({ "error": format!("unknown hint: {hint}") }).to_string(),
86                }
87            }
88            "classify_query" => {
89                let query = parsed
90                    .query
91                    .as_deref()
92                    .ok_or_else(|| anyhow::anyhow!("classify_query requires a `query` field"))?;
93                match self.router.classify_query(query) {
94                    Some(hint) => json!({ "hint": hint }).to_string(),
95                    None => json!({ "hint": null }).to_string(),
96                }
97            }
98            "route_query" => {
99                let query = parsed
100                    .query
101                    .as_deref()
102                    .ok_or_else(|| anyhow::anyhow!("route_query requires a `query` field"))?;
103                match self.router.route_query(query) {
104                    Some(route) => json!({
105                        "hint": route.matched_hint,
106                        "provider": route.provider,
107                        "model": route.model,
108                        "max_tokens": route.max_tokens,
109                    })
110                    .to_string(),
111                    None => json!({ "route": null }).to_string(),
112                }
113            }
114            other => json!({ "error": format!("unknown op: {other}") }).to_string(),
115        };
116
117        Ok(ToolResult { output })
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use agentzero_core::routing::{ClassificationRule, EmbeddingRoute, ModelRoute, ModelRouter};
125
126    fn test_router() -> ModelRouter {
127        ModelRouter {
128            model_routes: vec![
129                ModelRoute {
130                    hint: "fast".into(),
131                    provider: "openai".into(),
132                    model: "gpt-4o-mini".into(),
133                    max_tokens: Some(4096),
134                    api_key: None,
135                    transport: None,
136                },
137                ModelRoute {
138                    hint: "reasoning".into(),
139                    provider: "openai".into(),
140                    model: "o1".into(),
141                    max_tokens: Some(8192),
142                    api_key: None,
143                    transport: None,
144                },
145            ],
146            embedding_routes: vec![EmbeddingRoute {
147                hint: "default".into(),
148                provider: "openai".into(),
149                model: "text-embedding-3-small".into(),
150                dimensions: Some(1536),
151                api_key: None,
152            }],
153            classification_rules: vec![ClassificationRule {
154                hint: "reasoning".into(),
155                keywords: vec!["explain".into(), "why".into()],
156                patterns: vec![],
157                min_length: None,
158                max_length: None,
159                priority: 10,
160            }],
161            classification_enabled: true,
162        }
163    }
164
165    fn test_ctx() -> ToolContext {
166        ToolContext::new("/tmp".to_string())
167    }
168
169    #[tokio::test]
170    async fn list_routes_returns_hints() {
171        let tool = ModelRoutingConfigTool::new(test_router());
172        let result = tool
173            .execute(r#"{"op":"list_routes"}"#, &test_ctx())
174            .await
175            .expect("should succeed");
176        let v: serde_json::Value = serde_json::from_str(&result.output).unwrap();
177        let routes = v["routes"].as_array().unwrap();
178        assert_eq!(routes.len(), 2);
179        assert_eq!(routes[0], "fast");
180        assert_eq!(routes[1], "reasoning");
181    }
182
183    #[tokio::test]
184    async fn list_embedding_routes_returns_hints() {
185        let tool = ModelRoutingConfigTool::new(test_router());
186        let result = tool
187            .execute(r#"{"op":"list_embedding_routes"}"#, &test_ctx())
188            .await
189            .expect("should succeed");
190        let v: serde_json::Value = serde_json::from_str(&result.output).unwrap();
191        let routes = v["embedding_routes"].as_array().unwrap();
192        assert_eq!(routes.len(), 1);
193        assert_eq!(routes[0], "default");
194    }
195
196    #[tokio::test]
197    async fn resolve_hint_returns_route() {
198        let tool = ModelRoutingConfigTool::new(test_router());
199        let result = tool
200            .execute(r#"{"op":"resolve_hint","hint":"fast"}"#, &test_ctx())
201            .await
202            .expect("should succeed");
203        let v: serde_json::Value = serde_json::from_str(&result.output).unwrap();
204        assert_eq!(v["model"], "gpt-4o-mini");
205        assert_eq!(v["provider"], "openai");
206    }
207
208    #[tokio::test]
209    async fn resolve_hint_unknown_returns_error() {
210        let tool = ModelRoutingConfigTool::new(test_router());
211        let result = tool
212            .execute(r#"{"op":"resolve_hint","hint":"nonexistent"}"#, &test_ctx())
213            .await
214            .expect("should succeed");
215        let v: serde_json::Value = serde_json::from_str(&result.output).unwrap();
216        assert!(v["error"].as_str().unwrap().contains("unknown hint"));
217    }
218
219    #[tokio::test]
220    async fn classify_query_returns_hint() {
221        let tool = ModelRoutingConfigTool::new(test_router());
222        let result = tool
223            .execute(
224                r#"{"op":"classify_query","query":"explain why this is wrong"}"#,
225                &test_ctx(),
226            )
227            .await
228            .expect("should succeed");
229        let v: serde_json::Value = serde_json::from_str(&result.output).unwrap();
230        assert_eq!(v["hint"], "reasoning");
231    }
232
233    #[tokio::test]
234    async fn invalid_op_returns_error() {
235        let tool = ModelRoutingConfigTool::new(test_router());
236        let result = tool
237            .execute(r#"{"op":"delete_everything"}"#, &test_ctx())
238            .await
239            .expect("should succeed");
240        let v: serde_json::Value = serde_json::from_str(&result.output).unwrap();
241        assert!(v["error"].as_str().unwrap().contains("unknown op"));
242    }
243}