1use thiserror::Error;
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8
9use crate::pdu::AkkaPdu;
10
11#[derive(Debug, Error)]
12#[non_exhaustive]
13pub enum CodecError {
14 #[error("io: {0}")]
15 Io(#[from] std::io::Error),
16 #[error("encode: {0}")]
17 Encode(String),
18 #[error("decode: {0}")]
19 Decode(String),
20 #[error("frame too large ({0} bytes, max {1})")]
21 FrameTooLarge(usize, usize),
22}
23
24pub fn encode_pdu(pdu: &AkkaPdu) -> Result<Vec<u8>, CodecError> {
25 bincode::serde::encode_to_vec(pdu, bincode::config::standard())
26 .map_err(|e| CodecError::Encode(e.to_string()))
27}
28
29pub fn decode_pdu(bytes: &[u8]) -> Result<AkkaPdu, CodecError> {
30 let (pdu, _) = bincode::serde::decode_from_slice(bytes, bincode::config::standard())
31 .map_err(|e| CodecError::Decode(e.to_string()))?;
32 Ok(pdu)
33}
34
35pub async fn write_frame<W: tokio::io::AsyncWrite + Unpin>(
36 w: &mut W,
37 pdu: &AkkaPdu,
38 max_size: usize,
39) -> Result<(), CodecError> {
40 let bytes = encode_pdu(pdu)?;
41 if bytes.len() > max_size {
42 return Err(CodecError::FrameTooLarge(bytes.len(), max_size));
43 }
44 w.write_all(&(bytes.len() as u32).to_be_bytes()).await?;
45 w.write_all(&bytes).await?;
46 w.flush().await?;
47 Ok(())
48}
49
50pub async fn read_frame<R: tokio::io::AsyncRead + Unpin>(
51 r: &mut R,
52 max_size: usize,
53) -> Result<AkkaPdu, CodecError> {
54 let mut len = [0u8; 4];
55 r.read_exact(&mut len).await?;
56 let n = u32::from_be_bytes(len) as usize;
57 if n > max_size {
58 return Err(CodecError::FrameTooLarge(n, max_size));
59 }
60 let mut buf = vec![0u8; n];
61 r.read_exact(&mut buf).await?;
62 decode_pdu(&buf)
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68 use crate::pdu::{AssociateInfo, PROTOCOL_VERSION};
69 use atomr_core::actor::Address;
70
71 #[test]
72 fn roundtrip_associate() {
73 let pdu = AkkaPdu::Associate(AssociateInfo {
74 origin: Address::remote("akka.tcp", "S", "127.0.0.1", 1234),
75 uid: 99,
76 cookie: Some("hunter2".into()),
77 protocol_version: PROTOCOL_VERSION,
78 });
79 let bytes = encode_pdu(&pdu).unwrap();
80 let back = decode_pdu(&bytes).unwrap();
81 assert_eq!(pdu, back);
82 }
83
84 #[test]
85 fn roundtrip_heartbeat_and_disassociate() {
86 for pdu in [AkkaPdu::Heartbeat, AkkaPdu::Disassociate(crate::pdu::DisassociateReason::Normal)] {
87 let bytes = encode_pdu(&pdu).unwrap();
88 assert_eq!(decode_pdu(&bytes).unwrap(), pdu);
89 }
90 }
91}