Skip to main content

mxr_rules/
shell_hook.rs

1use serde::Serialize;
2use std::process::Stdio;
3use std::time::Duration;
4use tokio::io::AsyncWriteExt;
5use tokio::process::Command;
6use tokio::time::timeout;
7
8/// Default timeout for shell hooks.
9const DEFAULT_HOOK_TIMEOUT: Duration = Duration::from_secs(30);
10
11/// JSON payload piped to shell hook stdin.
12#[derive(Debug, Serialize)]
13pub struct ShellHookPayload {
14    pub id: String,
15    pub from: ShellHookAddress,
16    pub subject: String,
17    pub date: String,
18    pub body_text: Option<String>,
19    pub attachments: Vec<ShellHookAttachment>,
20}
21
22#[derive(Debug, Serialize)]
23pub struct ShellHookAddress {
24    pub name: Option<String>,
25    pub email: String,
26}
27
28#[derive(Debug, Serialize)]
29pub struct ShellHookAttachment {
30    pub filename: String,
31    pub size_bytes: u64,
32    pub local_path: Option<String>,
33}
34
35/// Execute a shell hook command with message data on stdin.
36///
37/// Returns Ok(()) on exit code 0, Err on non-zero or timeout.
38pub async fn execute_shell_hook(
39    command: &str,
40    payload: &ShellHookPayload,
41    hook_timeout: Option<Duration>,
42) -> Result<(), ShellHookError> {
43    let timeout_dur = hook_timeout.unwrap_or(DEFAULT_HOOK_TIMEOUT);
44
45    let json = serde_json::to_string(payload)
46        .map_err(|e| ShellHookError::SerializationFailed(e.to_string()))?;
47
48    let mut child = Command::new("sh")
49        .arg("-c")
50        .arg(command)
51        .stdin(Stdio::piped())
52        .stdout(Stdio::null())
53        .stderr(Stdio::piped())
54        .spawn()
55        .map_err(|e| ShellHookError::SpawnFailed {
56            command: command.to_string(),
57            error: e.to_string(),
58        })?;
59
60    if let Some(mut stdin) = child.stdin.take() {
61        stdin
62            .write_all(json.as_bytes())
63            .await
64            .map_err(|e| ShellHookError::StdinWriteFailed(e.to_string()))?;
65    }
66
67    let result = timeout(timeout_dur, child.wait_with_output())
68        .await
69        .map_err(|_| ShellHookError::Timeout {
70            command: command.to_string(),
71            timeout: timeout_dur,
72        })?
73        .map_err(|e| ShellHookError::WaitFailed(e.to_string()))?;
74
75    if result.status.success() {
76        Ok(())
77    } else {
78        let stderr = String::from_utf8_lossy(&result.stderr).to_string();
79        Err(ShellHookError::NonZeroExit {
80            command: command.to_string(),
81            code: result.status.code(),
82            stderr,
83        })
84    }
85}
86
87#[derive(Debug, thiserror::Error)]
88pub enum ShellHookError {
89    #[error("Failed to serialize message to JSON: {0}")]
90    SerializationFailed(String),
91    #[error("Failed to spawn command '{command}': {error}")]
92    SpawnFailed { command: String, error: String },
93    #[error("Failed to write to command stdin: {0}")]
94    StdinWriteFailed(String),
95    #[error("Command '{command}' timed out after {timeout:?}")]
96    Timeout { command: String, timeout: Duration },
97    #[error("Failed to wait for command: {0}")]
98    WaitFailed(String),
99    #[error("Command '{command}' exited with code {code:?}: {stderr}")]
100    NonZeroExit {
101        command: String,
102        code: Option<i32>,
103        stderr: String,
104    },
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    fn sample_payload() -> ShellHookPayload {
112        ShellHookPayload {
113            id: "msg_123".into(),
114            from: ShellHookAddress {
115                name: Some("Alice".into()),
116                email: "alice@example.com".into(),
117            },
118            subject: "Invoice #2847".into(),
119            date: "2026-03-17T10:30:00Z".into(),
120            body_text: Some("Please find attached the invoice.".into()),
121            attachments: vec![ShellHookAttachment {
122                filename: "invoice.pdf".into(),
123                size_bytes: 234_567,
124                local_path: Some("/tmp/mxr/invoice.pdf".into()),
125            }],
126        }
127    }
128
129    #[tokio::test]
130    async fn hook_success_exit_zero() {
131        let result = execute_shell_hook("cat > /dev/null", &sample_payload(), None).await;
132        assert!(result.is_ok());
133    }
134
135    #[tokio::test]
136    async fn hook_failure_exit_nonzero() {
137        let result = execute_shell_hook("exit 1", &sample_payload(), None).await;
138        assert!(matches!(result, Err(ShellHookError::NonZeroExit { .. })));
139    }
140
141    #[tokio::test]
142    async fn hook_captures_stderr_on_failure() {
143        let result = execute_shell_hook("echo 'oops' >&2; exit 1", &sample_payload(), None).await;
144        match result {
145            Err(ShellHookError::NonZeroExit { stderr, .. }) => {
146                assert!(stderr.contains("oops"));
147            }
148            other => panic!("Expected NonZeroExit, got {:?}", other),
149        }
150    }
151
152    #[tokio::test]
153    async fn hook_timeout() {
154        let result = execute_shell_hook(
155            "sleep 60",
156            &sample_payload(),
157            Some(Duration::from_millis(100)),
158        )
159        .await;
160        assert!(matches!(result, Err(ShellHookError::Timeout { .. })));
161    }
162
163    #[tokio::test]
164    async fn hook_receives_valid_json_on_stdin() {
165        // Use python to validate JSON on stdin
166        let result = execute_shell_hook(
167            "python3 -c 'import sys, json; d = json.load(sys.stdin); assert d[\"id\"] == \"msg_123\"'",
168            &sample_payload(),
169            None,
170        )
171        .await;
172        assert!(
173            result.is_ok(),
174            "Hook should receive valid JSON: {:?}",
175            result
176        );
177    }
178
179    #[tokio::test]
180    async fn hook_payload_contains_all_fields() {
181        // Extract and verify specific fields from the JSON
182        let result = execute_shell_hook(
183            "python3 -c 'import sys, json; d = json.load(sys.stdin); assert d[\"from\"][\"email\"] == \"alice@example.com\"; assert d[\"subject\"] == \"Invoice #2847\"; assert len(d[\"attachments\"]) == 1'",
184            &sample_payload(),
185            None,
186        )
187        .await;
188        assert!(result.is_ok(), "Payload field check failed: {:?}", result);
189    }
190
191    #[tokio::test]
192    async fn hook_with_pipe_command() {
193        // Test that shell pipes work
194        let result =
195            execute_shell_hook("cat | head -c 1 > /dev/null", &sample_payload(), None).await;
196        assert!(result.is_ok());
197    }
198}