Skip to main content

autoagents_llm/models/
mod.rs

1use crate::{builder::LLMBackend, error::LLMError};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::fmt::Debug;
7
8pub trait ModelListResponse: std::fmt::Debug {
9    fn get_models(&self) -> Vec<String>;
10    fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>>;
11    /// Returns an identifier for the backend that produced this response.
12    /// Built-in providers return their canonical lowercase name (e.g. `"openai"`,
13    /// `"anthropic"`). Third-party implementors may return any string that uniquely identifies their backend
14    fn get_backend(&self) -> String;
15}
16
17pub trait ModelListRawEntry: Debug {
18    fn get_id(&self) -> String;
19    fn get_created_at(&self) -> DateTime<Utc>;
20    fn get_raw(&self) -> serde_json::Value;
21}
22
23#[derive(Debug, Clone, Default)]
24pub struct ModelListRequest {
25    pub filter: Option<String>,
26}
27
28/// Trait for providers that support listing and retrieving model information.
29#[async_trait]
30pub trait ModelsProvider {
31    /// Asynchronously retrieves the list of available models ID's from the provider.
32    ///
33    /// # Arguments
34    ///
35    /// * `_request` - Optional filter by model ID
36    ///
37    /// # Returns
38    ///
39    /// List of model ID's or error
40    async fn list_models(
41        &self,
42        _request: Option<&ModelListRequest>,
43    ) -> Result<Box<dyn ModelListResponse>, LLMError> {
44        Err(LLMError::ProviderError(
45            "List Models not supported".to_string(),
46        ))
47    }
48}
49
50/// Standard model entry structure used by OpenAI-compatible providers
51#[derive(Clone, Debug, Deserialize, Serialize)]
52pub struct StandardModelEntry {
53    pub id: String,
54    pub created: Option<u64>,
55    #[serde(flatten)]
56    pub extra: Value,
57}
58
59impl ModelListRawEntry for StandardModelEntry {
60    fn get_id(&self) -> String {
61        self.id.clone()
62    }
63
64    fn get_created_at(&self) -> DateTime<Utc> {
65        self.created
66            .map(|t| DateTime::from_timestamp(t as i64, 0).unwrap_or_default())
67            .unwrap_or_default()
68    }
69
70    fn get_raw(&self) -> Value {
71        self.extra.clone()
72    }
73}
74
75/// Inner structure for model list response data (serializable)
76#[derive(Clone, Debug, Deserialize, Serialize)]
77pub struct StandardModelListResponseInner {
78    pub data: Vec<StandardModelEntry>,
79}
80
81/// Standard model list response structure used by OpenAI-compatible providers
82#[derive(Clone, Debug)]
83pub struct StandardModelListResponse {
84    pub inner: StandardModelListResponseInner,
85    pub backend: LLMBackend,
86}
87
88impl ModelListResponse for StandardModelListResponse {
89    fn get_models(&self) -> Vec<String> {
90        self.inner.data.iter().map(|e| e.id.clone()).collect()
91    }
92
93    fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
94        self.inner
95            .data
96            .iter()
97            .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
98            .collect()
99    }
100
101    fn get_backend(&self) -> String {
102        self.backend.to_string()
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::builder::LLMBackend;
110    use crate::error::LLMError;
111    use async_trait::async_trait;
112    use chrono::TimeZone;
113    use chrono::{DateTime, Utc};
114    use serde_json::Value;
115
116    // Mock implementations for testing
117    #[derive(Debug, Clone)]
118    struct MockModelEntry {
119        id: String,
120        created_at: DateTime<Utc>,
121        extra_data: Value,
122    }
123
124    impl ModelListRawEntry for MockModelEntry {
125        fn get_id(&self) -> String {
126            self.id.clone()
127        }
128
129        fn get_created_at(&self) -> DateTime<Utc> {
130            self.created_at
131        }
132
133        fn get_raw(&self) -> Value {
134            self.extra_data.clone()
135        }
136    }
137
138    #[derive(Debug)]
139    struct MockModelListResponse {
140        models: Vec<String>,
141        raw_entries: Vec<MockModelEntry>,
142        /// Backend identifier — plain string, no coupling to LLMBackend enum.
143        backend: String,
144    }
145
146    impl ModelListResponse for MockModelListResponse {
147        fn get_models(&self) -> Vec<String> {
148            self.models.clone()
149        }
150
151        fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
152            self.raw_entries
153                .iter()
154                .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
155                .collect()
156        }
157
158        fn get_backend(&self) -> String {
159            self.backend.clone()
160        }
161    }
162
163    struct MockModelsProvider {
164        should_fail: bool,
165        models: Vec<String>,
166    }
167
168    impl MockModelsProvider {
169        fn new(models: Vec<String>) -> Self {
170            Self {
171                should_fail: false,
172                models,
173            }
174        }
175
176        fn with_failure() -> Self {
177            Self {
178                should_fail: true,
179                models: vec![],
180            }
181        }
182    }
183
184    #[async_trait]
185    impl ModelsProvider for MockModelsProvider {
186        async fn list_models(
187            &self,
188            _request: Option<&ModelListRequest>,
189        ) -> Result<Box<dyn ModelListResponse>, LLMError> {
190            if self.should_fail {
191                return Err(LLMError::ProviderError("Mock provider failed".to_string()));
192            }
193
194            let raw_entries = self
195                .models
196                .iter()
197                .enumerate()
198                .map(|(i, model)| MockModelEntry {
199                    id: model.clone(),
200                    created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
201                    extra_data: serde_json::json!({
202                        "index": i,
203                        "description": format!("Model {}", model)
204                    }),
205                })
206                .collect();
207
208            Ok(Box::new(MockModelListResponse {
209                models: self.models.clone(),
210                raw_entries,
211                backend: "openai".to_string(),
212            }))
213        }
214    }
215
216    // Default provider that always fails
217    struct DefaultModelsProvider;
218
219    #[async_trait]
220    impl ModelsProvider for DefaultModelsProvider {}
221
222    #[test]
223    fn test_model_list_request_default() {
224        let request = ModelListRequest::default();
225        assert!(request.filter.is_none());
226    }
227
228    #[test]
229    fn test_model_list_request_with_filter() {
230        let request = ModelListRequest {
231            filter: Some("gpt".to_string()),
232        };
233        assert_eq!(request.filter, Some("gpt".to_string()));
234    }
235
236    #[test]
237    fn test_model_list_request_clone() {
238        let request = ModelListRequest {
239            filter: Some("test".to_string()),
240        };
241        let cloned = request.clone();
242        assert_eq!(request.filter, cloned.filter);
243    }
244
245    #[test]
246    fn test_model_list_request_debug() {
247        let request = ModelListRequest {
248            filter: Some("debug_test".to_string()),
249        };
250        let debug_str = format!("{request:?}");
251        assert!(debug_str.contains("ModelListRequest"));
252        assert!(debug_str.contains("debug_test"));
253    }
254
255    #[test]
256    fn test_mock_model_entry_creation() {
257        let now = Utc.timestamp_opt(1640995200, 0).unwrap();
258        let entry = MockModelEntry {
259            id: "test-model".to_string(),
260            created_at: now,
261            extra_data: serde_json::json!({"key": "value"}),
262        };
263
264        assert_eq!(entry.get_id(), "test-model");
265        assert_eq!(entry.get_created_at(), now);
266        assert_eq!(entry.get_raw(), serde_json::json!({"key": "value"}));
267    }
268
269    #[test]
270    fn test_mock_model_entry_debug() {
271        let entry = MockModelEntry {
272            id: "debug-model".to_string(),
273            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
274            extra_data: serde_json::json!({"debug": true}),
275        };
276
277        let debug_str = format!("{entry:?}");
278        assert!(debug_str.contains("MockModelEntry"));
279        assert!(debug_str.contains("debug-model"));
280    }
281
282    #[test]
283    fn test_mock_model_entry_clone() {
284        let entry = MockModelEntry {
285            id: "clone-model".to_string(),
286            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
287            extra_data: serde_json::json!({"clone": true}),
288        };
289
290        let cloned = entry.clone();
291        assert_eq!(entry.get_id(), cloned.get_id());
292        assert_eq!(entry.get_created_at(), cloned.get_created_at());
293        assert_eq!(entry.get_raw(), cloned.get_raw());
294    }
295
296    #[test]
297    fn test_mock_model_list_response_get_models() {
298        let models = vec!["model1".to_string(), "model2".to_string()];
299        let response = MockModelListResponse {
300            models: models.clone(),
301            raw_entries: vec![],
302            backend: "openai".to_string(),
303        };
304
305        assert_eq!(response.get_models(), models);
306    }
307
308    #[test]
309    fn test_mock_model_list_response_get_models_raw() {
310        let raw_entries = vec![
311            MockModelEntry {
312                id: "model1".to_string(),
313                created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
314                extra_data: serde_json::json!({"index": 0}),
315            },
316            MockModelEntry {
317                id: "model2".to_string(),
318                created_at: Utc.timestamp_opt(1640995201, 0).unwrap(),
319                extra_data: serde_json::json!({"index": 1}),
320            },
321        ];
322
323        let response = MockModelListResponse {
324            models: vec!["model1".to_string(), "model2".to_string()],
325            raw_entries: raw_entries.clone(),
326            backend: "anthropic".to_string(),
327        };
328
329        let raw = response.get_models_raw();
330        assert_eq!(raw.len(), 2);
331        assert_eq!(raw[0].get_id(), "model1");
332        assert_eq!(raw[1].get_id(), "model2");
333    }
334
335    #[test]
336    fn test_mock_model_list_response_get_backend() {
337        let response = MockModelListResponse {
338            models: vec![],
339            raw_entries: vec![],
340            backend: "google".to_string(),
341        };
342
343        assert_eq!(response.get_backend(), "google");
344    }
345
346    #[tokio::test]
347    async fn test_mock_models_provider_success() {
348        let models = vec!["gpt-3.5-turbo".to_string(), "gpt-4".to_string()];
349        let provider = MockModelsProvider::new(models.clone());
350
351        let result = provider.list_models(None).await;
352        assert!(result.is_ok());
353
354        let response = result.unwrap();
355        assert_eq!(response.get_models(), models);
356        assert_eq!(response.get_models_raw().len(), 2);
357        assert_eq!(response.get_backend(), "openai");
358    }
359
360    #[tokio::test]
361    async fn test_mock_models_provider_with_request() {
362        let models = vec!["model1".to_string(), "model2".to_string()];
363        let provider = MockModelsProvider::new(models.clone());
364        let request = ModelListRequest {
365            filter: Some("gpt".to_string()),
366        };
367
368        let result = provider.list_models(Some(&request)).await;
369        assert!(result.is_ok());
370
371        let response = result.unwrap();
372        assert_eq!(response.get_models(), models);
373    }
374
375    #[tokio::test]
376    async fn test_mock_models_provider_failure() {
377        let provider = MockModelsProvider::with_failure();
378
379        let result = provider.list_models(None).await;
380        assert!(result.is_err());
381        assert!(
382            result
383                .unwrap_err()
384                .to_string()
385                .contains("Mock provider failed")
386        );
387    }
388
389    #[tokio::test]
390    async fn test_mock_models_provider_empty_models() {
391        let provider = MockModelsProvider::new(vec![]);
392
393        let result = provider.list_models(None).await;
394        assert!(result.is_ok());
395
396        let response = result.unwrap();
397        assert_eq!(response.get_models(), Vec::<String>::new());
398        assert_eq!(response.get_models_raw().len(), 0);
399    }
400
401    #[tokio::test]
402    async fn test_default_models_provider_not_supported() {
403        let provider = DefaultModelsProvider;
404
405        let result = provider.list_models(None).await;
406        assert!(result.is_err());
407        assert!(
408            result
409                .unwrap_err()
410                .to_string()
411                .contains("List Models not supported")
412        );
413    }
414
415    #[tokio::test]
416    async fn test_default_models_provider_with_request() {
417        let provider = DefaultModelsProvider;
418        let request = ModelListRequest {
419            filter: Some("test".to_string()),
420        };
421
422        let result = provider.list_models(Some(&request)).await;
423        assert!(result.is_err());
424        assert!(
425            result
426                .unwrap_err()
427                .to_string()
428                .contains("List Models not supported")
429        );
430    }
431
432    #[test]
433    fn test_model_list_raw_entry_trait_object() {
434        let entry = MockModelEntry {
435            id: "trait-test".to_string(),
436            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
437            extra_data: serde_json::json!({"test": "data"}),
438        };
439
440        let boxed: Box<dyn ModelListRawEntry> = Box::new(entry);
441        assert_eq!(boxed.get_id(), "trait-test");
442        assert_eq!(boxed.get_raw(), serde_json::json!({"test": "data"}));
443    }
444
445    #[test]
446    fn test_model_list_response_trait_object() {
447        let response = MockModelListResponse {
448            models: vec!["test-model".to_string()],
449            raw_entries: vec![],
450            backend: "ollama".to_string(),
451        };
452
453        let boxed: Box<dyn ModelListResponse> = Box::new(response);
454        assert_eq!(boxed.get_models(), vec!["test-model".to_string()]);
455        assert_eq!(boxed.get_backend(), "ollama");
456    }
457
458    #[test]
459    fn test_model_entry_with_complex_data() {
460        let complex_data = serde_json::json!({
461            "capabilities": ["chat", "completion"],
462            "max_tokens": 4096,
463            "pricing": {
464                "input": 0.0015,
465                "output": 0.002
466            }
467        });
468
469        let entry = MockModelEntry {
470            id: "complex-model".to_string(),
471            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
472            extra_data: complex_data.clone(),
473        };
474
475        assert_eq!(entry.get_raw(), complex_data);
476        assert_eq!(entry.get_raw()["capabilities"][0], "chat");
477        assert_eq!(entry.get_raw()["max_tokens"], 4096);
478    }
479
480    #[test]
481    fn test_model_entry_with_null_data() {
482        let entry = MockModelEntry {
483            id: "null-model".to_string(),
484            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
485            extra_data: serde_json::Value::Null,
486        };
487
488        assert_eq!(entry.get_raw(), serde_json::Value::Null);
489    }
490
491    #[test]
492    fn test_model_entry_time_ordering() {
493        let time1 = Utc.timestamp_opt(1640995200, 0).unwrap();
494        let time2 = time1 + chrono::Duration::seconds(1);
495
496        let entry1 = MockModelEntry {
497            id: "older".to_string(),
498            created_at: time1,
499            extra_data: serde_json::Value::Null,
500        };
501
502        let entry2 = MockModelEntry {
503            id: "newer".to_string(),
504            created_at: time2,
505            extra_data: serde_json::Value::Null,
506        };
507
508        assert!(entry1.get_created_at() < entry2.get_created_at());
509    }
510
511    #[test]
512    fn test_backend_variants() {
513        let cases = [
514            (LLMBackend::OpenAI, "openai"),
515            (LLMBackend::Anthropic, "anthropic"),
516            (LLMBackend::Ollama, "ollama"),
517            (LLMBackend::DeepSeek, "deepseek"),
518            (LLMBackend::XAI, "xai"),
519            (LLMBackend::Phind, "phind"),
520            (LLMBackend::Google, "google"),
521            (LLMBackend::Groq, "groq"),
522            (LLMBackend::AzureOpenAI, "azure-openai"),
523        ];
524
525        for (backend, expected) in cases {
526            // StandardModelListResponse stores LLMBackend internally and returns String
527            let response = StandardModelListResponse {
528                inner: StandardModelListResponseInner { data: vec![] },
529                backend,
530            };
531            assert_eq!(response.get_backend(), expected);
532        }
533    }
534
535    #[tokio::test]
536    async fn test_models_provider_error_handling() {
537        let provider = MockModelsProvider::with_failure();
538
539        let result = provider.list_models(None).await;
540        match result {
541            Err(LLMError::ProviderError(msg)) => {
542                assert_eq!(msg, "Mock provider failed");
543            }
544            _ => panic!("Expected ProviderError"),
545        }
546    }
547
548    #[tokio::test]
549    async fn test_models_provider_with_many_models() {
550        let models: Vec<String> = (0..100).map(|i| format!("model-{i:03}")).collect();
551        let provider = MockModelsProvider::new(models.clone());
552
553        let result = provider.list_models(None).await;
554        assert!(result.is_ok());
555
556        let response = result.unwrap();
557        assert_eq!(response.get_models().len(), 100);
558        assert_eq!(response.get_models_raw().len(), 100);
559        assert_eq!(response.get_models()[0], "model-000");
560        assert_eq!(response.get_models()[99], "model-099");
561    }
562
563    #[test]
564    fn test_model_list_request_with_empty_filter() {
565        let request = ModelListRequest {
566            filter: Some("".to_string()),
567        };
568        assert_eq!(request.filter, Some("".to_string()));
569    }
570
571    #[test]
572    fn test_model_list_request_with_special_chars() {
573        let request = ModelListRequest {
574            filter: Some("model-name_with.special-chars".to_string()),
575        };
576        assert_eq!(
577            request.filter,
578            Some("model-name_with.special-chars".to_string())
579        );
580    }
581
582    #[test]
583    fn test_model_entry_with_empty_id() {
584        let entry = MockModelEntry {
585            id: "".to_string(),
586            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
587            extra_data: serde_json::Value::Null,
588        };
589        assert_eq!(entry.get_id(), "");
590    }
591
592    #[test]
593    fn test_model_entry_with_unicode_id() {
594        let entry = MockModelEntry {
595            id: "模型-测试-🤖".to_string(),
596            created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
597            extra_data: serde_json::Value::Null,
598        };
599        assert_eq!(entry.get_id(), "模型-测试-🤖");
600    }
601
602    #[test]
603    fn test_standard_model_entry_created_at_default() {
604        let entry = StandardModelEntry {
605            id: "m1".to_string(),
606            created: None,
607            extra: serde_json::Value::Null,
608        };
609        let dt = entry.get_created_at();
610        assert_eq!(dt.timestamp(), 0);
611    }
612
613    #[test]
614    fn test_standard_model_list_response_helpers() {
615        let inner = StandardModelListResponseInner {
616            data: vec![StandardModelEntry {
617                id: "m1".to_string(),
618                created: Some(1),
619                extra: serde_json::Value::Null,
620            }],
621        };
622        let response = StandardModelListResponse {
623            inner,
624            backend: LLMBackend::OpenAI,
625        };
626        assert_eq!(response.get_models(), vec!["m1".to_string()]);
627        assert_eq!(response.get_models_raw().len(), 1);
628        assert_eq!(response.get_backend(), "openai");
629    }
630}