use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Model {
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "type")]
pub r#type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub created_at: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub owned_by: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_length: Option<u32>,
}
impl Model {
pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
id: id.into(),
name: Some(name.into()),
description: None,
r#type: None,
created_at: None,
owned_by: None,
context_length: None,
}
}
pub fn from_id(id: impl Into<String>) -> Self {
Self {
id: id.into(),
name: None,
description: None,
r#type: None,
created_at: None,
owned_by: None,
context_length: None,
}
}
pub fn display_name(&self) -> &str {
self.name.as_ref().unwrap_or(&self.id)
}
}
impl fmt::Display for Model {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.display_name())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelList {
pub data: Vec<Model>,
}
impl ModelList {
pub fn new(data: Vec<Model>) -> Self {
Self { data }
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn iter(&self) -> std::slice::Iter<'_, Model> {
self.data.iter()
}
}
impl IntoIterator for ModelList {
type Item = Model;
type IntoIter = std::vec::IntoIter<Model>;
fn into_iter(self) -> Self::IntoIter {
self.data.into_iter()
}
}
impl<'a> IntoIterator for &'a ModelList {
type Item = &'a Model;
type IntoIter = std::slice::Iter<'a, Model>;
fn into_iter(self) -> Self::IntoIter {
self.data.iter()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelListingError {
ApiError {
status_code: u16,
message: String,
},
RequestError {
message: String,
},
ParseError {
message: String,
},
AuthError {
message: String,
},
RateLimitError {
message: String,
},
ServiceUnavailable {
message: String,
},
UnknownError {
message: String,
},
}
impl ModelListingError {
pub fn api_error(status_code: u16, message: impl Into<String>) -> Self {
Self::ApiError {
status_code,
message: message.into(),
}
}
pub fn request_error(message: impl Into<String>) -> Self {
Self::RequestError {
message: message.into(),
}
}
pub fn parse_error(message: impl Into<String>) -> Self {
Self::ParseError {
message: message.into(),
}
}
pub fn auth_error(message: impl Into<String>) -> Self {
Self::AuthError {
message: message.into(),
}
}
pub fn rate_limit_error(message: impl Into<String>) -> Self {
Self::RateLimitError {
message: message.into(),
}
}
pub fn service_unavailable(message: impl Into<String>) -> Self {
Self::ServiceUnavailable {
message: message.into(),
}
}
pub fn unknown_error(message: impl Into<String>) -> Self {
Self::UnknownError {
message: message.into(),
}
}
}
impl fmt::Display for ModelListingError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ApiError {
status_code,
message,
} => write!(f, "API error (status {}): {}", status_code, message),
Self::RequestError { message } => write!(f, "Request error: {}", message),
Self::ParseError { message } => write!(f, "Parse error: {}", message),
Self::AuthError { message } => write!(f, "Authentication error: {}", message),
Self::RateLimitError { message } => write!(f, "Rate limit error: {}", message),
Self::ServiceUnavailable { message } => write!(f, "Service unavailable: {}", message),
Self::UnknownError { message } => write!(f, "Unknown error: {}", message),
}
}
}
impl std::error::Error for ModelListingError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_from_id() {
let model = Model::from_id("gpt-4");
assert_eq!(model.id, "gpt-4");
assert_eq!(model.name, None);
assert_eq!(model.description, None);
assert_eq!(model.r#type, None);
assert_eq!(model.created_at, None);
assert_eq!(model.owned_by, None);
assert_eq!(model.context_length, None);
}
#[test]
fn test_model_new() {
let model = Model::new("gpt-4", "GPT-4");
assert_eq!(model.id, "gpt-4");
assert_eq!(model.name, Some("GPT-4".to_string()));
}
#[test]
fn test_model_display_name() {
let model_with_name = Model::new("gpt-4", "GPT-4");
assert_eq!(model_with_name.display_name(), "GPT-4");
let model_without_name = Model::from_id("gpt-4");
assert_eq!(model_without_name.display_name(), "gpt-4");
}
#[test]
fn test_model_display() {
let model = Model::new("gpt-4", "GPT-4");
assert_eq!(format!("{}", model), "GPT-4");
}
#[test]
fn test_model_list_new() {
let list = ModelList::new(vec![Model::from_id("gpt-4")]);
assert_eq!(list.len(), 1);
}
#[test]
fn test_model_list_empty() {
let list = ModelList::new(vec![]);
assert!(list.is_empty());
assert_eq!(list.len(), 0);
}
#[test]
fn test_model_list_iter() {
let list = ModelList::new(vec![
Model::from_id("gpt-4"),
Model::from_id("gpt-3.5-turbo"),
]);
let models: Vec<_> = list.iter().collect();
assert_eq!(models.len(), 2);
}
#[test]
fn test_model_list_into_iter() {
let list = ModelList::new(vec![
Model::from_id("gpt-4"),
Model::from_id("gpt-3.5-turbo"),
]);
let models: Vec<_> = list.into_iter().collect();
assert_eq!(models.len(), 2);
}
#[test]
fn test_model_listing_error_display() {
let error = ModelListingError::api_error(404, "Not found");
assert_eq!(error.to_string(), "API error (status 404): Not found");
let error = ModelListingError::request_error("Connection failed");
assert_eq!(error.to_string(), "Request error: Connection failed");
let error = ModelListingError::parse_error("Invalid JSON");
assert_eq!(error.to_string(), "Parse error: Invalid JSON");
let error = ModelListingError::auth_error("Invalid API key");
assert_eq!(error.to_string(), "Authentication error: Invalid API key");
let error = ModelListingError::rate_limit_error("Too many requests");
assert_eq!(error.to_string(), "Rate limit error: Too many requests");
let error = ModelListingError::service_unavailable("Maintenance mode");
assert_eq!(error.to_string(), "Service unavailable: Maintenance mode");
let error = ModelListingError::unknown_error("Something went wrong");
assert_eq!(error.to_string(), "Unknown error: Something went wrong");
}
#[test]
fn test_model_serde() {
let model = Model {
id: "gpt-4".to_string(),
name: Some("GPT-4".to_string()),
description: None,
r#type: Some("chat".to_string()),
created_at: Some(1677610600),
owned_by: Some("openai".to_string()),
context_length: Some(8192),
};
let json = serde_json::to_string(&model).unwrap();
assert!(json.contains("gpt-4"));
assert!(json.contains("GPT-4"));
let deserialized: Model = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.id, "gpt-4");
assert_eq!(deserialized.name, Some("GPT-4".to_string()));
}
#[test]
fn test_model_list_serde() {
let list = ModelList {
data: vec![Model::from_id("gpt-4")],
};
let json = serde_json::to_string(&list).unwrap();
assert!(json.contains("gpt-4"));
let deserialized: ModelList = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.len(), 1);
}
#[test]
fn test_model_listing_error_serde() {
let error = ModelListingError::api_error(404, "Not found");
let json = serde_json::to_string(&error).unwrap();
assert!(json.contains("ApiError"));
let deserialized: ModelListingError = serde_json::from_str(&json).unwrap();
match deserialized {
ModelListingError::ApiError {
status_code,
message,
} => {
assert_eq!(status_code, 404);
assert_eq!(message, "Not found");
}
_ => panic!("Expected ApiError"),
}
}
}