diffuser_scheduler/types/
der.rs

1use std::fmt::Formatter;
2
3use serde::{
4    __private::de::{ContentDeserializer, TaggedContentVisitor},
5    de::{Error, MapAccess, Visitor},
6    Deserialize, Deserializer,
7};
8
9use crate::{
10    utils::{SchedulerDeserializer, SchedulerKindDeserializer},
11    DDIMScheduler, DiffuserScheduler, DiffuserSchedulerKind, EulerDiscreteScheduler,
12};
13
14impl<'de> Deserialize<'de> for DiffuserScheduler {
15    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
16    where
17        D: Deserializer<'de>,
18    {
19        deserializer.deserialize_any(SchedulerDeserializer)
20    }
21}
22
23impl<'de> Visitor<'de> for SchedulerDeserializer {
24    type Value = DiffuserScheduler;
25
26    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
27        formatter.write_str("Except scheduler name or Scheduler object.")
28    }
29
30    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
31    where
32        E: Error,
33    {
34        let kind = SchedulerKindDeserializer.visit_str(v)?;
35        Ok(kind.as_scheduler())
36    }
37
38    fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
39    where
40        A: MapAccess<'de>,
41    {
42        let visitor = TaggedContentVisitor::<DiffuserSchedulerKind>::new("type", "internally tagged enum DiffuserScheduler");
43        let tagged = visitor.visit_map(map)?;
44        match tagged.tag {
45            DiffuserSchedulerKind::Euler => {
46                let scheduler = EulerDiscreteScheduler::deserialize(ContentDeserializer::<A::Error>::new(tagged.content))?;
47                Ok(DiffuserScheduler::Euler(Box::new(scheduler)))
48            }
49            DiffuserSchedulerKind::DDIM => {
50                let scheduler = DDIMScheduler::deserialize(ContentDeserializer::<A::Error>::new(tagged.content))?;
51                Ok(DiffuserScheduler::DDIM(Box::new(scheduler)))
52            }
53        }
54    }
55}