autoagents_llm/models/
mod.rs

1use crate::{builder::LLMBackend, error::LLMError};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::fmt::Debug;
7
8pub trait ModelListResponse: std::fmt::Debug {
9    fn get_models(&self) -> Vec<String>;
10    fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>>;
11    fn get_backend(&self) -> LLMBackend;
12}
13
14pub trait ModelListRawEntry: Debug {
15    fn get_id(&self) -> String;
16    fn get_created_at(&self) -> DateTime<Utc>;
17    fn get_raw(&self) -> serde_json::Value;
18}
19
20#[derive(Debug, Clone, Default)]
21pub struct ModelListRequest {
22    pub filter: Option<String>,
23}
24
25/// Trait for providers that support listing and retrieving model information.
26#[async_trait]
27pub trait ModelsProvider {
28    /// Asynchronously retrieves the list of available models ID's from the provider.
29    ///
30    /// # Arguments
31    ///
32    /// * `_request` - Optional filter by model ID
33    ///
34    /// # Returns
35    ///
36    /// List of model ID's or error
37    async fn list_models(
38        &self,
39        _request: Option<&ModelListRequest>,
40    ) -> Result<Box<dyn ModelListResponse>, LLMError> {
41        Err(LLMError::ProviderError(
42            "List Models not supported".to_string(),
43        ))
44    }
45}
46
47/// Standard model entry structure used by OpenAI-compatible providers
48#[derive(Clone, Debug, Deserialize, Serialize)]
49pub struct StandardModelEntry {
50    pub id: String,
51    pub created: Option<u64>,
52    #[serde(flatten)]
53    pub extra: Value,
54}
55
56impl ModelListRawEntry for StandardModelEntry {
57    fn get_id(&self) -> String {
58        self.id.clone()
59    }
60
61    fn get_created_at(&self) -> DateTime<Utc> {
62        self.created
63            .map(|t| DateTime::from_timestamp(t as i64, 0).unwrap_or_default())
64            .unwrap_or_default()
65    }
66
67    fn get_raw(&self) -> Value {
68        self.extra.clone()
69    }
70}
71
72/// Inner structure for model list response data (serializable)
73#[derive(Clone, Debug, Deserialize, Serialize)]
74pub struct StandardModelListResponseInner {
75    pub data: Vec<StandardModelEntry>,
76}
77
78/// Standard model list response structure used by OpenAI-compatible providers
79#[derive(Clone, Debug)]
80pub struct StandardModelListResponse {
81    pub inner: StandardModelListResponseInner,
82    pub backend: LLMBackend,
83}
84
85impl ModelListResponse for StandardModelListResponse {
86    fn get_models(&self) -> Vec<String> {
87        self.inner.data.iter().map(|e| e.id.clone()).collect()
88    }
89
90    fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
91        self.inner
92            .data
93            .iter()
94            .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
95            .collect()
96    }
97
98    fn get_backend(&self) -> LLMBackend {
99        self.backend.clone()
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::builder::LLMBackend;
107    use crate::error::LLMError;
108    use async_trait::async_trait;
109    use chrono::TimeZone;
110    use chrono::{DateTime, Utc};
111    use serde_json::Value;
112
113    // Mock implementations for testing
114    #[derive(Debug, Clone)]
115    struct MockModelEntry {
116        id: String,
117        created_at: DateTime<Utc>,
118        extra_data: Value,
119    }
120
121    impl ModelListRawEntry for MockModelEntry {
122        fn get_id(&self) -> String {
123            self.id.clone()
124        }
125
126        fn get_created_at(&self) -> DateTime<Utc> {
127            self.created_at
128        }
129
130        fn get_raw(&self) -> Value {
131            self.extra_data.clone()
132        }
133    }
134
135    struct MockModelListResponse {
136        models: Vec<String>,
137        raw_entries: Vec<MockModelEntry>,
138        backend: LLMBackend,
139    }
140
141    impl std::fmt::Debug for MockModelListResponse {
142        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143            f.debug_struct("MockModelListResponse")
144                .field("models", &self.models)
145                .field("raw_entries", &self.raw_entries)
146                .field("backend", &self.backend)
147                .finish()
148        }
149    }
150
151    impl ModelListResponse for MockModelListResponse {
152        fn get_models(&self) -> Vec<String> {
153            self.models.clone()
154        }
155
156        fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
157            self.raw_entries
158                .iter()
159                .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
160                .collect()
161        }
162
163        fn get_backend(&self) -> LLMBackend {
164            self.backend.clone()
165        }
166    }
167
168    struct MockModelsProvider {
169        should_fail: bool,
170        models: Vec<String>,
171    }
172
173    impl MockModelsProvider {
174        fn new(models: Vec<String>) -> Self {
175            Self {
176                should_fail: false,
177                models,
178            }
179        }
180
181        fn with_failure() -> Self {
182            Self {
183                should_fail: true,
184                models: vec![],
185            }
186        }
187    }
188
189    #[async_trait]
190    impl ModelsProvider for MockModelsProvider {
191        async fn list_models(
192            &self,
193            _request: Option<&ModelListRequest>,
194        ) -> Result<Box<dyn ModelListResponse>, LLMError> {
195            if self.should_fail {
196                return Err(LLMError::ProviderError("Mock provider failed".to_string()));
197            }
198
199            let raw_entries = self
200                .models
201                .iter()
202                .enumerate()
203                .map(|(i, model)| MockModelEntry {
204                    id: model.clone(),
205                    created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
206                    extra_data: serde_json::json!({
207                        "index": i,
208                        "description": format!("Model {}", model)
209                    }),
210                })
211                .collect();
212
213            Ok(Box::new(MockModelListResponse {
214                models: self.models.clone(),
215                raw_entries,
216                backend: LLMBackend::OpenAI,
217            }))
218        }
219    }
220
221    // Default provider that always fails
222    struct DefaultModelsProvider;
223
224    #[async_trait]
225    impl ModelsProvider for DefaultModelsProvider {}
226
227    #[test]
228    fn test_model_list_request_default() {
229        let request = ModelListRequest::default();
230        assert!(request.filter.is_none());
231    }
232
233    #[test]
234    fn test_model_list_request_with_filter() {
235        let request = ModelListRequest {
236            filter: Some("gpt".to_string()),
237        };
238        assert_eq!(request.filter, Some("gpt".to_string()));
239    }
240
241    #[test]
242    fn test_model_list_request_clone() {
243        let request = ModelListRequest {
244            filter: Some("test".to_string()),
245        };
246        let cloned = request.clone();
247        assert_eq!(request.filter, cloned.filter);
248    }
249
250    #[test]
251    fn test_model_list_request_debug() {
252        let request = ModelListRequest {
253            filter: Some("debug_test".to_string()),
254        };
255        let debug_str = format!("{request:?}");
256        assert!(debug_str.contains("ModelListRequest"));
257        assert!(debug_str.contains("debug_test"));
258    }
259
260    #[test]
261    fn test_mock_model_entry_creation() {
262        let now = Utc.timestamp_opt(1640995200, 0).unwrap();
263        let entry = MockModelEntry {
264            id: "test-model".to_string(),
265            created_at: now,
266            extra_data: serde_json::json!({"key": "value"}),
267        };
268
269        assert_eq!(entry.get_id(), "test-model");
270        assert_eq!(entry.get_created_at(), now);
271        assert_eq!(entry.get_raw(), serde_json::json!({"key": "value"}));
272    }
273
274    #[test]
275    fn test_mock_model_entry_debug() {
276        let entry = MockModelEntry {
277            id: "debug-model".to_string(),
278            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
279            extra_data: serde_json::json!({"debug": true}),
280        };
281
282        let debug_str = format!("{entry:?}");
283        assert!(debug_str.contains("MockModelEntry"));
284        assert!(debug_str.contains("debug-model"));
285    }
286
287    #[test]
288    fn test_mock_model_entry_clone() {
289        let entry = MockModelEntry {
290            id: "clone-model".to_string(),
291            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
292            extra_data: serde_json::json!({"clone": true}),
293        };
294
295        let cloned = entry.clone();
296        assert_eq!(entry.get_id(), cloned.get_id());
297        assert_eq!(entry.get_created_at(), cloned.get_created_at());
298        assert_eq!(entry.get_raw(), cloned.get_raw());
299    }
300
301    #[test]
302    fn test_mock_model_list_response_get_models() {
303        let models = vec!["model1".to_string(), "model2".to_string()];
304        let response = MockModelListResponse {
305            models: models.clone(),
306            raw_entries: vec![],
307            backend: LLMBackend::OpenAI,
308        };
309
310        assert_eq!(response.get_models(), models);
311    }
312
313    #[test]
314    fn test_mock_model_list_response_get_models_raw() {
315        let raw_entries = vec![
316            MockModelEntry {
317                id: "model1".to_string(),
318                created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
319                extra_data: serde_json::json!({"index": 0}),
320            },
321            MockModelEntry {
322                id: "model2".to_string(),
323                created_at: Utc.timestamp_opt(1640995201, 0).unwrap(),
324                extra_data: serde_json::json!({"index": 1}),
325            },
326        ];
327
328        let response = MockModelListResponse {
329            models: vec!["model1".to_string(), "model2".to_string()],
330            raw_entries: raw_entries.clone(),
331            backend: LLMBackend::Anthropic,
332        };
333
334        let raw = response.get_models_raw();
335        assert_eq!(raw.len(), 2);
336        assert_eq!(raw[0].get_id(), "model1");
337        assert_eq!(raw[1].get_id(), "model2");
338    }
339
340    #[test]
341    fn test_mock_model_list_response_get_backend() {
342        let response = MockModelListResponse {
343            models: vec![],
344            raw_entries: vec![],
345            backend: LLMBackend::Google,
346        };
347
348        assert!(matches!(response.get_backend(), LLMBackend::Google));
349    }
350
351    #[tokio::test]
352    async fn test_mock_models_provider_success() {
353        let models = vec!["gpt-3.5-turbo".to_string(), "gpt-4".to_string()];
354        let provider = MockModelsProvider::new(models.clone());
355
356        let result = provider.list_models(None).await;
357        assert!(result.is_ok());
358
359        let response = result.unwrap();
360        assert_eq!(response.get_models(), models);
361        assert_eq!(response.get_models_raw().len(), 2);
362        assert!(matches!(response.get_backend(), LLMBackend::OpenAI));
363    }
364
365    #[tokio::test]
366    async fn test_mock_models_provider_with_request() {
367        let models = vec!["model1".to_string(), "model2".to_string()];
368        let provider = MockModelsProvider::new(models.clone());
369        let request = ModelListRequest {
370            filter: Some("gpt".to_string()),
371        };
372
373        let result = provider.list_models(Some(&request)).await;
374        assert!(result.is_ok());
375
376        let response = result.unwrap();
377        assert_eq!(response.get_models(), models);
378    }
379
380    #[tokio::test]
381    async fn test_mock_models_provider_failure() {
382        let provider = MockModelsProvider::with_failure();
383
384        let result = provider.list_models(None).await;
385        assert!(result.is_err());
386        assert!(result
387            .unwrap_err()
388            .to_string()
389            .contains("Mock provider failed"));
390    }
391
392    #[tokio::test]
393    async fn test_mock_models_provider_empty_models() {
394        let provider = MockModelsProvider::new(vec![]);
395
396        let result = provider.list_models(None).await;
397        assert!(result.is_ok());
398
399        let response = result.unwrap();
400        assert_eq!(response.get_models(), Vec::<String>::new());
401        assert_eq!(response.get_models_raw().len(), 0);
402    }
403
404    #[tokio::test]
405    async fn test_default_models_provider_not_supported() {
406        let provider = DefaultModelsProvider;
407
408        let result = provider.list_models(None).await;
409        assert!(result.is_err());
410        assert!(result
411            .unwrap_err()
412            .to_string()
413            .contains("List Models not supported"));
414    }
415
416    #[tokio::test]
417    async fn test_default_models_provider_with_request() {
418        let provider = DefaultModelsProvider;
419        let request = ModelListRequest {
420            filter: Some("test".to_string()),
421        };
422
423        let result = provider.list_models(Some(&request)).await;
424        assert!(result.is_err());
425        assert!(result
426            .unwrap_err()
427            .to_string()
428            .contains("List Models not supported"));
429    }
430
431    #[test]
432    fn test_model_list_raw_entry_trait_object() {
433        let entry = MockModelEntry {
434            id: "trait-test".to_string(),
435            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
436            extra_data: serde_json::json!({"test": "data"}),
437        };
438
439        let boxed: Box<dyn ModelListRawEntry> = Box::new(entry);
440        assert_eq!(boxed.get_id(), "trait-test");
441        assert_eq!(boxed.get_raw(), serde_json::json!({"test": "data"}));
442    }
443
444    #[test]
445    fn test_model_list_response_trait_object() {
446        let response = MockModelListResponse {
447            models: vec!["test-model".to_string()],
448            raw_entries: vec![],
449            backend: LLMBackend::Ollama,
450        };
451
452        let boxed: Box<dyn ModelListResponse> = Box::new(response);
453        assert_eq!(boxed.get_models(), vec!["test-model".to_string()]);
454        assert!(matches!(boxed.get_backend(), LLMBackend::Ollama));
455    }
456
457    #[test]
458    fn test_model_entry_with_complex_data() {
459        let complex_data = serde_json::json!({
460            "capabilities": ["chat", "completion"],
461            "max_tokens": 4096,
462            "pricing": {
463                "input": 0.0015,
464                "output": 0.002
465            }
466        });
467
468        let entry = MockModelEntry {
469            id: "complex-model".to_string(),
470            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
471            extra_data: complex_data.clone(),
472        };
473
474        assert_eq!(entry.get_raw(), complex_data);
475        assert_eq!(entry.get_raw()["capabilities"][0], "chat");
476        assert_eq!(entry.get_raw()["max_tokens"], 4096);
477    }
478
479    #[test]
480    fn test_model_entry_with_null_data() {
481        let entry = MockModelEntry {
482            id: "null-model".to_string(),
483            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
484            extra_data: serde_json::Value::Null,
485        };
486
487        assert_eq!(entry.get_raw(), serde_json::Value::Null);
488    }
489
490    #[test]
491    fn test_model_entry_time_ordering() {
492        let time1 = Utc.timestamp_opt(1640995200, 0).unwrap();
493        let time2 = time1 + chrono::Duration::seconds(1);
494
495        let entry1 = MockModelEntry {
496            id: "older".to_string(),
497            created_at: time1,
498            extra_data: serde_json::Value::Null,
499        };
500
501        let entry2 = MockModelEntry {
502            id: "newer".to_string(),
503            created_at: time2,
504            extra_data: serde_json::Value::Null,
505        };
506
507        assert!(entry1.get_created_at() < entry2.get_created_at());
508    }
509
510    #[test]
511    fn test_backend_variants() {
512        let backends = vec![
513            LLMBackend::OpenAI,
514            LLMBackend::Anthropic,
515            LLMBackend::Ollama,
516            LLMBackend::DeepSeek,
517            LLMBackend::XAI,
518            LLMBackend::Phind,
519            LLMBackend::Google,
520            LLMBackend::Groq,
521            LLMBackend::AzureOpenAI,
522        ];
523
524        for backend in backends {
525            let response = MockModelListResponse {
526                models: vec![],
527                raw_entries: vec![],
528                backend: backend.clone(),
529            };
530            // Use a more flexible assertion since matches! doesn't work with clone
531            let result_backend = response.get_backend();
532            assert!(std::mem::discriminant(&result_backend) == std::mem::discriminant(&backend));
533        }
534    }
535
536    #[tokio::test]
537    async fn test_models_provider_error_handling() {
538        let provider = MockModelsProvider::with_failure();
539
540        let result = provider.list_models(None).await;
541        match result {
542            Err(LLMError::ProviderError(msg)) => {
543                assert_eq!(msg, "Mock provider failed");
544            }
545            _ => panic!("Expected ProviderError"),
546        }
547    }
548
549    #[tokio::test]
550    async fn test_models_provider_with_many_models() {
551        let models: Vec<String> = (0..100).map(|i| format!("model-{i:03}")).collect();
552        let provider = MockModelsProvider::new(models.clone());
553
554        let result = provider.list_models(None).await;
555        assert!(result.is_ok());
556
557        let response = result.unwrap();
558        assert_eq!(response.get_models().len(), 100);
559        assert_eq!(response.get_models_raw().len(), 100);
560        assert_eq!(response.get_models()[0], "model-000");
561        assert_eq!(response.get_models()[99], "model-099");
562    }
563
564    #[test]
565    fn test_model_list_request_with_empty_filter() {
566        let request = ModelListRequest {
567            filter: Some("".to_string()),
568        };
569        assert_eq!(request.filter, Some("".to_string()));
570    }
571
572    #[test]
573    fn test_model_list_request_with_special_chars() {
574        let request = ModelListRequest {
575            filter: Some("model-name_with.special-chars".to_string()),
576        };
577        assert_eq!(
578            request.filter,
579            Some("model-name_with.special-chars".to_string())
580        );
581    }
582
583    #[test]
584    fn test_model_entry_with_empty_id() {
585        let entry = MockModelEntry {
586            id: "".to_string(),
587            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
588            extra_data: serde_json::Value::Null,
589        };
590        assert_eq!(entry.get_id(), "");
591    }
592
593    #[test]
594    fn test_model_entry_with_unicode_id() {
595        let entry = MockModelEntry {
596            id: "模型-测试-🤖".to_string(),
597            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
598            extra_data: serde_json::Value::Null,
599        };
600        assert_eq!(entry.get_id(), "模型-测试-🤖");
601    }
602}