anubis-wormhole 1.0.0

A post-quantum secure file transfer tool based on the Magic Wormhole protocol.
Documentation
use crate::layers::L2Frame;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use thiserror::Error;
use std::collections::{HashMap, VecDeque};

#[derive(Debug, Error)]
pub enum DilationError {
    #[error("io error")] Io,
    #[error("protocol error: {0}")] Proto(&'static str),
}

pub struct DilationSession<S> {
    stream: S,
    negotiated: Vec<String>,
    _send_queues: HashMap<u32, VecDeque<Vec<u8>>>,
    send_window: HashMap<u32, i32>,
    recv_window: HashMap<u32, (i32, i32, i32)>, // (credit, threshold, grant)
}

impl<S: AsyncReadExt + AsyncWriteExt + Unpin> DilationSession<S> {
    pub fn new(stream: S) -> Self {
        Self {
            stream,
            negotiated: vec![],
            _send_queues: HashMap::new(),
            send_window: HashMap::new(),
            recv_window: HashMap::new(),
        }
    }

    pub async fn negotiate(&mut self, ours: &[&str], is_initiator: bool) -> Result<Vec<String>, DilationError> {
        // very simple: send JSON of our list, read peer list, intersect
        let my = serde_json::to_vec(&ours).map_err(|_| DilationError::Proto("json"))?;
        if is_initiator {
            self.write_frame(0, 0, false, &my).await?;
            let peer = self.read_frame().await?;
            let theirs: Vec<String> = match peer { L2Frame::Data { body, .. } => serde_json::from_slice(&body).map_err(|_| DilationError::Proto("json"))?, _ => return Err(DilationError::Proto("bad frame")) };
            self.negotiated = ours.iter().map(|s| s.to_string()).filter(|s| theirs.contains(s)).collect();
        } else {
            let peer = self.read_frame().await?;
            let theirs: Vec<String> = match peer { L2Frame::Data { body, .. } => serde_json::from_slice(&body).map_err(|_| DilationError::Proto("json"))?, _ => return Err(DilationError::Proto("bad frame")) };
            self.write_frame(0, 0, false, &my).await?;
            self.negotiated = ours.iter().map(|s| s.to_string()).filter(|s| theirs.contains(s)).collect();
        }
        Ok(self.negotiated.clone())
    }

    pub async fn open_subchannel(&mut self, id: u32, name: &str) -> Result<(), DilationError> {
        let payload = serde_json::to_vec(&serde_json::json!({"open": {"id": id, "name": name}})).map_err(|_| DilationError::Proto("json"))?;
        self.write_frame(id, 0, false, &payload).await?;
        Ok(())
    }

    pub async fn send(&mut self, id: u32, seq: u64, payload: &[u8], fin: bool) -> Result<(), DilationError> {
        let w = *self.send_window.get(&id).unwrap_or(&i32::MAX);
        if w <= 0 {
            // buffer
            self._send_queues.entry(id).or_default().push_back([seq.to_be_bytes().as_slice(), if fin { &[1][..] } else { &[0][..] }, payload].concat());
            return Ok(());
        }
        self.send_window.insert(id, w-1);
        self.write_frame(id, seq, fin, payload).await
    }

    pub async fn recv(&mut self) -> Result<L2Frame, DilationError> {
        self.read_frame().await
    }

    async fn write_frame(&mut self, ch: u32, seq: u64, fin: bool, body: &[u8]) -> Result<(), DilationError> {
        // L2 encoding:
        // Data: [type=0x01][flags(1)][ch(4)][seq(8)][len(4)] + body
        let mut hdr = Vec::with_capacity(1+1+4+8+4);
        hdr.push(0x01);
        hdr.push(if fin { 0x01 } else { 0 });
        hdr.extend_from_slice(&ch.to_be_bytes());
        hdr.extend_from_slice(&seq.to_be_bytes());
        hdr.extend_from_slice(&(body.len() as u32).to_be_bytes());
        self.stream.write_all(&hdr).await.map_err(|_| DilationError::Io)?;
        self.stream.write_all(body).await.map_err(|_| DilationError::Io)?;
        Ok(())
    }

    async fn read_frame(&mut self) -> Result<L2Frame, DilationError> {
        // Peek type
        let mut t = [0u8;1]; self.stream.read_exact(&mut t).await.map_err(|_| DilationError::Io)?;
        match t[0] {
            0x01 => {
                let mut rest = [0u8; 1+4+8+4];
                self.stream.read_exact(&mut rest).await.map_err(|_| DilationError::Io)?;
                let flags = rest[0];
                let mut chb=[0u8;4]; chb.copy_from_slice(&rest[1..5]); let ch = u32::from_be_bytes(chb);
                let mut sqb=[0u8;8]; sqb.copy_from_slice(&rest[5..13]); let seq = u64::from_be_bytes(sqb);
                let mut lnb=[0u8;4]; lnb.copy_from_slice(&rest[13..17]); let len = u32::from_be_bytes(lnb) as usize;
                let mut body = vec![0u8; len]; self.stream.read_exact(&mut body).await.map_err(|_| DilationError::Io)?;
                // Receive-side window tracking and automatic updates
                let mut send_update: Option<(u32, u32)> = None;
                {
                    let e = self.recv_window.entry(ch).or_insert((64, 16, 48));
                    e.0 -= 1;
                    if e.0 <= e.1 {
                        let grant = e.2.max(0) as u32;
                        e.0 += e.2;
                        send_update = Some((ch, grant));
                    }
                }
                if let Some((ch, credit)) = send_update {
                    // Ignore update send failures; delivery of data took priority
                    let _ = self.window_update(ch, credit).await;
                }
                Ok(L2Frame::Data { ch, seq, fin: flags & 0x01 != 0, body })
            }
            0x02 => {
                let mut b=[0u8;8]; self.stream.read_exact(&mut b).await.map_err(|_| DilationError::Io)?;
                let mut chb=[0u8;4]; chb.copy_from_slice(&b[..4]); let ch = u32::from_be_bytes(chb);
                let mut crb=[0u8;4]; crb.copy_from_slice(&b[4..]); let credit = u32::from_be_bytes(crb);
                Ok(L2Frame::WindowUpdate { ch, credit })
            }
            0x03 => Ok(L2Frame::Ping),
            0x04 => Ok(L2Frame::Pong),
            _ => Err(DilationError::Proto("bad type"))
        }
    }

    // Basic ping/pong for liveness
    pub async fn ping(&mut self) -> Result<(), DilationError> { self.write_frame(0xffff_fffe, 0, false, b"ping").await }
    pub async fn recv_pong(&mut self) -> Result<(), DilationError> {
        let f = self.read_frame().await?; match f { L2Frame::Data{ch, body, ..} if ch==0xffff_fffe && body==b"pong" => Ok(()), _ => Err(DilationError::Proto("expected pong")) }
    }

    pub async fn window_update(&mut self, ch: u32, credit: u32) -> Result<(), DilationError> {
        // Encode: [type=0x02][ch(4)][credit(4)]
        let mut hdr = Vec::with_capacity(1+4+4);
        hdr.push(0x02);
        hdr.extend_from_slice(&ch.to_be_bytes());
        hdr.extend_from_slice(&credit.to_be_bytes());
        self.stream.write_all(&hdr).await.map_err(|_| DilationError::Io)
    }

    pub fn add_send_window(&mut self, ch: u32, credit: i32) { let e = self.send_window.entry(ch).or_insert(0); *e += credit; }

    /// Configure a per-channel receive window policy (credit, threshold, grant).
    /// Intended for tests and tuning.
    pub fn set_recv_window(&mut self, ch: u32, credit: i32, threshold: i32, grant: i32) {
        self.recv_window.insert(ch, (credit, threshold, grant));
    }

    pub async fn flush(&mut self, ch: u32) -> Result<(), DilationError> {
        loop {
            let window = *self.send_window.get(&ch).unwrap_or(&0);
            if window <= 0 { break; }
            let rec_opt = {
                let q = self._send_queues.get_mut(&ch);
                if let Some(q) = q { q.pop_front() } else { None }
            };
            if let Some(mut rec) = rec_opt {
                if rec.len() < 9 { continue; }
                let fin = rec[8] == 1; // [seq(8)][fin(1)]..body
                let mut seqb=[0u8;8]; seqb.copy_from_slice(&rec[..8]); let seq=u64::from_be_bytes(seqb);
                let body = rec.split_off(9);
                let w = *self.send_window.get(&ch).unwrap_or(&0); self.send_window.insert(ch, w-1);
                self.write_frame(ch, seq, fin, &body).await?;
            } else { break; }
        }
        Ok(())
    }

    pub fn rebind<T: AsyncReadExt + AsyncWriteExt + Unpin>(self, stream: T) -> DilationSession<T> {
        DilationSession { stream, negotiated: self.negotiated, _send_queues: HashMap::new(), send_window: HashMap::new(), recv_window: HashMap::new() }
    }
}