1use crate::{builder::LLMBackend, error::LLMError};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use std::fmt::Debug;
5
6pub trait ModelListResponse: std::fmt::Debug {
7 fn get_models(&self) -> Vec<String>;
8 fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>>;
9 fn get_backend(&self) -> LLMBackend;
10}
11
12pub trait ModelListRawEntry: Debug {
13 fn get_id(&self) -> String;
14 fn get_created_at(&self) -> DateTime<Utc>;
15 fn get_raw(&self) -> serde_json::Value;
16}
17
18#[derive(Debug, Clone, Default)]
19pub struct ModelListRequest {
20 pub filter: Option<String>,
21}
22
23#[async_trait]
25pub trait ModelsProvider {
26 async fn list_models(
36 &self,
37 _request: Option<&ModelListRequest>,
38 ) -> Result<Box<dyn ModelListResponse>, LLMError> {
39 Err(LLMError::ProviderError(
40 "List Models not supported".to_string(),
41 ))
42 }
43}
44
45#[cfg(test)]
46mod tests {
47 use super::*;
48 use crate::builder::LLMBackend;
49 use crate::error::LLMError;
50 use async_trait::async_trait;
51 use chrono::TimeZone;
52 use chrono::{DateTime, Utc};
53 use serde_json::Value;
54
55 #[derive(Debug, Clone)]
57 struct MockModelEntry {
58 id: String,
59 created_at: DateTime<Utc>,
60 extra_data: Value,
61 }
62
63 impl ModelListRawEntry for MockModelEntry {
64 fn get_id(&self) -> String {
65 self.id.clone()
66 }
67
68 fn get_created_at(&self) -> DateTime<Utc> {
69 self.created_at
70 }
71
72 fn get_raw(&self) -> Value {
73 self.extra_data.clone()
74 }
75 }
76
77 struct MockModelListResponse {
78 models: Vec<String>,
79 raw_entries: Vec<MockModelEntry>,
80 backend: LLMBackend,
81 }
82
83 impl std::fmt::Debug for MockModelListResponse {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 f.debug_struct("MockModelListResponse")
86 .field("models", &self.models)
87 .field("raw_entries", &self.raw_entries)
88 .field("backend", &self.backend)
89 .finish()
90 }
91 }
92
93 impl ModelListResponse for MockModelListResponse {
94 fn get_models(&self) -> Vec<String> {
95 self.models.clone()
96 }
97
98 fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
99 self.raw_entries
100 .iter()
101 .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
102 .collect()
103 }
104
105 fn get_backend(&self) -> LLMBackend {
106 self.backend.clone()
107 }
108 }
109
110 struct MockModelsProvider {
111 should_fail: bool,
112 models: Vec<String>,
113 }
114
115 impl MockModelsProvider {
116 fn new(models: Vec<String>) -> Self {
117 Self {
118 should_fail: false,
119 models,
120 }
121 }
122
123 fn with_failure() -> Self {
124 Self {
125 should_fail: true,
126 models: vec![],
127 }
128 }
129 }
130
131 #[async_trait]
132 impl ModelsProvider for MockModelsProvider {
133 async fn list_models(
134 &self,
135 _request: Option<&ModelListRequest>,
136 ) -> Result<Box<dyn ModelListResponse>, LLMError> {
137 if self.should_fail {
138 return Err(LLMError::ProviderError("Mock provider failed".to_string()));
139 }
140
141 let raw_entries = self
142 .models
143 .iter()
144 .enumerate()
145 .map(|(i, model)| MockModelEntry {
146 id: model.clone(),
147 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
148 extra_data: serde_json::json!({
149 "index": i,
150 "description": format!("Model {}", model)
151 }),
152 })
153 .collect();
154
155 Ok(Box::new(MockModelListResponse {
156 models: self.models.clone(),
157 raw_entries,
158 backend: LLMBackend::OpenAI,
159 }))
160 }
161 }
162
163 struct DefaultModelsProvider;
165
166 #[async_trait]
167 impl ModelsProvider for DefaultModelsProvider {}
168
169 #[test]
170 fn test_model_list_request_default() {
171 let request = ModelListRequest::default();
172 assert!(request.filter.is_none());
173 }
174
175 #[test]
176 fn test_model_list_request_with_filter() {
177 let request = ModelListRequest {
178 filter: Some("gpt".to_string()),
179 };
180 assert_eq!(request.filter, Some("gpt".to_string()));
181 }
182
183 #[test]
184 fn test_model_list_request_clone() {
185 let request = ModelListRequest {
186 filter: Some("test".to_string()),
187 };
188 let cloned = request.clone();
189 assert_eq!(request.filter, cloned.filter);
190 }
191
192 #[test]
193 fn test_model_list_request_debug() {
194 let request = ModelListRequest {
195 filter: Some("debug_test".to_string()),
196 };
197 let debug_str = format!("{request:?}");
198 assert!(debug_str.contains("ModelListRequest"));
199 assert!(debug_str.contains("debug_test"));
200 }
201
202 #[test]
203 fn test_mock_model_entry_creation() {
204 let now = Utc.timestamp_opt(1640995200, 0).unwrap();
205 let entry = MockModelEntry {
206 id: "test-model".to_string(),
207 created_at: now,
208 extra_data: serde_json::json!({"key": "value"}),
209 };
210
211 assert_eq!(entry.get_id(), "test-model");
212 assert_eq!(entry.get_created_at(), now);
213 assert_eq!(entry.get_raw(), serde_json::json!({"key": "value"}));
214 }
215
216 #[test]
217 fn test_mock_model_entry_debug() {
218 let entry = MockModelEntry {
219 id: "debug-model".to_string(),
220 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
221 extra_data: serde_json::json!({"debug": true}),
222 };
223
224 let debug_str = format!("{entry:?}");
225 assert!(debug_str.contains("MockModelEntry"));
226 assert!(debug_str.contains("debug-model"));
227 }
228
229 #[test]
230 fn test_mock_model_entry_clone() {
231 let entry = MockModelEntry {
232 id: "clone-model".to_string(),
233 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
234 extra_data: serde_json::json!({"clone": true}),
235 };
236
237 let cloned = entry.clone();
238 assert_eq!(entry.get_id(), cloned.get_id());
239 assert_eq!(entry.get_created_at(), cloned.get_created_at());
240 assert_eq!(entry.get_raw(), cloned.get_raw());
241 }
242
243 #[test]
244 fn test_mock_model_list_response_get_models() {
245 let models = vec!["model1".to_string(), "model2".to_string()];
246 let response = MockModelListResponse {
247 models: models.clone(),
248 raw_entries: vec![],
249 backend: LLMBackend::OpenAI,
250 };
251
252 assert_eq!(response.get_models(), models);
253 }
254
255 #[test]
256 fn test_mock_model_list_response_get_models_raw() {
257 let raw_entries = vec![
258 MockModelEntry {
259 id: "model1".to_string(),
260 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
261 extra_data: serde_json::json!({"index": 0}),
262 },
263 MockModelEntry {
264 id: "model2".to_string(),
265 created_at: Utc.timestamp_opt(1640995201, 0).unwrap(),
266 extra_data: serde_json::json!({"index": 1}),
267 },
268 ];
269
270 let response = MockModelListResponse {
271 models: vec!["model1".to_string(), "model2".to_string()],
272 raw_entries: raw_entries.clone(),
273 backend: LLMBackend::Anthropic,
274 };
275
276 let raw = response.get_models_raw();
277 assert_eq!(raw.len(), 2);
278 assert_eq!(raw[0].get_id(), "model1");
279 assert_eq!(raw[1].get_id(), "model2");
280 }
281
282 #[test]
283 fn test_mock_model_list_response_get_backend() {
284 let response = MockModelListResponse {
285 models: vec![],
286 raw_entries: vec![],
287 backend: LLMBackend::Google,
288 };
289
290 assert!(matches!(response.get_backend(), LLMBackend::Google));
291 }
292
293 #[tokio::test]
294 async fn test_mock_models_provider_success() {
295 let models = vec!["gpt-3.5-turbo".to_string(), "gpt-4".to_string()];
296 let provider = MockModelsProvider::new(models.clone());
297
298 let result = provider.list_models(None).await;
299 assert!(result.is_ok());
300
301 let response = result.unwrap();
302 assert_eq!(response.get_models(), models);
303 assert_eq!(response.get_models_raw().len(), 2);
304 assert!(matches!(response.get_backend(), LLMBackend::OpenAI));
305 }
306
307 #[tokio::test]
308 async fn test_mock_models_provider_with_request() {
309 let models = vec!["model1".to_string(), "model2".to_string()];
310 let provider = MockModelsProvider::new(models.clone());
311 let request = ModelListRequest {
312 filter: Some("gpt".to_string()),
313 };
314
315 let result = provider.list_models(Some(&request)).await;
316 assert!(result.is_ok());
317
318 let response = result.unwrap();
319 assert_eq!(response.get_models(), models);
320 }
321
322 #[tokio::test]
323 async fn test_mock_models_provider_failure() {
324 let provider = MockModelsProvider::with_failure();
325
326 let result = provider.list_models(None).await;
327 assert!(result.is_err());
328 assert!(result
329 .unwrap_err()
330 .to_string()
331 .contains("Mock provider failed"));
332 }
333
334 #[tokio::test]
335 async fn test_mock_models_provider_empty_models() {
336 let provider = MockModelsProvider::new(vec![]);
337
338 let result = provider.list_models(None).await;
339 assert!(result.is_ok());
340
341 let response = result.unwrap();
342 assert_eq!(response.get_models(), Vec::<String>::new());
343 assert_eq!(response.get_models_raw().len(), 0);
344 }
345
346 #[tokio::test]
347 async fn test_default_models_provider_not_supported() {
348 let provider = DefaultModelsProvider;
349
350 let result = provider.list_models(None).await;
351 assert!(result.is_err());
352 assert!(result
353 .unwrap_err()
354 .to_string()
355 .contains("List Models not supported"));
356 }
357
358 #[tokio::test]
359 async fn test_default_models_provider_with_request() {
360 let provider = DefaultModelsProvider;
361 let request = ModelListRequest {
362 filter: Some("test".to_string()),
363 };
364
365 let result = provider.list_models(Some(&request)).await;
366 assert!(result.is_err());
367 assert!(result
368 .unwrap_err()
369 .to_string()
370 .contains("List Models not supported"));
371 }
372
373 #[test]
374 fn test_model_list_raw_entry_trait_object() {
375 let entry = MockModelEntry {
376 id: "trait-test".to_string(),
377 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
378 extra_data: serde_json::json!({"test": "data"}),
379 };
380
381 let boxed: Box<dyn ModelListRawEntry> = Box::new(entry);
382 assert_eq!(boxed.get_id(), "trait-test");
383 assert_eq!(boxed.get_raw(), serde_json::json!({"test": "data"}));
384 }
385
386 #[test]
387 fn test_model_list_response_trait_object() {
388 let response = MockModelListResponse {
389 models: vec!["test-model".to_string()],
390 raw_entries: vec![],
391 backend: LLMBackend::Ollama,
392 };
393
394 let boxed: Box<dyn ModelListResponse> = Box::new(response);
395 assert_eq!(boxed.get_models(), vec!["test-model".to_string()]);
396 assert!(matches!(boxed.get_backend(), LLMBackend::Ollama));
397 }
398
399 #[test]
400 fn test_model_entry_with_complex_data() {
401 let complex_data = serde_json::json!({
402 "capabilities": ["chat", "completion"],
403 "max_tokens": 4096,
404 "pricing": {
405 "input": 0.0015,
406 "output": 0.002
407 }
408 });
409
410 let entry = MockModelEntry {
411 id: "complex-model".to_string(),
412 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
413 extra_data: complex_data.clone(),
414 };
415
416 assert_eq!(entry.get_raw(), complex_data);
417 assert_eq!(entry.get_raw()["capabilities"][0], "chat");
418 assert_eq!(entry.get_raw()["max_tokens"], 4096);
419 }
420
421 #[test]
422 fn test_model_entry_with_null_data() {
423 let entry = MockModelEntry {
424 id: "null-model".to_string(),
425 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
426 extra_data: serde_json::Value::Null,
427 };
428
429 assert_eq!(entry.get_raw(), serde_json::Value::Null);
430 }
431
432 #[test]
433 fn test_model_entry_time_ordering() {
434 let time1 = Utc.timestamp_opt(1640995200, 0).unwrap();
435 let time2 = time1 + chrono::Duration::seconds(1);
436
437 let entry1 = MockModelEntry {
438 id: "older".to_string(),
439 created_at: time1,
440 extra_data: serde_json::Value::Null,
441 };
442
443 let entry2 = MockModelEntry {
444 id: "newer".to_string(),
445 created_at: time2,
446 extra_data: serde_json::Value::Null,
447 };
448
449 assert!(entry1.get_created_at() < entry2.get_created_at());
450 }
451
452 #[test]
453 fn test_backend_variants() {
454 let backends = vec![
455 LLMBackend::OpenAI,
456 LLMBackend::Anthropic,
457 LLMBackend::Ollama,
458 LLMBackend::DeepSeek,
459 LLMBackend::XAI,
460 LLMBackend::Phind,
461 LLMBackend::Google,
462 LLMBackend::Groq,
463 LLMBackend::AzureOpenAI,
464 ];
465
466 for backend in backends {
467 let response = MockModelListResponse {
468 models: vec![],
469 raw_entries: vec![],
470 backend: backend.clone(),
471 };
472 let result_backend = response.get_backend();
474 assert!(std::mem::discriminant(&result_backend) == std::mem::discriminant(&backend));
475 }
476 }
477
478 #[tokio::test]
479 async fn test_models_provider_error_handling() {
480 let provider = MockModelsProvider::with_failure();
481
482 let result = provider.list_models(None).await;
483 match result {
484 Err(LLMError::ProviderError(msg)) => {
485 assert_eq!(msg, "Mock provider failed");
486 }
487 _ => panic!("Expected ProviderError"),
488 }
489 }
490
491 #[tokio::test]
492 async fn test_models_provider_with_many_models() {
493 let models: Vec<String> = (0..100).map(|i| format!("model-{i:03}")).collect();
494 let provider = MockModelsProvider::new(models.clone());
495
496 let result = provider.list_models(None).await;
497 assert!(result.is_ok());
498
499 let response = result.unwrap();
500 assert_eq!(response.get_models().len(), 100);
501 assert_eq!(response.get_models_raw().len(), 100);
502 assert_eq!(response.get_models()[0], "model-000");
503 assert_eq!(response.get_models()[99], "model-099");
504 }
505
506 #[test]
507 fn test_model_list_request_with_empty_filter() {
508 let request = ModelListRequest {
509 filter: Some("".to_string()),
510 };
511 assert_eq!(request.filter, Some("".to_string()));
512 }
513
514 #[test]
515 fn test_model_list_request_with_special_chars() {
516 let request = ModelListRequest {
517 filter: Some("model-name_with.special-chars".to_string()),
518 };
519 assert_eq!(
520 request.filter,
521 Some("model-name_with.special-chars".to_string())
522 );
523 }
524
525 #[test]
526 fn test_model_entry_with_empty_id() {
527 let entry = MockModelEntry {
528 id: "".to_string(),
529 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
530 extra_data: serde_json::Value::Null,
531 };
532 assert_eq!(entry.get_id(), "");
533 }
534
535 #[test]
536 fn test_model_entry_with_unicode_id() {
537 let entry = MockModelEntry {
538 id: "模型-测试-🤖".to_string(),
539 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
540 extra_data: serde_json::Value::Null,
541 };
542 assert_eq!(entry.get_id(), "模型-测试-🤖");
543 }
544}