clojure_reader/
de.rs

1use alloc::collections::BTreeMap;
2use alloc::format;
3use alloc::string::ToString;
4use alloc::vec::Vec;
5use core::fmt::Display;
6
7use crate::edn::{self, Edn};
8
9use serde::de::{
10  self, DeserializeSeed, EnumAccess, IntoDeserializer, MapAccess, SeqAccess, VariantAccess, Visitor,
11};
12use serde::{Deserialize, forward_to_deserialize_any};
13
14use crate::error::{Code, Error, Result};
15
16/// Deserializer for a EDN formatted &str.
17///
18/// # Errors
19///
20/// See [`crate::error::Error`].
21/// Always returns `Code::Serde`.
22pub fn from_str<'a, T>(s: &'a str) -> Result<T>
23where
24  T: Deserialize<'a>,
25{
26  let edn = edn::read_string(s)?;
27  let t = T::deserialize(edn)?;
28  Ok(t)
29}
30
31impl de::Error for Error {
32  #[cold]
33  fn custom<T: Display>(msg: T) -> Self {
34    Self { code: Code::Serde(msg.to_string()), line: None, column: None, ptr: None }
35  }
36}
37
38fn get_int_from_edn(edn: &Edn<'_>) -> Result<i64> {
39  if let Edn::Int(i) = edn {
40    return Ok(*i);
41  }
42  Err(de::Error::custom(format!("cannot convert {edn:?} to i64")))
43}
44
45impl<'de> de::Deserializer<'de> for Edn<'de> {
46  type Error = Error;
47
48  fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
49  where
50    V: Visitor<'de>,
51  {
52    match self {
53      Edn::Key(k) => visitor.visit_borrowed_str(k),
54      Edn::Str(s) | Edn::Symbol(s) => visitor.visit_borrowed_str(s),
55      Edn::Int(i) => visitor.visit_i64(i),
56      #[cfg(feature = "floats")]
57      Edn::Double(d) => visitor.visit_f64(*d),
58      Edn::Char(c) => visitor.visit_char(c),
59      Edn::Bool(b) => visitor.visit_bool(b),
60      Edn::Nil => visitor.visit_unit(),
61      Edn::Vector(mut list) | Edn::List(mut list) => {
62        list.reverse();
63        Ok(visitor.visit_seq(SeqEdn::new(list))?)
64      }
65      Edn::Map(mut map) => {
66        if map == BTreeMap::new() {
67          visitor.visit_unit()
68        } else {
69          visitor.visit_map(MapEdn::new(&mut map))
70        }
71      }
72      Edn::Set(set) => {
73        let mut s: Vec<Edn<'_>> = set.into_iter().collect();
74        s.reverse();
75        Ok(visitor.visit_seq(SeqEdn::new(s))?)
76      }
77      // Things like rational numbers and custom tags can't be represented in rust types
78      _ => Err(de::Error::custom(format!("Don't know how to convert {self:?} into any"))),
79    }
80  }
81
82  forward_to_deserialize_any! {
83    bool i64 f64 char str unit map ignored_any seq tuple_struct
84  }
85
86  fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
87  where
88    V: Visitor<'de>,
89  {
90    let int = i8::try_from(get_int_from_edn(&self)?);
91    int.map_or_else(
92      |_| Err(de::Error::custom(format!("can't convert {int:?} into i8"))),
93      |i| visitor.visit_i8(i),
94    )
95  }
96
97  fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
98  where
99    V: Visitor<'de>,
100  {
101    let int = i16::try_from(get_int_from_edn(&self)?);
102    int.map_or_else(
103      |_| Err(de::Error::custom(format!("can't convert {int:?} into i16"))),
104      |i| visitor.visit_i16(i),
105    )
106  }
107
108  fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
109  where
110    V: Visitor<'de>,
111  {
112    let int = i32::try_from(get_int_from_edn(&self)?);
113    int.map_or_else(
114      |_| Err(de::Error::custom(format!("can't convert {int:?} into i32"))),
115      |i| visitor.visit_i32(i),
116    )
117  }
118
119  fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
120  where
121    V: Visitor<'de>,
122  {
123    let int = u8::try_from(get_int_from_edn(&self)?);
124    int.map_or_else(
125      |_| Err(de::Error::custom(format!("can't convert {int:?} into u8"))),
126      |i| visitor.visit_u8(i),
127    )
128  }
129
130  fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
131  where
132    V: Visitor<'de>,
133  {
134    let int = u16::try_from(get_int_from_edn(&self)?);
135    int.map_or_else(
136      |_| Err(de::Error::custom(format!("can't convert {int:?} into u16"))),
137      |i| visitor.visit_u16(i),
138    )
139  }
140
141  fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
142  where
143    V: Visitor<'de>,
144  {
145    let int = u32::try_from(get_int_from_edn(&self)?);
146    int.map_or_else(
147      |_| Err(de::Error::custom(format!("can't convert {int:?} into u32"))),
148      |i| visitor.visit_u32(i),
149    )
150  }
151
152  fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
153  where
154    V: Visitor<'de>,
155  {
156    let int = u64::try_from(get_int_from_edn(&self)?);
157    int.map_or_else(
158      |_| Err(de::Error::custom(format!("can't convert {int:?} into u64"))),
159      |i| visitor.visit_u64(i),
160    )
161  }
162
163  fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
164  where
165    V: Visitor<'de>,
166  {
167    let _ = visitor; // hush clippy
168    #[cfg(feature = "floats")]
169    if let Edn::Double(f) = self {
170      #[expect(clippy::cast_possible_truncation)]
171      return visitor.visit_f32(*f as f32);
172    }
173    Err(de::Error::custom(format!("can't convert {self:?} into f32")))
174  }
175
176  fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
177  where
178    V: Visitor<'de>,
179  {
180    self.deserialize_str(visitor)
181  }
182
183  fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value>
184  where
185    V: Visitor<'de>,
186  {
187    Err(de::Error::custom("deserialize_bytes is unimplemented/unused".to_string()))
188  }
189
190  fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
191  where
192    V: Visitor<'de>,
193  {
194    self.deserialize_bytes(visitor)
195  }
196
197  fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
198  where
199    V: Visitor<'de>,
200  {
201    if self == Edn::Nil { visitor.visit_none() } else { visitor.visit_some(self) }
202  }
203
204  fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
205  where
206    V: Visitor<'de>,
207  {
208    self.deserialize_unit(visitor)
209  }
210
211  fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
212  where
213    V: Visitor<'de>,
214  {
215    visitor.visit_newtype_struct(self)
216  }
217
218  fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
219  where
220    V: Visitor<'de>,
221  {
222    self.deserialize_seq(visitor)
223  }
224
225  fn deserialize_struct<V>(
226    self,
227    _name: &'static str,
228    _fields: &'static [&'static str],
229    visitor: V,
230  ) -> Result<V::Value>
231  where
232    V: Visitor<'de>,
233  {
234    self.deserialize_map(visitor)
235  }
236
237  fn deserialize_enum<V>(
238    self,
239    name: &'static str,
240    _variants: &'static [&'static str],
241    visitor: V,
242  ) -> Result<V::Value>
243  where
244    V: Visitor<'de>,
245  {
246    let Edn::Tagged(tag, ref edn) = self else {
247      return Err(de::Error::custom(format!("can't convert {self:?} into Tagged for enum")));
248    };
249
250    let mut split = tag.split('/');
251    let (Some(tag_first), Some(tag_second)) = (split.next(), split.next()) else {
252      return Err(de::Error::custom(format!("Expected namespace in {tag} for Tagged for enum")));
253    };
254
255    if name != tag_first {
256      return Err(de::Error::custom(format!("namespace in {tag} can't be matched to {name}")));
257    }
258
259    visitor.visit_enum(EnumEdn::new(edn, tag_second))
260  }
261
262  fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
263  where
264    V: Visitor<'de>,
265  {
266    self.deserialize_str(visitor)
267  }
268}
269
270struct SeqEdn<'de> {
271  de: Vec<Edn<'de>>,
272}
273
274impl<'de> SeqEdn<'de> {
275  const fn new(de: Vec<Edn<'de>>) -> Self {
276    SeqEdn { de }
277  }
278}
279
280impl<'de> SeqAccess<'de> for SeqEdn<'de> {
281  type Error = Error;
282
283  fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
284  where
285    T: DeserializeSeed<'de>,
286  {
287    let s = self.de.pop();
288    match s {
289      Some(e) => Ok(Some(seed.deserialize(e)?)),
290      None => Ok(None),
291    }
292  }
293}
294
295struct MapEdn<'a, 'de> {
296  de: &'a mut BTreeMap<Edn<'de>, Edn<'de>>,
297}
298
299impl<'a, 'de> MapEdn<'a, 'de> {
300  const fn new(de: &'a mut BTreeMap<Edn<'de>, Edn<'de>>) -> Self {
301    MapEdn { de }
302  }
303}
304
305impl<'de> MapAccess<'de> for MapEdn<'_, 'de> {
306  type Error = Error;
307
308  fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
309  where
310    K: DeserializeSeed<'de>,
311  {
312    while let Some((k, _)) = self.de.first_key_value() {
313      // pass over any keys that serde can't handle
314      match k {
315        Edn::Key(_) | Edn::Symbol(_) | Edn::Str(_) => {
316          return Ok(Some(seed.deserialize(k.clone())?));
317        }
318        _ => {
319          self.de.pop_first();
320        }
321      }
322    }
323    Ok(None)
324  }
325
326  fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
327  where
328    V: DeserializeSeed<'de>,
329  {
330    let (_, v) = self.de.pop_first().expect("kv must exist, because next_key_seed succeeded");
331    seed.deserialize(v)
332  }
333}
334
335#[derive(Debug)]
336struct EnumEdn<'a, 'de> {
337  de: &'a Edn<'de>,
338  variant: &'a str,
339}
340
341impl<'a, 'de> EnumEdn<'a, 'de> {
342  const fn new(de: &'a Edn<'de>, variant: &'a str) -> Self {
343    EnumEdn { de, variant }
344  }
345}
346
347impl<'de> EnumAccess<'de> for EnumEdn<'_, 'de> {
348  type Error = Error;
349  type Variant = Self;
350
351  fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
352  where
353    V: DeserializeSeed<'de>,
354  {
355    let val = seed.deserialize(self.variant.into_deserializer())?;
356    Ok((val, self))
357  }
358}
359
360impl<'de> VariantAccess<'de> for EnumEdn<'_, 'de> {
361  type Error = Error;
362
363  fn unit_variant(self) -> Result<()> {
364    Ok(())
365  }
366
367  fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
368  where
369    T: DeserializeSeed<'de>,
370  {
371    seed.deserialize(self.de.clone())
372  }
373
374  fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
375  where
376    V: Visitor<'de>,
377  {
378    de::Deserializer::deserialize_seq(self.de.clone(), visitor)
379  }
380
381  fn struct_variant<V>(
382    self,
383    _fields: &'static [&'static str],
384    visitor: V,
385  ) -> core::result::Result<V::Value, Self::Error>
386  where
387    V: Visitor<'de>,
388  {
389    de::Deserializer::deserialize_map(self.de.clone(), visitor)
390  }
391}