1use serde::Deserialize;
2use serde::de::{self, Deserializer, Error, IntoDeserializer, MapAccess, SeqAccess, Visitor};
3
4use crate::error::SirtDeserializeError;
5use crate::{Block, Value, parse_input};
6
7struct ListAccess<'a> {
8 iter: std::slice::Iter<'a, Value>,
9}
10
11impl<'a, 'de> SeqAccess<'de> for ListAccess<'a> {
12 type Error = SirtDeserializeError;
13
14 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
15 where
16 T: de::DeserializeSeed<'de>,
17 {
18 match self.iter.next() {
19 Some(value) => {
20 let de = ValueDeserializer { value };
21 seed.deserialize(de).map(Some)
22 }
23 None => Ok(None),
24 }
25 }
26}
27
28struct BlockMapAccess<'a> {
29 iter: std::collections::hash_map::Iter<'a, String, Value>,
30 value: Option<&'a Value>,
31}
32
33impl<'de, 'a> MapAccess<'de> for BlockMapAccess<'a> {
34 type Error = SirtDeserializeError;
35
36 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
37 where
38 K: de::DeserializeSeed<'de>,
39 {
40 if let Some((key, value)) = self.iter.next() {
41 self.value = Some(value);
42 seed.deserialize(key.as_str().into_deserializer()).map(Some)
43 } else {
44 Ok(None)
45 }
46 }
47
48 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
49 where
50 V: de::DeserializeSeed<'de>,
51 {
52 let value = self
53 .value
54 .take()
55 .ok_or(SirtDeserializeError::custom("MapAccess error"))?;
56
57 seed.deserialize(ValueDeserializer { value })
58 }
59}
60
61pub struct BlockDeserializer<'a> {
62 block: &'a Block,
63}
64
65impl<'de, 'a> Deserializer<'de> for BlockDeserializer<'a> {
66 type Error = SirtDeserializeError;
67
68 fn deserialize_struct<V>(
69 self,
70 _name: &'static str,
71 _fields: &'static [&'static str],
72 visitor: V,
73 ) -> Result<V::Value, Self::Error>
74 where
75 V: Visitor<'de>,
76 {
77 visitor.visit_map(BlockMapAccess {
78 iter: self.block.fields.iter(),
79 value: None,
80 })
81 }
82
83 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
84 where
85 V: Visitor<'de>,
86 {
87 self.deserialize_struct("", &[], visitor)
88 }
89
90 fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
91 where
92 V: Visitor<'de>,
93 {
94 unreachable!()
95 }
96
97 serde::forward_to_deserialize_any! {
98 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
99 bytes byte_buf option unit unit_struct newtype_struct seq tuple
100 tuple_struct enum identifier ignored_any
101 }
102}
103
104struct ValueDeserializer<'a> {
105 value: &'a Value,
106}
107
108impl<'de, 'a> Deserializer<'de> for ValueDeserializer<'a> {
109 type Error = SirtDeserializeError;
110
111 fn deserialize_any<V>(self, v: V) -> Result<V::Value, Self::Error>
112 where
113 V: Visitor<'de>,
114 {
115 match self.value {
116 Value::List(_) => self.deserialize_seq(v),
117 Value::Bool(_) => self.deserialize_bool(v),
118 Value::Int(_) => self.deserialize_i64(v),
119 Value::Text(_) => self.deserialize_string(v),
120 }
121 }
122
123 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
124 where
125 V: Visitor<'de>,
126 {
127 match self.value {
128 Value::Text(s) => visitor.visit_string(s.clone()),
129 other => Err(SirtDeserializeError::custom(format!(
130 "expected string, found {other:?}"
131 ))),
132 }
133 }
134
135 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
136 where
137 V: Visitor<'de>,
138 {
139 match self.value {
140 Value::Int(num) => visitor.visit_i64(*num),
141 other => Err(SirtDeserializeError::custom(format!(
142 "expected i64, found {other:?}"
143 ))),
144 }
145 }
146
147 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
148 where
149 V: Visitor<'de>,
150 {
151 match self.value {
152 Value::Bool(b) => visitor.visit_bool(*b),
153 other => Err(SirtDeserializeError::custom(format!(
154 "expected bool, found {other:?}"
155 ))),
156 }
157 }
158
159 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
160 where
161 V: Visitor<'de>,
162 {
163 match self.value {
164 Value::List(list) => visitor.visit_seq(ListAccess { iter: list.iter() }),
165 other => Err(SirtDeserializeError::custom(format!(
166 "expected list, found {other:?}"
167 ))),
168 }
169 }
170
171 serde::forward_to_deserialize_any! {
172 i8 i16 i32 i128 u8 u16 u32 u64 u128 f32 f64 char str
173 bytes byte_buf option unit unit_struct newtype_struct
174 tuple_struct map struct enum identifier ignored_any tuple
175 }
176}
177
178pub fn from_str<'de, T>(input: &str) -> Result<T, SirtDeserializeError>
179where
180 T: Deserialize<'de>,
181{
182 let blocks = parse_input(input).map_err(|err| {
183 SirtDeserializeError::custom(format!("failed to parse sirt format: {err:?}"))
184 })?;
185
186 if blocks.len() < 1 {
187 return Err(SirtDeserializeError::custom("expected at least one block"));
188 }
189
190 let block = &blocks[0];
191 let des = BlockDeserializer { block };
192 T::deserialize(des)
193}
194
195pub fn from_str_named<'de, T>(input: &str, name: &str) -> Result<T, SirtDeserializeError>
196where
197 T: Deserialize<'de>,
198{
199 let blocks = parse_input(input).map_err(|err| {
200 SirtDeserializeError::custom(format!("failed to parse sirt format: {err:?}"))
201 })?;
202
203 let block = blocks.iter().find(|block| block.get_name() == name).ok_or(
204 SirtDeserializeError::custom(format!("couldn't find block with name '{name}'")),
205 )?;
206
207 let des = BlockDeserializer { block };
208 T::deserialize(des)
209}
210
211pub fn from_str_named_iter<T>(
212 input: &str,
213 name: &str,
214) -> Result<impl Iterator<Item = Result<T, SirtDeserializeError>>, SirtDeserializeError>
215where
216 T: for<'de> Deserialize<'de>,
217{
218 let blocks = parse_input(input).map_err(|err| {
219 SirtDeserializeError::custom(format!("failed to parse sirt format: '{err:?}'"))
220 })?;
221
222 Ok(blocks
223 .into_iter()
224 .filter(move |block| block.get_name() == name)
225 .map(|block| T::deserialize(BlockDeserializer { block: &block })))
226}