Skip to main content

teamy_figue/
extract.rs

1//! Extract requirements structs from parsed configuration.
2//!
3//! This module provides the ability to validate and extract subcommand-specific
4//! required fields from a successfully parsed configuration.
5
6use crate::config_value::{ConfigValue, ObjectMap, Sourced};
7use crate::schema::Schema;
8use facet::{Def, Facet, Field, Type, UserType};
9use heck::{ToKebabCase, ToShoutySnakeCase};
10use indexmap::IndexMap;
11use owo_colors::OwoColorize;
12use owo_colors::Stream::Stdout;
13
14/// Information about a missing required field during extraction.
15#[derive(Debug, Clone)]
16pub struct ExtractMissingField {
17    /// Field name in the requirements struct.
18    pub field_name: String,
19    /// Origin path that was looked up.
20    pub origin_path: String,
21    /// Expected type name.
22    pub type_name: String,
23    /// CLI hint for setting this field (e.g., "--config.database-url").
24    pub cli_hint: Option<String>,
25    /// Environment variable hint (e.g., "$MYAPP__DATABASE_URL").
26    pub env_hint: Option<String>,
27}
28
29/// Error returned when extraction fails.
30#[derive(Debug)]
31pub struct ExtractError {
32    /// List of missing required fields.
33    pub missing_fields: Vec<ExtractMissingField>,
34}
35
36impl std::fmt::Display for ExtractError {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        writeln!(f, "Missing required fields for this operation:\n")?;
39        for field in &self.missing_fields {
40            write!(
41                f,
42                "  {} <{}> at {}",
43                field
44                    .field_name
45                    .if_supports_color(Stdout, |text| text.bold()),
46                field
47                    .type_name
48                    .if_supports_color(Stdout, |text| text.cyan()),
49                field.origin_path
50            )?;
51
52            let mut hints = Vec::new();
53            if let Some(cli) = &field.cli_hint {
54                hints.push(
55                    cli.if_supports_color(Stdout, |text| text.green())
56                        .to_string(),
57                );
58            }
59            if let Some(env) = &field.env_hint {
60                hints.push(
61                    env.if_supports_color(Stdout, |text| text.yellow())
62                        .to_string(),
63                );
64            }
65            if !hints.is_empty() {
66                write!(f, "\n    Set via: {}", hints.join(" or "))?;
67            }
68            writeln!(f)?;
69        }
70        Ok(())
71    }
72}
73
74impl std::error::Error for ExtractError {}
75
76/// Extract a requirements struct from a ConfigValue.
77///
78/// The requirements struct should have fields annotated with `#[facet(args::origin = "path")]`
79/// to indicate which values from the config should be extracted.
80///
81/// Returns an error if:
82/// - Any required field (non-Option) has a missing origin value
83/// - Any field lacks the `args::origin` attribute
84pub fn extract_requirements<R: Facet<'static>>(
85    config_value: &ConfigValue,
86    schema: &Schema,
87) -> Result<R, ExtractError> {
88    let shape = R::SHAPE;
89
90    // Verify it's a struct
91    let struct_type = match &shape.ty {
92        Type::User(UserType::Struct(s)) => *s,
93        _ => {
94            return Err(ExtractError {
95                missing_fields: vec![ExtractMissingField {
96                    field_name: "<root>".to_string(),
97                    origin_path: "<root>".to_string(),
98                    type_name: shape.type_identifier.to_string(),
99                    cli_hint: None,
100                    env_hint: None,
101                }],
102            });
103        }
104    };
105
106    let mut missing_fields = Vec::new();
107    let mut extracted_values: ObjectMap = IndexMap::default();
108
109    // Get env prefix from schema if available
110    let env_prefix = schema.config().and_then(|c| c.env_prefix());
111
112    for field in struct_type.fields {
113        let field_name = field.name;
114
115        // Find the args::origin attribute
116        let origin_path = find_origin_attribute(field);
117
118        let Some(origin_path) = origin_path else {
119            // Field doesn't have args::origin - this is an error
120            return Err(ExtractError {
121                missing_fields: vec![ExtractMissingField {
122                    field_name: field_name.to_string(),
123                    origin_path: "<missing args::origin attribute>".to_string(),
124                    type_name: field.shape().type_identifier.to_string(),
125                    cli_hint: None,
126                    env_hint: None,
127                }],
128            });
129        };
130
131        // Parse the origin path
132        let path_segments: Vec<&str> = origin_path.split('.').collect();
133
134        // Look up value in config_value
135        let value = get_value_by_path(config_value, &path_segments);
136
137        // Check if field is optional (Option<T>)
138        let is_optional = matches!(field.shape().def, Def::Option(_));
139
140        match value {
141            Some(v) if !is_null_value(v) => {
142                // Value exists - add to extracted values
143                extracted_values.insert(field_name.to_string(), v.clone());
144            }
145            _ => {
146                // Value is missing or null
147                if is_optional {
148                    // Optional field - insert null
149                    extracted_values
150                        .insert(field_name.to_string(), ConfigValue::Null(Sourced::new(())));
151                } else {
152                    // Required field is missing - collect error info
153                    let cli_hint = compute_cli_hint(origin_path);
154                    let env_hint = compute_env_hint(origin_path, env_prefix);
155
156                    missing_fields.push(ExtractMissingField {
157                        field_name: field_name.to_string(),
158                        origin_path: origin_path.to_string(),
159                        type_name: field.shape().type_identifier.to_string(),
160                        cli_hint,
161                        env_hint,
162                    });
163                }
164            }
165        }
166    }
167
168    if !missing_fields.is_empty() {
169        return Err(ExtractError { missing_fields });
170    }
171
172    // Build ConfigValue::Object from extracted values and deserialize
173    let extracted_config = ConfigValue::Object(Sourced::new(extracted_values));
174
175    crate::config_value_parser::from_config_value(&extracted_config).map_err(|e| ExtractError {
176        missing_fields: vec![ExtractMissingField {
177            field_name: "<deserialization>".to_string(),
178            origin_path: e.to_string(),
179            type_name: shape.type_identifier.to_string(),
180            cli_hint: None,
181            env_hint: None,
182        }],
183    })
184}
185
186/// Find the `args::origin` attribute value from a field.
187fn find_origin_attribute(field: &Field) -> Option<&'static str> {
188    // The attribute data for args::origin is stored directly as &str
189    for field_attr in field.attributes {
190        if field_attr.ns == Some("args")
191            && field_attr.key == "origin"
192            && let Some(s) = field_attr.get_as::<&str>()
193        {
194            return Some(s);
195        }
196    }
197    None
198}
199
200/// Navigate into a ConfigValue by dot-separated path.
201fn get_value_by_path<'a>(value: &'a ConfigValue, path: &[&str]) -> Option<&'a ConfigValue> {
202    let mut current = value;
203    for segment in path {
204        match current {
205            ConfigValue::Object(obj) => {
206                current = obj.value.get(*segment)?;
207            }
208            _ => return None,
209        }
210    }
211    Some(current)
212}
213
214/// Check if a ConfigValue is null.
215fn is_null_value(value: &ConfigValue) -> bool {
216    matches!(value, ConfigValue::Null(_))
217}
218
219/// Compute CLI hint from origin path.
220fn compute_cli_hint(origin_path: &str) -> Option<String> {
221    let kebab_path = origin_path
222        .split('.')
223        .map(|s| s.to_kebab_case())
224        .collect::<Vec<_>>()
225        .join(".");
226    Some(format!("--{}", kebab_path))
227}
228
229/// Compute environment variable hint from origin path.
230fn compute_env_hint(origin_path: &str, env_prefix: Option<&str>) -> Option<String> {
231    let shouty_path = origin_path
232        .split('.')
233        .map(|s| s.to_shouty_snake_case())
234        .collect::<Vec<_>>()
235        .join("__");
236
237    if let Some(prefix) = env_prefix {
238        Some(format!("${}__{}", prefix, shouty_path))
239    } else {
240        Some(format!("${}", shouty_path))
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::config_value::Sourced;
248    use facet::Facet;
249    use figue_attrs as args;
250
251    // Helper to create test config values
252    fn cv_object(fields: impl IntoIterator<Item = (&'static str, ConfigValue)>) -> ConfigValue {
253        let map: ObjectMap = fields
254            .into_iter()
255            .map(|(k, v)| (k.to_string(), v))
256            .collect();
257        ConfigValue::Object(Sourced::new(map))
258    }
259
260    fn cv_string(s: &str) -> ConfigValue {
261        ConfigValue::String(Sourced::new(s.to_string()))
262    }
263
264    fn cv_int(i: i64) -> ConfigValue {
265        ConfigValue::Integer(Sourced::new(i))
266    }
267
268    // ========================================================================
269    // Test structs
270    // ========================================================================
271
272    #[derive(Facet, Debug, PartialEq)]
273    struct SimpleRequirements {
274        #[facet(args::origin = "config.database_url")]
275        database_url: String,
276
277        #[facet(args::origin = "config.port")]
278        port: u16,
279    }
280
281    #[derive(Facet, Debug, PartialEq)]
282    struct RequirementsWithOptional {
283        #[facet(args::origin = "config.database_url")]
284        database_url: String,
285
286        #[facet(args::origin = "config.timeout")]
287        timeout: Option<u32>,
288    }
289
290    #[derive(Facet, Debug, PartialEq)]
291    struct NestedRequirements {
292        #[facet(args::origin = "config.server.host")]
293        host: String,
294
295        #[facet(args::origin = "config.server.port")]
296        port: u16,
297    }
298
299    // ========================================================================
300    // Tests
301    // ========================================================================
302
303    #[test]
304    fn test_extract_all_present() {
305        let config = cv_object([(
306            "config",
307            cv_object([
308                ("database_url", cv_string("postgres://localhost/db")),
309                ("port", cv_int(8080)),
310            ]),
311        )]);
312
313        // Create a minimal schema for testing
314        #[derive(Facet)]
315        struct TestConfig {
316            database_url: String,
317            port: u16,
318        }
319
320        #[derive(Facet)]
321        struct TestArgs {
322            #[facet(args::config)]
323            config: TestConfig,
324        }
325
326        let schema = Schema::from_shape(TestArgs::SHAPE).unwrap();
327        let result: Result<SimpleRequirements, _> = extract_requirements(&config, &schema);
328
329        assert!(result.is_ok(), "extraction should succeed: {:?}", result);
330        let req = result.unwrap();
331        assert_eq!(req.database_url, "postgres://localhost/db");
332        assert_eq!(req.port, 8080);
333    }
334
335    #[test]
336    fn test_extract_missing_required() {
337        let config = cv_object([("config", cv_object([("port", cv_int(8080))]))]);
338
339        #[derive(Facet)]
340        struct TestConfig {
341            database_url: Option<String>,
342            port: u16,
343        }
344
345        #[derive(Facet)]
346        struct TestArgs {
347            #[facet(args::config)]
348            config: TestConfig,
349        }
350
351        let schema = Schema::from_shape(TestArgs::SHAPE).unwrap();
352        let result: Result<SimpleRequirements, _> = extract_requirements(&config, &schema);
353
354        assert!(result.is_err(), "extraction should fail");
355        let err = result.unwrap_err();
356        assert_eq!(err.missing_fields.len(), 1);
357        assert_eq!(err.missing_fields[0].field_name, "database_url");
358        assert_eq!(err.missing_fields[0].origin_path, "config.database_url");
359    }
360
361    #[test]
362    fn test_extract_optional_missing() {
363        let config = cv_object([(
364            "config",
365            cv_object([("database_url", cv_string("postgres://localhost/db"))]),
366        )]);
367
368        #[derive(Facet)]
369        struct TestConfig {
370            database_url: String,
371            timeout: Option<u32>,
372        }
373
374        #[derive(Facet)]
375        struct TestArgs {
376            #[facet(args::config)]
377            config: TestConfig,
378        }
379
380        let schema = Schema::from_shape(TestArgs::SHAPE).unwrap();
381        let result: Result<RequirementsWithOptional, _> = extract_requirements(&config, &schema);
382
383        assert!(
384            result.is_ok(),
385            "extraction should succeed with missing optional: {:?}",
386            result
387        );
388        let req = result.unwrap();
389        assert_eq!(req.database_url, "postgres://localhost/db");
390        assert_eq!(req.timeout, None);
391    }
392
393    #[test]
394    fn test_extract_nested_paths() {
395        let config = cv_object([(
396            "config",
397            cv_object([(
398                "server",
399                cv_object([("host", cv_string("localhost")), ("port", cv_int(3000))]),
400            )]),
401        )]);
402
403        #[derive(Facet)]
404        struct ServerConfig {
405            host: String,
406            port: u16,
407        }
408
409        #[derive(Facet)]
410        struct TestConfig {
411            server: ServerConfig,
412        }
413
414        #[derive(Facet)]
415        struct TestArgs {
416            #[facet(args::config)]
417            config: TestConfig,
418        }
419
420        let schema = Schema::from_shape(TestArgs::SHAPE).unwrap();
421        let result: Result<NestedRequirements, _> = extract_requirements(&config, &schema);
422
423        assert!(
424            result.is_ok(),
425            "extraction with nested paths should succeed: {:?}",
426            result
427        );
428        let req = result.unwrap();
429        assert_eq!(req.host, "localhost");
430        assert_eq!(req.port, 3000);
431    }
432
433    #[test]
434    fn test_extract_multiple_missing() {
435        let config = cv_object([("config", cv_object([]))]);
436
437        #[derive(Facet)]
438        struct TestConfig {
439            database_url: Option<String>,
440            port: Option<u16>,
441        }
442
443        #[derive(Facet)]
444        struct TestArgs {
445            #[facet(args::config)]
446            config: TestConfig,
447        }
448
449        let schema = Schema::from_shape(TestArgs::SHAPE).unwrap();
450        let result: Result<SimpleRequirements, _> = extract_requirements(&config, &schema);
451
452        assert!(result.is_err(), "extraction should fail");
453        let err = result.unwrap_err();
454        assert_eq!(err.missing_fields.len(), 2);
455
456        let field_names: Vec<_> = err
457            .missing_fields
458            .iter()
459            .map(|f| f.field_name.as_str())
460            .collect();
461        assert!(field_names.contains(&"database_url"));
462        assert!(field_names.contains(&"port"));
463    }
464
465    #[test]
466    fn test_cli_hint_format() {
467        let hint = compute_cli_hint("config.database_url");
468        assert_eq!(hint, Some("--config.database-url".to_string()));
469    }
470
471    #[test]
472    fn test_env_hint_format_with_prefix() {
473        let hint = compute_env_hint("config.database_url", Some("MYAPP"));
474        assert_eq!(hint, Some("$MYAPP__CONFIG__DATABASE_URL".to_string()));
475    }
476
477    #[test]
478    fn test_env_hint_format_without_prefix() {
479        let hint = compute_env_hint("config.database_url", None);
480        assert_eq!(hint, Some("$CONFIG__DATABASE_URL".to_string()));
481    }
482
483    #[test]
484    fn test_missing_origin_attribute_error() {
485        #[derive(Facet, Debug)]
486        struct BadRequirements {
487            // Missing args::origin attribute
488            database_url: String,
489        }
490
491        let config = cv_object([]);
492
493        #[derive(Facet)]
494        struct TestArgs {}
495
496        let schema = Schema::from_shape(TestArgs::SHAPE).unwrap();
497        let result: Result<BadRequirements, _> = extract_requirements(&config, &schema);
498
499        assert!(result.is_err(), "should fail for missing origin attribute");
500        let err = result.unwrap_err();
501        assert!(
502            err.missing_fields[0]
503                .origin_path
504                .contains("missing args::origin")
505        );
506    }
507}