Skip to main content

daemon/grpc_local_impl/
hook.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Local-mode `HookService`. Manages hook registrations on disk and
3//! exposes the event-schema catalog so hook authors can scaffold
4//! payloads. Live event delivery (subscribe + respond) lands when the
5//! capture/merge code paths emit events.
6
7use std::{path::PathBuf, pin::Pin};
8
9use futures::Stream;
10use grpc::heddle::v1::{
11    DeleteResponse, DeregisterHookRequest, GetHookEventSchemaRequest, GetHookEventSchemaResponse,
12    Hook as ProtoHook, HookEvent as ProtoHookEvent, HookEventSchema, ListHooksRequest,
13    ListHooksResponse, RegisterHookRequest, RespondToHookRequest, RespondToHookResponse,
14    SubscribeHookEventsRequest, hook_service_server::HookService,
15};
16use objects::{error::HeddleError, fs_atomic::write_file_atomic};
17use prost::Message;
18use serde::{Deserialize, Serialize};
19use tokio_stream::{StreamExt, wrappers::ReceiverStream};
20use tonic::{Request, Response, Status};
21
22use super::{GrpcLocalService, HookResponse, to_status, with_idempotency};
23
24#[derive(Clone)]
25pub struct LocalHookService {
26    inner: GrpcLocalService,
27}
28
29impl LocalHookService {
30    pub fn new(inner: GrpcLocalService) -> Self {
31        Self { inner }
32    }
33}
34
35#[derive(Debug, Clone, Default, Serialize, Deserialize)]
36struct HookRegistry {
37    #[serde(default)]
38    hooks: Vec<HookConfig>,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
42struct HookConfig {
43    name: String,
44    command: String,
45    #[serde(default)]
46    events: Vec<String>,
47    #[serde(default)]
48    timeout_ms: u32,
49}
50
51impl HookConfig {
52    fn to_proto(&self) -> ProtoHook {
53        ProtoHook {
54            name: self.name.clone(),
55            command: self.command.clone(),
56            events: self.events.clone(),
57            timeout_ms: self.timeout_ms,
58        }
59    }
60}
61
62fn registry_path(heddle_dir: &std::path::Path) -> PathBuf {
63    heddle_dir.join("hooks").join("registry.toml")
64}
65
66fn load_registry(heddle_dir: &std::path::Path) -> Result<HookRegistry, Status> {
67    let path = registry_path(heddle_dir);
68    if !path.exists() {
69        return Ok(HookRegistry::default());
70    }
71    let raw = std::fs::read_to_string(&path).map_err(|e| to_status(HeddleError::from(e)))?;
72    toml::from_str(&raw).map_err(|e| {
73        Status::internal(format!(
74            "hook registry at {} is malformed: {e}",
75            path.display()
76        ))
77    })
78}
79
80fn save_registry(heddle_dir: &std::path::Path, registry: &HookRegistry) -> Result<(), Status> {
81    let path = registry_path(heddle_dir);
82    if let Some(parent) = path.parent() {
83        std::fs::create_dir_all(parent).map_err(|e| to_status(HeddleError::from(e)))?;
84    }
85    let raw = toml::to_string_pretty(registry)
86        .map_err(|e| Status::internal(format!("failed to encode hook registry: {e}")))?;
87    write_file_atomic(&path, raw.as_bytes()).map_err(|e| to_status(HeddleError::from(e)))
88}
89
90/// The hook event catalog. Each entry documents what the payload and
91/// response look like in JSON Schema. Exposing the catalog from
92/// `GetHookEventSchema` lets hook authors scaffold against the contract
93/// even before live event delivery is wired up.
94fn event_catalog() -> Vec<HookEventSchema> {
95    let v1 = 1;
96    vec![
97        HookEventSchema {
98            event_name: "pre_capture".to_string(),
99            schema_version: v1,
100            payload_schema_json: r#"{"type":"object","properties":{"thread":{"type":"string"},"intent":{"type":"string"}},"required":[]}"#.to_string(),
101            response_schema_json: r#"{"type":"object","properties":{"extra_signals":{"type":"array"},"abort":{"type":"string"}}}"#.to_string(),
102        },
103        HookEventSchema {
104            event_name: "post_capture".to_string(),
105            schema_version: v1,
106            payload_schema_json: r#"{"type":"object","properties":{"state_id":{"type":"string"}}}"#.to_string(),
107            response_schema_json: r#"{"type":"object"}"#.to_string(),
108        },
109        HookEventSchema {
110            event_name: "pre_merge".to_string(),
111            schema_version: v1,
112            payload_schema_json: r#"{"type":"object","properties":{"source":{"type":"string"},"target":{"type":"string"}}}"#.to_string(),
113            response_schema_json: r#"{"type":"object","properties":{"abort":{"type":"string"}}}"#.to_string(),
114        },
115        HookEventSchema {
116            event_name: "post_merge".to_string(),
117            schema_version: v1,
118            payload_schema_json: r#"{"type":"object","properties":{"state_id":{"type":"string"}}}"#.to_string(),
119            response_schema_json: r#"{"type":"object"}"#.to_string(),
120        },
121        HookEventSchema {
122            event_name: "on_conflict".to_string(),
123            schema_version: v1,
124            payload_schema_json: r#"{"type":"object","properties":{"conflicts":{"type":"array"}}}"#.to_string(),
125            response_schema_json: r#"{"type":"object","properties":{"veto":{"type":"object","properties":{"reason":{"type":"string"},"discussion_id":{"type":"string"}}}}}"#.to_string(),
126        },
127        HookEventSchema {
128            event_name: "pre_thread_create".to_string(),
129            schema_version: v1,
130            payload_schema_json: r#"{"type":"object","properties":{"name":{"type":"string"}}}"#.to_string(),
131            response_schema_json: r#"{"type":"object","properties":{"abort":{"type":"string"}}}"#.to_string(),
132        },
133        HookEventSchema {
134            event_name: "post_thread_create".to_string(),
135            schema_version: v1,
136            payload_schema_json: r#"{"type":"object","properties":{"name":{"type":"string"}}}"#.to_string(),
137            response_schema_json: r#"{"type":"object"}"#.to_string(),
138        },
139        HookEventSchema {
140            event_name: "pre_push".to_string(),
141            schema_version: v1,
142            payload_schema_json: r#"{"type":"object","properties":{"remote":{"type":"string"}}}"#.to_string(),
143            response_schema_json: r#"{"type":"object","properties":{"abort":{"type":"string"}}}"#.to_string(),
144        },
145        HookEventSchema {
146            event_name: "post_push".to_string(),
147            schema_version: v1,
148            payload_schema_json: r#"{"type":"object","properties":{"remote":{"type":"string"}}}"#.to_string(),
149            response_schema_json: r#"{"type":"object"}"#.to_string(),
150        },
151        HookEventSchema {
152            event_name: "on_signal".to_string(),
153            schema_version: v1,
154            payload_schema_json: r#"{"type":"object","properties":{"state_id":{"type":"string"},"signal_kind":{"type":"string"}}}"#.to_string(),
155            response_schema_json: r#"{"type":"object"}"#.to_string(),
156        },
157    ]
158}
159
160/// Stream type for `SubscribeHookEvents`. Boxed so tonic can hand it
161/// back through the trait associated type without surfacing the
162/// concrete `mpsc::Receiver` shape.
163pub type SubscribeHookEventsStream =
164    Pin<Box<dyn Stream<Item = Result<ProtoHookEvent, Status>> + Send>>;
165
166#[tonic::async_trait]
167impl HookService for LocalHookService {
168    type SubscribeHookEventsStream = SubscribeHookEventsStream;
169
170    async fn register_hook(
171        &self,
172        request: Request<RegisterHookRequest>,
173    ) -> Result<Response<ProtoHook>, Status> {
174        let req = request.into_inner();
175        let body = req.encode_to_vec();
176        let heddle_dir = self.inner.repo().heddle_dir().to_path_buf();
177        let client_op = req.client_operation_id.clone();
178
179        let result = with_idempotency(
180            &self.inner,
181            &client_op,
182            "hook.register_hook",
183            &body,
184            || async move {
185                if req.name.trim().is_empty() {
186                    return Err(Status::invalid_argument("hook name must not be empty"));
187                }
188                if req.command.trim().is_empty() {
189                    return Err(Status::invalid_argument("hook command must not be empty"));
190                }
191                let catalog: std::collections::HashSet<String> =
192                    event_catalog().into_iter().map(|s| s.event_name).collect();
193                for event in &req.events {
194                    if !catalog.contains(event) {
195                        return Err(Status::invalid_argument(format!(
196                            "unknown hook event '{event}' — see GetHookEventSchema for the catalog"
197                        )));
198                    }
199                }
200                let mut registry = load_registry(&heddle_dir)?;
201                registry.hooks.retain(|h| h.name != req.name);
202                let cfg = HookConfig {
203                    name: req.name.clone(),
204                    command: req.command.clone(),
205                    events: req.events.clone(),
206                    timeout_ms: req.timeout_ms,
207                };
208                registry.hooks.push(cfg.clone());
209                save_registry(&heddle_dir, &registry)?;
210                Ok(cfg.to_proto())
211            },
212        )
213        .await?;
214        Ok(Response::new(result))
215    }
216
217    async fn deregister_hook(
218        &self,
219        request: Request<DeregisterHookRequest>,
220    ) -> Result<Response<DeleteResponse>, Status> {
221        let req = request.into_inner();
222        let body = req.encode_to_vec();
223        let heddle_dir = self.inner.repo().heddle_dir().to_path_buf();
224        let client_op = req.client_operation_id.clone();
225        let result = with_idempotency(
226            &self.inner,
227            &client_op,
228            "hook.deregister_hook",
229            &body,
230            || async move {
231                let mut registry = load_registry(&heddle_dir)?;
232                let before = registry.hooks.len();
233                registry.hooks.retain(|h| h.name != req.name);
234                let deleted = registry.hooks.len() < before;
235                if deleted {
236                    save_registry(&heddle_dir, &registry)?;
237                }
238                Ok(DeleteResponse { deleted })
239            },
240        )
241        .await?;
242        Ok(Response::new(result))
243    }
244
245    async fn list_hooks(
246        &self,
247        _request: Request<ListHooksRequest>,
248    ) -> Result<Response<ListHooksResponse>, Status> {
249        let registry = load_registry(self.inner.repo().heddle_dir())?;
250        let hooks = registry.hooks.iter().map(HookConfig::to_proto).collect();
251        Ok(Response::new(ListHooksResponse { hooks }))
252    }
253
254    async fn get_hook_event_schema(
255        &self,
256        request: Request<GetHookEventSchemaRequest>,
257    ) -> Result<Response<GetHookEventSchemaResponse>, Status> {
258        let req = request.into_inner();
259        let mut catalog = event_catalog();
260        if !req.event_name.is_empty() {
261            catalog.retain(|s| s.event_name == req.event_name);
262            if catalog.is_empty() {
263                return Err(Status::not_found(format!(
264                    "unknown hook event '{}'",
265                    req.event_name
266                )));
267            }
268        }
269        Ok(Response::new(GetHookEventSchemaResponse {
270            schemas: catalog,
271        }))
272    }
273
274    async fn subscribe_hook_events(
275        &self,
276        request: Request<SubscribeHookEventsRequest>,
277    ) -> Result<Response<Self::SubscribeHookEventsStream>, Status> {
278        let req = request.into_inner();
279        // Optional event-name filter. Empty = subscribe to every
280        // event in the catalog. Validate up front so a typo is a
281        // synchronous `InvalidArgument` rather than a silently-empty
282        // stream.
283        let catalog: std::collections::HashSet<String> =
284            event_catalog().into_iter().map(|s| s.event_name).collect();
285        for event in &req.events {
286            if !catalog.contains(event) {
287                return Err(Status::invalid_argument(format!(
288                    "unknown hook event '{event}' — see GetHookEventSchema for the catalog"
289                )));
290            }
291        }
292        let filter: std::collections::HashSet<String> = req.events.into_iter().collect();
293        let receiver = self.inner.hook_events.subscribe();
294        // Adapt the broker's `mpsc::Receiver<ProtoHookEvent>` into a
295        // `tonic::Stream<Result<ProtoHookEvent, Status>>`. Apply the
296        // event-name filter on the read side so subscribers don't pay
297        // for events they don't care about.
298        let stream = ReceiverStream::new(receiver).filter_map(move |event| {
299            if filter.is_empty() || filter.contains(&event.event_name) {
300                Some(Ok(event))
301            } else {
302                None
303            }
304        });
305        Ok(Response::new(Box::pin(stream)))
306    }
307
308    async fn respond_to_hook(
309        &self,
310        request: Request<RespondToHookRequest>,
311    ) -> Result<Response<RespondToHookResponse>, Status> {
312        let req = request.into_inner();
313        let body = req.encode_to_vec();
314        let client_op = req.client_operation_id.clone();
315        let broker = self.inner.hook_events.clone();
316        let result = with_idempotency(
317            &self.inner,
318            &client_op,
319            "hook.respond_to_hook",
320            &body,
321            move || async move {
322                if req.hook_event_id.trim().is_empty() {
323                    return Err(Status::invalid_argument("hook_event_id must not be empty"));
324                }
325                // Decode `extra_signals_json` lazily — empty string =
326                // no extra. Anything else must parse as JSON; a
327                // malformed payload surfaces as `InvalidArgument`
328                // rather than getting silently dropped on the
329                // emit-side.
330                let extra = if req.extra_signals_json.trim().is_empty() {
331                    serde_json::Value::Null
332                } else {
333                    serde_json::from_str::<serde_json::Value>(&req.extra_signals_json).map_err(
334                        |err| {
335                            Status::invalid_argument(format!(
336                                "extra_signals_json is not valid JSON: {err}"
337                            ))
338                        },
339                    )?
340                };
341                let response = HookResponse {
342                    abort: req.abort,
343                    extra,
344                };
345                let accepted = broker.deliver_response(&req.hook_event_id, response);
346                Ok(RespondToHookResponse { accepted })
347            },
348        )
349        .await?;
350        Ok(Response::new(result))
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use std::sync::Arc;
357
358    use repo::Repository;
359    use tempfile::TempDir;
360
361    use super::*;
362
363    fn fresh_service() -> (TempDir, LocalHookService) {
364        let temp = TempDir::new().unwrap();
365        let repo = Repository::init_default(temp.path()).unwrap();
366        let dedup =
367            Arc::new(repo::operation_dedup::OperationDedupStore::open(repo.heddle_dir()).unwrap());
368        let inner = GrpcLocalService::new(Arc::new(repo), dedup);
369        let svc = LocalHookService::new(inner);
370        (temp, svc)
371    }
372
373    #[tokio::test]
374    #[serial_test::serial(process_global)]
375    async fn register_then_list_returns_hook() {
376        let (_t, svc) = fresh_service();
377        svc.register_hook(Request::new(RegisterHookRequest {
378            repo_path: String::new(),
379            name: "log-capture".into(),
380            command: "/usr/local/bin/heddle-log".into(),
381            events: vec!["post_capture".into()],
382            timeout_ms: 5000,
383            client_operation_id: String::new(),
384        }))
385        .await
386        .unwrap();
387        let resp = svc
388            .list_hooks(Request::new(ListHooksRequest {
389                repo_path: String::new(),
390            }))
391            .await
392            .unwrap();
393        let hooks = resp.into_inner().hooks;
394        assert_eq!(hooks.len(), 1);
395        assert_eq!(hooks[0].name, "log-capture");
396        assert_eq!(hooks[0].events, vec!["post_capture".to_string()]);
397    }
398
399    #[tokio::test]
400    #[serial_test::serial(process_global)]
401    async fn register_unknown_event_is_invalid_argument() {
402        let (_t, svc) = fresh_service();
403        let err = svc
404            .register_hook(Request::new(RegisterHookRequest {
405                repo_path: String::new(),
406                name: "x".into(),
407                command: "true".into(),
408                events: vec!["definitely_not_an_event".into()],
409                timeout_ms: 0,
410                client_operation_id: String::new(),
411            }))
412            .await
413            .unwrap_err();
414        assert_eq!(err.code(), tonic::Code::InvalidArgument);
415    }
416
417    #[tokio::test]
418    #[serial_test::serial(process_global)]
419    async fn deregister_removes_hook() {
420        let (_t, svc) = fresh_service();
421        svc.register_hook(Request::new(RegisterHookRequest {
422            repo_path: String::new(),
423            name: "x".into(),
424            command: "true".into(),
425            events: vec!["pre_capture".into()],
426            timeout_ms: 0,
427            client_operation_id: String::new(),
428        }))
429        .await
430        .unwrap();
431        let resp = svc
432            .deregister_hook(Request::new(DeregisterHookRequest {
433                repo_path: String::new(),
434                name: "x".into(),
435                client_operation_id: String::new(),
436            }))
437            .await
438            .unwrap();
439        assert!(resp.into_inner().deleted);
440        let listed = svc
441            .list_hooks(Request::new(ListHooksRequest {
442                repo_path: String::new(),
443            }))
444            .await
445            .unwrap();
446        assert!(listed.into_inner().hooks.is_empty());
447    }
448
449    #[tokio::test]
450    #[serial_test::serial(process_global)]
451    async fn get_hook_event_schema_returns_full_catalog() {
452        let (_t, svc) = fresh_service();
453        let resp = svc
454            .get_hook_event_schema(Request::new(GetHookEventSchemaRequest {
455                event_name: String::new(),
456            }))
457            .await
458            .unwrap();
459        let catalog = resp.into_inner().schemas;
460        assert!(catalog.iter().any(|s| s.event_name == "pre_capture"));
461        assert!(catalog.iter().any(|s| s.event_name == "on_conflict"));
462    }
463
464    #[tokio::test]
465    #[serial_test::serial(process_global)]
466    async fn get_hook_event_schema_unknown_returns_not_found() {
467        let (_t, svc) = fresh_service();
468        let err = svc
469            .get_hook_event_schema(Request::new(GetHookEventSchemaRequest {
470                event_name: "pretend".into(),
471            }))
472            .await
473            .unwrap_err();
474        assert_eq!(err.code(), tonic::Code::NotFound);
475    }
476
477    #[tokio::test]
478    #[serial_test::serial(process_global)]
479    async fn subscribe_then_emit_round_trips() {
480        let (_t, svc) = fresh_service();
481        let stream = svc
482            .subscribe_hook_events(Request::new(SubscribeHookEventsRequest {
483                repo_path: String::new(),
484                events: vec![],
485            }))
486            .await
487            .unwrap()
488            .into_inner();
489        let mut stream = Box::pin(stream);
490        // Yield so the subscriber's forwarding task is wired up
491        // before the broker emit fires.
492        tokio::task::yield_now().await;
493        let id = svc.inner.hook_events.emit("post_capture", "{}");
494        let event = futures::StreamExt::next(&mut stream)
495            .await
496            .expect("event")
497            .expect("ok");
498        assert_eq!(event.hook_event_id, id);
499        assert_eq!(event.event_name, "post_capture");
500    }
501
502    #[tokio::test]
503    #[serial_test::serial(process_global)]
504    async fn subscribe_unknown_event_is_invalid_argument() {
505        let (_t, svc) = fresh_service();
506        let result = svc
507            .subscribe_hook_events(Request::new(SubscribeHookEventsRequest {
508                repo_path: String::new(),
509                events: vec!["definitely_not_an_event".into()],
510            }))
511            .await;
512        // `Response<Stream>` doesn't implement Debug so we can't
513        // `unwrap_err` here. Match on the result instead.
514        match result {
515            Err(status) => assert_eq!(status.code(), tonic::Code::InvalidArgument),
516            Ok(_) => panic!("expected InvalidArgument, got Ok"),
517        }
518    }
519
520    #[tokio::test]
521    #[serial_test::serial(process_global)]
522    async fn respond_to_hook_delivers_to_emit_waiter() {
523        use std::time::Duration;
524        let (_t, svc) = fresh_service();
525        let _stream = svc
526            .subscribe_hook_events(Request::new(SubscribeHookEventsRequest {
527                repo_path: String::new(),
528                events: vec![],
529            }))
530            .await
531            .unwrap()
532            .into_inner();
533        tokio::task::yield_now().await;
534        let (id, waiter) =
535            svc.inner
536                .hook_events
537                .emit_and_wait("pre_capture", "{}", Duration::from_secs(1));
538        let resp = svc
539            .respond_to_hook(Request::new(RespondToHookRequest {
540                repo_path: String::new(),
541                hook_event_id: id,
542                abort: "veto".into(),
543                extra_signals_json: String::new(),
544                client_operation_id: String::new(),
545            }))
546            .await
547            .unwrap();
548        assert!(resp.into_inner().accepted);
549        let response = waiter.wait().await.expect("response");
550        assert_eq!(response.abort, "veto");
551    }
552
553    #[tokio::test]
554    #[serial_test::serial(process_global)]
555    async fn respond_to_hook_rejects_empty_id() {
556        let (_t, svc) = fresh_service();
557        let err = svc
558            .respond_to_hook(Request::new(RespondToHookRequest {
559                repo_path: String::new(),
560                hook_event_id: String::new(),
561                abort: String::new(),
562                extra_signals_json: String::new(),
563                client_operation_id: String::new(),
564            }))
565            .await
566            .unwrap_err();
567        assert_eq!(err.code(), tonic::Code::InvalidArgument);
568    }
569
570    #[tokio::test]
571    #[serial_test::serial(process_global)]
572    async fn respond_to_hook_unknown_id_returns_not_accepted() {
573        let (_t, svc) = fresh_service();
574        let resp = svc
575            .respond_to_hook(Request::new(RespondToHookRequest {
576                repo_path: String::new(),
577                hook_event_id: "made-up".into(),
578                abort: String::new(),
579                extra_signals_json: String::new(),
580                client_operation_id: String::new(),
581            }))
582            .await
583            .unwrap();
584        assert!(!resp.into_inner().accepted);
585    }
586
587    #[tokio::test]
588    #[serial_test::serial(process_global)]
589    async fn register_idempotent_returns_same_hook() {
590        let (_t, svc) = fresh_service();
591        let op_id = objects::object::OperationId::new().to_string();
592        let req = RegisterHookRequest {
593            repo_path: String::new(),
594            name: "foo".into(),
595            command: "true".into(),
596            events: vec!["pre_capture".into()],
597            timeout_ms: 1000,
598            client_operation_id: op_id.clone(),
599        };
600        let first = svc
601            .register_hook(Request::new(req.clone()))
602            .await
603            .unwrap()
604            .into_inner();
605        let second = svc
606            .register_hook(Request::new(req))
607            .await
608            .unwrap()
609            .into_inner();
610        assert_eq!(first, second);
611        let listed = svc
612            .list_hooks(Request::new(ListHooksRequest {
613                repo_path: String::new(),
614            }))
615            .await
616            .unwrap();
617        assert_eq!(listed.into_inner().hooks.len(), 1);
618    }
619}