use core::fmt;
use serde::{Deserialize, Serialize};
use std::str::FromStr;
#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
pub enum Gpt {
#[serde(rename(serialize = "gpt-4"))]
GPT4,
#[serde(rename(serialize = "gpt-4o"))]
GPT4o,
#[serde(rename(serialize = "gpt-4o-mini"))]
GPT4oMini,
#[serde(rename(serialize = "gpt-4-turbo"))]
GPT4Turbo,
#[serde(rename(serialize = "gpt-3.5-turbo"))]
GPT35Turbo,
}
impl<'de> Deserialize<'de> for Gpt {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct GptModelVisitor;
impl<'de> serde::de::Visitor<'de> for GptModelVisitor {
type Value = Gpt;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string representing a GPT model")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match value {
"gpt-4" => Ok(Gpt::GPT4),
"gpt-4o" => Ok(Gpt::GPT4o),
"gpt-4o-mini" => Ok(Gpt::GPT4oMini),
"gpt-4-turbo" => Ok(Gpt::GPT4Turbo),
"chatgpt-4o-latest" => Ok(Gpt::GPT4o),
_ if value.starts_with("gpt-3.5-turbo") => Ok(Gpt::GPT35Turbo),
_ if value.starts_with("gpt-4-turbo-") => Ok(Gpt::GPT4Turbo),
_ if value.starts_with("gpt-4-") => Ok(Gpt::GPT4),
_ if value.starts_with("gpt-4o-mini-") => Ok(Gpt::GPT4oMini),
_ if value.starts_with("gpt-4o-") => Ok(Gpt::GPT4o),
_ => Err(E::custom(format!("Unknown GPT model: {}", value))),
}
}
}
deserializer.deserialize_str(GptModelVisitor)
}
}
impl FromStr for Gpt {
type Err = crate::error::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gpt-4" => Ok(Gpt::GPT4),
"gpt-4o" => Ok(Gpt::GPT4o),
"gpt-4o-mini" => Ok(Gpt::GPT4oMini),
"gpt-4-turbo" => Ok(Gpt::GPT4Turbo),
"chatgpt-4o-latest" => Ok(Gpt::GPT4o),
_ if s.starts_with("gpt-3.5-turbo") => Ok(Gpt::GPT35Turbo),
_ if s.starts_with("gpt-4o-mini-") => Ok(Gpt::GPT4oMini),
_ if s.starts_with("gpt-4o-") => Ok(Gpt::GPT4o),
_ if s.starts_with("gpt-4-turbo-") => Ok(Gpt::GPT4Turbo),
_ if s.starts_with("gpt-4-") => Ok(Gpt::GPT4),
_ => Err(crate::error::Error::ModelNotSupported(s.to_string())),
}
}
}
impl fmt::Display for Gpt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::GPT4 => write!(f, "gpt-4"),
Self::GPT4o => write!(f, "gpt-4o"),
Self::GPT4oMini => write!(f, "gpt-4o-mini"),
Self::GPT4Turbo => write!(f, "gpt-4-turbo"),
Self::GPT35Turbo => write!(f, "gpt-3.5-turbo"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use serde_json::from_str;
#[test]
fn should_deserialize_gpt_models() {
assert_eq!(from_str::<Gpt>(r#""gpt-4""#).unwrap(), Gpt::GPT4);
assert_eq!(from_str::<Gpt>(r#""gpt-4o""#).unwrap(), Gpt::GPT4o);
assert_eq!(from_str::<Gpt>(r#""gpt-4o-mini""#).unwrap(), Gpt::GPT4oMini);
assert_eq!(from_str::<Gpt>(r#""gpt-4-turbo""#).unwrap(), Gpt::GPT4Turbo);
assert_eq!(
from_str::<Gpt>(r#""chatgpt-4o-latest""#).unwrap(),
Gpt::GPT4o
);
assert_eq!(
from_str::<Gpt>(r#""gpt-3.5-turbo""#).unwrap(),
Gpt::GPT35Turbo
);
assert_eq!(
from_str::<Gpt>(r#""gpt-3.5-turbo-0125""#).unwrap(),
Gpt::GPT35Turbo
);
assert_eq!(
from_str::<Gpt>(r#""gpt-4-0125-preview""#).unwrap(),
Gpt::GPT4
);
assert_eq!(
from_str::<Gpt>(r#""gpt-4o-2024-05-13""#).unwrap(),
Gpt::GPT4o
);
assert_eq!(
from_str::<Gpt>(r#""gpt-4o-mini-1234""#).unwrap(),
Gpt::GPT4oMini
);
assert_eq!(
from_str::<Gpt>(r#""gpt-4-turbo-2024-04-09""#).unwrap(),
Gpt::GPT4Turbo
);
assert!(from_str::<Gpt>(r#""unknown-model""#).is_err());
}
#[test]
fn test_gpt_model_from_str() {
assert_eq!("gpt-4".parse::<Gpt>().unwrap(), Gpt::GPT4);
assert_eq!("gpt-4o".parse::<Gpt>().unwrap(), Gpt::GPT4o);
assert_eq!("gpt-4o-mini".parse::<Gpt>().unwrap(), Gpt::GPT4oMini);
assert_eq!("gpt-4-turbo".parse::<Gpt>().unwrap(), Gpt::GPT4Turbo);
assert_eq!("chatgpt-4o-latest".parse::<Gpt>().unwrap(), Gpt::GPT4o);
assert_eq!("gpt-3.5-turbo".parse::<Gpt>().unwrap(), Gpt::GPT35Turbo);
assert_eq!(
"gpt-3.5-turbo-0125".parse::<Gpt>().unwrap(),
Gpt::GPT35Turbo
);
assert_eq!("gpt-4-9012".parse::<Gpt>().unwrap(), Gpt::GPT4);
assert_eq!("gpt-4o-5678".parse::<Gpt>().unwrap(), Gpt::GPT4o);
assert_eq!("gpt-4o-mini-1234".parse::<Gpt>().unwrap(), Gpt::GPT4oMini);
assert_eq!(
"gpt-4-turbo-2024-04-09".parse::<Gpt>().unwrap(),
Gpt::GPT4Turbo
);
assert!("unknown-model".parse::<Gpt>().is_err());
}
#[test]
fn should_display_gpt_models() {
assert_eq!(Gpt::GPT35Turbo.to_string(), "gpt-3.5-turbo");
assert_eq!(Gpt::GPT4.to_string(), "gpt-4");
assert_eq!(Gpt::GPT4o.to_string(), "gpt-4o");
assert_eq!(Gpt::GPT4oMini.to_string(), "gpt-4o-mini");
assert_eq!(Gpt::GPT4Turbo.to_string(), "gpt-4-turbo");
}
}