use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Model {
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "type")]
pub r#type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub created_at: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub owned_by: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_length: Option<u32>,
}
impl Model {
pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
id: id.into(),
name: Some(name.into()),
description: None,
r#type: None,
created_at: None,
owned_by: None,
context_length: None,
}
}
pub fn from_id(id: impl Into<String>) -> Self {
Self {
id: id.into(),
name: None,
description: None,
r#type: None,
created_at: None,
owned_by: None,
context_length: None,
}
}
pub fn display_name(&self) -> &str {
self.name.as_ref().unwrap_or(&self.id)
}
}
impl fmt::Display for Model {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.display_name())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelList {
pub data: Vec<Model>,
}
impl ModelList {
pub fn new(data: Vec<Model>) -> Self {
Self { data }
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn iter(&self) -> std::slice::Iter<'_, Model> {
self.data.iter()
}
}
impl IntoIterator for ModelList {
type Item = Model;
type IntoIter = std::vec::IntoIter<Model>;
fn into_iter(self) -> Self::IntoIter {
self.data.into_iter()
}
}
impl<'a> IntoIterator for &'a ModelList {
type Item = &'a Model;
type IntoIter = std::slice::Iter<'a, Model>;
fn into_iter(self) -> Self::IntoIter {
self.data.iter()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelListingError {
ApiError {
status_code: u16,
message: String,
},
RequestError {
message: String,
},
ParseError {
message: String,
},
AuthError {
message: String,
},
RateLimitError {
message: String,
},
ServiceUnavailable {
message: String,
},
UnknownError {
message: String,
},
}
const RESPONSE_BODY_PREVIEW_LIMIT: usize = 2048;
fn format_response_body_preview(body: &[u8]) -> String {
let preview_len = body.len().min(RESPONSE_BODY_PREVIEW_LIMIT);
let mut preview = String::from_utf8_lossy(&body[..preview_len]).into_owned();
if body.len() > RESPONSE_BODY_PREVIEW_LIMIT {
preview.push_str(&format!(
"\n...<truncated {} bytes>",
body.len() - RESPONSE_BODY_PREVIEW_LIMIT
));
}
preview
}
fn format_response_context(
provider: &str,
path: &str,
details: impl fmt::Display,
body: &[u8],
) -> String {
format!(
"provider={provider}\npath={path}\n{details}\nbody_bytes={}\nresponse_body_preview:\n{}",
body.len(),
format_response_body_preview(body)
)
}
impl ModelListingError {
pub fn api_error(status_code: u16, message: impl Into<String>) -> Self {
Self::ApiError {
status_code,
message: message.into(),
}
}
pub fn request_error(message: impl Into<String>) -> Self {
Self::RequestError {
message: message.into(),
}
}
pub fn parse_error(message: impl Into<String>) -> Self {
Self::ParseError {
message: message.into(),
}
}
pub(crate) fn api_error_with_context(
provider: &str,
path: &str,
status_code: u16,
body: &[u8],
) -> Self {
let message =
format_response_context(provider, path, format_args!("status={status_code}"), body);
Self::api_error(status_code, message)
}
pub(crate) fn parse_error_with_context(
provider: &str,
path: &str,
error: &serde_json::Error,
body: &[u8],
) -> Self {
let message =
format_response_context(provider, path, format_args!("parse_error={error}"), body);
Self::parse_error(message)
}
pub(crate) fn parse_error_with_details(
provider: &str,
path: &str,
details: impl fmt::Display,
body: &[u8],
) -> Self {
let message = format_response_context(provider, path, details, body);
Self::parse_error(message)
}
pub fn auth_error(message: impl Into<String>) -> Self {
Self::AuthError {
message: message.into(),
}
}
pub fn rate_limit_error(message: impl Into<String>) -> Self {
Self::RateLimitError {
message: message.into(),
}
}
pub fn service_unavailable(message: impl Into<String>) -> Self {
Self::ServiceUnavailable {
message: message.into(),
}
}
pub fn unknown_error(message: impl Into<String>) -> Self {
Self::UnknownError {
message: message.into(),
}
}
}
impl fmt::Display for ModelListingError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ApiError {
status_code,
message,
} => write!(f, "API error (status {}): {}", status_code, message),
Self::RequestError { message } => write!(f, "Request error: {}", message),
Self::ParseError { message } => write!(f, "Parse error: {}", message),
Self::AuthError { message } => write!(f, "Authentication error: {}", message),
Self::RateLimitError { message } => write!(f, "Rate limit error: {}", message),
Self::ServiceUnavailable { message } => write!(f, "Service unavailable: {}", message),
Self::UnknownError { message } => write!(f, "Unknown error: {}", message),
}
}
}
impl std::error::Error for ModelListingError {}
impl From<crate::http_client::Error> for ModelListingError {
fn from(e: crate::http_client::Error) -> Self {
Self::request_error(e.to_string())
}
}
impl From<http::Error> for ModelListingError {
fn from(e: http::Error) -> Self {
Self::request_error(e.to_string())
}
}
impl From<serde_json::Error> for ModelListingError {
fn from(e: serde_json::Error) -> Self {
Self::parse_error(e.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_from_id() {
let model = Model::from_id("gpt-4");
assert_eq!(model.id, "gpt-4");
assert_eq!(model.name, None);
assert_eq!(model.description, None);
assert_eq!(model.r#type, None);
assert_eq!(model.created_at, None);
assert_eq!(model.owned_by, None);
assert_eq!(model.context_length, None);
}
#[test]
fn test_model_new() {
let model = Model::new("gpt-4", "GPT-4");
assert_eq!(model.id, "gpt-4");
assert_eq!(model.name, Some("GPT-4".to_string()));
}
#[test]
fn test_model_display_name() {
let model_with_name = Model::new("gpt-4", "GPT-4");
assert_eq!(model_with_name.display_name(), "GPT-4");
let model_without_name = Model::from_id("gpt-4");
assert_eq!(model_without_name.display_name(), "gpt-4");
}
#[test]
fn test_model_display() {
let model = Model::new("gpt-4", "GPT-4");
assert_eq!(format!("{}", model), "GPT-4");
}
#[test]
fn test_model_list_new() {
let list = ModelList::new(vec![Model::from_id("gpt-4")]);
assert_eq!(list.len(), 1);
}
#[test]
fn test_model_list_empty() {
let list = ModelList::new(vec![]);
assert!(list.is_empty());
assert_eq!(list.len(), 0);
}
#[test]
fn test_model_list_iter() {
let list = ModelList::new(vec![
Model::from_id("gpt-4"),
Model::from_id("gpt-3.5-turbo"),
]);
let models: Vec<_> = list.iter().collect();
assert_eq!(models.len(), 2);
}
#[test]
fn test_model_list_into_iter() {
let list = ModelList::new(vec![
Model::from_id("gpt-4"),
Model::from_id("gpt-3.5-turbo"),
]);
let models: Vec<_> = list.into_iter().collect();
assert_eq!(models.len(), 2);
}
#[test]
fn test_model_listing_error_display() {
let error = ModelListingError::api_error(404, "Not found");
assert_eq!(error.to_string(), "API error (status 404): Not found");
let error = ModelListingError::request_error("Connection failed");
assert_eq!(error.to_string(), "Request error: Connection failed");
let error = ModelListingError::parse_error("Invalid JSON");
assert_eq!(error.to_string(), "Parse error: Invalid JSON");
let error = ModelListingError::auth_error("Invalid API key");
assert_eq!(error.to_string(), "Authentication error: Invalid API key");
let error = ModelListingError::rate_limit_error("Too many requests");
assert_eq!(error.to_string(), "Rate limit error: Too many requests");
let error = ModelListingError::service_unavailable("Maintenance mode");
assert_eq!(error.to_string(), "Service unavailable: Maintenance mode");
let error = ModelListingError::unknown_error("Something went wrong");
assert_eq!(error.to_string(), "Unknown error: Something went wrong");
}
#[test]
fn test_model_serde() {
let model = Model {
id: "gpt-4".to_string(),
name: Some("GPT-4".to_string()),
description: None,
r#type: Some("chat".to_string()),
created_at: Some(1677610600),
owned_by: Some("openai".to_string()),
context_length: Some(8192),
};
let json = serde_json::to_string(&model).unwrap();
assert!(json.contains("gpt-4"));
assert!(json.contains("GPT-4"));
let deserialized: Model = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.id, "gpt-4");
assert_eq!(deserialized.name, Some("GPT-4".to_string()));
}
#[test]
fn test_model_list_serde() {
let list = ModelList {
data: vec![Model::from_id("gpt-4")],
};
let json = serde_json::to_string(&list).unwrap();
assert!(json.contains("gpt-4"));
let deserialized: ModelList = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.len(), 1);
}
#[test]
fn test_model_listing_error_serde() {
let error = ModelListingError::api_error(404, "Not found");
let json = serde_json::to_string(&error).unwrap();
assert!(json.contains("ApiError"));
let deserialized: ModelListingError = serde_json::from_str(&json).unwrap();
match deserialized {
ModelListingError::ApiError {
status_code,
message,
} => {
assert_eq!(status_code, 404);
assert_eq!(message, "Not found");
}
_ => panic!("Expected ApiError"),
}
}
#[test]
fn test_format_response_body_preview_without_truncation() {
let preview = format_response_body_preview(br#"{"ok":true}"#);
assert_eq!(preview, r#"{"ok":true}"#);
}
#[test]
fn test_format_response_body_preview_with_truncation() {
let body = vec![b'a'; RESPONSE_BODY_PREVIEW_LIMIT + 3];
let preview = format_response_body_preview(&body);
assert!(preview.starts_with(&"a".repeat(RESPONSE_BODY_PREVIEW_LIMIT)));
assert!(preview.ends_with("\n...<truncated 3 bytes>"));
}
#[test]
fn test_api_error_with_context_includes_provider_path_and_preview() {
let error = ModelListingError::api_error_with_context(
"Gemini",
"/v1beta/models?pageSize=1000",
500,
br#"{"error":"boom"}"#,
);
match error {
ModelListingError::ApiError {
status_code,
message,
} => {
assert_eq!(status_code, 500);
assert!(message.contains("provider=Gemini"));
assert!(message.contains("path=/v1beta/models?pageSize=1000"));
assert!(message.contains("status=500"));
assert!(message.contains(r#"{"error":"boom"}"#));
}
_ => panic!("Expected ApiError"),
}
}
#[test]
fn test_parse_error_with_context_includes_parse_error_and_preview() {
let body = br#"{"models":[{"displayName":"broken"}]}"#;
let parse_error = serde_json::from_slice::<serde_json::Value>(b"{")
.expect_err("expected malformed JSON to fail");
let error = ModelListingError::parse_error_with_context(
"Gemini",
"/v1beta/models?pageSize=1000",
&parse_error,
body,
);
match error {
ModelListingError::ParseError { message } => {
assert!(message.contains("provider=Gemini"));
assert!(message.contains("path=/v1beta/models?pageSize=1000"));
assert!(message.contains("parse_error=EOF while parsing an object"));
assert!(message.contains(r#"{"models":[{"displayName":"broken"}]}"#));
}
_ => panic!("Expected ParseError"),
}
}
}