Skip to main content

agent_orchestrator/crd/
plugins.rs

1use anyhow::{Result, anyhow};
2use std::collections::HashMap;
3use std::path::Path;
4use std::process::Stdio;
5use std::time::Duration;
6use tokio::io::AsyncWriteExt;
7use tokio::process::Command;
8
9use crate::crd::types::CrdPlugin;
10
11/// Plugin type: interceptor (gates request processing).
12pub const PLUGIN_TYPE_INTERCEPTOR: &str = "interceptor";
13/// Plugin type: transformer (modifies payload data).
14pub const PLUGIN_TYPE_TRANSFORMER: &str = "transformer";
15/// Plugin type: cron (periodic maintenance task).
16pub const PLUGIN_TYPE_CRON: &str = "cron";
17
18/// Phase: webhook authentication (runs before signature verification).
19pub const PHASE_WEBHOOK_AUTHENTICATE: &str = "webhook.authenticate";
20/// Phase: webhook transformation (normalizes payload before trigger matching).
21pub const PHASE_WEBHOOK_TRANSFORM: &str = "webhook.transform";
22
23/// Execute an interceptor plugin (e.g. custom signature verification).
24///
25/// The plugin receives context via environment variables:
26/// - `PLUGIN_NAME`, `PLUGIN_TYPE`, `CRD_KIND`: plugin identity
27/// - `WEBHOOK_BODY`: raw request body
28/// - `WEBHOOK_HEADER_<NAME>`: one variable per HTTP header (uppercased, hyphens→underscores)
29///
30/// Returns Ok(()) if the plugin exits 0 (accept), or Err if non-zero (reject).
31pub async fn execute_interceptor(
32    plugin: &CrdPlugin,
33    crd_kind: &str,
34    headers: &HashMap<String, String>,
35    body: &str,
36    db_path: Option<&Path>,
37) -> Result<()> {
38    audit_plugin_execution(db_path, "plugin_execute", crd_kind, plugin);
39    let timeout = Duration::from_secs(plugin.effective_timeout());
40
41    let mut cmd = Command::new("sh");
42    cmd.arg("-c")
43        .arg(&plugin.command)
44        .env("PLUGIN_NAME", &plugin.name)
45        .env("PLUGIN_TYPE", PLUGIN_TYPE_INTERCEPTOR)
46        .env("CRD_KIND", crd_kind)
47        .env("WEBHOOK_BODY", body);
48
49    for (key, value) in headers {
50        let env_key = format!("WEBHOOK_HEADER_{}", key.to_uppercase().replace('-', "_"));
51        cmd.env(env_key, value);
52    }
53
54    let output = run_plugin_with_timeout(&mut cmd, None, timeout)
55        .await
56        .map_err(|e| {
57            audit_plugin_timeout(db_path, crd_kind, plugin, &e);
58            anyhow!(
59                "interceptor plugin '{}' for CRD '{}' failed: {}",
60                plugin.name,
61                crd_kind,
62                e
63            )
64        })?;
65
66    if !output.status.success() {
67        let stderr = String::from_utf8_lossy(&output.stderr);
68        return Err(anyhow!(
69            "interceptor plugin '{}' for CRD '{}' rejected request (exit {}): {}",
70            plugin.name,
71            crd_kind,
72            output.status.code().unwrap_or(-1),
73            stderr.trim()
74        ));
75    }
76
77    Ok(())
78}
79
80/// Execute a transformer plugin (e.g. payload normalization).
81///
82/// The plugin receives:
83/// - stdin: the original JSON payload
84/// - env: `PLUGIN_NAME`, `PLUGIN_TYPE`, `CRD_KIND`
85///
86/// Returns the transformed JSON from stdout.
87pub async fn execute_transformer(
88    plugin: &CrdPlugin,
89    crd_kind: &str,
90    payload: &serde_json::Value,
91    db_path: Option<&Path>,
92) -> Result<serde_json::Value> {
93    audit_plugin_execution(db_path, "plugin_execute", crd_kind, plugin);
94    let timeout = Duration::from_secs(plugin.effective_timeout());
95    let input = serde_json::to_string(payload)
96        .map_err(|e| anyhow!("failed to serialize payload for transformer: {}", e))?;
97
98    let mut cmd = Command::new("sh");
99    cmd.arg("-c")
100        .arg(&plugin.command)
101        .env("PLUGIN_NAME", &plugin.name)
102        .env("PLUGIN_TYPE", PLUGIN_TYPE_TRANSFORMER)
103        .env("CRD_KIND", crd_kind);
104
105    let output = run_plugin_with_timeout(&mut cmd, Some(input.as_bytes()), timeout)
106        .await
107        .map_err(|e| {
108            audit_plugin_timeout(db_path, crd_kind, plugin, &e);
109            anyhow!(
110                "transformer plugin '{}' for CRD '{}' failed: {}",
111                plugin.name,
112                crd_kind,
113                e
114            )
115        })?;
116
117    if !output.status.success() {
118        let stderr = String::from_utf8_lossy(&output.stderr);
119        return Err(anyhow!(
120            "transformer plugin '{}' for CRD '{}' failed (exit {}): {}",
121            plugin.name,
122            crd_kind,
123            output.status.code().unwrap_or(-1),
124            stderr.trim()
125        ));
126    }
127
128    let stdout = String::from_utf8_lossy(&output.stdout);
129    serde_json::from_str(stdout.trim()).map_err(|e| {
130        anyhow!(
131            "transformer plugin '{}' for CRD '{}' returned invalid JSON: {}",
132            plugin.name,
133            crd_kind,
134            e
135        )
136    })
137}
138
139/// Execute a cron plugin (periodic maintenance task).
140///
141/// The plugin receives env: `PLUGIN_NAME`, `PLUGIN_TYPE`, `CRD_KIND`.
142/// Returns Ok(()) on success, Err on failure (caller should log, not abort).
143pub async fn execute_cron_plugin(
144    plugin: &CrdPlugin,
145    crd_kind: &str,
146    db_path: Option<&Path>,
147) -> Result<()> {
148    audit_plugin_execution(db_path, "plugin_execute", crd_kind, plugin);
149    let timeout = Duration::from_secs(plugin.effective_timeout());
150
151    let mut cmd = Command::new("sh");
152    cmd.arg("-c")
153        .arg(&plugin.command)
154        .env("PLUGIN_NAME", &plugin.name)
155        .env("PLUGIN_TYPE", PLUGIN_TYPE_CRON)
156        .env("CRD_KIND", crd_kind);
157
158    let output = run_plugin_with_timeout(&mut cmd, None, timeout)
159        .await
160        .map_err(|e| {
161            audit_plugin_timeout(db_path, crd_kind, plugin, &e);
162            anyhow!(
163                "cron plugin '{}' for CRD '{}' failed: {}",
164                plugin.name,
165                crd_kind,
166                e
167            )
168        })?;
169
170    if !output.status.success() {
171        let stderr = String::from_utf8_lossy(&output.stderr);
172        return Err(anyhow!(
173            "cron plugin '{}' for CRD '{}' failed (exit {}): {}",
174            plugin.name,
175            crd_kind,
176            output.status.code().unwrap_or(-1),
177            stderr.trim()
178        ));
179    }
180
181    Ok(())
182}
183
184/// Collect plugins of a given phase from a CRD's plugin list.
185pub fn plugins_for_phase<'a>(plugins: &'a [CrdPlugin], phase: &str) -> Vec<&'a CrdPlugin> {
186    plugins
187        .iter()
188        .filter(|p| p.phase.as_deref() == Some(phase))
189        .collect()
190}
191
192/// Collect cron-type plugins from a CRD's plugin list.
193pub fn cron_plugins(plugins: &[CrdPlugin]) -> Vec<&CrdPlugin> {
194    plugins
195        .iter()
196        .filter(|p| p.plugin_type == PLUGIN_TYPE_CRON)
197        .collect()
198}
199
200// --- audit helper ---
201
202fn audit_plugin_execution(
203    db_path: Option<&Path>,
204    action: &str,
205    crd_kind: &str,
206    plugin: &CrdPlugin,
207) {
208    if let Some(path) = db_path {
209        let _ = crate::db::insert_plugin_audit(
210            path,
211            &crate::db::PluginAuditRecord {
212                action: action.into(),
213                crd_kind: crd_kind.into(),
214                plugin_name: Some(plugin.name.clone()),
215                plugin_type: Some(plugin.plugin_type.clone()),
216                command: plugin.command.clone(),
217                applied_by: None,
218                transport: None,
219                peer_pid: None,
220                result: "allowed".into(),
221                policy_mode: None,
222            },
223        );
224    }
225}
226
227// --- internal helpers ---
228
229/// Spawn a plugin process with process-group isolation and async timeout.
230///
231/// - Sets `process_group(0)` so the child becomes its own PGID leader (Unix).
232/// - Sets `kill_on_drop(true)` as a safety net.
233/// - On timeout, kills the entire process group (child + all descendants)
234///   via `SIGKILL` to `-pid`, not just the direct child.
235/// - Uses `tokio::time::timeout` instead of busy-wait polling.
236async fn run_plugin_with_timeout(
237    cmd: &mut Command,
238    stdin_data: Option<&[u8]>,
239    timeout: Duration,
240) -> Result<std::process::Output> {
241    #[cfg(unix)]
242    {
243        cmd.process_group(0);
244    }
245    cmd.kill_on_drop(true);
246
247    if stdin_data.is_some() {
248        cmd.stdin(Stdio::piped());
249    }
250    cmd.stdout(Stdio::piped());
251    cmd.stderr(Stdio::piped());
252
253    let mut child = cmd.spawn().map_err(|e| anyhow!("spawn failed: {}", e))?;
254
255    // Write stdin data and close the handle before waiting.
256    if let Some(data) = stdin_data {
257        if let Some(mut stdin) = child.stdin.take() {
258            let _ = stdin.write_all(data).await;
259            drop(stdin);
260        }
261    }
262
263    // Take stdout/stderr pipes before waiting so we retain `&mut child` for kill.
264    let mut child_stdout = child.stdout.take();
265    let mut child_stderr = child.stderr.take();
266
267    match tokio::time::timeout(timeout, child.wait()).await {
268        Ok(Ok(status)) => {
269            use tokio::io::AsyncReadExt;
270            let mut stdout = Vec::new();
271            let mut stderr = Vec::new();
272            if let Some(ref mut p) = child_stdout {
273                let _ = p.read_to_end(&mut stdout).await;
274            }
275            if let Some(ref mut p) = child_stderr {
276                let _ = p.read_to_end(&mut stderr).await;
277            }
278            Ok(std::process::Output {
279                status,
280                stdout,
281                stderr,
282            })
283        }
284        Ok(Err(e)) => Err(anyhow!("wait failed: {}", e)),
285        Err(_elapsed) => {
286            // Timeout — kill the entire process group, not just the direct child.
287            crate::runner::kill_child_process_group(&mut child).await;
288            Err(anyhow!("timed out after {}s", timeout.as_secs()))
289        }
290    }
291}
292
293fn audit_plugin_timeout(
294    db_path: Option<&Path>,
295    crd_kind: &str,
296    plugin: &CrdPlugin,
297    error: &anyhow::Error,
298) {
299    if !error.to_string().contains("timed out") {
300        return;
301    }
302    if let Some(path) = db_path {
303        let _ = crate::db::insert_plugin_audit(
304            path,
305            &crate::db::PluginAuditRecord {
306                action: "plugin_timeout_kill".into(),
307                crd_kind: crd_kind.into(),
308                plugin_name: Some(plugin.name.clone()),
309                plugin_type: Some(plugin.plugin_type.clone()),
310                command: plugin.command.clone(),
311                applied_by: None,
312                transport: None,
313                peer_pid: None,
314                result: format!("killed_after_{}s", plugin.effective_timeout()),
315                policy_mode: None,
316            },
317        );
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use crate::crd::types::CrdPlugin;
325
326    fn make_plugin(name: &str, plugin_type: &str, phase: Option<&str>, command: &str) -> CrdPlugin {
327        CrdPlugin {
328            name: name.to_string(),
329            plugin_type: plugin_type.to_string(),
330            phase: phase.map(|s| s.to_string()),
331            command: command.to_string(),
332            timeout: Some(5),
333            schedule: None,
334            timezone: None,
335        }
336    }
337
338    #[tokio::test]
339    async fn interceptor_accepts_on_exit_zero() {
340        let plugin = make_plugin("test", "interceptor", Some("webhook.authenticate"), "true");
341        let headers = HashMap::new();
342        assert!(
343            execute_interceptor(&plugin, "Foo", &headers, "{}", None)
344                .await
345                .is_ok()
346        );
347    }
348
349    #[tokio::test]
350    async fn interceptor_rejects_on_exit_nonzero() {
351        let plugin = make_plugin(
352            "test",
353            "interceptor",
354            Some("webhook.authenticate"),
355            "exit 1",
356        );
357        let headers = HashMap::new();
358        let err = execute_interceptor(&plugin, "Foo", &headers, "{}", None)
359            .await
360            .unwrap_err();
361        assert!(err.to_string().contains("rejected request"));
362    }
363
364    #[tokio::test]
365    async fn interceptor_passes_headers_and_body() {
366        let plugin = make_plugin(
367            "check-env",
368            "interceptor",
369            Some("webhook.authenticate"),
370            r#"test "$WEBHOOK_BODY" = '{"ok":true}' && test "$WEBHOOK_HEADER_X_SIG" = "abc""#,
371        );
372        let mut headers = HashMap::new();
373        headers.insert("X-Sig".to_string(), "abc".to_string());
374        assert!(
375            execute_interceptor(&plugin, "Foo", &headers, r#"{"ok":true}"#, None)
376                .await
377                .is_ok()
378        );
379    }
380
381    #[tokio::test]
382    async fn transformer_returns_modified_json() {
383        // Transformer that wraps input in {"wrapped": <input>}
384        let plugin = make_plugin(
385            "wrap",
386            "transformer",
387            Some("webhook.transform"),
388            r#"read input; echo "{\"wrapped\":$input}""#,
389        );
390        let payload = serde_json::json!({"a": 1});
391        let result = execute_transformer(&plugin, "Foo", &payload, None)
392            .await
393            .unwrap();
394        assert!(result.get("wrapped").is_some());
395    }
396
397    #[tokio::test]
398    async fn transformer_rejects_invalid_json_output() {
399        let plugin = make_plugin(
400            "bad",
401            "transformer",
402            Some("webhook.transform"),
403            "echo 'not json'",
404        );
405        let payload = serde_json::json!({});
406        assert!(
407            execute_transformer(&plugin, "Foo", &payload, None)
408                .await
409                .is_err()
410        );
411    }
412
413    #[tokio::test]
414    async fn cron_plugin_success() {
415        let plugin = make_plugin("daily", "cron", None, "true");
416        assert!(execute_cron_plugin(&plugin, "Foo", None).await.is_ok());
417    }
418
419    #[tokio::test]
420    async fn cron_plugin_failure() {
421        let plugin = make_plugin("daily", "cron", None, "exit 42");
422        assert!(execute_cron_plugin(&plugin, "Foo", None).await.is_err());
423    }
424
425    #[test]
426    fn plugins_for_phase_filters_correctly() {
427        let plugins = vec![
428            make_plugin("a", "interceptor", Some("webhook.authenticate"), "true"),
429            make_plugin("b", "transformer", Some("webhook.transform"), "cat"),
430            make_plugin("c", "interceptor", Some("webhook.authenticate"), "true"),
431            make_plugin("d", "cron", None, "true"),
432        ];
433        let auth = plugins_for_phase(&plugins, PHASE_WEBHOOK_AUTHENTICATE);
434        assert_eq!(auth.len(), 2);
435        let transform = plugins_for_phase(&plugins, PHASE_WEBHOOK_TRANSFORM);
436        assert_eq!(transform.len(), 1);
437    }
438
439    #[test]
440    fn cron_plugins_filters_correctly() {
441        let plugins = vec![
442            make_plugin("a", "interceptor", Some("webhook.authenticate"), "true"),
443            make_plugin("b", "cron", None, "true"),
444            make_plugin("c", "cron", None, "echo hi"),
445        ];
446        assert_eq!(cron_plugins(&plugins).len(), 2);
447    }
448
449    #[tokio::test]
450    async fn interceptor_timeout_kills_process() {
451        let plugin = CrdPlugin {
452            name: "slow".to_string(),
453            plugin_type: "interceptor".to_string(),
454            phase: Some("webhook.authenticate".to_string()),
455            command: "sleep 60".to_string(),
456            timeout: Some(1), // 1 second timeout
457            schedule: None,
458            timezone: None,
459        };
460        let headers = HashMap::new();
461        let err = execute_interceptor(&plugin, "Foo", &headers, "{}", None)
462            .await
463            .unwrap_err();
464        assert!(err.to_string().contains("timed out"));
465    }
466
467    /// Verify that timeout kills the entire process group, not just the direct child.
468    /// Spawns a plugin that forks a background grandchild, then asserts the grandchild
469    /// is also killed when the plugin times out.
470    #[tokio::test]
471    async fn timeout_kills_entire_process_group() {
472        let pid_file =
473            std::env::temp_dir().join(format!("plugin_pgkill_test_{}", std::process::id()));
474        let command = format!(
475            // Fork a background child that writes its PID to a file, then sleep forever.
476            // The parent also sleeps forever. On timeout, both should be killed.
477            r#"sh -c 'echo $$ > {}; sleep 3600' & sleep 3600"#,
478            pid_file.display()
479        );
480        let plugin = CrdPlugin {
481            name: "pgkill".to_string(),
482            plugin_type: "interceptor".to_string(),
483            phase: Some("webhook.authenticate".to_string()),
484            command,
485            timeout: Some(1),
486            schedule: None,
487            timezone: None,
488        };
489        let headers = HashMap::new();
490        let err = execute_interceptor(&plugin, "Foo", &headers, "{}", None)
491            .await
492            .unwrap_err();
493        assert!(err.to_string().contains("timed out"));
494
495        // Give the OS a moment to reap.
496        tokio::time::sleep(Duration::from_millis(100)).await;
497
498        // Read the grandchild PID and verify it's no longer running.
499        if let Ok(pid_str) = std::fs::read_to_string(&pid_file) {
500            if let Ok(pid) = pid_str.trim().parse::<i32>() {
501                #[cfg(unix)]
502                {
503                    // SAFETY: kill(pid, 0) checks if process exists without
504                    // sending a signal. The pid is a valid i32 parsed from the
505                    // grandchild's PID file written earlier in this test.
506                    let alive = unsafe { libc::kill(pid, 0) };
507                    assert_ne!(alive, 0, "grandchild process {} should be dead", pid);
508                }
509            }
510        }
511        let _ = std::fs::remove_file(&pid_file);
512    }
513}