Skip to main content

ai_agents_llm/
multi.rs

1use async_trait::async_trait;
2use std::sync::Arc;
3
4use ai_agents_core::{
5    ChatMessage, LLMCapability, LLMChunk, LLMConfig, LLMError, LLMFeature, LLMProvider,
6    LLMResponse, TaskContext, ToolSelection,
7};
8
9use super::capability::DefaultLLMCapability;
10
11#[derive(Clone)]
12pub struct MultiLLMRouter {
13    primary: Arc<dyn LLMProvider>,
14    tool_selector: Option<Arc<dyn LLMProvider>>,
15    guard_evaluator: Option<Arc<dyn LLMProvider>>,
16    classifier: Option<Arc<dyn LLMProvider>>,
17    enable_fallback: bool,
18}
19
20impl MultiLLMRouter {
21    pub fn new(primary: Arc<dyn LLMProvider>) -> Self {
22        Self {
23            primary,
24            tool_selector: None,
25            guard_evaluator: None,
26            classifier: None,
27            enable_fallback: true,
28        }
29    }
30
31    pub fn with_tool_selector(mut self, provider: Arc<dyn LLMProvider>) -> Self {
32        self.tool_selector = Some(provider);
33        self
34    }
35
36    pub fn with_guard_evaluator(mut self, provider: Arc<dyn LLMProvider>) -> Self {
37        self.guard_evaluator = Some(provider);
38        self
39    }
40
41    pub fn with_classifier(mut self, provider: Arc<dyn LLMProvider>) -> Self {
42        self.classifier = Some(provider);
43        self
44    }
45
46    pub fn with_fallback(mut self, enable: bool) -> Self {
47        self.enable_fallback = enable;
48        self
49    }
50
51    fn get_tool_selector(&self) -> Arc<dyn LLMProvider> {
52        self.tool_selector
53            .as_ref()
54            .cloned()
55            .unwrap_or_else(|| self.primary.clone())
56    }
57
58    fn get_guard_evaluator(&self) -> Arc<dyn LLMProvider> {
59        self.guard_evaluator
60            .as_ref()
61            .cloned()
62            .unwrap_or_else(|| self.primary.clone())
63    }
64
65    fn get_classifier(&self) -> Arc<dyn LLMProvider> {
66        self.classifier
67            .as_ref()
68            .cloned()
69            .unwrap_or_else(|| self.primary.clone())
70    }
71
72    #[allow(dead_code)]
73    async fn execute_with_fallback<F, Fut, T>(
74        &self,
75        primary_fn: F,
76        _fallback_provider: Arc<dyn LLMProvider>,
77        operation: &str,
78    ) -> Result<T, LLMError>
79    where
80        F: FnOnce() -> Fut,
81        Fut: std::future::Future<Output = Result<T, LLMError>>,
82    {
83        match primary_fn().await {
84            Ok(result) => Ok(result),
85            Err(e) if self.enable_fallback => {
86                eprintln!(
87                    "Multi-LLM: {} failed with specialized provider, falling back to primary: {}",
88                    operation, e
89                );
90                Err(e)
91            }
92            Err(e) => Err(e),
93        }
94    }
95}
96
97#[async_trait]
98impl LLMProvider for MultiLLMRouter {
99    async fn complete(
100        &self,
101        messages: &[ChatMessage],
102        config: Option<&LLMConfig>,
103    ) -> Result<LLMResponse, LLMError> {
104        self.primary.complete(messages, config).await
105    }
106
107    async fn complete_stream(
108        &self,
109        messages: &[ChatMessage],
110        config: Option<&LLMConfig>,
111    ) -> Result<Box<dyn futures::Stream<Item = Result<LLMChunk, LLMError>> + Unpin + Send>, LLMError>
112    {
113        self.primary.complete_stream(messages, config).await
114    }
115
116    fn provider_name(&self) -> &str {
117        "multi-llm-router"
118    }
119
120    fn supports(&self, feature: LLMFeature) -> bool {
121        self.primary.supports(feature)
122    }
123}
124
125#[async_trait]
126impl LLMCapability for MultiLLMRouter {
127    async fn select_tool(
128        &self,
129        context: &TaskContext,
130        user_input: &str,
131    ) -> Result<ToolSelection, LLMError> {
132        let provider = self.get_tool_selector();
133        let capability = DefaultLLMCapability::new(provider);
134        capability.select_tool(context, user_input).await
135    }
136
137    async fn generate_tool_args(
138        &self,
139        tool_id: &str,
140        user_input: &str,
141        schema: &serde_json::Value,
142    ) -> Result<serde_json::Value, LLMError> {
143        let provider = self.get_tool_selector();
144        let capability = DefaultLLMCapability::new(provider);
145        capability
146            .generate_tool_args(tool_id, user_input, schema)
147            .await
148    }
149
150    async fn evaluate_yesno(
151        &self,
152        question: &str,
153        context: &TaskContext,
154    ) -> Result<(bool, String), LLMError> {
155        let provider = self.get_guard_evaluator();
156        let capability = DefaultLLMCapability::new(provider);
157        capability.evaluate_yesno(question, context).await
158    }
159
160    async fn classify(
161        &self,
162        input: &str,
163        categories: &[String],
164    ) -> Result<(String, f32), LLMError> {
165        let provider = self.get_classifier();
166        let capability = DefaultLLMCapability::new(provider);
167        capability.classify(input, categories).await
168    }
169
170    async fn process_task(
171        &self,
172        context: &TaskContext,
173        system_prompt: &str,
174    ) -> Result<LLMResponse, LLMError> {
175        let capability = DefaultLLMCapability::new(self.primary.clone());
176        capability.process_task(context, system_prompt).await
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use crate::mock::MockLLMProvider;
184    use ai_agents_core::{FinishReason, Role};
185    use std::collections::HashMap;
186
187    #[tokio::test]
188    async fn test_router_with_primary_only() {
189        let mut primary = MockLLMProvider::new("primary");
190        primary.add_response(LLMResponse::new("Hello from primary", FinishReason::Stop));
191
192        let router = MultiLLMRouter::new(Arc::new(primary));
193
194        let messages = vec![ChatMessage {
195            timestamp: None,
196            role: Role::User,
197            content: "Test".to_string(),
198            name: None,
199        }];
200
201        let response = router.complete(&messages, None).await.unwrap();
202        assert_eq!(response.content, "Hello from primary");
203    }
204
205    #[tokio::test]
206    async fn test_router_with_specialized_providers() {
207        let mut primary = MockLLMProvider::new("primary");
208        primary.add_response(LLMResponse::new("Primary response", FinishReason::Stop));
209
210        let mut tool_selector = MockLLMProvider::new("tool-selector");
211        tool_selector.add_response(LLMResponse::new(
212            r#"{"tool_id": "calculator", "confidence": 0.9}"#,
213            FinishReason::Stop,
214        ));
215
216        let mut guard = MockLLMProvider::new("guard");
217        guard.add_response(LLMResponse::new(
218            r#"{"answer": true, "reasoning": "Approved"}"#,
219            FinishReason::Stop,
220        ));
221
222        let router = MultiLLMRouter::new(Arc::new(primary))
223            .with_tool_selector(Arc::new(tool_selector))
224            .with_guard_evaluator(Arc::new(guard));
225
226        let context = TaskContext {
227            current_state: None,
228            available_tools: vec!["calculator".to_string()],
229            memory_slots: HashMap::new(),
230            recent_messages: vec![],
231        };
232
233        let tool_selection = router.select_tool(&context, "Do math").await.unwrap();
234        assert_eq!(tool_selection.tool_id, "calculator");
235        assert_eq!(tool_selection.confidence, 0.9);
236
237        let (answer, reasoning) = router
238            .evaluate_yesno("Is it safe?", &context)
239            .await
240            .unwrap();
241        assert!(answer);
242        assert_eq!(reasoning, "Approved");
243    }
244
245    #[tokio::test]
246    async fn test_router_fallback_to_primary() {
247        let mut primary = MockLLMProvider::new("primary");
248        primary.add_response(LLMResponse::new("Primary response", FinishReason::Stop));
249
250        let router = MultiLLMRouter::new(Arc::new(primary)).with_fallback(true);
251
252        let messages = vec![ChatMessage {
253            timestamp: None,
254            role: Role::User,
255            content: "Test".to_string(),
256            name: None,
257        }];
258
259        let response = router.complete(&messages, None).await.unwrap();
260        assert_eq!(response.content, "Primary response");
261    }
262
263    #[tokio::test]
264    async fn test_router_provider_name() {
265        let primary = MockLLMProvider::new("primary");
266        let router = MultiLLMRouter::new(Arc::new(primary));
267
268        assert_eq!(router.provider_name(), "multi-llm-router");
269    }
270
271    #[tokio::test]
272    async fn test_router_supports() {
273        let mut primary = MockLLMProvider::new("primary");
274        primary.set_feature_support(LLMFeature::Streaming, true);
275
276        let router = MultiLLMRouter::new(Arc::new(primary));
277
278        assert!(router.supports(LLMFeature::Streaming));
279    }
280
281    #[tokio::test]
282    async fn test_classify_with_specialized_provider() {
283        let primary = MockLLMProvider::new("primary");
284
285        let mut classifier = MockLLMProvider::new("classifier");
286        classifier.add_response(LLMResponse::new(
287            r#"{"category": "greeting", "confidence": 0.95}"#,
288            FinishReason::Stop,
289        ));
290
291        let router = MultiLLMRouter::new(Arc::new(primary)).with_classifier(Arc::new(classifier));
292
293        let categories = vec!["greeting".to_string(), "question".to_string()];
294        let (category, confidence) = router.classify("Hello!", &categories).await.unwrap();
295
296        assert_eq!(category, "greeting");
297        assert_eq!(confidence, 0.95);
298    }
299
300    #[tokio::test]
301    async fn test_process_task_uses_primary() {
302        let mut primary = MockLLMProvider::new("primary");
303        primary.add_response(LLMResponse::new(
304            "Task processed by primary",
305            FinishReason::Stop,
306        ));
307
308        let tool_selector = MockLLMProvider::new("tool-selector");
309
310        let router =
311            MultiLLMRouter::new(Arc::new(primary)).with_tool_selector(Arc::new(tool_selector));
312
313        let context = TaskContext {
314            current_state: None,
315            available_tools: vec![],
316            memory_slots: HashMap::new(),
317            recent_messages: vec![],
318        };
319
320        let response = router
321            .process_task(&context, "System prompt")
322            .await
323            .unwrap();
324
325        assert_eq!(response.content, "Task processed by primary");
326    }
327
328    #[tokio::test]
329    async fn test_generate_tool_args_with_specialized() {
330        let primary = MockLLMProvider::new("primary");
331
332        let mut tool_selector = MockLLMProvider::new("tool-selector");
333        tool_selector.add_response(LLMResponse::new(
334            r#"{"expression": "2 + 2"}"#,
335            FinishReason::Stop,
336        ));
337
338        let router =
339            MultiLLMRouter::new(Arc::new(primary)).with_tool_selector(Arc::new(tool_selector));
340
341        let schema = serde_json::json!({
342            "type": "object",
343            "properties": {
344                "expression": {"type": "string"}
345            }
346        });
347
348        let result = router
349            .generate_tool_args("calculator", "Calculate 2 + 2", &schema)
350            .await
351            .unwrap();
352
353        assert_eq!(result["expression"], "2 + 2");
354    }
355
356    #[test]
357    fn test_builder_pattern() {
358        let primary = MockLLMProvider::new("primary");
359        let tool_selector = MockLLMProvider::new("tool-selector");
360        let guard = MockLLMProvider::new("guard");
361        let classifier = MockLLMProvider::new("classifier");
362
363        let router = MultiLLMRouter::new(Arc::new(primary))
364            .with_tool_selector(Arc::new(tool_selector))
365            .with_guard_evaluator(Arc::new(guard))
366            .with_classifier(Arc::new(classifier))
367            .with_fallback(false);
368
369        assert_eq!(router.provider_name(), "multi-llm-router");
370        assert!(!router.enable_fallback);
371    }
372}