1use clap::ArgMatches;
33use serde::Serialize;
34use std::any::{Any, TypeId};
35use std::collections::HashMap;
36use std::fmt;
37
38#[derive(Default)]
71pub struct Extensions {
72 map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
73}
74
75impl Extensions {
76 pub fn new() -> Self {
78 Self::default()
79 }
80
81 pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) -> Option<T> {
85 self.map
86 .insert(TypeId::of::<T>(), Box::new(val))
87 .and_then(|boxed| boxed.downcast().ok().map(|b| *b))
88 }
89
90 pub fn get<T: 'static>(&self) -> Option<&T> {
94 self.map
95 .get(&TypeId::of::<T>())
96 .and_then(|boxed| boxed.downcast_ref())
97 }
98
99 pub fn get_mut<T: 'static>(&mut self) -> Option<&mut T> {
103 self.map
104 .get_mut(&TypeId::of::<T>())
105 .and_then(|boxed| boxed.downcast_mut())
106 }
107
108 pub fn get_required<T: 'static>(&self) -> Result<&T, anyhow::Error> {
112 self.get::<T>().ok_or_else(|| {
113 anyhow::anyhow!(
114 "Extension missing: type {} not found in context",
115 std::any::type_name::<T>()
116 )
117 })
118 }
119
120 pub fn get_mut_required<T: 'static>(&mut self) -> Result<&mut T, anyhow::Error> {
124 self.get_mut::<T>().ok_or_else(|| {
125 anyhow::anyhow!(
126 "Extension missing: type {} not found in context",
127 std::any::type_name::<T>()
128 )
129 })
130 }
131
132 pub fn remove<T: 'static>(&mut self) -> Option<T> {
134 self.map
135 .remove(&TypeId::of::<T>())
136 .and_then(|boxed| boxed.downcast().ok().map(|b| *b))
137 }
138
139 pub fn contains<T: 'static>(&self) -> bool {
141 self.map.contains_key(&TypeId::of::<T>())
142 }
143
144 pub fn len(&self) -> usize {
146 self.map.len()
147 }
148
149 pub fn is_empty(&self) -> bool {
151 self.map.is_empty()
152 }
153
154 pub fn clear(&mut self) {
156 self.map.clear();
157 }
158}
159
160impl fmt::Debug for Extensions {
161 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162 f.debug_struct("Extensions")
163 .field("len", &self.map.len())
164 .finish_non_exhaustive()
165 }
166}
167
168impl Clone for Extensions {
169 fn clone(&self) -> Self {
170 Self::new()
174 }
175}
176
177#[derive(Debug, Default)]
208pub struct CommandContext {
209 pub command_path: Vec<String>,
211
212 pub extensions: Extensions,
216}
217
218#[derive(Debug)]
222pub enum Output<T: Serialize> {
223 Render(T),
225 Silent,
227 Binary {
229 data: Vec<u8>,
231 filename: String,
233 },
234}
235
236impl<T: Serialize> Output<T> {
237 pub fn is_render(&self) -> bool {
239 matches!(self, Output::Render(_))
240 }
241
242 pub fn is_silent(&self) -> bool {
244 matches!(self, Output::Silent)
245 }
246
247 pub fn is_binary(&self) -> bool {
249 matches!(self, Output::Binary { .. })
250 }
251}
252
253pub type HandlerResult<T> = Result<Output<T>, anyhow::Error>;
257
258#[derive(Debug)]
263pub enum RunResult {
264 Handled(String),
266 Binary(Vec<u8>, String),
268 Silent,
270 NoMatch(ArgMatches),
272}
273
274impl RunResult {
275 pub fn is_handled(&self) -> bool {
277 matches!(self, RunResult::Handled(_))
278 }
279
280 pub fn is_binary(&self) -> bool {
282 matches!(self, RunResult::Binary(_, _))
283 }
284
285 pub fn is_silent(&self) -> bool {
287 matches!(self, RunResult::Silent)
288 }
289
290 pub fn output(&self) -> Option<&str> {
292 match self {
293 RunResult::Handled(s) => Some(s),
294 _ => None,
295 }
296 }
297
298 pub fn binary(&self) -> Option<(&[u8], &str)> {
300 match self {
301 RunResult::Binary(bytes, filename) => Some((bytes, filename)),
302 _ => None,
303 }
304 }
305
306 pub fn matches(&self) -> Option<&ArgMatches> {
308 match self {
309 RunResult::NoMatch(m) => Some(m),
310 _ => None,
311 }
312 }
313}
314
315pub trait Handler: Send + Sync {
319 type Output: Serialize;
321
322 fn handle(&self, matches: &ArgMatches, ctx: &CommandContext) -> HandlerResult<Self::Output>;
324}
325
326pub struct FnHandler<F, T>
328where
329 F: Fn(&ArgMatches, &CommandContext) -> HandlerResult<T> + Send + Sync,
330 T: Serialize + Send + Sync,
331{
332 f: F,
333 _phantom: std::marker::PhantomData<fn() -> T>,
334}
335
336impl<F, T> FnHandler<F, T>
337where
338 F: Fn(&ArgMatches, &CommandContext) -> HandlerResult<T> + Send + Sync,
339 T: Serialize + Send + Sync,
340{
341 pub fn new(f: F) -> Self {
343 Self {
344 f,
345 _phantom: std::marker::PhantomData,
346 }
347 }
348}
349
350impl<F, T> Handler for FnHandler<F, T>
351where
352 F: Fn(&ArgMatches, &CommandContext) -> HandlerResult<T> + Send + Sync,
353 T: Serialize + Send + Sync,
354{
355 type Output = T;
356
357 fn handle(&self, matches: &ArgMatches, ctx: &CommandContext) -> HandlerResult<T> {
358 (self.f)(matches, ctx)
359 }
360}
361
362pub trait LocalHandler {
369 type Output: Serialize;
371
372 fn handle(&mut self, matches: &ArgMatches, ctx: &CommandContext)
374 -> HandlerResult<Self::Output>;
375}
376
377pub struct LocalFnHandler<F, T>
379where
380 F: FnMut(&ArgMatches, &CommandContext) -> HandlerResult<T>,
381 T: Serialize,
382{
383 f: F,
384 _phantom: std::marker::PhantomData<fn() -> T>,
385}
386
387impl<F, T> LocalFnHandler<F, T>
388where
389 F: FnMut(&ArgMatches, &CommandContext) -> HandlerResult<T>,
390 T: Serialize,
391{
392 pub fn new(f: F) -> Self {
394 Self {
395 f,
396 _phantom: std::marker::PhantomData,
397 }
398 }
399}
400
401impl<F, T> LocalHandler for LocalFnHandler<F, T>
402where
403 F: FnMut(&ArgMatches, &CommandContext) -> HandlerResult<T>,
404 T: Serialize,
405{
406 type Output = T;
407
408 fn handle(&mut self, matches: &ArgMatches, ctx: &CommandContext) -> HandlerResult<T> {
409 (self.f)(matches, ctx)
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use serde_json::json;
417
418 #[test]
419 fn test_command_context_creation() {
420 let ctx = CommandContext {
421 command_path: vec!["config".into(), "get".into()],
422 extensions: Extensions::new(),
423 };
424 assert_eq!(ctx.command_path, vec!["config", "get"]);
425 }
426
427 #[test]
428 fn test_command_context_default() {
429 let ctx = CommandContext::default();
430 assert!(ctx.command_path.is_empty());
431 assert!(ctx.extensions.is_empty());
432 }
433
434 #[test]
436 fn test_extensions_insert_and_get() {
437 struct MyState {
438 value: i32,
439 }
440
441 let mut ext = Extensions::new();
442 assert!(ext.is_empty());
443
444 ext.insert(MyState { value: 42 });
445 assert!(!ext.is_empty());
446 assert_eq!(ext.len(), 1);
447
448 let state = ext.get::<MyState>().unwrap();
449 assert_eq!(state.value, 42);
450 }
451
452 #[test]
453 fn test_extensions_get_mut() {
454 struct Counter {
455 count: i32,
456 }
457
458 let mut ext = Extensions::new();
459 ext.insert(Counter { count: 0 });
460
461 if let Some(counter) = ext.get_mut::<Counter>() {
462 counter.count += 1;
463 }
464
465 assert_eq!(ext.get::<Counter>().unwrap().count, 1);
466 }
467
468 #[test]
469 fn test_extensions_multiple_types() {
470 struct TypeA(i32);
471 struct TypeB(String);
472
473 let mut ext = Extensions::new();
474 ext.insert(TypeA(1));
475 ext.insert(TypeB("hello".into()));
476
477 assert_eq!(ext.len(), 2);
478 assert_eq!(ext.get::<TypeA>().unwrap().0, 1);
479 assert_eq!(ext.get::<TypeB>().unwrap().0, "hello");
480 }
481
482 #[test]
483 fn test_extensions_replace() {
484 struct Value(i32);
485
486 let mut ext = Extensions::new();
487 ext.insert(Value(1));
488
489 let old = ext.insert(Value(2));
490 assert_eq!(old.unwrap().0, 1);
491 assert_eq!(ext.get::<Value>().unwrap().0, 2);
492 }
493
494 #[test]
495 fn test_extensions_remove() {
496 struct Value(i32);
497
498 let mut ext = Extensions::new();
499 ext.insert(Value(42));
500
501 let removed = ext.remove::<Value>();
502 assert_eq!(removed.unwrap().0, 42);
503 assert!(ext.is_empty());
504 assert!(ext.get::<Value>().is_none());
505 }
506
507 #[test]
508 fn test_extensions_contains() {
509 struct Present;
510 struct Absent;
511
512 let mut ext = Extensions::new();
513 ext.insert(Present);
514
515 assert!(ext.contains::<Present>());
516 assert!(!ext.contains::<Absent>());
517 }
518
519 #[test]
520 fn test_extensions_clear() {
521 struct A;
522 struct B;
523
524 let mut ext = Extensions::new();
525 ext.insert(A);
526 ext.insert(B);
527 assert_eq!(ext.len(), 2);
528
529 ext.clear();
530 assert!(ext.is_empty());
531 }
532
533 #[test]
534 fn test_extensions_missing_type_returns_none() {
535 struct NotInserted;
536
537 let ext = Extensions::new();
538 assert!(ext.get::<NotInserted>().is_none());
539 }
540
541 #[test]
542 fn test_extensions_get_required() {
543 #[derive(Debug)]
544 struct Config {
545 value: i32,
546 }
547
548 let mut ext = Extensions::new();
549 ext.insert(Config { value: 100 });
550
551 let val = ext.get_required::<Config>();
553 assert!(val.is_ok());
554 assert_eq!(val.unwrap().value, 100);
555
556 #[derive(Debug)]
558 struct Missing;
559 let err = ext.get_required::<Missing>();
560 assert!(err.is_err());
561 assert!(err
562 .unwrap_err()
563 .to_string()
564 .contains("Extension missing: type"));
565 }
566
567 #[test]
568 fn test_extensions_get_mut_required() {
569 #[derive(Debug)]
570 struct State {
571 count: i32,
572 }
573
574 let mut ext = Extensions::new();
575 ext.insert(State { count: 0 });
576
577 {
579 let val = ext.get_mut_required::<State>();
580 assert!(val.is_ok());
581 val.unwrap().count += 1;
582 }
583 assert_eq!(ext.get_required::<State>().unwrap().count, 1);
584
585 #[derive(Debug)]
587 struct Missing;
588 let err = ext.get_mut_required::<Missing>();
589 assert!(err.is_err());
590 }
591
592 #[test]
593 fn test_extensions_clone_behavior() {
594 struct Data(i32);
596
597 let mut original = Extensions::new();
598 original.insert(Data(42));
599
600 let cloned = original.clone();
601
602 assert!(original.get::<Data>().is_some());
604
605 assert!(cloned.is_empty());
607 assert!(cloned.get::<Data>().is_none());
608 }
609
610 #[test]
611 fn test_output_render() {
612 let output: Output<String> = Output::Render("success".into());
613 assert!(output.is_render());
614 assert!(!output.is_silent());
615 assert!(!output.is_binary());
616 }
617
618 #[test]
619 fn test_output_silent() {
620 let output: Output<String> = Output::Silent;
621 assert!(!output.is_render());
622 assert!(output.is_silent());
623 assert!(!output.is_binary());
624 }
625
626 #[test]
627 fn test_output_binary() {
628 let output: Output<String> = Output::Binary {
629 data: vec![0x25, 0x50, 0x44, 0x46],
630 filename: "report.pdf".into(),
631 };
632 assert!(!output.is_render());
633 assert!(!output.is_silent());
634 assert!(output.is_binary());
635 }
636
637 #[test]
638 fn test_run_result_handled() {
639 let result = RunResult::Handled("output".into());
640 assert!(result.is_handled());
641 assert!(!result.is_binary());
642 assert!(!result.is_silent());
643 assert_eq!(result.output(), Some("output"));
644 assert!(result.matches().is_none());
645 }
646
647 #[test]
648 fn test_run_result_silent() {
649 let result = RunResult::Silent;
650 assert!(!result.is_handled());
651 assert!(!result.is_binary());
652 assert!(result.is_silent());
653 }
654
655 #[test]
656 fn test_run_result_binary() {
657 let bytes = vec![0x25, 0x50, 0x44, 0x46];
658 let result = RunResult::Binary(bytes.clone(), "report.pdf".into());
659 assert!(!result.is_handled());
660 assert!(result.is_binary());
661 assert!(!result.is_silent());
662
663 let (data, filename) = result.binary().unwrap();
664 assert_eq!(data, &bytes);
665 assert_eq!(filename, "report.pdf");
666 }
667
668 #[test]
669 fn test_run_result_no_match() {
670 let matches = clap::Command::new("test").get_matches_from(vec!["test"]);
671 let result = RunResult::NoMatch(matches);
672 assert!(!result.is_handled());
673 assert!(!result.is_binary());
674 assert!(result.matches().is_some());
675 }
676
677 #[test]
678 fn test_fn_handler() {
679 let handler = FnHandler::new(|_m: &ArgMatches, _ctx: &CommandContext| {
680 Ok(Output::Render(json!({"status": "ok"})))
681 });
682
683 let ctx = CommandContext::default();
684 let matches = clap::Command::new("test").get_matches_from(vec!["test"]);
685
686 let result = handler.handle(&matches, &ctx);
687 assert!(result.is_ok());
688 }
689
690 #[test]
691 fn test_local_fn_handler_mutation() {
692 let mut counter = 0u32;
693
694 let mut handler = LocalFnHandler::new(|_m: &ArgMatches, _ctx: &CommandContext| {
695 counter += 1;
696 Ok(Output::Render(counter))
697 });
698
699 let ctx = CommandContext::default();
700 let matches = clap::Command::new("test").get_matches_from(vec!["test"]);
701
702 let _ = handler.handle(&matches, &ctx);
703 let _ = handler.handle(&matches, &ctx);
704 let result = handler.handle(&matches, &ctx);
705
706 assert!(result.is_ok());
707 if let Ok(Output::Render(count)) = result {
708 assert_eq!(count, 3);
709 }
710 }
711}