Documentation
use serde::{
    Deserialize, Deserializer, Serialize, Serializer,
    de::{Error, Expected, Visitor},
};
use std::fmt::Formatter;

use super::Ref;

#[derive(Clone, serde::Deserialize, serde::Serialize)]
#[serde(untagged)]
pub enum Number {
    Int(isize),
    UInt(usize),
    Float(f64),
}

impl Eq for Number {}

impl PartialEq for Number {
    fn eq(&self, other: &Self) -> bool {
        match (self, other) {
            (Self::Int(left), Self::Int(right)) => left == right,
            (Self::UInt(left), Self::UInt(right)) => left == right,
            (Self::Float(left), Self::Float(right)) => left == right,
            _ => false,
        }
    }
}

macro_rules! impl_from_for_number {
    ( $( $ty:ident => $pat:ident $( as $as:ident )? ),* ) => {
        $(
        impl From<$ty> for Number {
            fn from(value: $ty) -> Self {
                Self::$pat(value $( as $as )?)
            }
        }
        )*
    };
}

#[rustfmt::skip]
impl_from_for_number!(
    f32 => Float as f64, f64 => Float,
    i8 => Int as isize, i16 => Int as isize, i32 => Int as isize, i64 => Int as isize,
    u8 => UInt as usize, u16 => UInt as usize, u32 => UInt as usize, u64 => UInt as usize,
    isize => Int, usize => UInt
);

#[derive(Serialize, Clone, PartialEq, Eq, Default)]
pub enum OpenApiVersion {
    #[serde(rename = "3.1.0")]
    #[default]
    Version31,
}

impl<'de> Deserialize<'de> for OpenApiVersion {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        struct VersionVisitor;

        impl<'v> Visitor<'v> for VersionVisitor {
            type Value = OpenApiVersion;

            fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
                formatter.write_str("a version string in 3.1.x format")
            }

            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
            where
                E: Error,
            {
                self.visit_string(v.to_string())
            }

            fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
            where
                E: Error,
            {
                let version = v
                    .split('.')
                    .flat_map(|digit| digit.parse::<i8>())
                    .collect::<Vec<_>>();

                if version.len() == 3 && version.first() == Some(&3) && version.get(1) == Some(&1) {
                    Ok(OpenApiVersion::Version31)
                } else {
                    let expected: &dyn Expected = &"3.1.0";
                    Err(Error::invalid_value(
                        serde::de::Unexpected::Str(&v),
                        expected,
                    ))
                }
            }
        }

        deserializer.deserialize_string(VersionVisitor)
    }
}

#[derive(PartialEq, Eq, Clone, Default)]
#[allow(missing_docs)]
pub enum Deprecated {
    True,
    #[default]
    False,
}

impl Serialize for Deprecated {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        serializer.serialize_bool(matches!(self, Self::True))
    }
}

impl<'de> Deserialize<'de> for Deprecated {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        deserialize_bool_flag(deserializer, Deprecated::True, Deprecated::False)
    }
}

#[derive(PartialEq, Eq, Clone, Default)]
#[allow(missing_docs)]
pub enum Required {
    True,
    #[default]
    False,
}

impl Serialize for Required {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        serializer.serialize_bool(matches!(self, Self::True))
    }
}

impl<'de> Deserialize<'de> for Required {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        deserialize_bool_flag(deserializer, Required::True, Required::False)
    }
}

#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)]
#[serde(untagged)]
pub enum RefOr<T> {
    Ref(Ref),
    T(T),
}

fn deserialize_bool_flag<'de, D, T>(
    deserializer: D,
    true_value: T,
    false_value: T,
) -> Result<T, D::Error>
where
    D: Deserializer<'de>,
    T: Clone,
{
    struct BoolVisitor<T> {
        true_value: T,
        false_value: T,
    }

    impl<'de, T: Clone> Visitor<'de> for BoolVisitor<T> {
        type Value = T;

        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
            formatter.write_str("a bool true or false")
        }

        fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
        where
            E: serde::de::Error,
        {
            Ok(if v { self.true_value } else { self.false_value })
        }
    }

    deserializer.deserialize_bool(BoolVisitor {
        true_value,
        false_value,
    })
}