Documentation
use serde::{Deserialize, Serialize};
use std::fmt::Display;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct O1 {
    pub model: O1Model,
    pub version: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum O1Model {
    Preview,
    Mini,
}

impl O1 {
    pub fn new(variant: O1Model) -> Self {
        Self {
            model: variant,
            version: None,
        }
    }
}

impl Display for O1 {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match (&self.model, &self.version) {
            (O1Model::Preview, None) => write!(f, "o1-preview"),
            (O1Model::Preview, Some(date)) => write!(f, "o1-preview-{}", date),
            (O1Model::Mini, None) => write!(f, "o1-mini"),
            (O1Model::Mini, Some(date)) => write!(f, "o1-mini-{}", date),
        }
    }
}

impl std::str::FromStr for O1 {
    type Err = String;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s {
            "o1-preview" => Ok(O1 {
                model: O1Model::Preview,
                version: None,
            }),
            "o1-mini" => Ok(O1 {
                model: O1Model::Mini,
                version: None,
            }),
            _ if s.starts_with("o1-preview-") => {
                let version = s
                    .strip_prefix("o1-preview-")
                    .ok_or_else(|| format!("Invalid model version {}", s))?;
                Ok(O1 {
                    model: O1Model::Preview,
                    version: Some(version.to_string()),
                })
            }
            _ if s.starts_with("o1-mini-") => {
                let version = s
                    .strip_prefix("o1-mini-")
                    .ok_or_else(|| format!("Invalid model version {}", s))?;
                Ok(O1 {
                    model: O1Model::Mini,
                    version: Some(version.to_string()),
                })
            }
            _ => Err(format!("Unknown GPT model: {}", s)),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use pretty_assertions::assert_eq;
    use std::str::FromStr;

    #[test]
    fn test_display() {
        assert_eq!(O1::new(O1Model::Preview).to_string(), "o1-preview");
        assert_eq!(
            O1 {
                model: O1Model::Preview,
                version: Some("2024-09-12".to_string())
            }
            .to_string(),
            "o1-preview-2024-09-12"
        );

        assert_eq!(O1::new(O1Model::Mini).to_string(), "o1-mini");
        assert_eq!(
            O1 {
                model: O1Model::Mini,
                version: Some("2024-09-12".to_string())
            }
            .to_string(),
            "o1-mini-2024-09-12"
        );
    }

    #[test]
    fn should_parse_str() {
        assert_eq!(O1::from_str("o1-preview"), Ok(O1::new(O1Model::Preview)));
        assert_eq!(
            O1::from_str("o1-preview-2024-09-12"),
            Ok(O1 {
                model: O1Model::Preview,
                version: Some("2024-09-12".to_string())
            })
        );
        assert_eq!(
            O1::from_str("o1-preview-extra-part"),
            Ok(O1 {
                model: O1Model::Preview,
                version: Some("extra-part".to_string())
            })
        );

        assert_eq!(O1::from_str("o1-mini"), Ok(O1::new(O1Model::Mini)));
        assert_eq!(
            O1::from_str("o1-mini-2024-09-12"),
            Ok(O1 {
                model: O1Model::Mini,
                version: Some("2024-09-12".to_string())
            })
        );

        assert!(O1::from_str("invalid-model").is_err());
    }
}