1use serde::{
2 de::{
3 value::{MapDeserializer, SeqDeserializer},
4 Error as _, IntoDeserializer, Unexpected,
5 },
6 Deserializer,
7};
8use yaml_rust::Yaml;
9
10pub fn yaml_to_app(yaml: &Yaml) -> Result<clap::Command<'_>, Error> {
41 let wrap = YamlWrap { yaml };
42 use serde::Deserialize;
43 crate::CommandWrap::deserialize(wrap).map(|x| x.into())
44}
45
46pub struct YamlWrap<'a> {
51 yaml: &'a yaml_rust::Yaml,
52}
53
54impl<'a> YamlWrap<'a> {
55 pub fn new(yaml: &'a yaml_rust::Yaml) -> Self {
56 Self { yaml }
57 }
58}
59
60#[derive(Debug, Clone)]
61pub enum Error {
62 Custom(String),
63}
64
65impl std::fmt::Display for Error {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 std::fmt::Debug::fmt(self, f)
68 }
69}
70impl std::error::Error for Error {}
71
72impl serde::de::Error for Error {
73 fn custom<T>(msg: T) -> Self
74 where
75 T: std::fmt::Display,
76 {
77 Self::Custom(msg.to_string())
78 }
79}
80
81macro_rules! de_num {
82 ($sig : ident, $sig_v : ident) => {
83 fn $sig<V>(self, visitor: V) -> Result<V::Value, Self::Error>
84 where
85 V: serde::de::Visitor<'de>,
86 {
87 visitor.$sig_v(match self.yaml.as_i64().map(|i| i.try_into()) {
88 Some(Ok(i)) => i,
89 _ => return Err(as_invalid(self.yaml, "Intger")),
90 })
91 }
92 };
93}
94
95fn as_invalid(y: &Yaml, expected: &str) -> Error {
96 Error::invalid_type(
97 match y {
98 Yaml::Real(r) => r
99 .parse()
100 .map(Unexpected::Float)
101 .unwrap_or(Unexpected::Other(r)),
102 Yaml::Integer(i) => Unexpected::Signed(*i),
103 Yaml::String(s) => Unexpected::Str(s),
104 Yaml::Boolean(b) => Unexpected::Bool(*b),
105 Yaml::Array(_) => Unexpected::Seq,
106 Yaml::Hash(_) => Unexpected::Map,
107 Yaml::Alias(_) => todo!(),
108 Yaml::Null => Unexpected::Unit,
109 Yaml::BadValue => Unexpected::Other("BadValue"),
110 },
111 &expected,
112 )
113}
114
115impl<'de> IntoDeserializer<'de, Error> for YamlWrap<'de> {
116 type Deserializer = Self;
117
118 fn into_deserializer(self) -> Self::Deserializer {
119 self
120 }
121}
122
123impl<'de> Deserializer<'de> for YamlWrap<'de> {
124 type Error = Error;
125
126 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
127 where
128 V: serde::de::Visitor<'de>,
129 {
130 match self.yaml {
131 yaml_rust::Yaml::Real(s) => {
132 visitor.visit_f64(s.parse::<f64>().map_err(|e| Error::Custom(e.to_string()))?)
133 }
134 yaml_rust::Yaml::Integer(i) => visitor.visit_i64(*i),
135 yaml_rust::Yaml::String(s) => visitor.visit_str(s),
136 yaml_rust::Yaml::Boolean(b) => visitor.visit_bool(*b),
137 yaml_rust::Yaml::Array(_) => self.deserialize_seq(visitor), yaml_rust::Yaml::Hash(_) => self.deserialize_map(visitor),
139 yaml_rust::Yaml::Alias(_) => todo!(),
140 yaml_rust::Yaml::Null => visitor.visit_none(),
141 yaml_rust::Yaml::BadValue => Err(as_invalid(self.yaml, "any")),
142 }
143 }
144
145 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
146 where
147 V: serde::de::Visitor<'de>,
148 {
149 visitor.visit_bool(
150 self.yaml
151 .as_bool()
152 .ok_or_else(|| as_invalid(self.yaml, "bool"))?,
153 )
154 }
155
156 de_num!(deserialize_i8, visit_i8);
157 de_num!(deserialize_i16, visit_i16);
158 de_num!(deserialize_i32, visit_i32);
159 de_num!(deserialize_i64, visit_i64);
160 de_num!(deserialize_u8, visit_i8);
161 de_num!(deserialize_u16, visit_u16);
162 de_num!(deserialize_u32, visit_u32);
163 de_num!(deserialize_u64, visit_u64);
164
165 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
166 where
167 V: serde::de::Visitor<'de>,
168 {
169 visitor.visit_f32(
170 self.yaml
171 .as_f64()
172 .ok_or_else(|| as_invalid(self.yaml, "f32"))? as f32,
173 )
174 }
175
176 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
177 where
178 V: serde::de::Visitor<'de>,
179 {
180 visitor.visit_f64(
181 self.yaml
182 .as_f64()
183 .ok_or_else(|| as_invalid(self.yaml, "f64"))?,
184 )
185 }
186
187 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
188 where
189 V: serde::de::Visitor<'de>,
190 {
191 visitor.visit_char(
192 self.yaml
193 .as_str()
194 .ok_or_else(|| as_invalid(self.yaml, "char"))?
195 .chars()
196 .next()
197 .ok_or_else(|| as_invalid(self.yaml, "char"))?,
198 )
199 }
200
201 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
202 where
203 V: serde::de::Visitor<'de>,
204 {
205 let s = self.yaml.as_str();
206 visitor.visit_borrowed_str(s.ok_or_else(|| as_invalid(self.yaml, "str"))?)
207 }
208
209 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
210 where
211 V: serde::de::Visitor<'de>,
212 {
213 visitor.visit_string(
214 self.yaml
215 .as_str()
216 .ok_or_else(|| as_invalid(self.yaml, "string"))?
217 .to_string(),
218 )
219 }
220
221 fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
223 where
224 V: serde::de::Visitor<'de>,
225 {
226 unimplemented!()
227 }
228
229 fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
231 where
232 V: serde::de::Visitor<'de>,
233 {
234 unimplemented!()
235 }
236
237 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
238 where
239 V: serde::de::Visitor<'de>,
240 {
241 if matches!(self.yaml, yaml_rust::Yaml::Null) {
242 visitor.visit_none()
243 } else {
244 visitor.visit_some(self)
245 }
246 }
247
248 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
249 where
250 V: serde::de::Visitor<'de>,
251 {
252 if matches!(self.yaml, yaml_rust::Yaml::Null) {
253 visitor.visit_unit()
254 } else {
255 Err(as_invalid(self.yaml, "unit"))
256 }
257 }
258
259 fn deserialize_unit_struct<V>(
261 self,
262 _name: &'static str,
263 _visitor: V,
264 ) -> Result<V::Value, Self::Error>
265 where
266 V: serde::de::Visitor<'de>,
267 {
268 todo!()
269 }
270
271 fn deserialize_newtype_struct<V>(
273 self,
274 _name: &'static str,
275 _visitor: V,
276 ) -> Result<V::Value, Self::Error>
277 where
278 V: serde::de::Visitor<'de>,
279 {
280 todo!()
281 }
282
283 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
285 where
286 V: serde::de::Visitor<'de>,
287 {
288 if let Some(n) = self.yaml.as_vec() {
289 let seq = SeqDeserializer::new(n.iter().map(|y| YamlWrap { yaml: y }));
290 visitor.visit_seq(seq)
291 } else {
292 Err(as_invalid(self.yaml, "seq"))
293 }
294 }
295
296 fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
298 where
299 V: serde::de::Visitor<'de>,
300 {
301 todo!()
302 }
303
304 fn deserialize_tuple_struct<V>(
306 self,
307 _name: &'static str,
308 _len: usize,
309 _visitor: V,
310 ) -> Result<V::Value, Self::Error>
311 where
312 V: serde::de::Visitor<'de>,
313 {
314 todo!()
315 }
316
317 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
318 where
319 V: serde::de::Visitor<'de>,
320 {
321 match self.yaml {
322 Yaml::Hash(h) => {
323 let m = MapDeserializer::new(
324 h.iter()
325 .map(|(k, v)| (YamlWrap { yaml: k }, YamlWrap { yaml: v })),
326 );
327 visitor.visit_map(m)
328 }
329 Yaml::Array(a) => {
330 let x = a
331 .iter()
332 .map(|y| y.as_hash().ok_or_else(|| as_invalid(self.yaml, "map")))
333 .collect::<Result<Vec<_>, _>>()?;
334 let m = MapDeserializer::new(
335 x.into_iter()
336 .flat_map(|x| x.iter())
337 .map(|(k, v)| (YamlWrap { yaml: k }, YamlWrap { yaml: v })),
338 );
339 visitor.visit_map(m)
340 }
341 _ => Err(as_invalid(self.yaml, "map")),
342 }
343 }
344
345 fn deserialize_struct<V>(
347 self,
348 _name: &'static str,
349 _fields: &'static [&'static str],
350 _visitor: V,
351 ) -> Result<V::Value, Self::Error>
352 where
353 V: serde::de::Visitor<'de>,
354 {
355 todo!()
356 }
357
358 fn deserialize_enum<V>(
359 self,
360 _name: &'static str,
361 _variants: &'static [&'static str],
362 visitor: V,
363 ) -> Result<V::Value, Self::Error>
364 where
365 V: serde::de::Visitor<'de>,
366 {
367 if let Some(s) = self.yaml.as_str() {
368 visitor.visit_enum(s.into_deserializer())
369 } else {
370 Err(as_invalid(self.yaml, "enum"))
371 }
372 }
373
374 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
375 where
376 V: serde::de::Visitor<'de>,
377 {
378 self.deserialize_str(visitor)
379 }
380
381 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
382 where
383 V: serde::de::Visitor<'de>,
384 {
385 self.deserialize_any(visitor)
386 }
387}