diffuser-scheduler 0.0.0

Shared definition for diffuser schedulers
Documentation
use std::fmt::Formatter;

use serde::{
    __private::de::{ContentDeserializer, TaggedContentVisitor},
    de::{Error, MapAccess, Visitor},
    Deserialize, Deserializer,
};

use crate::{
    utils::{SchedulerDeserializer, SchedulerKindDeserializer},
    DDIMScheduler, DiffuserScheduler, DiffuserSchedulerKind, EulerDiscreteScheduler,
};

impl<'de> Deserialize<'de> for DiffuserScheduler {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        deserializer.deserialize_any(SchedulerDeserializer)
    }
}

impl<'de> Visitor<'de> for SchedulerDeserializer {
    type Value = DiffuserScheduler;

    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
        formatter.write_str("Except scheduler name or Scheduler object.")
    }

    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
    where
        E: Error,
    {
        let kind = SchedulerKindDeserializer.visit_str(v)?;
        Ok(kind.as_scheduler())
    }

    fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
    where
        A: MapAccess<'de>,
    {
        let visitor = TaggedContentVisitor::<DiffuserSchedulerKind>::new("type", "internally tagged enum DiffuserScheduler");
        let tagged = visitor.visit_map(map)?;
        match tagged.tag {
            DiffuserSchedulerKind::Euler => {
                let scheduler = EulerDiscreteScheduler::deserialize(ContentDeserializer::<A::Error>::new(tagged.content))?;
                Ok(DiffuserScheduler::Euler(Box::new(scheduler)))
            }
            DiffuserSchedulerKind::DDIM => {
                let scheduler = DDIMScheduler::deserialize(ContentDeserializer::<A::Error>::new(tagged.content))?;
                Ok(DiffuserScheduler::DDIM(Box::new(scheduler)))
            }
        }
    }
}