git_revise/
hook.rs

1// use crate::hooks::types::HookType;
2use std::{
3    process::{Command, Stdio},
4    str::FromStr,
5};
6
7use serde::{Deserialize, Serialize};
8
9use crate::error::ReviseResult;
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
12pub enum HookType {
13    PreAdd,
14    PostAdd,
15    PreCommit,
16    PostCommit,
17    PrePush,
18    PostMerge,
19}
20
21impl FromStr for HookType {
22    type Err = anyhow::Error;
23
24    fn from_str(s: &str) -> Result<Self, Self::Err> {
25        match s {
26            "pre-add" => Ok(Self::PreAdd),
27            "post-add" => Ok(Self::PostAdd),
28            "pre-commit" => Ok(Self::PreCommit),
29            "post-commit" => Ok(Self::PostCommit),
30            "pre-push" => Ok(Self::PrePush),
31            "post-merge" => Ok(Self::PostMerge),
32            _ => Err(anyhow::anyhow!("Invalid hook type: {}", s)),
33        }
34    }
35}
36
37pub struct HookRunner;
38
39impl HookRunner {
40    pub fn run_hook<F, Args, R>(f: F, args: Args) -> R
41    where
42        F: FnOnce(Args) -> R,
43    {
44        f(args)
45    }
46
47    pub fn run_command(command: &str) -> ReviseResult<()> {
48        let output = if cfg!(target_os = "windows") {
49            Command::new("cmd")
50                .arg("/C")
51                .arg(command)
52                .stdout(Stdio::inherit())
53                .stderr(Stdio::inherit())
54                .output()?
55        } else {
56            Command::new("sh")
57                .arg("-c")
58                .arg(command)
59                .stdout(Stdio::inherit())
60                .stderr(Stdio::inherit())
61                .output()?
62        };
63
64        if !output.status.success() {
65            return Err(anyhow::anyhow!(
66                "Hook failed: {}\nExit code: {:?}\nStderr: {}",
67                command,
68                output.status.code(),
69                String::from_utf8_lossy(&output.stderr)
70            ));
71        }
72
73        Ok(())
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    #[test]
82    fn test_run_hook() {
83        let func = |(a, b, c): (i32, i32, i32)| -> String {
84            println!("running hook");
85            println!("a: {a}");
86            println!("b: {b}");
87            println!("c: {c}");
88            format!("result: {}", a + b + c)
89        };
90
91        let result = HookRunner::run_hook(func, (3, 4, 5));
92        assert_eq!(result, "result: 12");
93    }
94}