Skip to main content

liminal/protocol/
codec.rs

1mod known;
2mod payload;
3
4#[cfg(test)]
5mod tests_support;
6
7use super::causal::MessageId;
8use super::envelope::SchemaId;
9use super::error::ProtocolError;
10use super::frame::{
11    Frame, FrameType, HEADER_LEN, WORKER_REGISTER_ACK_ACCEPTED, WORKER_REGISTER_ACK_REJECTED,
12    WorkerRegisterOutcome, WorkerRegistration, validate_stream,
13};
14use super::version::ProtocolVersion;
15use known::decode_known_payload;
16use payload::{
17    PayloadReader, PayloadWriter, U16_LEN, U32_LEN, U64_LEN, bytes_field_len, checked_u32_len,
18    option_string_len, option_u16_len, schema_ids_field_len, string_field_len,
19    string_vec_field_len, sum_lengths,
20};
21
22/// Wire length of the one-byte fields the worker-registration frames use (the
23/// optional-node presence byte folds into [`option_string_len`]; this covers the
24/// ack's status byte).
25const U8_FIELD_LEN: usize = 1;
26
27/// Return the number of bytes needed to encode a frame.
28///
29/// # Errors
30///
31/// Returns [`ProtocolError`] when the frame violates stream invariants or its
32/// payload cannot fit in the protocol's `u32` length fields.
33pub fn encoded_len(frame: &Frame) -> Result<usize, ProtocolError> {
34    frame.validate()?;
35    let payload_len = encoded_payload_len(frame)?;
36    HEADER_LEN
37        .checked_add(payload_len)
38        .ok_or_else(|| ProtocolError::codec("encoded frame length overflowed usize"))
39}
40
41/// Encode a frame into the provided byte buffer, returning bytes written.
42///
43/// The buffer must be at least [`encoded_len`] bytes long. Encoding writes the
44/// fixed 10-byte header followed by the serialized payload and performs no heap
45/// allocation.
46///
47/// # Errors
48///
49/// Returns [`ProtocolError`] when the frame violates stream invariants, its
50/// payload cannot fit in the protocol's length fields, or the provided buffer is
51/// too small.
52pub fn encode(frame: &Frame, buffer: &mut [u8]) -> Result<usize, ProtocolError> {
53    frame.validate()?;
54    let payload_len = encoded_payload_len(frame)?;
55    let payload_length = u32::try_from(payload_len)
56        .map_err(|_| ProtocolError::codec("payload length exceeded u32::MAX"))?;
57    let total_len = HEADER_LEN
58        .checked_add(payload_len)
59        .ok_or_else(|| ProtocolError::codec("encoded frame length overflowed usize"))?;
60
61    if buffer.len() < total_len {
62        return Err(ProtocolError::codec("output buffer is too small"));
63    }
64
65    let Some(header) = buffer.get_mut(..HEADER_LEN) else {
66        return Err(ProtocolError::codec(
67            "output buffer is too small for header",
68        ));
69    };
70    write_header(frame, payload_length, header)?;
71
72    let Some(payload) = buffer.get_mut(HEADER_LEN..total_len) else {
73        return Err(ProtocolError::codec(
74            "output buffer is too small for payload",
75        ));
76    };
77    write_payload(frame, payload)?;
78
79    Ok(total_len)
80}
81
82/// Decode one complete frame from a byte buffer.
83///
84/// Returns the decoded frame and the number of bytes consumed. Unknown frame
85/// types are length-delimited and returned as [`Frame::Unknown`] without
86/// producing an error.
87///
88/// # Errors
89///
90/// Returns [`ProtocolError::IncompleteHeader`] for buffers shorter than the
91/// fixed header, [`ProtocolError::TruncatedPayload`] when the declared payload
92/// is not fully present, and [`ProtocolError`] for malformed known-frame
93/// payloads or invalid stream placement.
94pub fn decode(buffer: &[u8]) -> Result<(Frame, usize), ProtocolError> {
95    if buffer.len() < HEADER_LEN {
96        return Err(ProtocolError::IncompleteHeader {
97            message: Some("buffer shorter than fixed frame header".to_owned()),
98        });
99    }
100
101    let Some(header) = buffer.get(..HEADER_LEN) else {
102        return Err(ProtocolError::IncompleteHeader {
103            message: Some("buffer shorter than fixed frame header".to_owned()),
104        });
105    };
106    let mut header_reader = PayloadReader::new(header);
107    let type_id = header_reader.read_u8()?;
108    let flags = header_reader.read_u8()?;
109    let stream_id = header_reader.read_u32()?;
110    let payload_length = header_reader.read_u32()?;
111    header_reader.finish()?;
112
113    let payload_len = usize::try_from(payload_length)
114        .map_err(|_| ProtocolError::codec("payload length cannot fit usize"))?;
115    let total_len = HEADER_LEN
116        .checked_add(payload_len)
117        .ok_or_else(|| ProtocolError::codec("decoded frame length overflowed usize"))?;
118
119    if buffer.len() < total_len {
120        return Err(ProtocolError::TruncatedPayload {
121            message: Some("buffer shorter than declared payload length".to_owned()),
122        });
123    }
124
125    let Some(payload) = buffer.get(HEADER_LEN..total_len) else {
126        return Err(ProtocolError::TruncatedPayload {
127            message: Some("buffer shorter than declared payload length".to_owned()),
128        });
129    };
130
131    let frame_type = FrameType::from(type_id);
132    let frame = decode_payload(frame_type, flags, stream_id, payload)?;
133    Ok((frame, total_len))
134}
135
136fn write_header(
137    frame: &Frame,
138    payload_length: u32,
139    buffer: &mut [u8],
140) -> Result<(), ProtocolError> {
141    let mut writer = PayloadWriter::new(buffer);
142    writer.write_u8(u8::from(frame.frame_type()))?;
143    writer.write_u8(frame.flags())?;
144    writer.write_u32(frame.stream_id())?;
145    writer.write_u32(payload_length)?;
146    writer.finish()
147}
148
149fn encoded_payload_len(frame: &Frame) -> Result<usize, ProtocolError> {
150    match frame {
151        Frame::Connect { auth_token, .. } => sum_lengths(&[
152            ProtocolVersion::WIRE_LEN,
153            ProtocolVersion::WIRE_LEN,
154            bytes_field_len(auth_token)?,
155        ]),
156        Frame::ConnectAck { .. } => sum_lengths(&[ProtocolVersion::WIRE_LEN, U32_LEN]),
157        Frame::ConnectError { message, .. }
158        | Frame::SubscribeError { message, .. }
159        | Frame::PublishError { message, .. } => {
160            sum_lengths(&[U16_LEN, option_string_len(message.as_deref())?])
161        }
162        Frame::Disconnect { .. } | Frame::Ping { .. } | Frame::Pong { .. } => Ok(0),
163        Frame::Subscribe {
164            channel,
165            accepted_schemas,
166            ..
167        } => sum_lengths(&[
168            string_field_len(channel)?,
169            schema_ids_field_len(accepted_schemas)?,
170            U32_LEN,
171        ]),
172        Frame::SubscribeAck { .. } => sum_lengths(&[U64_LEN, SchemaId::WIRE_LEN]),
173        Frame::Unsubscribe { .. } | Frame::PublishAck { .. } => Ok(U64_LEN),
174        Frame::Publish {
175            channel,
176            envelope,
177            idempotency_key,
178            ..
179        } => {
180            let mut parts = vec![
181                string_field_len(channel)?,
182                envelope_bytes_field_len(envelope.encoded_len()?)?,
183            ];
184            if let Some(key) = idempotency_key {
185                parts.push(string_field_len(key)?);
186            }
187            sum_lengths(&parts)
188        }
189        Frame::ConversationOpen { subject, .. } => {
190            sum_lengths(&[U64_LEN, string_field_len(subject)?])
191        }
192        Frame::ConversationMessage { envelope, .. } => {
193            sum_lengths(&[U64_LEN, envelope_bytes_field_len(envelope.encoded_len()?)?])
194        }
195        Frame::ConversationClose {
196            reason_code,
197            message,
198            ..
199        } => sum_lengths(&[
200            U64_LEN,
201            option_u16_len(*reason_code),
202            option_string_len(message.as_deref())?,
203        ]),
204        Frame::ConversationError { message, .. } => {
205            sum_lengths(&[U64_LEN, U16_LEN, option_string_len(message.as_deref())?])
206        }
207        Frame::Accept {
208            referenced_message_id,
209            ..
210        } => message_id_field_len(referenced_message_id),
211        Frame::Defer {
212            referenced_message_id,
213            reason,
214            ..
215        }
216        | Frame::Reject {
217            referenced_message_id,
218            reason,
219            ..
220        } => sum_lengths(&[
221            message_id_field_len(referenced_message_id)?,
222            option_string_len(reason.as_deref())?,
223        ]),
224        Frame::Push { payload, .. } | Frame::PushReply { payload, .. } => {
225            sum_lengths(&[U64_LEN, bytes_field_len(payload)?])
226        }
227        Frame::WorkerRegister { registration, .. } => worker_register_payload_len(registration),
228        Frame::WorkerRegisterAck { outcome, .. } => worker_register_ack_payload_len(outcome),
229        Frame::Unknown { payload, .. } => checked_u32_len(payload.len()).map(|()| payload.len()),
230    }
231}
232
233fn envelope_bytes_field_len(envelope_len: usize) -> Result<usize, ProtocolError> {
234    checked_u32_len(envelope_len)?;
235    sum_lengths(&[U32_LEN, envelope_len])
236}
237
238fn message_id_field_len(message_id: &MessageId) -> Result<usize, ProtocolError> {
239    string_field_len(message_id.as_str())
240}
241
242fn worker_register_payload_len(registration: &WorkerRegistration) -> Result<usize, ProtocolError> {
243    sum_lengths(&[
244        string_vec_field_len(&registration.namespaces)?,
245        string_field_len(&registration.task_queue)?,
246        option_string_len(registration.node.as_deref())?,
247        string_vec_field_len(&registration.activity_types)?,
248        string_field_len(&registration.identity)?,
249    ])
250}
251
252fn worker_register_ack_payload_len(
253    outcome: &WorkerRegisterOutcome,
254) -> Result<usize, ProtocolError> {
255    match outcome {
256        WorkerRegisterOutcome::Accepted => Ok(U8_FIELD_LEN),
257        WorkerRegisterOutcome::Rejected { reason } => {
258            sum_lengths(&[U8_FIELD_LEN, string_field_len(reason)?])
259        }
260    }
261}
262
263fn write_handshake_payload(
264    frame: &Frame,
265    writer: &mut PayloadWriter<'_>,
266) -> Result<(), ProtocolError> {
267    match frame {
268        Frame::Connect {
269            min_version,
270            max_version,
271            auth_token,
272            ..
273        } => {
274            writer.write_slice(&min_version.to_wire_bytes())?;
275            writer.write_slice(&max_version.to_wire_bytes())?;
276            writer.write_bytes_field(auth_token)
277        }
278        Frame::ConnectAck {
279            selected_version,
280            capabilities,
281            ..
282        } => {
283            writer.write_slice(&selected_version.to_wire_bytes())?;
284            writer.write_u32(*capabilities)
285        }
286        _ => Err(ProtocolError::codec("frame type was not a handshake frame")),
287    }
288}
289
290fn write_pressure_payload(
291    frame: &Frame,
292    writer: &mut PayloadWriter<'_>,
293) -> Result<(), ProtocolError> {
294    match frame {
295        Frame::Accept {
296            referenced_message_id,
297            ..
298        } => writer.write_string_field(referenced_message_id.as_str()),
299        Frame::Defer {
300            referenced_message_id,
301            reason,
302            ..
303        }
304        | Frame::Reject {
305            referenced_message_id,
306            reason,
307            ..
308        } => {
309            writer.write_string_field(referenced_message_id.as_str())?;
310            writer.write_optional_string(reason.as_deref())
311        }
312        _ => Err(ProtocolError::codec("frame type was not a pressure frame")),
313    }
314}
315
316fn write_publish_payload(
317    frame: &Frame,
318    writer: &mut PayloadWriter<'_>,
319) -> Result<(), ProtocolError> {
320    match frame {
321        Frame::Publish {
322            channel,
323            envelope,
324            idempotency_key,
325            ..
326        } => {
327            writer.write_string_field(channel)?;
328            writer.write_bytes_field(&envelope.serialize()?)?;
329            // The trailing idempotency-key field is written ONLY when present, so a
330            // no-key publish stays byte-identical to the pre-13-L1 layout. The
331            // PUBLISH_IDEMPOTENCY_KEY_FLAG bit (set on construction) tells the
332            // decoder whether to read it back.
333            if let Some(key) = idempotency_key {
334                writer.write_string_field(key)?;
335            }
336            Ok(())
337        }
338        _ => Err(ProtocolError::codec("frame type was not a publish frame")),
339    }
340}
341
342fn write_push_payload(frame: &Frame, writer: &mut PayloadWriter<'_>) -> Result<(), ProtocolError> {
343    match frame {
344        Frame::Push {
345            correlation_id,
346            payload,
347            ..
348        }
349        | Frame::PushReply {
350            correlation_id,
351            payload,
352            ..
353        } => {
354            writer.write_u64(*correlation_id)?;
355            writer.write_bytes_field(payload)
356        }
357        _ => Err(ProtocolError::codec("frame type was not a push frame")),
358    }
359}
360
361fn write_worker_register_payload(
362    registration: &WorkerRegistration,
363    writer: &mut PayloadWriter<'_>,
364) -> Result<(), ProtocolError> {
365    writer.write_string_vec_field(&registration.namespaces)?;
366    writer.write_string_field(&registration.task_queue)?;
367    // `node` is optional locality: a presence byte distinguishes `None` from
368    // `Some("")` so an absent node never collapses to an empty string.
369    writer.write_optional_string(registration.node.as_deref())?;
370    writer.write_string_vec_field(&registration.activity_types)?;
371    writer.write_string_field(&registration.identity)
372}
373
374fn write_worker_register_ack_payload(
375    outcome: &WorkerRegisterOutcome,
376    writer: &mut PayloadWriter<'_>,
377) -> Result<(), ProtocolError> {
378    match outcome {
379        WorkerRegisterOutcome::Accepted => writer.write_u8(WORKER_REGISTER_ACK_ACCEPTED),
380        WorkerRegisterOutcome::Rejected { reason } => {
381            writer.write_u8(WORKER_REGISTER_ACK_REJECTED)?;
382            writer.write_string_field(reason)
383        }
384    }
385}
386
387fn write_payload(frame: &Frame, buffer: &mut [u8]) -> Result<(), ProtocolError> {
388    let mut writer = PayloadWriter::new(buffer);
389    match frame {
390        Frame::Connect { .. } | Frame::ConnectAck { .. } => {
391            write_handshake_payload(frame, &mut writer)?;
392        }
393        Frame::ConnectError {
394            reason_code,
395            message,
396            ..
397        }
398        | Frame::SubscribeError {
399            reason_code,
400            message,
401            ..
402        }
403        | Frame::PublishError {
404            reason_code,
405            message,
406            ..
407        } => {
408            writer.write_u16(*reason_code)?;
409            writer.write_optional_string(message.as_deref())?;
410        }
411        Frame::Disconnect { .. } | Frame::Ping { .. } | Frame::Pong { .. } => {}
412        Frame::Subscribe {
413            channel,
414            accepted_schemas,
415            max_in_flight,
416            ..
417        } => {
418            writer.write_string_field(channel)?;
419            writer.write_schema_ids_field(accepted_schemas)?;
420            writer.write_u32(*max_in_flight)?;
421        }
422        Frame::SubscribeAck {
423            subscription_id,
424            selected_schema,
425            ..
426        } => {
427            writer.write_u64(*subscription_id)?;
428            writer.write_schema_id(*selected_schema)?;
429        }
430        Frame::Unsubscribe {
431            subscription_id, ..
432        } => writer.write_u64(*subscription_id)?,
433        Frame::Publish { .. } => write_publish_payload(frame, &mut writer)?,
434        Frame::PublishAck { message_id, .. } => writer.write_u64(*message_id)?,
435        Frame::ConversationOpen {
436            conversation_id,
437            subject,
438            ..
439        } => {
440            writer.write_u64(*conversation_id)?;
441            writer.write_string_field(subject)?;
442        }
443        Frame::ConversationMessage {
444            conversation_id,
445            envelope,
446            ..
447        } => {
448            writer.write_u64(*conversation_id)?;
449            writer.write_bytes_field(&envelope.serialize()?)?;
450        }
451        Frame::ConversationClose {
452            conversation_id,
453            reason_code,
454            message,
455            ..
456        } => {
457            writer.write_u64(*conversation_id)?;
458            writer.write_optional_u16(*reason_code)?;
459            writer.write_optional_string(message.as_deref())?;
460        }
461        Frame::ConversationError {
462            conversation_id,
463            reason_code,
464            message,
465            ..
466        } => {
467            writer.write_u64(*conversation_id)?;
468            writer.write_u16(*reason_code)?;
469            writer.write_optional_string(message.as_deref())?;
470        }
471        Frame::Accept { .. } | Frame::Defer { .. } | Frame::Reject { .. } => {
472            write_pressure_payload(frame, &mut writer)?;
473        }
474        Frame::Push { .. } | Frame::PushReply { .. } => {
475            write_push_payload(frame, &mut writer)?;
476        }
477        Frame::WorkerRegister { registration, .. } => {
478            write_worker_register_payload(registration, &mut writer)?;
479        }
480        Frame::WorkerRegisterAck { outcome, .. } => {
481            write_worker_register_ack_payload(outcome, &mut writer)?;
482        }
483        Frame::Unknown { payload, .. } => writer.write_slice(payload)?,
484    }
485    writer.finish()
486}
487
488fn decode_payload(
489    frame_type: FrameType,
490    flags: u8,
491    stream_id: u32,
492    payload: &[u8],
493) -> Result<Frame, ProtocolError> {
494    if let FrameType::Unknown(type_id) = frame_type {
495        return Ok(Frame::Unknown {
496            type_id,
497            flags,
498            stream_id,
499            payload: payload.to_vec(),
500        });
501    }
502
503    validate_stream(frame_type, stream_id)?;
504    decode_known_payload(frame_type, flags, stream_id, payload)
505}
506
507#[cfg(test)]
508mod tests;