1use crate::{builder::LLMBackend, error::LLMError};
2use async_trait::async_trait;
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::fmt::Debug;
7
8pub trait ModelListResponse: std::fmt::Debug {
9 fn get_models(&self) -> Vec<String>;
10 fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>>;
11 fn get_backend(&self) -> String;
15}
16
17pub trait ModelListRawEntry: Debug {
18 fn get_id(&self) -> String;
19 fn get_created_at(&self) -> DateTime<Utc>;
20 fn get_raw(&self) -> serde_json::Value;
21}
22
23#[derive(Debug, Clone, Default)]
24pub struct ModelListRequest {
25 pub filter: Option<String>,
26}
27
28#[async_trait]
30pub trait ModelsProvider {
31 async fn list_models(
41 &self,
42 _request: Option<&ModelListRequest>,
43 ) -> Result<Box<dyn ModelListResponse>, LLMError> {
44 Err(LLMError::ProviderError(
45 "List Models not supported".to_string(),
46 ))
47 }
48}
49
50#[derive(Clone, Debug, Deserialize, Serialize)]
52pub struct StandardModelEntry {
53 pub id: String,
54 pub created: Option<u64>,
55 #[serde(flatten)]
56 pub extra: Value,
57}
58
59impl ModelListRawEntry for StandardModelEntry {
60 fn get_id(&self) -> String {
61 self.id.clone()
62 }
63
64 fn get_created_at(&self) -> DateTime<Utc> {
65 self.created
66 .map(|t| DateTime::from_timestamp(t as i64, 0).unwrap_or_default())
67 .unwrap_or_default()
68 }
69
70 fn get_raw(&self) -> Value {
71 self.extra.clone()
72 }
73}
74
75#[derive(Clone, Debug, Deserialize, Serialize)]
77pub struct StandardModelListResponseInner {
78 pub data: Vec<StandardModelEntry>,
79}
80
81#[derive(Clone, Debug)]
83pub struct StandardModelListResponse {
84 pub inner: StandardModelListResponseInner,
85 pub backend: LLMBackend,
86}
87
88impl ModelListResponse for StandardModelListResponse {
89 fn get_models(&self) -> Vec<String> {
90 self.inner.data.iter().map(|e| e.id.clone()).collect()
91 }
92
93 fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
94 self.inner
95 .data
96 .iter()
97 .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
98 .collect()
99 }
100
101 fn get_backend(&self) -> String {
102 self.backend.to_string()
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::builder::LLMBackend;
110 use crate::error::LLMError;
111 use async_trait::async_trait;
112 use chrono::TimeZone;
113 use chrono::{DateTime, Utc};
114 use serde_json::Value;
115
116 #[derive(Debug, Clone)]
118 struct MockModelEntry {
119 id: String,
120 created_at: DateTime<Utc>,
121 extra_data: Value,
122 }
123
124 impl ModelListRawEntry for MockModelEntry {
125 fn get_id(&self) -> String {
126 self.id.clone()
127 }
128
129 fn get_created_at(&self) -> DateTime<Utc> {
130 self.created_at
131 }
132
133 fn get_raw(&self) -> Value {
134 self.extra_data.clone()
135 }
136 }
137
138 #[derive(Debug)]
139 struct MockModelListResponse {
140 models: Vec<String>,
141 raw_entries: Vec<MockModelEntry>,
142 backend: String,
144 }
145
146 impl ModelListResponse for MockModelListResponse {
147 fn get_models(&self) -> Vec<String> {
148 self.models.clone()
149 }
150
151 fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
152 self.raw_entries
153 .iter()
154 .map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
155 .collect()
156 }
157
158 fn get_backend(&self) -> String {
159 self.backend.clone()
160 }
161 }
162
163 struct MockModelsProvider {
164 should_fail: bool,
165 models: Vec<String>,
166 }
167
168 impl MockModelsProvider {
169 fn new(models: Vec<String>) -> Self {
170 Self {
171 should_fail: false,
172 models,
173 }
174 }
175
176 fn with_failure() -> Self {
177 Self {
178 should_fail: true,
179 models: vec![],
180 }
181 }
182 }
183
184 #[async_trait]
185 impl ModelsProvider for MockModelsProvider {
186 async fn list_models(
187 &self,
188 _request: Option<&ModelListRequest>,
189 ) -> Result<Box<dyn ModelListResponse>, LLMError> {
190 if self.should_fail {
191 return Err(LLMError::ProviderError("Mock provider failed".to_string()));
192 }
193
194 let raw_entries = self
195 .models
196 .iter()
197 .enumerate()
198 .map(|(i, model)| MockModelEntry {
199 id: model.clone(),
200 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
201 extra_data: serde_json::json!({
202 "index": i,
203 "description": format!("Model {}", model)
204 }),
205 })
206 .collect();
207
208 Ok(Box::new(MockModelListResponse {
209 models: self.models.clone(),
210 raw_entries,
211 backend: "openai".to_string(),
212 }))
213 }
214 }
215
216 struct DefaultModelsProvider;
218
219 #[async_trait]
220 impl ModelsProvider for DefaultModelsProvider {}
221
222 #[test]
223 fn test_model_list_request_default() {
224 let request = ModelListRequest::default();
225 assert!(request.filter.is_none());
226 }
227
228 #[test]
229 fn test_model_list_request_with_filter() {
230 let request = ModelListRequest {
231 filter: Some("gpt".to_string()),
232 };
233 assert_eq!(request.filter, Some("gpt".to_string()));
234 }
235
236 #[test]
237 fn test_model_list_request_clone() {
238 let request = ModelListRequest {
239 filter: Some("test".to_string()),
240 };
241 let cloned = request.clone();
242 assert_eq!(request.filter, cloned.filter);
243 }
244
245 #[test]
246 fn test_model_list_request_debug() {
247 let request = ModelListRequest {
248 filter: Some("debug_test".to_string()),
249 };
250 let debug_str = format!("{request:?}");
251 assert!(debug_str.contains("ModelListRequest"));
252 assert!(debug_str.contains("debug_test"));
253 }
254
255 #[test]
256 fn test_mock_model_entry_creation() {
257 let now = Utc.timestamp_opt(1640995200, 0).unwrap();
258 let entry = MockModelEntry {
259 id: "test-model".to_string(),
260 created_at: now,
261 extra_data: serde_json::json!({"key": "value"}),
262 };
263
264 assert_eq!(entry.get_id(), "test-model");
265 assert_eq!(entry.get_created_at(), now);
266 assert_eq!(entry.get_raw(), serde_json::json!({"key": "value"}));
267 }
268
269 #[test]
270 fn test_mock_model_entry_debug() {
271 let entry = MockModelEntry {
272 id: "debug-model".to_string(),
273 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
274 extra_data: serde_json::json!({"debug": true}),
275 };
276
277 let debug_str = format!("{entry:?}");
278 assert!(debug_str.contains("MockModelEntry"));
279 assert!(debug_str.contains("debug-model"));
280 }
281
282 #[test]
283 fn test_mock_model_entry_clone() {
284 let entry = MockModelEntry {
285 id: "clone-model".to_string(),
286 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
287 extra_data: serde_json::json!({"clone": true}),
288 };
289
290 let cloned = entry.clone();
291 assert_eq!(entry.get_id(), cloned.get_id());
292 assert_eq!(entry.get_created_at(), cloned.get_created_at());
293 assert_eq!(entry.get_raw(), cloned.get_raw());
294 }
295
296 #[test]
297 fn test_mock_model_list_response_get_models() {
298 let models = vec!["model1".to_string(), "model2".to_string()];
299 let response = MockModelListResponse {
300 models: models.clone(),
301 raw_entries: vec![],
302 backend: "openai".to_string(),
303 };
304
305 assert_eq!(response.get_models(), models);
306 }
307
308 #[test]
309 fn test_mock_model_list_response_get_models_raw() {
310 let raw_entries = vec![
311 MockModelEntry {
312 id: "model1".to_string(),
313 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
314 extra_data: serde_json::json!({"index": 0}),
315 },
316 MockModelEntry {
317 id: "model2".to_string(),
318 created_at: Utc.timestamp_opt(1640995201, 0).unwrap(),
319 extra_data: serde_json::json!({"index": 1}),
320 },
321 ];
322
323 let response = MockModelListResponse {
324 models: vec!["model1".to_string(), "model2".to_string()],
325 raw_entries: raw_entries.clone(),
326 backend: "anthropic".to_string(),
327 };
328
329 let raw = response.get_models_raw();
330 assert_eq!(raw.len(), 2);
331 assert_eq!(raw[0].get_id(), "model1");
332 assert_eq!(raw[1].get_id(), "model2");
333 }
334
335 #[test]
336 fn test_mock_model_list_response_get_backend() {
337 let response = MockModelListResponse {
338 models: vec![],
339 raw_entries: vec![],
340 backend: "google".to_string(),
341 };
342
343 assert_eq!(response.get_backend(), "google");
344 }
345
346 #[tokio::test]
347 async fn test_mock_models_provider_success() {
348 let models = vec!["gpt-3.5-turbo".to_string(), "gpt-4".to_string()];
349 let provider = MockModelsProvider::new(models.clone());
350
351 let result = provider.list_models(None).await;
352 assert!(result.is_ok());
353
354 let response = result.unwrap();
355 assert_eq!(response.get_models(), models);
356 assert_eq!(response.get_models_raw().len(), 2);
357 assert_eq!(response.get_backend(), "openai");
358 }
359
360 #[tokio::test]
361 async fn test_mock_models_provider_with_request() {
362 let models = vec!["model1".to_string(), "model2".to_string()];
363 let provider = MockModelsProvider::new(models.clone());
364 let request = ModelListRequest {
365 filter: Some("gpt".to_string()),
366 };
367
368 let result = provider.list_models(Some(&request)).await;
369 assert!(result.is_ok());
370
371 let response = result.unwrap();
372 assert_eq!(response.get_models(), models);
373 }
374
375 #[tokio::test]
376 async fn test_mock_models_provider_failure() {
377 let provider = MockModelsProvider::with_failure();
378
379 let result = provider.list_models(None).await;
380 assert!(result.is_err());
381 assert!(
382 result
383 .unwrap_err()
384 .to_string()
385 .contains("Mock provider failed")
386 );
387 }
388
389 #[tokio::test]
390 async fn test_mock_models_provider_empty_models() {
391 let provider = MockModelsProvider::new(vec![]);
392
393 let result = provider.list_models(None).await;
394 assert!(result.is_ok());
395
396 let response = result.unwrap();
397 assert_eq!(response.get_models(), Vec::<String>::new());
398 assert_eq!(response.get_models_raw().len(), 0);
399 }
400
401 #[tokio::test]
402 async fn test_default_models_provider_not_supported() {
403 let provider = DefaultModelsProvider;
404
405 let result = provider.list_models(None).await;
406 assert!(result.is_err());
407 assert!(
408 result
409 .unwrap_err()
410 .to_string()
411 .contains("List Models not supported")
412 );
413 }
414
415 #[tokio::test]
416 async fn test_default_models_provider_with_request() {
417 let provider = DefaultModelsProvider;
418 let request = ModelListRequest {
419 filter: Some("test".to_string()),
420 };
421
422 let result = provider.list_models(Some(&request)).await;
423 assert!(result.is_err());
424 assert!(
425 result
426 .unwrap_err()
427 .to_string()
428 .contains("List Models not supported")
429 );
430 }
431
432 #[test]
433 fn test_model_list_raw_entry_trait_object() {
434 let entry = MockModelEntry {
435 id: "trait-test".to_string(),
436 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
437 extra_data: serde_json::json!({"test": "data"}),
438 };
439
440 let boxed: Box<dyn ModelListRawEntry> = Box::new(entry);
441 assert_eq!(boxed.get_id(), "trait-test");
442 assert_eq!(boxed.get_raw(), serde_json::json!({"test": "data"}));
443 }
444
445 #[test]
446 fn test_model_list_response_trait_object() {
447 let response = MockModelListResponse {
448 models: vec!["test-model".to_string()],
449 raw_entries: vec![],
450 backend: "ollama".to_string(),
451 };
452
453 let boxed: Box<dyn ModelListResponse> = Box::new(response);
454 assert_eq!(boxed.get_models(), vec!["test-model".to_string()]);
455 assert_eq!(boxed.get_backend(), "ollama");
456 }
457
458 #[test]
459 fn test_model_entry_with_complex_data() {
460 let complex_data = serde_json::json!({
461 "capabilities": ["chat", "completion"],
462 "max_tokens": 4096,
463 "pricing": {
464 "input": 0.0015,
465 "output": 0.002
466 }
467 });
468
469 let entry = MockModelEntry {
470 id: "complex-model".to_string(),
471 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
472 extra_data: complex_data.clone(),
473 };
474
475 assert_eq!(entry.get_raw(), complex_data);
476 assert_eq!(entry.get_raw()["capabilities"][0], "chat");
477 assert_eq!(entry.get_raw()["max_tokens"], 4096);
478 }
479
480 #[test]
481 fn test_model_entry_with_null_data() {
482 let entry = MockModelEntry {
483 id: "null-model".to_string(),
484 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
485 extra_data: serde_json::Value::Null,
486 };
487
488 assert_eq!(entry.get_raw(), serde_json::Value::Null);
489 }
490
491 #[test]
492 fn test_model_entry_time_ordering() {
493 let time1 = Utc.timestamp_opt(1640995200, 0).unwrap();
494 let time2 = time1 + chrono::Duration::seconds(1);
495
496 let entry1 = MockModelEntry {
497 id: "older".to_string(),
498 created_at: time1,
499 extra_data: serde_json::Value::Null,
500 };
501
502 let entry2 = MockModelEntry {
503 id: "newer".to_string(),
504 created_at: time2,
505 extra_data: serde_json::Value::Null,
506 };
507
508 assert!(entry1.get_created_at() < entry2.get_created_at());
509 }
510
511 #[test]
512 fn test_backend_variants() {
513 let cases = [
514 (LLMBackend::OpenAI, "openai"),
515 (LLMBackend::Anthropic, "anthropic"),
516 (LLMBackend::Ollama, "ollama"),
517 (LLMBackend::DeepSeek, "deepseek"),
518 (LLMBackend::XAI, "xai"),
519 (LLMBackend::Phind, "phind"),
520 (LLMBackend::Google, "google"),
521 (LLMBackend::Groq, "groq"),
522 (LLMBackend::AzureOpenAI, "azure-openai"),
523 ];
524
525 for (backend, expected) in cases {
526 let response = StandardModelListResponse {
528 inner: StandardModelListResponseInner { data: vec![] },
529 backend,
530 };
531 assert_eq!(response.get_backend(), expected);
532 }
533 }
534
535 #[tokio::test]
536 async fn test_models_provider_error_handling() {
537 let provider = MockModelsProvider::with_failure();
538
539 let result = provider.list_models(None).await;
540 match result {
541 Err(LLMError::ProviderError(msg)) => {
542 assert_eq!(msg, "Mock provider failed");
543 }
544 _ => panic!("Expected ProviderError"),
545 }
546 }
547
548 #[tokio::test]
549 async fn test_models_provider_with_many_models() {
550 let models: Vec<String> = (0..100).map(|i| format!("model-{i:03}")).collect();
551 let provider = MockModelsProvider::new(models.clone());
552
553 let result = provider.list_models(None).await;
554 assert!(result.is_ok());
555
556 let response = result.unwrap();
557 assert_eq!(response.get_models().len(), 100);
558 assert_eq!(response.get_models_raw().len(), 100);
559 assert_eq!(response.get_models()[0], "model-000");
560 assert_eq!(response.get_models()[99], "model-099");
561 }
562
563 #[test]
564 fn test_model_list_request_with_empty_filter() {
565 let request = ModelListRequest {
566 filter: Some("".to_string()),
567 };
568 assert_eq!(request.filter, Some("".to_string()));
569 }
570
571 #[test]
572 fn test_model_list_request_with_special_chars() {
573 let request = ModelListRequest {
574 filter: Some("model-name_with.special-chars".to_string()),
575 };
576 assert_eq!(
577 request.filter,
578 Some("model-name_with.special-chars".to_string())
579 );
580 }
581
582 #[test]
583 fn test_model_entry_with_empty_id() {
584 let entry = MockModelEntry {
585 id: "".to_string(),
586 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
587 extra_data: serde_json::Value::Null,
588 };
589 assert_eq!(entry.get_id(), "");
590 }
591
592 #[test]
593 fn test_model_entry_with_unicode_id() {
594 let entry = MockModelEntry {
595 id: "模型-测试-🤖".to_string(),
596 created_at: Utc.timestamp_opt(1640995200, 0).unwrap(),
597 extra_data: serde_json::Value::Null,
598 };
599 assert_eq!(entry.get_id(), "模型-测试-🤖");
600 }
601
602 #[test]
603 fn test_standard_model_entry_created_at_default() {
604 let entry = StandardModelEntry {
605 id: "m1".to_string(),
606 created: None,
607 extra: serde_json::Value::Null,
608 };
609 let dt = entry.get_created_at();
610 assert_eq!(dt.timestamp(), 0);
611 }
612
613 #[test]
614 fn test_standard_model_list_response_helpers() {
615 let inner = StandardModelListResponseInner {
616 data: vec![StandardModelEntry {
617 id: "m1".to_string(),
618 created: Some(1),
619 extra: serde_json::Value::Null,
620 }],
621 };
622 let response = StandardModelListResponse {
623 inner,
624 backend: LLMBackend::OpenAI,
625 };
626 assert_eq!(response.get_models(), vec!["m1".to_string()]);
627 assert_eq!(response.get_models_raw().len(), 1);
628 assert_eq!(response.get_backend(), "openai");
629 }
630}