Skip to main content

netconf_rust/
codec.rs

1use bytes::{Buf, Bytes, BytesMut};
2use log::{debug, trace};
3use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4use tokio_util::codec::{Decoder, Encoder};
5
6use crate::error::FramingError;
7
8const EOM_MARKER: &[u8] = b"]]>]]>";
9const EOM_LEN: usize = EOM_MARKER.len();
10
11const CHUNKED_EOM_MARKER: &[u8] = b"\n##\n";
12const CHUNKED_EOM_MARKER_LEN: usize = CHUNKED_EOM_MARKER.len();
13
14const CHUNKED_HEADER_START: &[u8] = b"\n#";
15
16/// Netconf defines 2 framing modes.
17/// 1. Netconf 1.0 defines End of Message (EOM) in RFC 4742, where each message is followed by
18///    the literal sequence `]]>]]>`. For example, `<rpc-reply message-id="1"><ok/></rpc-reply>]]>]]>`
19/// 2. Netconf 1.1 defines Chunked in RFC 6242, where each message is sent with length-prefixed
20///    chunks. The chunk starts with \n#{num_of_bytes_in_msg}\n. \n##\n defines end of message
21///    for example `\n#28\n<rpc-reply><ok/></rpc-re\n#6\nply/>\n##\n`
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum FramingMode {
24    /// NETCONF 1.0 (RFC 4742): messages terminated by `]]>]]`
25    EndOfMessage,
26    /// NETCONF 1.1 (RFC 6242): length-prefixed chunked framing
27    Chunked,
28}
29
30#[derive(Default, Debug, Clone, Copy)]
31pub struct CodecConfig {
32    pub max_message_size: Option<usize>, // None = unlimited
33}
34
35pub struct NetconfCodec {
36    framing_mode: FramingMode,
37    config: CodecConfig,
38    chunked_buf: BytesMut,
39    /// Tracks how far we've scanned for the EOM marker in the buffer.
40    /// On the next decode() call, we resume scanning from this offset
41    /// (backed up by EOM_LEN-1 to handle markers split across reads)
42    /// instead of rescanning from the start. This matters for devices
43    /// with small TCP windows (e.g. 4KB) where a large response arrives
44    /// in hundreds of segments.
45    eom_search_offset: usize,
46}
47/// tokio_util codec for NETCONF message framing.
48///
49/// Implements the `Decoder` trait from tokio_util, which defines how to
50/// extract discrete messages from a continuous TCP byte stream.
51///
52/// We don't call `decode()` directly. Instead, we pass this codec to
53/// `FramedRead`, which handles the I/O loop:
54///
55///   let reader = FramedRead::new(read_half, NetconfCodec::new(...));
56///   while let Some(msg) = reader.next().await { ... }
57///
58/// `FramedRead` reads bytes from the network into an internal buffer,
59/// calls our `decode()` to check for a complete message, and yields
60/// messages as a Stream. We only need to implement the framing logic --
61/// "is there a complete message in this buffer?" -- and `FramedRead`
62/// handles everything else (buffer management, read loops, EOF).
63///
64/// decode() is called repeatedly. It owns the internal buffer. When
65/// decode() returns Ok(None), FramedRead reads more bytes from the
66/// underlying AsyncRead into the buffer and calls decode() again.
67/// When decode() returns Ok(Some(item)), the item is yielded to the
68/// consumer. The key invariant is: decode() must consume
69/// (via split_to() or advance()) exactly the bytes belonging to the
70/// returned message. Any leftover bytes in src remain for the next
71/// decode() call.
72///
73/// Returns `Bytes` (not `String`) because UTF-8 validation is deferred
74/// to the XML parser. This avoids a redundant O(n) validation pass at
75/// the framing layer
76impl NetconfCodec {
77    pub fn new(framing_mode: FramingMode, config: CodecConfig) -> Self {
78        Self {
79            framing_mode,
80            config,
81            chunked_buf: BytesMut::new(),
82            eom_search_offset: 0,
83        }
84    }
85
86    pub fn set_mode(&mut self, framing_mode: FramingMode) {
87        self.framing_mode = framing_mode;
88        self.chunked_buf.clear();
89        self.eom_search_offset = 0;
90    }
91    pub fn framing_mode(&self) -> FramingMode {
92        self.framing_mode
93    }
94
95    fn check_size(&self, size: usize) -> Result<(), FramingError> {
96        if let Some(max_size) = self.config.max_message_size
97            && size > max_size
98        {
99            return Err(FramingError::MessageTooLarge {
100                limit: max_size,
101                received: size,
102            });
103        }
104        Ok(())
105    }
106
107    /// Scans src for ]]>]]> and returns message bytes if found otherwise we need more data.
108    /// We assume that the marker only appears at the message boundry because
109    /// 1.  XML itself forbids ]]> inside content. so ]]>]]> can never appear in well-formed XML.
110    /// 2.  RFC 4742 Section 4.2 defines the framing: the message is terminated by ]]>]]>.
111    ///     The content should always be valid XML, and valid XML can't contain ]]>
112    ///     However, in Netconf 1.0, there is actually no guarantee. This fact relies on the netconf server actually
113    ///     returning well formed XML. If it doesn't then our code will find the first occurrence
114    ///     the marker, split the message in the wrong place and then return an error.
115    fn decode_eom(&mut self, src: &mut BytesMut) -> Result<Option<Bytes>, FramingError> {
116        // The smallest possible complete message is EOM_MARKER, so if src is smaller than
117        // that, then we need more data
118        if src.len() < EOM_LEN {
119            trace!(
120                "eom: buffer too small ({} bytes), need more data",
121                src.len()
122            );
123            return Ok(None);
124        }
125
126        // Resume scanning from where we left off, backed up by EOM_LEN-1
127        // to handle markers that were split across the previous and current read.
128        // Uses memchr's SIMD-accelerated search
129        let search_start = self.eom_search_offset.saturating_sub(EOM_LEN - 1);
130        trace!(
131            "eom: scanning {} bytes (buffer={}, search_offset={}, search_start={})",
132            src.len() - search_start,
133            src.len(),
134            self.eom_search_offset,
135            search_start
136        );
137        if let Some(pos) = memchr::memmem::find(&src[search_start..], EOM_MARKER) {
138            let msg_len = search_start + pos;
139            self.check_size(msg_len)?;
140            let msg = src.split_to(msg_len).freeze();
141            src.advance(EOM_LEN);
142            self.eom_search_offset = 0;
143            debug!("eom: decoded message ({} bytes)", msg_len);
144            trace!(
145                "eom: message preview: {:?}",
146                String::from_utf8_lossy(&msg[..msg.len().min(200)])
147            );
148            Ok(Some(msg))
149        } else {
150            self.eom_search_offset = src.len();
151            self.check_size(src.len())?;
152            trace!("eom: no marker found, buffered {} bytes total", src.len());
153            Ok(None)
154        }
155    }
156
157    fn decode_chunked(&mut self, src: &mut BytesMut) -> Result<Option<Bytes>, FramingError> {
158        // The smallest possible complete message is CHUNKED_EOM_MARKER. Otherwise more
159        // data is expected.
160        loop {
161            if src.len() < CHUNKED_EOM_MARKER_LEN {
162                trace!(
163                    "chunked: buffer too small ({} bytes), accumulated {} bytes so far",
164                    src.len(),
165                    self.chunked_buf.len()
166                );
167                return Ok(None);
168            }
169
170            // Every chunk or end marker starts with \n#
171            if src[0..2] != *CHUNKED_HEADER_START {
172                return Err(FramingError::InvalidHeader {
173                    expected: "\\n#",
174                    got: src[..2].to_vec(),
175                });
176            }
177
178            // Check for end of chunks marker: \n##\n.
179            // We know the first 2 bytes are valid already
180            if src[2] == b'#' {
181                if src[3] != b'\n' {
182                    return Err(FramingError::InvalidHeader {
183                        expected: "\\n##\\n",
184                        got: src[..4].to_vec(),
185                    });
186                }
187
188                // We found end of message. Consume \n##\n
189                src.advance(CHUNKED_EOM_MARKER_LEN);
190                let msg = self.chunked_buf.split().freeze();
191                debug!("chunked: decoded message ({} bytes)", msg.len());
192                trace!(
193                    "chunked: message preview: {:?}",
194                    String::from_utf8_lossy(&msg[..msg.len().min(200)])
195                );
196                return Ok(Some(msg));
197            }
198
199            // Otherwise not end of message, so we have chunk header \n#<size>\n
200            let header_start = 2; // skip \n#
201            let header_end = match src[header_start..].iter().position(|&b| b == b'\n') {
202                Some(pos_end_of_header) => header_start + pos_end_of_header,
203                None => {
204                    // Header not yet complete - need more data.
205                    // But sanity-check the chuck size
206                    if src.len() > 20 {
207                        return Err(FramingError::InvalidChunkSize(
208                            String::from_utf8_lossy(&src[header_start..]).into_owned(),
209                        ));
210                    }
211                    return Ok(None);
212                }
213            };
214
215            // Extract the chunk size from the header and parse into usize
216            let size_str = &src[header_start..header_end];
217            let chunk_size: usize = std::str::from_utf8(size_str)
218                .map_err(|_| {
219                    FramingError::InvalidChunkSize(String::from_utf8_lossy(size_str).into_owned())
220                })?
221                .parse()
222                .map_err(|_| {
223                    FramingError::InvalidChunkSize(String::from_utf8_lossy(size_str).into_owned())
224                })?;
225
226            if chunk_size == 0 {
227                return Err(FramingError::InvalidChunkSize("0".into()));
228            }
229
230            // Total header length  \n# + size digits + \n
231            let header_len = header_end + 1; // +1 for the trailing \n
232
233            // Check if the full chunk (header + data) is available
234            let total_chunk_len = header_len + chunk_size;
235            if src.len() < total_chunk_len {
236                trace!(
237                    "chunked: need {} more bytes for chunk (have {}, need {})",
238                    total_chunk_len - src.len(),
239                    src.len(),
240                    total_chunk_len
241                );
242                return Ok(None); // Need more data for the chunk
243            }
244            self.check_size(self.chunked_buf.len() + chunk_size)?;
245
246            trace!(
247                "chunked: consuming chunk ({} bytes, accumulated {} bytes)",
248                chunk_size,
249                self.chunked_buf.len() + chunk_size
250            );
251
252            // Consume header
253            src.advance(header_len);
254
255            // Consume the chunk of data
256            self.chunked_buf.extend_from_slice(&src[..chunk_size]);
257            src.advance(chunk_size);
258        }
259    }
260}
261
262pub(crate) async fn read_eom_message<R: AsyncRead + Unpin>(
263    reader: &mut R,
264    max_size: Option<usize>,
265) -> crate::Result<String> {
266    let mut buf = Vec::with_capacity(4096);
267    let mut tmp = [0u8; 4096];
268
269    loop {
270        let read_bytes = reader.read(&mut tmp).await?;
271
272        if read_bytes == 0 {
273            debug!("read_eom: unexpected EOF after {} bytes", buf.len());
274            return Err(FramingError::UnexpectedEof.into());
275        }
276        buf.extend_from_slice(&tmp[..read_bytes]);
277        trace!(
278            "read_eom: read {} bytes, buffer now {} bytes",
279            read_bytes,
280            buf.len()
281        );
282        if let Some(limit) = max_size
283            && buf.len() > limit + EOM_LEN
284        {
285            return Err(FramingError::MessageTooLarge {
286                limit,
287                received: buf.len(),
288            }
289            .into());
290        }
291        if let Some(pos) = memchr::memmem::find(&buf, EOM_MARKER) {
292            buf.truncate(pos);
293            debug!("read_eom: complete message ({} bytes)", buf.len());
294            return String::from_utf8(buf).map_err(|_| FramingError::InvalidUtf8.into());
295        }
296    }
297}
298
299pub(crate) async fn write_eom_message<W: AsyncWrite + Unpin>(
300    writer: &mut W,
301    message: &str,
302) -> crate::Result<()> {
303    writer.write_all(message.as_bytes()).await?;
304    writer.write_all(EOM_MARKER).await?;
305    writer.flush().await?;
306    Ok(())
307}
308
309/// <https://docs.rs/tokio-util/latest/tokio_util/codec/trait.Decoder.html>
310impl Decoder for NetconfCodec {
311    type Item = Bytes;
312    type Error = FramingError;
313
314    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
315        match self.framing_mode {
316            FramingMode::EndOfMessage => self.decode_eom(src),
317            FramingMode::Chunked => self.decode_chunked(src),
318        }
319    }
320}
321
322impl Encoder<Bytes> for NetconfCodec {
323    type Error = FramingError;
324
325    fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
326        debug!(
327            "encode: framing={:?}, message={} bytes",
328            self.framing_mode,
329            item.len()
330        );
331        trace!(
332            "encode: message preview: {:?}",
333            String::from_utf8_lossy(&item[..item.len().min(200)])
334        );
335        match self.framing_mode {
336            FramingMode::EndOfMessage => {
337                dst.reserve(item.len() + EOM_LEN);
338                dst.extend_from_slice(&item);
339                dst.extend_from_slice(EOM_MARKER);
340            }
341            FramingMode::Chunked => {
342                let header = format!("\n#{}\n", item.len());
343                dst.reserve(header.len() + item.len() + CHUNKED_EOM_MARKER_LEN);
344                dst.extend_from_slice(header.as_bytes());
345                dst.extend_from_slice(&item);
346                dst.extend_from_slice(CHUNKED_EOM_MARKER);
347            }
348        }
349        Ok(())
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    // ── EOM Decoder tests ───────────────────────────────────────────────
358
359    #[test]
360    fn eom_decode_complete_message() {
361        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
362        let mut buf = BytesMut::from(&b"<rpc-reply/>]]>]]>"[..]);
363        let result = codec.decode(&mut buf).unwrap();
364        assert_eq!(result, Some(Bytes::from_static(b"<rpc-reply/>")));
365        assert!(buf.is_empty());
366    }
367
368    #[test]
369    fn eom_decode_incomplete_message() {
370        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
371        let mut buf = BytesMut::from(&b"<rpc-reply/>"[..]);
372        let result = codec.decode(&mut buf).unwrap();
373        assert_eq!(result, None);
374    }
375
376    #[test]
377    fn eom_decode_partial_marker() {
378        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
379        let mut buf = BytesMut::from(&b"<ok/>]]>"[..]);
380        assert_eq!(codec.decode(&mut buf).unwrap(), None);
381        // Now complete the marker
382        buf.extend_from_slice(b"]]>");
383        let result = codec.decode(&mut buf).unwrap();
384        assert_eq!(result, Some(Bytes::from_static(b"<ok/>")));
385    }
386
387    #[test]
388    fn eom_decode_empty_message() {
389        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
390        let mut buf = BytesMut::from(&b"]]>]]>"[..]);
391        let result = codec.decode(&mut buf).unwrap();
392        assert_eq!(result, Some(Bytes::from_static(b"")));
393    }
394
395    #[test]
396    fn eom_decode_two_messages() {
397        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
398        let mut buf = BytesMut::from(&b"<a/>]]>]]><b/>]]>]]>"[..]);
399        assert_eq!(
400            codec.decode(&mut buf).unwrap(),
401            Some(Bytes::from_static(b"<a/>"))
402        );
403        assert_eq!(
404            codec.decode(&mut buf).unwrap(),
405            Some(Bytes::from_static(b"<b/>"))
406        );
407    }
408
409    #[test]
410    fn eom_decode_size_limit_exceeded() {
411        let config = CodecConfig {
412            max_message_size: Some(5),
413        };
414        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, config);
415        let mut buf = BytesMut::from(&b"<too-large/>]]>]]>"[..]);
416        let err = codec.decode(&mut buf).unwrap_err();
417        assert!(matches!(err, FramingError::MessageTooLarge { .. }));
418    }
419
420    #[test]
421    fn eom_decode_size_limit_ok() {
422        let config = CodecConfig {
423            max_message_size: Some(100),
424        };
425        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, config);
426        let mut buf = BytesMut::from(&b"<ok/>]]>]]>"[..]);
427        assert!(codec.decode(&mut buf).unwrap().is_some());
428    }
429
430    // ── Chunked Decoder tests ───────────────────────────────────────────
431
432    #[test]
433    fn chunked_decode_single_chunk() {
434        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
435        let mut buf = BytesMut::from(&b"\n#7\n<data/>\n##\n"[..]);
436        let result = codec.decode(&mut buf).unwrap();
437        assert_eq!(result, Some(Bytes::from_static(b"<data/>")));
438        assert!(buf.is_empty());
439    }
440
441    #[test]
442    fn chunked_decode_multiple_chunks() {
443        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
444        let mut buf = BytesMut::from(&b"\n#5\nHello\n#6\n World\n##\n"[..]);
445        let result = codec.decode(&mut buf).unwrap();
446        assert_eq!(result, Some(Bytes::from_static(b"Hello World")));
447    }
448
449    #[test]
450    fn chunked_decode_incomplete_header() {
451        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
452        let mut buf = BytesMut::from(&b"\n#"[..]);
453        assert_eq!(codec.decode(&mut buf).unwrap(), None);
454    }
455
456    #[test]
457    fn chunked_decode_incomplete_data() {
458        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
459        let mut buf = BytesMut::from(&b"\n#10\nHello"[..]);
460        assert_eq!(codec.decode(&mut buf).unwrap(), None);
461        // Complete the data + end marker
462        buf.extend_from_slice(b" Wrld\n##\n");
463        let result = codec.decode(&mut buf).unwrap();
464        assert_eq!(result, Some(Bytes::from_static(b"Hello Wrld")));
465    }
466
467    #[test]
468    fn chunked_decode_large_chunk() {
469        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
470        let data = "x".repeat(10000);
471        let mut buf = BytesMut::new();
472        buf.extend_from_slice(format!("\n#{}\n", data.len()).as_bytes());
473        buf.extend_from_slice(data.as_bytes());
474        buf.extend_from_slice(b"\n##\n");
475        let result = codec.decode(&mut buf).unwrap();
476        assert_eq!(result.unwrap().len(), 10000);
477    }
478
479    #[test]
480    fn chunked_decode_invalid_header() {
481        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
482        let mut buf = BytesMut::from(&b"\n#abc\n"[..]);
483        let err = codec.decode(&mut buf).unwrap_err();
484        assert!(matches!(err, FramingError::InvalidChunkSize(_)));
485    }
486
487    #[test]
488    fn chunked_decode_zero_chunk_size() {
489        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
490        let mut buf = BytesMut::from(&b"\n#0\n\n##\n"[..]);
491        let err = codec.decode(&mut buf).unwrap_err();
492        assert!(matches!(err, FramingError::InvalidChunkSize(_)));
493    }
494
495    #[test]
496    fn chunked_decode_size_limit() {
497        let config = CodecConfig {
498            max_message_size: Some(5),
499        };
500        let mut codec = NetconfCodec::new(FramingMode::Chunked, config);
501        let mut buf = BytesMut::from(&b"\n#10\n0123456789\n##\n"[..]);
502        let err = codec.decode(&mut buf).unwrap_err();
503        assert!(matches!(err, FramingError::MessageTooLarge { .. }));
504    }
505
506    // ── Encoder tests ───────────────────────────────────────────────────
507
508    #[test]
509    fn eom_encode() {
510        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
511        let mut buf = BytesMut::new();
512        codec
513            .encode(Bytes::from_static(b"<ok/>"), &mut buf)
514            .unwrap();
515        assert_eq!(&buf[..], b"<ok/>]]>]]>");
516    }
517
518    #[test]
519    fn chunked_encode() {
520        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
521        let mut buf = BytesMut::new();
522        codec
523            .encode(Bytes::from_static(b"<ok/>"), &mut buf)
524            .unwrap();
525        assert_eq!(&buf[..], b"\n#5\n<ok/>\n##\n");
526    }
527
528    // ── Roundtrip tests ─────────────────────────────────────────────────
529
530    #[test]
531    fn eom_roundtrip() {
532        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
533        let original = Bytes::from_static(b"<rpc message-id=\"1\"><get/></rpc>");
534        let mut buf = BytesMut::new();
535        codec.encode(original.clone(), &mut buf).unwrap();
536        let decoded = codec.decode(&mut buf).unwrap().unwrap();
537        assert_eq!(decoded, original);
538    }
539
540    #[test]
541    fn chunked_roundtrip() {
542        let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
543        let original = Bytes::from_static(b"<rpc message-id=\"1\"><get/></rpc>");
544        let mut buf = BytesMut::new();
545        codec.encode(original.clone(), &mut buf).unwrap();
546        let decoded = codec.decode(&mut buf).unwrap().unwrap();
547        assert_eq!(decoded, original);
548    }
549
550    #[test]
551    fn mode_switch() {
552        let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
553
554        // Encode/decode in EOM mode
555        let mut buf = BytesMut::new();
556        codec
557            .encode(Bytes::from_static(b"hello"), &mut buf)
558            .unwrap();
559        assert_eq!(
560            codec.decode(&mut buf).unwrap(),
561            Some(Bytes::from_static(b"hello"))
562        );
563
564        // Switch to chunked
565        codec.set_mode(FramingMode::Chunked);
566
567        let mut buf = BytesMut::new();
568        codec
569            .encode(Bytes::from_static(b"world"), &mut buf)
570            .unwrap();
571        assert_eq!(
572            codec.decode(&mut buf).unwrap(),
573            Some(Bytes::from_static(b"world"))
574        );
575    }
576
577    // ── EOM hello helpers tests ─────────────────────────────────────────
578
579    #[tokio::test]
580    async fn eom_helper_roundtrip() {
581        let (mut client, mut server) = tokio::io::duplex(4096);
582
583        let msg = "<hello/>";
584        tokio::spawn(async move {
585            write_eom_message(&mut server, msg).await.unwrap();
586        });
587
588        let received = read_eom_message(&mut client, None).await.unwrap();
589        assert_eq!(received, msg);
590    }
591
592    #[tokio::test]
593    async fn eom_helper_size_limit() {
594        let (mut client, mut server) = tokio::io::duplex(4096);
595
596        let msg = "x".repeat(1000);
597        tokio::spawn(async move {
598            write_eom_message(&mut server, &msg).await.unwrap();
599        });
600
601        let result = read_eom_message(&mut client, Some(10)).await;
602        assert!(result.is_err());
603    }
604}