diffuser-scheduler 0.0.0

Shared definition for diffuser schedulers
Documentation
use std::{fmt::Formatter, mem::transmute, str::FromStr};

use serde::{
    de::{Error, Visitor},
    Deserialize, Deserializer,
};

use crate::{utils::SchedulerKindDeserializer, DiffuserSchedulerKind};

impl<'de> Deserialize<'de> for DiffuserSchedulerKind {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        deserializer.deserialize_any(SchedulerKindDeserializer)
    }
}

impl FromStr for DiffuserSchedulerKind {
    type Err = ();

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        if s.eq_ignore_ascii_case("ddim") {
            Ok(DiffuserSchedulerKind::DDIM)
        }
        else if s.eq_ignore_ascii_case("euler") {
            Ok(DiffuserSchedulerKind::Euler)
        }
        else {
            Err(())
        }
    }
}

impl<'de> Visitor<'de> for SchedulerKindDeserializer {
    type Value = DiffuserSchedulerKind;

    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
        formatter.write_str("Except one of `Euler`, `DDIM`")
    }
    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
    where
        E: Error,
    {
        unsafe { Ok(transmute::<u8, DiffuserSchedulerKind>(v as u8)) }
    }
    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
    where
        E: Error,
    {
        DiffuserSchedulerKind::from_str(v).map_err(|_| Error::custom(format!("Unknown scheduler type `{v}`")))
    }
    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
    where
        E: Error,
    {
        self.visit_str(String::from_utf8_lossy(v).as_ref())
    }
}