Skip to main content

netconf_rust/
codec.rs

1use bytes::{Buf, Bytes, BytesMut};
2use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3use tokio_util::codec::{Decoder, Encoder};
4
5use crate::error::FramingError;
6
7const EOM_MARKER: &[u8] = b"]]>]]>";
8const EOM_LEN: usize = EOM_MARKER.len();
9
10const CHUNKED_EOM_MARKER: &[u8] = b"\n##\n";
11const CHUNKED_EOM_MARKER_LEN: usize = CHUNKED_EOM_MARKER.len();
12
13const CHUNKED_HEADER_START: &[u8] = b"\n#";
14
15/// Netconf defines 2 framing modes.
16/// 1. Netconf 1.0 defines End of Message (EOM) in RFC 4742, where each message is followed by
17///    the literal sequence `]]>]]>`. For example, `<rpc-reply message-id="1"><ok/></rpc-reply>]]>]]>`
18/// 2. Netconf 1.1 defines Chunked in RFC 6242, where each message is sent with length-prefixed
19///    chunks. The chunk starts with \n#{num_of_bytes_in_msg}\n. \n##\n defines end of message
20///    for example `\n#28\n<rpc-reply><ok/></rpc-re\n#6\nply/>\n##\n`
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum FramingMode {
23    /// NETCONF 1.0 (RFC 4742): messages terminated by `]]>]]`
24    EndOfMessage,
25    /// NETCONF 1.1 (RFC 6242): length-prefixed chunked framing
26    Chunked,
27}
28
29#[derive(Default, Debug, Clone, Copy)]
30pub struct CodecConfig {
31    pub max_message_size: Option<usize>, // None = unlimited
32}
33
34pub struct NetconfCodec {
35    framing_mode: FramingMode,
36    config: CodecConfig,
37    chunked_buf: BytesMut,
38}
39/// tokio_util codec for NETCONF message framing.
40///
41/// Implements the `Decoder` trait from tokio_util, which defines how to
42/// extract discrete messages from a continuous TCP byte stream.
43///
44/// We don't call `decode()` directly. Instead, we pass this codec to
45/// `FramedRead`, which handles the I/O loop:
46///
47///   let reader = FramedRead::new(read_half, NetconfCodec::new(...));
48///   while let Some(msg) = reader.next().await { ... }
49///
50/// `FramedRead` reads bytes from the network into an internal buffer,
51/// calls our `decode()` to check for a complete message, and yields
52/// messages as a Stream. We only need to implement the framing logic --
53/// "is there a complete message in this buffer?" -- and `FramedRead`
54/// handles everything else (buffer management, read loops, EOF).
55///
56/// decode() is called repeatedly. It owns the internal buffer. When
57/// decode() returns Ok(None), FramedRead reads more bytes from the
58/// underlying AsyncRead into the buffer and calls decode() again.
59/// When decode() returns Ok(Some(item)), the item is yielded to the
60/// consumer. The key invariant is: decode() must consume
61/// (via split_to() or advance()) exactly the bytes belonging to the
62/// returned message. Any leftover bytes in src remain for the next
63/// decode() call.
64///
65/// Returns `Bytes` (not `String`) because UTF-8 validation is deferred
66/// to the XML parser. This avoids a redundant O(n) validation pass at
67/// the framing layer
68impl NetconfCodec {
69    pub fn new(framing_mode: FramingMode, config: CodecConfig) -> Self {
70        Self {
71            framing_mode,
72            config,
73            chunked_buf: BytesMut::new(),
74        }
75    }
76
77    pub fn set_mode(&mut self, framing_mode: FramingMode) {
78        self.framing_mode = framing_mode;
79        self.chunked_buf.clear();
80    }
81    pub fn framing_mode(&self) -> FramingMode {
82        self.framing_mode
83    }
84
85    fn check_size(&self, size: usize) -> Result<(), FramingError> {
86        if let Some(max_size) = self.config.max_message_size
87            && size > max_size
88        {
89            return Err(FramingError::MessageTooLarge {
90                limit: max_size,
91                received: size,
92            });
93        }
94        Ok(())
95    }
96
97    /// Scans src for ]]>]]> and returns message bytes if found otherwise we need more data.
98    /// We assume that the marker only appears at the message boundry because
99    /// 1.  XML itself forbids ]]> inside content. so ]]>]]> can never appear in well-formed XML.
100    /// 2.  RFC 4742 Section 4.2 defines the framing: the message is terminated by ]]>]]>.
101    ///     The content should always be valid XML, and valid XML can't contain ]]>
102    ///     However, in Netconf 1.0, there is actually no guarantee. This fact relies on the netconf server actually
103    ///     returning well formed XML. If it doesn't then our code will find the first occurrence
104    ///     the marker, split the message in the wrong place and then return an error.
105    fn decode_eom(&self, src: &mut BytesMut) -> Result<Option<Bytes>, FramingError> {
106        // The smallest possible complete message is EOM_MARKER, so if src is smaller than
107        // that, then we need more data
108        if src.len() < EOM_LEN {
109            return Ok(None);
110        }
111        if let Some(eom_pos_start) = find_subsequence(src, EOM_MARKER) {
112            let msg_len = eom_pos_start;
113            self.check_size(msg_len)?;
114            let msg = src.split_to(msg_len).freeze();
115            src.advance(EOM_LEN);
116            Ok(Some(msg))
117        } else {
118            self.check_size(src.len())?;
119            Ok(None)
120        }
121    }
122
123    fn decode_chunked(&mut self, src: &mut BytesMut) -> Result<Option<Bytes>, FramingError> {
124        // The smallest possible complete message is CHUNKED_EOM_MARKER. Otherwise more
125        // data is expected.
126        loop {
127            if src.len() < CHUNKED_EOM_MARKER_LEN {
128                return Ok(None);
129            }
130
131            // Every chunk or end marker starts with \n#
132            if src[0..2] != *CHUNKED_HEADER_START {
133                return Err(FramingError::InvalidHeader {
134                    expected: "\\n#",
135                    got: src[..2].to_vec(),
136                });
137            }
138
139            // Check for end of chunks marker: \n##\n.
140            // We know the first 2 bytes are valid already
141            if src[2] == b'#' {
142                if src[3] != b'\n' {
143                    return Err(FramingError::InvalidHeader {
144                        expected: "\\n##\\n",
145                        got: src[..4].to_vec(),
146                    });
147                }
148
149                // We found end of message. Consume \n##\n
150                src.advance(CHUNKED_EOM_MARKER_LEN);
151                let msg = self.chunked_buf.split().freeze();
152                return Ok(Some(msg));
153            }
154
155            // Otherwise not end of message, so we have chunk header \n#<size>\n
156            let header_start = 2; // skip \n#
157            let header_end = match src[header_start..].iter().position(|&b| b == b'\n') {
158                Some(pos_end_of_header) => header_start + pos_end_of_header,
159                None => {
160                    // Header not yet complete - need more data.
161                    // But sanity-check the chuck size
162                    if src.len() > 20 {
163                        return Err(FramingError::InvalidChunkSize(
164                            String::from_utf8_lossy(&src[header_start..]).into_owned(),
165                        ));
166                    }
167                    return Ok(None);
168                }
169            };
170
171            // Extract the chunk size from the header and parse into usize
172            let size_str = &src[header_start..header_end];
173            let chunk_size: usize = std::str::from_utf8(size_str)
174                .map_err(|_| {
175                    FramingError::InvalidChunkSize(String::from_utf8_lossy(size_str).into_owned())
176                })?
177                .parse()
178                .map_err(|_| {
179                    FramingError::InvalidChunkSize(String::from_utf8_lossy(size_str).into_owned())
180                })?;
181
182            if chunk_size == 0 {
183                return Err(FramingError::InvalidChunkSize("0".into()));
184            }
185
186            // Total header length  \n# + size digits + \n
187            let header_len = header_end + 1; // +1 for the trailing \n
188
189            // Check if the full chunk (header + data) is available
190            let total_chunk_len = header_len + chunk_size;
191            if src.len() < total_chunk_len {
192                return Ok(None); // Need more data for the chunk
193            }
194            self.check_size(self.chunked_buf.len() + chunk_size)?;
195
196            // Consume header
197            src.advance(header_len);
198
199            // Consume the chunk of data
200            self.chunked_buf.extend_from_slice(&src[..chunk_size]);
201            src.advance(chunk_size);
202        }
203    }
204}
205
206// find ]]>]]> and return the start byte index
207// TODO: Potential optimization: why scan the thing again?
208fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
209    haystack
210        .windows(needle.len())
211        .position(|window| window == needle)
212}
213
214pub(crate) async fn read_eom_message<R: AsyncRead + Unpin>(
215    reader: &mut R,
216    max_size: Option<usize>,
217) -> crate::Result<String> {
218    let mut buf = Vec::with_capacity(4096);
219    let mut tmp = [0u8; 4096];
220
221    loop {
222        let read_bytes = reader.read(&mut tmp).await?;
223
224        if read_bytes == 0 {
225            return Err(FramingError::UnexpectedEof.into());
226        }
227        buf.extend_from_slice(&tmp[..read_bytes]);
228        if let Some(limit) = max_size
229            && buf.len() > limit + EOM_LEN
230        {
231            return Err(FramingError::MessageTooLarge {
232                limit,
233                received: buf.len(),
234            }
235            .into());
236        }
237        if buf.len() >= EOM_LEN && buf[buf.len() - EOM_LEN..] == *EOM_MARKER {
238            buf.truncate(buf.len() - EOM_LEN);
239            return String::from_utf8(buf).map_err(|_| FramingError::InvalidUtf8.into());
240        }
241    }
242}
243
244pub(crate) async fn write_eom_message<W: AsyncWrite + Unpin>(
245    writer: &mut W,
246    message: &str,
247) -> crate::Result<()> {
248    writer.write_all(message.as_bytes()).await?;
249    writer.write_all(EOM_MARKER).await?;
250    writer.flush().await?;
251    Ok(())
252}
253
254/// <https://docs.rs/tokio-util/latest/tokio_util/codec/trait.Decoder.html>
255impl Decoder for NetconfCodec {
256    type Item = Bytes;
257    type Error = FramingError;
258
259    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
260        match self.framing_mode {
261            FramingMode::EndOfMessage => self.decode_eom(src),
262            FramingMode::Chunked => self.decode_chunked(src),
263        }
264    }
265}
266
267impl Encoder<Bytes> for NetconfCodec {
268    type Error = FramingError;
269
270    fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
271        match self.framing_mode {
272            FramingMode::EndOfMessage => {
273                dst.reserve(item.len() + EOM_LEN);
274                dst.extend_from_slice(&item);
275                dst.extend_from_slice(EOM_MARKER);
276            }
277            FramingMode::Chunked => {
278                let header = format!("\n#{}\n", item.len());
279                dst.reserve(header.len() + item.len() + CHUNKED_EOM_MARKER_LEN);
280                dst.extend_from_slice(header.as_bytes());
281                dst.extend_from_slice(&item);
282                dst.extend_from_slice(CHUNKED_EOM_MARKER);
283            }
284        }
285        Ok(())
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    // ── EOM Decoder tests ───────────────────────────────────────────────
294
295    #[test]
296    fn eom_decode_complete_message() {
297        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
298        let mut buf = BytesMut::from(&b"<rpc-reply/>]]>]]>"[..]);
299        let result = codec.decode(&mut buf).unwrap();
300        assert_eq!(result, Some(Bytes::from_static(b"<rpc-reply/>")));
301        assert!(buf.is_empty());
302    }
303
304    #[test]
305    fn eom_decode_incomplete_message() {
306        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
307        let mut buf = BytesMut::from(&b"<rpc-reply/>"[..]);
308        let result = codec.decode(&mut buf).unwrap();
309        assert_eq!(result, None);
310    }
311
312    #[test]
313    fn eom_decode_partial_marker() {
314        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
315        let mut buf = BytesMut::from(&b"<ok/>]]>"[..]);
316        assert_eq!(codec.decode(&mut buf).unwrap(), None);
317        // Now complete the marker
318        buf.extend_from_slice(b"]]>");
319        let result = codec.decode(&mut buf).unwrap();
320        assert_eq!(result, Some(Bytes::from_static(b"<ok/>")));
321    }
322
323    #[test]
324    fn eom_decode_empty_message() {
325        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
326        let mut buf = BytesMut::from(&b"]]>]]>"[..]);
327        let result = codec.decode(&mut buf).unwrap();
328        assert_eq!(result, Some(Bytes::from_static(b"")));
329    }
330
331    #[test]
332    fn eom_decode_two_messages() {
333        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
334        let mut buf = BytesMut::from(&b"<a/>]]>]]><b/>]]>]]>"[..]);
335        assert_eq!(
336            codec.decode(&mut buf).unwrap(),
337            Some(Bytes::from_static(b"<a/>"))
338        );
339        assert_eq!(
340            codec.decode(&mut buf).unwrap(),
341            Some(Bytes::from_static(b"<b/>"))
342        );
343    }
344
345    #[test]
346    fn eom_decode_size_limit_exceeded() {
347        let config = CodecConfig {
348            max_message_size: Some(5),
349        };
350        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, config);
351        let mut buf = BytesMut::from(&b"<too-large/>]]>]]>"[..]);
352        let err = codec.decode(&mut buf).unwrap_err();
353        assert!(matches!(err, FramingError::MessageTooLarge { .. }));
354    }
355
356    #[test]
357    fn eom_decode_size_limit_ok() {
358        let config = CodecConfig {
359            max_message_size: Some(100),
360        };
361        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, config);
362        let mut buf = BytesMut::from(&b"<ok/>]]>]]>"[..]);
363        assert!(codec.decode(&mut buf).unwrap().is_some());
364    }
365
366    // ── Chunked Decoder tests ───────────────────────────────────────────
367
368    #[test]
369    fn chunked_decode_single_chunk() {
370        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
371        let mut buf = BytesMut::from(&b"\n#7\n<data/>\n##\n"[..]);
372        let result = codec.decode(&mut buf).unwrap();
373        assert_eq!(result, Some(Bytes::from_static(b"<data/>")));
374        assert!(buf.is_empty());
375    }
376
377    #[test]
378    fn chunked_decode_multiple_chunks() {
379        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
380        let mut buf = BytesMut::from(&b"\n#5\nHello\n#6\n World\n##\n"[..]);
381        let result = codec.decode(&mut buf).unwrap();
382        assert_eq!(result, Some(Bytes::from_static(b"Hello World")));
383    }
384
385    #[test]
386    fn chunked_decode_incomplete_header() {
387        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
388        let mut buf = BytesMut::from(&b"\n#"[..]);
389        assert_eq!(codec.decode(&mut buf).unwrap(), None);
390    }
391
392    #[test]
393    fn chunked_decode_incomplete_data() {
394        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
395        let mut buf = BytesMut::from(&b"\n#10\nHello"[..]);
396        assert_eq!(codec.decode(&mut buf).unwrap(), None);
397        // Complete the data + end marker
398        buf.extend_from_slice(b" Wrld\n##\n");
399        let result = codec.decode(&mut buf).unwrap();
400        assert_eq!(result, Some(Bytes::from_static(b"Hello Wrld")));
401    }
402
403    #[test]
404    fn chunked_decode_large_chunk() {
405        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
406        let data = "x".repeat(10000);
407        let mut buf = BytesMut::new();
408        buf.extend_from_slice(format!("\n#{}\n", data.len()).as_bytes());
409        buf.extend_from_slice(data.as_bytes());
410        buf.extend_from_slice(b"\n##\n");
411        let result = codec.decode(&mut buf).unwrap();
412        assert_eq!(result.unwrap().len(), 10000);
413    }
414
415    #[test]
416    fn chunked_decode_invalid_header() {
417        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
418        let mut buf = BytesMut::from(&b"\n#abc\n"[..]);
419        let err = codec.decode(&mut buf).unwrap_err();
420        assert!(matches!(err, FramingError::InvalidChunkSize(_)));
421    }
422
423    #[test]
424    fn chunked_decode_zero_chunk_size() {
425        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
426        let mut buf = BytesMut::from(&b"\n#0\n\n##\n"[..]);
427        let err = codec.decode(&mut buf).unwrap_err();
428        assert!(matches!(err, FramingError::InvalidChunkSize(_)));
429    }
430
431    #[test]
432    fn chunked_decode_size_limit() {
433        let config = CodecConfig {
434            max_message_size: Some(5),
435        };
436        let mut codec = NetconfCodec::new(FramingMode::Chunked, config);
437        let mut buf = BytesMut::from(&b"\n#10\n0123456789\n##\n"[..]);
438        let err = codec.decode(&mut buf).unwrap_err();
439        assert!(matches!(err, FramingError::MessageTooLarge { .. }));
440    }
441
442    // ── Encoder tests ───────────────────────────────────────────────────
443
444    #[test]
445    fn eom_encode() {
446        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
447        let mut buf = BytesMut::new();
448        codec
449            .encode(Bytes::from_static(b"<ok/>"), &mut buf)
450            .unwrap();
451        assert_eq!(&buf[..], b"<ok/>]]>]]>");
452    }
453
454    #[test]
455    fn chunked_encode() {
456        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
457        let mut buf = BytesMut::new();
458        codec
459            .encode(Bytes::from_static(b"<ok/>"), &mut buf)
460            .unwrap();
461        assert_eq!(&buf[..], b"\n#5\n<ok/>\n##\n");
462    }
463
464    // ── Roundtrip tests ─────────────────────────────────────────────────
465
466    #[test]
467    fn eom_roundtrip() {
468        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
469        let original = Bytes::from_static(b"<rpc message-id=\"1\"><get/></rpc>");
470        let mut buf = BytesMut::new();
471        codec.encode(original.clone(), &mut buf).unwrap();
472        let decoded = codec.decode(&mut buf).unwrap().unwrap();
473        assert_eq!(decoded, original);
474    }
475
476    #[test]
477    fn chunked_roundtrip() {
478        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
479        let original = Bytes::from_static(b"<rpc message-id=\"1\"><get/></rpc>");
480        let mut buf = BytesMut::new();
481        codec.encode(original.clone(), &mut buf).unwrap();
482        let decoded = codec.decode(&mut buf).unwrap().unwrap();
483        assert_eq!(decoded, original);
484    }
485
486    #[test]
487    fn mode_switch() {
488        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
489
490        // Encode/decode in EOM mode
491        let mut buf = BytesMut::new();
492        codec
493            .encode(Bytes::from_static(b"hello"), &mut buf)
494            .unwrap();
495        assert_eq!(
496            codec.decode(&mut buf).unwrap(),
497            Some(Bytes::from_static(b"hello"))
498        );
499
500        // Switch to chunked
501        codec.set_mode(FramingMode::Chunked);
502
503        let mut buf = BytesMut::new();
504        codec
505            .encode(Bytes::from_static(b"world"), &mut buf)
506            .unwrap();
507        assert_eq!(
508            codec.decode(&mut buf).unwrap(),
509            Some(Bytes::from_static(b"world"))
510        );
511    }
512
513    // ── EOM hello helpers tests ─────────────────────────────────────────
514
515    #[tokio::test]
516    async fn eom_helper_roundtrip() {
517        let (mut client, mut server) = tokio::io::duplex(4096);
518
519        let msg = "<hello/>";
520        tokio::spawn(async move {
521            write_eom_message(&mut server, msg).await.unwrap();
522        });
523
524        let received = read_eom_message(&mut client, None).await.unwrap();
525        assert_eq!(received, msg);
526    }
527
528    #[tokio::test]
529    async fn eom_helper_size_limit() {
530        let (mut client, mut server) = tokio::io::duplex(4096);
531
532        let msg = "x".repeat(1000);
533        tokio::spawn(async move {
534            write_eom_message(&mut server, &msg).await.unwrap();
535        });
536
537        let result = read_eom_message(&mut client, Some(10)).await;
538        assert!(result.is_err());
539    }
540}