facet_args/
format.rs

1use crate::{
2    arg::ArgType,
3    error::{ArgsError, ArgsErrorKind, ArgsErrorWithInput},
4    span::Span,
5};
6use facet_core::{Def, Facet, Field, FieldAttribute, FieldFlags, Shape, Type, UserType};
7use facet_reflect::{HeapValue, Partial};
8use heck::ToSnakeCase;
9
10/// Parse command line arguments provided by std::env::args() into a Facet-compatible type
11pub fn from_std_args<T: Facet<'static>>() -> Result<T, ArgsErrorWithInput> {
12    let args = std::env::args().skip(1).collect::<Vec<String>>();
13    let args_str: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
14    from_slice(&args_str[..])
15}
16
17/// Parse command line arguments into a Facet-compatible type
18pub fn from_slice<'input, T: Facet<'static>>(
19    args: &'input [&'input str],
20) -> Result<T, ArgsErrorWithInput> {
21    let mut cx = Context::new(args, T::SHAPE);
22    let hv = cx.work_add_input()?;
23
24    // TODO: proper error handling
25    Ok(hv.materialize::<T>().unwrap())
26}
27
28struct Context<'input> {
29    /// The shape we're building
30    shape: &'static Shape,
31
32    /// Input arguments (already tokenized)
33    args: &'input [&'input str],
34
35    /// Argument we're currently parsing
36    index: usize,
37
38    /// Flips to true after `--`, which makes us only look for positional args
39    positional_only: bool,
40
41    /// Index of every argument in `flattened_args`
42    arg_indices: Vec<usize>,
43
44    /// Essentially `input.join(" ")`
45    flattened_args: String,
46}
47
48impl<'input> Context<'input> {
49    fn new(args: &'input [&'input str], shape: &'static Shape) -> Self {
50        let mut arg_indices = vec![];
51        let mut flattened_args = String::new();
52
53        for arg in args {
54            arg_indices.push(flattened_args.len());
55            flattened_args.push_str(arg);
56            flattened_args.push(' ');
57        }
58        log::trace!("flattened args: {flattened_args:?}");
59        log::trace!("arg_indices: {arg_indices:?}");
60
61        Self {
62            shape,
63            args,
64            index: 0,
65            positional_only: false,
66            arg_indices,
67            flattened_args,
68        }
69    }
70
71    /// Returns fields for the current shape, errors out if it's not a struct
72    fn fields(&self, p: &Partial<'static>) -> Result<&'static [Field], ArgsErrorKind> {
73        let shape = p.shape();
74        match &shape.ty {
75            Type::User(UserType::Struct(struct_type)) => Ok(struct_type.fields),
76            _ => Err(ArgsErrorKind::NoFields { shape }),
77        }
78    }
79
80    /// Once we have found the struct field that corresponds to a `--long` or `-s` short flag,
81    /// this is where we toggle something on, look for a value, etc.
82    fn handle_field(
83        &mut self,
84        p: &mut Partial<'static>,
85        field_index: usize,
86        value: Option<Token<'input>>,
87    ) -> Result<(), ArgsErrorKind> {
88        let fields = self.fields(p)?;
89        let field = fields[field_index];
90        log::trace!("Found field {field:?}");
91
92        p.begin_nth_field(field_index)?;
93
94        log::trace!("After begin_field, shape is {}", p.shape());
95        if p.shape().is_shape(bool::SHAPE) {
96            log::trace!("Flag is boolean, setting it to true");
97            p.set(true)?;
98
99            self.index += 1;
100        } else {
101            log::trace!("Flag isn't boolean, expecting a {} value", p.shape());
102
103            if let Some(value) = value {
104                self.handle_value(p, value.s)?;
105            } else {
106                if self.index + 1 >= self.args.len() {
107                    return Err(ArgsErrorKind::ExpectedValueGotEof { shape: p.shape() });
108                }
109                let value = self.args[self.index + 1];
110
111                self.index += 1;
112                self.handle_value(p, value)?;
113            }
114
115            self.index += 1;
116        }
117
118        p.end()?;
119
120        Ok(())
121    }
122
123    fn handle_value(
124        &mut self,
125        p: &mut Partial<'static>,
126        value: &'input str,
127    ) -> Result<(), ArgsErrorKind> {
128        match p.shape().def {
129            Def::List(_) => {
130                // if it's a list, then we'll want to initialize the list first and push to it
131                p.begin_list()?;
132                p.begin_list_item()?;
133                p.parse_from_str(value)?;
134                p.end()?;
135            }
136            _ => {
137                // TODO: this surely won't be enough eventually
138                p.parse_from_str(value)?;
139            }
140        }
141
142        Ok(())
143    }
144
145    fn work_add_input(&mut self) -> Result<HeapValue<'static>, ArgsErrorWithInput> {
146        self.work().map_err(|e| ArgsErrorWithInput {
147            inner: e,
148            flattened_args: self.flattened_args.clone(),
149        })
150    }
151
152    /// Forward to `work_inner`, converts `ArgsErrorKind` to `ArgsError` (with span)
153    fn work(&mut self) -> Result<HeapValue<'static>, ArgsError> {
154        self.work_inner().map_err(|kind| {
155            let span = if self.index >= self.args.len() {
156                Span::new(self.flattened_args.len(), 0)
157            } else {
158                let arg = self.args[self.index];
159                let index = self.arg_indices[self.index];
160                Span::new(index, arg.len())
161            };
162            ArgsError::new(kind, span)
163        })
164    }
165
166    fn work_inner(&mut self) -> Result<HeapValue<'static>, ArgsErrorKind> {
167        let mut p = Partial::alloc_shape(self.shape)?;
168
169        while self.args.len() > self.index {
170            let arg = self.args[self.index];
171            let arg_span = Span::new(self.arg_indices[self.index], arg.len());
172            let at = if self.positional_only {
173                ArgType::Positional
174            } else {
175                ArgType::parse(arg)
176            };
177            log::trace!("Parsed {at:?}");
178
179            match at {
180                ArgType::DoubleDash => {
181                    self.positional_only = true;
182                    self.index += 1;
183                }
184                ArgType::LongFlag(flag) => {
185                    let flag_span = Span::new(arg_span.start + 2, arg_span.len - 2);
186                    match split(flag, flag_span) {
187                        Some(tokens) => {
188                            // We have something like `--key=value`
189                            let mut tokens = tokens.into_iter();
190                            let Some(key) = tokens.next() else {
191                                unreachable!()
192                            };
193                            let Some(value) = tokens.next() else {
194                                unreachable!()
195                            };
196
197                            let flag = key.s;
198                            let snek = key.s.to_snake_case();
199                            log::trace!("Looking up long flag {flag} (field name: {snek})");
200                            let Some(field_index) = p.field_index(&snek) else {
201                                return Err(ArgsErrorKind::UnknownLongFlag);
202                            };
203                            self.handle_field(&mut p, field_index, Some(value))?;
204                        }
205                        None => {
206                            let snek = flag.to_snake_case();
207                            log::trace!("Looking up long flag {flag} (field name: {snek})");
208                            let Some(field_index) = p.field_index(&snek) else {
209                                return Err(ArgsErrorKind::UnknownLongFlag);
210                            };
211                            self.handle_field(&mut p, field_index, None)?;
212                        }
213                    }
214                }
215                ArgType::ShortFlag(flag) => {
216                    let flag_span = Span::new(arg_span.start + 1, arg_span.len - 1);
217                    match split(flag, flag_span) {
218                        Some(tokens) => {
219                            // We have something like `--key=value`
220                            let mut tokens = tokens.into_iter();
221                            let Some(key) = tokens.next() else {
222                                unreachable!()
223                            };
224                            let Some(value) = tokens.next() else {
225                                unreachable!()
226                            };
227
228                            let flag = key.s;
229                            log::trace!("Looking up short flag {flag}");
230                            let fields = self.fields(&p)?;
231                            let Some(field_index) = find_field_index_with_short(fields, flag)
232                            else {
233                                return Err(ArgsErrorKind::UnknownShortFlag);
234                            };
235                            self.handle_field(&mut p, field_index, Some(value))?;
236                        }
237                        None => {
238                            log::trace!("Looking up short flag {flag}");
239                            let fields = self.fields(&p)?;
240                            let Some(field_index) = find_field_index_with_short(fields, flag)
241                            else {
242                                return Err(ArgsErrorKind::UnknownShortFlag);
243                            };
244                            self.handle_field(&mut p, field_index, None)?;
245                        }
246                    }
247                }
248                ArgType::Positional => {
249                    let fields = self.fields(&p)?;
250                    let mut chosen_field_index: Option<usize> = None;
251
252                    for (field_index, field) in fields.iter().enumerate() {
253                        let is_positional = field.attributes.iter().any(|attr| match attr {
254                            // this is terrible, tbh
255                            FieldAttribute::Arbitrary(attr) => attr.contains("positional"),
256                            _ => false,
257                        });
258                        if !is_positional {
259                            continue;
260                        }
261
262                        // we've found a positional field. if it's a list, then we're done: every
263                        // positional argument will just be pushed to it.
264                        if matches!(field.shape().def, Def::List(_list_def)) {
265                            // cool, keep going
266                        } else if p.is_field_set(field_index)? {
267                            // field is already set, continue
268                            continue;
269                        }
270
271                        log::trace!("found field, it's not a list {field:?}");
272                        chosen_field_index = Some(field_index);
273                        break;
274                    }
275
276                    let Some(chosen_field_index) = chosen_field_index else {
277                        return Err(ArgsErrorKind::UnexpectedPositionalArgument);
278                    };
279
280                    p.begin_nth_field(chosen_field_index)?;
281
282                    let value = self.args[self.index];
283                    self.handle_value(&mut p, value)?;
284
285                    p.end()?;
286                    self.index += 1;
287                }
288                ArgType::None => todo!(),
289            }
290        }
291
292        {
293            let fields = self.fields(&p)?;
294            for (field_index, field) in fields.iter().enumerate() {
295                if p.is_field_set(field_index)? {
296                    // cool
297                    continue;
298                }
299
300                if field.flags.contains(FieldFlags::DEFAULT) {
301                    log::trace!("Setting #{field_index} field to default: {field:?}");
302                    p.set_nth_field_to_default(field_index)?;
303                } else if (field.shape)().is_shape(bool::SHAPE) {
304                    // bools are just set to false
305                    p.set_nth_field(field_index, false)?;
306                } else {
307                    return Err(ArgsErrorKind::MissingArgument { field });
308                }
309            }
310        }
311
312        Ok(p.build()?)
313    }
314}
315
316/// Result of `split`
317#[derive(Debug, PartialEq)]
318struct Token<'input> {
319    s: &'input str,
320    span: Span,
321}
322
323/// Split on `=`, e.g. `a=b` returns (`a`, `b`).
324/// Span-aware. If `=` is not contained in the input string,
325/// returns None
326fn split<'input>(input: &'input str, span: Span) -> Option<Vec<Token<'input>>> {
327    let equals_index = input.find('=')?;
328
329    let l = &input[0..equals_index];
330    let l_span = Span::new(span.start, l.len());
331
332    let r = &input[equals_index + 1..];
333    let r_span = Span::new(equals_index + 1, r.len());
334
335    Some(vec![
336        Token { s: l, span: l_span },
337        Token { s: r, span: r_span },
338    ])
339}
340
341#[test]
342fn test_split() {
343    assert_eq!(split("ababa", Span::new(5, 5)), None);
344    assert_eq!(
345        split("foo=bar", Span::new(0, 7)),
346        Some(vec![
347            Token {
348                s: "foo",
349                span: Span::new(0, 3)
350            },
351            Token {
352                s: "bar",
353                span: Span::new(4, 3)
354            },
355        ])
356    );
357    assert_eq!(
358        split("foo=", Span::new(0, 4)),
359        Some(vec![
360            Token {
361                s: "foo",
362                span: Span::new(0, 3)
363            },
364            Token {
365                s: "",
366                span: Span::new(4, 0)
367            },
368        ])
369    );
370    assert_eq!(
371        split("=bar", Span::new(0, 4)),
372        Some(vec![
373            Token {
374                s: "",
375                span: Span::new(0, 0)
376            },
377            Token {
378                s: "bar",
379                span: Span::new(1, 3)
380            },
381        ])
382    );
383}
384
385/// Given an array of fields, find the field with the given `short = 'a'`
386/// annotation.
387fn find_field_index_with_short(field: &'static [Field], short: &str) -> Option<usize> {
388    let just_short = "short";
389    let full_attr1 = format!("short = '{short}'");
390    let full_attr2 = format!("short = \"{short}\"");
391
392    field.iter().position(|f| {
393        f.attributes.iter().any(|attr| match attr {
394            FieldAttribute::Arbitrary(attr_str) => {
395                attr_str == &full_attr1
396                    || attr_str == &full_attr2
397                    || (attr_str == &just_short && f.name == short)
398            }
399            _ => false,
400        })
401    })
402}