libsirt/
de.rs

1use serde::Deserialize;
2use serde::de::{self, Deserializer, Error, IntoDeserializer, MapAccess, SeqAccess, Visitor};
3
4use crate::error::SirtDeserializeError;
5use crate::{Block, Value, parse_input};
6
7struct ListAccess<'a> {
8    iter: std::slice::Iter<'a, Value>,
9}
10
11impl<'a, 'de> SeqAccess<'de> for ListAccess<'a> {
12    type Error = SirtDeserializeError;
13
14    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
15    where
16        T: de::DeserializeSeed<'de>,
17    {
18        match self.iter.next() {
19            Some(value) => {
20                let de = ValueDeserializer { value };
21                seed.deserialize(de).map(Some)
22            }
23            None => Ok(None),
24        }
25    }
26}
27
28struct BlockMapAccess<'a> {
29    iter: std::collections::hash_map::Iter<'a, String, Value>,
30    value: Option<&'a Value>,
31}
32
33impl<'de, 'a> MapAccess<'de> for BlockMapAccess<'a> {
34    type Error = SirtDeserializeError;
35
36    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
37    where
38        K: de::DeserializeSeed<'de>,
39    {
40        if let Some((key, value)) = self.iter.next() {
41            self.value = Some(value);
42            seed.deserialize(key.as_str().into_deserializer()).map(Some)
43        } else {
44            Ok(None)
45        }
46    }
47
48    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
49    where
50        V: de::DeserializeSeed<'de>,
51    {
52        let value = self
53            .value
54            .take()
55            .ok_or(SirtDeserializeError::custom("MapAccess error"))?;
56
57        seed.deserialize(ValueDeserializer { value })
58    }
59}
60
61pub struct BlockDeserializer<'a> {
62    block: &'a Block,
63}
64
65impl<'de, 'a> Deserializer<'de> for BlockDeserializer<'a> {
66    type Error = SirtDeserializeError;
67
68    fn deserialize_struct<V>(
69        self,
70        _name: &'static str,
71        _fields: &'static [&'static str],
72        visitor: V,
73    ) -> Result<V::Value, Self::Error>
74    where
75        V: Visitor<'de>,
76    {
77        visitor.visit_map(BlockMapAccess {
78            iter: self.block.fields.iter(),
79            value: None,
80        })
81    }
82
83    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
84    where
85        V: Visitor<'de>,
86    {
87        self.deserialize_struct("", &[], visitor)
88    }
89
90    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
91    where
92        V: Visitor<'de>,
93    {
94        unreachable!()
95    }
96
97    serde::forward_to_deserialize_any! {
98        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
99        bytes byte_buf option unit unit_struct newtype_struct seq tuple
100        tuple_struct enum identifier ignored_any
101    }
102}
103
104struct ValueDeserializer<'a> {
105    value: &'a Value,
106}
107
108impl<'de, 'a> Deserializer<'de> for ValueDeserializer<'a> {
109    type Error = SirtDeserializeError;
110
111    fn deserialize_any<V>(self, v: V) -> Result<V::Value, Self::Error>
112    where
113        V: Visitor<'de>,
114    {
115        match self.value {
116            Value::List(_) => self.deserialize_seq(v),
117            Value::Bool(_) => self.deserialize_bool(v),
118            Value::Int(_) => self.deserialize_i64(v),
119            Value::Text(_) => self.deserialize_string(v),
120        }
121    }
122
123    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
124    where
125        V: Visitor<'de>,
126    {
127        match self.value {
128            Value::Text(s) => visitor.visit_string(s.clone()),
129            other => Err(SirtDeserializeError::custom(format!(
130                "expected string, found {other:?}"
131            ))),
132        }
133    }
134
135    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
136    where
137        V: Visitor<'de>,
138    {
139        match self.value {
140            Value::Int(num) => visitor.visit_i64(*num),
141            other => Err(SirtDeserializeError::custom(format!(
142                "expected i64, found {other:?}"
143            ))),
144        }
145    }
146
147    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
148    where
149        V: Visitor<'de>,
150    {
151        match self.value {
152            Value::Bool(b) => visitor.visit_bool(*b),
153            other => Err(SirtDeserializeError::custom(format!(
154                "expected bool, found {other:?}"
155            ))),
156        }
157    }
158
159    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
160    where
161        V: Visitor<'de>,
162    {
163        match self.value {
164            Value::List(list) => visitor.visit_seq(ListAccess { iter: list.iter() }),
165            other => Err(SirtDeserializeError::custom(format!(
166                "expected list, found {other:?}"
167            ))),
168        }
169    }
170
171    serde::forward_to_deserialize_any! {
172        i8 i16 i32 i128 u8 u16 u32 u64 u128 f32 f64 char str
173        bytes byte_buf option unit unit_struct newtype_struct
174        tuple_struct map struct enum identifier ignored_any tuple
175    }
176}
177
178pub fn from_str<'de, T>(input: &str) -> Result<T, SirtDeserializeError>
179where
180    T: Deserialize<'de>,
181{
182    let blocks = parse_input(input).map_err(|err| {
183        SirtDeserializeError::custom(format!("failed to parse sirt format: {err:?}"))
184    })?;
185
186    if blocks.len() < 1 {
187        return Err(SirtDeserializeError::custom("expected at least one block"));
188    }
189
190    let block = &blocks[0];
191    let des = BlockDeserializer { block };
192    T::deserialize(des)
193}
194
195pub fn from_str_named<'de, T>(input: &str, name: &str) -> Result<T, SirtDeserializeError>
196where
197    T: Deserialize<'de>,
198{
199    let blocks = parse_input(input).map_err(|err| {
200        SirtDeserializeError::custom(format!("failed to parse sirt format: {err:?}"))
201    })?;
202
203    let block = blocks.iter().find(|block| block.get_name() == name).ok_or(
204        SirtDeserializeError::custom(format!("couldn't find block with name '{name}'")),
205    )?;
206
207    let des = BlockDeserializer { block };
208    T::deserialize(des)
209}
210
211pub fn from_str_named_iter<T>(
212    input: &str,
213    name: &str,
214) -> Result<impl Iterator<Item = Result<T, SirtDeserializeError>>, SirtDeserializeError>
215where
216    T: for<'de> Deserialize<'de>,
217{
218    let blocks = parse_input(input).map_err(|err| {
219        SirtDeserializeError::custom(format!("failed to parse sirt format: '{err:?}'"))
220    })?;
221
222    Ok(blocks
223        .into_iter()
224        .filter(move |block| block.get_name() == name)
225        .map(|block| T::deserialize(BlockDeserializer { block: &block })))
226}