1use crate::config_value::{ConfigValue, ObjectMap, Sourced};
7use crate::schema::Schema;
8use facet::{Def, Facet, Field, Type, UserType};
9use heck::{ToKebabCase, ToShoutySnakeCase};
10use indexmap::IndexMap;
11use owo_colors::OwoColorize;
12use owo_colors::Stream::Stdout;
13
14#[derive(Debug, Clone)]
16pub struct ExtractMissingField {
17 pub field_name: String,
19 pub origin_path: String,
21 pub type_name: String,
23 pub cli_hint: Option<String>,
25 pub env_hint: Option<String>,
27}
28
29#[derive(Debug)]
31pub struct ExtractError {
32 pub missing_fields: Vec<ExtractMissingField>,
34}
35
36impl std::fmt::Display for ExtractError {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 writeln!(f, "Missing required fields for this operation:\n")?;
39 for field in &self.missing_fields {
40 write!(
41 f,
42 " {} <{}> at {}",
43 field
44 .field_name
45 .if_supports_color(Stdout, |text| text.bold()),
46 field
47 .type_name
48 .if_supports_color(Stdout, |text| text.cyan()),
49 field.origin_path
50 )?;
51
52 let mut hints = Vec::new();
53 if let Some(cli) = &field.cli_hint {
54 hints.push(
55 cli.if_supports_color(Stdout, |text| text.green())
56 .to_string(),
57 );
58 }
59 if let Some(env) = &field.env_hint {
60 hints.push(
61 env.if_supports_color(Stdout, |text| text.yellow())
62 .to_string(),
63 );
64 }
65 if !hints.is_empty() {
66 write!(f, "\n Set via: {}", hints.join(" or "))?;
67 }
68 writeln!(f)?;
69 }
70 Ok(())
71 }
72}
73
74impl std::error::Error for ExtractError {}
75
76pub fn extract_requirements<R: Facet<'static>>(
85 config_value: &ConfigValue,
86 schema: &Schema,
87) -> Result<R, ExtractError> {
88 let shape = R::SHAPE;
89
90 let struct_type = match &shape.ty {
92 Type::User(UserType::Struct(s)) => *s,
93 _ => {
94 return Err(ExtractError {
95 missing_fields: vec![ExtractMissingField {
96 field_name: "<root>".to_string(),
97 origin_path: "<root>".to_string(),
98 type_name: shape.type_identifier.to_string(),
99 cli_hint: None,
100 env_hint: None,
101 }],
102 });
103 }
104 };
105
106 let mut missing_fields = Vec::new();
107 let mut extracted_values: ObjectMap = IndexMap::default();
108
109 let env_prefix = schema.config().and_then(|c| c.env_prefix());
111
112 for field in struct_type.fields {
113 let field_name = field.name;
114
115 let origin_path = find_origin_attribute(field);
117
118 let Some(origin_path) = origin_path else {
119 return Err(ExtractError {
121 missing_fields: vec![ExtractMissingField {
122 field_name: field_name.to_string(),
123 origin_path: "<missing args::origin attribute>".to_string(),
124 type_name: field.shape().type_identifier.to_string(),
125 cli_hint: None,
126 env_hint: None,
127 }],
128 });
129 };
130
131 let path_segments: Vec<&str> = origin_path.split('.').collect();
133
134 let value = get_value_by_path(config_value, &path_segments);
136
137 let is_optional = matches!(field.shape().def, Def::Option(_));
139
140 match value {
141 Some(v) if !is_null_value(v) => {
142 extracted_values.insert(field_name.to_string(), v.clone());
144 }
145 _ => {
146 if is_optional {
148 extracted_values
150 .insert(field_name.to_string(), ConfigValue::Null(Sourced::new(())));
151 } else {
152 let cli_hint = compute_cli_hint(origin_path);
154 let env_hint = compute_env_hint(origin_path, env_prefix);
155
156 missing_fields.push(ExtractMissingField {
157 field_name: field_name.to_string(),
158 origin_path: origin_path.to_string(),
159 type_name: field.shape().type_identifier.to_string(),
160 cli_hint,
161 env_hint,
162 });
163 }
164 }
165 }
166 }
167
168 if !missing_fields.is_empty() {
169 return Err(ExtractError { missing_fields });
170 }
171
172 let extracted_config = ConfigValue::Object(Sourced::new(extracted_values));
174
175 crate::config_value_parser::from_config_value(&extracted_config).map_err(|e| ExtractError {
176 missing_fields: vec![ExtractMissingField {
177 field_name: "<deserialization>".to_string(),
178 origin_path: e.to_string(),
179 type_name: shape.type_identifier.to_string(),
180 cli_hint: None,
181 env_hint: None,
182 }],
183 })
184}
185
186fn find_origin_attribute(field: &Field) -> Option<&'static str> {
188 for field_attr in field.attributes {
190 if field_attr.ns == Some("args")
191 && field_attr.key == "origin"
192 && let Some(s) = field_attr.get_as::<&str>()
193 {
194 return Some(s);
195 }
196 }
197 None
198}
199
200fn get_value_by_path<'a>(value: &'a ConfigValue, path: &[&str]) -> Option<&'a ConfigValue> {
202 let mut current = value;
203 for segment in path {
204 match current {
205 ConfigValue::Object(obj) => {
206 current = obj.value.get(*segment)?;
207 }
208 _ => return None,
209 }
210 }
211 Some(current)
212}
213
214fn is_null_value(value: &ConfigValue) -> bool {
216 matches!(value, ConfigValue::Null(_))
217}
218
219fn compute_cli_hint(origin_path: &str) -> Option<String> {
221 let kebab_path = origin_path
222 .split('.')
223 .map(|s| s.to_kebab_case())
224 .collect::<Vec<_>>()
225 .join(".");
226 Some(format!("--{}", kebab_path))
227}
228
229fn compute_env_hint(origin_path: &str, env_prefix: Option<&str>) -> Option<String> {
231 let shouty_path = origin_path
232 .split('.')
233 .map(|s| s.to_shouty_snake_case())
234 .collect::<Vec<_>>()
235 .join("__");
236
237 if let Some(prefix) = env_prefix {
238 Some(format!("${}__{}", prefix, shouty_path))
239 } else {
240 Some(format!("${}", shouty_path))
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::config_value::Sourced;
248 use facet::Facet;
249 use figue_attrs as args;
250
251 fn cv_object(fields: impl IntoIterator<Item = (&'static str, ConfigValue)>) -> ConfigValue {
253 let map: ObjectMap = fields
254 .into_iter()
255 .map(|(k, v)| (k.to_string(), v))
256 .collect();
257 ConfigValue::Object(Sourced::new(map))
258 }
259
260 fn cv_string(s: &str) -> ConfigValue {
261 ConfigValue::String(Sourced::new(s.to_string()))
262 }
263
264 fn cv_int(i: i64) -> ConfigValue {
265 ConfigValue::Integer(Sourced::new(i))
266 }
267
268 #[derive(Facet, Debug, PartialEq)]
273 struct SimpleRequirements {
274 #[facet(args::origin = "config.database_url")]
275 database_url: String,
276
277 #[facet(args::origin = "config.port")]
278 port: u16,
279 }
280
281 #[derive(Facet, Debug, PartialEq)]
282 struct RequirementsWithOptional {
283 #[facet(args::origin = "config.database_url")]
284 database_url: String,
285
286 #[facet(args::origin = "config.timeout")]
287 timeout: Option<u32>,
288 }
289
290 #[derive(Facet, Debug, PartialEq)]
291 struct NestedRequirements {
292 #[facet(args::origin = "config.server.host")]
293 host: String,
294
295 #[facet(args::origin = "config.server.port")]
296 port: u16,
297 }
298
299 #[test]
304 fn test_extract_all_present() {
305 let config = cv_object([(
306 "config",
307 cv_object([
308 ("database_url", cv_string("postgres://localhost/db")),
309 ("port", cv_int(8080)),
310 ]),
311 )]);
312
313 #[derive(Facet)]
315 struct TestConfig {
316 database_url: String,
317 port: u16,
318 }
319
320 #[derive(Facet)]
321 struct TestArgs {
322 #[facet(args::config)]
323 config: TestConfig,
324 }
325
326 let schema = Schema::from_shape(TestArgs::SHAPE).unwrap();
327 let result: Result<SimpleRequirements, _> = extract_requirements(&config, &schema);
328
329 assert!(result.is_ok(), "extraction should succeed: {:?}", result);
330 let req = result.unwrap();
331 assert_eq!(req.database_url, "postgres://localhost/db");
332 assert_eq!(req.port, 8080);
333 }
334
335 #[test]
336 fn test_extract_missing_required() {
337 let config = cv_object([("config", cv_object([("port", cv_int(8080))]))]);
338
339 #[derive(Facet)]
340 struct TestConfig {
341 database_url: Option<String>,
342 port: u16,
343 }
344
345 #[derive(Facet)]
346 struct TestArgs {
347 #[facet(args::config)]
348 config: TestConfig,
349 }
350
351 let schema = Schema::from_shape(TestArgs::SHAPE).unwrap();
352 let result: Result<SimpleRequirements, _> = extract_requirements(&config, &schema);
353
354 assert!(result.is_err(), "extraction should fail");
355 let err = result.unwrap_err();
356 assert_eq!(err.missing_fields.len(), 1);
357 assert_eq!(err.missing_fields[0].field_name, "database_url");
358 assert_eq!(err.missing_fields[0].origin_path, "config.database_url");
359 }
360
361 #[test]
362 fn test_extract_optional_missing() {
363 let config = cv_object([(
364 "config",
365 cv_object([("database_url", cv_string("postgres://localhost/db"))]),
366 )]);
367
368 #[derive(Facet)]
369 struct TestConfig {
370 database_url: String,
371 timeout: Option<u32>,
372 }
373
374 #[derive(Facet)]
375 struct TestArgs {
376 #[facet(args::config)]
377 config: TestConfig,
378 }
379
380 let schema = Schema::from_shape(TestArgs::SHAPE).unwrap();
381 let result: Result<RequirementsWithOptional, _> = extract_requirements(&config, &schema);
382
383 assert!(
384 result.is_ok(),
385 "extraction should succeed with missing optional: {:?}",
386 result
387 );
388 let req = result.unwrap();
389 assert_eq!(req.database_url, "postgres://localhost/db");
390 assert_eq!(req.timeout, None);
391 }
392
393 #[test]
394 fn test_extract_nested_paths() {
395 let config = cv_object([(
396 "config",
397 cv_object([(
398 "server",
399 cv_object([("host", cv_string("localhost")), ("port", cv_int(3000))]),
400 )]),
401 )]);
402
403 #[derive(Facet)]
404 struct ServerConfig {
405 host: String,
406 port: u16,
407 }
408
409 #[derive(Facet)]
410 struct TestConfig {
411 server: ServerConfig,
412 }
413
414 #[derive(Facet)]
415 struct TestArgs {
416 #[facet(args::config)]
417 config: TestConfig,
418 }
419
420 let schema = Schema::from_shape(TestArgs::SHAPE).unwrap();
421 let result: Result<NestedRequirements, _> = extract_requirements(&config, &schema);
422
423 assert!(
424 result.is_ok(),
425 "extraction with nested paths should succeed: {:?}",
426 result
427 );
428 let req = result.unwrap();
429 assert_eq!(req.host, "localhost");
430 assert_eq!(req.port, 3000);
431 }
432
433 #[test]
434 fn test_extract_multiple_missing() {
435 let config = cv_object([("config", cv_object([]))]);
436
437 #[derive(Facet)]
438 struct TestConfig {
439 database_url: Option<String>,
440 port: Option<u16>,
441 }
442
443 #[derive(Facet)]
444 struct TestArgs {
445 #[facet(args::config)]
446 config: TestConfig,
447 }
448
449 let schema = Schema::from_shape(TestArgs::SHAPE).unwrap();
450 let result: Result<SimpleRequirements, _> = extract_requirements(&config, &schema);
451
452 assert!(result.is_err(), "extraction should fail");
453 let err = result.unwrap_err();
454 assert_eq!(err.missing_fields.len(), 2);
455
456 let field_names: Vec<_> = err
457 .missing_fields
458 .iter()
459 .map(|f| f.field_name.as_str())
460 .collect();
461 assert!(field_names.contains(&"database_url"));
462 assert!(field_names.contains(&"port"));
463 }
464
465 #[test]
466 fn test_cli_hint_format() {
467 let hint = compute_cli_hint("config.database_url");
468 assert_eq!(hint, Some("--config.database-url".to_string()));
469 }
470
471 #[test]
472 fn test_env_hint_format_with_prefix() {
473 let hint = compute_env_hint("config.database_url", Some("MYAPP"));
474 assert_eq!(hint, Some("$MYAPP__CONFIG__DATABASE_URL".to_string()));
475 }
476
477 #[test]
478 fn test_env_hint_format_without_prefix() {
479 let hint = compute_env_hint("config.database_url", None);
480 assert_eq!(hint, Some("$CONFIG__DATABASE_URL".to_string()));
481 }
482
483 #[test]
484 fn test_missing_origin_attribute_error() {
485 #[derive(Facet, Debug)]
486 struct BadRequirements {
487 database_url: String,
489 }
490
491 let config = cv_object([]);
492
493 #[derive(Facet)]
494 struct TestArgs {}
495
496 let schema = Schema::from_shape(TestArgs::SHAPE).unwrap();
497 let result: Result<BadRequirements, _> = extract_requirements(&config, &schema);
498
499 assert!(result.is_err(), "should fail for missing origin attribute");
500 let err = result.unwrap_err();
501 assert!(
502 err.missing_fields[0]
503 .origin_path
504 .contains("missing args::origin")
505 );
506 }
507}