1use std::str;
2
3use super::error::ProtocolError;
4
5const ABSENT_TAG: u8 = 0;
6const PRESENT_TAG: u8 = 1;
7const U8_LEN: usize = 1;
8const U32_LEN: usize = 4;
9const U64_LEN: usize = 8;
10const ENVELOPE_SCHEMA_ID_LEN: usize = 32;
11
12#[derive(Clone, Debug, PartialEq, Eq, Hash)]
14pub struct MessageId(
15 pub String,
17);
18
19impl MessageId {
20 #[must_use]
22 pub fn new(value: impl Into<String>) -> Self {
23 Self(value.into())
24 }
25
26 #[must_use]
28 pub fn as_str(&self) -> &str {
29 &self.0
30 }
31
32 #[must_use]
34 pub fn into_string(self) -> String {
35 self.0
36 }
37}
38
39impl From<String> for MessageId {
40 fn from(value: String) -> Self {
41 Self(value)
42 }
43}
44
45impl From<&str> for MessageId {
46 fn from(value: &str) -> Self {
47 Self(value.to_owned())
48 }
49}
50
51#[derive(Clone, Debug, PartialEq, Eq)]
53pub struct CausalContext {
54 pub parent_id: Option<MessageId>,
56 pub vector_clock_entry: Option<u64>,
58}
59
60impl CausalContext {
61 #[must_use]
63 pub const fn independent() -> Self {
64 Self {
65 parent_id: None,
66 vector_clock_entry: None,
67 }
68 }
69
70 #[must_use]
72 pub const fn with_parent(parent_id: MessageId) -> Self {
73 Self {
74 parent_id: Some(parent_id),
75 vector_clock_entry: None,
76 }
77 }
78
79 pub fn encoded_len(&self) -> Result<usize, ProtocolError> {
86 let parent_len = match &self.parent_id {
87 Some(parent_id) => sum_lengths(&[
88 U8_LEN,
89 U32_LEN,
90 checked_u32_len(parent_id.as_str().len(), "message id")?,
91 ])?,
92 None => U8_LEN,
93 };
94 let vector_len = if self.vector_clock_entry.is_some() {
95 U8_LEN + U64_LEN
96 } else {
97 U8_LEN
98 };
99 sum_lengths(&[parent_len, vector_len])
100 }
101
102 pub fn serialize(&self) -> Result<Vec<u8>, ProtocolError> {
109 let len = self.encoded_len()?;
110 let mut bytes = Vec::with_capacity(len);
111
112 match &self.parent_id {
113 Some(parent_id) => {
114 bytes.push(PRESENT_TAG);
115 write_u32(&mut bytes, parent_id.as_str().len(), "message id")?;
116 bytes.extend_from_slice(parent_id.as_str().as_bytes());
117 }
118 None => bytes.push(ABSENT_TAG),
119 }
120
121 match self.vector_clock_entry {
122 Some(entry) => {
123 bytes.push(PRESENT_TAG);
124 bytes.extend_from_slice(&entry.to_be_bytes());
125 }
126 None => bytes.push(ABSENT_TAG),
127 }
128
129 if bytes.len() == len {
130 Ok(bytes)
131 } else {
132 Err(ProtocolError::codec(
133 "causal context encoder produced an unexpected length",
134 ))
135 }
136 }
137
138 pub fn to_wire_bytes(&self) -> Result<Vec<u8>, ProtocolError> {
145 self.serialize()
146 }
147
148 pub fn deserialize(bytes: &[u8]) -> Result<Self, ProtocolError> {
156 Self::from_wire_bytes(bytes)
157 }
158
159 pub fn from_wire_bytes(bytes: &[u8]) -> Result<Self, ProtocolError> {
167 let mut offset = 0;
168 let parent_id = match read_u8(bytes, &mut offset, "parent id presence tag")? {
169 ABSENT_TAG => None,
170 PRESENT_TAG => {
171 let len = read_u32_as_usize(bytes, &mut offset, "message id length")?;
172 let id_bytes = read_slice(bytes, &mut offset, len, "message id bytes")?;
173 let id = str::from_utf8(id_bytes)
174 .map_err(|_| ProtocolError::codec("message id was not valid utf-8"))?;
175 Some(MessageId::new(id))
176 }
177 _ => return Err(ProtocolError::codec("parent id presence tag was invalid")),
178 };
179
180 let vector_clock_entry = match read_u8(bytes, &mut offset, "vector clock presence tag")? {
181 ABSENT_TAG => None,
182 PRESENT_TAG => Some(read_u64(bytes, &mut offset, "vector clock entry")?),
183 _ => {
184 return Err(ProtocolError::codec(
185 "vector clock presence tag was invalid",
186 ));
187 }
188 };
189
190 if offset == bytes.len() {
191 Ok(Self {
192 parent_id,
193 vector_clock_entry,
194 })
195 } else {
196 Err(ProtocolError::codec(
197 "causal context contained trailing bytes",
198 ))
199 }
200 }
201}
202
203impl Default for CausalContext {
204 fn default() -> Self {
205 Self::independent()
206 }
207}
208
209pub fn extract_causal_context(envelope_bytes: &[u8]) -> Result<CausalContext, ProtocolError> {
221 let mut offset = 0;
222 let _schema_id = read_slice(
223 envelope_bytes,
224 &mut offset,
225 ENVELOPE_SCHEMA_ID_LEN,
226 "schema id",
227 )?;
228 let causal_len = read_u32_as_usize(envelope_bytes, &mut offset, "causal context length")?;
229 let causal_bytes = read_slice(
230 envelope_bytes,
231 &mut offset,
232 causal_len,
233 "causal context bytes",
234 )?;
235 CausalContext::deserialize(causal_bytes)
236}
237
238fn checked_u32_len(len: usize, field: &str) -> Result<usize, ProtocolError> {
239 u32::try_from(len)
240 .map(|_| len)
241 .map_err(|_| ProtocolError::codec(format!("{field} length exceeded u32::MAX")))
242}
243
244fn sum_lengths(parts: &[usize]) -> Result<usize, ProtocolError> {
245 let mut total = 0_usize;
246 for part in parts {
247 total = total
248 .checked_add(*part)
249 .ok_or_else(|| ProtocolError::codec("causal context length overflowed usize"))?;
250 }
251 Ok(total)
252}
253
254fn write_u32(buffer: &mut Vec<u8>, value: usize, field: &str) -> Result<(), ProtocolError> {
255 let value = u32::try_from(value)
256 .map_err(|_| ProtocolError::codec(format!("{field} length exceeded u32::MAX")))?;
257 buffer.extend_from_slice(&value.to_be_bytes());
258 Ok(())
259}
260
261fn read_u8(bytes: &[u8], offset: &mut usize, field: &str) -> Result<u8, ProtocolError> {
262 let bytes = read_slice(bytes, offset, U8_LEN, field)?;
263 let [value] = bytes else {
264 return Err(ProtocolError::codec(format!("{field} was truncated")));
265 };
266 Ok(*value)
267}
268
269fn read_u32_as_usize(
270 bytes: &[u8],
271 offset: &mut usize,
272 field: &str,
273) -> Result<usize, ProtocolError> {
274 let bytes = read_slice(bytes, offset, U32_LEN, field)?;
275 let [b0, b1, b2, b3] = bytes else {
276 return Err(ProtocolError::codec(format!("{field} was truncated")));
277 };
278 usize::try_from(u32::from_be_bytes([*b0, *b1, *b2, *b3]))
279 .map_err(|_| ProtocolError::codec(format!("{field} cannot fit usize")))
280}
281
282fn read_u64(bytes: &[u8], offset: &mut usize, field: &str) -> Result<u64, ProtocolError> {
283 let bytes = read_slice(bytes, offset, U64_LEN, field)?;
284 let [b0, b1, b2, b3, b4, b5, b6, b7] = bytes else {
285 return Err(ProtocolError::codec(format!("{field} was truncated")));
286 };
287 Ok(u64::from_be_bytes([*b0, *b1, *b2, *b3, *b4, *b5, *b6, *b7]))
288}
289
290fn read_slice<'a>(
291 bytes: &'a [u8],
292 offset: &mut usize,
293 len: usize,
294 field: &str,
295) -> Result<&'a [u8], ProtocolError> {
296 let end = offset
297 .checked_add(len)
298 .ok_or_else(|| ProtocolError::codec(format!("{field} offset overflowed usize")))?;
299 let Some(slice) = bytes.get(*offset..end) else {
300 return Err(ProtocolError::codec(format!(
301 "{field} exceeded available bytes"
302 )));
303 };
304 *offset = end;
305 Ok(slice)
306}
307
308#[cfg(test)]
309mod tests {
310 use std::fmt::Debug;
311
312 use super::{CausalContext, MessageId, extract_causal_context};
313 use crate::protocol::{MessageEnvelope, ProtocolError, SchemaId};
314
315 #[test]
316 fn causal_context_trait_bounds_are_available() {
317 fn assert_traits<T: Debug + Clone + PartialEq + Eq>() {}
318
319 assert_traits::<CausalContext>();
320 }
321
322 #[test]
323 fn constructors_create_expected_context_shapes() {
324 let independent = CausalContext::independent();
325 assert_eq!(independent.parent_id, None);
326 assert_eq!(independent.vector_clock_entry, None);
327
328 let parent = MessageId::from("parent-1");
329 let child = CausalContext::with_parent(parent.clone());
330 assert_eq!(child.parent_id, Some(parent));
331 assert_eq!(child.vector_clock_entry, None);
332 }
333
334 #[test]
335 fn causal_context_serialization_round_trips() -> Result<(), ProtocolError> {
336 let context = CausalContext {
337 parent_id: Some(MessageId::from("parent-1")),
338 vector_clock_entry: Some(7),
339 };
340 let encoded = context.serialize()?;
341 let decoded = CausalContext::deserialize(&encoded)?;
342
343 assert_eq!(decoded, context);
344 assert_eq!(encoded, context.serialize()?);
345 Ok(())
346 }
347
348 #[test]
349 fn independent_context_serializes_as_absent_fields() -> Result<(), ProtocolError> {
350 let encoded = CausalContext::independent().serialize()?;
351
352 assert_eq!(encoded, vec![0, 0]);
353 assert_eq!(
354 CausalContext::deserialize(&encoded)?,
355 CausalContext::independent()
356 );
357 Ok(())
358 }
359
360 #[test]
361 fn extract_reads_causal_context_without_payload_parsing() -> Result<(), ProtocolError> {
362 let context = CausalContext {
363 parent_id: Some(MessageId::from("parent-2")),
364 vector_clock_entry: Some(11),
365 };
366 let envelope = MessageEnvelope::new(
367 SchemaId::new([0xAB; 32]),
368 context.clone(),
369 vec![0xFF, 0xFE, 0xFD],
370 );
371 let encoded = envelope.serialize()?;
372
373 assert_eq!(extract_causal_context(&encoded)?, context);
374 Ok(())
375 }
376}