1use crate::{
2 arg::ArgType,
3 error::{ArgsError, ArgsErrorKind, ArgsErrorWithInput},
4 span::Span,
5};
6use facet_core::{Def, Facet, Field, FieldAttribute, FieldFlags, Shape, Type, UserType};
7use facet_reflect::{HeapValue, Partial};
8use heck::ToSnakeCase;
9
10pub fn from_std_args<T: Facet<'static>>() -> Result<T, ArgsErrorWithInput> {
12 let args = std::env::args().skip(1).collect::<Vec<String>>();
13 let args_str: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
14 from_slice(&args_str[..])
15}
16
17pub fn from_slice<'input, T: Facet<'static>>(
19 args: &'input [&'input str],
20) -> Result<T, ArgsErrorWithInput> {
21 let mut cx = Context::new(args, T::SHAPE);
22 let hv = cx.work_add_input()?;
23
24 Ok(hv.materialize::<T>().unwrap())
26}
27
28struct Context<'input> {
29 shape: &'static Shape,
31
32 args: &'input [&'input str],
34
35 index: usize,
37
38 positional_only: bool,
40
41 arg_indices: Vec<usize>,
43
44 flattened_args: String,
46}
47
48impl<'input> Context<'input> {
49 fn new(args: &'input [&'input str], shape: &'static Shape) -> Self {
50 let mut arg_indices = vec![];
51 let mut flattened_args = String::new();
52
53 for arg in args {
54 arg_indices.push(flattened_args.len());
55 flattened_args.push_str(arg);
56 flattened_args.push(' ');
57 }
58 log::trace!("flattened args: {flattened_args:?}");
59 log::trace!("arg_indices: {arg_indices:?}");
60
61 Self {
62 shape,
63 args,
64 index: 0,
65 positional_only: false,
66 arg_indices,
67 flattened_args,
68 }
69 }
70
71 fn fields(&self, p: &Partial<'static>) -> Result<&'static [Field], ArgsErrorKind> {
73 let shape = p.shape();
74 match &shape.ty {
75 Type::User(UserType::Struct(struct_type)) => Ok(struct_type.fields),
76 _ => Err(ArgsErrorKind::NoFields { shape }),
77 }
78 }
79
80 fn handle_field(
83 &mut self,
84 p: &mut Partial<'static>,
85 field_index: usize,
86 value: Option<Token<'input>>,
87 ) -> Result<(), ArgsErrorKind> {
88 let fields = self.fields(p)?;
89 let field = fields[field_index];
90 log::trace!("Found field {field:?}");
91
92 p.begin_nth_field(field_index)?;
93
94 log::trace!("After begin_field, shape is {}", p.shape());
95 if p.shape().is_shape(bool::SHAPE) {
96 log::trace!("Flag is boolean, setting it to true");
97 p.set(true)?;
98
99 self.index += 1;
100 } else {
101 log::trace!("Flag isn't boolean, expecting a {} value", p.shape());
102
103 if let Some(value) = value {
104 self.handle_value(p, value.s)?;
105 } else {
106 if self.index + 1 >= self.args.len() {
107 return Err(ArgsErrorKind::ExpectedValueGotEof { shape: p.shape() });
108 }
109 let value = self.args[self.index + 1];
110
111 self.index += 1;
112 self.handle_value(p, value)?;
113 }
114
115 self.index += 1;
116 }
117
118 p.end()?;
119
120 Ok(())
121 }
122
123 fn handle_value(
124 &mut self,
125 p: &mut Partial<'static>,
126 value: &'input str,
127 ) -> Result<(), ArgsErrorKind> {
128 match p.shape().def {
129 Def::List(_) => {
130 p.begin_list()?;
132 p.begin_list_item()?;
133 p.parse_from_str(value)?;
134 p.end()?;
135 }
136 _ => {
137 p.parse_from_str(value)?;
139 }
140 }
141
142 Ok(())
143 }
144
145 fn work_add_input(&mut self) -> Result<HeapValue<'static>, ArgsErrorWithInput> {
146 self.work().map_err(|e| ArgsErrorWithInput {
147 inner: e,
148 flattened_args: self.flattened_args.clone(),
149 })
150 }
151
152 fn work(&mut self) -> Result<HeapValue<'static>, ArgsError> {
154 self.work_inner().map_err(|kind| {
155 let span = if self.index >= self.args.len() {
156 Span::new(self.flattened_args.len(), 0)
157 } else {
158 let arg = self.args[self.index];
159 let index = self.arg_indices[self.index];
160 Span::new(index, arg.len())
161 };
162 ArgsError::new(kind, span)
163 })
164 }
165
166 fn work_inner(&mut self) -> Result<HeapValue<'static>, ArgsErrorKind> {
167 let mut p = Partial::alloc_shape(self.shape)?;
168
169 while self.args.len() > self.index {
170 let arg = self.args[self.index];
171 let arg_span = Span::new(self.arg_indices[self.index], arg.len());
172 let at = if self.positional_only {
173 ArgType::Positional
174 } else {
175 ArgType::parse(arg)
176 };
177 log::trace!("Parsed {at:?}");
178
179 match at {
180 ArgType::DoubleDash => {
181 self.positional_only = true;
182 self.index += 1;
183 }
184 ArgType::LongFlag(flag) => {
185 let flag_span = Span::new(arg_span.start + 2, arg_span.len - 2);
186 match split(flag, flag_span) {
187 Some(tokens) => {
188 let mut tokens = tokens.into_iter();
190 let Some(key) = tokens.next() else {
191 unreachable!()
192 };
193 let Some(value) = tokens.next() else {
194 unreachable!()
195 };
196
197 let flag = key.s;
198 let snek = key.s.to_snake_case();
199 log::trace!("Looking up long flag {flag} (field name: {snek})");
200 let Some(field_index) = p.field_index(&snek) else {
201 return Err(ArgsErrorKind::UnknownLongFlag);
202 };
203 self.handle_field(&mut p, field_index, Some(value))?;
204 }
205 None => {
206 let snek = flag.to_snake_case();
207 log::trace!("Looking up long flag {flag} (field name: {snek})");
208 let Some(field_index) = p.field_index(&snek) else {
209 return Err(ArgsErrorKind::UnknownLongFlag);
210 };
211 self.handle_field(&mut p, field_index, None)?;
212 }
213 }
214 }
215 ArgType::ShortFlag(flag) => {
216 let flag_span = Span::new(arg_span.start + 1, arg_span.len - 1);
217 match split(flag, flag_span) {
218 Some(tokens) => {
219 let mut tokens = tokens.into_iter();
221 let Some(key) = tokens.next() else {
222 unreachable!()
223 };
224 let Some(value) = tokens.next() else {
225 unreachable!()
226 };
227
228 let flag = key.s;
229 log::trace!("Looking up short flag {flag}");
230 let fields = self.fields(&p)?;
231 let Some(field_index) = find_field_index_with_short(fields, flag)
232 else {
233 return Err(ArgsErrorKind::UnknownShortFlag);
234 };
235 self.handle_field(&mut p, field_index, Some(value))?;
236 }
237 None => {
238 log::trace!("Looking up short flag {flag}");
239 let fields = self.fields(&p)?;
240 let Some(field_index) = find_field_index_with_short(fields, flag)
241 else {
242 return Err(ArgsErrorKind::UnknownShortFlag);
243 };
244 self.handle_field(&mut p, field_index, None)?;
245 }
246 }
247 }
248 ArgType::Positional => {
249 let fields = self.fields(&p)?;
250 let mut chosen_field_index: Option<usize> = None;
251
252 for (field_index, field) in fields.iter().enumerate() {
253 let is_positional = field.attributes.iter().any(|attr| match attr {
254 FieldAttribute::Arbitrary(attr) => attr.contains("positional"),
256 _ => false,
257 });
258 if !is_positional {
259 continue;
260 }
261
262 if matches!(field.shape().def, Def::List(_list_def)) {
265 } else if p.is_field_set(field_index)? {
267 continue;
269 }
270
271 log::trace!("found field, it's not a list {field:?}");
272 chosen_field_index = Some(field_index);
273 break;
274 }
275
276 let Some(chosen_field_index) = chosen_field_index else {
277 return Err(ArgsErrorKind::UnexpectedPositionalArgument);
278 };
279
280 p.begin_nth_field(chosen_field_index)?;
281
282 let value = self.args[self.index];
283 self.handle_value(&mut p, value)?;
284
285 p.end()?;
286 self.index += 1;
287 }
288 ArgType::None => todo!(),
289 }
290 }
291
292 {
293 let fields = self.fields(&p)?;
294 for (field_index, field) in fields.iter().enumerate() {
295 if p.is_field_set(field_index)? {
296 continue;
298 }
299
300 if field.flags.contains(FieldFlags::DEFAULT) {
301 log::trace!("Setting #{field_index} field to default: {field:?}");
302 p.set_nth_field_to_default(field_index)?;
303 } else if (field.shape)().is_shape(bool::SHAPE) {
304 p.set_nth_field(field_index, false)?;
306 } else {
307 return Err(ArgsErrorKind::MissingArgument { field });
308 }
309 }
310 }
311
312 Ok(p.build()?)
313 }
314}
315
316#[derive(Debug, PartialEq)]
318struct Token<'input> {
319 s: &'input str,
320 span: Span,
321}
322
323fn split<'input>(input: &'input str, span: Span) -> Option<Vec<Token<'input>>> {
327 let equals_index = input.find('=')?;
328
329 let l = &input[0..equals_index];
330 let l_span = Span::new(span.start, l.len());
331
332 let r = &input[equals_index + 1..];
333 let r_span = Span::new(equals_index + 1, r.len());
334
335 Some(vec![
336 Token { s: l, span: l_span },
337 Token { s: r, span: r_span },
338 ])
339}
340
341#[test]
342fn test_split() {
343 assert_eq!(split("ababa", Span::new(5, 5)), None);
344 assert_eq!(
345 split("foo=bar", Span::new(0, 7)),
346 Some(vec![
347 Token {
348 s: "foo",
349 span: Span::new(0, 3)
350 },
351 Token {
352 s: "bar",
353 span: Span::new(4, 3)
354 },
355 ])
356 );
357 assert_eq!(
358 split("foo=", Span::new(0, 4)),
359 Some(vec![
360 Token {
361 s: "foo",
362 span: Span::new(0, 3)
363 },
364 Token {
365 s: "",
366 span: Span::new(4, 0)
367 },
368 ])
369 );
370 assert_eq!(
371 split("=bar", Span::new(0, 4)),
372 Some(vec![
373 Token {
374 s: "",
375 span: Span::new(0, 0)
376 },
377 Token {
378 s: "bar",
379 span: Span::new(1, 3)
380 },
381 ])
382 );
383}
384
385fn find_field_index_with_short(field: &'static [Field], short: &str) -> Option<usize> {
388 let just_short = "short";
389 let full_attr1 = format!("short = '{short}'");
390 let full_attr2 = format!("short = \"{short}\"");
391
392 field.iter().position(|f| {
393 f.attributes.iter().any(|attr| match attr {
394 FieldAttribute::Arbitrary(attr_str) => {
395 attr_str == &full_attr1
396 || attr_str == &full_attr2
397 || (attr_str == &just_short && f.name == short)
398 }
399 _ => false,
400 })
401 })
402}