1use bytes::{Buf, Bytes, BytesMut};
2use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3use tokio_util::codec::{Decoder, Encoder};
4
5use crate::error::FramingError;
6
7const EOM_MARKER: &[u8] = b"]]>]]>";
8const EOM_LEN: usize = EOM_MARKER.len();
9
10const CHUNKED_EOM_MARKER: &[u8] = b"\n##\n";
11const CHUNKED_EOM_MARKER_LEN: usize = CHUNKED_EOM_MARKER.len();
12
13const CHUNKED_HEADER_START: &[u8] = b"\n#";
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum FramingMode {
23 EndOfMessage,
25 Chunked,
27}
28
29#[derive(Default, Debug, Clone, Copy)]
30pub struct CodecConfig {
31 pub max_message_size: Option<usize>, }
33
34pub struct NetconfCodec {
35 framing_mode: FramingMode,
36 config: CodecConfig,
37 chunked_buf: BytesMut,
38}
39impl NetconfCodec {
69 pub fn new(framing_mode: FramingMode, config: CodecConfig) -> Self {
70 Self {
71 framing_mode,
72 config,
73 chunked_buf: BytesMut::new(),
74 }
75 }
76
77 pub fn set_mode(&mut self, framing_mode: FramingMode) {
78 self.framing_mode = framing_mode;
79 self.chunked_buf.clear();
80 }
81 pub fn framing_mode(&self) -> FramingMode {
82 self.framing_mode
83 }
84
85 fn check_size(&self, size: usize) -> Result<(), FramingError> {
86 if let Some(max_size) = self.config.max_message_size
87 && size > max_size
88 {
89 return Err(FramingError::MessageTooLarge {
90 limit: max_size,
91 received: size,
92 });
93 }
94 Ok(())
95 }
96
97 fn decode_eom(&self, src: &mut BytesMut) -> Result<Option<Bytes>, FramingError> {
106 if src.len() < EOM_LEN {
109 return Ok(None);
110 }
111 if let Some(eom_pos_start) = find_subsequence(src, EOM_MARKER) {
112 let msg_len = eom_pos_start;
113 self.check_size(msg_len)?;
114 let msg = src.split_to(msg_len).freeze();
115 src.advance(EOM_LEN);
116 Ok(Some(msg))
117 } else {
118 self.check_size(src.len())?;
119 Ok(None)
120 }
121 }
122
123 fn decode_chunked(&mut self, src: &mut BytesMut) -> Result<Option<Bytes>, FramingError> {
124 loop {
127 if src.len() < CHUNKED_EOM_MARKER_LEN {
128 return Ok(None);
129 }
130
131 if src[0..2] != *CHUNKED_HEADER_START {
133 return Err(FramingError::InvalidHeader {
134 expected: "\\n#",
135 got: src[..2].to_vec(),
136 });
137 }
138
139 if src[2] == b'#' {
142 if src[3] != b'\n' {
143 return Err(FramingError::InvalidHeader {
144 expected: "\\n##\\n",
145 got: src[..4].to_vec(),
146 });
147 }
148
149 src.advance(CHUNKED_EOM_MARKER_LEN);
151 let msg = self.chunked_buf.split().freeze();
152 return Ok(Some(msg));
153 }
154
155 let header_start = 2; let header_end = match src[header_start..].iter().position(|&b| b == b'\n') {
158 Some(pos_end_of_header) => header_start + pos_end_of_header,
159 None => {
160 if src.len() > 20 {
163 return Err(FramingError::InvalidChunkSize(
164 String::from_utf8_lossy(&src[header_start..]).into_owned(),
165 ));
166 }
167 return Ok(None);
168 }
169 };
170
171 let size_str = &src[header_start..header_end];
173 let chunk_size: usize = std::str::from_utf8(size_str)
174 .map_err(|_| {
175 FramingError::InvalidChunkSize(String::from_utf8_lossy(size_str).into_owned())
176 })?
177 .parse()
178 .map_err(|_| {
179 FramingError::InvalidChunkSize(String::from_utf8_lossy(size_str).into_owned())
180 })?;
181
182 if chunk_size == 0 {
183 return Err(FramingError::InvalidChunkSize("0".into()));
184 }
185
186 let header_len = header_end + 1; let total_chunk_len = header_len + chunk_size;
191 if src.len() < total_chunk_len {
192 return Ok(None); }
194 self.check_size(self.chunked_buf.len() + chunk_size)?;
195
196 src.advance(header_len);
198
199 self.chunked_buf.extend_from_slice(&src[..chunk_size]);
201 src.advance(chunk_size);
202 }
203 }
204}
205
206fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
209 haystack
210 .windows(needle.len())
211 .position(|window| window == needle)
212}
213
214pub(crate) async fn read_eom_message<R: AsyncRead + Unpin>(
215 reader: &mut R,
216 max_size: Option<usize>,
217) -> crate::Result<String> {
218 let mut buf = Vec::with_capacity(4096);
219 let mut tmp = [0u8; 4096];
220
221 loop {
222 let read_bytes = reader.read(&mut tmp).await?;
223
224 if read_bytes == 0 {
225 return Err(FramingError::UnexpectedEof.into());
226 }
227 buf.extend_from_slice(&tmp[..read_bytes]);
228 if let Some(limit) = max_size
229 && buf.len() > limit + EOM_LEN
230 {
231 return Err(FramingError::MessageTooLarge {
232 limit,
233 received: buf.len(),
234 }
235 .into());
236 }
237 if buf.len() >= EOM_LEN && buf[buf.len() - EOM_LEN..] == *EOM_MARKER {
238 buf.truncate(buf.len() - EOM_LEN);
239 return String::from_utf8(buf).map_err(|_| FramingError::InvalidUtf8.into());
240 }
241 }
242}
243
244pub(crate) async fn write_eom_message<W: AsyncWrite + Unpin>(
245 writer: &mut W,
246 message: &str,
247) -> crate::Result<()> {
248 writer.write_all(message.as_bytes()).await?;
249 writer.write_all(EOM_MARKER).await?;
250 writer.flush().await?;
251 Ok(())
252}
253
254impl Decoder for NetconfCodec {
256 type Item = Bytes;
257 type Error = FramingError;
258
259 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
260 match self.framing_mode {
261 FramingMode::EndOfMessage => self.decode_eom(src),
262 FramingMode::Chunked => self.decode_chunked(src),
263 }
264 }
265}
266
267impl Encoder<Bytes> for NetconfCodec {
268 type Error = FramingError;
269
270 fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
271 match self.framing_mode {
272 FramingMode::EndOfMessage => {
273 dst.reserve(item.len() + EOM_LEN);
274 dst.extend_from_slice(&item);
275 dst.extend_from_slice(EOM_MARKER);
276 }
277 FramingMode::Chunked => {
278 let header = format!("\n#{}\n", item.len());
279 dst.reserve(header.len() + item.len() + CHUNKED_EOM_MARKER_LEN);
280 dst.extend_from_slice(header.as_bytes());
281 dst.extend_from_slice(&item);
282 dst.extend_from_slice(CHUNKED_EOM_MARKER);
283 }
284 }
285 Ok(())
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
296 fn eom_decode_complete_message() {
297 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
298 let mut buf = BytesMut::from(&b"<rpc-reply/>]]>]]>"[..]);
299 let result = codec.decode(&mut buf).unwrap();
300 assert_eq!(result, Some(Bytes::from_static(b"<rpc-reply/>")));
301 assert!(buf.is_empty());
302 }
303
304 #[test]
305 fn eom_decode_incomplete_message() {
306 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
307 let mut buf = BytesMut::from(&b"<rpc-reply/>"[..]);
308 let result = codec.decode(&mut buf).unwrap();
309 assert_eq!(result, None);
310 }
311
312 #[test]
313 fn eom_decode_partial_marker() {
314 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
315 let mut buf = BytesMut::from(&b"<ok/>]]>"[..]);
316 assert_eq!(codec.decode(&mut buf).unwrap(), None);
317 buf.extend_from_slice(b"]]>");
319 let result = codec.decode(&mut buf).unwrap();
320 assert_eq!(result, Some(Bytes::from_static(b"<ok/>")));
321 }
322
323 #[test]
324 fn eom_decode_empty_message() {
325 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
326 let mut buf = BytesMut::from(&b"]]>]]>"[..]);
327 let result = codec.decode(&mut buf).unwrap();
328 assert_eq!(result, Some(Bytes::from_static(b"")));
329 }
330
331 #[test]
332 fn eom_decode_two_messages() {
333 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
334 let mut buf = BytesMut::from(&b"<a/>]]>]]><b/>]]>]]>"[..]);
335 assert_eq!(
336 codec.decode(&mut buf).unwrap(),
337 Some(Bytes::from_static(b"<a/>"))
338 );
339 assert_eq!(
340 codec.decode(&mut buf).unwrap(),
341 Some(Bytes::from_static(b"<b/>"))
342 );
343 }
344
345 #[test]
346 fn eom_decode_size_limit_exceeded() {
347 let config = CodecConfig {
348 max_message_size: Some(5),
349 };
350 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, config);
351 let mut buf = BytesMut::from(&b"<too-large/>]]>]]>"[..]);
352 let err = codec.decode(&mut buf).unwrap_err();
353 assert!(matches!(err, FramingError::MessageTooLarge { .. }));
354 }
355
356 #[test]
357 fn eom_decode_size_limit_ok() {
358 let config = CodecConfig {
359 max_message_size: Some(100),
360 };
361 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, config);
362 let mut buf = BytesMut::from(&b"<ok/>]]>]]>"[..]);
363 assert!(codec.decode(&mut buf).unwrap().is_some());
364 }
365
366 #[test]
369 fn chunked_decode_single_chunk() {
370 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
371 let mut buf = BytesMut::from(&b"\n#7\n<data/>\n##\n"[..]);
372 let result = codec.decode(&mut buf).unwrap();
373 assert_eq!(result, Some(Bytes::from_static(b"<data/>")));
374 assert!(buf.is_empty());
375 }
376
377 #[test]
378 fn chunked_decode_multiple_chunks() {
379 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
380 let mut buf = BytesMut::from(&b"\n#5\nHello\n#6\n World\n##\n"[..]);
381 let result = codec.decode(&mut buf).unwrap();
382 assert_eq!(result, Some(Bytes::from_static(b"Hello World")));
383 }
384
385 #[test]
386 fn chunked_decode_incomplete_header() {
387 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
388 let mut buf = BytesMut::from(&b"\n#"[..]);
389 assert_eq!(codec.decode(&mut buf).unwrap(), None);
390 }
391
392 #[test]
393 fn chunked_decode_incomplete_data() {
394 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
395 let mut buf = BytesMut::from(&b"\n#10\nHello"[..]);
396 assert_eq!(codec.decode(&mut buf).unwrap(), None);
397 buf.extend_from_slice(b" Wrld\n##\n");
399 let result = codec.decode(&mut buf).unwrap();
400 assert_eq!(result, Some(Bytes::from_static(b"Hello Wrld")));
401 }
402
403 #[test]
404 fn chunked_decode_large_chunk() {
405 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
406 let data = "x".repeat(10000);
407 let mut buf = BytesMut::new();
408 buf.extend_from_slice(format!("\n#{}\n", data.len()).as_bytes());
409 buf.extend_from_slice(data.as_bytes());
410 buf.extend_from_slice(b"\n##\n");
411 let result = codec.decode(&mut buf).unwrap();
412 assert_eq!(result.unwrap().len(), 10000);
413 }
414
415 #[test]
416 fn chunked_decode_invalid_header() {
417 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
418 let mut buf = BytesMut::from(&b"\n#abc\n"[..]);
419 let err = codec.decode(&mut buf).unwrap_err();
420 assert!(matches!(err, FramingError::InvalidChunkSize(_)));
421 }
422
423 #[test]
424 fn chunked_decode_zero_chunk_size() {
425 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
426 let mut buf = BytesMut::from(&b"\n#0\n\n##\n"[..]);
427 let err = codec.decode(&mut buf).unwrap_err();
428 assert!(matches!(err, FramingError::InvalidChunkSize(_)));
429 }
430
431 #[test]
432 fn chunked_decode_size_limit() {
433 let config = CodecConfig {
434 max_message_size: Some(5),
435 };
436 let mut codec = NetconfCodec::new(FramingMode::Chunked, config);
437 let mut buf = BytesMut::from(&b"\n#10\n0123456789\n##\n"[..]);
438 let err = codec.decode(&mut buf).unwrap_err();
439 assert!(matches!(err, FramingError::MessageTooLarge { .. }));
440 }
441
442 #[test]
445 fn eom_encode() {
446 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
447 let mut buf = BytesMut::new();
448 codec
449 .encode(Bytes::from_static(b"<ok/>"), &mut buf)
450 .unwrap();
451 assert_eq!(&buf[..], b"<ok/>]]>]]>");
452 }
453
454 #[test]
455 fn chunked_encode() {
456 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
457 let mut buf = BytesMut::new();
458 codec
459 .encode(Bytes::from_static(b"<ok/>"), &mut buf)
460 .unwrap();
461 assert_eq!(&buf[..], b"\n#5\n<ok/>\n##\n");
462 }
463
464 #[test]
467 fn eom_roundtrip() {
468 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
469 let original = Bytes::from_static(b"<rpc message-id=\"1\"><get/></rpc>");
470 let mut buf = BytesMut::new();
471 codec.encode(original.clone(), &mut buf).unwrap();
472 let decoded = codec.decode(&mut buf).unwrap().unwrap();
473 assert_eq!(decoded, original);
474 }
475
476 #[test]
477 fn chunked_roundtrip() {
478 let mut codec = NetconfCodec::new(FramingMode::Chunked, CodecConfig::default());
479 let original = Bytes::from_static(b"<rpc message-id=\"1\"><get/></rpc>");
480 let mut buf = BytesMut::new();
481 codec.encode(original.clone(), &mut buf).unwrap();
482 let decoded = codec.decode(&mut buf).unwrap().unwrap();
483 assert_eq!(decoded, original);
484 }
485
486 #[test]
487 fn mode_switch() {
488 let mut codec = NetconfCodec::new(FramingMode::EndOfMessage, CodecConfig::default());
489
490 let mut buf = BytesMut::new();
492 codec
493 .encode(Bytes::from_static(b"hello"), &mut buf)
494 .unwrap();
495 assert_eq!(
496 codec.decode(&mut buf).unwrap(),
497 Some(Bytes::from_static(b"hello"))
498 );
499
500 codec.set_mode(FramingMode::Chunked);
502
503 let mut buf = BytesMut::new();
504 codec
505 .encode(Bytes::from_static(b"world"), &mut buf)
506 .unwrap();
507 assert_eq!(
508 codec.decode(&mut buf).unwrap(),
509 Some(Bytes::from_static(b"world"))
510 );
511 }
512
513 #[tokio::test]
516 async fn eom_helper_roundtrip() {
517 let (mut client, mut server) = tokio::io::duplex(4096);
518
519 let msg = "<hello/>";
520 tokio::spawn(async move {
521 write_eom_message(&mut server, msg).await.unwrap();
522 });
523
524 let received = read_eom_message(&mut client, None).await.unwrap();
525 assert_eq!(received, msg);
526 }
527
528 #[tokio::test]
529 async fn eom_helper_size_limit() {
530 let (mut client, mut server) = tokio::io::duplex(4096);
531
532 let msg = "x".repeat(1000);
533 tokio::spawn(async move {
534 write_eom_message(&mut server, &msg).await.unwrap();
535 });
536
537 let result = read_eom_message(&mut client, Some(10)).await;
538 assert!(result.is_err());
539 }
540}