1use serde::{Deserialize, Serialize};
8use std::fmt;
9
10use crate::crypto::Hash;
11use crate::error::{Error, Result};
12use crate::event::EventId;
13
14use super::principal::PrincipalId;
15use super::session::SessionId;
16
17#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
31pub struct CausalContext {
32 parent_event_id: Option<EventId>,
35
36 root_event_id: EventId,
39
40 session_id: SessionId,
42
43 principal: PrincipalId,
45
46 depth: u32,
48
49 sequence: u64,
51
52 cross_session_ref: Option<CrossSessionReference>,
54}
55
56impl CausalContext {
57 pub fn builder() -> CausalContextBuilder {
59 CausalContextBuilder::new()
60 }
61
62 pub fn root(event_id: EventId, session_id: SessionId, principal: PrincipalId) -> Self {
66 Self {
67 parent_event_id: None,
68 root_event_id: event_id,
69 session_id,
70 principal,
71 depth: 0,
72 sequence: 0,
73 cross_session_ref: None,
74 }
75 }
76
77 pub fn child(&self, parent_event_id: EventId, sequence: u64) -> Result<Self> {
86 if sequence <= self.sequence {
87 return Err(Error::invalid_input(format!(
88 "Child sequence {} must be greater than parent sequence {}",
89 sequence, self.sequence
90 )));
91 }
92
93 Ok(Self {
94 parent_event_id: Some(parent_event_id),
95 root_event_id: self.root_event_id,
96 session_id: self.session_id,
97 principal: self.principal.clone(),
98 depth: self.depth + 1,
99 sequence,
100 cross_session_ref: None,
101 })
102 }
103
104 pub fn parent_event_id(&self) -> Option<&EventId> {
106 self.parent_event_id.as_ref()
107 }
108
109 pub fn root_event_id(&self) -> &EventId {
111 &self.root_event_id
112 }
113
114 pub fn session_id(&self) -> SessionId {
116 self.session_id
117 }
118
119 pub fn principal(&self) -> &PrincipalId {
121 &self.principal
122 }
123
124 pub fn depth(&self) -> u32 {
126 self.depth
127 }
128
129 pub fn sequence(&self) -> u64 {
131 self.sequence
132 }
133
134 pub fn cross_session_ref(&self) -> Option<&CrossSessionReference> {
136 self.cross_session_ref.as_ref()
137 }
138
139 pub fn is_root(&self) -> bool {
141 self.depth == 0
142 }
143
144 pub fn validate(&self, max_depth: u32) -> Result<()> {
152 if self.depth > max_depth {
154 return Err(Error::invalid_input(format!(
155 "Causal depth {} exceeds maximum {}",
156 self.depth, max_depth
157 )));
158 }
159
160 if self.depth == 0 && self.parent_event_id.is_some() {
162 return Err(Error::invalid_input(
163 "Root event (depth=0) must not have a parent",
164 ));
165 }
166
167 if self.depth > 0 && self.parent_event_id.is_none() {
169 return Err(Error::invalid_input(
170 "Non-root event (depth>0) must have a parent",
171 ));
172 }
173
174 Ok(())
178 }
179
180 pub fn validate_against_parent(&self, parent: &CausalContext) -> Result<()> {
184 if parent.sequence >= self.sequence {
186 return Err(Error::invalid_input(format!(
187 "Parent sequence {} must be less than child sequence {}",
188 parent.sequence, self.sequence
189 )));
190 }
191
192 if parent.depth + 1 != self.depth {
194 return Err(Error::invalid_input(format!(
195 "Parent depth {} + 1 must equal child depth {}",
196 parent.depth, self.depth
197 )));
198 }
199
200 if parent.root_event_id != self.root_event_id {
202 return Err(Error::invalid_input(
203 "Root event ID must match parent's root event ID",
204 ));
205 }
206
207 if self.cross_session_ref.is_none() && parent.session_id != self.session_id {
209 return Err(Error::invalid_input(
210 "Session ID must match parent's session ID (or use cross-session reference)",
211 ));
212 }
213
214 if parent.principal != self.principal {
216 return Err(Error::invalid_input(
217 "Principal must match parent's principal",
218 ));
219 }
220
221 Ok(())
222 }
223}
224
225impl fmt::Display for CausalContext {
226 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227 write!(
228 f,
229 "CausalContext(session={}, depth={}, seq={})",
230 self.session_id, self.depth, self.sequence
231 )
232 }
233}
234
235#[derive(Debug, Default)]
237pub struct CausalContextBuilder {
238 parent_event_id: Option<EventId>,
239 root_event_id: Option<EventId>,
240 session_id: Option<SessionId>,
241 principal: Option<PrincipalId>,
242 depth: Option<u32>,
243 sequence: Option<u64>,
244 cross_session_ref: Option<CrossSessionReference>,
245}
246
247impl CausalContextBuilder {
248 pub fn new() -> Self {
250 Self::default()
251 }
252
253 pub fn parent_event_id(mut self, id: EventId) -> Self {
255 self.parent_event_id = Some(id);
256 self
257 }
258
259 pub fn root_event_id(mut self, id: EventId) -> Self {
261 self.root_event_id = Some(id);
262 self
263 }
264
265 pub fn session_id(mut self, id: SessionId) -> Self {
267 self.session_id = Some(id);
268 self
269 }
270
271 pub fn principal(mut self, principal: PrincipalId) -> Self {
273 self.principal = Some(principal);
274 self
275 }
276
277 pub fn depth(mut self, depth: u32) -> Self {
279 self.depth = Some(depth);
280 self
281 }
282
283 pub fn sequence(mut self, sequence: u64) -> Self {
285 self.sequence = Some(sequence);
286 self
287 }
288
289 pub fn cross_session_ref(mut self, reference: CrossSessionReference) -> Self {
291 self.cross_session_ref = Some(reference);
292 self
293 }
294
295 pub fn build(self) -> Result<CausalContext> {
300 let root_event_id = self
301 .root_event_id
302 .ok_or_else(|| Error::invalid_input("root_event_id is required"))?;
303
304 let session_id = self
305 .session_id
306 .ok_or_else(|| Error::invalid_input("session_id is required"))?;
307
308 let principal = self
309 .principal
310 .ok_or_else(|| Error::invalid_input("principal is required"))?;
311
312 let depth = self.depth.unwrap_or(0);
313 let sequence = self.sequence.unwrap_or(0);
314
315 if depth == 0 && self.parent_event_id.is_some() {
317 return Err(Error::invalid_input(
318 "Root context (depth=0) must not have parent_event_id",
319 ));
320 }
321
322 if depth > 0 && self.parent_event_id.is_none() {
323 return Err(Error::invalid_input(
324 "Non-root context (depth>0) requires parent_event_id",
325 ));
326 }
327
328 Ok(CausalContext {
329 parent_event_id: self.parent_event_id,
330 root_event_id,
331 session_id,
332 principal,
333 depth,
334 sequence,
335 cross_session_ref: self.cross_session_ref,
336 })
337 }
338}
339
340#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
345pub struct CrossSessionReference {
346 pub source_session_id: SessionId,
348
349 pub source_event_id: EventId,
351
352 pub reason: String,
354
355 pub source_event_hash: Hash,
357}
358
359impl CrossSessionReference {
360 pub fn new(
362 source_session_id: SessionId,
363 source_event_id: EventId,
364 reason: impl Into<String>,
365 source_event_hash: Hash,
366 ) -> Self {
367 Self {
368 source_session_id,
369 source_event_id,
370 reason: reason.into(),
371 source_event_hash,
372 }
373 }
374}
375
376pub trait CausalChainQuery {
382 fn trace_to_root(&self, event_id: &EventId) -> Result<Vec<CausalContext>>;
386
387 fn find_root(&self, event_id: &EventId) -> Result<CausalContext>;
389
390 fn events_in_session(&self, session_id: &SessionId) -> Result<Vec<CausalContext>>;
392
393 fn children_of(&self, event_id: &EventId) -> Result<Vec<CausalContext>>;
395
396 fn max_depth_in_session(&self, session_id: &SessionId) -> Result<u32>;
398}
399
400#[derive(Debug, Default)]
403pub struct InMemoryCausalStore {
404 contexts: std::collections::HashMap<EventId, CausalContext>,
406 session_events: std::collections::HashMap<SessionId, Vec<EventId>>,
408}
409
410impl InMemoryCausalStore {
411 pub fn new() -> Self {
413 Self::default()
414 }
415
416 pub fn insert(&mut self, event_id: EventId, context: CausalContext) {
418 let session_id = context.session_id();
419 self.session_events
420 .entry(session_id)
421 .or_default()
422 .push(event_id);
423 self.contexts.insert(event_id, context);
424 }
425
426 pub fn get(&self, event_id: &EventId) -> Option<&CausalContext> {
428 self.contexts.get(event_id)
429 }
430
431 pub fn len(&self) -> usize {
433 self.contexts.len()
434 }
435
436 pub fn is_empty(&self) -> bool {
438 self.contexts.is_empty()
439 }
440}
441
442impl CausalChainQuery for InMemoryCausalStore {
443 fn trace_to_root(&self, event_id: &EventId) -> Result<Vec<CausalContext>> {
444 let mut chain = Vec::new();
445 let mut current_id = *event_id;
446
447 loop {
448 let ctx = self.contexts.get(¤t_id).ok_or_else(|| {
449 Error::invalid_input(format!("event {} not found in causal store", current_id))
450 })?;
451 chain.push(ctx.clone());
452
453 if ctx.is_root() {
454 break;
455 }
456
457 match ctx.parent_event_id() {
458 Some(parent) => current_id = *parent,
459 None => break,
460 }
461 }
462
463 chain.reverse();
464 Ok(chain)
465 }
466
467 fn find_root(&self, event_id: &EventId) -> Result<CausalContext> {
468 let chain = self.trace_to_root(event_id)?;
469 chain
470 .into_iter()
471 .next()
472 .ok_or_else(|| Error::invalid_input("empty causal chain"))
473 }
474
475 fn events_in_session(&self, session_id: &SessionId) -> Result<Vec<CausalContext>> {
476 let event_ids = self
477 .session_events
478 .get(session_id)
479 .cloned()
480 .unwrap_or_default();
481 let mut contexts: Vec<CausalContext> = event_ids
482 .iter()
483 .filter_map(|id| self.contexts.get(id).cloned())
484 .collect();
485 contexts.sort_by_key(|c| c.sequence());
486 Ok(contexts)
487 }
488
489 fn children_of(&self, event_id: &EventId) -> Result<Vec<CausalContext>> {
490 let children: Vec<CausalContext> = self
491 .contexts
492 .values()
493 .filter(|ctx| ctx.parent_event_id() == Some(event_id))
494 .cloned()
495 .collect();
496 Ok(children)
497 }
498
499 fn max_depth_in_session(&self, session_id: &SessionId) -> Result<u32> {
500 let events = self.events_in_session(session_id)?;
501 Ok(events.iter().map(|c| c.depth()).max().unwrap_or(0))
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508 use crate::crypto::hash;
509
510 fn test_event_id() -> EventId {
511 EventId(hash(b"test-event"))
512 }
513
514 fn test_session_id() -> SessionId {
515 SessionId::random()
516 }
517
518 fn test_principal() -> PrincipalId {
519 PrincipalId::user("alice").unwrap()
520 }
521
522 #[test]
525 fn causal_context_root_created_successfully() {
526 let event_id = test_event_id();
527 let session_id = test_session_id();
528 let principal = test_principal();
529
530 let ctx = CausalContext::root(event_id, session_id, principal.clone());
531
532 assert!(ctx.is_root());
533 assert_eq!(ctx.depth(), 0);
534 assert_eq!(ctx.sequence(), 0);
535 assert!(ctx.parent_event_id().is_none());
536 assert_eq!(ctx.root_event_id(), &event_id);
537 assert_eq!(ctx.principal(), &principal);
538 }
539
540 #[test]
541 fn causal_context_requires_session_id() {
542 let result = CausalContext::builder()
543 .root_event_id(test_event_id())
544 .principal(test_principal())
545 .build();
546
547 assert!(result.is_err());
548 }
549
550 #[test]
551 fn causal_context_requires_principal() {
552 let result = CausalContext::builder()
553 .root_event_id(test_event_id())
554 .session_id(test_session_id())
555 .build();
556
557 assert!(result.is_err());
558 }
559
560 #[test]
561 fn causal_context_depth_zero_has_no_parent() {
562 let ctx = CausalContext::builder()
563 .root_event_id(test_event_id())
564 .session_id(test_session_id())
565 .principal(test_principal())
566 .depth(0)
567 .build()
568 .unwrap();
569
570 assert!(ctx.parent_event_id().is_none());
571 assert!(ctx.is_root());
572 }
573
574 #[test]
575 fn causal_context_depth_zero_with_parent_rejected() {
576 let result = CausalContext::builder()
577 .root_event_id(test_event_id())
578 .session_id(test_session_id())
579 .principal(test_principal())
580 .depth(0)
581 .parent_event_id(test_event_id())
582 .build();
583
584 assert!(result.is_err());
585 }
586
587 #[test]
588 fn causal_context_depth_nonzero_requires_parent() {
589 let result = CausalContext::builder()
590 .root_event_id(test_event_id())
591 .session_id(test_session_id())
592 .principal(test_principal())
593 .depth(1)
594 .build();
595
596 assert!(result.is_err());
597 }
598
599 #[test]
600 fn causal_context_depth_nonzero_with_parent_succeeds() {
601 let ctx = CausalContext::builder()
602 .root_event_id(test_event_id())
603 .session_id(test_session_id())
604 .principal(test_principal())
605 .depth(1)
606 .sequence(1)
607 .parent_event_id(test_event_id())
608 .build()
609 .unwrap();
610
611 assert!(!ctx.is_root());
612 assert_eq!(ctx.depth(), 1);
613 }
614
615 #[test]
618 fn child_context_created_successfully() {
619 let root = CausalContext::root(test_event_id(), test_session_id(), test_principal());
620
621 let parent_id = test_event_id();
622 let child = root.child(parent_id, 1).unwrap();
623
624 assert_eq!(child.depth(), 1);
625 assert_eq!(child.sequence(), 1);
626 assert_eq!(child.parent_event_id(), Some(&parent_id));
627 assert_eq!(child.root_event_id(), root.root_event_id());
628 }
629
630 #[test]
631 fn child_sequence_must_exceed_parent() {
632 let root = CausalContext::root(test_event_id(), test_session_id(), test_principal());
633
634 let result = root.child(test_event_id(), 0);
636 assert!(result.is_err());
637
638 let result = root.child(test_event_id(), 1);
640 assert!(result.is_ok());
641 }
642
643 #[test]
646 fn validate_rejects_depth_exceeding_max() {
647 let ctx = CausalContext::builder()
648 .root_event_id(test_event_id())
649 .session_id(test_session_id())
650 .principal(test_principal())
651 .depth(5)
652 .sequence(5)
653 .parent_event_id(test_event_id())
654 .build()
655 .unwrap();
656
657 let result = ctx.validate(3);
659 assert!(result.is_err());
660
661 let result = ctx.validate(10);
663 assert!(result.is_ok());
664 }
665
666 #[test]
667 fn validate_against_parent_checks_sequence() {
668 let parent = CausalContext::root(test_event_id(), test_session_id(), test_principal());
669
670 let child = parent.child(test_event_id(), 1).unwrap();
671
672 assert!(child.validate_against_parent(&parent).is_ok());
674
675 let invalid_child = CausalContext::builder()
677 .root_event_id(parent.root_event_id)
678 .session_id(parent.session_id)
679 .principal(parent.principal.clone())
680 .depth(1)
681 .sequence(0) .parent_event_id(test_event_id())
683 .build()
684 .unwrap();
685
686 assert!(invalid_child.validate_against_parent(&parent).is_err());
687 }
688
689 #[test]
690 fn validate_against_parent_checks_depth() {
691 let parent = CausalContext::root(test_event_id(), test_session_id(), test_principal());
692
693 let valid_child = parent.child(test_event_id(), 1).unwrap();
695 assert!(valid_child.validate_against_parent(&parent).is_ok());
696
697 let invalid_child = CausalContext::builder()
699 .root_event_id(parent.root_event_id)
700 .session_id(parent.session_id)
701 .principal(parent.principal.clone())
702 .depth(2) .sequence(1)
704 .parent_event_id(test_event_id())
705 .build()
706 .unwrap();
707
708 assert!(invalid_child.validate_against_parent(&parent).is_err());
709 }
710
711 #[test]
712 fn validate_accepts_cross_session_reference() {
713 let source_session = test_session_id();
714 let target_session = test_session_id();
715 let event_id = test_event_id();
716
717 let cross_ref = CrossSessionReference::new(
718 source_session,
719 event_id,
720 "Follow-up task",
721 hash(b"event-data"),
722 );
723
724 let ctx = CausalContext::builder()
725 .root_event_id(event_id)
726 .session_id(target_session)
727 .principal(test_principal())
728 .depth(1)
729 .sequence(1)
730 .parent_event_id(event_id)
731 .cross_session_ref(cross_ref)
732 .build()
733 .unwrap();
734
735 assert!(ctx.cross_session_ref().is_some());
736 }
737
738 #[test]
741 fn display_format_correct() {
742 let ctx = CausalContext::root(test_event_id(), test_session_id(), test_principal());
743
744 let display = format!("{}", ctx);
745 assert!(display.contains("CausalContext"));
746 assert!(display.contains("depth=0"));
747 assert!(display.contains("seq=0"));
748 }
749}