1use std::{path::PathBuf, pin::Pin};
8
9use futures::Stream;
10use grpc::heddle::v1::{
11 DeleteResponse, DeregisterHookRequest, GetHookEventSchemaRequest, GetHookEventSchemaResponse,
12 Hook as ProtoHook, HookEvent as ProtoHookEvent, HookEventSchema, ListHooksRequest,
13 ListHooksResponse, RegisterHookRequest, RespondToHookRequest, RespondToHookResponse,
14 SubscribeHookEventsRequest, hook_service_server::HookService,
15};
16use objects::{error::HeddleError, fs_atomic::write_file_atomic};
17use prost::Message;
18use serde::{Deserialize, Serialize};
19use tokio_stream::{StreamExt, wrappers::ReceiverStream};
20use tonic::{Request, Response, Status};
21
22use super::{GrpcLocalService, HookResponse, to_status, with_idempotency};
23
24#[derive(Clone)]
25pub struct LocalHookService {
26 inner: GrpcLocalService,
27}
28
29impl LocalHookService {
30 pub fn new(inner: GrpcLocalService) -> Self {
31 Self { inner }
32 }
33}
34
35#[derive(Debug, Clone, Default, Serialize, Deserialize)]
36struct HookRegistry {
37 #[serde(default)]
38 hooks: Vec<HookConfig>,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
42struct HookConfig {
43 name: String,
44 command: String,
45 #[serde(default)]
46 events: Vec<String>,
47 #[serde(default)]
48 timeout_ms: u32,
49}
50
51impl HookConfig {
52 fn to_proto(&self) -> ProtoHook {
53 ProtoHook {
54 name: self.name.clone(),
55 command: self.command.clone(),
56 events: self.events.clone(),
57 timeout_ms: self.timeout_ms,
58 }
59 }
60}
61
62fn registry_path(heddle_dir: &std::path::Path) -> PathBuf {
63 heddle_dir.join("hooks").join("registry.toml")
64}
65
66fn load_registry(heddle_dir: &std::path::Path) -> Result<HookRegistry, Status> {
67 let path = registry_path(heddle_dir);
68 if !path.exists() {
69 return Ok(HookRegistry::default());
70 }
71 let raw = std::fs::read_to_string(&path).map_err(|e| to_status(HeddleError::from(e)))?;
72 toml::from_str(&raw).map_err(|e| {
73 Status::internal(format!(
74 "hook registry at {} is malformed: {e}",
75 path.display()
76 ))
77 })
78}
79
80fn save_registry(heddle_dir: &std::path::Path, registry: &HookRegistry) -> Result<(), Status> {
81 let path = registry_path(heddle_dir);
82 if let Some(parent) = path.parent() {
83 std::fs::create_dir_all(parent).map_err(|e| to_status(HeddleError::from(e)))?;
84 }
85 let raw = toml::to_string_pretty(registry)
86 .map_err(|e| Status::internal(format!("failed to encode hook registry: {e}")))?;
87 write_file_atomic(&path, raw.as_bytes()).map_err(|e| to_status(HeddleError::from(e)))
88}
89
90fn event_catalog() -> Vec<HookEventSchema> {
95 let v1 = 1;
96 vec![
97 HookEventSchema {
98 event_name: "pre_capture".to_string(),
99 schema_version: v1,
100 payload_schema_json: r#"{"type":"object","properties":{"thread":{"type":"string"},"intent":{"type":"string"}},"required":[]}"#.to_string(),
101 response_schema_json: r#"{"type":"object","properties":{"extra_signals":{"type":"array"},"abort":{"type":"string"}}}"#.to_string(),
102 },
103 HookEventSchema {
104 event_name: "post_capture".to_string(),
105 schema_version: v1,
106 payload_schema_json: r#"{"type":"object","properties":{"state_id":{"type":"string"}}}"#.to_string(),
107 response_schema_json: r#"{"type":"object"}"#.to_string(),
108 },
109 HookEventSchema {
110 event_name: "pre_merge".to_string(),
111 schema_version: v1,
112 payload_schema_json: r#"{"type":"object","properties":{"source":{"type":"string"},"target":{"type":"string"}}}"#.to_string(),
113 response_schema_json: r#"{"type":"object","properties":{"abort":{"type":"string"}}}"#.to_string(),
114 },
115 HookEventSchema {
116 event_name: "post_merge".to_string(),
117 schema_version: v1,
118 payload_schema_json: r#"{"type":"object","properties":{"state_id":{"type":"string"}}}"#.to_string(),
119 response_schema_json: r#"{"type":"object"}"#.to_string(),
120 },
121 HookEventSchema {
122 event_name: "on_conflict".to_string(),
123 schema_version: v1,
124 payload_schema_json: r#"{"type":"object","properties":{"conflicts":{"type":"array"}}}"#.to_string(),
125 response_schema_json: r#"{"type":"object","properties":{"veto":{"type":"object","properties":{"reason":{"type":"string"},"discussion_id":{"type":"string"}}}}}"#.to_string(),
126 },
127 HookEventSchema {
128 event_name: "pre_thread_create".to_string(),
129 schema_version: v1,
130 payload_schema_json: r#"{"type":"object","properties":{"name":{"type":"string"}}}"#.to_string(),
131 response_schema_json: r#"{"type":"object","properties":{"abort":{"type":"string"}}}"#.to_string(),
132 },
133 HookEventSchema {
134 event_name: "post_thread_create".to_string(),
135 schema_version: v1,
136 payload_schema_json: r#"{"type":"object","properties":{"name":{"type":"string"}}}"#.to_string(),
137 response_schema_json: r#"{"type":"object"}"#.to_string(),
138 },
139 HookEventSchema {
140 event_name: "pre_push".to_string(),
141 schema_version: v1,
142 payload_schema_json: r#"{"type":"object","properties":{"remote":{"type":"string"}}}"#.to_string(),
143 response_schema_json: r#"{"type":"object","properties":{"abort":{"type":"string"}}}"#.to_string(),
144 },
145 HookEventSchema {
146 event_name: "post_push".to_string(),
147 schema_version: v1,
148 payload_schema_json: r#"{"type":"object","properties":{"remote":{"type":"string"}}}"#.to_string(),
149 response_schema_json: r#"{"type":"object"}"#.to_string(),
150 },
151 HookEventSchema {
152 event_name: "on_signal".to_string(),
153 schema_version: v1,
154 payload_schema_json: r#"{"type":"object","properties":{"state_id":{"type":"string"},"signal_kind":{"type":"string"}}}"#.to_string(),
155 response_schema_json: r#"{"type":"object"}"#.to_string(),
156 },
157 ]
158}
159
160pub type SubscribeHookEventsStream =
164 Pin<Box<dyn Stream<Item = Result<ProtoHookEvent, Status>> + Send>>;
165
166#[tonic::async_trait]
167impl HookService for LocalHookService {
168 type SubscribeHookEventsStream = SubscribeHookEventsStream;
169
170 async fn register_hook(
171 &self,
172 request: Request<RegisterHookRequest>,
173 ) -> Result<Response<ProtoHook>, Status> {
174 let req = request.into_inner();
175 let body = req.encode_to_vec();
176 let heddle_dir = self.inner.repo().heddle_dir().to_path_buf();
177 let client_op = req.client_operation_id.clone();
178
179 let result = with_idempotency(
180 &self.inner,
181 &client_op,
182 "hook.register_hook",
183 &body,
184 || async move {
185 if req.name.trim().is_empty() {
186 return Err(Status::invalid_argument("hook name must not be empty"));
187 }
188 if req.command.trim().is_empty() {
189 return Err(Status::invalid_argument("hook command must not be empty"));
190 }
191 let catalog: std::collections::HashSet<String> =
192 event_catalog().into_iter().map(|s| s.event_name).collect();
193 for event in &req.events {
194 if !catalog.contains(event) {
195 return Err(Status::invalid_argument(format!(
196 "unknown hook event '{event}' — see GetHookEventSchema for the catalog"
197 )));
198 }
199 }
200 let mut registry = load_registry(&heddle_dir)?;
201 registry.hooks.retain(|h| h.name != req.name);
202 let cfg = HookConfig {
203 name: req.name.clone(),
204 command: req.command.clone(),
205 events: req.events.clone(),
206 timeout_ms: req.timeout_ms,
207 };
208 registry.hooks.push(cfg.clone());
209 save_registry(&heddle_dir, ®istry)?;
210 Ok(cfg.to_proto())
211 },
212 )
213 .await?;
214 Ok(Response::new(result))
215 }
216
217 async fn deregister_hook(
218 &self,
219 request: Request<DeregisterHookRequest>,
220 ) -> Result<Response<DeleteResponse>, Status> {
221 let req = request.into_inner();
222 let body = req.encode_to_vec();
223 let heddle_dir = self.inner.repo().heddle_dir().to_path_buf();
224 let client_op = req.client_operation_id.clone();
225 let result = with_idempotency(
226 &self.inner,
227 &client_op,
228 "hook.deregister_hook",
229 &body,
230 || async move {
231 let mut registry = load_registry(&heddle_dir)?;
232 let before = registry.hooks.len();
233 registry.hooks.retain(|h| h.name != req.name);
234 let deleted = registry.hooks.len() < before;
235 if deleted {
236 save_registry(&heddle_dir, ®istry)?;
237 }
238 Ok(DeleteResponse { deleted })
239 },
240 )
241 .await?;
242 Ok(Response::new(result))
243 }
244
245 async fn list_hooks(
246 &self,
247 _request: Request<ListHooksRequest>,
248 ) -> Result<Response<ListHooksResponse>, Status> {
249 let registry = load_registry(self.inner.repo().heddle_dir())?;
250 let hooks = registry.hooks.iter().map(HookConfig::to_proto).collect();
251 Ok(Response::new(ListHooksResponse { hooks }))
252 }
253
254 async fn get_hook_event_schema(
255 &self,
256 request: Request<GetHookEventSchemaRequest>,
257 ) -> Result<Response<GetHookEventSchemaResponse>, Status> {
258 let req = request.into_inner();
259 let mut catalog = event_catalog();
260 if !req.event_name.is_empty() {
261 catalog.retain(|s| s.event_name == req.event_name);
262 if catalog.is_empty() {
263 return Err(Status::not_found(format!(
264 "unknown hook event '{}'",
265 req.event_name
266 )));
267 }
268 }
269 Ok(Response::new(GetHookEventSchemaResponse {
270 schemas: catalog,
271 }))
272 }
273
274 async fn subscribe_hook_events(
275 &self,
276 request: Request<SubscribeHookEventsRequest>,
277 ) -> Result<Response<Self::SubscribeHookEventsStream>, Status> {
278 let req = request.into_inner();
279 let catalog: std::collections::HashSet<String> =
284 event_catalog().into_iter().map(|s| s.event_name).collect();
285 for event in &req.events {
286 if !catalog.contains(event) {
287 return Err(Status::invalid_argument(format!(
288 "unknown hook event '{event}' — see GetHookEventSchema for the catalog"
289 )));
290 }
291 }
292 let filter: std::collections::HashSet<String> = req.events.into_iter().collect();
293 let receiver = self.inner.hook_events.subscribe();
294 let stream = ReceiverStream::new(receiver).filter_map(move |event| {
299 if filter.is_empty() || filter.contains(&event.event_name) {
300 Some(Ok(event))
301 } else {
302 None
303 }
304 });
305 Ok(Response::new(Box::pin(stream)))
306 }
307
308 async fn respond_to_hook(
309 &self,
310 request: Request<RespondToHookRequest>,
311 ) -> Result<Response<RespondToHookResponse>, Status> {
312 let req = request.into_inner();
313 let body = req.encode_to_vec();
314 let client_op = req.client_operation_id.clone();
315 let broker = self.inner.hook_events.clone();
316 let result = with_idempotency(
317 &self.inner,
318 &client_op,
319 "hook.respond_to_hook",
320 &body,
321 move || async move {
322 if req.hook_event_id.trim().is_empty() {
323 return Err(Status::invalid_argument("hook_event_id must not be empty"));
324 }
325 let extra = if req.extra_signals_json.trim().is_empty() {
331 serde_json::Value::Null
332 } else {
333 serde_json::from_str::<serde_json::Value>(&req.extra_signals_json).map_err(
334 |err| {
335 Status::invalid_argument(format!(
336 "extra_signals_json is not valid JSON: {err}"
337 ))
338 },
339 )?
340 };
341 let response = HookResponse {
342 abort: req.abort,
343 extra,
344 };
345 let accepted = broker.deliver_response(&req.hook_event_id, response);
346 Ok(RespondToHookResponse { accepted })
347 },
348 )
349 .await?;
350 Ok(Response::new(result))
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use std::sync::Arc;
357
358 use repo::Repository;
359 use tempfile::TempDir;
360
361 use super::*;
362
363 fn fresh_service() -> (TempDir, LocalHookService) {
364 let temp = TempDir::new().unwrap();
365 let repo = Repository::init_default(temp.path()).unwrap();
366 let dedup =
367 Arc::new(repo::operation_dedup::OperationDedupStore::open(repo.heddle_dir()).unwrap());
368 let inner = GrpcLocalService::new(Arc::new(repo), dedup);
369 let svc = LocalHookService::new(inner);
370 (temp, svc)
371 }
372
373 #[tokio::test]
374 #[serial_test::serial(process_global)]
375 async fn register_then_list_returns_hook() {
376 let (_t, svc) = fresh_service();
377 svc.register_hook(Request::new(RegisterHookRequest {
378 repo_path: String::new(),
379 name: "log-capture".into(),
380 command: "/usr/local/bin/heddle-log".into(),
381 events: vec!["post_capture".into()],
382 timeout_ms: 5000,
383 client_operation_id: String::new(),
384 }))
385 .await
386 .unwrap();
387 let resp = svc
388 .list_hooks(Request::new(ListHooksRequest {
389 repo_path: String::new(),
390 }))
391 .await
392 .unwrap();
393 let hooks = resp.into_inner().hooks;
394 assert_eq!(hooks.len(), 1);
395 assert_eq!(hooks[0].name, "log-capture");
396 assert_eq!(hooks[0].events, vec!["post_capture".to_string()]);
397 }
398
399 #[tokio::test]
400 #[serial_test::serial(process_global)]
401 async fn register_unknown_event_is_invalid_argument() {
402 let (_t, svc) = fresh_service();
403 let err = svc
404 .register_hook(Request::new(RegisterHookRequest {
405 repo_path: String::new(),
406 name: "x".into(),
407 command: "true".into(),
408 events: vec!["definitely_not_an_event".into()],
409 timeout_ms: 0,
410 client_operation_id: String::new(),
411 }))
412 .await
413 .unwrap_err();
414 assert_eq!(err.code(), tonic::Code::InvalidArgument);
415 }
416
417 #[tokio::test]
418 #[serial_test::serial(process_global)]
419 async fn deregister_removes_hook() {
420 let (_t, svc) = fresh_service();
421 svc.register_hook(Request::new(RegisterHookRequest {
422 repo_path: String::new(),
423 name: "x".into(),
424 command: "true".into(),
425 events: vec!["pre_capture".into()],
426 timeout_ms: 0,
427 client_operation_id: String::new(),
428 }))
429 .await
430 .unwrap();
431 let resp = svc
432 .deregister_hook(Request::new(DeregisterHookRequest {
433 repo_path: String::new(),
434 name: "x".into(),
435 client_operation_id: String::new(),
436 }))
437 .await
438 .unwrap();
439 assert!(resp.into_inner().deleted);
440 let listed = svc
441 .list_hooks(Request::new(ListHooksRequest {
442 repo_path: String::new(),
443 }))
444 .await
445 .unwrap();
446 assert!(listed.into_inner().hooks.is_empty());
447 }
448
449 #[tokio::test]
450 #[serial_test::serial(process_global)]
451 async fn get_hook_event_schema_returns_full_catalog() {
452 let (_t, svc) = fresh_service();
453 let resp = svc
454 .get_hook_event_schema(Request::new(GetHookEventSchemaRequest {
455 event_name: String::new(),
456 }))
457 .await
458 .unwrap();
459 let catalog = resp.into_inner().schemas;
460 assert!(catalog.iter().any(|s| s.event_name == "pre_capture"));
461 assert!(catalog.iter().any(|s| s.event_name == "on_conflict"));
462 }
463
464 #[tokio::test]
465 #[serial_test::serial(process_global)]
466 async fn get_hook_event_schema_unknown_returns_not_found() {
467 let (_t, svc) = fresh_service();
468 let err = svc
469 .get_hook_event_schema(Request::new(GetHookEventSchemaRequest {
470 event_name: "pretend".into(),
471 }))
472 .await
473 .unwrap_err();
474 assert_eq!(err.code(), tonic::Code::NotFound);
475 }
476
477 #[tokio::test]
478 #[serial_test::serial(process_global)]
479 async fn subscribe_then_emit_round_trips() {
480 let (_t, svc) = fresh_service();
481 let stream = svc
482 .subscribe_hook_events(Request::new(SubscribeHookEventsRequest {
483 repo_path: String::new(),
484 events: vec![],
485 }))
486 .await
487 .unwrap()
488 .into_inner();
489 let mut stream = Box::pin(stream);
490 tokio::task::yield_now().await;
493 let id = svc.inner.hook_events.emit("post_capture", "{}");
494 let event = futures::StreamExt::next(&mut stream)
495 .await
496 .expect("event")
497 .expect("ok");
498 assert_eq!(event.hook_event_id, id);
499 assert_eq!(event.event_name, "post_capture");
500 }
501
502 #[tokio::test]
503 #[serial_test::serial(process_global)]
504 async fn subscribe_unknown_event_is_invalid_argument() {
505 let (_t, svc) = fresh_service();
506 let result = svc
507 .subscribe_hook_events(Request::new(SubscribeHookEventsRequest {
508 repo_path: String::new(),
509 events: vec!["definitely_not_an_event".into()],
510 }))
511 .await;
512 match result {
515 Err(status) => assert_eq!(status.code(), tonic::Code::InvalidArgument),
516 Ok(_) => panic!("expected InvalidArgument, got Ok"),
517 }
518 }
519
520 #[tokio::test]
521 #[serial_test::serial(process_global)]
522 async fn respond_to_hook_delivers_to_emit_waiter() {
523 use std::time::Duration;
524 let (_t, svc) = fresh_service();
525 let _stream = svc
526 .subscribe_hook_events(Request::new(SubscribeHookEventsRequest {
527 repo_path: String::new(),
528 events: vec![],
529 }))
530 .await
531 .unwrap()
532 .into_inner();
533 tokio::task::yield_now().await;
534 let (id, waiter) =
535 svc.inner
536 .hook_events
537 .emit_and_wait("pre_capture", "{}", Duration::from_secs(1));
538 let resp = svc
539 .respond_to_hook(Request::new(RespondToHookRequest {
540 repo_path: String::new(),
541 hook_event_id: id,
542 abort: "veto".into(),
543 extra_signals_json: String::new(),
544 client_operation_id: String::new(),
545 }))
546 .await
547 .unwrap();
548 assert!(resp.into_inner().accepted);
549 let response = waiter.wait().await.expect("response");
550 assert_eq!(response.abort, "veto");
551 }
552
553 #[tokio::test]
554 #[serial_test::serial(process_global)]
555 async fn respond_to_hook_rejects_empty_id() {
556 let (_t, svc) = fresh_service();
557 let err = svc
558 .respond_to_hook(Request::new(RespondToHookRequest {
559 repo_path: String::new(),
560 hook_event_id: String::new(),
561 abort: String::new(),
562 extra_signals_json: String::new(),
563 client_operation_id: String::new(),
564 }))
565 .await
566 .unwrap_err();
567 assert_eq!(err.code(), tonic::Code::InvalidArgument);
568 }
569
570 #[tokio::test]
571 #[serial_test::serial(process_global)]
572 async fn respond_to_hook_unknown_id_returns_not_accepted() {
573 let (_t, svc) = fresh_service();
574 let resp = svc
575 .respond_to_hook(Request::new(RespondToHookRequest {
576 repo_path: String::new(),
577 hook_event_id: "made-up".into(),
578 abort: String::new(),
579 extra_signals_json: String::new(),
580 client_operation_id: String::new(),
581 }))
582 .await
583 .unwrap();
584 assert!(!resp.into_inner().accepted);
585 }
586
587 #[tokio::test]
588 #[serial_test::serial(process_global)]
589 async fn register_idempotent_returns_same_hook() {
590 let (_t, svc) = fresh_service();
591 let op_id = objects::object::OperationId::new().to_string();
592 let req = RegisterHookRequest {
593 repo_path: String::new(),
594 name: "foo".into(),
595 command: "true".into(),
596 events: vec!["pre_capture".into()],
597 timeout_ms: 1000,
598 client_operation_id: op_id.clone(),
599 };
600 let first = svc
601 .register_hook(Request::new(req.clone()))
602 .await
603 .unwrap()
604 .into_inner();
605 let second = svc
606 .register_hook(Request::new(req))
607 .await
608 .unwrap()
609 .into_inner();
610 assert_eq!(first, second);
611 let listed = svc
612 .list_hooks(Request::new(ListHooksRequest {
613 repo_path: String::new(),
614 }))
615 .await
616 .unwrap();
617 assert_eq!(listed.into_inner().hooks.len(), 1);
618 }
619}