iridium_stomp/codec.rs
1use bytes::{Buf, BufMut, BytesMut};
2use std::io;
3use tokio_util::codec::{Decoder, Encoder};
4
5use crate::frame::Frame;
6use crate::parser::{parse_frame_slice, unescape_header_value};
7
8/// Escape a STOMP 1.2 header value for wire transmission.
9///
10/// Per STOMP 1.2 spec, the following characters must be escaped:
11/// - backslash (0x5c) → `\\`
12/// - carriage return (0x0d) → `\r`
13/// - line feed (0x0a) → `\n`
14/// - colon (0x3a) → `\c` (primarily for header names, but we escape in values too for safety)
15fn escape_header_value(input: &str) -> String {
16 let mut result = String::with_capacity(input.len());
17 for ch in input.chars() {
18 match ch {
19 '\\' => result.push_str("\\\\"),
20 '\r' => result.push_str("\\r"),
21 '\n' => result.push_str("\\n"),
22 ':' => result.push_str("\\c"),
23 _ => result.push(ch),
24 }
25 }
26 result
27}
28
29/// (parser-based implementation uses `src` directly; header parsing is
30/// delegated to the `parser` module.)
31/// Items produced or consumed by the codec.
32///
33/// A `StompItem` is either a decoded `Frame` or a `Heartbeat` marker
34/// representing a single LF received on the wire.
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum StompItem {
37 /// A decoded STOMP frame (command + headers + body)
38 Frame(Frame),
39 /// A single heartbeat pulse (LF)
40 Heartbeat,
41}
42
43/// `StompCodec` implements `tokio_util::codec::{Decoder, Encoder}` for the
44/// STOMP wire protocol.
45///
46/// Responsibilities:
47/// - Decode incoming bytes into `StompItem::Frame` or `StompItem::Heartbeat`.
48/// - Support both NUL-terminated frames and frames using the `content-length`
49/// header (STOMP 1.2) for binary bodies containing NUL bytes.
50/// - Encode `StompItem` back into bytes for the wire format and emit
51/// `content-length` when necessary.
52pub struct StompCodec {
53 // No internal buffer: we parse directly from the provided `src` buffer
54}
55
56impl StompCodec {
57 pub fn new() -> Self {
58 Self {}
59 }
60}
61
62impl Default for StompCodec {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl Decoder for StompCodec {
69 type Item = StompItem;
70 type Error = io::Error;
71 /// Decode bytes from `src` into a `StompItem`.
72 ///
73 /// Parameters
74 /// - `src`: a mutable reference to the read buffer containing bytes from the
75 /// transport. The decoder may consume bytes from this buffer (using
76 /// methods like `advance` or `split_to`) when it successfully decodes a
77 /// frame. If there are not enough bytes to form a complete frame, this
78 /// method should return `Ok(None)` and leave `src` in the same state.
79 ///
80 /// Returns
81 /// - `Ok(Some(StompItem))` when a full item (frame or heartbeat) was
82 /// decoded and bytes were consumed from `src` accordingly.
83 /// - `Ok(None)` when more bytes are required to decode a complete item.
84 /// - `Err(io::Error)` on protocol or data errors (invalid UTF-8, malformed
85 /// frames, missing NUL after a content-length body, etc.).
86 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
87 // Move any newly-received bytes from the provided `src` into our
88 // internal buffer. We keep a separate buffer so parsing can proceed
89 // across arbitrary chunk boundaries without relying on indexes into
90 // heartbeat: single LF
91 if let Some(&b'\n') = src.chunk().first() {
92 src.advance(1);
93 return Ok(Some(StompItem::Heartbeat));
94 }
95
96 let chunk = src.chunk();
97 match parse_frame_slice(chunk) {
98 Ok(Some((cmd_bytes, headers, body, consumed))) => {
99 // advance src by consumed
100 src.advance(consumed);
101
102 // build owned Frame
103 let command = String::from_utf8(cmd_bytes).map_err(|e| {
104 io::Error::new(
105 io::ErrorKind::InvalidData,
106 format!("invalid utf8 in command: {}", e),
107 )
108 })?;
109 // convert headers Vec<(Vec<u8>,Vec<u8>)> -> Vec<(String,String)>
110 // and unescape per STOMP 1.2 spec
111 let mut hdrs: Vec<(String, String)> = Vec::new();
112 for (k, v) in headers {
113 // Unescape header key
114 let k_unescaped = unescape_header_value(&k).map_err(|e| {
115 io::Error::new(
116 io::ErrorKind::InvalidData,
117 format!("invalid escape in header key: {}", e),
118 )
119 })?;
120 let ks = String::from_utf8(k_unescaped).map_err(|e| {
121 io::Error::new(
122 io::ErrorKind::InvalidData,
123 format!("invalid utf8 in header key: {}", e),
124 )
125 })?;
126 // Unescape header value
127 let v_unescaped = unescape_header_value(&v).map_err(|e| {
128 io::Error::new(
129 io::ErrorKind::InvalidData,
130 format!("invalid escape in header value: {}", e),
131 )
132 })?;
133 let vs = String::from_utf8(v_unescaped).map_err(|e| {
134 io::Error::new(
135 io::ErrorKind::InvalidData,
136 format!("invalid utf8 in header value: {}", e),
137 )
138 })?;
139 hdrs.push((ks, vs));
140 }
141
142 let body = body.unwrap_or_default();
143
144 let frame = Frame {
145 command,
146 headers: hdrs,
147 body,
148 };
149 Ok(Some(StompItem::Frame(frame)))
150 }
151 Ok(None) => Ok(None),
152 Err(e) => Err(io::Error::new(
153 io::ErrorKind::InvalidData,
154 format!("parse error: {}", e),
155 )),
156 }
157 }
158}
159
160impl Encoder<StompItem> for StompCodec {
161 type Error = io::Error;
162 /// Encode a `StompItem` into the provided destination buffer.
163 ///
164 /// Parameters
165 /// - `item`: the `StompItem` to encode. The encoder takes ownership of the
166 /// item (and any contained `Frame`) and may consume/mutate its contents.
167 /// - `dst`: destination buffer where encoded bytes should be appended.
168 /// This is the same `BytesMut` provided by the `tokio_util::codec`
169 /// framework (e.g. `Framed`). Do not replace or reassign `dst`; instead
170 /// append bytes into it using `BufMut` methods (`put_u8`,
171 /// `put_slice`, `extend_from_slice`, etc.). After `encode` returns the
172 /// contents of `dst` will be written to the underlying transport.
173 ///
174 /// Returns
175 /// - `Ok(())` on success, or `Err(io::Error)` on encoding-related errors.
176 fn encode(&mut self, item: StompItem, dst: &mut BytesMut) -> Result<(), Self::Error> {
177 match item {
178 StompItem::Heartbeat => {
179 dst.put_u8(b'\n');
180 }
181 StompItem::Frame(frame) => {
182 dst.extend_from_slice(frame.command.as_bytes());
183 dst.put_u8(b'\n');
184
185 let mut headers = frame.headers;
186 let has_cl = headers
187 .iter()
188 .any(|(k, _)| k.to_lowercase() == "content-length");
189 if !has_cl {
190 let include_cl =
191 frame.body.contains(&0) || std::str::from_utf8(&frame.body).is_err();
192 if include_cl {
193 headers.push(("content-length".to_string(), frame.body.len().to_string()));
194 }
195 }
196
197 for (k, v) in headers {
198 // Escape header name and value per STOMP 1.2 spec
199 let escaped_key = escape_header_value(&k);
200 let escaped_val = escape_header_value(&v);
201 dst.extend_from_slice(escaped_key.as_bytes());
202 dst.put_u8(b':');
203 dst.extend_from_slice(escaped_val.as_bytes());
204 dst.put_u8(b'\n');
205 }
206
207 dst.put_slice(b"\n");
208 dst.extend_from_slice(&frame.body);
209 dst.put_u8(0);
210 }
211 }
212
213 Ok(())
214 }
215}