facet_args/
lib.rs

1#![warn(missing_docs)]
2#![warn(clippy::std_instead_of_core)]
3#![warn(clippy::std_instead_of_alloc)]
4#![forbid(unsafe_code)]
5#![doc = include_str!("../README.md")]
6
7extern crate alloc;
8use alloc::borrow::Cow;
9
10mod error;
11
12use error::{ArgsError, ArgsErrorKind};
13use facet_core::{Def, Facet, FieldAttribute, Type, UserType};
14use facet_reflect::{ReflectError, Wip};
15
16fn parse_field<'facet>(wip: Wip<'facet>, value: &'facet str) -> Result<Wip<'facet>, ArgsError> {
17    let shape = wip.shape();
18
19    if shape.is_type::<String>() {
20        log::trace!("shape is String");
21        wip.put(value.to_string())
22    } else if shape.is_type::<&str>() {
23        log::trace!("shape is &str");
24        wip.put(value)
25    } else if shape.is_type::<bool>() {
26        log::trace!("shape is bool, setting to true");
27        wip.put(value.to_lowercase() == "true")
28    } else {
29        match shape.def {
30            Def::Scalar(_) => {
31                log::trace!("shape is nothing known, falling back to parse: {}", shape);
32                wip.parse(value)
33            }
34            _def => {
35                return Err(ArgsError::new(ArgsErrorKind::GenericReflect(
36                    ReflectError::OperationFailed {
37                        shape,
38                        operation: "parsing field",
39                    },
40                )));
41            }
42        }
43    }
44    .map_err(|e| ArgsError::new(ArgsErrorKind::GenericReflect(e)))?
45    .pop()
46    .map_err(|e| ArgsError {
47        kind: ArgsErrorKind::GenericReflect(e),
48    })
49}
50
51fn kebab_to_snake(input: &str) -> Cow<str> {
52    // ASSUMPTION: We only support GNU/Unix kebab-case named argument
53    // ASSUMPTION: struct fields are snake_case
54    if !input.contains('-') {
55        return Cow::Borrowed(input);
56    }
57    Cow::Owned(input.replace('-', "_"))
58}
59
60/// Parses command-line arguments
61pub fn from_slice<'input, 'facet, T>(s: &[&'input str]) -> Result<T, ArgsError>
62where
63    T: Facet<'facet>,
64    'input: 'facet,
65{
66    log::trace!("Entering from_slice function");
67    let mut s = s;
68    let mut wip =
69        Wip::alloc::<T>().map_err(|e| ArgsError::new(ArgsErrorKind::GenericReflect(e)))?;
70    log::trace!("Allocated Poke for type T");
71    let Type::User(UserType::Struct(st)) = wip.shape().ty else {
72        return Err(ArgsError::new(ArgsErrorKind::GenericArgsError(
73            "Expected struct type".to_string(),
74        )));
75    };
76
77    while let Some(token) = s.first() {
78        log::trace!("Processing token: {}", token);
79        s = &s[1..];
80
81        if let Some(key) = token.strip_prefix("--") {
82            let key = kebab_to_snake(key);
83            let field_index = match wip.field_index(&key) {
84                Some(index) => index,
85                None => {
86                    return Err(ArgsError::new(ArgsErrorKind::GenericArgsError(format!(
87                        "Unknown argument `{key}`",
88                    ))));
89                }
90            };
91            log::trace!("Found named argument: {}", key);
92
93            let field = wip
94                .field(field_index)
95                .expect("field_index should be a valid field bound");
96
97            if field.shape().is_type::<bool>() {
98                // TODO: absence i.e "false" case is not handled
99                wip = parse_field(field, "true")?;
100            } else {
101                let value = s
102                    .first()
103                    .ok_or(ArgsError::new(ArgsErrorKind::GenericArgsError(format!(
104                        "expected value after argument `{key}`"
105                    ))))?;
106                log::trace!("Field value: {}", value);
107                s = &s[1..];
108                wip = parse_field(field, value)?;
109            }
110        } else if let Some(key) = token.strip_prefix("-") {
111            log::trace!("Found short named argument: {}", key);
112            for (field_index, f) in st.fields.iter().enumerate() {
113                if f.attributes
114                    .iter()
115                    .any(|a| matches!(a, FieldAttribute::Arbitrary(a) if a.contains("short") && a.contains(key))
116                   )
117                {
118                    log::trace!("Found field matching short_code: {} for field {}", key, f.name);
119                    let field = wip.field(field_index).expect("field_index is in bounds");
120                    if field.shape().is_type::<bool>() {
121                        wip = parse_field(field, "true")?;
122                    } else {
123                        let value = s
124                            .first()
125                            .ok_or(ArgsError::new(ArgsErrorKind::GenericArgsError(format!(
126                                "expected value after argument `{key}`"
127                            ))))?;
128                        log::trace!("Field value: {}", value);
129                        s = &s[1..];
130                        wip = parse_field(field, value)?;
131                    }
132                    break;
133                }
134            }
135        } else {
136            log::trace!("Encountered positional argument: {}", token);
137            for (field_index, f) in st.fields.iter().enumerate() {
138                if f.attributes
139                    .iter()
140                    .any(|a| matches!(a, FieldAttribute::Arbitrary(a) if a.contains("positional")))
141                {
142                    if wip
143                        .is_field_set(field_index)
144                        .expect("field_index is in bounds")
145                    {
146                        continue;
147                    }
148                    let field = wip.field(field_index).expect("field_index is in bounds");
149                    wip = parse_field(field, token)?;
150                    break;
151                }
152            }
153        }
154    }
155
156    // Look for uninitialized fields with DEFAULT flag
157    // Adapted from the approach in `facet-deserialize::StackRunner::pop()`
158    for (field_index, field) in st.fields.iter().enumerate() {
159        if !wip.is_field_set(field_index).expect("in bounds") {
160            log::trace!(
161                "Field {} is not initialized, checking if it has DEFAULT flag",
162                field.name
163            );
164
165            // Check if the field has the DEFAULT flag
166            if field.flags.contains(facet_core::FieldFlags::DEFAULT) {
167                log::trace!("Field {} has DEFAULT flag, applying default", field.name);
168
169                let field_wip = wip.field(field_index).expect("field_index is in bounds");
170
171                // Check if there's a custom default function
172                if let Some(default_fn) = field.vtable.default_fn {
173                    log::trace!("Using custom default function for field {}", field.name);
174                    wip = field_wip
175                        .put_from_fn(default_fn)
176                        .map_err(|e| ArgsError::new(ArgsErrorKind::GenericReflect(e)))?;
177                } else {
178                    // Otherwise use the Default trait
179                    log::trace!("Using Default trait for field {}", field.name);
180                    wip = field_wip
181                        .put_default()
182                        .map_err(|e| ArgsError::new(ArgsErrorKind::GenericReflect(e)))?;
183                }
184
185                // Pop back up to the struct level
186                wip = wip
187                    .pop()
188                    .map_err(|e| ArgsError::new(ArgsErrorKind::GenericReflect(e)))?;
189            }
190        }
191    }
192
193    // If a boolean field is unset the value is set to `false`
194    // This behaviour means `#[facet(default = false)]` does not need to be explicitly set
195    // on each boolean field specified on a Command struct
196    for (field_index, f) in st.fields.iter().enumerate() {
197        if f.shape().is_type::<bool>() && !wip.is_field_set(field_index).expect("in bounds") {
198            let field = wip.field(field_index).expect("field_index is in bounds");
199            wip = parse_field(field, "false")?;
200        }
201    }
202
203    // Add this right after getting the struct type (st)
204    log::trace!("Checking field attributes");
205    for (i, field) in st.fields.iter().enumerate() {
206        log::trace!(
207            "Field {}: {} - Attributes: {:?}",
208            i,
209            field.name,
210            field.attributes
211        );
212    }
213
214    let heap_vale = wip
215        .build()
216        .map_err(|e| ArgsError::new(ArgsErrorKind::GenericReflect(e)))?;
217    let result = heap_vale
218        .materialize()
219        .map_err(|e| ArgsError::new(ArgsErrorKind::GenericReflect(e)))?;
220    Ok(result)
221}