diffuser_scheduler/types/ddim/
der.rs

1use std::fmt::Formatter;
2
3use serde::{
4    __private::de::Content,
5    de::{MapAccess, Visitor},
6    Deserialize, Deserializer,
7};
8
9use crate::DDIMScheduler;
10
11struct DDIMDeserialize<'i> {
12    place: &'i mut DDIMScheduler,
13}
14
15impl<'de> Deserialize<'de> for DDIMScheduler {
16    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
17    where
18        D: Deserializer<'de>,
19    {
20        let mut out = DDIMScheduler::default();
21        DDIMScheduler::deserialize_in_place(deserializer, &mut out)?;
22        Ok(out)
23    }
24    fn deserialize_in_place<D>(deserializer: D, place: &mut Self) -> Result<(), D::Error>
25    where
26        D: Deserializer<'de>,
27    {
28        deserializer.deserialize_map(DDIMDeserialize { place })
29    }
30}
31
32impl<'i, 'de> Visitor<'de> for DDIMDeserialize<'i> {
33    type Value = ();
34
35    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
36        formatter.write_str("Except scheduler name or Scheduler object.")
37    }
38
39    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
40    where
41        A: MapAccess<'de>,
42    {
43        while let Some(key) = map.next_key::<&str>()? {
44            match key {
45                "init_noise_sigma" => {
46                    self.place.init_noise_sigma = map.next_value()?;
47                }
48                _ => {
49                    println!("Unknown {}", key);
50                    let _ = map.next_value::<Content>()?;
51                }
52            }
53        }
54        Ok(())
55    }
56}