diffuser_scheduler/kinds/
der.rs

1use std::{fmt::Formatter, mem::transmute, str::FromStr};
2
3use serde::{
4    de::{Error, Visitor},
5    Deserialize, Deserializer,
6};
7
8use crate::{utils::SchedulerKindDeserializer, DiffuserSchedulerKind};
9
10impl<'de> Deserialize<'de> for DiffuserSchedulerKind {
11    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
12    where
13        D: Deserializer<'de>,
14    {
15        deserializer.deserialize_any(SchedulerKindDeserializer)
16    }
17}
18
19impl FromStr for DiffuserSchedulerKind {
20    type Err = ();
21
22    fn from_str(s: &str) -> Result<Self, Self::Err> {
23        if s.eq_ignore_ascii_case("ddim") {
24            Ok(DiffuserSchedulerKind::DDIM)
25        }
26        else if s.eq_ignore_ascii_case("euler") {
27            Ok(DiffuserSchedulerKind::Euler)
28        }
29        else {
30            Err(())
31        }
32    }
33}
34
35impl<'de> Visitor<'de> for SchedulerKindDeserializer {
36    type Value = DiffuserSchedulerKind;
37
38    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
39        formatter.write_str("Except one of `Euler`, `DDIM`")
40    }
41    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
42    where
43        E: Error,
44    {
45        unsafe { Ok(transmute::<u8, DiffuserSchedulerKind>(v as u8)) }
46    }
47    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
48    where
49        E: Error,
50    {
51        DiffuserSchedulerKind::from_str(v).map_err(|_| Error::custom(format!("Unknown scheduler type `{v}`")))
52    }
53    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
54    where
55        E: Error,
56    {
57        self.visit_str(String::from_utf8_lossy(v).as_ref())
58    }
59}