1use async_trait::async_trait;
2use std::sync::Arc;
3
4use ai_agents_core::{
5 ChatMessage, LLMCapability, LLMChunk, LLMConfig, LLMError, LLMFeature, LLMProvider,
6 LLMResponse, TaskContext, ToolSelection,
7};
8
9use super::capability::DefaultLLMCapability;
10
11#[derive(Clone)]
12pub struct MultiLLMRouter {
13 primary: Arc<dyn LLMProvider>,
14 tool_selector: Option<Arc<dyn LLMProvider>>,
15 guard_evaluator: Option<Arc<dyn LLMProvider>>,
16 classifier: Option<Arc<dyn LLMProvider>>,
17 enable_fallback: bool,
18}
19
20impl MultiLLMRouter {
21 pub fn new(primary: Arc<dyn LLMProvider>) -> Self {
22 Self {
23 primary,
24 tool_selector: None,
25 guard_evaluator: None,
26 classifier: None,
27 enable_fallback: true,
28 }
29 }
30
31 pub fn with_tool_selector(mut self, provider: Arc<dyn LLMProvider>) -> Self {
32 self.tool_selector = Some(provider);
33 self
34 }
35
36 pub fn with_guard_evaluator(mut self, provider: Arc<dyn LLMProvider>) -> Self {
37 self.guard_evaluator = Some(provider);
38 self
39 }
40
41 pub fn with_classifier(mut self, provider: Arc<dyn LLMProvider>) -> Self {
42 self.classifier = Some(provider);
43 self
44 }
45
46 pub fn with_fallback(mut self, enable: bool) -> Self {
47 self.enable_fallback = enable;
48 self
49 }
50
51 fn get_tool_selector(&self) -> Arc<dyn LLMProvider> {
52 self.tool_selector
53 .as_ref()
54 .cloned()
55 .unwrap_or_else(|| self.primary.clone())
56 }
57
58 fn get_guard_evaluator(&self) -> Arc<dyn LLMProvider> {
59 self.guard_evaluator
60 .as_ref()
61 .cloned()
62 .unwrap_or_else(|| self.primary.clone())
63 }
64
65 fn get_classifier(&self) -> Arc<dyn LLMProvider> {
66 self.classifier
67 .as_ref()
68 .cloned()
69 .unwrap_or_else(|| self.primary.clone())
70 }
71
72 #[allow(dead_code)]
73 async fn execute_with_fallback<F, Fut, T>(
74 &self,
75 primary_fn: F,
76 _fallback_provider: Arc<dyn LLMProvider>,
77 operation: &str,
78 ) -> Result<T, LLMError>
79 where
80 F: FnOnce() -> Fut,
81 Fut: std::future::Future<Output = Result<T, LLMError>>,
82 {
83 match primary_fn().await {
84 Ok(result) => Ok(result),
85 Err(e) if self.enable_fallback => {
86 eprintln!(
87 "Multi-LLM: {} failed with specialized provider, falling back to primary: {}",
88 operation, e
89 );
90 Err(e)
91 }
92 Err(e) => Err(e),
93 }
94 }
95}
96
97#[async_trait]
98impl LLMProvider for MultiLLMRouter {
99 async fn complete(
100 &self,
101 messages: &[ChatMessage],
102 config: Option<&LLMConfig>,
103 ) -> Result<LLMResponse, LLMError> {
104 self.primary.complete(messages, config).await
105 }
106
107 async fn complete_stream(
108 &self,
109 messages: &[ChatMessage],
110 config: Option<&LLMConfig>,
111 ) -> Result<Box<dyn futures::Stream<Item = Result<LLMChunk, LLMError>> + Unpin + Send>, LLMError>
112 {
113 self.primary.complete_stream(messages, config).await
114 }
115
116 fn provider_name(&self) -> &str {
117 "multi-llm-router"
118 }
119
120 fn supports(&self, feature: LLMFeature) -> bool {
121 self.primary.supports(feature)
122 }
123}
124
125#[async_trait]
126impl LLMCapability for MultiLLMRouter {
127 async fn select_tool(
128 &self,
129 context: &TaskContext,
130 user_input: &str,
131 ) -> Result<ToolSelection, LLMError> {
132 let provider = self.get_tool_selector();
133 let capability = DefaultLLMCapability::new(provider);
134 capability.select_tool(context, user_input).await
135 }
136
137 async fn generate_tool_args(
138 &self,
139 tool_id: &str,
140 user_input: &str,
141 schema: &serde_json::Value,
142 ) -> Result<serde_json::Value, LLMError> {
143 let provider = self.get_tool_selector();
144 let capability = DefaultLLMCapability::new(provider);
145 capability
146 .generate_tool_args(tool_id, user_input, schema)
147 .await
148 }
149
150 async fn evaluate_yesno(
151 &self,
152 question: &str,
153 context: &TaskContext,
154 ) -> Result<(bool, String), LLMError> {
155 let provider = self.get_guard_evaluator();
156 let capability = DefaultLLMCapability::new(provider);
157 capability.evaluate_yesno(question, context).await
158 }
159
160 async fn classify(
161 &self,
162 input: &str,
163 categories: &[String],
164 ) -> Result<(String, f32), LLMError> {
165 let provider = self.get_classifier();
166 let capability = DefaultLLMCapability::new(provider);
167 capability.classify(input, categories).await
168 }
169
170 async fn process_task(
171 &self,
172 context: &TaskContext,
173 system_prompt: &str,
174 ) -> Result<LLMResponse, LLMError> {
175 let capability = DefaultLLMCapability::new(self.primary.clone());
176 capability.process_task(context, system_prompt).await
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::mock::MockLLMProvider;
184 use ai_agents_core::{FinishReason, Role};
185 use std::collections::HashMap;
186
187 #[tokio::test]
188 async fn test_router_with_primary_only() {
189 let mut primary = MockLLMProvider::new("primary");
190 primary.add_response(LLMResponse::new("Hello from primary", FinishReason::Stop));
191
192 let router = MultiLLMRouter::new(Arc::new(primary));
193
194 let messages = vec![ChatMessage {
195 timestamp: None,
196 role: Role::User,
197 content: "Test".to_string(),
198 name: None,
199 }];
200
201 let response = router.complete(&messages, None).await.unwrap();
202 assert_eq!(response.content, "Hello from primary");
203 }
204
205 #[tokio::test]
206 async fn test_router_with_specialized_providers() {
207 let mut primary = MockLLMProvider::new("primary");
208 primary.add_response(LLMResponse::new("Primary response", FinishReason::Stop));
209
210 let mut tool_selector = MockLLMProvider::new("tool-selector");
211 tool_selector.add_response(LLMResponse::new(
212 r#"{"tool_id": "calculator", "confidence": 0.9}"#,
213 FinishReason::Stop,
214 ));
215
216 let mut guard = MockLLMProvider::new("guard");
217 guard.add_response(LLMResponse::new(
218 r#"{"answer": true, "reasoning": "Approved"}"#,
219 FinishReason::Stop,
220 ));
221
222 let router = MultiLLMRouter::new(Arc::new(primary))
223 .with_tool_selector(Arc::new(tool_selector))
224 .with_guard_evaluator(Arc::new(guard));
225
226 let context = TaskContext {
227 current_state: None,
228 available_tools: vec!["calculator".to_string()],
229 memory_slots: HashMap::new(),
230 recent_messages: vec![],
231 };
232
233 let tool_selection = router.select_tool(&context, "Do math").await.unwrap();
234 assert_eq!(tool_selection.tool_id, "calculator");
235 assert_eq!(tool_selection.confidence, 0.9);
236
237 let (answer, reasoning) = router
238 .evaluate_yesno("Is it safe?", &context)
239 .await
240 .unwrap();
241 assert!(answer);
242 assert_eq!(reasoning, "Approved");
243 }
244
245 #[tokio::test]
246 async fn test_router_fallback_to_primary() {
247 let mut primary = MockLLMProvider::new("primary");
248 primary.add_response(LLMResponse::new("Primary response", FinishReason::Stop));
249
250 let router = MultiLLMRouter::new(Arc::new(primary)).with_fallback(true);
251
252 let messages = vec![ChatMessage {
253 timestamp: None,
254 role: Role::User,
255 content: "Test".to_string(),
256 name: None,
257 }];
258
259 let response = router.complete(&messages, None).await.unwrap();
260 assert_eq!(response.content, "Primary response");
261 }
262
263 #[tokio::test]
264 async fn test_router_provider_name() {
265 let primary = MockLLMProvider::new("primary");
266 let router = MultiLLMRouter::new(Arc::new(primary));
267
268 assert_eq!(router.provider_name(), "multi-llm-router");
269 }
270
271 #[tokio::test]
272 async fn test_router_supports() {
273 let mut primary = MockLLMProvider::new("primary");
274 primary.set_feature_support(LLMFeature::Streaming, true);
275
276 let router = MultiLLMRouter::new(Arc::new(primary));
277
278 assert!(router.supports(LLMFeature::Streaming));
279 }
280
281 #[tokio::test]
282 async fn test_classify_with_specialized_provider() {
283 let primary = MockLLMProvider::new("primary");
284
285 let mut classifier = MockLLMProvider::new("classifier");
286 classifier.add_response(LLMResponse::new(
287 r#"{"category": "greeting", "confidence": 0.95}"#,
288 FinishReason::Stop,
289 ));
290
291 let router = MultiLLMRouter::new(Arc::new(primary)).with_classifier(Arc::new(classifier));
292
293 let categories = vec!["greeting".to_string(), "question".to_string()];
294 let (category, confidence) = router.classify("Hello!", &categories).await.unwrap();
295
296 assert_eq!(category, "greeting");
297 assert_eq!(confidence, 0.95);
298 }
299
300 #[tokio::test]
301 async fn test_process_task_uses_primary() {
302 let mut primary = MockLLMProvider::new("primary");
303 primary.add_response(LLMResponse::new(
304 "Task processed by primary",
305 FinishReason::Stop,
306 ));
307
308 let tool_selector = MockLLMProvider::new("tool-selector");
309
310 let router =
311 MultiLLMRouter::new(Arc::new(primary)).with_tool_selector(Arc::new(tool_selector));
312
313 let context = TaskContext {
314 current_state: None,
315 available_tools: vec![],
316 memory_slots: HashMap::new(),
317 recent_messages: vec![],
318 };
319
320 let response = router
321 .process_task(&context, "System prompt")
322 .await
323 .unwrap();
324
325 assert_eq!(response.content, "Task processed by primary");
326 }
327
328 #[tokio::test]
329 async fn test_generate_tool_args_with_specialized() {
330 let primary = MockLLMProvider::new("primary");
331
332 let mut tool_selector = MockLLMProvider::new("tool-selector");
333 tool_selector.add_response(LLMResponse::new(
334 r#"{"expression": "2 + 2"}"#,
335 FinishReason::Stop,
336 ));
337
338 let router =
339 MultiLLMRouter::new(Arc::new(primary)).with_tool_selector(Arc::new(tool_selector));
340
341 let schema = serde_json::json!({
342 "type": "object",
343 "properties": {
344 "expression": {"type": "string"}
345 }
346 });
347
348 let result = router
349 .generate_tool_args("calculator", "Calculate 2 + 2", &schema)
350 .await
351 .unwrap();
352
353 assert_eq!(result["expression"], "2 + 2");
354 }
355
356 #[test]
357 fn test_builder_pattern() {
358 let primary = MockLLMProvider::new("primary");
359 let tool_selector = MockLLMProvider::new("tool-selector");
360 let guard = MockLLMProvider::new("guard");
361 let classifier = MockLLMProvider::new("classifier");
362
363 let router = MultiLLMRouter::new(Arc::new(primary))
364 .with_tool_selector(Arc::new(tool_selector))
365 .with_guard_evaluator(Arc::new(guard))
366 .with_classifier(Arc::new(classifier))
367 .with_fallback(false);
368
369 assert_eq!(router.provider_name(), "multi-llm-router");
370 assert!(!router.enable_fallback);
371 }
372}