1use crate::{builder::LLMBackend, error::LLMError};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::fmt::Debug;
7
8pub trait ModelListResponse: std::fmt::Debug {
9 fn get_models(&self) -> Vec<String>;
10 fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>>;
11 fn get_backend(&self) -> LLMBackend;
12}
13
14pub trait ModelListRawEntry: Debug {
15 fn get_id(&self) -> String;
16 fn get_created_at(&self) -> DateTime<Utc>;
17 fn get_raw(&self) -> serde_json::Value;
18}
19
20#[derive(Debug, Clone, Default)]
21pub struct ModelListRequest {
22 pub filter: Option<String>,
23}
24
25#[async_trait]
27pub trait ModelsProvider {
28 async fn list_models(
38 &self,
39 _request: Option<&ModelListRequest>,
40 ) -> Result<Box<dyn ModelListResponse>, LLMError> {
41 Err(LLMError::ProviderError(
42 "List Models not supported".to_string(),
43 ))
44 }
45}
46
47#[derive(Clone, Debug, Deserialize, Serialize)]
49pub struct StandardModelEntry {
50 pub id: String,
51 pub created: Option<u64>,
52 #[serde(flatten)]
53 pub extra: Value,
54}
55
56impl ModelListRawEntry for StandardModelEntry {
57 fn get_id(&self) -> String {
58 self.id.clone()
59 }
60
61 fn get_created_at(&self) -> DateTime<Utc> {
62 self.created
63 .map(|t| DateTime::from_timestamp(t as i64, 0).unwrap_or_default())
64 .unwrap_or_default()
65 }
66
67 fn get_raw(&self) -> Value {
68 self.extra.clone()
69 }
70}
71
72#[derive(Clone, Debug, Deserialize, Serialize)]
74pub struct StandardModelListResponseInner {
75 pub data: Vec<StandardModelEntry>,
76}
77
78#[derive(Clone, Debug)]
80pub struct StandardModelListResponse {
81 pub inner: StandardModelListResponseInner,
82 pub backend: LLMBackend,
83}
84
85impl ModelListResponse for StandardModelListResponse {
86 fn get_models(&self) -> Vec<String> {
87 self.inner.data.iter().map(|e| e.id.clone()).collect()
88 }
89
90 fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
91 self.inner
92 .data
93 .iter()
94 .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
95 .collect()
96 }
97
98 fn get_backend(&self) -> LLMBackend {
99 self.backend.clone()
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::builder::LLMBackend;
107 use crate::error::LLMError;
108 use async_trait::async_trait;
109 use chrono::TimeZone;
110 use chrono::{DateTime, Utc};
111 use serde_json::Value;
112
113 #[derive(Debug, Clone)]
115 struct MockModelEntry {
116 id: String,
117 created_at: DateTime<Utc>,
118 extra_data: Value,
119 }
120
121 impl ModelListRawEntry for MockModelEntry {
122 fn get_id(&self) -> String {
123 self.id.clone()
124 }
125
126 fn get_created_at(&self) -> DateTime<Utc> {
127 self.created_at
128 }
129
130 fn get_raw(&self) -> Value {
131 self.extra_data.clone()
132 }
133 }
134
135 struct MockModelListResponse {
136 models: Vec<String>,
137 raw_entries: Vec<MockModelEntry>,
138 backend: LLMBackend,
139 }
140
141 impl std::fmt::Debug for MockModelListResponse {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 f.debug_struct("MockModelListResponse")
144 .field("models", &self.models)
145 .field("raw_entries", &self.raw_entries)
146 .field("backend", &self.backend)
147 .finish()
148 }
149 }
150
151 impl ModelListResponse for MockModelListResponse {
152 fn get_models(&self) -> Vec<String> {
153 self.models.clone()
154 }
155
156 fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
157 self.raw_entries
158 .iter()
159 .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
160 .collect()
161 }
162
163 fn get_backend(&self) -> LLMBackend {
164 self.backend.clone()
165 }
166 }
167
168 struct MockModelsProvider {
169 should_fail: bool,
170 models: Vec<String>,
171 }
172
173 impl MockModelsProvider {
174 fn new(models: Vec<String>) -> Self {
175 Self {
176 should_fail: false,
177 models,
178 }
179 }
180
181 fn with_failure() -> Self {
182 Self {
183 should_fail: true,
184 models: vec![],
185 }
186 }
187 }
188
189 #[async_trait]
190 impl ModelsProvider for MockModelsProvider {
191 async fn list_models(
192 &self,
193 _request: Option<&ModelListRequest>,
194 ) -> Result<Box<dyn ModelListResponse>, LLMError> {
195 if self.should_fail {
196 return Err(LLMError::ProviderError("Mock provider failed".to_string()));
197 }
198
199 let raw_entries = self
200 .models
201 .iter()
202 .enumerate()
203 .map(|(i, model)| MockModelEntry {
204 id: model.clone(),
205 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
206 extra_data: serde_json::json!({
207 "index": i,
208 "description": format!("Model {}", model)
209 }),
210 })
211 .collect();
212
213 Ok(Box::new(MockModelListResponse {
214 models: self.models.clone(),
215 raw_entries,
216 backend: LLMBackend::OpenAI,
217 }))
218 }
219 }
220
221 struct DefaultModelsProvider;
223
224 #[async_trait]
225 impl ModelsProvider for DefaultModelsProvider {}
226
227 #[test]
228 fn test_model_list_request_default() {
229 let request = ModelListRequest::default();
230 assert!(request.filter.is_none());
231 }
232
233 #[test]
234 fn test_model_list_request_with_filter() {
235 let request = ModelListRequest {
236 filter: Some("gpt".to_string()),
237 };
238 assert_eq!(request.filter, Some("gpt".to_string()));
239 }
240
241 #[test]
242 fn test_model_list_request_clone() {
243 let request = ModelListRequest {
244 filter: Some("test".to_string()),
245 };
246 let cloned = request.clone();
247 assert_eq!(request.filter, cloned.filter);
248 }
249
250 #[test]
251 fn test_model_list_request_debug() {
252 let request = ModelListRequest {
253 filter: Some("debug_test".to_string()),
254 };
255 let debug_str = format!("{request:?}");
256 assert!(debug_str.contains("ModelListRequest"));
257 assert!(debug_str.contains("debug_test"));
258 }
259
260 #[test]
261 fn test_mock_model_entry_creation() {
262 let now = Utc.timestamp_opt(1640995200, 0).unwrap();
263 let entry = MockModelEntry {
264 id: "test-model".to_string(),
265 created_at: now,
266 extra_data: serde_json::json!({"key": "value"}),
267 };
268
269 assert_eq!(entry.get_id(), "test-model");
270 assert_eq!(entry.get_created_at(), now);
271 assert_eq!(entry.get_raw(), serde_json::json!({"key": "value"}));
272 }
273
274 #[test]
275 fn test_mock_model_entry_debug() {
276 let entry = MockModelEntry {
277 id: "debug-model".to_string(),
278 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
279 extra_data: serde_json::json!({"debug": true}),
280 };
281
282 let debug_str = format!("{entry:?}");
283 assert!(debug_str.contains("MockModelEntry"));
284 assert!(debug_str.contains("debug-model"));
285 }
286
287 #[test]
288 fn test_mock_model_entry_clone() {
289 let entry = MockModelEntry {
290 id: "clone-model".to_string(),
291 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
292 extra_data: serde_json::json!({"clone": true}),
293 };
294
295 let cloned = entry.clone();
296 assert_eq!(entry.get_id(), cloned.get_id());
297 assert_eq!(entry.get_created_at(), cloned.get_created_at());
298 assert_eq!(entry.get_raw(), cloned.get_raw());
299 }
300
301 #[test]
302 fn test_mock_model_list_response_get_models() {
303 let models = vec!["model1".to_string(), "model2".to_string()];
304 let response = MockModelListResponse {
305 models: models.clone(),
306 raw_entries: vec![],
307 backend: LLMBackend::OpenAI,
308 };
309
310 assert_eq!(response.get_models(), models);
311 }
312
313 #[test]
314 fn test_mock_model_list_response_get_models_raw() {
315 let raw_entries = vec![
316 MockModelEntry {
317 id: "model1".to_string(),
318 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
319 extra_data: serde_json::json!({"index": 0}),
320 },
321 MockModelEntry {
322 id: "model2".to_string(),
323 created_at: Utc.timestamp_opt(1640995201, 0).unwrap(),
324 extra_data: serde_json::json!({"index": 1}),
325 },
326 ];
327
328 let response = MockModelListResponse {
329 models: vec!["model1".to_string(), "model2".to_string()],
330 raw_entries: raw_entries.clone(),
331 backend: LLMBackend::Anthropic,
332 };
333
334 let raw = response.get_models_raw();
335 assert_eq!(raw.len(), 2);
336 assert_eq!(raw[0].get_id(), "model1");
337 assert_eq!(raw[1].get_id(), "model2");
338 }
339
340 #[test]
341 fn test_mock_model_list_response_get_backend() {
342 let response = MockModelListResponse {
343 models: vec![],
344 raw_entries: vec![],
345 backend: LLMBackend::Google,
346 };
347
348 assert!(matches!(response.get_backend(), LLMBackend::Google));
349 }
350
351 #[tokio::test]
352 async fn test_mock_models_provider_success() {
353 let models = vec!["gpt-3.5-turbo".to_string(), "gpt-4".to_string()];
354 let provider = MockModelsProvider::new(models.clone());
355
356 let result = provider.list_models(None).await;
357 assert!(result.is_ok());
358
359 let response = result.unwrap();
360 assert_eq!(response.get_models(), models);
361 assert_eq!(response.get_models_raw().len(), 2);
362 assert!(matches!(response.get_backend(), LLMBackend::OpenAI));
363 }
364
365 #[tokio::test]
366 async fn test_mock_models_provider_with_request() {
367 let models = vec!["model1".to_string(), "model2".to_string()];
368 let provider = MockModelsProvider::new(models.clone());
369 let request = ModelListRequest {
370 filter: Some("gpt".to_string()),
371 };
372
373 let result = provider.list_models(Some(&request)).await;
374 assert!(result.is_ok());
375
376 let response = result.unwrap();
377 assert_eq!(response.get_models(), models);
378 }
379
380 #[tokio::test]
381 async fn test_mock_models_provider_failure() {
382 let provider = MockModelsProvider::with_failure();
383
384 let result = provider.list_models(None).await;
385 assert!(result.is_err());
386 assert!(
387 result
388 .unwrap_err()
389 .to_string()
390 .contains("Mock provider failed")
391 );
392 }
393
394 #[tokio::test]
395 async fn test_mock_models_provider_empty_models() {
396 let provider = MockModelsProvider::new(vec![]);
397
398 let result = provider.list_models(None).await;
399 assert!(result.is_ok());
400
401 let response = result.unwrap();
402 assert_eq!(response.get_models(), Vec::<String>::new());
403 assert_eq!(response.get_models_raw().len(), 0);
404 }
405
406 #[tokio::test]
407 async fn test_default_models_provider_not_supported() {
408 let provider = DefaultModelsProvider;
409
410 let result = provider.list_models(None).await;
411 assert!(result.is_err());
412 assert!(
413 result
414 .unwrap_err()
415 .to_string()
416 .contains("List Models not supported")
417 );
418 }
419
420 #[tokio::test]
421 async fn test_default_models_provider_with_request() {
422 let provider = DefaultModelsProvider;
423 let request = ModelListRequest {
424 filter: Some("test".to_string()),
425 };
426
427 let result = provider.list_models(Some(&request)).await;
428 assert!(result.is_err());
429 assert!(
430 result
431 .unwrap_err()
432 .to_string()
433 .contains("List Models not supported")
434 );
435 }
436
437 #[test]
438 fn test_model_list_raw_entry_trait_object() {
439 let entry = MockModelEntry {
440 id: "trait-test".to_string(),
441 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
442 extra_data: serde_json::json!({"test": "data"}),
443 };
444
445 let boxed: Box<dyn ModelListRawEntry> = Box::new(entry);
446 assert_eq!(boxed.get_id(), "trait-test");
447 assert_eq!(boxed.get_raw(), serde_json::json!({"test": "data"}));
448 }
449
450 #[test]
451 fn test_model_list_response_trait_object() {
452 let response = MockModelListResponse {
453 models: vec!["test-model".to_string()],
454 raw_entries: vec![],
455 backend: LLMBackend::Ollama,
456 };
457
458 let boxed: Box<dyn ModelListResponse> = Box::new(response);
459 assert_eq!(boxed.get_models(), vec!["test-model".to_string()]);
460 assert!(matches!(boxed.get_backend(), LLMBackend::Ollama));
461 }
462
463 #[test]
464 fn test_model_entry_with_complex_data() {
465 let complex_data = serde_json::json!({
466 "capabilities": ["chat", "completion"],
467 "max_tokens": 4096,
468 "pricing": {
469 "input": 0.0015,
470 "output": 0.002
471 }
472 });
473
474 let entry = MockModelEntry {
475 id: "complex-model".to_string(),
476 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
477 extra_data: complex_data.clone(),
478 };
479
480 assert_eq!(entry.get_raw(), complex_data);
481 assert_eq!(entry.get_raw()["capabilities"][0], "chat");
482 assert_eq!(entry.get_raw()["max_tokens"], 4096);
483 }
484
485 #[test]
486 fn test_model_entry_with_null_data() {
487 let entry = MockModelEntry {
488 id: "null-model".to_string(),
489 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
490 extra_data: serde_json::Value::Null,
491 };
492
493 assert_eq!(entry.get_raw(), serde_json::Value::Null);
494 }
495
496 #[test]
497 fn test_model_entry_time_ordering() {
498 let time1 = Utc.timestamp_opt(1640995200, 0).unwrap();
499 let time2 = time1 + chrono::Duration::seconds(1);
500
501 let entry1 = MockModelEntry {
502 id: "older".to_string(),
503 created_at: time1,
504 extra_data: serde_json::Value::Null,
505 };
506
507 let entry2 = MockModelEntry {
508 id: "newer".to_string(),
509 created_at: time2,
510 extra_data: serde_json::Value::Null,
511 };
512
513 assert!(entry1.get_created_at() < entry2.get_created_at());
514 }
515
516 #[test]
517 fn test_backend_variants() {
518 let backends = vec![
519 LLMBackend::OpenAI,
520 LLMBackend::Anthropic,
521 LLMBackend::Ollama,
522 LLMBackend::DeepSeek,
523 LLMBackend::XAI,
524 LLMBackend::Phind,
525 LLMBackend::Google,
526 LLMBackend::Groq,
527 LLMBackend::AzureOpenAI,
528 ];
529
530 for backend in backends {
531 let response = MockModelListResponse {
532 models: vec![],
533 raw_entries: vec![],
534 backend: backend.clone(),
535 };
536 let result_backend = response.get_backend();
538 assert!(std::mem::discriminant(&result_backend) == std::mem::discriminant(&backend));
539 }
540 }
541
542 #[tokio::test]
543 async fn test_models_provider_error_handling() {
544 let provider = MockModelsProvider::with_failure();
545
546 let result = provider.list_models(None).await;
547 match result {
548 Err(LLMError::ProviderError(msg)) => {
549 assert_eq!(msg, "Mock provider failed");
550 }
551 _ => panic!("Expected ProviderError"),
552 }
553 }
554
555 #[tokio::test]
556 async fn test_models_provider_with_many_models() {
557 let models: Vec<String> = (0..100).map(|i| format!("model-{i:03}")).collect();
558 let provider = MockModelsProvider::new(models.clone());
559
560 let result = provider.list_models(None).await;
561 assert!(result.is_ok());
562
563 let response = result.unwrap();
564 assert_eq!(response.get_models().len(), 100);
565 assert_eq!(response.get_models_raw().len(), 100);
566 assert_eq!(response.get_models()[0], "model-000");
567 assert_eq!(response.get_models()[99], "model-099");
568 }
569
570 #[test]
571 fn test_model_list_request_with_empty_filter() {
572 let request = ModelListRequest {
573 filter: Some("".to_string()),
574 };
575 assert_eq!(request.filter, Some("".to_string()));
576 }
577
578 #[test]
579 fn test_model_list_request_with_special_chars() {
580 let request = ModelListRequest {
581 filter: Some("model-name_with.special-chars".to_string()),
582 };
583 assert_eq!(
584 request.filter,
585 Some("model-name_with.special-chars".to_string())
586 );
587 }
588
589 #[test]
590 fn test_model_entry_with_empty_id() {
591 let entry = MockModelEntry {
592 id: "".to_string(),
593 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
594 extra_data: serde_json::Value::Null,
595 };
596 assert_eq!(entry.get_id(), "");
597 }
598
599 #[test]
600 fn test_model_entry_with_unicode_id() {
601 let entry = MockModelEntry {
602 id: "模型-测试-🤖".to_string(),
603 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
604 extra_data: serde_json::Value::Null,
605 };
606 assert_eq!(entry.get_id(), "模型-测试-🤖");
607 }
608}