Skip to main content

ralph_core/hooks/
engine.rs

1use crate::config::{
2    HookDefaults, HookMutationConfig, HookOnError, HookPhaseEvent, HookSpec, HookSuspendMode,
3    HooksConfig,
4};
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use serde_json::{Map, Value};
8use std::collections::HashMap;
9use std::path::PathBuf;
10
11const HOOK_PAYLOAD_SCHEMA_VERSION: u32 = 1;
12const DEFAULT_ACTIVE_HAT: &str = "ralph";
13
14/// Resolves configured hooks for a lifecycle phase-event.
15#[derive(Debug, Clone)]
16pub struct HookEngine {
17    defaults: HookDefaults,
18    hooks_by_phase_event: HashMap<HookPhaseEvent, Vec<HookSpec>>,
19}
20
21impl HookEngine {
22    /// Creates a hook engine from validated hook configuration.
23    #[must_use]
24    pub fn new(config: &HooksConfig) -> Self {
25        Self {
26            defaults: config.defaults.clone(),
27            hooks_by_phase_event: config.events.clone(),
28        }
29    }
30
31    /// Resolves hooks for a canonical phase-event key in declaration order.
32    #[must_use]
33    pub fn resolve_phase_event(&self, phase_event: HookPhaseEvent) -> Vec<ResolvedHookSpec> {
34        self.hooks_by_phase_event
35            .get(&phase_event)
36            .map(|hooks| {
37                hooks
38                    .iter()
39                    .enumerate()
40                    .map(|(declaration_order, hook)| {
41                        ResolvedHookSpec::from_spec(
42                            phase_event,
43                            declaration_order,
44                            &self.defaults,
45                            hook,
46                        )
47                    })
48                    .collect()
49            })
50            .unwrap_or_default()
51    }
52
53    /// Resolves hooks by phase-event string key.
54    ///
55    /// Unknown phase-event keys return an empty list.
56    #[must_use]
57    pub fn resolve_phase_event_str(&self, phase_event: &str) -> Vec<ResolvedHookSpec> {
58        HookPhaseEvent::parse(phase_event)
59            .map(|phase| self.resolve_phase_event(phase))
60            .unwrap_or_default()
61    }
62
63    /// Builds the lifecycle JSON payload sent to hook stdin.
64    #[must_use]
65    pub fn build_payload(
66        &self,
67        phase_event: HookPhaseEvent,
68        input: HookPayloadBuilderInput,
69    ) -> HookInvocationPayload {
70        self.build_payload_with_timestamp(phase_event, input, Utc::now())
71    }
72
73    /// Builds the lifecycle JSON payload sent to hook stdin with a fixed timestamp.
74    #[must_use]
75    pub fn build_payload_with_timestamp(
76        &self,
77        phase_event: HookPhaseEvent,
78        input: HookPayloadBuilderInput,
79        timestamp: DateTime<Utc>,
80    ) -> HookInvocationPayload {
81        let (phase, event) = split_phase_event(phase_event);
82        let HookPayloadBuilderInput {
83            loop_id,
84            is_primary,
85            workspace,
86            repo_root,
87            pid,
88            iteration_current,
89            iteration_max,
90            context,
91        } = input;
92
93        let HookPayloadContextInput {
94            active_hat,
95            selected_hat,
96            selected_task,
97            termination_reason,
98            human_interact,
99            metadata,
100        } = context;
101
102        HookInvocationPayload {
103            schema_version: HOOK_PAYLOAD_SCHEMA_VERSION,
104            phase: phase.to_string(),
105            event: event.to_string(),
106            phase_event: phase_event.as_str().to_string(),
107            timestamp,
108            loop_context: HookPayloadLoop {
109                id: loop_id,
110                is_primary,
111                workspace: workspace.to_string_lossy().into_owned(),
112                repo_root: repo_root.to_string_lossy().into_owned(),
113                pid,
114            },
115            iteration: HookPayloadIteration {
116                current: iteration_current,
117                max: iteration_max,
118            },
119            context: HookPayloadContext {
120                active_hat: active_hat.unwrap_or_else(|| DEFAULT_ACTIVE_HAT.to_string()),
121                selected_hat,
122                selected_task,
123                termination_reason,
124                human_interact,
125            },
126            metadata: HookPayloadMetadata {
127                accumulated: metadata,
128            },
129        }
130    }
131}
132
133fn split_phase_event(phase_event: HookPhaseEvent) -> (&'static str, &'static str) {
134    phase_event.as_str().split_once('.').expect(
135        "HookPhaseEvent canonical keys always contain a phase prefix and event suffix separated by '.'",
136    )
137}
138
139/// Hook spec with defaults materialized for runtime dispatch.
140#[derive(Debug, Clone)]
141pub struct ResolvedHookSpec {
142    pub phase_event: HookPhaseEvent,
143    pub declaration_order: usize,
144    pub name: String,
145    pub command: Vec<String>,
146    pub cwd: Option<PathBuf>,
147    pub env: HashMap<String, String>,
148    pub timeout_seconds: u64,
149    pub max_output_bytes: u64,
150    pub on_error: HookOnError,
151    pub suspend_mode: HookSuspendMode,
152    pub mutate: HookMutationConfig,
153}
154
155impl ResolvedHookSpec {
156    fn from_spec(
157        phase_event: HookPhaseEvent,
158        declaration_order: usize,
159        defaults: &HookDefaults,
160        spec: &HookSpec,
161    ) -> Self {
162        Self {
163            phase_event,
164            declaration_order,
165            name: spec.name.clone(),
166            command: spec.command.clone(),
167            cwd: spec.cwd.clone(),
168            env: spec.env.clone(),
169            timeout_seconds: spec.timeout_seconds.unwrap_or(defaults.timeout_seconds),
170            max_output_bytes: spec.max_output_bytes.unwrap_or(defaults.max_output_bytes),
171            on_error: spec.on_error.unwrap_or(HookOnError::Warn),
172            suspend_mode: spec.suspend_mode.unwrap_or(defaults.suspend_mode),
173            mutate: spec.mutate.clone(),
174        }
175    }
176}
177
178/// Input contract for building hook invocation stdin payloads.
179#[derive(Debug, Clone)]
180pub struct HookPayloadBuilderInput {
181    pub loop_id: String,
182    pub is_primary: bool,
183    pub workspace: PathBuf,
184    pub repo_root: PathBuf,
185    pub pid: u32,
186    pub iteration_current: u32,
187    pub iteration_max: u32,
188    pub context: HookPayloadContextInput,
189}
190
191/// Mutable lifecycle context fields carried in hook stdin payloads.
192#[derive(Debug, Clone, Default)]
193pub struct HookPayloadContextInput {
194    pub active_hat: Option<String>,
195    pub selected_hat: Option<String>,
196    pub selected_task: Option<String>,
197    pub termination_reason: Option<String>,
198    pub human_interact: Option<Value>,
199    pub metadata: Map<String, Value>,
200}
201
202/// Structured lifecycle payload sent to hook stdin as JSON.
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct HookInvocationPayload {
205    pub schema_version: u32,
206    pub phase: String,
207    pub event: String,
208    pub phase_event: String,
209    pub timestamp: DateTime<Utc>,
210    #[serde(rename = "loop")]
211    pub loop_context: HookPayloadLoop,
212    pub iteration: HookPayloadIteration,
213    pub context: HookPayloadContext,
214    pub metadata: HookPayloadMetadata,
215}
216
217/// Loop metadata payload block.
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct HookPayloadLoop {
220    pub id: String,
221    pub is_primary: bool,
222    pub workspace: String,
223    pub repo_root: String,
224    pub pid: u32,
225}
226
227/// Iteration metadata payload block.
228#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct HookPayloadIteration {
230    pub current: u32,
231    pub max: u32,
232}
233
234/// Lifecycle context payload block.
235#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct HookPayloadContext {
237    pub active_hat: String,
238    pub selected_hat: Option<String>,
239    pub selected_task: Option<String>,
240    pub termination_reason: Option<String>,
241    pub human_interact: Option<Value>,
242}
243
244/// Mutable metadata payload block.
245#[derive(Debug, Clone, Default, Serialize, Deserialize)]
246pub struct HookPayloadMetadata {
247    #[serde(default)]
248    pub accumulated: Map<String, Value>,
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use chrono::TimeZone;
255    use serde_json::json;
256
257    fn hook_spec(name: &str) -> HookSpec {
258        HookSpec {
259            name: name.to_string(),
260            command: vec!["echo".to_string(), name.to_string()],
261            cwd: None,
262            env: HashMap::new(),
263            timeout_seconds: None,
264            max_output_bytes: None,
265            on_error: Some(HookOnError::Warn),
266            suspend_mode: None,
267            mutate: HookMutationConfig::default(),
268            extra: HashMap::new(),
269        }
270    }
271
272    fn hooks_config(events: HashMap<HookPhaseEvent, Vec<HookSpec>>) -> HooksConfig {
273        HooksConfig {
274            enabled: true,
275            defaults: HookDefaults {
276                timeout_seconds: 45,
277                max_output_bytes: 16_384,
278                suspend_mode: HookSuspendMode::WaitThenRetry,
279            },
280            events,
281            extra: HashMap::new(),
282        }
283    }
284
285    fn fixed_time(hour: u32, minute: u32, second: u32) -> DateTime<Utc> {
286        Utc.with_ymd_and_hms(2026, 2, 28, hour, minute, second)
287            .single()
288            .expect("fixed timestamp")
289    }
290
291    fn payload_input() -> HookPayloadBuilderInput {
292        HookPayloadBuilderInput {
293            loop_id: "loop-1234-abcd".to_string(),
294            is_primary: false,
295            workspace: PathBuf::from("/repo/.worktrees/loop-1234-abcd"),
296            repo_root: PathBuf::from("/repo"),
297            pid: 12345,
298            iteration_current: 7,
299            iteration_max: 100,
300            context: HookPayloadContextInput::default(),
301        }
302    }
303
304    #[test]
305    fn resolve_phase_event_preserves_declaration_order() {
306        let mut events = HashMap::new();
307        events.insert(
308            HookPhaseEvent::PreLoopStart,
309            vec![
310                hook_spec("env-guard"),
311                hook_spec("workspace-check"),
312                hook_spec("notify"),
313            ],
314        );
315        events.insert(HookPhaseEvent::PostLoopStart, vec![hook_spec("post-loop")]);
316
317        let engine = HookEngine::new(&hooks_config(events));
318        let resolved = engine.resolve_phase_event(HookPhaseEvent::PreLoopStart);
319
320        assert_eq!(resolved.len(), 3);
321        assert_eq!(resolved[0].name, "env-guard");
322        assert_eq!(resolved[0].declaration_order, 0);
323        assert_eq!(resolved[1].name, "workspace-check");
324        assert_eq!(resolved[1].declaration_order, 1);
325        assert_eq!(resolved[2].name, "notify");
326        assert_eq!(resolved[2].declaration_order, 2);
327        assert!(
328            resolved
329                .iter()
330                .all(|hook| hook.phase_event == HookPhaseEvent::PreLoopStart)
331        );
332    }
333
334    #[test]
335    fn resolve_phase_event_applies_defaults_and_per_hook_overrides() {
336        let mut hook_with_overrides = hook_spec("manual-gate");
337        hook_with_overrides.timeout_seconds = Some(9);
338        hook_with_overrides.max_output_bytes = Some(777);
339        hook_with_overrides.on_error = Some(HookOnError::Suspend);
340        hook_with_overrides.suspend_mode = Some(HookSuspendMode::RetryBackoff);
341
342        let mut events = HashMap::new();
343        events.insert(
344            HookPhaseEvent::PreIterationStart,
345            vec![hook_spec("defaulted"), hook_with_overrides],
346        );
347
348        let engine = HookEngine::new(&hooks_config(events));
349        let resolved = engine.resolve_phase_event(HookPhaseEvent::PreIterationStart);
350
351        assert_eq!(resolved.len(), 2);
352
353        assert_eq!(resolved[0].timeout_seconds, 45);
354        assert_eq!(resolved[0].max_output_bytes, 16_384);
355        assert_eq!(resolved[0].on_error, HookOnError::Warn);
356        assert_eq!(resolved[0].suspend_mode, HookSuspendMode::WaitThenRetry);
357
358        assert_eq!(resolved[1].timeout_seconds, 9);
359        assert_eq!(resolved[1].max_output_bytes, 777);
360        assert_eq!(resolved[1].on_error, HookOnError::Suspend);
361        assert_eq!(resolved[1].suspend_mode, HookSuspendMode::RetryBackoff);
362    }
363
364    #[test]
365    fn resolve_phase_event_returns_empty_for_unconfigured_or_unknown_phase() {
366        let mut events = HashMap::new();
367        events.insert(HookPhaseEvent::PreLoopStart, vec![hook_spec("env-guard")]);
368
369        let engine = HookEngine::new(&hooks_config(events));
370
371        let missing = engine.resolve_phase_event(HookPhaseEvent::PostIterationStart);
372        assert!(missing.is_empty());
373
374        let unknown = engine.resolve_phase_event_str("post.nonexistent.event");
375        assert!(unknown.is_empty());
376    }
377
378    #[test]
379    fn build_payload_maps_loop_iteration_and_context_fields() {
380        let engine = HookEngine::new(&hooks_config(HashMap::new()));
381        let mut input = payload_input();
382
383        let mut metadata = Map::new();
384        metadata.insert("risk_score".to_string(), json!(0.72));
385
386        input.context = HookPayloadContextInput {
387            active_hat: Some("ralph".to_string()),
388            selected_hat: Some("builder".to_string()),
389            selected_task: Some("task-1772314313-a244".to_string()),
390            termination_reason: None,
391            human_interact: Some(json!({"question": "Proceed?"})),
392            metadata,
393        };
394
395        let payload = engine.build_payload_with_timestamp(
396            HookPhaseEvent::PostIterationStart,
397            input,
398            fixed_time(21, 47, 0),
399        );
400
401        assert_eq!(payload.schema_version, HOOK_PAYLOAD_SCHEMA_VERSION);
402        assert_eq!(payload.phase, "post");
403        assert_eq!(payload.event, "iteration.start");
404        assert_eq!(payload.phase_event, "post.iteration.start");
405        assert_eq!(payload.loop_context.id, "loop-1234-abcd");
406        assert!(!payload.loop_context.is_primary);
407        assert_eq!(
408            payload.loop_context.workspace,
409            "/repo/.worktrees/loop-1234-abcd"
410        );
411        assert_eq!(payload.loop_context.repo_root, "/repo");
412        assert_eq!(payload.loop_context.pid, 12345);
413        assert_eq!(payload.iteration.current, 7);
414        assert_eq!(payload.iteration.max, 100);
415        assert_eq!(payload.context.active_hat, "ralph");
416        assert_eq!(payload.context.selected_hat.as_deref(), Some("builder"));
417        assert_eq!(
418            payload.context.selected_task.as_deref(),
419            Some("task-1772314313-a244")
420        );
421        assert_eq!(payload.metadata.accumulated["risk_score"], json!(0.72));
422
423        let value = serde_json::to_value(&payload).expect("serialize payload");
424        assert_eq!(value["loop"]["id"], "loop-1234-abcd");
425        assert_eq!(value["context"]["selected_hat"], "builder");
426        assert_eq!(value["context"]["selected_task"], "task-1772314313-a244");
427        assert_eq!(value["metadata"]["accumulated"]["risk_score"], json!(0.72));
428    }
429
430    #[test]
431    fn build_payload_defaults_optional_context_fields() {
432        let engine = HookEngine::new(&hooks_config(HashMap::new()));
433        let payload = engine.build_payload_with_timestamp(
434            HookPhaseEvent::PreLoopStart,
435            payload_input(),
436            fixed_time(21, 48, 0),
437        );
438
439        assert_eq!(payload.phase, "pre");
440        assert_eq!(payload.event, "loop.start");
441        assert_eq!(payload.phase_event, "pre.loop.start");
442        assert_eq!(payload.context.active_hat, DEFAULT_ACTIVE_HAT);
443        assert!(payload.context.selected_hat.is_none());
444        assert!(payload.context.selected_task.is_none());
445        assert!(payload.context.termination_reason.is_none());
446        assert!(payload.context.human_interact.is_none());
447        assert!(payload.metadata.accumulated.is_empty());
448
449        let value = serde_json::to_value(&payload).expect("serialize payload");
450        assert!(value["context"]["selected_hat"].is_null());
451        assert!(value["context"]["selected_task"].is_null());
452        assert!(
453            value["metadata"]["accumulated"]
454                .as_object()
455                .expect("accumulated metadata object")
456                .is_empty()
457        );
458    }
459}