1use bytes::{Buf, Bytes, BytesMut};
2use log::{debug, trace};
3use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4use tokio_util::codec::{Decoder, Encoder};
5
6use crate::error::FramingError;
7
8const EOM_MARKER: &[u8] = b"]]>]]>";
9const EOM_LEN: usize = EOM_MARKER.len();
10
11const CHUNKED_EOM_MARKER: &[u8] = b"\n##\n";
12const CHUNKED_EOM_MARKER_LEN: usize = CHUNKED_EOM_MARKER.len();
13
14const CHUNKED_HEADER_START: &[u8] = b"\n#";
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum FramingMode {
24 EndOfMessage,
26 Chunked,
28}
29
30#[derive(Default, Debug, Clone, Copy)]
31pub struct CodecConfig {
32 pub max_message_size: Option<usize>, }
34
35pub struct NetconfCodec {
36 framing_mode: FramingMode,
37 config: CodecConfig,
38 chunked_buf: BytesMut,
39 eom_search_offset: usize,
46}
47impl NetconfCodec {
77 pub fn new(framing_mode: FramingMode, config: CodecConfig) -> Self {
78 Self {
79 framing_mode,
80 config,
81 chunked_buf: BytesMut::new(),
82 eom_search_offset: 0,
83 }
84 }
85
86 pub fn set_mode(&mut self, framing_mode: FramingMode) {
87 self.framing_mode = framing_mode;
88 self.chunked_buf.clear();
89 self.eom_search_offset = 0;
90 }
91 pub fn framing_mode(&self) -> FramingMode {
92 self.framing_mode
93 }
94
95 fn check_size(&self, size: usize) -> Result<(), FramingError> {
96 if let Some(max_size) = self.config.max_message_size
97 && size > max_size
98 {
99 return Err(FramingError::MessageTooLarge {
100 limit: max_size,
101 received: size,
102 });
103 }
104 Ok(())
105 }
106
107 fn decode_eom(&mut self, src: &mut BytesMut) -> Result<Option<Bytes>, FramingError> {
116 if src.len() < EOM_LEN {
119 trace!(
120 "eom: buffer too small ({} bytes), need more data",
121 src.len()
122 );
123 return Ok(None);
124 }
125
126 let search_start = self.eom_search_offset.saturating_sub(EOM_LEN - 1);
130 trace!(
131 "eom: scanning {} bytes (buffer={}, search_offset={}, search_start={})",
132 src.len() - search_start,
133 src.len(),
134 self.eom_search_offset,
135 search_start
136 );
137 if let Some(pos) = memchr::memmem::find(&src[search_start..], EOM_MARKER) {
138 let msg_len = search_start + pos;
139 self.check_size(msg_len)?;
140 let msg = src.split_to(msg_len).freeze();
141 src.advance(EOM_LEN);
142 self.eom_search_offset = 0;
143 debug!("eom: decoded message ({} bytes)", msg_len);
144 trace!(
145 "eom: message preview: {:?}",
146 String::from_utf8_lossy(&msg[..msg.len().min(200)])
147 );
148 Ok(Some(msg))
149 } else {
150 self.eom_search_offset = src.len();
151 self.check_size(src.len())?;
152 trace!("eom: no marker found, buffered {} bytes total", src.len());
153 Ok(None)
154 }
155 }
156
157 fn decode_chunked(&mut self, src: &mut BytesMut) -> Result<Option<Bytes>, FramingError> {
158 loop {
161 if src.len() < CHUNKED_EOM_MARKER_LEN {
162 trace!(
163 "chunked: buffer too small ({} bytes), accumulated {} bytes so far",
164 src.len(),
165 self.chunked_buf.len()
166 );
167 return Ok(None);
168 }
169
170 if src[0..2] != *CHUNKED_HEADER_START {
172 return Err(FramingError::InvalidHeader {
173 expected: "\\n#",
174 got: src[..2].to_vec(),
175 });
176 }
177
178 if src[2] == b'#' {
181 if src[3] != b'\n' {
182 return Err(FramingError::InvalidHeader {
183 expected: "\\n##\\n",
184 got: src[..4].to_vec(),
185 });
186 }
187
188 src.advance(CHUNKED_EOM_MARKER_LEN);
190 let msg = self.chunked_buf.split().freeze();
191 debug!("chunked: decoded message ({} bytes)", msg.len());
192 trace!(
193 "chunked: message preview: {:?}",
194 String::from_utf8_lossy(&msg[..msg.len().min(200)])
195 );
196 return Ok(Some(msg));
197 }
198
199 let header_start = 2; let header_end = match src[header_start..].iter().position(|&b| b == b'\n') {
202 Some(pos_end_of_header) => header_start + pos_end_of_header,
203 None => {
204 if src.len() > 20 {
207 return Err(FramingError::InvalidChunkSize(
208 String::from_utf8_lossy(&src[header_start..]).into_owned(),
209 ));
210 }
211 return Ok(None);
212 }
213 };
214
215 let size_str = &src[header_start..header_end];
217 let chunk_size: usize = std::str::from_utf8(size_str)
218 .map_err(|_| {
219 FramingError::InvalidChunkSize(String::from_utf8_lossy(size_str).into_owned())
220 })?
221 .parse()
222 .map_err(|_| {
223 FramingError::InvalidChunkSize(String::from_utf8_lossy(size_str).into_owned())
224 })?;
225
226 if chunk_size == 0 {
227 return Err(FramingError::InvalidChunkSize("0".into()));
228 }
229
230 let header_len = header_end + 1; let total_chunk_len = header_len + chunk_size;
235 if src.len() < total_chunk_len {
236 trace!(
237 "chunked: need {} more bytes for chunk (have {}, need {})",
238 total_chunk_len - src.len(),
239 src.len(),
240 total_chunk_len
241 );
242 return Ok(None); }
244 self.check_size(self.chunked_buf.len() + chunk_size)?;
245
246 trace!(
247 "chunked: consuming chunk ({} bytes, accumulated {} bytes)",
248 chunk_size,
249 self.chunked_buf.len() + chunk_size
250 );
251
252 src.advance(header_len);
254
255 self.chunked_buf.extend_from_slice(&src[..chunk_size]);
257 src.advance(chunk_size);
258 }
259 }
260}
261
262pub(crate) async fn read_eom_message<R: AsyncRead + Unpin>(
263 reader: &mut R,
264 max_size: Option<usize>,
265) -> crate::Result<String> {
266 let mut buf = Vec::with_capacity(4096);
267 let mut tmp = [0u8; 4096];
268
269 loop {
270 let read_bytes = reader.read(&mut tmp).await?;
271
272 if read_bytes == 0 {
273 debug!("read_eom: unexpected EOF after {} bytes", buf.len());
274 return Err(FramingError::UnexpectedEof.into());
275 }
276 buf.extend_from_slice(&tmp[..read_bytes]);
277 trace!(
278 "read_eom: read {} bytes, buffer now {} bytes",
279 read_bytes,
280 buf.len()
281 );
282 if let Some(limit) = max_size
283 && buf.len() > limit + EOM_LEN
284 {
285 return Err(FramingError::MessageTooLarge {
286 limit,
287 received: buf.len(),
288 }
289 .into());
290 }
291 if let Some(pos) = memchr::memmem::find(&buf, EOM_MARKER) {
292 buf.truncate(pos);
293 debug!("read_eom: complete message ({} bytes)", buf.len());
294 return String::from_utf8(buf).map_err(|_| FramingError::InvalidUtf8.into());
295 }
296 }
297}
298
299pub(crate) async fn write_eom_message<W: AsyncWrite + Unpin>(
300 writer: &mut W,
301 message: &str,
302) -> crate::Result<()> {
303 writer.write_all(message.as_bytes()).await?;
304 writer.write_all(EOM_MARKER).await?;
305 writer.flush().await?;
306 Ok(())
307}
308
309impl Decoder for NetconfCodec {
311 type Item = Bytes;
312 type Error = FramingError;
313
314 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
315 match self.framing_mode {
316 FramingMode::EndOfMessage => self.decode_eom(src),
317 FramingMode::Chunked => self.decode_chunked(src),
318 }
319 }
320}
321
322impl Encoder<Bytes> for NetconfCodec {
323 type Error = FramingError;
324
325 fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
326 debug!(
327 "encode: framing={:?}, message={} bytes",
328 self.framing_mode,
329 item.len()
330 );
331 trace!(
332 "encode: message preview: {:?}",
333 String::from_utf8_lossy(&item[..item.len().min(200)])
334 );
335 match self.framing_mode {
336 FramingMode::EndOfMessage => {
337 dst.reserve(item.len() + EOM_LEN);
338 dst.extend_from_slice(&item);
339 dst.extend_from_slice(EOM_MARKER);
340 }
341 FramingMode::Chunked => {
342 let header = format!("\n#{}\n", item.len());
343 dst.reserve(header.len() + item.len() + CHUNKED_EOM_MARKER_LEN);
344 dst.extend_from_slice(header.as_bytes());
345 dst.extend_from_slice(&item);
346 dst.extend_from_slice(CHUNKED_EOM_MARKER);
347 }
348 }
349 Ok(())
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356
357 #[test]
360 fn eom_decode_complete_message() {
361 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
362 let mut buf = BytesMut::from(&b"<rpc-reply/>]]>]]>"[..]);
363 let result = codec.decode(&mut buf).unwrap();
364 assert_eq!(result, Some(Bytes::from_static(b"<rpc-reply/>")));
365 assert!(buf.is_empty());
366 }
367
368 #[test]
369 fn eom_decode_incomplete_message() {
370 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
371 let mut buf = BytesMut::from(&b"<rpc-reply/>"[..]);
372 let result = codec.decode(&mut buf).unwrap();
373 assert_eq!(result, None);
374 }
375
376 #[test]
377 fn eom_decode_partial_marker() {
378 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
379 let mut buf = BytesMut::from(&b"<ok/>]]>"[..]);
380 assert_eq!(codec.decode(&mut buf).unwrap(), None);
381 buf.extend_from_slice(b"]]>");
383 let result = codec.decode(&mut buf).unwrap();
384 assert_eq!(result, Some(Bytes::from_static(b"<ok/>")));
385 }
386
387 #[test]
388 fn eom_decode_empty_message() {
389 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
390 let mut buf = BytesMut::from(&b"]]>]]>"[..]);
391 let result = codec.decode(&mut buf).unwrap();
392 assert_eq!(result, Some(Bytes::from_static(b"")));
393 }
394
395 #[test]
396 fn eom_decode_two_messages() {
397 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
398 let mut buf = BytesMut::from(&b"<a/>]]>]]><b/>]]>]]>"[..]);
399 assert_eq!(
400 codec.decode(&mut buf).unwrap(),
401 Some(Bytes::from_static(b"<a/>"))
402 );
403 assert_eq!(
404 codec.decode(&mut buf).unwrap(),
405 Some(Bytes::from_static(b"<b/>"))
406 );
407 }
408
409 #[test]
410 fn eom_decode_size_limit_exceeded() {
411 let config = CodecConfig {
412 max_message_size: Some(5),
413 };
414 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, config);
415 let mut buf = BytesMut::from(&b"<too-large/>]]>]]>"[..]);
416 let err = codec.decode(&mut buf).unwrap_err();
417 assert!(matches!(err, FramingError::MessageTooLarge { .. }));
418 }
419
420 #[test]
421 fn eom_decode_size_limit_ok() {
422 let config = CodecConfig {
423 max_message_size: Some(100),
424 };
425 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, config);
426 let mut buf = BytesMut::from(&b"<ok/>]]>]]>"[..]);
427 assert!(codec.decode(&mut buf).unwrap().is_some());
428 }
429
430 #[test]
433 fn chunked_decode_single_chunk() {
434 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
435 let mut buf = BytesMut::from(&b"\n#7\n<data/>\n##\n"[..]);
436 let result = codec.decode(&mut buf).unwrap();
437 assert_eq!(result, Some(Bytes::from_static(b"<data/>")));
438 assert!(buf.is_empty());
439 }
440
441 #[test]
442 fn chunked_decode_multiple_chunks() {
443 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
444 let mut buf = BytesMut::from(&b"\n#5\nHello\n#6\n World\n##\n"[..]);
445 let result = codec.decode(&mut buf).unwrap();
446 assert_eq!(result, Some(Bytes::from_static(b"Hello World")));
447 }
448
449 #[test]
450 fn chunked_decode_incomplete_header() {
451 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
452 let mut buf = BytesMut::from(&b"\n#"[..]);
453 assert_eq!(codec.decode(&mut buf).unwrap(), None);
454 }
455
456 #[test]
457 fn chunked_decode_incomplete_data() {
458 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
459 let mut buf = BytesMut::from(&b"\n#10\nHello"[..]);
460 assert_eq!(codec.decode(&mut buf).unwrap(), None);
461 buf.extend_from_slice(b" Wrld\n##\n");
463 let result = codec.decode(&mut buf).unwrap();
464 assert_eq!(result, Some(Bytes::from_static(b"Hello Wrld")));
465 }
466
467 #[test]
468 fn chunked_decode_large_chunk() {
469 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
470 let data = "x".repeat(10000);
471 let mut buf = BytesMut::new();
472 buf.extend_from_slice(format!("\n#{}\n", data.len()).as_bytes());
473 buf.extend_from_slice(data.as_bytes());
474 buf.extend_from_slice(b"\n##\n");
475 let result = codec.decode(&mut buf).unwrap();
476 assert_eq!(result.unwrap().len(), 10000);
477 }
478
479 #[test]
480 fn chunked_decode_invalid_header() {
481 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
482 let mut buf = BytesMut::from(&b"\n#abc\n"[..]);
483 let err = codec.decode(&mut buf).unwrap_err();
484 assert!(matches!(err, FramingError::InvalidChunkSize(_)));
485 }
486
487 #[test]
488 fn chunked_decode_zero_chunk_size() {
489 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
490 let mut buf = BytesMut::from(&b"\n#0\n\n##\n"[..]);
491 let err = codec.decode(&mut buf).unwrap_err();
492 assert!(matches!(err, FramingError::InvalidChunkSize(_)));
493 }
494
495 #[test]
496 fn chunked_decode_size_limit() {
497 let config = CodecConfig {
498 max_message_size: Some(5),
499 };
500 let mut codec = NetconfCodec::new(FramingMode::Chunked, config);
501 let mut buf = BytesMut::from(&b"\n#10\n0123456789\n##\n"[..]);
502 let err = codec.decode(&mut buf).unwrap_err();
503 assert!(matches!(err, FramingError::MessageTooLarge { .. }));
504 }
505
506 #[test]
509 fn eom_encode() {
510 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
511 let mut buf = BytesMut::new();
512 codec
513 .encode(Bytes::from_static(b"<ok/>"), &mut buf)
514 .unwrap();
515 assert_eq!(&buf[..], b"<ok/>]]>]]>");
516 }
517
518 #[test]
519 fn chunked_encode() {
520 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
521 let mut buf = BytesMut::new();
522 codec
523 .encode(Bytes::from_static(b"<ok/>"), &mut buf)
524 .unwrap();
525 assert_eq!(&buf[..], b"\n#5\n<ok/>\n##\n");
526 }
527
528 #[test]
531 fn eom_roundtrip() {
532 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
533 let original = Bytes::from_static(b"<rpc message-id=\"1\"><get/></rpc>");
534 let mut buf = BytesMut::new();
535 codec.encode(original.clone(), &mut buf).unwrap();
536 let decoded = codec.decode(&mut buf).unwrap().unwrap();
537 assert_eq!(decoded, original);
538 }
539
540 #[test]
541 fn chunked_roundtrip() {
542 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
543 let original = Bytes::from_static(b"<rpc message-id=\"1\"><get/></rpc>");
544 let mut buf = BytesMut::new();
545 codec.encode(original.clone(), &mut buf).unwrap();
546 let decoded = codec.decode(&mut buf).unwrap().unwrap();
547 assert_eq!(decoded, original);
548 }
549
550 #[test]
551 fn mode_switch() {
552 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
553
554 let mut buf = BytesMut::new();
556 codec
557 .encode(Bytes::from_static(b"hello"), &mut buf)
558 .unwrap();
559 assert_eq!(
560 codec.decode(&mut buf).unwrap(),
561 Some(Bytes::from_static(b"hello"))
562 );
563
564 codec.set_mode(FramingMode::Chunked);
566
567 let mut buf = BytesMut::new();
568 codec
569 .encode(Bytes::from_static(b"world"), &mut buf)
570 .unwrap();
571 assert_eq!(
572 codec.decode(&mut buf).unwrap(),
573 Some(Bytes::from_static(b"world"))
574 );
575 }
576
577 #[tokio::test]
580 async fn eom_helper_roundtrip() {
581 let (mut client, mut server) = tokio::io::duplex(4096);
582
583 let msg = "<hello/>";
584 tokio::spawn(async move {
585 write_eom_message(&mut server, msg).await.unwrap();
586 });
587
588 let received = read_eom_message(&mut client, None).await.unwrap();
589 assert_eq!(received, msg);
590 }
591
592 #[tokio::test]
593 async fn eom_helper_size_limit() {
594 let (mut client, mut server) = tokio::io::duplex(4096);
595
596 let msg = "x".repeat(1000);
597 tokio::spawn(async move {
598 write_eom_message(&mut server, &msg).await.unwrap();
599 });
600
601 let result = read_eom_message(&mut client, Some(10)).await;
602 assert!(result.is_err());
603 }
604}