diffuser_scheduler/types/
der.rs1use std::fmt::Formatter;
2
3use serde::{
4 __private::de::{ContentDeserializer, TaggedContentVisitor},
5 de::{Error, MapAccess, Visitor},
6 Deserialize, Deserializer,
7};
8
9use crate::{
10 utils::{SchedulerDeserializer, SchedulerKindDeserializer},
11 DDIMScheduler, DiffuserScheduler, DiffuserSchedulerKind, EulerDiscreteScheduler,
12};
13
14impl<'de> Deserialize<'de> for DiffuserScheduler {
15 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
16 where
17 D: Deserializer<'de>,
18 {
19 deserializer.deserialize_any(SchedulerDeserializer)
20 }
21}
22
23impl<'de> Visitor<'de> for SchedulerDeserializer {
24 type Value = DiffuserScheduler;
25
26 fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
27 formatter.write_str("Except scheduler name or Scheduler object.")
28 }
29
30 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
31 where
32 E: Error,
33 {
34 let kind = SchedulerKindDeserializer.visit_str(v)?;
35 Ok(kind.as_scheduler())
36 }
37
38 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
39 where
40 A: MapAccess<'de>,
41 {
42 let visitor = TaggedContentVisitor::<DiffuserSchedulerKind>::new("type", "internally tagged enum DiffuserScheduler");
43 let tagged = visitor.visit_map(map)?;
44 match tagged.tag {
45 DiffuserSchedulerKind::Euler => {
46 let scheduler = EulerDiscreteScheduler::deserialize(ContentDeserializer::<A::Error>::new(tagged.content))?;
47 Ok(DiffuserScheduler::Euler(Box::new(scheduler)))
48 }
49 DiffuserSchedulerKind::DDIM => {
50 let scheduler = DDIMScheduler::deserialize(ContentDeserializer::<A::Error>::new(tagged.content))?;
51 Ok(DiffuserScheduler::DDIM(Box::new(scheduler)))
52 }
53 }
54 }
55}