use core::fmt;
use serde::{Deserialize, Serialize};
use std::str::FromStr;
#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
pub enum GptModel {
#[serde(rename(serialize = "gpt-4o"))]
GPT4o,
#[serde(rename(serialize = "gpt-4o-mini"))]
GPT4oMini,
}
impl<'de> Deserialize<'de> for GptModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct GptModelVisitor;
impl<'de> serde::de::Visitor<'de> for GptModelVisitor {
type Value = GptModel;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string representing a GPT model")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match value {
"gpt-4o" => Ok(GptModel::GPT4o),
"gpt-4o-mini" => Ok(GptModel::GPT4oMini),
_ if value.starts_with("gpt-4o-mini-") => Ok(GptModel::GPT4oMini),
_ if value.starts_with("gpt-4o-") => Ok(GptModel::GPT4o),
_ => Err(E::custom(format!("Unknown GPT model: {}", value))),
}
}
}
deserializer.deserialize_str(GptModelVisitor)
}
}
impl FromStr for GptModel {
type Err = crate::error::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gpt-4o" => Ok(Self::GPT4o),
_ => Err(crate::error::Error::ModelNotSupported(s.to_string())),
}
}
}
impl fmt::Display for GptModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::GPT4o => write!(f, "gpt-4o"),
Self::GPT4oMini => write!(f, "gpt-4o-mini"),
}
}
}