clap_serde/
yaml.rs

1use serde::{
2    de::{
3        value::{MapDeserializer, SeqDeserializer},
4        Error as _, IntoDeserializer, Unexpected,
5    },
6    Deserializer,
7};
8use yaml_rust::Yaml;
9
10/**
11Deserializing from [`Yaml`]
12```
13const YAML_STR: &'static str = r#"
14name: app_clap_serde
15version : "1.0"
16about : yaml_support!
17author : yaml_supporter
18
19args:
20    - apple :
21        - short: a
22    - banana:
23        - short: b
24        - long: banana
25        - aliases :
26            - musa_spp
27
28subcommands:
29    - sub1:
30        - about : subcommand_1
31    - sub2:
32        - about : subcommand_2
33
34"#;
35let yaml = yaml_rust::Yaml::Array(yaml_rust::YamlLoader::load_from_str(YAML_STR).expect("not a yaml"));
36let app = clap_serde::yaml_to_app(&yaml).expect("parse failed from yaml");
37assert_eq!(app.get_name(), "app_clap_serde");
38```
39*/
40pub fn yaml_to_app(yaml: &Yaml) -> Result<clap::Command<'_>, Error> {
41    let wrap = YamlWrap { yaml };
42    use serde::Deserialize;
43    crate::CommandWrap::deserialize(wrap).map(|x| x.into())
44}
45
46/// Wrapper to use [`Yaml`] as [`Deserializer`].
47///
48/// Currently this implement functions in [`Deserializer`] that is only needed in deserializing into `Command`.
49/// Recommend to use [`yaml_to_app`] instead.
50pub struct YamlWrap<'a> {
51    yaml: &'a yaml_rust::Yaml,
52}
53
54impl<'a> YamlWrap<'a> {
55    pub fn new(yaml: &'a yaml_rust::Yaml) -> Self {
56        Self { yaml }
57    }
58}
59
60#[derive(Debug, Clone)]
61pub enum Error {
62    Custom(String),
63}
64
65impl std::fmt::Display for Error {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        std::fmt::Debug::fmt(self, f)
68    }
69}
70impl std::error::Error for Error {}
71
72impl serde::de::Error for Error {
73    fn custom<T>(msg: T) -> Self
74    where
75        T: std::fmt::Display,
76    {
77        Self::Custom(msg.to_string())
78    }
79}
80
81macro_rules! de_num {
82    ($sig : ident, $sig_v : ident) => {
83        fn $sig<V>(self, visitor: V) -> Result<V::Value, Self::Error>
84        where
85            V: serde::de::Visitor<'de>,
86        {
87            visitor.$sig_v(match self.yaml.as_i64().map(|i| i.try_into()) {
88                Some(Ok(i)) => i,
89                _ => return Err(as_invalid(self.yaml, "Intger")),
90            })
91        }
92    };
93}
94
95fn as_invalid(y: &Yaml, expected: &str) -> Error {
96    Error::invalid_type(
97        match y {
98            Yaml::Real(r) => r
99                .parse()
100                .map(Unexpected::Float)
101                .unwrap_or(Unexpected::Other(r)),
102            Yaml::Integer(i) => Unexpected::Signed(*i),
103            Yaml::String(s) => Unexpected::Str(s),
104            Yaml::Boolean(b) => Unexpected::Bool(*b),
105            Yaml::Array(_) => Unexpected::Seq,
106            Yaml::Hash(_) => Unexpected::Map,
107            Yaml::Alias(_) => todo!(),
108            Yaml::Null => Unexpected::Unit,
109            Yaml::BadValue => Unexpected::Other("BadValue"),
110        },
111        &expected,
112    )
113}
114
115impl<'de> IntoDeserializer<'de, Error> for YamlWrap<'de> {
116    type Deserializer = Self;
117
118    fn into_deserializer(self) -> Self::Deserializer {
119        self
120    }
121}
122
123impl<'de> Deserializer<'de> for YamlWrap<'de> {
124    type Error = Error;
125
126    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
127    where
128        V: serde::de::Visitor<'de>,
129    {
130        match self.yaml {
131            yaml_rust::Yaml::Real(s) => {
132                visitor.visit_f64(s.parse::<f64>().map_err(|e| Error::Custom(e.to_string()))?)
133            }
134            yaml_rust::Yaml::Integer(i) => visitor.visit_i64(*i),
135            yaml_rust::Yaml::String(s) => visitor.visit_str(s),
136            yaml_rust::Yaml::Boolean(b) => visitor.visit_bool(*b),
137            yaml_rust::Yaml::Array(_) => self.deserialize_seq(visitor), //visitor.visit_seq(a),
138            yaml_rust::Yaml::Hash(_) => self.deserialize_map(visitor),
139            yaml_rust::Yaml::Alias(_) => todo!(),
140            yaml_rust::Yaml::Null => visitor.visit_none(),
141            yaml_rust::Yaml::BadValue => Err(as_invalid(self.yaml, "any")),
142        }
143    }
144
145    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
146    where
147        V: serde::de::Visitor<'de>,
148    {
149        visitor.visit_bool(
150            self.yaml
151                .as_bool()
152                .ok_or_else(|| as_invalid(self.yaml, "bool"))?,
153        )
154    }
155
156    de_num!(deserialize_i8, visit_i8);
157    de_num!(deserialize_i16, visit_i16);
158    de_num!(deserialize_i32, visit_i32);
159    de_num!(deserialize_i64, visit_i64);
160    de_num!(deserialize_u8, visit_i8);
161    de_num!(deserialize_u16, visit_u16);
162    de_num!(deserialize_u32, visit_u32);
163    de_num!(deserialize_u64, visit_u64);
164
165    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
166    where
167        V: serde::de::Visitor<'de>,
168    {
169        visitor.visit_f32(
170            self.yaml
171                .as_f64()
172                .ok_or_else(|| as_invalid(self.yaml, "f32"))? as f32,
173        )
174    }
175
176    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
177    where
178        V: serde::de::Visitor<'de>,
179    {
180        visitor.visit_f64(
181            self.yaml
182                .as_f64()
183                .ok_or_else(|| as_invalid(self.yaml, "f64"))?,
184        )
185    }
186
187    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
188    where
189        V: serde::de::Visitor<'de>,
190    {
191        visitor.visit_char(
192            self.yaml
193                .as_str()
194                .ok_or_else(|| as_invalid(self.yaml, "char"))?
195                .chars()
196                .next()
197                .ok_or_else(|| as_invalid(self.yaml, "char"))?,
198        )
199    }
200
201    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
202    where
203        V: serde::de::Visitor<'de>,
204    {
205        let s = self.yaml.as_str();
206        visitor.visit_borrowed_str(s.ok_or_else(|| as_invalid(self.yaml, "str"))?)
207    }
208
209    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
210    where
211        V: serde::de::Visitor<'de>,
212    {
213        visitor.visit_string(
214            self.yaml
215                .as_str()
216                .ok_or_else(|| as_invalid(self.yaml, "string"))?
217                .to_string(),
218        )
219    }
220
221    ///not supported
222    fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
223    where
224        V: serde::de::Visitor<'de>,
225    {
226        unimplemented!()
227    }
228
229    ///not supported
230    fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
231    where
232        V: serde::de::Visitor<'de>,
233    {
234        unimplemented!()
235    }
236
237    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
238    where
239        V: serde::de::Visitor<'de>,
240    {
241        if matches!(self.yaml, yaml_rust::Yaml::Null) {
242            visitor.visit_none()
243        } else {
244            visitor.visit_some(self)
245        }
246    }
247
248    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
249    where
250        V: serde::de::Visitor<'de>,
251    {
252        if matches!(self.yaml, yaml_rust::Yaml::Null) {
253            visitor.visit_unit()
254        } else {
255            Err(as_invalid(self.yaml, "unit"))
256        }
257    }
258
259    ///unimplemented
260    fn deserialize_unit_struct<V>(
261        self,
262        _name: &'static str,
263        _visitor: V,
264    ) -> Result<V::Value, Self::Error>
265    where
266        V: serde::de::Visitor<'de>,
267    {
268        todo!()
269    }
270
271    ///unimplemented
272    fn deserialize_newtype_struct<V>(
273        self,
274        _name: &'static str,
275        _visitor: V,
276    ) -> Result<V::Value, Self::Error>
277    where
278        V: serde::de::Visitor<'de>,
279    {
280        todo!()
281    }
282
283    ///unimplemented
284    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
285    where
286        V: serde::de::Visitor<'de>,
287    {
288        if let Some(n) = self.yaml.as_vec() {
289            let seq = SeqDeserializer::new(n.iter().map(|y| YamlWrap { yaml: y }));
290            visitor.visit_seq(seq)
291        } else {
292            Err(as_invalid(self.yaml, "seq"))
293        }
294    }
295
296    ///unimplemented
297    fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
298    where
299        V: serde::de::Visitor<'de>,
300    {
301        todo!()
302    }
303
304    ///unimplemented
305    fn deserialize_tuple_struct<V>(
306        self,
307        _name: &'static str,
308        _len: usize,
309        _visitor: V,
310    ) -> Result<V::Value, Self::Error>
311    where
312        V: serde::de::Visitor<'de>,
313    {
314        todo!()
315    }
316
317    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
318    where
319        V: serde::de::Visitor<'de>,
320    {
321        match self.yaml {
322            Yaml::Hash(h) => {
323                let m = MapDeserializer::new(
324                    h.iter()
325                        .map(|(k, v)| (YamlWrap { yaml: k }, YamlWrap { yaml: v })),
326                );
327                visitor.visit_map(m)
328            }
329            Yaml::Array(a) => {
330                let x = a
331                    .iter()
332                    .map(|y| y.as_hash().ok_or_else(|| as_invalid(self.yaml, "map")))
333                    .collect::<Result<Vec<_>, _>>()?;
334                let m = MapDeserializer::new(
335                    x.into_iter()
336                        .flat_map(|x| x.iter())
337                        .map(|(k, v)| (YamlWrap { yaml: k }, YamlWrap { yaml: v })),
338                );
339                visitor.visit_map(m)
340            }
341            _ => Err(as_invalid(self.yaml, "map")),
342        }
343    }
344
345    ///unimplemented
346    fn deserialize_struct<V>(
347        self,
348        _name: &'static str,
349        _fields: &'static [&'static str],
350        _visitor: V,
351    ) -> Result<V::Value, Self::Error>
352    where
353        V: serde::de::Visitor<'de>,
354    {
355        todo!()
356    }
357
358    fn deserialize_enum<V>(
359        self,
360        _name: &'static str,
361        _variants: &'static [&'static str],
362        visitor: V,
363    ) -> Result<V::Value, Self::Error>
364    where
365        V: serde::de::Visitor<'de>,
366    {
367        if let Some(s) = self.yaml.as_str() {
368            visitor.visit_enum(s.into_deserializer())
369        } else {
370            Err(as_invalid(self.yaml, "enum"))
371        }
372    }
373
374    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
375    where
376        V: serde::de::Visitor<'de>,
377    {
378        self.deserialize_str(visitor)
379    }
380
381    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
382    where
383        V: serde::de::Visitor<'de>,
384    {
385        self.deserialize_any(visitor)
386    }
387}