Skip to main content

atomr_remote/
codec.rs

1//! Length-prefixed framing for `AkkaPdu`.
2//!
3//! On the wire each frame is `u32` big-endian length, followed by a
4//! bincode-serialized [`AkkaPdu`].
5
6use thiserror::Error;
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8
9use crate::pdu::AkkaPdu;
10
11#[derive(Debug, Error)]
12#[non_exhaustive]
13pub enum CodecError {
14    #[error("io: {0}")]
15    Io(#[from] std::io::Error),
16    #[error("encode: {0}")]
17    Encode(String),
18    #[error("decode: {0}")]
19    Decode(String),
20    #[error("frame too large ({0} bytes, max {1})")]
21    FrameTooLarge(usize, usize),
22}
23
24pub fn encode_pdu(pdu: &AkkaPdu) -> Result<Vec<u8>, CodecError> {
25    bincode::serde::encode_to_vec(pdu, bincode::config::standard())
26        .map_err(|e| CodecError::Encode(e.to_string()))
27}
28
29pub fn decode_pdu(bytes: &[u8]) -> Result<AkkaPdu, CodecError> {
30    let (pdu, _) = bincode::serde::decode_from_slice(bytes, bincode::config::standard())
31        .map_err(|e| CodecError::Decode(e.to_string()))?;
32    Ok(pdu)
33}
34
35pub async fn write_frame<W: tokio::io::AsyncWrite + Unpin>(
36    w: &mut W,
37    pdu: &AkkaPdu,
38    max_size: usize,
39) -> Result<(), CodecError> {
40    let bytes = encode_pdu(pdu)?;
41    if bytes.len() > max_size {
42        return Err(CodecError::FrameTooLarge(bytes.len(), max_size));
43    }
44    w.write_all(&(bytes.len() as u32).to_be_bytes()).await?;
45    w.write_all(&bytes).await?;
46    w.flush().await?;
47    Ok(())
48}
49
50pub async fn read_frame<R: tokio::io::AsyncRead + Unpin>(
51    r: &mut R,
52    max_size: usize,
53) -> Result<AkkaPdu, CodecError> {
54    let mut len = [0u8; 4];
55    r.read_exact(&mut len).await?;
56    let n = u32::from_be_bytes(len) as usize;
57    if n > max_size {
58        return Err(CodecError::FrameTooLarge(n, max_size));
59    }
60    let mut buf = vec![0u8; n];
61    r.read_exact(&mut buf).await?;
62    decode_pdu(&buf)
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68    use crate::pdu::{AssociateInfo, PROTOCOL_VERSION};
69    use atomr_core::actor::Address;
70
71    #[test]
72    fn roundtrip_associate() {
73        let pdu = AkkaPdu::Associate(AssociateInfo {
74            origin: Address::remote("akka.tcp", "S", "127.0.0.1", 1234),
75            uid: 99,
76            cookie: Some("hunter2".into()),
77            protocol_version: PROTOCOL_VERSION,
78        });
79        let bytes = encode_pdu(&pdu).unwrap();
80        let back = decode_pdu(&bytes).unwrap();
81        assert_eq!(pdu, back);
82    }
83
84    #[test]
85    fn roundtrip_heartbeat_and_disassociate() {
86        for pdu in [AkkaPdu::Heartbeat, AkkaPdu::Disassociate(crate::pdu::DisassociateReason::Normal)] {
87            let bytes = encode_pdu(&pdu).unwrap();
88            assert_eq!(decode_pdu(&bytes).unwrap(), pdu);
89        }
90    }
91}