Skip to main content

construct/tools/
model_switch.rs

1use super::traits::{Tool, ToolResult};
2use crate::agent::loop_::get_model_switch_state;
3use crate::providers;
4use crate::security::SecurityPolicy;
5use crate::security::policy::ToolOperation;
6use async_trait::async_trait;
7use serde_json::json;
8use std::sync::Arc;
9
10pub struct ModelSwitchTool {
11    security: Arc<SecurityPolicy>,
12}
13
14impl ModelSwitchTool {
15    pub fn new(security: Arc<SecurityPolicy>) -> Self {
16        Self { security }
17    }
18}
19
20#[async_trait]
21impl Tool for ModelSwitchTool {
22    fn name(&self) -> &str {
23        "model_switch"
24    }
25
26    fn description(&self) -> &str {
27        "Switch the AI model at runtime. Use 'get' to see current model, 'list_providers' to see available providers, 'list_models' to see models for a provider, or 'set' to switch to a different model. The switch takes effect immediately for the current conversation."
28    }
29
30    fn parameters_schema(&self) -> serde_json::Value {
31        json!({
32            "type": "object",
33            "properties": {
34                "action": {
35                    "type": "string",
36                    "enum": ["get", "set", "list_providers", "list_models"],
37                    "description": "Action to perform: get current model, set a new model, list available providers, or list models for a provider"
38                },
39                "provider": {
40                    "type": "string",
41                    "description": "Provider name (e.g., 'openai', 'anthropic', 'groq', 'ollama'). Required for 'set' and 'list_models' actions."
42                },
43                "model": {
44                    "type": "string",
45                    "description": "Model ID (e.g., 'gpt-4o', 'claude-sonnet-4-6'). Required for 'set' action."
46                }
47            },
48            "required": ["action"]
49        })
50    }
51
52    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
53        let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("get");
54
55        if let Err(error) = self
56            .security
57            .enforce_tool_operation(ToolOperation::Act, "model_switch")
58        {
59            return Ok(ToolResult {
60                success: false,
61                output: String::new(),
62                error: Some(error),
63            });
64        }
65
66        match action {
67            "get" => self.handle_get(),
68            "set" => self.handle_set(&args),
69            "list_providers" => self.handle_list_providers(),
70            "list_models" => self.handle_list_models(&args),
71            _ => Ok(ToolResult {
72                success: false,
73                output: String::new(),
74                error: Some(format!(
75                    "Unknown action: {}. Valid actions: get, set, list_providers, list_models",
76                    action
77                )),
78            }),
79        }
80    }
81}
82
83impl ModelSwitchTool {
84    fn handle_get(&self) -> anyhow::Result<ToolResult> {
85        let switch_state = get_model_switch_state();
86        let pending = switch_state.lock().unwrap().clone();
87
88        Ok(ToolResult {
89            success: true,
90            output: serde_json::to_string_pretty(&json!({
91                "pending_switch": pending,
92                "note": "To switch models, use action 'set' with provider and model parameters"
93            }))?,
94            error: None,
95        })
96    }
97
98    fn handle_set(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
99        let provider = args.get("provider").and_then(|v| v.as_str());
100
101        let provider = match provider {
102            Some(p) => p,
103            None => {
104                return Ok(ToolResult {
105                    success: false,
106                    output: String::new(),
107                    error: Some("Missing 'provider' parameter for 'set' action".to_string()),
108                });
109            }
110        };
111
112        let model = args.get("model").and_then(|v| v.as_str());
113
114        let model = match model {
115            Some(m) => m,
116            None => {
117                return Ok(ToolResult {
118                    success: false,
119                    output: String::new(),
120                    error: Some("Missing 'model' parameter for 'set' action".to_string()),
121                });
122            }
123        };
124
125        // Validate the provider exists
126        let known_providers = providers::list_providers();
127        let provider_valid = known_providers.iter().any(|p| {
128            p.name.eq_ignore_ascii_case(provider)
129                || p.aliases.iter().any(|a| a.eq_ignore_ascii_case(provider))
130        });
131
132        if !provider_valid {
133            return Ok(ToolResult {
134                success: false,
135                output: serde_json::to_string_pretty(&json!({
136                    "available_providers": known_providers.iter().map(|p| p.name).collect::<Vec<_>>()
137                }))?,
138                error: Some(format!(
139                    "Unknown provider: {}. Use 'list_providers' to see available options.",
140                    provider
141                )),
142            });
143        }
144
145        // Set the global model switch request
146        let switch_state = get_model_switch_state();
147        *switch_state.lock().unwrap() = Some((provider.to_string(), model.to_string()));
148
149        Ok(ToolResult {
150            success: true,
151            output: serde_json::to_string_pretty(&json!({
152                "message": "Model switch requested",
153                "provider": provider,
154                "model": model,
155                "note": "The agent will switch to this model on the next turn. Use 'get' to check pending switch."
156            }))?,
157            error: None,
158        })
159    }
160
161    fn handle_list_providers(&self) -> anyhow::Result<ToolResult> {
162        let providers_list = providers::list_providers();
163
164        let providers: Vec<serde_json::Value> = providers_list
165            .iter()
166            .map(|p| {
167                json!({
168                    "name": p.name,
169                    "display_name": p.display_name,
170                    "aliases": p.aliases,
171                    "local": p.local
172                })
173            })
174            .collect();
175
176        Ok(ToolResult {
177            success: true,
178            output: serde_json::to_string_pretty(&json!({
179                "providers": providers,
180                "count": providers.len(),
181                "example": "Use action 'set' with provider and model to switch"
182            }))?,
183            error: None,
184        })
185    }
186
187    fn handle_list_models(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
188        let provider = args.get("provider").and_then(|v| v.as_str());
189
190        let provider = match provider {
191            Some(p) => p,
192            None => {
193                return Ok(ToolResult {
194                    success: false,
195                    output: String::new(),
196                    error: Some(
197                        "Missing 'provider' parameter for 'list_models' action".to_string(),
198                    ),
199                });
200            }
201        };
202
203        // Return common models for known providers
204        let models = match provider.to_lowercase().as_str() {
205            "openai" => vec![
206                "gpt-4o",
207                "gpt-4o-mini",
208                "gpt-4-turbo",
209                "gpt-4",
210                "gpt-3.5-turbo",
211            ],
212            "anthropic" => vec![
213                "claude-sonnet-4-6",
214                "claude-sonnet-4-5",
215                "claude-3-5-sonnet",
216                "claude-3-opus",
217                "claude-3-haiku",
218            ],
219            "openrouter" => vec![
220                "anthropic/claude-sonnet-4-6",
221                "openai/gpt-4o",
222                "google/gemini-pro",
223                "meta-llama/llama-3-70b-instruct",
224            ],
225            "groq" => vec![
226                "llama-3.3-70b-versatile",
227                "mixtral-8x7b-32768",
228                "llama-3.1-70b-speculative",
229            ],
230            "ollama" => vec!["llama3", "llama3.1", "mistral", "codellama", "phi3"],
231            "deepseek" => vec!["deepseek-chat", "deepseek-coder"],
232            "mistral" => vec![
233                "mistral-large-latest",
234                "mistral-small-latest",
235                "mistral-nemo",
236            ],
237            "google" | "gemini" => vec!["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"],
238            "xai" | "grok" => vec!["grok-2", "grok-2-vision", "grok-beta"],
239            _ => vec![],
240        };
241
242        if models.is_empty() {
243            return Ok(ToolResult {
244                success: true,
245                output: serde_json::to_string_pretty(&json!({
246                    "provider": provider,
247                    "models": [],
248                    "note": "No common models listed for this provider. Check provider documentation for available models."
249                }))?,
250                error: None,
251            });
252        }
253
254        Ok(ToolResult {
255            success: true,
256            output: serde_json::to_string_pretty(&json!({
257                "provider": provider,
258                "models": models,
259                "example": "Use action 'set' with this provider and a model ID to switch"
260            }))?,
261            error: None,
262        })
263    }
264}