clojure_reader/
de.rs

1use alloc::collections::BTreeMap;
2use alloc::format;
3use alloc::string::ToString;
4use alloc::vec::Vec;
5use core::fmt::Display;
6
7use crate::edn::{self, Edn};
8
9use serde::de::{
10  self, DeserializeSeed, EnumAccess, IntoDeserializer, MapAccess, SeqAccess, VariantAccess, Visitor,
11};
12use serde::{forward_to_deserialize_any, Deserialize};
13
14use crate::error::{Code, Error, Result};
15
16/// Deserializer for a EDN formatted &str.
17///
18/// # Errors
19///
20/// See [`crate::error::Error`].
21/// Always returns `Code::Serde`.
22pub fn from_str<'a, T>(s: &'a str) -> Result<T>
23where
24  T: Deserialize<'a>,
25{
26  let edn = edn::read_string(s)?;
27  let t = T::deserialize(edn)?;
28  Ok(t)
29}
30
31impl de::Error for Error {
32  #[cold]
33  fn custom<T: Display>(msg: T) -> Self {
34    Self { code: Code::Serde(msg.to_string()), line: None, column: None, ptr: None }
35  }
36}
37
38fn get_int_from_edn(edn: &Edn<'_>) -> Result<i64> {
39  if let Edn::Int(i) = edn {
40    return Ok(*i);
41  };
42  Err(de::Error::custom(format!("cannot convert {edn:?} to i64")))
43}
44
45impl<'de> de::Deserializer<'de> for Edn<'de> {
46  type Error = Error;
47
48  fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
49  where
50    V: Visitor<'de>,
51  {
52    match self {
53      Edn::Key(k) => visitor.visit_borrowed_str(k),
54      Edn::Str(s) | Edn::Symbol(s) => visitor.visit_borrowed_str(s),
55      Edn::Int(i) => visitor.visit_i64(i),
56      #[cfg(feature = "floats")]
57      Edn::Double(d) => visitor.visit_f64(*d),
58      Edn::Char(c) => visitor.visit_char(c),
59      Edn::Bool(b) => visitor.visit_bool(b),
60      Edn::Nil => visitor.visit_unit(),
61      Edn::Vector(mut list) | Edn::List(mut list) => {
62        list.reverse();
63        Ok(visitor.visit_seq(SeqEdn::new(list))?)
64      }
65      Edn::Map(mut map) => {
66        if map == BTreeMap::new() {
67          visitor.visit_unit()
68        } else {
69          visitor.visit_map(MapEdn::new(&mut map))
70        }
71      }
72      Edn::Set(set) => {
73        let mut s: Vec<Edn<'_>> = set.into_iter().collect();
74        s.reverse();
75        Ok(visitor.visit_seq(SeqEdn::new(s))?)
76      }
77      // Things like rational numbers and custom tags can't be represented in rust types
78      _ => Err(de::Error::custom(format!("Don't know how to convert {self:?} into any"))),
79    }
80  }
81
82  forward_to_deserialize_any! {
83    bool i64 f64 char str unit map ignored_any seq tuple_struct
84  }
85
86  fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
87  where
88    V: Visitor<'de>,
89  {
90    let int = i8::try_from(get_int_from_edn(&self)?);
91    int.map_or_else(
92      |_| Err(de::Error::custom(format!("can't convert {int:?} into i8"))),
93      |i| visitor.visit_i8(i),
94    )
95  }
96
97  fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
98  where
99    V: Visitor<'de>,
100  {
101    let int = i16::try_from(get_int_from_edn(&self)?);
102    int.map_or_else(
103      |_| Err(de::Error::custom(format!("can't convert {int:?} into i16"))),
104      |i| visitor.visit_i16(i),
105    )
106  }
107
108  fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
109  where
110    V: Visitor<'de>,
111  {
112    let int = i32::try_from(get_int_from_edn(&self)?);
113    int.map_or_else(
114      |_| Err(de::Error::custom(format!("can't convert {int:?} into i32"))),
115      |i| visitor.visit_i32(i),
116    )
117  }
118
119  fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
120  where
121    V: Visitor<'de>,
122  {
123    let int = u8::try_from(get_int_from_edn(&self)?);
124    int.map_or_else(
125      |_| Err(de::Error::custom(format!("can't convert {int:?} into u8"))),
126      |i| visitor.visit_u8(i),
127    )
128  }
129
130  fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
131  where
132    V: Visitor<'de>,
133  {
134    let int = u16::try_from(get_int_from_edn(&self)?);
135    int.map_or_else(
136      |_| Err(de::Error::custom(format!("can't convert {int:?} into u16"))),
137      |i| visitor.visit_u16(i),
138    )
139  }
140
141  fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
142  where
143    V: Visitor<'de>,
144  {
145    let int = u32::try_from(get_int_from_edn(&self)?);
146    int.map_or_else(
147      |_| Err(de::Error::custom(format!("can't convert {int:?} into u32"))),
148      |i| visitor.visit_u32(i),
149    )
150  }
151
152  fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
153  where
154    V: Visitor<'de>,
155  {
156    let int = u64::try_from(get_int_from_edn(&self)?);
157    int.map_or_else(
158      |_| Err(de::Error::custom(format!("can't convert {int:?} into u64"))),
159      |i| visitor.visit_u64(i),
160    )
161  }
162
163  fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
164  where
165    V: Visitor<'de>,
166  {
167    let _ = visitor; // hush clippy
168    #[cfg(feature = "floats")]
169    if let Edn::Double(f) = self {
170      #[expect(clippy::cast_possible_truncation)]
171      return visitor.visit_f32(*f as f32);
172    };
173    Err(de::Error::custom(format!("can't convert {self:?} into f32")))
174  }
175
176  fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
177  where
178    V: Visitor<'de>,
179  {
180    self.deserialize_str(visitor)
181  }
182
183  fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value>
184  where
185    V: Visitor<'de>,
186  {
187    Err(de::Error::custom("deserialize_bytes is unimplemented/unused".to_string()))
188  }
189
190  fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
191  where
192    V: Visitor<'de>,
193  {
194    self.deserialize_bytes(visitor)
195  }
196
197  fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
198  where
199    V: Visitor<'de>,
200  {
201    if self == Edn::Nil {
202      visitor.visit_none()
203    } else {
204      visitor.visit_some(self)
205    }
206  }
207
208  fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
209  where
210    V: Visitor<'de>,
211  {
212    self.deserialize_unit(visitor)
213  }
214
215  fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
216  where
217    V: Visitor<'de>,
218  {
219    visitor.visit_newtype_struct(self)
220  }
221
222  fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
223  where
224    V: Visitor<'de>,
225  {
226    self.deserialize_seq(visitor)
227  }
228
229  fn deserialize_struct<V>(
230    self,
231    _name: &'static str,
232    _fields: &'static [&'static str],
233    visitor: V,
234  ) -> Result<V::Value>
235  where
236    V: Visitor<'de>,
237  {
238    self.deserialize_map(visitor)
239  }
240
241  fn deserialize_enum<V>(
242    self,
243    name: &'static str,
244    _variants: &'static [&'static str],
245    visitor: V,
246  ) -> Result<V::Value>
247  where
248    V: Visitor<'de>,
249  {
250    let Edn::Tagged(tag, ref edn) = self else {
251      return Err(de::Error::custom(format!("can't convert {self:?} into Tagged for enum")));
252    };
253
254    let mut split = tag.split('/');
255    let (Some(tag_first), Some(tag_second)) = (split.next(), split.next()) else {
256      return Err(de::Error::custom(format!("Expected namespace in {tag} for Tagged for enum")));
257    };
258
259    if name != tag_first {
260      return Err(de::Error::custom(format!("namespace in {tag} can't be matched to {name}")));
261    }
262
263    visitor.visit_enum(EnumEdn::new(edn, tag_second))
264  }
265
266  fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
267  where
268    V: Visitor<'de>,
269  {
270    self.deserialize_str(visitor)
271  }
272}
273
274struct SeqEdn<'de> {
275  de: Vec<Edn<'de>>,
276}
277
278impl<'de> SeqEdn<'de> {
279  const fn new(de: Vec<Edn<'de>>) -> Self {
280    SeqEdn { de }
281  }
282}
283
284impl<'de> SeqAccess<'de> for SeqEdn<'de> {
285  type Error = Error;
286
287  fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
288  where
289    T: DeserializeSeed<'de>,
290  {
291    let s = self.de.pop();
292    match s {
293      Some(e) => Ok(Some(seed.deserialize(e)?)),
294      None => Ok(None),
295    }
296  }
297}
298
299struct MapEdn<'a, 'de> {
300  de: &'a mut BTreeMap<Edn<'de>, Edn<'de>>,
301}
302
303impl<'a, 'de> MapEdn<'a, 'de> {
304  fn new(de: &'a mut BTreeMap<Edn<'de>, Edn<'de>>) -> Self {
305    MapEdn { de }
306  }
307}
308
309impl<'de> MapAccess<'de> for MapEdn<'_, 'de> {
310  type Error = Error;
311
312  fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
313  where
314    K: DeserializeSeed<'de>,
315  {
316    while let Some((k, _)) = self.de.first_key_value() {
317      // pass over any keys that serde can't handle
318      match k {
319        Edn::Key(_) | Edn::Symbol(_) | Edn::Str(_) => {
320          return Ok(Some(seed.deserialize(k.clone())?))
321        }
322        _ => {
323          self.de.pop_first();
324          continue;
325        }
326      }
327    }
328    Ok(None)
329  }
330
331  fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
332  where
333    V: DeserializeSeed<'de>,
334  {
335    let (_, v) = self.de.pop_first().expect("kv must exist, because next_key_seed succeeded");
336    seed.deserialize(v)
337  }
338}
339
340#[derive(Debug)]
341struct EnumEdn<'a, 'de> {
342  de: &'a Edn<'de>,
343  variant: &'a str,
344}
345
346impl<'a, 'de> EnumEdn<'a, 'de> {
347  const fn new(de: &'a Edn<'de>, variant: &'a str) -> Self {
348    EnumEdn { de, variant }
349  }
350}
351
352impl<'de> EnumAccess<'de> for EnumEdn<'_, 'de> {
353  type Error = Error;
354  type Variant = Self;
355
356  fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
357  where
358    V: DeserializeSeed<'de>,
359  {
360    let val = seed.deserialize(self.variant.into_deserializer())?;
361    Ok((val, self))
362  }
363}
364
365impl<'de> VariantAccess<'de> for EnumEdn<'_, 'de> {
366  type Error = Error;
367
368  fn unit_variant(self) -> Result<()> {
369    Ok(())
370  }
371
372  fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
373  where
374    T: DeserializeSeed<'de>,
375  {
376    seed.deserialize(self.de.clone())
377  }
378
379  fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
380  where
381    V: Visitor<'de>,
382  {
383    de::Deserializer::deserialize_seq(self.de.clone(), visitor)
384  }
385
386  fn struct_variant<V>(
387    self,
388    _fields: &'static [&'static str],
389    visitor: V,
390  ) -> core::result::Result<V::Value, Self::Error>
391  where
392    V: Visitor<'de>,
393  {
394    de::Deserializer::deserialize_map(self.de.clone(), visitor)
395  }
396}