1use crate::{
2 arg::ArgType,
3 error::{ArgsError, ArgsErrorKind, ArgsErrorWithInput},
4 span::Span,
5};
6use facet_core::{Def, Facet, Field, FieldAttribute, FieldFlags, Shape, Type, UserType};
7use facet_reflect::{HeapValue, Partial};
8use heck::ToSnakeCase;
9
10pub fn from_std_args<T: Facet<'static>>() -> Result<T, ArgsErrorWithInput> {
12 let args = std::env::args().skip(1).collect::<Vec<String>>();
13 let args_str: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
14 from_slice(&args_str[..])
15}
16
17pub fn from_slice<'input, T: Facet<'static>>(
19 args: &'input [&'input str],
20) -> Result<T, ArgsErrorWithInput> {
21 let mut cx = Context::new(args, T::SHAPE);
22 let hv = cx.work_add_input()?;
23
24 Ok(hv.materialize::<T>().unwrap())
26}
27
28struct Context<'input> {
29 shape: &'static Shape,
31
32 args: &'input [&'input str],
34
35 index: usize,
37
38 positional_only: bool,
40
41 arg_indices: Vec<usize>,
43
44 flattened_args: String,
46}
47
48impl<'input> Context<'input> {
49 fn new(args: &'input [&'input str], shape: &'static Shape) -> Self {
50 let mut arg_indices = vec![];
51 let mut flattened_args = String::new();
52
53 for arg in args {
54 arg_indices.push(flattened_args.len());
55 flattened_args.push_str(arg);
56 flattened_args.push(' ');
57 }
58 log::trace!("flattened args: {flattened_args:?}");
59 log::trace!("arg_indices: {arg_indices:?}");
60
61 Self {
62 shape,
63 args,
64 index: 0,
65 positional_only: false,
66 arg_indices,
67 flattened_args,
68 }
69 }
70
71 fn fields(&self, p: &Partial<'static>) -> Result<&'static [Field], ArgsErrorKind> {
73 let shape = p.shape();
74 match &shape.ty {
75 Type::User(UserType::Struct(struct_type)) => Ok(struct_type.fields),
76 _ => Err(ArgsErrorKind::NoFields { shape }),
77 }
78 }
79
80 fn handle_field(
83 &mut self,
84 p: &mut Partial<'static>,
85 field_index: usize,
86 value: Option<Token<'input>>,
87 ) -> Result<(), ArgsErrorKind> {
88 let fields = self.fields(p)?;
89 let field = fields[field_index];
90 log::trace!("Found field {field:?}");
91
92 p.begin_nth_field(field_index)?;
93
94 log::trace!("After begin_field, shape is {}", p.shape());
95 if p.shape().is_shape(bool::SHAPE) {
96 log::trace!("Flag is boolean, setting it to true");
97 p.set(true)?;
98
99 self.index += 1;
100 } else {
101 log::trace!("Flag isn't boolean, expecting a {} value", p.shape());
102
103 if let Some(value) = value {
104 self.handle_value(p, value.s)?;
105 } else {
106 if self.index + 1 >= self.args.len() {
107 return Err(ArgsErrorKind::ExpectedValueGotEof { shape: p.shape() });
108 }
109 let value = self.args[self.index + 1];
110
111 self.index += 1;
112 self.handle_value(p, value)?;
113 }
114
115 self.index += 1;
116 }
117
118 p.end()?;
119
120 Ok(())
121 }
122
123 fn handle_value(
124 &mut self,
125 p: &mut Partial<'static>,
126 value: &'input str,
127 ) -> Result<(), ArgsErrorKind> {
128 match p.shape().def {
129 Def::List(_) => {
130 p.begin_list()?;
132 p.begin_list_item()?;
133 p.parse_from_str(value)?;
134 p.end()?;
135 }
136 _ => {
137 p.parse_from_str(value)?;
139 }
140 }
141
142 Ok(())
143 }
144
145 fn work_add_input(&mut self) -> Result<HeapValue<'static>, ArgsErrorWithInput> {
146 self.work().map_err(|e| ArgsErrorWithInput {
147 inner: e,
148 flattened_args: self.flattened_args.clone(),
149 })
150 }
151
152 fn work(&mut self) -> Result<HeapValue<'static>, ArgsError> {
154 self.work_inner().map_err(|kind| {
155 let span = if self.index >= self.args.len() {
156 Span::new(self.flattened_args.len(), 0)
157 } else {
158 let arg = self.args[self.index];
159 let index = self.arg_indices[self.index];
160 Span::new(index, arg.len())
161 };
162 ArgsError::new(kind, span)
163 })
164 }
165
166 fn work_inner(&mut self) -> Result<HeapValue<'static>, ArgsErrorKind> {
167 let mut p = Partial::alloc_shape(self.shape)?;
168
169 while self.args.len() > self.index {
170 let arg = self.args[self.index];
171 let arg_span = Span::new(self.arg_indices[self.index], arg.len());
172 let at = if self.positional_only {
173 ArgType::Positional
174 } else {
175 ArgType::parse(arg)
176 };
177 log::trace!("Parsed {at:?}");
178
179 match at {
180 ArgType::DoubleDash => {
181 self.positional_only = true;
182 self.index += 1;
183 }
184 ArgType::LongFlag(flag) => {
185 let flag_span = Span::new(arg_span.start + 2, arg_span.len - 2);
186 match split(flag, flag_span) {
187 Some(tokens) => {
188 let mut tokens = tokens.into_iter();
190 let Some(key) = tokens.next() else {
191 unreachable!()
192 };
193 let Some(value) = tokens.next() else {
194 unreachable!()
195 };
196
197 let flag = key.s;
198 let snek = key.s.to_snake_case();
199 log::trace!("Looking up long flag {flag} (field name: {snek})");
200 let Some(field_index) = p.field_index(&snek) else {
201 return Err(ArgsErrorKind::UnknownLongFlag);
202 };
203 self.handle_field(&mut p, field_index, Some(value))?;
204 }
205 None => {
206 let snek = flag.to_snake_case();
207 log::trace!("Looking up long flag {flag} (field name: {snek})");
208 let Some(field_index) = p.field_index(&snek) else {
209 return Err(ArgsErrorKind::UnknownLongFlag);
210 };
211 self.handle_field(&mut p, field_index, None)?;
212 }
213 }
214 }
215 ArgType::ShortFlag(flag) => {
216 let flag_span = Span::new(arg_span.start + 1, arg_span.len - 1);
217 match split(flag, flag_span) {
218 Some(tokens) => {
219 let mut tokens = tokens.into_iter();
221 let Some(key) = tokens.next() else {
222 unreachable!()
223 };
224 let Some(value) = tokens.next() else {
225 unreachable!()
226 };
227
228 let flag = key.s;
229 log::trace!("Looking up short flag {flag}");
230 let fields = self.fields(&p)?;
231 let Some(field_index) = find_field_index_with_short(fields, flag)
232 else {
233 return Err(ArgsErrorKind::UnknownShortFlag);
234 };
235 self.handle_field(&mut p, field_index, Some(value))?;
236 }
237 None => {
238 log::trace!("Looking up short flag {flag}");
239 let fields = self.fields(&p)?;
240 let Some(field_index) = find_field_index_with_short(fields, flag)
241 else {
242 return Err(ArgsErrorKind::UnknownShortFlag);
243 };
244 self.handle_field(&mut p, field_index, None)?;
245 }
246 }
247 }
248 ArgType::Positional => {
249 let fields = self.fields(&p)?;
250 let mut chosen_field_index: Option<usize> = None;
251
252 for (field_index, field) in fields.iter().enumerate() {
253 let is_positional = field.attributes.iter().any(|attr| match attr {
254 FieldAttribute::Arbitrary(attr) => attr.contains("positional"),
256 });
257 if !is_positional {
258 continue;
259 }
260
261 if matches!(field.shape().def, Def::List(_list_def)) {
264 } else if p.is_field_set(field_index)? {
266 continue;
268 }
269
270 log::trace!("found field, it's not a list {field:?}");
271 chosen_field_index = Some(field_index);
272 break;
273 }
274
275 let Some(chosen_field_index) = chosen_field_index else {
276 return Err(ArgsErrorKind::UnexpectedPositionalArgument);
277 };
278
279 p.begin_nth_field(chosen_field_index)?;
280
281 let value = self.args[self.index];
282 self.handle_value(&mut p, value)?;
283
284 p.end()?;
285 self.index += 1;
286 }
287 ArgType::None => todo!(),
288 }
289 }
290
291 {
292 let fields = self.fields(&p)?;
293 for (field_index, field) in fields.iter().enumerate() {
294 if p.is_field_set(field_index)? {
295 continue;
297 }
298
299 if field.flags.contains(FieldFlags::DEFAULT) {
300 log::trace!("Setting #{field_index} field to default: {field:?}");
301 p.set_nth_field_to_default(field_index)?;
302 } else if field.shape.is_shape(bool::SHAPE) {
303 p.set_nth_field(field_index, false)?;
305 } else {
306 return Err(ArgsErrorKind::MissingArgument { field });
307 }
308 }
309 }
310
311 Ok(p.build()?)
312 }
313}
314
315#[derive(Debug, PartialEq)]
317struct Token<'input> {
318 s: &'input str,
319 span: Span,
320}
321
322fn split<'input>(input: &'input str, span: Span) -> Option<Vec<Token<'input>>> {
326 let equals_index = input.find('=')?;
327
328 let l = &input[0..equals_index];
329 let l_span = Span::new(span.start, l.len());
330
331 let r = &input[equals_index + 1..];
332 let r_span = Span::new(equals_index + 1, r.len());
333
334 Some(vec![
335 Token { s: l, span: l_span },
336 Token { s: r, span: r_span },
337 ])
338}
339
340#[test]
341fn test_split() {
342 assert_eq!(split("ababa", Span::new(5, 5)), None);
343 assert_eq!(
344 split("foo=bar", Span::new(0, 7)),
345 Some(vec![
346 Token {
347 s: "foo",
348 span: Span::new(0, 3)
349 },
350 Token {
351 s: "bar",
352 span: Span::new(4, 3)
353 },
354 ])
355 );
356 assert_eq!(
357 split("foo=", Span::new(0, 4)),
358 Some(vec![
359 Token {
360 s: "foo",
361 span: Span::new(0, 3)
362 },
363 Token {
364 s: "",
365 span: Span::new(4, 0)
366 },
367 ])
368 );
369 assert_eq!(
370 split("=bar", Span::new(0, 4)),
371 Some(vec![
372 Token {
373 s: "",
374 span: Span::new(0, 0)
375 },
376 Token {
377 s: "bar",
378 span: Span::new(1, 3)
379 },
380 ])
381 );
382}
383
384fn find_field_index_with_short(field: &'static [Field], short: &str) -> Option<usize> {
387 let just_short = "short";
388 let full_attr1 = format!("short = '{short}'");
389 let full_attr2 = format!("short = \"{short}\"");
390
391 field.iter().position(|f| {
392 f.attributes.iter().any(|attr| match attr {
393 FieldAttribute::Arbitrary(attr_str) => {
394 attr_str == &full_attr1
395 || attr_str == &full_attr2
396 || (attr_str == &just_short && f.name == short)
397 }
398 })
399 })
400}