autoagents_llm/models/
mod.rs

1use crate::{builder::LLMBackend, error::LLMError};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use std::fmt::Debug;
5
6pub trait ModelListResponse: std::fmt::Debug {
7    fn get_models(&self) -> Vec<String>;
8    fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>>;
9    fn get_backend(&self) -> LLMBackend;
10}
11
12pub trait ModelListRawEntry: Debug {
13    fn get_id(&self) -> String;
14    fn get_created_at(&self) -> DateTime<Utc>;
15    fn get_raw(&self) -> serde_json::Value;
16}
17
18#[derive(Debug, Clone, Default)]
19pub struct ModelListRequest {
20    pub filter: Option<String>,
21}
22
23/// Trait for providers that support listing and retrieving model information.
24#[async_trait]
25pub trait ModelsProvider {
26    /// Asynchronously retrieves the list of available models ID's from the provider.
27    ///
28    /// # Arguments
29    ///
30    /// * `_request` - Optional filter by model ID
31    ///
32    /// # Returns
33    ///
34    /// List of model ID's or error
35    async fn list_models(
36        &self,
37        _request: Option<&ModelListRequest>,
38    ) -> Result<Box<dyn ModelListResponse>, LLMError> {
39        Err(LLMError::ProviderError(
40            "List Models not supported".to_string(),
41        ))
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48    use crate::builder::LLMBackend;
49    use crate::error::LLMError;
50    use async_trait::async_trait;
51    use chrono::TimeZone;
52    use chrono::{DateTime, Utc};
53    use serde_json::Value;
54
55    // Mock implementations for testing
56    #[derive(Debug, Clone)]
57    struct MockModelEntry {
58        id: String,
59        created_at: DateTime<Utc>,
60        extra_data: Value,
61    }
62
63    impl ModelListRawEntry for MockModelEntry {
64        fn get_id(&self) -> String {
65            self.id.clone()
66        }
67
68        fn get_created_at(&self) -> DateTime<Utc> {
69            self.created_at
70        }
71
72        fn get_raw(&self) -> Value {
73            self.extra_data.clone()
74        }
75    }
76
77    struct MockModelListResponse {
78        models: Vec<String>,
79        raw_entries: Vec<MockModelEntry>,
80        backend: LLMBackend,
81    }
82
83    impl std::fmt::Debug for MockModelListResponse {
84        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85            f.debug_struct("MockModelListResponse")
86                .field("models", &self.models)
87                .field("raw_entries", &self.raw_entries)
88                .field("backend", &self.backend)
89                .finish()
90        }
91    }
92
93    impl ModelListResponse for MockModelListResponse {
94        fn get_models(&self) -> Vec<String> {
95            self.models.clone()
96        }
97
98        fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
99            self.raw_entries
100                .iter()
101                .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
102                .collect()
103        }
104
105        fn get_backend(&self) -> LLMBackend {
106            self.backend.clone()
107        }
108    }
109
110    struct MockModelsProvider {
111        should_fail: bool,
112        models: Vec<String>,
113    }
114
115    impl MockModelsProvider {
116        fn new(models: Vec<String>) -> Self {
117            Self {
118                should_fail: false,
119                models,
120            }
121        }
122
123        fn with_failure() -> Self {
124            Self {
125                should_fail: true,
126                models: vec![],
127            }
128        }
129    }
130
131    #[async_trait]
132    impl ModelsProvider for MockModelsProvider {
133        async fn list_models(
134            &self,
135            _request: Option<&ModelListRequest>,
136        ) -> Result<Box<dyn ModelListResponse>, LLMError> {
137            if self.should_fail {
138                return Err(LLMError::ProviderError("Mock provider failed".to_string()));
139            }
140
141            let raw_entries = self
142                .models
143                .iter()
144                .enumerate()
145                .map(|(i, model)| MockModelEntry {
146                    id: model.clone(),
147                    created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
148                    extra_data: serde_json::json!({
149                        "index": i,
150                        "description": format!("Model {}", model)
151                    }),
152                })
153                .collect();
154
155            Ok(Box::new(MockModelListResponse {
156                models: self.models.clone(),
157                raw_entries,
158                backend: LLMBackend::OpenAI,
159            }))
160        }
161    }
162
163    // Default provider that always fails
164    struct DefaultModelsProvider;
165
166    #[async_trait]
167    impl ModelsProvider for DefaultModelsProvider {}
168
169    #[test]
170    fn test_model_list_request_default() {
171        let request = ModelListRequest::default();
172        assert!(request.filter.is_none());
173    }
174
175    #[test]
176    fn test_model_list_request_with_filter() {
177        let request = ModelListRequest {
178            filter: Some("gpt".to_string()),
179        };
180        assert_eq!(request.filter, Some("gpt".to_string()));
181    }
182
183    #[test]
184    fn test_model_list_request_clone() {
185        let request = ModelListRequest {
186            filter: Some("test".to_string()),
187        };
188        let cloned = request.clone();
189        assert_eq!(request.filter, cloned.filter);
190    }
191
192    #[test]
193    fn test_model_list_request_debug() {
194        let request = ModelListRequest {
195            filter: Some("debug_test".to_string()),
196        };
197        let debug_str = format!("{request:?}");
198        assert!(debug_str.contains("ModelListRequest"));
199        assert!(debug_str.contains("debug_test"));
200    }
201
202    #[test]
203    fn test_mock_model_entry_creation() {
204        let now = Utc.timestamp_opt(1640995200, 0).unwrap();
205        let entry = MockModelEntry {
206            id: "test-model".to_string(),
207            created_at: now,
208            extra_data: serde_json::json!({"key": "value"}),
209        };
210
211        assert_eq!(entry.get_id(), "test-model");
212        assert_eq!(entry.get_created_at(), now);
213        assert_eq!(entry.get_raw(), serde_json::json!({"key": "value"}));
214    }
215
216    #[test]
217    fn test_mock_model_entry_debug() {
218        let entry = MockModelEntry {
219            id: "debug-model".to_string(),
220            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
221            extra_data: serde_json::json!({"debug": true}),
222        };
223
224        let debug_str = format!("{entry:?}");
225        assert!(debug_str.contains("MockModelEntry"));
226        assert!(debug_str.contains("debug-model"));
227    }
228
229    #[test]
230    fn test_mock_model_entry_clone() {
231        let entry = MockModelEntry {
232            id: "clone-model".to_string(),
233            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
234            extra_data: serde_json::json!({"clone": true}),
235        };
236
237        let cloned = entry.clone();
238        assert_eq!(entry.get_id(), cloned.get_id());
239        assert_eq!(entry.get_created_at(), cloned.get_created_at());
240        assert_eq!(entry.get_raw(), cloned.get_raw());
241    }
242
243    #[test]
244    fn test_mock_model_list_response_get_models() {
245        let models = vec!["model1".to_string(), "model2".to_string()];
246        let response = MockModelListResponse {
247            models: models.clone(),
248            raw_entries: vec![],
249            backend: LLMBackend::OpenAI,
250        };
251
252        assert_eq!(response.get_models(), models);
253    }
254
255    #[test]
256    fn test_mock_model_list_response_get_models_raw() {
257        let raw_entries = vec![
258            MockModelEntry {
259                id: "model1".to_string(),
260                created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
261                extra_data: serde_json::json!({"index": 0}),
262            },
263            MockModelEntry {
264                id: "model2".to_string(),
265                created_at: Utc.timestamp_opt(1640995201, 0).unwrap(),
266                extra_data: serde_json::json!({"index": 1}),
267            },
268        ];
269
270        let response = MockModelListResponse {
271            models: vec!["model1".to_string(), "model2".to_string()],
272            raw_entries: raw_entries.clone(),
273            backend: LLMBackend::Anthropic,
274        };
275
276        let raw = response.get_models_raw();
277        assert_eq!(raw.len(), 2);
278        assert_eq!(raw[0].get_id(), "model1");
279        assert_eq!(raw[1].get_id(), "model2");
280    }
281
282    #[test]
283    fn test_mock_model_list_response_get_backend() {
284        let response = MockModelListResponse {
285            models: vec![],
286            raw_entries: vec![],
287            backend: LLMBackend::Google,
288        };
289
290        assert!(matches!(response.get_backend(), LLMBackend::Google));
291    }
292
293    #[tokio::test]
294    async fn test_mock_models_provider_success() {
295        let models = vec!["gpt-3.5-turbo".to_string(), "gpt-4".to_string()];
296        let provider = MockModelsProvider::new(models.clone());
297
298        let result = provider.list_models(None).await;
299        assert!(result.is_ok());
300
301        let response = result.unwrap();
302        assert_eq!(response.get_models(), models);
303        assert_eq!(response.get_models_raw().len(), 2);
304        assert!(matches!(response.get_backend(), LLMBackend::OpenAI));
305    }
306
307    #[tokio::test]
308    async fn test_mock_models_provider_with_request() {
309        let models = vec!["model1".to_string(), "model2".to_string()];
310        let provider = MockModelsProvider::new(models.clone());
311        let request = ModelListRequest {
312            filter: Some("gpt".to_string()),
313        };
314
315        let result = provider.list_models(Some(&request)).await;
316        assert!(result.is_ok());
317
318        let response = result.unwrap();
319        assert_eq!(response.get_models(), models);
320    }
321
322    #[tokio::test]
323    async fn test_mock_models_provider_failure() {
324        let provider = MockModelsProvider::with_failure();
325
326        let result = provider.list_models(None).await;
327        assert!(result.is_err());
328        assert!(result
329            .unwrap_err()
330            .to_string()
331            .contains("Mock provider failed"));
332    }
333
334    #[tokio::test]
335    async fn test_mock_models_provider_empty_models() {
336        let provider = MockModelsProvider::new(vec![]);
337
338        let result = provider.list_models(None).await;
339        assert!(result.is_ok());
340
341        let response = result.unwrap();
342        assert_eq!(response.get_models(), Vec::<String>::new());
343        assert_eq!(response.get_models_raw().len(), 0);
344    }
345
346    #[tokio::test]
347    async fn test_default_models_provider_not_supported() {
348        let provider = DefaultModelsProvider;
349
350        let result = provider.list_models(None).await;
351        assert!(result.is_err());
352        assert!(result
353            .unwrap_err()
354            .to_string()
355            .contains("List Models not supported"));
356    }
357
358    #[tokio::test]
359    async fn test_default_models_provider_with_request() {
360        let provider = DefaultModelsProvider;
361        let request = ModelListRequest {
362            filter: Some("test".to_string()),
363        };
364
365        let result = provider.list_models(Some(&request)).await;
366        assert!(result.is_err());
367        assert!(result
368            .unwrap_err()
369            .to_string()
370            .contains("List Models not supported"));
371    }
372
373    #[test]
374    fn test_model_list_raw_entry_trait_object() {
375        let entry = MockModelEntry {
376            id: "trait-test".to_string(),
377            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
378            extra_data: serde_json::json!({"test": "data"}),
379        };
380
381        let boxed: Box<dyn ModelListRawEntry> = Box::new(entry);
382        assert_eq!(boxed.get_id(), "trait-test");
383        assert_eq!(boxed.get_raw(), serde_json::json!({"test": "data"}));
384    }
385
386    #[test]
387    fn test_model_list_response_trait_object() {
388        let response = MockModelListResponse {
389            models: vec!["test-model".to_string()],
390            raw_entries: vec![],
391            backend: LLMBackend::Ollama,
392        };
393
394        let boxed: Box<dyn ModelListResponse> = Box::new(response);
395        assert_eq!(boxed.get_models(), vec!["test-model".to_string()]);
396        assert!(matches!(boxed.get_backend(), LLMBackend::Ollama));
397    }
398
399    #[test]
400    fn test_model_entry_with_complex_data() {
401        let complex_data = serde_json::json!({
402            "capabilities": ["chat", "completion"],
403            "max_tokens": 4096,
404            "pricing": {
405                "input": 0.0015,
406                "output": 0.002
407            }
408        });
409
410        let entry = MockModelEntry {
411            id: "complex-model".to_string(),
412            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
413            extra_data: complex_data.clone(),
414        };
415
416        assert_eq!(entry.get_raw(), complex_data);
417        assert_eq!(entry.get_raw()["capabilities"][0], "chat");
418        assert_eq!(entry.get_raw()["max_tokens"], 4096);
419    }
420
421    #[test]
422    fn test_model_entry_with_null_data() {
423        let entry = MockModelEntry {
424            id: "null-model".to_string(),
425            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
426            extra_data: serde_json::Value::Null,
427        };
428
429        assert_eq!(entry.get_raw(), serde_json::Value::Null);
430    }
431
432    #[test]
433    fn test_model_entry_time_ordering() {
434        let time1 = Utc.timestamp_opt(1640995200, 0).unwrap();
435        let time2 = time1 + chrono::Duration::seconds(1);
436
437        let entry1 = MockModelEntry {
438            id: "older".to_string(),
439            created_at: time1,
440            extra_data: serde_json::Value::Null,
441        };
442
443        let entry2 = MockModelEntry {
444            id: "newer".to_string(),
445            created_at: time2,
446            extra_data: serde_json::Value::Null,
447        };
448
449        assert!(entry1.get_created_at() < entry2.get_created_at());
450    }
451
452    #[test]
453    fn test_backend_variants() {
454        let backends = vec![
455            LLMBackend::OpenAI,
456            LLMBackend::Anthropic,
457            LLMBackend::Ollama,
458            LLMBackend::DeepSeek,
459            LLMBackend::XAI,
460            LLMBackend::Phind,
461            LLMBackend::Google,
462            LLMBackend::Groq,
463            LLMBackend::AzureOpenAI,
464        ];
465
466        for backend in backends {
467            let response = MockModelListResponse {
468                models: vec![],
469                raw_entries: vec![],
470                backend: backend.clone(),
471            };
472            // Use a more flexible assertion since matches! doesn't work with clone
473            let result_backend = response.get_backend();
474            assert!(std::mem::discriminant(&result_backend) == std::mem::discriminant(&backend));
475        }
476    }
477
478    #[tokio::test]
479    async fn test_models_provider_error_handling() {
480        let provider = MockModelsProvider::with_failure();
481
482        let result = provider.list_models(None).await;
483        match result {
484            Err(LLMError::ProviderError(msg)) => {
485                assert_eq!(msg, "Mock provider failed");
486            }
487            _ => panic!("Expected ProviderError"),
488        }
489    }
490
491    #[tokio::test]
492    async fn test_models_provider_with_many_models() {
493        let models: Vec<String> = (0..100).map(|i| format!("model-{i:03}")).collect();
494        let provider = MockModelsProvider::new(models.clone());
495
496        let result = provider.list_models(None).await;
497        assert!(result.is_ok());
498
499        let response = result.unwrap();
500        assert_eq!(response.get_models().len(), 100);
501        assert_eq!(response.get_models_raw().len(), 100);
502        assert_eq!(response.get_models()[0], "model-000");
503        assert_eq!(response.get_models()[99], "model-099");
504    }
505
506    #[test]
507    fn test_model_list_request_with_empty_filter() {
508        let request = ModelListRequest {
509            filter: Some("".to_string()),
510        };
511        assert_eq!(request.filter, Some("".to_string()));
512    }
513
514    #[test]
515    fn test_model_list_request_with_special_chars() {
516        let request = ModelListRequest {
517            filter: Some("model-name_with.special-chars".to_string()),
518        };
519        assert_eq!(
520            request.filter,
521            Some("model-name_with.special-chars".to_string())
522        );
523    }
524
525    #[test]
526    fn test_model_entry_with_empty_id() {
527        let entry = MockModelEntry {
528            id: "".to_string(),
529            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
530            extra_data: serde_json::Value::Null,
531        };
532        assert_eq!(entry.get_id(), "");
533    }
534
535    #[test]
536    fn test_model_entry_with_unicode_id() {
537        let entry = MockModelEntry {
538            id: "模型-测试-🤖".to_string(),
539            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
540            extra_data: serde_json::Value::Null,
541        };
542        assert_eq!(entry.get_id(), "模型-测试-🤖");
543    }
544}