1use std::fmt;
7
8use clap::ArgMatches;
9
10use crate::collector::{InputCollector, InputSourceKind, ResolvedInput};
11use crate::InputError;
12
13type ValidatorFn<T> = Box<dyn Fn(&T) -> Result<(), String> + Send + Sync>;
15
16pub struct InputChain<T> {
48 sources: Vec<(Box<dyn InputCollector<T>>, InputSourceKind)>,
49 validators: Vec<(ValidatorFn<T>, String)>,
50 default: Option<T>,
51}
52
53impl<T: Clone + Send + Sync + 'static> InputChain<T> {
54 pub fn new() -> Self {
56 Self {
57 sources: Vec::new(),
58 validators: Vec::new(),
59 default: None,
60 }
61 }
62
63 pub fn try_source<C: InputCollector<T> + 'static>(mut self, source: C) -> Self {
67 let kind = source_kind_from_name(source.name());
68 self.sources.push((Box::new(source), kind));
69 self
70 }
71
72 pub fn try_source_with_kind<C: InputCollector<T> + 'static>(
76 mut self,
77 source: C,
78 kind: InputSourceKind,
79 ) -> Self {
80 self.sources.push((Box::new(source), kind));
81 self
82 }
83
84 pub fn validate<F>(mut self, f: F, error_msg: impl Into<String>) -> Self
93 where
94 F: Fn(&T) -> bool + Send + Sync + 'static,
95 {
96 let msg = error_msg.into();
97 let msg_for_closure = msg.clone();
98 self.validators.push((
99 Box::new(move |value| {
100 if f(value) {
101 Ok(())
102 } else {
103 Err(msg_for_closure.clone())
104 }
105 }),
106 msg,
107 ));
108 self
109 }
110
111 pub fn validate_with<F>(mut self, f: F) -> Self
116 where
117 F: Fn(&T) -> Result<(), String> + Send + Sync + 'static,
118 {
119 self.validators
120 .push((Box::new(f), "validation failed".to_string()));
121 self
122 }
123
124 pub fn default(mut self, value: T) -> Self {
129 self.default = Some(value);
130 self
131 }
132
133 pub fn resolve(&self, matches: &ArgMatches) -> Result<T, InputError> {
138 self.resolve_with_source(matches).map(|r| r.value)
139 }
140
141 pub fn resolve_with_source(
146 &self,
147 matches: &ArgMatches,
148 ) -> Result<ResolvedInput<T>, InputError> {
149 for (source, kind) in &self.sources {
150 if !source.is_available(matches) {
151 continue;
152 }
153
154 #[allow(clippy::while_let_loop)]
158 loop {
159 match source.collect(matches)? {
160 Some(value) => {
161 if let Err(msg) = source.validate(&value) {
163 if source.can_retry() {
164 eprintln!("Invalid: {}", msg);
165 continue;
166 }
167 return Err(InputError::ValidationFailed(msg));
168 }
169
170 for (validator, _) in &self.validators {
172 if let Err(msg) = validator(&value) {
173 if source.can_retry() {
174 eprintln!("Invalid: {}", msg);
175 continue;
176 }
177 return Err(InputError::ValidationFailed(msg));
178 }
179 }
180
181 return Ok(ResolvedInput {
182 value,
183 source: *kind,
184 });
185 }
186 None => break, }
188 }
189 }
190
191 if let Some(value) = &self.default {
193 return Ok(ResolvedInput {
194 value: value.clone(),
195 source: InputSourceKind::Default,
196 });
197 }
198
199 Err(InputError::NoInput)
200 }
201
202 pub fn has_available_source(&self, matches: &ArgMatches) -> bool {
204 self.sources.iter().any(|(s, _)| s.is_available(matches)) || self.default.is_some()
205 }
206
207 pub fn source_count(&self) -> usize {
209 self.sources.len()
210 }
211}
212
213impl<T: Clone + Send + Sync + 'static> Default for InputChain<T> {
214 fn default() -> Self {
215 Self::new()
216 }
217}
218
219impl<T> fmt::Debug for InputChain<T> {
220 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221 f.debug_struct("InputChain")
222 .field(
223 "sources",
224 &self.sources.iter().map(|(_, k)| k).collect::<Vec<_>>(),
225 )
226 .field("validators", &self.validators.len())
227 .field("has_default", &self.default.is_some())
228 .finish()
229 }
230}
231
232fn source_kind_from_name(name: &str) -> InputSourceKind {
234 match name {
235 "argument" => InputSourceKind::Arg,
236 "flag" => InputSourceKind::Flag,
237 "stdin" => InputSourceKind::Stdin,
238 "environment variable" => InputSourceKind::Env,
239 "clipboard" => InputSourceKind::Clipboard,
240 "editor" => InputSourceKind::Editor,
241 "prompt" => InputSourceKind::Prompt,
242 "default" => InputSourceKind::Default,
243 _ => InputSourceKind::Default,
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use crate::env::{MockClipboard, MockEnv, MockStdin};
251 use crate::sources::{ArgSource, ClipboardSource, DefaultSource, EnvSource, StdinSource};
252 use clap::{Arg, Command};
253
254 fn make_matches(args: &[&str]) -> ArgMatches {
255 Command::new("test")
256 .arg(Arg::new("message").long("message").short('m'))
257 .try_get_matches_from(args)
258 .unwrap()
259 }
260
261 #[test]
262 fn chain_resolves_first_available() {
263 let matches = make_matches(&["test", "--message", "from arg"]);
264
265 let chain = InputChain::<String>::new()
266 .try_source(ArgSource::new("message"))
267 .try_source(DefaultSource::new("default".to_string()));
268
269 let result = chain.resolve_with_source(&matches).unwrap();
270 assert_eq!(result.value, "from arg");
271 assert_eq!(result.source, InputSourceKind::Arg);
272 }
273
274 #[test]
275 fn chain_falls_back_to_next_source() {
276 let matches = make_matches(&["test"]); let chain = InputChain::<String>::new()
279 .try_source(ArgSource::new("message"))
280 .try_source(StdinSource::with_reader(MockStdin::piped("from stdin")));
281
282 let result = chain.resolve_with_source(&matches).unwrap();
283 assert_eq!(result.value, "from stdin");
284 assert_eq!(result.source, InputSourceKind::Stdin);
285 }
286
287 #[test]
288 fn chain_falls_back_to_default() {
289 let matches = make_matches(&["test"]);
290
291 let chain = InputChain::<String>::new()
292 .try_source(ArgSource::new("message"))
293 .try_source(StdinSource::with_reader(MockStdin::terminal()))
294 .default("default value".to_string());
295
296 let result = chain.resolve_with_source(&matches).unwrap();
297 assert_eq!(result.value, "default value");
298 assert_eq!(result.source, InputSourceKind::Default);
299 }
300
301 #[test]
302 fn chain_error_when_no_input() {
303 let matches = make_matches(&["test"]);
304
305 let chain = InputChain::<String>::new()
306 .try_source(ArgSource::new("message"))
307 .try_source(StdinSource::with_reader(MockStdin::terminal()));
308
309 let result = chain.resolve(&matches);
310 assert!(matches!(result, Err(InputError::NoInput)));
311 }
312
313 #[test]
314 fn chain_validation_passes() {
315 let matches = make_matches(&["test", "--message", "valid@email.com"]);
316
317 let chain = InputChain::<String>::new()
318 .try_source(ArgSource::new("message"))
319 .validate(|s| s.contains('@'), "Must contain @");
320
321 let result = chain.resolve(&matches).unwrap();
322 assert_eq!(result, "valid@email.com");
323 }
324
325 #[test]
326 fn chain_validation_fails() {
327 let matches = make_matches(&["test", "--message", "invalid"]);
328
329 let chain = InputChain::<String>::new()
330 .try_source(ArgSource::new("message"))
331 .validate(|s| s.contains('@'), "Must contain @");
332
333 let result = chain.resolve(&matches);
334 assert!(matches!(result, Err(InputError::ValidationFailed(_))));
335 }
336
337 #[test]
338 fn chain_multiple_validators() {
339 let matches = make_matches(&["test", "--message", "ab"]);
340
341 let chain = InputChain::<String>::new()
342 .try_source(ArgSource::new("message"))
343 .validate(|s| !s.is_empty(), "Cannot be empty")
344 .validate(|s| s.len() >= 3, "Must be at least 3 characters");
345
346 let result = chain.resolve(&matches);
347 assert!(matches!(result, Err(InputError::ValidationFailed(_))));
348 }
349
350 #[test]
351 fn chain_complex_fallback() {
352 let matches = make_matches(&["test"]);
353
354 let chain = InputChain::<String>::new()
356 .try_source(ArgSource::new("message"))
357 .try_source(StdinSource::with_reader(MockStdin::terminal()))
358 .try_source(EnvSource::with_reader("MY_MSG", MockEnv::new()))
359 .try_source(ClipboardSource::with_reader(MockClipboard::with_content(
360 "from clipboard",
361 )));
362
363 let result = chain.resolve_with_source(&matches).unwrap();
364 assert_eq!(result.value, "from clipboard");
365 assert_eq!(result.source, InputSourceKind::Clipboard);
366 }
367
368 #[test]
369 fn chain_has_available_source() {
370 let matches = make_matches(&["test"]);
371
372 let chain_with_default = InputChain::<String>::new()
373 .try_source(ArgSource::new("message"))
374 .default("default".to_string());
375
376 assert!(chain_with_default.has_available_source(&matches));
377
378 let chain_without = InputChain::<String>::new().try_source(ArgSource::new("message"));
379
380 assert!(!chain_without.has_available_source(&matches));
381 }
382
383 #[test]
384 fn chain_source_count() {
385 let chain = InputChain::<String>::new()
386 .try_source(ArgSource::new("a"))
387 .try_source(ArgSource::new("b"))
388 .try_source(ArgSource::new("c"));
389
390 assert_eq!(chain.source_count(), 3);
391 }
392}