Skip to main content

daemon/grpc_local_impl/
hook.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Local-mode `HookService`. Manages hook registrations on disk and
3//! exposes the event-schema catalog so hook authors can scaffold
4//! payloads. Live event delivery (subscribe + respond) lands when the
5//! capture/merge code paths emit events.
6
7use std::{path::PathBuf, pin::Pin};
8
9use futures::Stream;
10use grpc::heddle::v1::{
11    DeleteResponse, DeregisterHookRequest, GetHookEventSchemaRequest, GetHookEventSchemaResponse,
12    Hook as ProtoHook, HookEvent as ProtoHookEvent, HookEventSchema, ListHooksRequest,
13    ListHooksResponse, RegisterHookRequest, RespondToHookRequest, RespondToHookResponse,
14    SubscribeHookEventsRequest, hook_service_server::HookService,
15};
16use objects::{error::HeddleError, fs_atomic::write_file_atomic};
17use prost::Message;
18use serde::{Deserialize, Serialize};
19use tokio_stream::{StreamExt, wrappers::ReceiverStream};
20use tonic::{Request, Response, Status};
21
22use super::{GrpcLocalService, HookResponse, to_status, with_idempotency};
23
24#[derive(Clone)]
25pub struct LocalHookService {
26    inner: GrpcLocalService,
27}
28
29impl LocalHookService {
30    pub fn new(inner: GrpcLocalService) -> Self {
31        Self { inner }
32    }
33}
34
35#[derive(Debug, Clone, Default, Serialize, Deserialize)]
36struct HookRegistry {
37    #[serde(default)]
38    hooks: Vec<HookConfig>,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
42struct HookConfig {
43    name: String,
44    command: String,
45    #[serde(default)]
46    events: Vec<String>,
47    #[serde(default)]
48    timeout_ms: u32,
49}
50
51impl HookConfig {
52    fn to_proto(&self) -> ProtoHook {
53        ProtoHook {
54            name: self.name.clone(),
55            command: self.command.clone(),
56            events: self.events.clone(),
57            timeout_ms: self.timeout_ms,
58        }
59    }
60}
61
62fn registry_path(heddle_dir: &std::path::Path) -> PathBuf {
63    heddle_dir.join("hooks").join("registry.toml")
64}
65
66fn load_registry(heddle_dir: &std::path::Path) -> Result<HookRegistry, Status> {
67    let path = registry_path(heddle_dir);
68    if !path.exists() {
69        return Ok(HookRegistry::default());
70    }
71    let raw = std::fs::read_to_string(&path).map_err(|e| to_status(HeddleError::from(e)))?;
72    toml::from_str(&raw).map_err(|e| {
73        Status::internal(format!(
74            "hook registry at {} is malformed: {e}",
75            path.display()
76        ))
77    })
78}
79
80fn save_registry(heddle_dir: &std::path::Path, registry: &HookRegistry) -> Result<(), Status> {
81    let path = registry_path(heddle_dir);
82    if let Some(parent) = path.parent() {
83        std::fs::create_dir_all(parent).map_err(|e| to_status(HeddleError::from(e)))?;
84    }
85    let raw = toml::to_string_pretty(registry)
86        .map_err(|e| Status::internal(format!("failed to encode hook registry: {e}")))?;
87    write_file_atomic(&path, raw.as_bytes()).map_err(|e| to_status(HeddleError::from(e)))
88}
89
90/// The hook event catalog. Each entry documents what the payload and
91/// response look like in JSON Schema. Exposing the catalog from
92/// `GetHookEventSchema` lets hook authors scaffold against the contract
93/// even before live event delivery is wired up.
94fn event_catalog() -> Vec<HookEventSchema> {
95    let v1 = 1;
96    vec![
97        HookEventSchema {
98            event_name: "pre_capture".to_string(),
99            schema_version: v1,
100            payload_schema_json: r#"{"type":"object","properties":{"thread":{"type":"string"},"intent":{"type":"string"}},"required":[]}"#.to_string(),
101            response_schema_json: r#"{"type":"object","properties":{"extra_signals":{"type":"array"},"abort":{"type":"string"}}}"#.to_string(),
102        },
103        HookEventSchema {
104            event_name: "post_capture".to_string(),
105            schema_version: v1,
106            payload_schema_json: r#"{"type":"object","properties":{"state_id":{"type":"string"}}}"#.to_string(),
107            response_schema_json: r#"{"type":"object"}"#.to_string(),
108        },
109        HookEventSchema {
110            event_name: "pre_merge".to_string(),
111            schema_version: v1,
112            payload_schema_json: r#"{"type":"object","properties":{"source":{"type":"string"},"target":{"type":"string"}}}"#.to_string(),
113            response_schema_json: r#"{"type":"object","properties":{"abort":{"type":"string"}}}"#.to_string(),
114        },
115        HookEventSchema {
116            event_name: "post_merge".to_string(),
117            schema_version: v1,
118            payload_schema_json: r#"{"type":"object","properties":{"state_id":{"type":"string"}}}"#.to_string(),
119            response_schema_json: r#"{"type":"object"}"#.to_string(),
120        },
121        HookEventSchema {
122            event_name: "on_conflict".to_string(),
123            schema_version: v1,
124            payload_schema_json: r#"{"type":"object","properties":{"conflicts":{"type":"array"}}}"#.to_string(),
125            response_schema_json: r#"{"type":"object","properties":{"veto":{"type":"object","properties":{"reason":{"type":"string"},"discussion_id":{"type":"string"}}}}}"#.to_string(),
126        },
127        HookEventSchema {
128            event_name: "pre_thread_create".to_string(),
129            schema_version: v1,
130            payload_schema_json: r#"{"type":"object","properties":{"name":{"type":"string"}}}"#.to_string(),
131            response_schema_json: r#"{"type":"object","properties":{"abort":{"type":"string"}}}"#.to_string(),
132        },
133        HookEventSchema {
134            event_name: "post_thread_create".to_string(),
135            schema_version: v1,
136            payload_schema_json: r#"{"type":"object","properties":{"name":{"type":"string"}}}"#.to_string(),
137            response_schema_json: r#"{"type":"object"}"#.to_string(),
138        },
139        HookEventSchema {
140            event_name: "pre_push".to_string(),
141            schema_version: v1,
142            payload_schema_json: r#"{"type":"object","properties":{"remote":{"type":"string"}}}"#.to_string(),
143            response_schema_json: r#"{"type":"object","properties":{"abort":{"type":"string"}}}"#.to_string(),
144        },
145        HookEventSchema {
146            event_name: "post_push".to_string(),
147            schema_version: v1,
148            payload_schema_json: r#"{"type":"object","properties":{"remote":{"type":"string"}}}"#.to_string(),
149            response_schema_json: r#"{"type":"object"}"#.to_string(),
150        },
151        HookEventSchema {
152            event_name: "on_signal".to_string(),
153            schema_version: v1,
154            payload_schema_json: r#"{"type":"object","properties":{"state_id":{"type":"string"},"signal_kind":{"type":"string"}}}"#.to_string(),
155            response_schema_json: r#"{"type":"object"}"#.to_string(),
156        },
157    ]
158}
159
160/// Stream type for `SubscribeHookEvents`. Boxed so tonic can hand it
161/// back through the trait associated type without surfacing the
162/// concrete `mpsc::Receiver` shape.
163pub type SubscribeHookEventsStream =
164    Pin<Box<dyn Stream<Item = Result<ProtoHookEvent, Status>> + Send>>;
165
166#[tonic::async_trait]
167impl HookService for LocalHookService {
168    type SubscribeHookEventsStream = SubscribeHookEventsStream;
169
170    async fn register_hook(
171        &self,
172        request: Request<RegisterHookRequest>,
173    ) -> Result<Response<ProtoHook>, Status> {
174        let req = request.into_inner();
175        let body = req.encode_to_vec();
176        let heddle_dir = self.inner.repo().heddle_dir().to_path_buf();
177        let dedup = self.inner.dedup();
178        let client_op = req.client_operation_id.clone();
179
180        let result = with_idempotency(
181            dedup,
182            &client_op,
183            "hook.register_hook",
184            &body,
185            |hook: &ProtoHook| hook.encode_to_vec(),
186            |bytes| ProtoHook::decode(&bytes[..]).map_err(|e| Status::internal(e.to_string())),
187            || async move {
188                if req.name.trim().is_empty() {
189                    return Err(Status::invalid_argument("hook name must not be empty"));
190                }
191                if req.command.trim().is_empty() {
192                    return Err(Status::invalid_argument("hook command must not be empty"));
193                }
194                let catalog: std::collections::HashSet<String> =
195                    event_catalog().into_iter().map(|s| s.event_name).collect();
196                for event in &req.events {
197                    if !catalog.contains(event) {
198                        return Err(Status::invalid_argument(format!(
199                            "unknown hook event '{event}' — see GetHookEventSchema for the catalog"
200                        )));
201                    }
202                }
203                let mut registry = load_registry(&heddle_dir)?;
204                registry.hooks.retain(|h| h.name != req.name);
205                let cfg = HookConfig {
206                    name: req.name.clone(),
207                    command: req.command.clone(),
208                    events: req.events.clone(),
209                    timeout_ms: req.timeout_ms,
210                };
211                registry.hooks.push(cfg.clone());
212                save_registry(&heddle_dir, &registry)?;
213                Ok(cfg.to_proto())
214            },
215        )
216        .await?;
217        Ok(Response::new(result))
218    }
219
220    async fn deregister_hook(
221        &self,
222        request: Request<DeregisterHookRequest>,
223    ) -> Result<Response<DeleteResponse>, Status> {
224        let req = request.into_inner();
225        let body = req.encode_to_vec();
226        let heddle_dir = self.inner.repo().heddle_dir().to_path_buf();
227        let dedup = self.inner.dedup();
228        let client_op = req.client_operation_id.clone();
229        let result = with_idempotency(
230            dedup,
231            &client_op,
232            "hook.deregister_hook",
233            &body,
234            |resp: &DeleteResponse| resp.encode_to_vec(),
235            |bytes| DeleteResponse::decode(&bytes[..]).map_err(|e| Status::internal(e.to_string())),
236            || async move {
237                let mut registry = load_registry(&heddle_dir)?;
238                let before = registry.hooks.len();
239                registry.hooks.retain(|h| h.name != req.name);
240                let deleted = registry.hooks.len() < before;
241                if deleted {
242                    save_registry(&heddle_dir, &registry)?;
243                }
244                Ok(DeleteResponse { deleted })
245            },
246        )
247        .await?;
248        Ok(Response::new(result))
249    }
250
251    async fn list_hooks(
252        &self,
253        _request: Request<ListHooksRequest>,
254    ) -> Result<Response<ListHooksResponse>, Status> {
255        let registry = load_registry(self.inner.repo().heddle_dir())?;
256        let hooks = registry.hooks.iter().map(HookConfig::to_proto).collect();
257        Ok(Response::new(ListHooksResponse { hooks }))
258    }
259
260    async fn get_hook_event_schema(
261        &self,
262        request: Request<GetHookEventSchemaRequest>,
263    ) -> Result<Response<GetHookEventSchemaResponse>, Status> {
264        let req = request.into_inner();
265        let mut catalog = event_catalog();
266        if !req.event_name.is_empty() {
267            catalog.retain(|s| s.event_name == req.event_name);
268            if catalog.is_empty() {
269                return Err(Status::not_found(format!(
270                    "unknown hook event '{}'",
271                    req.event_name
272                )));
273            }
274        }
275        Ok(Response::new(GetHookEventSchemaResponse {
276            schemas: catalog,
277        }))
278    }
279
280    async fn subscribe_hook_events(
281        &self,
282        request: Request<SubscribeHookEventsRequest>,
283    ) -> Result<Response<Self::SubscribeHookEventsStream>, Status> {
284        let req = request.into_inner();
285        // Optional event-name filter. Empty = subscribe to every
286        // event in the catalog. Validate up front so a typo is a
287        // synchronous `InvalidArgument` rather than a silently-empty
288        // stream.
289        let catalog: std::collections::HashSet<String> =
290            event_catalog().into_iter().map(|s| s.event_name).collect();
291        for event in &req.events {
292            if !catalog.contains(event) {
293                return Err(Status::invalid_argument(format!(
294                    "unknown hook event '{event}' — see GetHookEventSchema for the catalog"
295                )));
296            }
297        }
298        let filter: std::collections::HashSet<String> = req.events.into_iter().collect();
299        let receiver = self.inner.hook_events.subscribe();
300        // Adapt the broker's `mpsc::Receiver<ProtoHookEvent>` into a
301        // `tonic::Stream<Result<ProtoHookEvent, Status>>`. Apply the
302        // event-name filter on the read side so subscribers don't pay
303        // for events they don't care about.
304        let stream = ReceiverStream::new(receiver).filter_map(move |event| {
305            if filter.is_empty() || filter.contains(&event.event_name) {
306                Some(Ok(event))
307            } else {
308                None
309            }
310        });
311        Ok(Response::new(Box::pin(stream)))
312    }
313
314    async fn respond_to_hook(
315        &self,
316        request: Request<RespondToHookRequest>,
317    ) -> Result<Response<RespondToHookResponse>, Status> {
318        let req = request.into_inner();
319        let body = req.encode_to_vec();
320        let dedup = self.inner.dedup();
321        let client_op = req.client_operation_id.clone();
322        let broker = self.inner.hook_events.clone();
323        let result = with_idempotency(
324            dedup,
325            &client_op,
326            "hook.respond_to_hook",
327            &body,
328            |resp: &RespondToHookResponse| resp.encode_to_vec(),
329            |bytes| {
330                RespondToHookResponse::decode(&bytes[..])
331                    .map_err(|e| Status::internal(e.to_string()))
332            },
333            move || async move {
334                if req.hook_event_id.trim().is_empty() {
335                    return Err(Status::invalid_argument("hook_event_id must not be empty"));
336                }
337                // Decode `extra_signals_json` lazily — empty string =
338                // no extra. Anything else must parse as JSON; a
339                // malformed payload surfaces as `InvalidArgument`
340                // rather than getting silently dropped on the
341                // emit-side.
342                let extra = if req.extra_signals_json.trim().is_empty() {
343                    serde_json::Value::Null
344                } else {
345                    serde_json::from_str::<serde_json::Value>(&req.extra_signals_json).map_err(
346                        |err| {
347                            Status::invalid_argument(format!(
348                                "extra_signals_json is not valid JSON: {err}"
349                            ))
350                        },
351                    )?
352                };
353                let response = HookResponse {
354                    abort: req.abort,
355                    extra,
356                };
357                let accepted = broker.deliver_response(&req.hook_event_id, response);
358                Ok(RespondToHookResponse { accepted })
359            },
360        )
361        .await?;
362        Ok(Response::new(result))
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use std::sync::Arc;
369
370    use repo::Repository;
371    use tempfile::TempDir;
372
373    use super::*;
374
375    fn fresh_service() -> (TempDir, LocalHookService) {
376        let temp = TempDir::new().unwrap();
377        let repo = Repository::init_default(temp.path()).unwrap();
378        let dedup =
379            Arc::new(repo::operation_dedup::OperationDedupStore::open(repo.heddle_dir()).unwrap());
380        let inner = GrpcLocalService::new(Arc::new(repo), dedup);
381        let svc = LocalHookService::new(inner);
382        (temp, svc)
383    }
384
385    #[tokio::test]
386    async fn register_then_list_returns_hook() {
387        let (_t, svc) = fresh_service();
388        svc.register_hook(Request::new(RegisterHookRequest {
389            repo_path: String::new(),
390            name: "log-capture".into(),
391            command: "/usr/local/bin/heddle-log".into(),
392            events: vec!["post_capture".into()],
393            timeout_ms: 5000,
394            client_operation_id: String::new(),
395        }))
396        .await
397        .unwrap();
398        let resp = svc
399            .list_hooks(Request::new(ListHooksRequest {
400                repo_path: String::new(),
401            }))
402            .await
403            .unwrap();
404        let hooks = resp.into_inner().hooks;
405        assert_eq!(hooks.len(), 1);
406        assert_eq!(hooks[0].name, "log-capture");
407        assert_eq!(hooks[0].events, vec!["post_capture".to_string()]);
408    }
409
410    #[tokio::test]
411    async fn register_unknown_event_is_invalid_argument() {
412        let (_t, svc) = fresh_service();
413        let err = svc
414            .register_hook(Request::new(RegisterHookRequest {
415                repo_path: String::new(),
416                name: "x".into(),
417                command: "true".into(),
418                events: vec!["definitely_not_an_event".into()],
419                timeout_ms: 0,
420                client_operation_id: String::new(),
421            }))
422            .await
423            .unwrap_err();
424        assert_eq!(err.code(), tonic::Code::InvalidArgument);
425    }
426
427    #[tokio::test]
428    async fn deregister_removes_hook() {
429        let (_t, svc) = fresh_service();
430        svc.register_hook(Request::new(RegisterHookRequest {
431            repo_path: String::new(),
432            name: "x".into(),
433            command: "true".into(),
434            events: vec!["pre_capture".into()],
435            timeout_ms: 0,
436            client_operation_id: String::new(),
437        }))
438        .await
439        .unwrap();
440        let resp = svc
441            .deregister_hook(Request::new(DeregisterHookRequest {
442                repo_path: String::new(),
443                name: "x".into(),
444                client_operation_id: String::new(),
445            }))
446            .await
447            .unwrap();
448        assert!(resp.into_inner().deleted);
449        let listed = svc
450            .list_hooks(Request::new(ListHooksRequest {
451                repo_path: String::new(),
452            }))
453            .await
454            .unwrap();
455        assert!(listed.into_inner().hooks.is_empty());
456    }
457
458    #[tokio::test]
459    async fn get_hook_event_schema_returns_full_catalog() {
460        let (_t, svc) = fresh_service();
461        let resp = svc
462            .get_hook_event_schema(Request::new(GetHookEventSchemaRequest {
463                event_name: String::new(),
464            }))
465            .await
466            .unwrap();
467        let catalog = resp.into_inner().schemas;
468        assert!(catalog.iter().any(|s| s.event_name == "pre_capture"));
469        assert!(catalog.iter().any(|s| s.event_name == "on_conflict"));
470    }
471
472    #[tokio::test]
473    async fn get_hook_event_schema_unknown_returns_not_found() {
474        let (_t, svc) = fresh_service();
475        let err = svc
476            .get_hook_event_schema(Request::new(GetHookEventSchemaRequest {
477                event_name: "pretend".into(),
478            }))
479            .await
480            .unwrap_err();
481        assert_eq!(err.code(), tonic::Code::NotFound);
482    }
483
484    #[tokio::test]
485    async fn subscribe_then_emit_round_trips() {
486        let (_t, svc) = fresh_service();
487        let stream = svc
488            .subscribe_hook_events(Request::new(SubscribeHookEventsRequest {
489                repo_path: String::new(),
490                events: vec![],
491            }))
492            .await
493            .unwrap()
494            .into_inner();
495        let mut stream = Box::pin(stream);
496        // Yield so the subscriber's forwarding task is wired up
497        // before the broker emit fires.
498        tokio::task::yield_now().await;
499        let id = svc.inner.hook_events.emit("post_capture", "{}");
500        let event = futures::StreamExt::next(&mut stream)
501            .await
502            .expect("event")
503            .expect("ok");
504        assert_eq!(event.hook_event_id, id);
505        assert_eq!(event.event_name, "post_capture");
506    }
507
508    #[tokio::test]
509    async fn subscribe_unknown_event_is_invalid_argument() {
510        let (_t, svc) = fresh_service();
511        let result = svc
512            .subscribe_hook_events(Request::new(SubscribeHookEventsRequest {
513                repo_path: String::new(),
514                events: vec!["definitely_not_an_event".into()],
515            }))
516            .await;
517        // `Response<Stream>` doesn't implement Debug so we can't
518        // `unwrap_err` here. Match on the result instead.
519        match result {
520            Err(status) => assert_eq!(status.code(), tonic::Code::InvalidArgument),
521            Ok(_) => panic!("expected InvalidArgument, got Ok"),
522        }
523    }
524
525    #[tokio::test]
526    async fn respond_to_hook_delivers_to_emit_waiter() {
527        use std::time::Duration;
528        let (_t, svc) = fresh_service();
529        let _stream = svc
530            .subscribe_hook_events(Request::new(SubscribeHookEventsRequest {
531                repo_path: String::new(),
532                events: vec![],
533            }))
534            .await
535            .unwrap()
536            .into_inner();
537        tokio::task::yield_now().await;
538        let (id, waiter) =
539            svc.inner
540                .hook_events
541                .emit_and_wait("pre_capture", "{}", Duration::from_secs(1));
542        let resp = svc
543            .respond_to_hook(Request::new(RespondToHookRequest {
544                repo_path: String::new(),
545                hook_event_id: id,
546                abort: "veto".into(),
547                extra_signals_json: String::new(),
548                client_operation_id: String::new(),
549            }))
550            .await
551            .unwrap();
552        assert!(resp.into_inner().accepted);
553        let response = waiter.wait().await.expect("response");
554        assert_eq!(response.abort, "veto");
555    }
556
557    #[tokio::test]
558    async fn respond_to_hook_rejects_empty_id() {
559        let (_t, svc) = fresh_service();
560        let err = svc
561            .respond_to_hook(Request::new(RespondToHookRequest {
562                repo_path: String::new(),
563                hook_event_id: String::new(),
564                abort: String::new(),
565                extra_signals_json: String::new(),
566                client_operation_id: String::new(),
567            }))
568            .await
569            .unwrap_err();
570        assert_eq!(err.code(), tonic::Code::InvalidArgument);
571    }
572
573    #[tokio::test]
574    async fn respond_to_hook_unknown_id_returns_not_accepted() {
575        let (_t, svc) = fresh_service();
576        let resp = svc
577            .respond_to_hook(Request::new(RespondToHookRequest {
578                repo_path: String::new(),
579                hook_event_id: "made-up".into(),
580                abort: String::new(),
581                extra_signals_json: String::new(),
582                client_operation_id: String::new(),
583            }))
584            .await
585            .unwrap();
586        assert!(!resp.into_inner().accepted);
587    }
588
589    #[tokio::test]
590    async fn register_idempotent_returns_same_hook() {
591        let (_t, svc) = fresh_service();
592        let op_id = objects::object::OperationId::new().to_string();
593        let req = RegisterHookRequest {
594            repo_path: String::new(),
595            name: "foo".into(),
596            command: "true".into(),
597            events: vec!["pre_capture".into()],
598            timeout_ms: 1000,
599            client_operation_id: op_id.clone(),
600        };
601        let first = svc
602            .register_hook(Request::new(req.clone()))
603            .await
604            .unwrap()
605            .into_inner();
606        let second = svc
607            .register_hook(Request::new(req))
608            .await
609            .unwrap()
610            .into_inner();
611        assert_eq!(first, second);
612        let listed = svc
613            .list_hooks(Request::new(ListHooksRequest {
614                repo_path: String::new(),
615            }))
616            .await
617            .unwrap();
618        assert_eq!(listed.into_inner().hooks.len(), 1);
619    }
620}