Skip to main content

liminal/protocol/
causal.rs

1use std::str;
2
3use super::error::ProtocolError;
4
5const ABSENT_TAG: u8 = 0;
6const PRESENT_TAG: u8 = 1;
7const U8_LEN: usize = 1;
8const U32_LEN: usize = 4;
9const U64_LEN: usize = 8;
10const ENVELOPE_SCHEMA_ID_LEN: usize = 32;
11
12/// Protocol-level opaque identifier for a message in a causal chain.
13#[derive(Clone, Debug, PartialEq, Eq, Hash)]
14pub struct MessageId(
15    /// Stable UTF-8 identifier bytes used on the wire.
16    pub String,
17);
18
19impl MessageId {
20    /// Wrap a stable message identifier string.
21    #[must_use]
22    pub fn new(value: impl Into<String>) -> Self {
23        Self(value.into())
24    }
25
26    /// Return the identifier as a UTF-8 string slice.
27    #[must_use]
28    pub fn as_str(&self) -> &str {
29        &self.0
30    }
31
32    /// Consume this identifier and return the wrapped string.
33    #[must_use]
34    pub fn into_string(self) -> String {
35        self.0
36    }
37}
38
39impl From<String> for MessageId {
40    fn from(value: String) -> Self {
41        Self(value)
42    }
43}
44
45impl From<&str> for MessageId {
46    fn from(value: &str) -> Self {
47        Self(value.to_owned())
48    }
49}
50
51/// Causal metadata carried inside a protocol message envelope.
52#[derive(Clone, Debug, PartialEq, Eq)]
53pub struct CausalContext {
54    /// Optional reference to the causally preceding message.
55    pub parent_id: Option<MessageId>,
56    /// Optional logical timestamp for this message's causal-chain position.
57    pub vector_clock_entry: Option<u64>,
58}
59
60impl CausalContext {
61    /// Create a causally independent context with no parent or vector-clock entry.
62    #[must_use]
63    pub const fn independent() -> Self {
64        Self {
65            parent_id: None,
66            vector_clock_entry: None,
67        }
68    }
69
70    /// Create a context that follows `parent_id` in a causal chain.
71    #[must_use]
72    pub const fn with_parent(parent_id: MessageId) -> Self {
73        Self {
74            parent_id: Some(parent_id),
75            vector_clock_entry: None,
76        }
77    }
78
79    /// Return the number of bytes needed for deterministic causal-context encoding.
80    ///
81    /// # Errors
82    ///
83    /// Returns [`ProtocolError::CodecError`] when the message id length cannot fit
84    /// in the protocol's `u32` length field or the total length overflows `usize`.
85    pub fn encoded_len(&self) -> Result<usize, ProtocolError> {
86        let parent_len = match &self.parent_id {
87            Some(parent_id) => sum_lengths(&[
88                U8_LEN,
89                U32_LEN,
90                checked_u32_len(parent_id.as_str().len(), "message id")?,
91            ])?,
92            None => U8_LEN,
93        };
94        let vector_len = if self.vector_clock_entry.is_some() {
95            U8_LEN + U64_LEN
96        } else {
97            U8_LEN
98        };
99        sum_lengths(&[parent_len, vector_len])
100    }
101
102    /// Serialize this causal context with deterministic big-endian field encoding.
103    ///
104    /// # Errors
105    ///
106    /// Returns [`ProtocolError::CodecError`] when a length field cannot represent
107    /// the encoded context.
108    pub fn serialize(&self) -> Result<Vec<u8>, ProtocolError> {
109        let len = self.encoded_len()?;
110        let mut bytes = Vec::with_capacity(len);
111
112        match &self.parent_id {
113            Some(parent_id) => {
114                bytes.push(PRESENT_TAG);
115                write_u32(&mut bytes, parent_id.as_str().len(), "message id")?;
116                bytes.extend_from_slice(parent_id.as_str().as_bytes());
117            }
118            None => bytes.push(ABSENT_TAG),
119        }
120
121        match self.vector_clock_entry {
122            Some(entry) => {
123                bytes.push(PRESENT_TAG);
124                bytes.extend_from_slice(&entry.to_be_bytes());
125            }
126            None => bytes.push(ABSENT_TAG),
127        }
128
129        if bytes.len() == len {
130            Ok(bytes)
131        } else {
132            Err(ProtocolError::codec(
133                "causal context encoder produced an unexpected length",
134            ))
135        }
136    }
137
138    /// Serialize this causal context for wire transport.
139    ///
140    /// # Errors
141    ///
142    /// Returns [`ProtocolError::CodecError`] when a length field cannot represent
143    /// the encoded context.
144    pub fn to_wire_bytes(&self) -> Result<Vec<u8>, ProtocolError> {
145        self.serialize()
146    }
147
148    /// Deserialize a causal context from deterministic wire bytes.
149    ///
150    /// # Errors
151    ///
152    /// Returns [`ProtocolError::CodecError`] when the bytes are truncated, contain
153    /// invalid presence tags, contain invalid UTF-8 message ids, or contain
154    /// trailing bytes.
155    pub fn deserialize(bytes: &[u8]) -> Result<Self, ProtocolError> {
156        Self::from_wire_bytes(bytes)
157    }
158
159    /// Deserialize a causal context from deterministic wire bytes.
160    ///
161    /// # Errors
162    ///
163    /// Returns [`ProtocolError::CodecError`] when the bytes are truncated, contain
164    /// invalid presence tags, contain invalid UTF-8 message ids, or contain
165    /// trailing bytes.
166    pub fn from_wire_bytes(bytes: &[u8]) -> Result<Self, ProtocolError> {
167        let mut offset = 0;
168        let parent_id = match read_u8(bytes, &mut offset, "parent id presence tag")? {
169            ABSENT_TAG => None,
170            PRESENT_TAG => {
171                let len = read_u32_as_usize(bytes, &mut offset, "message id length")?;
172                let id_bytes = read_slice(bytes, &mut offset, len, "message id bytes")?;
173                let id = str::from_utf8(id_bytes)
174                    .map_err(|_| ProtocolError::codec("message id was not valid utf-8"))?;
175                Some(MessageId::new(id))
176            }
177            _ => return Err(ProtocolError::codec("parent id presence tag was invalid")),
178        };
179
180        let vector_clock_entry = match read_u8(bytes, &mut offset, "vector clock presence tag")? {
181            ABSENT_TAG => None,
182            PRESENT_TAG => Some(read_u64(bytes, &mut offset, "vector clock entry")?),
183            _ => {
184                return Err(ProtocolError::codec(
185                    "vector clock presence tag was invalid",
186                ));
187            }
188        };
189
190        if offset == bytes.len() {
191            Ok(Self {
192                parent_id,
193                vector_clock_entry,
194            })
195        } else {
196            Err(ProtocolError::codec(
197                "causal context contained trailing bytes",
198            ))
199        }
200    }
201}
202
203impl Default for CausalContext {
204    fn default() -> Self {
205        Self::independent()
206    }
207}
208
209/// Extract causal context bytes from a serialized message envelope without parsing payload bytes.
210///
211/// The envelope layout is `schema_id(32) + causal_context_length(u32) +
212/// causal_context_bytes + payload_length(u32) + payload_bytes`; this function
213/// reads only the schema prefix, causal-context length, and causal-context bytes.
214///
215/// # Errors
216///
217/// Returns [`ProtocolError::CodecError`] when the envelope is truncated before or
218/// within the causal-context section, or when the causal-context bytes are
219/// malformed.
220pub fn extract_causal_context(envelope_bytes: &[u8]) -> Result<CausalContext, ProtocolError> {
221    let mut offset = 0;
222    let _schema_id = read_slice(
223        envelope_bytes,
224        &mut offset,
225        ENVELOPE_SCHEMA_ID_LEN,
226        "schema id",
227    )?;
228    let causal_len = read_u32_as_usize(envelope_bytes, &mut offset, "causal context length")?;
229    let causal_bytes = read_slice(
230        envelope_bytes,
231        &mut offset,
232        causal_len,
233        "causal context bytes",
234    )?;
235    CausalContext::deserialize(causal_bytes)
236}
237
238fn checked_u32_len(len: usize, field: &str) -> Result<usize, ProtocolError> {
239    u32::try_from(len)
240        .map(|_| len)
241        .map_err(|_| ProtocolError::codec(format!("{field} length exceeded u32::MAX")))
242}
243
244fn sum_lengths(parts: &[usize]) -> Result<usize, ProtocolError> {
245    let mut total = 0_usize;
246    for part in parts {
247        total = total
248            .checked_add(*part)
249            .ok_or_else(|| ProtocolError::codec("causal context length overflowed usize"))?;
250    }
251    Ok(total)
252}
253
254fn write_u32(buffer: &mut Vec<u8>, value: usize, field: &str) -> Result<(), ProtocolError> {
255    let value = u32::try_from(value)
256        .map_err(|_| ProtocolError::codec(format!("{field} length exceeded u32::MAX")))?;
257    buffer.extend_from_slice(&value.to_be_bytes());
258    Ok(())
259}
260
261fn read_u8(bytes: &[u8], offset: &mut usize, field: &str) -> Result<u8, ProtocolError> {
262    let bytes = read_slice(bytes, offset, U8_LEN, field)?;
263    let [value] = bytes else {
264        return Err(ProtocolError::codec(format!("{field} was truncated")));
265    };
266    Ok(*value)
267}
268
269fn read_u32_as_usize(
270    bytes: &[u8],
271    offset: &mut usize,
272    field: &str,
273) -> Result<usize, ProtocolError> {
274    let bytes = read_slice(bytes, offset, U32_LEN, field)?;
275    let [b0, b1, b2, b3] = bytes else {
276        return Err(ProtocolError::codec(format!("{field} was truncated")));
277    };
278    usize::try_from(u32::from_be_bytes([*b0, *b1, *b2, *b3]))
279        .map_err(|_| ProtocolError::codec(format!("{field} cannot fit usize")))
280}
281
282fn read_u64(bytes: &[u8], offset: &mut usize, field: &str) -> Result<u64, ProtocolError> {
283    let bytes = read_slice(bytes, offset, U64_LEN, field)?;
284    let [b0, b1, b2, b3, b4, b5, b6, b7] = bytes else {
285        return Err(ProtocolError::codec(format!("{field} was truncated")));
286    };
287    Ok(u64::from_be_bytes([*b0, *b1, *b2, *b3, *b4, *b5, *b6, *b7]))
288}
289
290fn read_slice<'a>(
291    bytes: &'a [u8],
292    offset: &mut usize,
293    len: usize,
294    field: &str,
295) -> Result<&'a [u8], ProtocolError> {
296    let end = offset
297        .checked_add(len)
298        .ok_or_else(|| ProtocolError::codec(format!("{field} offset overflowed usize")))?;
299    let Some(slice) = bytes.get(*offset..end) else {
300        return Err(ProtocolError::codec(format!(
301            "{field} exceeded available bytes"
302        )));
303    };
304    *offset = end;
305    Ok(slice)
306}
307
308#[cfg(test)]
309mod tests {
310    use std::fmt::Debug;
311
312    use super::{CausalContext, MessageId, extract_causal_context};
313    use crate::protocol::{MessageEnvelope, ProtocolError, SchemaId};
314
315    #[test]
316    fn causal_context_trait_bounds_are_available() {
317        fn assert_traits<T: Debug + Clone + PartialEq + Eq>() {}
318
319        assert_traits::<CausalContext>();
320    }
321
322    #[test]
323    fn constructors_create_expected_context_shapes() {
324        let independent = CausalContext::independent();
325        assert_eq!(independent.parent_id, None);
326        assert_eq!(independent.vector_clock_entry, None);
327
328        let parent = MessageId::from("parent-1");
329        let child = CausalContext::with_parent(parent.clone());
330        assert_eq!(child.parent_id, Some(parent));
331        assert_eq!(child.vector_clock_entry, None);
332    }
333
334    #[test]
335    fn causal_context_serialization_round_trips() -> Result<(), ProtocolError> {
336        let context = CausalContext {
337            parent_id: Some(MessageId::from("parent-1")),
338            vector_clock_entry: Some(7),
339        };
340        let encoded = context.serialize()?;
341        let decoded = CausalContext::deserialize(&encoded)?;
342
343        assert_eq!(decoded, context);
344        assert_eq!(encoded, context.serialize()?);
345        Ok(())
346    }
347
348    #[test]
349    fn independent_context_serializes_as_absent_fields() -> Result<(), ProtocolError> {
350        let encoded = CausalContext::independent().serialize()?;
351
352        assert_eq!(encoded, vec![0, 0]);
353        assert_eq!(
354            CausalContext::deserialize(&encoded)?,
355            CausalContext::independent()
356        );
357        Ok(())
358    }
359
360    #[test]
361    fn extract_reads_causal_context_without_payload_parsing() -> Result<(), ProtocolError> {
362        let context = CausalContext {
363            parent_id: Some(MessageId::from("parent-2")),
364            vector_clock_entry: Some(11),
365        };
366        let envelope = MessageEnvelope::new(
367            SchemaId::new([0xAB; 32]),
368            context.clone(),
369            vec![0xFF, 0xFE, 0xFD],
370        );
371        let encoded = envelope.serialize()?;
372
373        assert_eq!(extract_causal_context(&encoded)?, context);
374        Ok(())
375    }
376}