1use serde::{Deserialize, Serialize};
8use std::fmt;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
48pub struct Model {
49 pub id: String,
51
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub name: Option<String>,
55
56 #[serde(skip_serializing_if = "Option::is_none")]
58 pub description: Option<String>,
59
60 #[serde(skip_serializing_if = "Option::is_none")]
62 #[serde(rename = "type")]
63 pub r#type: Option<String>,
64
65 #[serde(skip_serializing_if = "Option::is_none")]
67 pub created_at: Option<u64>,
68
69 #[serde(skip_serializing_if = "Option::is_none")]
71 pub owned_by: Option<String>,
72
73 #[serde(skip_serializing_if = "Option::is_none")]
75 pub context_length: Option<u32>,
76}
77
78impl Model {
79 pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
96 Self {
97 id: id.into(),
98 name: Some(name.into()),
99 description: None,
100 r#type: None,
101 created_at: None,
102 owned_by: None,
103 context_length: None,
104 }
105 }
106
107 pub fn from_id(id: impl Into<String>) -> Self {
123 Self {
124 id: id.into(),
125 name: None,
126 description: None,
127 r#type: None,
128 created_at: None,
129 owned_by: None,
130 context_length: None,
131 }
132 }
133
134 pub fn display_name(&self) -> &str {
151 self.name.as_ref().unwrap_or(&self.id)
152 }
153}
154
155impl fmt::Display for Model {
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157 write!(f, "{}", self.display_name())
158 }
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct ModelList {
187 pub data: Vec<Model>,
189}
190
191impl ModelList {
192 pub fn new(data: Vec<Model>) -> Self {
210 Self { data }
211 }
212
213 pub fn is_empty(&self) -> bool {
227 self.data.is_empty()
228 }
229
230 pub fn len(&self) -> usize {
244 self.data.len()
245 }
246
247 pub fn iter(&self) -> std::slice::Iter<'_, Model> {
264 self.data.iter()
265 }
266}
267
268impl IntoIterator for ModelList {
269 type Item = Model;
270 type IntoIter = std::vec::IntoIter<Model>;
271
272 fn into_iter(self) -> Self::IntoIter {
273 self.data.into_iter()
274 }
275}
276
277impl<'a> IntoIterator for &'a ModelList {
278 type Item = &'a Model;
279 type IntoIter = std::slice::Iter<'a, Model>;
280
281 fn into_iter(self) -> Self::IntoIter {
282 self.data.iter()
283 }
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
291pub enum ModelListingError {
292 ApiError {
294 status_code: u16,
296 message: String,
298 },
299
300 RequestError {
302 message: String,
304 },
305
306 ParseError {
308 message: String,
310 },
311
312 AuthError {
314 message: String,
316 },
317
318 RateLimitError {
320 message: String,
322 },
323
324 ServiceUnavailable {
326 message: String,
328 },
329
330 UnknownError {
332 message: String,
334 },
335}
336
337const RESPONSE_BODY_PREVIEW_LIMIT: usize = 2048;
338
339fn format_response_body_preview(body: &[u8]) -> String {
340 let preview_len = body.len().min(RESPONSE_BODY_PREVIEW_LIMIT);
341 let preview_bytes = body.get(..preview_len).unwrap_or(body);
342 let mut preview = String::from_utf8_lossy(preview_bytes).into_owned();
343
344 if body.len() > RESPONSE_BODY_PREVIEW_LIMIT {
345 preview.push_str(&format!(
346 "\n...<truncated {} bytes>",
347 body.len() - RESPONSE_BODY_PREVIEW_LIMIT
348 ));
349 }
350
351 preview
352}
353
354fn format_response_context(
355 provider: &str,
356 path: &str,
357 details: impl fmt::Display,
358 body: &[u8],
359) -> String {
360 format!(
361 "provider={provider}\npath={path}\n{details}\nbody_bytes={}\nresponse_body_preview:\n{}",
362 body.len(),
363 format_response_body_preview(body)
364 )
365}
366
367impl ModelListingError {
368 pub fn api_error(status_code: u16, message: impl Into<String>) -> Self {
370 Self::ApiError {
371 status_code,
372 message: message.into(),
373 }
374 }
375
376 pub fn request_error(message: impl Into<String>) -> Self {
378 Self::RequestError {
379 message: message.into(),
380 }
381 }
382
383 pub fn parse_error(message: impl Into<String>) -> Self {
385 Self::ParseError {
386 message: message.into(),
387 }
388 }
389
390 pub(crate) fn api_error_with_context(
391 provider: &str,
392 path: &str,
393 status_code: u16,
394 body: &[u8],
395 ) -> Self {
396 let message =
397 format_response_context(provider, path, format_args!("status={status_code}"), body);
398 Self::api_error(status_code, message)
399 }
400
401 pub(crate) fn parse_error_with_context(
402 provider: &str,
403 path: &str,
404 error: &serde_json::Error,
405 body: &[u8],
406 ) -> Self {
407 let message =
408 format_response_context(provider, path, format_args!("parse_error={error}"), body);
409 Self::parse_error(message)
410 }
411
412 pub(crate) fn parse_error_with_details(
413 provider: &str,
414 path: &str,
415 details: impl fmt::Display,
416 body: &[u8],
417 ) -> Self {
418 let message = format_response_context(provider, path, details, body);
419 Self::parse_error(message)
420 }
421
422 pub fn auth_error(message: impl Into<String>) -> Self {
424 Self::AuthError {
425 message: message.into(),
426 }
427 }
428
429 pub fn rate_limit_error(message: impl Into<String>) -> Self {
431 Self::RateLimitError {
432 message: message.into(),
433 }
434 }
435
436 pub fn service_unavailable(message: impl Into<String>) -> Self {
438 Self::ServiceUnavailable {
439 message: message.into(),
440 }
441 }
442
443 pub fn unknown_error(message: impl Into<String>) -> Self {
445 Self::UnknownError {
446 message: message.into(),
447 }
448 }
449}
450
451impl fmt::Display for ModelListingError {
452 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
453 match self {
454 Self::ApiError {
455 status_code,
456 message,
457 } => write!(f, "API error (status {}): {}", status_code, message),
458 Self::RequestError { message } => write!(f, "Request error: {}", message),
459 Self::ParseError { message } => write!(f, "Parse error: {}", message),
460 Self::AuthError { message } => write!(f, "Authentication error: {}", message),
461 Self::RateLimitError { message } => write!(f, "Rate limit error: {}", message),
462 Self::ServiceUnavailable { message } => write!(f, "Service unavailable: {}", message),
463 Self::UnknownError { message } => write!(f, "Unknown error: {}", message),
464 }
465 }
466}
467
468impl std::error::Error for ModelListingError {}
469
470impl From<crate::http_client::Error> for ModelListingError {
471 fn from(e: crate::http_client::Error) -> Self {
472 Self::request_error(e.to_string())
473 }
474}
475
476impl From<http::Error> for ModelListingError {
477 fn from(e: http::Error) -> Self {
478 Self::request_error(e.to_string())
479 }
480}
481
482impl From<serde_json::Error> for ModelListingError {
483 fn from(e: serde_json::Error) -> Self {
484 Self::parse_error(e.to_string())
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[test]
493 fn test_model_from_id() {
494 let model = Model::from_id("gpt-4");
495 assert_eq!(model.id, "gpt-4");
496 assert_eq!(model.name, None);
497 assert_eq!(model.description, None);
498 assert_eq!(model.r#type, None);
499 assert_eq!(model.created_at, None);
500 assert_eq!(model.owned_by, None);
501 assert_eq!(model.context_length, None);
502 }
503
504 #[test]
505 fn test_model_new() {
506 let model = Model::new("gpt-4", "GPT-4");
507 assert_eq!(model.id, "gpt-4");
508 assert_eq!(model.name, Some("GPT-4".to_string()));
509 }
510
511 #[test]
512 fn test_model_display_name() {
513 let model_with_name = Model::new("gpt-4", "GPT-4");
514 assert_eq!(model_with_name.display_name(), "GPT-4");
515
516 let model_without_name = Model::from_id("gpt-4");
517 assert_eq!(model_without_name.display_name(), "gpt-4");
518 }
519
520 #[test]
521 fn test_model_display() {
522 let model = Model::new("gpt-4", "GPT-4");
523 assert_eq!(format!("{}", model), "GPT-4");
524 }
525
526 #[test]
527 fn test_model_list_new() {
528 let list = ModelList::new(vec![Model::from_id("gpt-4")]);
529 assert_eq!(list.len(), 1);
530 }
531
532 #[test]
533 fn test_model_list_empty() {
534 let list = ModelList::new(vec![]);
535 assert!(list.is_empty());
536 assert_eq!(list.len(), 0);
537 }
538
539 #[test]
540 fn test_model_list_iter() {
541 let list = ModelList::new(vec![
542 Model::from_id("gpt-4"),
543 Model::from_id("gpt-3.5-turbo"),
544 ]);
545 let models: Vec<_> = list.iter().collect();
546 assert_eq!(models.len(), 2);
547 }
548
549 #[test]
550 fn test_model_list_into_iter() {
551 let list = ModelList::new(vec![
552 Model::from_id("gpt-4"),
553 Model::from_id("gpt-3.5-turbo"),
554 ]);
555 let models: Vec<_> = list.into_iter().collect();
556 assert_eq!(models.len(), 2);
557 }
558
559 #[test]
560 fn test_model_listing_error_display() {
561 let error = ModelListingError::api_error(404, "Not found");
562 assert_eq!(error.to_string(), "API error (status 404): Not found");
563
564 let error = ModelListingError::request_error("Connection failed");
565 assert_eq!(error.to_string(), "Request error: Connection failed");
566
567 let error = ModelListingError::parse_error("Invalid JSON");
568 assert_eq!(error.to_string(), "Parse error: Invalid JSON");
569
570 let error = ModelListingError::auth_error("Invalid API key");
571 assert_eq!(error.to_string(), "Authentication error: Invalid API key");
572
573 let error = ModelListingError::rate_limit_error("Too many requests");
574 assert_eq!(error.to_string(), "Rate limit error: Too many requests");
575
576 let error = ModelListingError::service_unavailable("Maintenance mode");
577 assert_eq!(error.to_string(), "Service unavailable: Maintenance mode");
578
579 let error = ModelListingError::unknown_error("Something went wrong");
580 assert_eq!(error.to_string(), "Unknown error: Something went wrong");
581 }
582
583 #[test]
584 fn test_model_serde() {
585 let model = Model {
586 id: "gpt-4".to_string(),
587 name: Some("GPT-4".to_string()),
588 description: None,
589 r#type: Some("chat".to_string()),
590 created_at: Some(1677610600),
591 owned_by: Some("openai".to_string()),
592 context_length: Some(8192),
593 };
594
595 let json = serde_json::to_string(&model).unwrap();
596 assert!(json.contains("gpt-4"));
597 assert!(json.contains("GPT-4"));
598
599 let deserialized: Model = serde_json::from_str(&json).unwrap();
600 assert_eq!(deserialized.id, "gpt-4");
601 assert_eq!(deserialized.name, Some("GPT-4".to_string()));
602 }
603
604 #[test]
605 fn test_model_list_serde() {
606 let list = ModelList {
607 data: vec![Model::from_id("gpt-4")],
608 };
609
610 let json = serde_json::to_string(&list).unwrap();
611 assert!(json.contains("gpt-4"));
612
613 let deserialized: ModelList = serde_json::from_str(&json).unwrap();
614 assert_eq!(deserialized.len(), 1);
615 }
616
617 #[test]
618 fn test_model_listing_error_serde() {
619 let error = ModelListingError::api_error(404, "Not found");
620
621 let json = serde_json::to_string(&error).unwrap();
622 assert!(json.contains("ApiError"));
623
624 let deserialized: ModelListingError = serde_json::from_str(&json).unwrap();
625 match deserialized {
626 ModelListingError::ApiError {
627 status_code,
628 message,
629 } => {
630 assert_eq!(status_code, 404);
631 assert_eq!(message, "Not found");
632 }
633 _ => panic!("Expected ApiError"),
634 }
635 }
636
637 #[test]
638 fn test_format_response_body_preview_without_truncation() {
639 let preview = format_response_body_preview(br#"{"ok":true}"#);
640 assert_eq!(preview, r#"{"ok":true}"#);
641 }
642
643 #[test]
644 fn test_format_response_body_preview_with_truncation() {
645 let body = vec![b'a'; RESPONSE_BODY_PREVIEW_LIMIT + 3];
646 let preview = format_response_body_preview(&body);
647
648 assert!(preview.starts_with(&"a".repeat(RESPONSE_BODY_PREVIEW_LIMIT)));
649 assert!(preview.ends_with("\n...<truncated 3 bytes>"));
650 }
651
652 #[test]
653 fn test_api_error_with_context_includes_provider_path_and_preview() {
654 let error = ModelListingError::api_error_with_context(
655 "Gemini",
656 "/v1beta/models?pageSize=1000",
657 500,
658 br#"{"error":"boom"}"#,
659 );
660
661 match error {
662 ModelListingError::ApiError {
663 status_code,
664 message,
665 } => {
666 assert_eq!(status_code, 500);
667 assert!(message.contains("provider=Gemini"));
668 assert!(message.contains("path=/v1beta/models?pageSize=1000"));
669 assert!(message.contains("status=500"));
670 assert!(message.contains(r#"{"error":"boom"}"#));
671 }
672 _ => panic!("Expected ApiError"),
673 }
674 }
675
676 #[test]
677 fn test_parse_error_with_context_includes_parse_error_and_preview() {
678 let body = br#"{"models":[{"displayName":"broken"}]}"#;
679 let parse_error = serde_json::from_slice::<serde_json::Value>(b"{")
680 .expect_err("expected malformed JSON to fail");
681 let error = ModelListingError::parse_error_with_context(
682 "Gemini",
683 "/v1beta/models?pageSize=1000",
684 &parse_error,
685 body,
686 );
687
688 match error {
689 ModelListingError::ParseError { message } => {
690 assert!(message.contains("provider=Gemini"));
691 assert!(message.contains("path=/v1beta/models?pageSize=1000"));
692 assert!(message.contains("parse_error=EOF while parsing an object"));
693 assert!(message.contains(r#"{"models":[{"displayName":"broken"}]}"#));
694 }
695 _ => panic!("Expected ParseError"),
696 }
697 }
698}