1use std::fmt;
33use std::sync::Arc;
34use thiserror::Error;
35
36use crate::handler::CommandContext;
37use clap::ArgMatches;
38
39#[derive(Debug, Clone)]
45pub struct TextOutput {
46 pub formatted: String,
48 pub raw: String,
52}
53
54impl TextOutput {
55 pub fn new(formatted: String, raw: String) -> Self {
57 Self { formatted, raw }
58 }
59
60 pub fn plain(text: String) -> Self {
64 Self {
65 formatted: text.clone(),
66 raw: text,
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
75pub enum RenderedOutput {
76 Text(TextOutput),
80 Binary(Vec<u8>, String),
82 Silent,
84}
85
86impl RenderedOutput {
87 pub fn is_text(&self) -> bool {
89 matches!(self, RenderedOutput::Text(_))
90 }
91
92 pub fn is_binary(&self) -> bool {
94 matches!(self, RenderedOutput::Binary(_, _))
95 }
96
97 pub fn is_silent(&self) -> bool {
99 matches!(self, RenderedOutput::Silent)
100 }
101
102 pub fn as_text(&self) -> Option<&str> {
104 match self {
105 RenderedOutput::Text(t) => Some(&t.formatted),
106 _ => None,
107 }
108 }
109
110 pub fn as_raw_text(&self) -> Option<&str> {
113 match self {
114 RenderedOutput::Text(t) => Some(&t.raw),
115 _ => None,
116 }
117 }
118
119 pub fn as_text_output(&self) -> Option<&TextOutput> {
121 match self {
122 RenderedOutput::Text(t) => Some(t),
123 _ => None,
124 }
125 }
126
127 pub fn as_binary(&self) -> Option<(&[u8], &str)> {
129 match self {
130 RenderedOutput::Binary(bytes, filename) => Some((bytes, filename)),
131 _ => None,
132 }
133 }
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum HookPhase {
139 PreDispatch,
141 PostDispatch,
143 PostOutput,
145}
146
147impl fmt::Display for HookPhase {
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 match self {
150 HookPhase::PreDispatch => write!(f, "pre-dispatch"),
151 HookPhase::PostDispatch => write!(f, "post-dispatch"),
152 HookPhase::PostOutput => write!(f, "post-output"),
153 }
154 }
155}
156
157#[derive(Debug, Error)]
159#[error("hook error ({phase}): {message}")]
160pub struct HookError {
161 pub message: String,
163 pub phase: HookPhase,
165 #[source]
167 pub source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
168}
169
170impl HookError {
171 pub fn pre_dispatch(message: impl Into<String>) -> Self {
173 Self {
174 message: message.into(),
175 phase: HookPhase::PreDispatch,
176 source: None,
177 }
178 }
179
180 pub fn post_dispatch(message: impl Into<String>) -> Self {
182 Self {
183 message: message.into(),
184 phase: HookPhase::PostDispatch,
185 source: None,
186 }
187 }
188
189 pub fn post_output(message: impl Into<String>) -> Self {
191 Self {
192 message: message.into(),
193 phase: HookPhase::PostOutput,
194 source: None,
195 }
196 }
197
198 pub fn with_source<E>(mut self, source: E) -> Self
200 where
201 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
202 {
203 self.source = Some(source.into());
204 self
205 }
206}
207
208pub type PreDispatchFn =
213 Arc<dyn Fn(&ArgMatches, &mut CommandContext) -> Result<(), HookError> + Send + Sync>;
214
215pub type PostDispatchFn = Arc<
217 dyn Fn(&ArgMatches, &CommandContext, serde_json::Value) -> Result<serde_json::Value, HookError>
218 + Send
219 + Sync,
220>;
221
222pub type PostOutputFn = Arc<
224 dyn Fn(&ArgMatches, &CommandContext, RenderedOutput) -> Result<RenderedOutput, HookError>
225 + Send
226 + Sync,
227>;
228
229#[derive(Clone, Default)]
233pub struct Hooks {
234 pre_dispatch: Vec<PreDispatchFn>,
235 post_dispatch: Vec<PostDispatchFn>,
236 post_output: Vec<PostOutputFn>,
237}
238
239impl Hooks {
240 pub fn new() -> Self {
242 Self::default()
243 }
244
245 pub fn is_empty(&self) -> bool {
247 self.pre_dispatch.is_empty() && self.post_dispatch.is_empty() && self.post_output.is_empty()
248 }
249
250 pub fn pre_dispatch<F>(mut self, f: F) -> Self
271 where
272 F: Fn(&ArgMatches, &mut CommandContext) -> Result<(), HookError> + Send + Sync + 'static,
273 {
274 self.pre_dispatch.push(Arc::new(f));
275 self
276 }
277
278 pub fn post_dispatch<F>(mut self, f: F) -> Self
280 where
281 F: Fn(
282 &ArgMatches,
283 &CommandContext,
284 serde_json::Value,
285 ) -> Result<serde_json::Value, HookError>
286 + Send
287 + Sync
288 + 'static,
289 {
290 self.post_dispatch.push(Arc::new(f));
291 self
292 }
293
294 pub fn post_output<F>(mut self, f: F) -> Self
296 where
297 F: Fn(&ArgMatches, &CommandContext, RenderedOutput) -> Result<RenderedOutput, HookError>
298 + Send
299 + Sync
300 + 'static,
301 {
302 self.post_output.push(Arc::new(f));
303 self
304 }
305
306 pub fn run_pre_dispatch(
310 &self,
311 matches: &ArgMatches,
312 ctx: &mut CommandContext,
313 ) -> Result<(), HookError> {
314 for hook in &self.pre_dispatch {
315 hook(matches, ctx)?;
316 }
317 Ok(())
318 }
319
320 pub fn run_post_dispatch(
322 &self,
323 matches: &ArgMatches,
324 ctx: &CommandContext,
325 data: serde_json::Value,
326 ) -> Result<serde_json::Value, HookError> {
327 let mut current = data;
328 for hook in &self.post_dispatch {
329 current = hook(matches, ctx, current)?;
330 }
331 Ok(current)
332 }
333
334 pub fn run_post_output(
336 &self,
337 matches: &ArgMatches,
338 ctx: &CommandContext,
339 output: RenderedOutput,
340 ) -> Result<RenderedOutput, HookError> {
341 let mut current = output;
342 for hook in &self.post_output {
343 current = hook(matches, ctx, current)?;
344 }
345 Ok(current)
346 }
347}
348
349impl fmt::Debug for Hooks {
350 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351 f.debug_struct("Hooks")
352 .field("pre_dispatch_count", &self.pre_dispatch.len())
353 .field("post_dispatch_count", &self.post_dispatch.len())
354 .field("post_output_count", &self.post_output.len())
355 .finish()
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 fn test_context() -> CommandContext {
364 CommandContext {
365 command_path: vec!["test".into()],
366 ..Default::default()
367 }
368 }
369
370 fn test_matches() -> ArgMatches {
371 clap::Command::new("test").get_matches_from(vec!["test"])
372 }
373
374 #[test]
375 fn test_rendered_output_variants() {
376 let text = RenderedOutput::Text(TextOutput::new("formatted".into(), "raw".into()));
377 assert!(text.is_text());
378 assert!(!text.is_binary());
379 assert!(!text.is_silent());
380 assert_eq!(text.as_text(), Some("formatted"));
381 assert_eq!(text.as_raw_text(), Some("raw"));
382
383 let plain = RenderedOutput::Text(TextOutput::plain("hello".into()));
385 assert_eq!(plain.as_text(), Some("hello"));
386 assert_eq!(plain.as_raw_text(), Some("hello"));
387
388 let binary = RenderedOutput::Binary(vec![1, 2, 3], "file.bin".into());
389 assert!(!binary.is_text());
390 assert!(binary.is_binary());
391 assert_eq!(binary.as_binary(), Some((&[1u8, 2, 3][..], "file.bin")));
392
393 let silent = RenderedOutput::Silent;
394 assert!(silent.is_silent());
395 }
396
397 #[test]
398 fn test_hook_error_creation() {
399 let err = HookError::pre_dispatch("test error");
400 assert_eq!(err.phase, HookPhase::PreDispatch);
401 assert_eq!(err.message, "test error");
402 }
403
404 #[test]
405 fn test_hooks_empty() {
406 let hooks = Hooks::new();
407 assert!(hooks.is_empty());
408 }
409
410 #[test]
411 fn test_pre_dispatch_success() {
412 let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
413 let called_clone = called.clone();
414
415 let hooks = Hooks::new().pre_dispatch(move |_, _| {
416 called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
417 Ok(())
418 });
419
420 let mut ctx = test_context();
421 let matches = test_matches();
422 let result = hooks.run_pre_dispatch(&matches, &mut ctx);
423
424 assert!(result.is_ok());
425 assert!(called.load(std::sync::atomic::Ordering::SeqCst));
426 }
427
428 #[test]
429 fn test_pre_dispatch_error_aborts() {
430 let hooks = Hooks::new()
431 .pre_dispatch(|_, _| Err(HookError::pre_dispatch("first fails")))
432 .pre_dispatch(|_, _| panic!("should not be called"));
433
434 let mut ctx = test_context();
435 let matches = test_matches();
436 let result = hooks.run_pre_dispatch(&matches, &mut ctx);
437
438 assert!(result.is_err());
439 }
440
441 #[test]
442 fn test_pre_dispatch_injects_extensions() {
443 struct TestState {
444 value: i32,
445 }
446
447 let hooks = Hooks::new().pre_dispatch(|_, ctx| {
448 ctx.extensions.insert(TestState { value: 42 });
449 Ok(())
450 });
451
452 let mut ctx = test_context();
453 let matches = test_matches();
454
455 assert!(!ctx.extensions.contains::<TestState>());
457
458 hooks.run_pre_dispatch(&matches, &mut ctx).unwrap();
459
460 let state = ctx.extensions.get::<TestState>().unwrap();
462 assert_eq!(state.value, 42);
463 }
464
465 #[test]
466 fn test_pre_dispatch_multiple_hooks_share_context() {
467 struct Counter {
468 count: i32,
469 }
470
471 let hooks = Hooks::new()
472 .pre_dispatch(|_, ctx| {
473 ctx.extensions.insert(Counter { count: 1 });
474 Ok(())
475 })
476 .pre_dispatch(|_, ctx| {
477 if let Some(counter) = ctx.extensions.get_mut::<Counter>() {
479 counter.count += 10;
480 }
481 Ok(())
482 });
483
484 let mut ctx = test_context();
485 let matches = test_matches();
486 hooks.run_pre_dispatch(&matches, &mut ctx).unwrap();
487
488 let counter = ctx.extensions.get::<Counter>().unwrap();
489 assert_eq!(counter.count, 11);
490 }
491
492 #[test]
493 fn test_post_dispatch_transformation() {
494 use serde_json::json;
495
496 let hooks = Hooks::new().post_dispatch(|_, _, mut data| {
497 if let Some(obj) = data.as_object_mut() {
498 obj.insert("modified".into(), json!(true));
499 }
500 Ok(data)
501 });
502
503 let ctx = test_context();
504 let matches = test_matches();
505 let data = json!({"value": 42});
506 let result = hooks.run_post_dispatch(&matches, &ctx, data);
507
508 assert!(result.is_ok());
509 let output = result.unwrap();
510 assert_eq!(output["value"], 42);
511 assert_eq!(output["modified"], true);
512 }
513
514 #[test]
515 fn test_post_output_transformation() {
516 let hooks = Hooks::new().post_output(|_, _, output| {
517 if let RenderedOutput::Text(text_output) = output {
518 Ok(RenderedOutput::Text(TextOutput::new(
519 text_output.formatted.to_uppercase(),
520 text_output.raw.to_uppercase(),
521 )))
522 } else {
523 Ok(output)
524 }
525 });
526
527 let ctx = test_context();
528 let matches = test_matches();
529 let input = RenderedOutput::Text(TextOutput::plain("hello".into()));
530 let result = hooks.run_post_output(&matches, &ctx, input);
531
532 assert!(result.is_ok());
533 assert_eq!(result.unwrap().as_text(), Some("HELLO"));
534 }
535}