1mod known;
2mod payload;
3
4#[cfg(test)]
5mod tests_support;
6
7use super::causal::MessageId;
8use super::envelope::SchemaId;
9use super::error::ProtocolError;
10use super::frame::{
11 Frame, FrameType, HEADER_LEN, WORKER_REGISTER_ACK_ACCEPTED, WORKER_REGISTER_ACK_REJECTED,
12 WorkerRegisterOutcome, WorkerRegistration, validate_stream,
13};
14use super::version::ProtocolVersion;
15use known::decode_known_payload;
16use payload::{
17 PayloadReader, PayloadWriter, U16_LEN, U32_LEN, U64_LEN, bytes_field_len, checked_u32_len,
18 option_string_len, option_u16_len, schema_ids_field_len, string_field_len,
19 string_vec_field_len, sum_lengths,
20};
21
22const U8_FIELD_LEN: usize = 1;
26
27pub fn encoded_len(frame: &Frame) -> Result<usize, ProtocolError> {
34 frame.validate()?;
35 let payload_len = encoded_payload_len(frame)?;
36 HEADER_LEN
37 .checked_add(payload_len)
38 .ok_or_else(|| ProtocolError::codec("encoded frame length overflowed usize"))
39}
40
41pub fn encode(frame: &Frame, buffer: &mut [u8]) -> Result<usize, ProtocolError> {
53 frame.validate()?;
54 let payload_len = encoded_payload_len(frame)?;
55 let payload_length = u32::try_from(payload_len)
56 .map_err(|_| ProtocolError::codec("payload length exceeded u32::MAX"))?;
57 let total_len = HEADER_LEN
58 .checked_add(payload_len)
59 .ok_or_else(|| ProtocolError::codec("encoded frame length overflowed usize"))?;
60
61 if buffer.len() < total_len {
62 return Err(ProtocolError::codec("output buffer is too small"));
63 }
64
65 let Some(header) = buffer.get_mut(..HEADER_LEN) else {
66 return Err(ProtocolError::codec(
67 "output buffer is too small for header",
68 ));
69 };
70 write_header(frame, payload_length, header)?;
71
72 let Some(payload) = buffer.get_mut(HEADER_LEN..total_len) else {
73 return Err(ProtocolError::codec(
74 "output buffer is too small for payload",
75 ));
76 };
77 write_payload(frame, payload)?;
78
79 Ok(total_len)
80}
81
82pub fn decode(buffer: &[u8]) -> Result<(Frame, usize), ProtocolError> {
95 if buffer.len() < HEADER_LEN {
96 return Err(ProtocolError::IncompleteHeader {
97 message: Some("buffer shorter than fixed frame header".to_owned()),
98 });
99 }
100
101 let Some(header) = buffer.get(..HEADER_LEN) else {
102 return Err(ProtocolError::IncompleteHeader {
103 message: Some("buffer shorter than fixed frame header".to_owned()),
104 });
105 };
106 let mut header_reader = PayloadReader::new(header);
107 let type_id = header_reader.read_u8()?;
108 let flags = header_reader.read_u8()?;
109 let stream_id = header_reader.read_u32()?;
110 let payload_length = header_reader.read_u32()?;
111 header_reader.finish()?;
112
113 let payload_len = usize::try_from(payload_length)
114 .map_err(|_| ProtocolError::codec("payload length cannot fit usize"))?;
115 let total_len = HEADER_LEN
116 .checked_add(payload_len)
117 .ok_or_else(|| ProtocolError::codec("decoded frame length overflowed usize"))?;
118
119 if buffer.len() < total_len {
120 return Err(ProtocolError::TruncatedPayload {
121 message: Some("buffer shorter than declared payload length".to_owned()),
122 });
123 }
124
125 let Some(payload) = buffer.get(HEADER_LEN..total_len) else {
126 return Err(ProtocolError::TruncatedPayload {
127 message: Some("buffer shorter than declared payload length".to_owned()),
128 });
129 };
130
131 let frame_type = FrameType::from(type_id);
132 let frame = decode_payload(frame_type, flags, stream_id, payload)?;
133 Ok((frame, total_len))
134}
135
136fn write_header(
137 frame: &Frame,
138 payload_length: u32,
139 buffer: &mut [u8],
140) -> Result<(), ProtocolError> {
141 let mut writer = PayloadWriter::new(buffer);
142 writer.write_u8(u8::from(frame.frame_type()))?;
143 writer.write_u8(frame.flags())?;
144 writer.write_u32(frame.stream_id())?;
145 writer.write_u32(payload_length)?;
146 writer.finish()
147}
148
149fn encoded_payload_len(frame: &Frame) -> Result<usize, ProtocolError> {
150 match frame {
151 Frame::Connect { auth_token, .. } => sum_lengths(&[
152 ProtocolVersion::WIRE_LEN,
153 ProtocolVersion::WIRE_LEN,
154 bytes_field_len(auth_token)?,
155 ]),
156 Frame::ConnectAck { .. } => sum_lengths(&[ProtocolVersion::WIRE_LEN, U32_LEN]),
157 Frame::ConnectError { message, .. }
158 | Frame::SubscribeError { message, .. }
159 | Frame::PublishError { message, .. } => {
160 sum_lengths(&[U16_LEN, option_string_len(message.as_deref())?])
161 }
162 Frame::Disconnect { .. } | Frame::Ping { .. } | Frame::Pong { .. } => Ok(0),
163 Frame::Subscribe {
164 channel,
165 accepted_schemas,
166 ..
167 } => sum_lengths(&[
168 string_field_len(channel)?,
169 schema_ids_field_len(accepted_schemas)?,
170 U32_LEN,
171 ]),
172 Frame::SubscribeAck { .. } => sum_lengths(&[U64_LEN, SchemaId::WIRE_LEN]),
173 Frame::Unsubscribe { .. } | Frame::PublishAck { .. } => Ok(U64_LEN),
174 Frame::Publish {
175 channel,
176 envelope,
177 idempotency_key,
178 ..
179 } => {
180 let mut parts = vec![
181 string_field_len(channel)?,
182 envelope_bytes_field_len(envelope.encoded_len()?)?,
183 ];
184 if let Some(key) = idempotency_key {
185 parts.push(string_field_len(key)?);
186 }
187 sum_lengths(&parts)
188 }
189 Frame::ConversationOpen { subject, .. } => {
190 sum_lengths(&[U64_LEN, string_field_len(subject)?])
191 }
192 Frame::ConversationMessage { envelope, .. } => {
193 sum_lengths(&[U64_LEN, envelope_bytes_field_len(envelope.encoded_len()?)?])
194 }
195 Frame::ConversationClose {
196 reason_code,
197 message,
198 ..
199 } => sum_lengths(&[
200 U64_LEN,
201 option_u16_len(*reason_code),
202 option_string_len(message.as_deref())?,
203 ]),
204 Frame::ConversationError { message, .. } => {
205 sum_lengths(&[U64_LEN, U16_LEN, option_string_len(message.as_deref())?])
206 }
207 Frame::Accept {
208 referenced_message_id,
209 ..
210 } => message_id_field_len(referenced_message_id),
211 Frame::Defer {
212 referenced_message_id,
213 reason,
214 ..
215 }
216 | Frame::Reject {
217 referenced_message_id,
218 reason,
219 ..
220 } => sum_lengths(&[
221 message_id_field_len(referenced_message_id)?,
222 option_string_len(reason.as_deref())?,
223 ]),
224 Frame::Push { payload, .. } | Frame::PushReply { payload, .. } => {
225 sum_lengths(&[U64_LEN, bytes_field_len(payload)?])
226 }
227 Frame::WorkerRegister { registration, .. } => worker_register_payload_len(registration),
228 Frame::WorkerRegisterAck { outcome, .. } => worker_register_ack_payload_len(outcome),
229 Frame::Unknown { payload, .. } => checked_u32_len(payload.len()).map(|()| payload.len()),
230 }
231}
232
233fn envelope_bytes_field_len(envelope_len: usize) -> Result<usize, ProtocolError> {
234 checked_u32_len(envelope_len)?;
235 sum_lengths(&[U32_LEN, envelope_len])
236}
237
238fn message_id_field_len(message_id: &MessageId) -> Result<usize, ProtocolError> {
239 string_field_len(message_id.as_str())
240}
241
242fn worker_register_payload_len(registration: &WorkerRegistration) -> Result<usize, ProtocolError> {
243 sum_lengths(&[
244 string_vec_field_len(®istration.namespaces)?,
245 string_field_len(®istration.task_queue)?,
246 option_string_len(registration.node.as_deref())?,
247 string_vec_field_len(®istration.activity_types)?,
248 string_field_len(®istration.identity)?,
249 ])
250}
251
252fn worker_register_ack_payload_len(
253 outcome: &WorkerRegisterOutcome,
254) -> Result<usize, ProtocolError> {
255 match outcome {
256 WorkerRegisterOutcome::Accepted => Ok(U8_FIELD_LEN),
257 WorkerRegisterOutcome::Rejected { reason } => {
258 sum_lengths(&[U8_FIELD_LEN, string_field_len(reason)?])
259 }
260 }
261}
262
263fn write_handshake_payload(
264 frame: &Frame,
265 writer: &mut PayloadWriter<'_>,
266) -> Result<(), ProtocolError> {
267 match frame {
268 Frame::Connect {
269 min_version,
270 max_version,
271 auth_token,
272 ..
273 } => {
274 writer.write_slice(&min_version.to_wire_bytes())?;
275 writer.write_slice(&max_version.to_wire_bytes())?;
276 writer.write_bytes_field(auth_token)
277 }
278 Frame::ConnectAck {
279 selected_version,
280 capabilities,
281 ..
282 } => {
283 writer.write_slice(&selected_version.to_wire_bytes())?;
284 writer.write_u32(*capabilities)
285 }
286 _ => Err(ProtocolError::codec("frame type was not a handshake frame")),
287 }
288}
289
290fn write_pressure_payload(
291 frame: &Frame,
292 writer: &mut PayloadWriter<'_>,
293) -> Result<(), ProtocolError> {
294 match frame {
295 Frame::Accept {
296 referenced_message_id,
297 ..
298 } => writer.write_string_field(referenced_message_id.as_str()),
299 Frame::Defer {
300 referenced_message_id,
301 reason,
302 ..
303 }
304 | Frame::Reject {
305 referenced_message_id,
306 reason,
307 ..
308 } => {
309 writer.write_string_field(referenced_message_id.as_str())?;
310 writer.write_optional_string(reason.as_deref())
311 }
312 _ => Err(ProtocolError::codec("frame type was not a pressure frame")),
313 }
314}
315
316fn write_publish_payload(
317 frame: &Frame,
318 writer: &mut PayloadWriter<'_>,
319) -> Result<(), ProtocolError> {
320 match frame {
321 Frame::Publish {
322 channel,
323 envelope,
324 idempotency_key,
325 ..
326 } => {
327 writer.write_string_field(channel)?;
328 writer.write_bytes_field(&envelope.serialize()?)?;
329 if let Some(key) = idempotency_key {
334 writer.write_string_field(key)?;
335 }
336 Ok(())
337 }
338 _ => Err(ProtocolError::codec("frame type was not a publish frame")),
339 }
340}
341
342fn write_push_payload(frame: &Frame, writer: &mut PayloadWriter<'_>) -> Result<(), ProtocolError> {
343 match frame {
344 Frame::Push {
345 correlation_id,
346 payload,
347 ..
348 }
349 | Frame::PushReply {
350 correlation_id,
351 payload,
352 ..
353 } => {
354 writer.write_u64(*correlation_id)?;
355 writer.write_bytes_field(payload)
356 }
357 _ => Err(ProtocolError::codec("frame type was not a push frame")),
358 }
359}
360
361fn write_worker_register_payload(
362 registration: &WorkerRegistration,
363 writer: &mut PayloadWriter<'_>,
364) -> Result<(), ProtocolError> {
365 writer.write_string_vec_field(®istration.namespaces)?;
366 writer.write_string_field(®istration.task_queue)?;
367 writer.write_optional_string(registration.node.as_deref())?;
370 writer.write_string_vec_field(®istration.activity_types)?;
371 writer.write_string_field(®istration.identity)
372}
373
374fn write_worker_register_ack_payload(
375 outcome: &WorkerRegisterOutcome,
376 writer: &mut PayloadWriter<'_>,
377) -> Result<(), ProtocolError> {
378 match outcome {
379 WorkerRegisterOutcome::Accepted => writer.write_u8(WORKER_REGISTER_ACK_ACCEPTED),
380 WorkerRegisterOutcome::Rejected { reason } => {
381 writer.write_u8(WORKER_REGISTER_ACK_REJECTED)?;
382 writer.write_string_field(reason)
383 }
384 }
385}
386
387fn write_payload(frame: &Frame, buffer: &mut [u8]) -> Result<(), ProtocolError> {
388 let mut writer = PayloadWriter::new(buffer);
389 match frame {
390 Frame::Connect { .. } | Frame::ConnectAck { .. } => {
391 write_handshake_payload(frame, &mut writer)?;
392 }
393 Frame::ConnectError {
394 reason_code,
395 message,
396 ..
397 }
398 | Frame::SubscribeError {
399 reason_code,
400 message,
401 ..
402 }
403 | Frame::PublishError {
404 reason_code,
405 message,
406 ..
407 } => {
408 writer.write_u16(*reason_code)?;
409 writer.write_optional_string(message.as_deref())?;
410 }
411 Frame::Disconnect { .. } | Frame::Ping { .. } | Frame::Pong { .. } => {}
412 Frame::Subscribe {
413 channel,
414 accepted_schemas,
415 max_in_flight,
416 ..
417 } => {
418 writer.write_string_field(channel)?;
419 writer.write_schema_ids_field(accepted_schemas)?;
420 writer.write_u32(*max_in_flight)?;
421 }
422 Frame::SubscribeAck {
423 subscription_id,
424 selected_schema,
425 ..
426 } => {
427 writer.write_u64(*subscription_id)?;
428 writer.write_schema_id(*selected_schema)?;
429 }
430 Frame::Unsubscribe {
431 subscription_id, ..
432 } => writer.write_u64(*subscription_id)?,
433 Frame::Publish { .. } => write_publish_payload(frame, &mut writer)?,
434 Frame::PublishAck { message_id, .. } => writer.write_u64(*message_id)?,
435 Frame::ConversationOpen {
436 conversation_id,
437 subject,
438 ..
439 } => {
440 writer.write_u64(*conversation_id)?;
441 writer.write_string_field(subject)?;
442 }
443 Frame::ConversationMessage {
444 conversation_id,
445 envelope,
446 ..
447 } => {
448 writer.write_u64(*conversation_id)?;
449 writer.write_bytes_field(&envelope.serialize()?)?;
450 }
451 Frame::ConversationClose {
452 conversation_id,
453 reason_code,
454 message,
455 ..
456 } => {
457 writer.write_u64(*conversation_id)?;
458 writer.write_optional_u16(*reason_code)?;
459 writer.write_optional_string(message.as_deref())?;
460 }
461 Frame::ConversationError {
462 conversation_id,
463 reason_code,
464 message,
465 ..
466 } => {
467 writer.write_u64(*conversation_id)?;
468 writer.write_u16(*reason_code)?;
469 writer.write_optional_string(message.as_deref())?;
470 }
471 Frame::Accept { .. } | Frame::Defer { .. } | Frame::Reject { .. } => {
472 write_pressure_payload(frame, &mut writer)?;
473 }
474 Frame::Push { .. } | Frame::PushReply { .. } => {
475 write_push_payload(frame, &mut writer)?;
476 }
477 Frame::WorkerRegister { registration, .. } => {
478 write_worker_register_payload(registration, &mut writer)?;
479 }
480 Frame::WorkerRegisterAck { outcome, .. } => {
481 write_worker_register_ack_payload(outcome, &mut writer)?;
482 }
483 Frame::Unknown { payload, .. } => writer.write_slice(payload)?,
484 }
485 writer.finish()
486}
487
488fn decode_payload(
489 frame_type: FrameType,
490 flags: u8,
491 stream_id: u32,
492 payload: &[u8],
493) -> Result<Frame, ProtocolError> {
494 if let FrameType::Unknown(type_id) = frame_type {
495 return Ok(Frame::Unknown {
496 type_id,
497 flags,
498 stream_id,
499 payload: payload.to_vec(),
500 });
501 }
502
503 validate_stream(frame_type, stream_id)?;
504 decode_known_payload(frame_type, flags, stream_id, payload)
505}
506
507#[cfg(test)]
508mod tests;