1use super::traits::{Tool, ToolResult};
2use crate::agent::loop_::get_model_switch_state;
3use crate::providers;
4use crate::security::SecurityPolicy;
5use crate::security::policy::ToolOperation;
6use async_trait::async_trait;
7use serde_json::json;
8use std::sync::Arc;
9
10pub struct ModelSwitchTool {
11 security: Arc<SecurityPolicy>,
12}
13
14impl ModelSwitchTool {
15 pub fn new(security: Arc<SecurityPolicy>) -> Self {
16 Self { security }
17 }
18}
19
20#[async_trait]
21impl Tool for ModelSwitchTool {
22 fn name(&self) -> &str {
23 "model_switch"
24 }
25
26 fn description(&self) -> &str {
27 "Switch the AI model at runtime. Use 'get' to see current model, 'list_providers' to see available providers, 'list_models' to see models for a provider, or 'set' to switch to a different model. The switch takes effect immediately for the current conversation."
28 }
29
30 fn parameters_schema(&self) -> serde_json::Value {
31 json!({
32 "type": "object",
33 "properties": {
34 "action": {
35 "type": "string",
36 "enum": ["get", "set", "list_providers", "list_models"],
37 "description": "Action to perform: get current model, set a new model, list available providers, or list models for a provider"
38 },
39 "provider": {
40 "type": "string",
41 "description": "Provider name (e.g., 'openai', 'anthropic', 'groq', 'ollama'). Required for 'set' and 'list_models' actions."
42 },
43 "model": {
44 "type": "string",
45 "description": "Model ID (e.g., 'gpt-4o', 'claude-sonnet-4-6'). Required for 'set' action."
46 }
47 },
48 "required": ["action"]
49 })
50 }
51
52 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
53 let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("get");
54
55 if let Err(error) = self
56 .security
57 .enforce_tool_operation(ToolOperation::Act, "model_switch")
58 {
59 return Ok(ToolResult {
60 success: false,
61 output: String::new(),
62 error: Some(error),
63 });
64 }
65
66 match action {
67 "get" => self.handle_get(),
68 "set" => self.handle_set(&args),
69 "list_providers" => self.handle_list_providers(),
70 "list_models" => self.handle_list_models(&args),
71 _ => Ok(ToolResult {
72 success: false,
73 output: String::new(),
74 error: Some(format!(
75 "Unknown action: {}. Valid actions: get, set, list_providers, list_models",
76 action
77 )),
78 }),
79 }
80 }
81}
82
83impl ModelSwitchTool {
84 fn handle_get(&self) -> anyhow::Result<ToolResult> {
85 let switch_state = get_model_switch_state();
86 let pending = switch_state.lock().unwrap().clone();
87
88 Ok(ToolResult {
89 success: true,
90 output: serde_json::to_string_pretty(&json!({
91 "pending_switch": pending,
92 "note": "To switch models, use action 'set' with provider and model parameters"
93 }))?,
94 error: None,
95 })
96 }
97
98 fn handle_set(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
99 let provider = args.get("provider").and_then(|v| v.as_str());
100
101 let provider = match provider {
102 Some(p) => p,
103 None => {
104 return Ok(ToolResult {
105 success: false,
106 output: String::new(),
107 error: Some("Missing 'provider' parameter for 'set' action".to_string()),
108 });
109 }
110 };
111
112 let model = args.get("model").and_then(|v| v.as_str());
113
114 let model = match model {
115 Some(m) => m,
116 None => {
117 return Ok(ToolResult {
118 success: false,
119 output: String::new(),
120 error: Some("Missing 'model' parameter for 'set' action".to_string()),
121 });
122 }
123 };
124
125 let known_providers = providers::list_providers();
127 let provider_valid = known_providers.iter().any(|p| {
128 p.name.eq_ignore_ascii_case(provider)
129 || p.aliases.iter().any(|a| a.eq_ignore_ascii_case(provider))
130 });
131
132 if !provider_valid {
133 return Ok(ToolResult {
134 success: false,
135 output: serde_json::to_string_pretty(&json!({
136 "available_providers": known_providers.iter().map(|p| p.name).collect::<Vec<_>>()
137 }))?,
138 error: Some(format!(
139 "Unknown provider: {}. Use 'list_providers' to see available options.",
140 provider
141 )),
142 });
143 }
144
145 let switch_state = get_model_switch_state();
147 *switch_state.lock().unwrap() = Some((provider.to_string(), model.to_string()));
148
149 Ok(ToolResult {
150 success: true,
151 output: serde_json::to_string_pretty(&json!({
152 "message": "Model switch requested",
153 "provider": provider,
154 "model": model,
155 "note": "The agent will switch to this model on the next turn. Use 'get' to check pending switch."
156 }))?,
157 error: None,
158 })
159 }
160
161 fn handle_list_providers(&self) -> anyhow::Result<ToolResult> {
162 let providers_list = providers::list_providers();
163
164 let providers: Vec<serde_json::Value> = providers_list
165 .iter()
166 .map(|p| {
167 json!({
168 "name": p.name,
169 "display_name": p.display_name,
170 "aliases": p.aliases,
171 "local": p.local
172 })
173 })
174 .collect();
175
176 Ok(ToolResult {
177 success: true,
178 output: serde_json::to_string_pretty(&json!({
179 "providers": providers,
180 "count": providers.len(),
181 "example": "Use action 'set' with provider and model to switch"
182 }))?,
183 error: None,
184 })
185 }
186
187 fn handle_list_models(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
188 let provider = args.get("provider").and_then(|v| v.as_str());
189
190 let provider = match provider {
191 Some(p) => p,
192 None => {
193 return Ok(ToolResult {
194 success: false,
195 output: String::new(),
196 error: Some(
197 "Missing 'provider' parameter for 'list_models' action".to_string(),
198 ),
199 });
200 }
201 };
202
203 let models = match provider.to_lowercase().as_str() {
205 "openai" => vec![
206 "gpt-4o",
207 "gpt-4o-mini",
208 "gpt-4-turbo",
209 "gpt-4",
210 "gpt-3.5-turbo",
211 ],
212 "anthropic" => vec![
213 "claude-sonnet-4-6",
214 "claude-sonnet-4-5",
215 "claude-3-5-sonnet",
216 "claude-3-opus",
217 "claude-3-haiku",
218 ],
219 "openrouter" => vec![
220 "anthropic/claude-sonnet-4-6",
221 "openai/gpt-4o",
222 "google/gemini-pro",
223 "meta-llama/llama-3-70b-instruct",
224 ],
225 "groq" => vec![
226 "llama-3.3-70b-versatile",
227 "mixtral-8x7b-32768",
228 "llama-3.1-70b-speculative",
229 ],
230 "ollama" => vec!["llama3", "llama3.1", "mistral", "codellama", "phi3"],
231 "deepseek" => vec!["deepseek-chat", "deepseek-coder"],
232 "mistral" => vec![
233 "mistral-large-latest",
234 "mistral-small-latest",
235 "mistral-nemo",
236 ],
237 "google" | "gemini" => vec!["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"],
238 "xai" | "grok" => vec!["grok-2", "grok-2-vision", "grok-beta"],
239 _ => vec![],
240 };
241
242 if models.is_empty() {
243 return Ok(ToolResult {
244 success: true,
245 output: serde_json::to_string_pretty(&json!({
246 "provider": provider,
247 "models": [],
248 "note": "No common models listed for this provider. Check provider documentation for available models."
249 }))?,
250 error: None,
251 });
252 }
253
254 Ok(ToolResult {
255 success: true,
256 output: serde_json::to_string_pretty(&json!({
257 "provider": provider,
258 "models": models,
259 "example": "Use action 'set' with this provider and a model ID to switch"
260 }))?,
261 error: None,
262 })
263 }
264}