Skip to main content

autoagents_llm/models/
mod.rs

1use crate::{builder::LLMBackend, error::LLMError};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::fmt::Debug;
7
8pub trait ModelListResponse: std::fmt::Debug {
9    fn get_models(&self) -> Vec<String>;
10    fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>>;
11    fn get_backend(&self) -> LLMBackend;
12}
13
14pub trait ModelListRawEntry: Debug {
15    fn get_id(&self) -> String;
16    fn get_created_at(&self) -> DateTime<Utc>;
17    fn get_raw(&self) -> serde_json::Value;
18}
19
20#[derive(Debug, Clone, Default)]
21pub struct ModelListRequest {
22    pub filter: Option<String>,
23}
24
25/// Trait for providers that support listing and retrieving model information.
26#[async_trait]
27pub trait ModelsProvider {
28    /// Asynchronously retrieves the list of available models ID's from the provider.
29    ///
30    /// # Arguments
31    ///
32    /// * `_request` - Optional filter by model ID
33    ///
34    /// # Returns
35    ///
36    /// List of model ID's or error
37    async fn list_models(
38        &self,
39        _request: Option<&ModelListRequest>,
40    ) -> Result<Box<dyn ModelListResponse>, LLMError> {
41        Err(LLMError::ProviderError(
42            "List Models not supported".to_string(),
43        ))
44    }
45}
46
47/// Standard model entry structure used by OpenAI-compatible providers
48#[derive(Clone, Debug, Deserialize, Serialize)]
49pub struct StandardModelEntry {
50    pub id: String,
51    pub created: Option<u64>,
52    #[serde(flatten)]
53    pub extra: Value,
54}
55
56impl ModelListRawEntry for StandardModelEntry {
57    fn get_id(&self) -> String {
58        self.id.clone()
59    }
60
61    fn get_created_at(&self) -> DateTime<Utc> {
62        self.created
63            .map(|t| DateTime::from_timestamp(t as i64, 0).unwrap_or_default())
64            .unwrap_or_default()
65    }
66
67    fn get_raw(&self) -> Value {
68        self.extra.clone()
69    }
70}
71
72/// Inner structure for model list response data (serializable)
73#[derive(Clone, Debug, Deserialize, Serialize)]
74pub struct StandardModelListResponseInner {
75    pub data: Vec<StandardModelEntry>,
76}
77
78/// Standard model list response structure used by OpenAI-compatible providers
79#[derive(Clone, Debug)]
80pub struct StandardModelListResponse {
81    pub inner: StandardModelListResponseInner,
82    pub backend: LLMBackend,
83}
84
85impl ModelListResponse for StandardModelListResponse {
86    fn get_models(&self) -> Vec<String> {
87        self.inner.data.iter().map(|e| e.id.clone()).collect()
88    }
89
90    fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
91        self.inner
92            .data
93            .iter()
94            .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
95            .collect()
96    }
97
98    fn get_backend(&self) -> LLMBackend {
99        self.backend.clone()
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::builder::LLMBackend;
107    use crate::error::LLMError;
108    use async_trait::async_trait;
109    use chrono::TimeZone;
110    use chrono::{DateTime, Utc};
111    use serde_json::Value;
112
113    // Mock implementations for testing
114    #[derive(Debug, Clone)]
115    struct MockModelEntry {
116        id: String,
117        created_at: DateTime<Utc>,
118        extra_data: Value,
119    }
120
121    impl ModelListRawEntry for MockModelEntry {
122        fn get_id(&self) -> String {
123            self.id.clone()
124        }
125
126        fn get_created_at(&self) -> DateTime<Utc> {
127            self.created_at
128        }
129
130        fn get_raw(&self) -> Value {
131            self.extra_data.clone()
132        }
133    }
134
135    struct MockModelListResponse {
136        models: Vec<String>,
137        raw_entries: Vec<MockModelEntry>,
138        backend: LLMBackend,
139    }
140
141    impl std::fmt::Debug for MockModelListResponse {
142        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143            f.debug_struct("MockModelListResponse")
144                .field("models", &self.models)
145                .field("raw_entries", &self.raw_entries)
146                .field("backend", &self.backend)
147                .finish()
148        }
149    }
150
151    impl ModelListResponse for MockModelListResponse {
152        fn get_models(&self) -> Vec<String> {
153            self.models.clone()
154        }
155
156        fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
157            self.raw_entries
158                .iter()
159                .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
160                .collect()
161        }
162
163        fn get_backend(&self) -> LLMBackend {
164            self.backend.clone()
165        }
166    }
167
168    struct MockModelsProvider {
169        should_fail: bool,
170        models: Vec<String>,
171    }
172
173    impl MockModelsProvider {
174        fn new(models: Vec<String>) -> Self {
175            Self {
176                should_fail: false,
177                models,
178            }
179        }
180
181        fn with_failure() -> Self {
182            Self {
183                should_fail: true,
184                models: vec![],
185            }
186        }
187    }
188
189    #[async_trait]
190    impl ModelsProvider for MockModelsProvider {
191        async fn list_models(
192            &self,
193            _request: Option<&ModelListRequest>,
194        ) -> Result<Box<dyn ModelListResponse>, LLMError> {
195            if self.should_fail {
196                return Err(LLMError::ProviderError("Mock provider failed".to_string()));
197            }
198
199            let raw_entries = self
200                .models
201                .iter()
202                .enumerate()
203                .map(|(i, model)| MockModelEntry {
204                    id: model.clone(),
205                    created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
206                    extra_data: serde_json::json!({
207                        "index": i,
208                        "description": format!("Model {}", model)
209                    }),
210                })
211                .collect();
212
213            Ok(Box::new(MockModelListResponse {
214                models: self.models.clone(),
215                raw_entries,
216                backend: LLMBackend::OpenAI,
217            }))
218        }
219    }
220
221    // Default provider that always fails
222    struct DefaultModelsProvider;
223
224    #[async_trait]
225    impl ModelsProvider for DefaultModelsProvider {}
226
227    #[test]
228    fn test_model_list_request_default() {
229        let request = ModelListRequest::default();
230        assert!(request.filter.is_none());
231    }
232
233    #[test]
234    fn test_model_list_request_with_filter() {
235        let request = ModelListRequest {
236            filter: Some("gpt".to_string()),
237        };
238        assert_eq!(request.filter, Some("gpt".to_string()));
239    }
240
241    #[test]
242    fn test_model_list_request_clone() {
243        let request = ModelListRequest {
244            filter: Some("test".to_string()),
245        };
246        let cloned = request.clone();
247        assert_eq!(request.filter, cloned.filter);
248    }
249
250    #[test]
251    fn test_model_list_request_debug() {
252        let request = ModelListRequest {
253            filter: Some("debug_test".to_string()),
254        };
255        let debug_str = format!("{request:?}");
256        assert!(debug_str.contains("ModelListRequest"));
257        assert!(debug_str.contains("debug_test"));
258    }
259
260    #[test]
261    fn test_mock_model_entry_creation() {
262        let now = Utc.timestamp_opt(1640995200, 0).unwrap();
263        let entry = MockModelEntry {
264            id: "test-model".to_string(),
265            created_at: now,
266            extra_data: serde_json::json!({"key": "value"}),
267        };
268
269        assert_eq!(entry.get_id(), "test-model");
270        assert_eq!(entry.get_created_at(), now);
271        assert_eq!(entry.get_raw(), serde_json::json!({"key": "value"}));
272    }
273
274    #[test]
275    fn test_mock_model_entry_debug() {
276        let entry = MockModelEntry {
277            id: "debug-model".to_string(),
278            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
279            extra_data: serde_json::json!({"debug": true}),
280        };
281
282        let debug_str = format!("{entry:?}");
283        assert!(debug_str.contains("MockModelEntry"));
284        assert!(debug_str.contains("debug-model"));
285    }
286
287    #[test]
288    fn test_mock_model_entry_clone() {
289        let entry = MockModelEntry {
290            id: "clone-model".to_string(),
291            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
292            extra_data: serde_json::json!({"clone": true}),
293        };
294
295        let cloned = entry.clone();
296        assert_eq!(entry.get_id(), cloned.get_id());
297        assert_eq!(entry.get_created_at(), cloned.get_created_at());
298        assert_eq!(entry.get_raw(), cloned.get_raw());
299    }
300
301    #[test]
302    fn test_mock_model_list_response_get_models() {
303        let models = vec!["model1".to_string(), "model2".to_string()];
304        let response = MockModelListResponse {
305            models: models.clone(),
306            raw_entries: vec![],
307            backend: LLMBackend::OpenAI,
308        };
309
310        assert_eq!(response.get_models(), models);
311    }
312
313    #[test]
314    fn test_mock_model_list_response_get_models_raw() {
315        let raw_entries = vec![
316            MockModelEntry {
317                id: "model1".to_string(),
318                created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
319                extra_data: serde_json::json!({"index": 0}),
320            },
321            MockModelEntry {
322                id: "model2".to_string(),
323                created_at: Utc.timestamp_opt(1640995201, 0).unwrap(),
324                extra_data: serde_json::json!({"index": 1}),
325            },
326        ];
327
328        let response = MockModelListResponse {
329            models: vec!["model1".to_string(), "model2".to_string()],
330            raw_entries: raw_entries.clone(),
331            backend: LLMBackend::Anthropic,
332        };
333
334        let raw = response.get_models_raw();
335        assert_eq!(raw.len(), 2);
336        assert_eq!(raw[0].get_id(), "model1");
337        assert_eq!(raw[1].get_id(), "model2");
338    }
339
340    #[test]
341    fn test_mock_model_list_response_get_backend() {
342        let response = MockModelListResponse {
343            models: vec![],
344            raw_entries: vec![],
345            backend: LLMBackend::Google,
346        };
347
348        assert!(matches!(response.get_backend(), LLMBackend::Google));
349    }
350
351    #[tokio::test]
352    async fn test_mock_models_provider_success() {
353        let models = vec!["gpt-3.5-turbo".to_string(), "gpt-4".to_string()];
354        let provider = MockModelsProvider::new(models.clone());
355
356        let result = provider.list_models(None).await;
357        assert!(result.is_ok());
358
359        let response = result.unwrap();
360        assert_eq!(response.get_models(), models);
361        assert_eq!(response.get_models_raw().len(), 2);
362        assert!(matches!(response.get_backend(), LLMBackend::OpenAI));
363    }
364
365    #[tokio::test]
366    async fn test_mock_models_provider_with_request() {
367        let models = vec!["model1".to_string(), "model2".to_string()];
368        let provider = MockModelsProvider::new(models.clone());
369        let request = ModelListRequest {
370            filter: Some("gpt".to_string()),
371        };
372
373        let result = provider.list_models(Some(&request)).await;
374        assert!(result.is_ok());
375
376        let response = result.unwrap();
377        assert_eq!(response.get_models(), models);
378    }
379
380    #[tokio::test]
381    async fn test_mock_models_provider_failure() {
382        let provider = MockModelsProvider::with_failure();
383
384        let result = provider.list_models(None).await;
385        assert!(result.is_err());
386        assert!(
387            result
388                .unwrap_err()
389                .to_string()
390                .contains("Mock provider failed")
391        );
392    }
393
394    #[tokio::test]
395    async fn test_mock_models_provider_empty_models() {
396        let provider = MockModelsProvider::new(vec![]);
397
398        let result = provider.list_models(None).await;
399        assert!(result.is_ok());
400
401        let response = result.unwrap();
402        assert_eq!(response.get_models(), Vec::<String>::new());
403        assert_eq!(response.get_models_raw().len(), 0);
404    }
405
406    #[tokio::test]
407    async fn test_default_models_provider_not_supported() {
408        let provider = DefaultModelsProvider;
409
410        let result = provider.list_models(None).await;
411        assert!(result.is_err());
412        assert!(
413            result
414                .unwrap_err()
415                .to_string()
416                .contains("List Models not supported")
417        );
418    }
419
420    #[tokio::test]
421    async fn test_default_models_provider_with_request() {
422        let provider = DefaultModelsProvider;
423        let request = ModelListRequest {
424            filter: Some("test".to_string()),
425        };
426
427        let result = provider.list_models(Some(&request)).await;
428        assert!(result.is_err());
429        assert!(
430            result
431                .unwrap_err()
432                .to_string()
433                .contains("List Models not supported")
434        );
435    }
436
437    #[test]
438    fn test_model_list_raw_entry_trait_object() {
439        let entry = MockModelEntry {
440            id: "trait-test".to_string(),
441            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
442            extra_data: serde_json::json!({"test": "data"}),
443        };
444
445        let boxed: Box<dyn ModelListRawEntry> = Box::new(entry);
446        assert_eq!(boxed.get_id(), "trait-test");
447        assert_eq!(boxed.get_raw(), serde_json::json!({"test": "data"}));
448    }
449
450    #[test]
451    fn test_model_list_response_trait_object() {
452        let response = MockModelListResponse {
453            models: vec!["test-model".to_string()],
454            raw_entries: vec![],
455            backend: LLMBackend::Ollama,
456        };
457
458        let boxed: Box<dyn ModelListResponse> = Box::new(response);
459        assert_eq!(boxed.get_models(), vec!["test-model".to_string()]);
460        assert!(matches!(boxed.get_backend(), LLMBackend::Ollama));
461    }
462
463    #[test]
464    fn test_model_entry_with_complex_data() {
465        let complex_data = serde_json::json!({
466            "capabilities": ["chat", "completion"],
467            "max_tokens": 4096,
468            "pricing": {
469                "input": 0.0015,
470                "output": 0.002
471            }
472        });
473
474        let entry = MockModelEntry {
475            id: "complex-model".to_string(),
476            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
477            extra_data: complex_data.clone(),
478        };
479
480        assert_eq!(entry.get_raw(), complex_data);
481        assert_eq!(entry.get_raw()["capabilities"][0], "chat");
482        assert_eq!(entry.get_raw()["max_tokens"], 4096);
483    }
484
485    #[test]
486    fn test_model_entry_with_null_data() {
487        let entry = MockModelEntry {
488            id: "null-model".to_string(),
489            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
490            extra_data: serde_json::Value::Null,
491        };
492
493        assert_eq!(entry.get_raw(), serde_json::Value::Null);
494    }
495
496    #[test]
497    fn test_model_entry_time_ordering() {
498        let time1 = Utc.timestamp_opt(1640995200, 0).unwrap();
499        let time2 = time1 + chrono::Duration::seconds(1);
500
501        let entry1 = MockModelEntry {
502            id: "older".to_string(),
503            created_at: time1,
504            extra_data: serde_json::Value::Null,
505        };
506
507        let entry2 = MockModelEntry {
508            id: "newer".to_string(),
509            created_at: time2,
510            extra_data: serde_json::Value::Null,
511        };
512
513        assert!(entry1.get_created_at() < entry2.get_created_at());
514    }
515
516    #[test]
517    fn test_backend_variants() {
518        let backends = vec![
519            LLMBackend::OpenAI,
520            LLMBackend::Anthropic,
521            LLMBackend::Ollama,
522            LLMBackend::DeepSeek,
523            LLMBackend::XAI,
524            LLMBackend::Phind,
525            LLMBackend::Google,
526            LLMBackend::Groq,
527            LLMBackend::AzureOpenAI,
528        ];
529
530        for backend in backends {
531            let response = MockModelListResponse {
532                models: vec![],
533                raw_entries: vec![],
534                backend: backend.clone(),
535            };
536            // Use a more flexible assertion since matches! doesn't work with clone
537            let result_backend = response.get_backend();
538            assert!(std::mem::discriminant(&result_backend) == std::mem::discriminant(&backend));
539        }
540    }
541
542    #[tokio::test]
543    async fn test_models_provider_error_handling() {
544        let provider = MockModelsProvider::with_failure();
545
546        let result = provider.list_models(None).await;
547        match result {
548            Err(LLMError::ProviderError(msg)) => {
549                assert_eq!(msg, "Mock provider failed");
550            }
551            _ => panic!("Expected ProviderError"),
552        }
553    }
554
555    #[tokio::test]
556    async fn test_models_provider_with_many_models() {
557        let models: Vec<String> = (0..100).map(|i| format!("model-{i:03}")).collect();
558        let provider = MockModelsProvider::new(models.clone());
559
560        let result = provider.list_models(None).await;
561        assert!(result.is_ok());
562
563        let response = result.unwrap();
564        assert_eq!(response.get_models().len(), 100);
565        assert_eq!(response.get_models_raw().len(), 100);
566        assert_eq!(response.get_models()[0], "model-000");
567        assert_eq!(response.get_models()[99], "model-099");
568    }
569
570    #[test]
571    fn test_model_list_request_with_empty_filter() {
572        let request = ModelListRequest {
573            filter: Some("".to_string()),
574        };
575        assert_eq!(request.filter, Some("".to_string()));
576    }
577
578    #[test]
579    fn test_model_list_request_with_special_chars() {
580        let request = ModelListRequest {
581            filter: Some("model-name_with.special-chars".to_string()),
582        };
583        assert_eq!(
584            request.filter,
585            Some("model-name_with.special-chars".to_string())
586        );
587    }
588
589    #[test]
590    fn test_model_entry_with_empty_id() {
591        let entry = MockModelEntry {
592            id: "".to_string(),
593            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
594            extra_data: serde_json::Value::Null,
595        };
596        assert_eq!(entry.get_id(), "");
597    }
598
599    #[test]
600    fn test_model_entry_with_unicode_id() {
601        let entry = MockModelEntry {
602            id: "模型-测试-🤖".to_string(),
603            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
604            extra_data: serde_json::Value::Null,
605        };
606        assert_eq!(entry.get_id(), "模型-测试-🤖");
607    }
608}