Skip to main content

standout_input/
chain.rs

1//! Input chain builder for composing multiple sources.
2//!
3//! The [`InputChain`] allows chaining multiple input sources with fallback
4//! behavior. Sources are tried in order until one provides input.
5
6use std::fmt;
7
8use clap::ArgMatches;
9
10use crate::collector::{InputCollector, InputSourceKind, ResolvedInput};
11use crate::InputError;
12
13/// Validator function type.
14type ValidatorFn<T> = Box<dyn Fn(&T) -> Result<(), String> + Send + Sync>;
15
16/// Chain multiple input sources with fallback behavior.
17///
18/// Sources are tried in the order they were added. The first source that
19/// returns `Some(value)` wins. If all sources return `None`, the chain
20/// uses the default value or returns [`InputError::NoInput`].
21///
22/// # Example
23///
24/// ```ignore
25/// use standout_input::{InputChain, ArgSource, StdinSource, DefaultSource};
26///
27/// // Try argument first, then stdin, then use default
28/// let chain = InputChain::<String>::new()
29///     .try_source(ArgSource::new("message"))
30///     .try_source(StdinSource::new())
31///     .try_source(DefaultSource::new("default message".to_string()));
32///
33/// let value = chain.resolve(&matches)?;
34/// ```
35///
36/// # Validation
37///
38/// Add validators to check the resolved value:
39///
40/// ```ignore
41/// let chain = InputChain::<String>::new()
42///     .try_source(ArgSource::new("email"))
43///     .validate(|s| s.contains('@'), "Must be a valid email");
44/// ```
45///
46/// Interactive sources (prompts, editor) can retry on validation failure.
47pub struct InputChain<T> {
48    sources: Vec<(Box<dyn InputCollector<T>>, InputSourceKind)>,
49    validators: Vec<(ValidatorFn<T>, String)>,
50    default: Option<T>,
51}
52
53impl<T: Clone + Send + Sync + 'static> InputChain<T> {
54    /// Create a new empty input chain.
55    pub fn new() -> Self {
56        Self {
57            sources: Vec::new(),
58            validators: Vec::new(),
59            default: None,
60        }
61    }
62
63    /// Add a source to the chain.
64    ///
65    /// Sources are tried in the order they are added.
66    pub fn try_source<C: InputCollector<T> + 'static>(mut self, source: C) -> Self {
67        let kind = source_kind_from_name(source.name());
68        self.sources.push((Box::new(source), kind));
69        self
70    }
71
72    /// Add a source with an explicit kind.
73    ///
74    /// Use this when the source name doesn't map to a standard kind.
75    pub fn try_source_with_kind<C: InputCollector<T> + 'static>(
76        mut self,
77        source: C,
78        kind: InputSourceKind,
79    ) -> Self {
80        self.sources.push((Box::new(source), kind));
81        self
82    }
83
84    /// Add a validation rule.
85    ///
86    /// The validator is called after a source successfully provides input.
87    /// If validation fails:
88    /// - Interactive sources (where `can_retry()` is true) will re-prompt
89    /// - Non-interactive sources will return a validation error
90    ///
91    /// Multiple validators are checked in order; all must pass.
92    pub fn validate<F>(mut self, f: F, error_msg: impl Into<String>) -> Self
93    where
94        F: Fn(&T) -> bool + Send + Sync + 'static,
95    {
96        let msg = error_msg.into();
97        let msg_for_closure = msg.clone();
98        self.validators.push((
99            Box::new(move |value| {
100                if f(value) {
101                    Ok(())
102                } else {
103                    Err(msg_for_closure.clone())
104                }
105            }),
106            msg,
107        ));
108        self
109    }
110
111    /// Add a validation rule that returns a Result.
112    ///
113    /// Unlike [`validate`](Self::validate), this allows custom error messages
114    /// per validation failure.
115    pub fn validate_with<F>(mut self, f: F) -> Self
116    where
117        F: Fn(&T) -> Result<(), String> + Send + Sync + 'static,
118    {
119        self.validators
120            .push((Box::new(f), "validation failed".to_string()));
121        self
122    }
123
124    /// Set a default value to use when no source provides input.
125    ///
126    /// This is equivalent to adding a [`DefaultSource`](crate::DefaultSource)
127    /// at the end of the chain.
128    pub fn default(mut self, value: T) -> Self {
129        self.default = Some(value);
130        self
131    }
132
133    /// Resolve the chain and return the input value.
134    ///
135    /// Tries each source in order until one provides input, then runs
136    /// validation. Returns the value or an error.
137    pub fn resolve(&self, matches: &ArgMatches) -> Result<T, InputError> {
138        self.resolve_with_source(matches).map(|r| r.value)
139    }
140
141    /// Resolve the chain and return the input with source metadata.
142    ///
143    /// Like [`resolve`](Self::resolve), but also returns which source
144    /// provided the value.
145    pub fn resolve_with_source(
146        &self,
147        matches: &ArgMatches,
148    ) -> Result<ResolvedInput<T>, InputError> {
149        for (source, kind) in &self.sources {
150            if !source.is_available(matches) {
151                continue;
152            }
153
154            // This loop is intentional: interactive sources (where can_retry() is true)
155            // will re-prompt on validation failure. The `break` on None moves to the
156            // next source in the chain.
157            #[allow(clippy::while_let_loop)]
158            loop {
159                match source.collect(matches)? {
160                    Some(value) => {
161                        // Run source-level validation
162                        if let Err(msg) = source.validate(&value) {
163                            if source.can_retry() {
164                                eprintln!("Invalid: {}", msg);
165                                continue;
166                            }
167                            return Err(InputError::ValidationFailed(msg));
168                        }
169
170                        // Run chain-level validators
171                        for (validator, _) in &self.validators {
172                            if let Err(msg) = validator(&value) {
173                                if source.can_retry() {
174                                    eprintln!("Invalid: {}", msg);
175                                    continue;
176                                }
177                                return Err(InputError::ValidationFailed(msg));
178                            }
179                        }
180
181                        return Ok(ResolvedInput {
182                            value,
183                            source: *kind,
184                        });
185                    }
186                    None => break, // Try next source
187                }
188            }
189        }
190
191        // No source provided input; try default
192        if let Some(value) = &self.default {
193            return Ok(ResolvedInput {
194                value: value.clone(),
195                source: InputSourceKind::Default,
196            });
197        }
198
199        Err(InputError::NoInput)
200    }
201
202    /// Check if any source is available to provide input.
203    pub fn has_available_source(&self, matches: &ArgMatches) -> bool {
204        self.sources.iter().any(|(s, _)| s.is_available(matches)) || self.default.is_some()
205    }
206
207    /// Get the number of sources in the chain.
208    pub fn source_count(&self) -> usize {
209        self.sources.len()
210    }
211}
212
213impl<T: Clone + Send + Sync + 'static> Default for InputChain<T> {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219impl<T> fmt::Debug for InputChain<T> {
220    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221        f.debug_struct("InputChain")
222            .field(
223                "sources",
224                &self.sources.iter().map(|(_, k)| k).collect::<Vec<_>>(),
225            )
226            .field("validators", &self.validators.len())
227            .field("has_default", &self.default.is_some())
228            .finish()
229    }
230}
231
232/// Map source name to InputSourceKind.
233fn source_kind_from_name(name: &str) -> InputSourceKind {
234    match name {
235        "argument" => InputSourceKind::Arg,
236        "flag" => InputSourceKind::Flag,
237        "stdin" => InputSourceKind::Stdin,
238        "environment variable" => InputSourceKind::Env,
239        "clipboard" => InputSourceKind::Clipboard,
240        "editor" => InputSourceKind::Editor,
241        "prompt" => InputSourceKind::Prompt,
242        "default" => InputSourceKind::Default,
243        _ => InputSourceKind::Default,
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use crate::env::{MockClipboard, MockEnv, MockStdin};
251    use crate::sources::{ArgSource, ClipboardSource, DefaultSource, EnvSource, StdinSource};
252    use clap::{Arg, Command};
253
254    fn make_matches(args: &[&str]) -> ArgMatches {
255        Command::new("test")
256            .arg(Arg::new("message").long("message").short('m'))
257            .try_get_matches_from(args)
258            .unwrap()
259    }
260
261    #[test]
262    fn chain_resolves_first_available() {
263        let matches = make_matches(&["test", "--message", "from arg"]);
264
265        let chain = InputChain::<String>::new()
266            .try_source(ArgSource::new("message"))
267            .try_source(DefaultSource::new("default".to_string()));
268
269        let result = chain.resolve_with_source(&matches).unwrap();
270        assert_eq!(result.value, "from arg");
271        assert_eq!(result.source, InputSourceKind::Arg);
272    }
273
274    #[test]
275    fn chain_falls_back_to_next_source() {
276        let matches = make_matches(&["test"]); // No --message
277
278        let chain = InputChain::<String>::new()
279            .try_source(ArgSource::new("message"))
280            .try_source(StdinSource::with_reader(MockStdin::piped("from stdin")));
281
282        let result = chain.resolve_with_source(&matches).unwrap();
283        assert_eq!(result.value, "from stdin");
284        assert_eq!(result.source, InputSourceKind::Stdin);
285    }
286
287    #[test]
288    fn chain_falls_back_to_default() {
289        let matches = make_matches(&["test"]);
290
291        let chain = InputChain::<String>::new()
292            .try_source(ArgSource::new("message"))
293            .try_source(StdinSource::with_reader(MockStdin::terminal()))
294            .default("default value".to_string());
295
296        let result = chain.resolve_with_source(&matches).unwrap();
297        assert_eq!(result.value, "default value");
298        assert_eq!(result.source, InputSourceKind::Default);
299    }
300
301    #[test]
302    fn chain_error_when_no_input() {
303        let matches = make_matches(&["test"]);
304
305        let chain = InputChain::<String>::new()
306            .try_source(ArgSource::new("message"))
307            .try_source(StdinSource::with_reader(MockStdin::terminal()));
308
309        let result = chain.resolve(&matches);
310        assert!(matches!(result, Err(InputError::NoInput)));
311    }
312
313    #[test]
314    fn chain_validation_passes() {
315        let matches = make_matches(&["test", "--message", "valid@email.com"]);
316
317        let chain = InputChain::<String>::new()
318            .try_source(ArgSource::new("message"))
319            .validate(|s| s.contains('@'), "Must contain @");
320
321        let result = chain.resolve(&matches).unwrap();
322        assert_eq!(result, "valid@email.com");
323    }
324
325    #[test]
326    fn chain_validation_fails() {
327        let matches = make_matches(&["test", "--message", "invalid"]);
328
329        let chain = InputChain::<String>::new()
330            .try_source(ArgSource::new("message"))
331            .validate(|s| s.contains('@'), "Must contain @");
332
333        let result = chain.resolve(&matches);
334        assert!(matches!(result, Err(InputError::ValidationFailed(_))));
335    }
336
337    #[test]
338    fn chain_multiple_validators() {
339        let matches = make_matches(&["test", "--message", "ab"]);
340
341        let chain = InputChain::<String>::new()
342            .try_source(ArgSource::new("message"))
343            .validate(|s| !s.is_empty(), "Cannot be empty")
344            .validate(|s| s.len() >= 3, "Must be at least 3 characters");
345
346        let result = chain.resolve(&matches);
347        assert!(matches!(result, Err(InputError::ValidationFailed(_))));
348    }
349
350    #[test]
351    fn chain_complex_fallback() {
352        let matches = make_matches(&["test"]);
353
354        // arg → stdin → env → clipboard → default
355        let chain = InputChain::<String>::new()
356            .try_source(ArgSource::new("message"))
357            .try_source(StdinSource::with_reader(MockStdin::terminal()))
358            .try_source(EnvSource::with_reader("MY_MSG", MockEnv::new()))
359            .try_source(ClipboardSource::with_reader(MockClipboard::with_content(
360                "from clipboard",
361            )));
362
363        let result = chain.resolve_with_source(&matches).unwrap();
364        assert_eq!(result.value, "from clipboard");
365        assert_eq!(result.source, InputSourceKind::Clipboard);
366    }
367
368    #[test]
369    fn chain_has_available_source() {
370        let matches = make_matches(&["test"]);
371
372        let chain_with_default = InputChain::<String>::new()
373            .try_source(ArgSource::new("message"))
374            .default("default".to_string());
375
376        assert!(chain_with_default.has_available_source(&matches));
377
378        let chain_without = InputChain::<String>::new().try_source(ArgSource::new("message"));
379
380        assert!(!chain_without.has_available_source(&matches));
381    }
382
383    #[test]
384    fn chain_source_count() {
385        let chain = InputChain::<String>::new()
386            .try_source(ArgSource::new("a"))
387            .try_source(ArgSource::new("b"))
388            .try_source(ArgSource::new("c"));
389
390        assert_eq!(chain.source_count(), 3);
391    }
392}