diffuser_scheduler/kinds/
der.rs1use std::{fmt::Formatter, mem::transmute, str::FromStr};
2
3use serde::{
4 de::{Error, Visitor},
5 Deserialize, Deserializer,
6};
7
8use crate::{utils::SchedulerKindDeserializer, DiffuserSchedulerKind};
9
10impl<'de> Deserialize<'de> for DiffuserSchedulerKind {
11 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
12 where
13 D: Deserializer<'de>,
14 {
15 deserializer.deserialize_any(SchedulerKindDeserializer)
16 }
17}
18
19impl FromStr for DiffuserSchedulerKind {
20 type Err = ();
21
22 fn from_str(s: &str) -> Result<Self, Self::Err> {
23 if s.eq_ignore_ascii_case("ddim") {
24 Ok(DiffuserSchedulerKind::DDIM)
25 }
26 else if s.eq_ignore_ascii_case("euler") {
27 Ok(DiffuserSchedulerKind::Euler)
28 }
29 else {
30 Err(())
31 }
32 }
33}
34
35impl<'de> Visitor<'de> for SchedulerKindDeserializer {
36 type Value = DiffuserSchedulerKind;
37
38 fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
39 formatter.write_str("Except one of `Euler`, `DDIM`")
40 }
41 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
42 where
43 E: Error,
44 {
45 unsafe { Ok(transmute::<u8, DiffuserSchedulerKind>(v as u8)) }
46 }
47 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
48 where
49 E: Error,
50 {
51 DiffuserSchedulerKind::from_str(v).map_err(|_| Error::custom(format!("Unknown scheduler type `{v}`")))
52 }
53 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
54 where
55 E: Error,
56 {
57 self.visit_str(String::from_utf8_lossy(v).as_ref())
58 }
59}