Skip to main content

iridium_stomp/
codec.rs

1use bytes::{Buf, BufMut, BytesMut};
2use std::io;
3use tokio_util::codec::{Decoder, Encoder};
4
5use crate::frame::Frame;
6use crate::parser::{parse_frame_slice, unescape_header_value};
7
8/// Escape a STOMP 1.2 header value for wire transmission.
9///
10/// Per STOMP 1.2 spec, the following characters must be escaped:
11/// - backslash (0x5c) → `\\`
12/// - carriage return (0x0d) → `\r`
13/// - line feed (0x0a) → `\n`
14/// - colon (0x3a) → `\c` (primarily for header names, but we escape in values too for safety)
15fn escape_header_value(input: &str) -> String {
16    let mut result = String::with_capacity(input.len());
17    for ch in input.chars() {
18        match ch {
19            '\\' => result.push_str("\\\\"),
20            '\r' => result.push_str("\\r"),
21            '\n' => result.push_str("\\n"),
22            ':' => result.push_str("\\c"),
23            _ => result.push(ch),
24        }
25    }
26    result
27}
28
29/// (parser-based implementation uses `src` directly; header parsing is
30/// delegated to the `parser` module.)
31/// Items produced or consumed by the codec.
32///
33/// A `StompItem` is either a decoded `Frame` or a `Heartbeat` marker
34/// representing a single LF received on the wire.
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum StompItem {
37    /// A decoded STOMP frame (command + headers + body)
38    Frame(Frame),
39    /// A single heartbeat pulse (LF)
40    Heartbeat,
41}
42
43/// `StompCodec` implements `tokio_util::codec::{Decoder, Encoder}` for the
44/// STOMP wire protocol.
45///
46/// Responsibilities:
47/// - Decode incoming bytes into `StompItem::Frame` or `StompItem::Heartbeat`.
48/// - Support both NUL-terminated frames and frames using the `content-length`
49///   header (STOMP 1.2) for binary bodies containing NUL bytes.
50/// - Encode `StompItem` back into bytes for the wire format and emit
51///   `content-length` when necessary.
52pub struct StompCodec {
53    // No internal buffer: we parse directly from the provided `src` buffer
54}
55
56impl StompCodec {
57    pub fn new() -> Self {
58        Self {}
59    }
60}
61
62impl Default for StompCodec {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl Decoder for StompCodec {
69    type Item = StompItem;
70    type Error = io::Error;
71    /// Decode bytes from `src` into a `StompItem`.
72    ///
73    /// Parameters
74    /// - `src`: a mutable reference to the read buffer containing bytes from the
75    ///   transport. The decoder may consume bytes from this buffer (using
76    ///   methods like `advance` or `split_to`) when it successfully decodes a
77    ///   frame. If there are not enough bytes to form a complete frame, this
78    ///   method should return `Ok(None)` and leave `src` in the same state.
79    ///
80    /// Returns
81    /// - `Ok(Some(StompItem))` when a full item (frame or heartbeat) was
82    ///   decoded and bytes were consumed from `src` accordingly.
83    /// - `Ok(None)` when more bytes are required to decode a complete item.
84    /// - `Err(io::Error)` on protocol or data errors (invalid UTF-8, malformed
85    ///   frames, missing NUL after a content-length body, etc.).
86    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
87        // Move any newly-received bytes from the provided `src` into our
88        // internal buffer. We keep a separate buffer so parsing can proceed
89        // across arbitrary chunk boundaries without relying on indexes into
90        // heartbeat: single LF
91        if let Some(&b'\n') = src.chunk().first() {
92            src.advance(1);
93            return Ok(Some(StompItem::Heartbeat));
94        }
95
96        let chunk = src.chunk();
97        match parse_frame_slice(chunk) {
98            Ok(Some((cmd_bytes, headers, body, consumed))) => {
99                // advance src by consumed
100                src.advance(consumed);
101
102                // build owned Frame
103                let command = String::from_utf8(cmd_bytes).map_err(|e| {
104                    io::Error::new(
105                        io::ErrorKind::InvalidData,
106                        format!("invalid utf8 in command: {}", e),
107                    )
108                })?;
109                // convert headers Vec<(Vec<u8>,Vec<u8>)> -> Vec<(String,String)>
110                // and unescape per STOMP 1.2 spec
111                let mut hdrs: Vec<(String, String)> = Vec::new();
112                for (k, v) in headers {
113                    // Unescape header key
114                    let k_unescaped = unescape_header_value(&k).map_err(|e| {
115                        io::Error::new(
116                            io::ErrorKind::InvalidData,
117                            format!("invalid escape in header key: {}", e),
118                        )
119                    })?;
120                    let ks = String::from_utf8(k_unescaped).map_err(|e| {
121                        io::Error::new(
122                            io::ErrorKind::InvalidData,
123                            format!("invalid utf8 in header key: {}", e),
124                        )
125                    })?;
126                    // Unescape header value
127                    let v_unescaped = unescape_header_value(&v).map_err(|e| {
128                        io::Error::new(
129                            io::ErrorKind::InvalidData,
130                            format!("invalid escape in header value: {}", e),
131                        )
132                    })?;
133                    let vs = String::from_utf8(v_unescaped).map_err(|e| {
134                        io::Error::new(
135                            io::ErrorKind::InvalidData,
136                            format!("invalid utf8 in header value: {}", e),
137                        )
138                    })?;
139                    hdrs.push((ks, vs));
140                }
141
142                let body = body.unwrap_or_default();
143
144                let frame = Frame {
145                    command,
146                    headers: hdrs,
147                    body,
148                };
149                Ok(Some(StompItem::Frame(frame)))
150            }
151            Ok(None) => Ok(None),
152            Err(e) => Err(io::Error::new(
153                io::ErrorKind::InvalidData,
154                format!("parse error: {}", e),
155            )),
156        }
157    }
158}
159
160impl Encoder<StompItem> for StompCodec {
161    type Error = io::Error;
162    /// Encode a `StompItem` into the provided destination buffer.
163    ///
164    /// Parameters
165    /// - `item`: the `StompItem` to encode. The encoder takes ownership of the
166    ///   item (and any contained `Frame`) and may consume/mutate its contents.
167    /// - `dst`: destination buffer where encoded bytes should be appended.
168    ///   This is the same `BytesMut` provided by the `tokio_util::codec`
169    ///   framework (e.g. `Framed`). Do not replace or reassign `dst`; instead
170    ///   append bytes into it using `BufMut` methods (`put_u8`,
171    ///   `put_slice`, `extend_from_slice`, etc.). After `encode` returns the
172    ///   contents of `dst` will be written to the underlying transport.
173    ///
174    /// Returns
175    /// - `Ok(())` on success, or `Err(io::Error)` on encoding-related errors.
176    fn encode(&mut self, item: StompItem, dst: &mut BytesMut) -> Result<(), Self::Error> {
177        match item {
178            StompItem::Heartbeat => {
179                dst.put_u8(b'\n');
180            }
181            StompItem::Frame(frame) => {
182                dst.extend_from_slice(frame.command.as_bytes());
183                dst.put_u8(b'\n');
184
185                let mut headers = frame.headers;
186                let has_cl = headers
187                    .iter()
188                    .any(|(k, _)| k.to_lowercase() == "content-length");
189                if !has_cl {
190                    let include_cl =
191                        frame.body.contains(&0) || std::str::from_utf8(&frame.body).is_err();
192                    if include_cl {
193                        headers.push(("content-length".to_string(), frame.body.len().to_string()));
194                    }
195                }
196
197                for (k, v) in headers {
198                    // Escape header name and value per STOMP 1.2 spec
199                    let escaped_key = escape_header_value(&k);
200                    let escaped_val = escape_header_value(&v);
201                    dst.extend_from_slice(escaped_key.as_bytes());
202                    dst.put_u8(b':');
203                    dst.extend_from_slice(escaped_val.as_bytes());
204                    dst.put_u8(b'\n');
205                }
206
207                dst.put_slice(b"\n");
208                dst.extend_from_slice(&frame.body);
209                dst.put_u8(0);
210            }
211        }
212
213        Ok(())
214    }
215}