use serde::{Deserialize, Serialize};
use std::fmt;
use crate::crypto::Hash;
use crate::error::{Error, Result};
use crate::event::EventId;
use super::principal::PrincipalId;
use super::session::SessionId;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CausalContext {
parent_event_id: Option<EventId>,
root_event_id: EventId,
session_id: SessionId,
principal: PrincipalId,
depth: u32,
sequence: u64,
cross_session_ref: Option<CrossSessionReference>,
}
impl CausalContext {
pub fn builder() -> CausalContextBuilder {
CausalContextBuilder::new()
}
pub fn root(event_id: EventId, session_id: SessionId, principal: PrincipalId) -> Self {
Self {
parent_event_id: None,
root_event_id: event_id,
session_id,
principal,
depth: 0,
sequence: 0,
cross_session_ref: None,
}
}
pub fn child(&self, parent_event_id: EventId, sequence: u64) -> Result<Self> {
if sequence <= self.sequence {
return Err(Error::invalid_input(format!(
"Child sequence {} must be greater than parent sequence {}",
sequence, self.sequence
)));
}
Ok(Self {
parent_event_id: Some(parent_event_id),
root_event_id: self.root_event_id,
session_id: self.session_id,
principal: self.principal.clone(),
depth: self.depth + 1,
sequence,
cross_session_ref: None,
})
}
pub fn parent_event_id(&self) -> Option<&EventId> {
self.parent_event_id.as_ref()
}
pub fn root_event_id(&self) -> &EventId {
&self.root_event_id
}
pub fn session_id(&self) -> SessionId {
self.session_id
}
pub fn principal(&self) -> &PrincipalId {
&self.principal
}
pub fn depth(&self) -> u32 {
self.depth
}
pub fn sequence(&self) -> u64 {
self.sequence
}
pub fn cross_session_ref(&self) -> Option<&CrossSessionReference> {
self.cross_session_ref.as_ref()
}
pub fn is_root(&self) -> bool {
self.depth == 0
}
pub fn validate(&self, max_depth: u32) -> Result<()> {
if self.depth > max_depth {
return Err(Error::invalid_input(format!(
"Causal depth {} exceeds maximum {}",
self.depth, max_depth
)));
}
if self.depth == 0 && self.parent_event_id.is_some() {
return Err(Error::invalid_input(
"Root event (depth=0) must not have a parent",
));
}
if self.depth > 0 && self.parent_event_id.is_none() {
return Err(Error::invalid_input(
"Non-root event (depth>0) must have a parent",
));
}
Ok(())
}
pub fn validate_against_parent(&self, parent: &CausalContext) -> Result<()> {
if parent.sequence >= self.sequence {
return Err(Error::invalid_input(format!(
"Parent sequence {} must be less than child sequence {}",
parent.sequence, self.sequence
)));
}
if parent.depth + 1 != self.depth {
return Err(Error::invalid_input(format!(
"Parent depth {} + 1 must equal child depth {}",
parent.depth, self.depth
)));
}
if parent.root_event_id != self.root_event_id {
return Err(Error::invalid_input(
"Root event ID must match parent's root event ID",
));
}
if self.cross_session_ref.is_none() && parent.session_id != self.session_id {
return Err(Error::invalid_input(
"Session ID must match parent's session ID (or use cross-session reference)",
));
}
if parent.principal != self.principal {
return Err(Error::invalid_input(
"Principal must match parent's principal",
));
}
Ok(())
}
}
impl fmt::Display for CausalContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"CausalContext(session={}, depth={}, seq={})",
self.session_id, self.depth, self.sequence
)
}
}
#[derive(Debug, Default)]
pub struct CausalContextBuilder {
parent_event_id: Option<EventId>,
root_event_id: Option<EventId>,
session_id: Option<SessionId>,
principal: Option<PrincipalId>,
depth: Option<u32>,
sequence: Option<u64>,
cross_session_ref: Option<CrossSessionReference>,
}
impl CausalContextBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn parent_event_id(mut self, id: EventId) -> Self {
self.parent_event_id = Some(id);
self
}
pub fn root_event_id(mut self, id: EventId) -> Self {
self.root_event_id = Some(id);
self
}
pub fn session_id(mut self, id: SessionId) -> Self {
self.session_id = Some(id);
self
}
pub fn principal(mut self, principal: PrincipalId) -> Self {
self.principal = Some(principal);
self
}
pub fn depth(mut self, depth: u32) -> Self {
self.depth = Some(depth);
self
}
pub fn sequence(mut self, sequence: u64) -> Self {
self.sequence = Some(sequence);
self
}
pub fn cross_session_ref(mut self, reference: CrossSessionReference) -> Self {
self.cross_session_ref = Some(reference);
self
}
pub fn build(self) -> Result<CausalContext> {
let root_event_id = self
.root_event_id
.ok_or_else(|| Error::invalid_input("root_event_id is required"))?;
let session_id = self
.session_id
.ok_or_else(|| Error::invalid_input("session_id is required"))?;
let principal = self
.principal
.ok_or_else(|| Error::invalid_input("principal is required"))?;
let depth = self.depth.unwrap_or(0);
let sequence = self.sequence.unwrap_or(0);
if depth == 0 && self.parent_event_id.is_some() {
return Err(Error::invalid_input(
"Root context (depth=0) must not have parent_event_id",
));
}
if depth > 0 && self.parent_event_id.is_none() {
return Err(Error::invalid_input(
"Non-root context (depth>0) requires parent_event_id",
));
}
Ok(CausalContext {
parent_event_id: self.parent_event_id,
root_event_id,
session_id,
principal,
depth,
sequence,
cross_session_ref: self.cross_session_ref,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CrossSessionReference {
pub source_session_id: SessionId,
pub source_event_id: EventId,
pub reason: String,
pub source_event_hash: Hash,
}
impl CrossSessionReference {
pub fn new(
source_session_id: SessionId,
source_event_id: EventId,
reason: impl Into<String>,
source_event_hash: Hash,
) -> Self {
Self {
source_session_id,
source_event_id,
reason: reason.into(),
source_event_hash,
}
}
}
pub trait CausalChainQuery {
fn trace_to_root(&self, event_id: &EventId) -> Result<Vec<CausalContext>>;
fn find_root(&self, event_id: &EventId) -> Result<CausalContext>;
fn events_in_session(&self, session_id: &SessionId) -> Result<Vec<CausalContext>>;
fn children_of(&self, event_id: &EventId) -> Result<Vec<CausalContext>>;
fn max_depth_in_session(&self, session_id: &SessionId) -> Result<u32>;
}
#[derive(Debug, Default)]
pub struct InMemoryCausalStore {
contexts: std::collections::HashMap<EventId, CausalContext>,
session_events: std::collections::HashMap<SessionId, Vec<EventId>>,
}
impl InMemoryCausalStore {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, event_id: EventId, context: CausalContext) {
let session_id = context.session_id();
self.session_events
.entry(session_id)
.or_default()
.push(event_id);
self.contexts.insert(event_id, context);
}
pub fn get(&self, event_id: &EventId) -> Option<&CausalContext> {
self.contexts.get(event_id)
}
pub fn len(&self) -> usize {
self.contexts.len()
}
pub fn is_empty(&self) -> bool {
self.contexts.is_empty()
}
}
impl CausalChainQuery for InMemoryCausalStore {
fn trace_to_root(&self, event_id: &EventId) -> Result<Vec<CausalContext>> {
let mut chain = Vec::new();
let mut current_id = *event_id;
loop {
let ctx = self.contexts.get(¤t_id).ok_or_else(|| {
Error::invalid_input(format!("event {} not found in causal store", current_id))
})?;
chain.push(ctx.clone());
if ctx.is_root() {
break;
}
match ctx.parent_event_id() {
Some(parent) => current_id = *parent,
None => break,
}
}
chain.reverse();
Ok(chain)
}
fn find_root(&self, event_id: &EventId) -> Result<CausalContext> {
let chain = self.trace_to_root(event_id)?;
chain
.into_iter()
.next()
.ok_or_else(|| Error::invalid_input("empty causal chain"))
}
fn events_in_session(&self, session_id: &SessionId) -> Result<Vec<CausalContext>> {
let event_ids = self
.session_events
.get(session_id)
.cloned()
.unwrap_or_default();
let mut contexts: Vec<CausalContext> = event_ids
.iter()
.filter_map(|id| self.contexts.get(id).cloned())
.collect();
contexts.sort_by_key(|c| c.sequence());
Ok(contexts)
}
fn children_of(&self, event_id: &EventId) -> Result<Vec<CausalContext>> {
let children: Vec<CausalContext> = self
.contexts
.values()
.filter(|ctx| ctx.parent_event_id() == Some(event_id))
.cloned()
.collect();
Ok(children)
}
fn max_depth_in_session(&self, session_id: &SessionId) -> Result<u32> {
let events = self.events_in_session(session_id)?;
Ok(events.iter().map(|c| c.depth()).max().unwrap_or(0))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::hash;
fn test_event_id() -> EventId {
EventId(hash(b"test-event"))
}
fn test_session_id() -> SessionId {
SessionId::random()
}
fn test_principal() -> PrincipalId {
PrincipalId::user("alice").unwrap()
}
#[test]
fn causal_context_root_created_successfully() {
let event_id = test_event_id();
let session_id = test_session_id();
let principal = test_principal();
let ctx = CausalContext::root(event_id, session_id, principal.clone());
assert!(ctx.is_root());
assert_eq!(ctx.depth(), 0);
assert_eq!(ctx.sequence(), 0);
assert!(ctx.parent_event_id().is_none());
assert_eq!(ctx.root_event_id(), &event_id);
assert_eq!(ctx.principal(), &principal);
}
#[test]
fn causal_context_requires_session_id() {
let result = CausalContext::builder()
.root_event_id(test_event_id())
.principal(test_principal())
.build();
assert!(result.is_err());
}
#[test]
fn causal_context_requires_principal() {
let result = CausalContext::builder()
.root_event_id(test_event_id())
.session_id(test_session_id())
.build();
assert!(result.is_err());
}
#[test]
fn causal_context_depth_zero_has_no_parent() {
let ctx = CausalContext::builder()
.root_event_id(test_event_id())
.session_id(test_session_id())
.principal(test_principal())
.depth(0)
.build()
.unwrap();
assert!(ctx.parent_event_id().is_none());
assert!(ctx.is_root());
}
#[test]
fn causal_context_depth_zero_with_parent_rejected() {
let result = CausalContext::builder()
.root_event_id(test_event_id())
.session_id(test_session_id())
.principal(test_principal())
.depth(0)
.parent_event_id(test_event_id())
.build();
assert!(result.is_err());
}
#[test]
fn causal_context_depth_nonzero_requires_parent() {
let result = CausalContext::builder()
.root_event_id(test_event_id())
.session_id(test_session_id())
.principal(test_principal())
.depth(1)
.build();
assert!(result.is_err());
}
#[test]
fn causal_context_depth_nonzero_with_parent_succeeds() {
let ctx = CausalContext::builder()
.root_event_id(test_event_id())
.session_id(test_session_id())
.principal(test_principal())
.depth(1)
.sequence(1)
.parent_event_id(test_event_id())
.build()
.unwrap();
assert!(!ctx.is_root());
assert_eq!(ctx.depth(), 1);
}
#[test]
fn child_context_created_successfully() {
let root = CausalContext::root(test_event_id(), test_session_id(), test_principal());
let parent_id = test_event_id();
let child = root.child(parent_id, 1).unwrap();
assert_eq!(child.depth(), 1);
assert_eq!(child.sequence(), 1);
assert_eq!(child.parent_event_id(), Some(&parent_id));
assert_eq!(child.root_event_id(), root.root_event_id());
}
#[test]
fn child_sequence_must_exceed_parent() {
let root = CausalContext::root(test_event_id(), test_session_id(), test_principal());
let result = root.child(test_event_id(), 0);
assert!(result.is_err());
let result = root.child(test_event_id(), 1);
assert!(result.is_ok());
}
#[test]
fn validate_rejects_depth_exceeding_max() {
let ctx = CausalContext::builder()
.root_event_id(test_event_id())
.session_id(test_session_id())
.principal(test_principal())
.depth(5)
.sequence(5)
.parent_event_id(test_event_id())
.build()
.unwrap();
let result = ctx.validate(3);
assert!(result.is_err());
let result = ctx.validate(10);
assert!(result.is_ok());
}
#[test]
fn validate_against_parent_checks_sequence() {
let parent = CausalContext::root(test_event_id(), test_session_id(), test_principal());
let child = parent.child(test_event_id(), 1).unwrap();
assert!(child.validate_against_parent(&parent).is_ok());
let invalid_child = CausalContext::builder()
.root_event_id(parent.root_event_id)
.session_id(parent.session_id)
.principal(parent.principal.clone())
.depth(1)
.sequence(0) .parent_event_id(test_event_id())
.build()
.unwrap();
assert!(invalid_child.validate_against_parent(&parent).is_err());
}
#[test]
fn validate_against_parent_checks_depth() {
let parent = CausalContext::root(test_event_id(), test_session_id(), test_principal());
let valid_child = parent.child(test_event_id(), 1).unwrap();
assert!(valid_child.validate_against_parent(&parent).is_ok());
let invalid_child = CausalContext::builder()
.root_event_id(parent.root_event_id)
.session_id(parent.session_id)
.principal(parent.principal.clone())
.depth(2) .sequence(1)
.parent_event_id(test_event_id())
.build()
.unwrap();
assert!(invalid_child.validate_against_parent(&parent).is_err());
}
#[test]
fn validate_accepts_cross_session_reference() {
let source_session = test_session_id();
let target_session = test_session_id();
let event_id = test_event_id();
let cross_ref = CrossSessionReference::new(
source_session,
event_id,
"Follow-up task",
hash(b"event-data"),
);
let ctx = CausalContext::builder()
.root_event_id(event_id)
.session_id(target_session)
.principal(test_principal())
.depth(1)
.sequence(1)
.parent_event_id(event_id)
.cross_session_ref(cross_ref)
.build()
.unwrap();
assert!(ctx.cross_session_ref().is_some());
}
#[test]
fn display_format_correct() {
let ctx = CausalContext::root(test_event_id(), test_session_id(), test_principal());
let display = format!("{}", ctx);
assert!(display.contains("CausalContext"));
assert!(display.contains("depth=0"));
assert!(display.contains("seq=0"));
}
}