1use agentzero_core::routing::ModelRouter;
2use agentzero_core::{Tool, ToolContext, ToolResult};
3use async_trait::async_trait;
4use serde::Deserialize;
5use serde_json::json;
6
7#[derive(Debug, Deserialize)]
8struct Input {
9 op: String,
10 #[serde(default)]
11 hint: Option<String>,
12 #[serde(default)]
13 query: Option<String>,
14}
15
16pub struct ModelRoutingConfigTool {
17 router: ModelRouter,
18}
19
20impl ModelRoutingConfigTool {
21 pub fn new(router: ModelRouter) -> Self {
22 Self { router }
23 }
24}
25
26#[async_trait]
27impl Tool for ModelRoutingConfigTool {
28 fn name(&self) -> &'static str {
29 "model_routing_config"
30 }
31
32 fn description(&self) -> &'static str {
33 "View or modify the model routing configuration at runtime."
34 }
35
36 fn input_schema(&self) -> Option<serde_json::Value> {
37 Some(serde_json::json!({
38 "type": "object",
39 "properties": {
40 "op": { "type": "string", "enum": ["list_routes", "list_embedding_routes", "resolve_hint", "classify_query", "route_query"], "description": "The routing operation to perform" },
41 "hint": { "type": "string", "description": "Route hint to resolve (for resolve_hint)" },
42 "query": { "type": "string", "description": "Query to classify or route (for classify_query/route_query)" }
43 },
44 "required": ["op"],
45 "additionalProperties": false
46 }))
47 }
48
49 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
50 let parsed: Input =
51 serde_json::from_str(input).map_err(|e| anyhow::anyhow!("invalid input: {e}"))?;
52
53 let output = match parsed.op.as_str() {
54 "list_routes" => {
55 let hints: Vec<&str> = self
56 .router
57 .model_routes
58 .iter()
59 .map(|r| r.hint.as_str())
60 .collect();
61 json!({ "routes": hints }).to_string()
62 }
63 "list_embedding_routes" => {
64 let hints: Vec<&str> = self
65 .router
66 .embedding_routes
67 .iter()
68 .map(|r| r.hint.as_str())
69 .collect();
70 json!({ "embedding_routes": hints }).to_string()
71 }
72 "resolve_hint" => {
73 let hint = parsed
74 .hint
75 .as_deref()
76 .ok_or_else(|| anyhow::anyhow!("resolve_hint requires a `hint` field"))?;
77 match self.router.resolve_hint(hint) {
78 Some(route) => json!({
79 "hint": route.matched_hint,
80 "provider": route.provider,
81 "model": route.model,
82 "max_tokens": route.max_tokens,
83 })
84 .to_string(),
85 None => json!({ "error": format!("unknown hint: {hint}") }).to_string(),
86 }
87 }
88 "classify_query" => {
89 let query = parsed
90 .query
91 .as_deref()
92 .ok_or_else(|| anyhow::anyhow!("classify_query requires a `query` field"))?;
93 match self.router.classify_query(query) {
94 Some(hint) => json!({ "hint": hint }).to_string(),
95 None => json!({ "hint": null }).to_string(),
96 }
97 }
98 "route_query" => {
99 let query = parsed
100 .query
101 .as_deref()
102 .ok_or_else(|| anyhow::anyhow!("route_query requires a `query` field"))?;
103 match self.router.route_query(query) {
104 Some(route) => json!({
105 "hint": route.matched_hint,
106 "provider": route.provider,
107 "model": route.model,
108 "max_tokens": route.max_tokens,
109 })
110 .to_string(),
111 None => json!({ "route": null }).to_string(),
112 }
113 }
114 other => json!({ "error": format!("unknown op: {other}") }).to_string(),
115 };
116
117 Ok(ToolResult { output })
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use agentzero_core::routing::{ClassificationRule, EmbeddingRoute, ModelRoute, ModelRouter};
125
126 fn test_router() -> ModelRouter {
127 ModelRouter {
128 model_routes: vec![
129 ModelRoute {
130 hint: "fast".into(),
131 provider: "openai".into(),
132 model: "gpt-4o-mini".into(),
133 max_tokens: Some(4096),
134 api_key: None,
135 transport: None,
136 },
137 ModelRoute {
138 hint: "reasoning".into(),
139 provider: "openai".into(),
140 model: "o1".into(),
141 max_tokens: Some(8192),
142 api_key: None,
143 transport: None,
144 },
145 ],
146 embedding_routes: vec![EmbeddingRoute {
147 hint: "default".into(),
148 provider: "openai".into(),
149 model: "text-embedding-3-small".into(),
150 dimensions: Some(1536),
151 api_key: None,
152 }],
153 classification_rules: vec![ClassificationRule {
154 hint: "reasoning".into(),
155 keywords: vec!["explain".into(), "why".into()],
156 patterns: vec![],
157 min_length: None,
158 max_length: None,
159 priority: 10,
160 }],
161 classification_enabled: true,
162 }
163 }
164
165 fn test_ctx() -> ToolContext {
166 ToolContext::new("/tmp".to_string())
167 }
168
169 #[tokio::test]
170 async fn list_routes_returns_hints() {
171 let tool = ModelRoutingConfigTool::new(test_router());
172 let result = tool
173 .execute(r#"{"op":"list_routes"}"#, &test_ctx())
174 .await
175 .expect("should succeed");
176 let v: serde_json::Value = serde_json::from_str(&result.output).unwrap();
177 let routes = v["routes"].as_array().unwrap();
178 assert_eq!(routes.len(), 2);
179 assert_eq!(routes[0], "fast");
180 assert_eq!(routes[1], "reasoning");
181 }
182
183 #[tokio::test]
184 async fn list_embedding_routes_returns_hints() {
185 let tool = ModelRoutingConfigTool::new(test_router());
186 let result = tool
187 .execute(r#"{"op":"list_embedding_routes"}"#, &test_ctx())
188 .await
189 .expect("should succeed");
190 let v: serde_json::Value = serde_json::from_str(&result.output).unwrap();
191 let routes = v["embedding_routes"].as_array().unwrap();
192 assert_eq!(routes.len(), 1);
193 assert_eq!(routes[0], "default");
194 }
195
196 #[tokio::test]
197 async fn resolve_hint_returns_route() {
198 let tool = ModelRoutingConfigTool::new(test_router());
199 let result = tool
200 .execute(r#"{"op":"resolve_hint","hint":"fast"}"#, &test_ctx())
201 .await
202 .expect("should succeed");
203 let v: serde_json::Value = serde_json::from_str(&result.output).unwrap();
204 assert_eq!(v["model"], "gpt-4o-mini");
205 assert_eq!(v["provider"], "openai");
206 }
207
208 #[tokio::test]
209 async fn resolve_hint_unknown_returns_error() {
210 let tool = ModelRoutingConfigTool::new(test_router());
211 let result = tool
212 .execute(r#"{"op":"resolve_hint","hint":"nonexistent"}"#, &test_ctx())
213 .await
214 .expect("should succeed");
215 let v: serde_json::Value = serde_json::from_str(&result.output).unwrap();
216 assert!(v["error"].as_str().unwrap().contains("unknown hint"));
217 }
218
219 #[tokio::test]
220 async fn classify_query_returns_hint() {
221 let tool = ModelRoutingConfigTool::new(test_router());
222 let result = tool
223 .execute(
224 r#"{"op":"classify_query","query":"explain why this is wrong"}"#,
225 &test_ctx(),
226 )
227 .await
228 .expect("should succeed");
229 let v: serde_json::Value = serde_json::from_str(&result.output).unwrap();
230 assert_eq!(v["hint"], "reasoning");
231 }
232
233 #[tokio::test]
234 async fn invalid_op_returns_error() {
235 let tool = ModelRoutingConfigTool::new(test_router());
236 let result = tool
237 .execute(r#"{"op":"delete_everything"}"#, &test_ctx())
238 .await
239 .expect("should succeed");
240 let v: serde_json::Value = serde_json::from_str(&result.output).unwrap();
241 assert!(v["error"].as_str().unwrap().contains("unknown op"));
242 }
243}