use crate::layers::L2Frame;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use thiserror::Error;
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Error)]
pub enum DilationError {
#[error("io error")] Io,
#[error("protocol error: {0}")] Proto(&'static str),
}
pub struct DilationSession<S> {
stream: S,
negotiated: Vec<String>,
_send_queues: HashMap<u32, VecDeque<Vec<u8>>>,
send_window: HashMap<u32, i32>,
recv_window: HashMap<u32, (i32, i32, i32)>, }
impl<S: AsyncReadExt + AsyncWriteExt + Unpin> DilationSession<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
negotiated: vec![],
_send_queues: HashMap::new(),
send_window: HashMap::new(),
recv_window: HashMap::new(),
}
}
pub async fn negotiate(&mut self, ours: &[&str], is_initiator: bool) -> Result<Vec<String>, DilationError> {
let my = serde_json::to_vec(&ours).map_err(|_| DilationError::Proto("json"))?;
if is_initiator {
self.write_frame(0, 0, false, &my).await?;
let peer = self.read_frame().await?;
let theirs: Vec<String> = match peer { L2Frame::Data { body, .. } => serde_json::from_slice(&body).map_err(|_| DilationError::Proto("json"))?, _ => return Err(DilationError::Proto("bad frame")) };
self.negotiated = ours.iter().map(|s| s.to_string()).filter(|s| theirs.contains(s)).collect();
} else {
let peer = self.read_frame().await?;
let theirs: Vec<String> = match peer { L2Frame::Data { body, .. } => serde_json::from_slice(&body).map_err(|_| DilationError::Proto("json"))?, _ => return Err(DilationError::Proto("bad frame")) };
self.write_frame(0, 0, false, &my).await?;
self.negotiated = ours.iter().map(|s| s.to_string()).filter(|s| theirs.contains(s)).collect();
}
Ok(self.negotiated.clone())
}
pub async fn open_subchannel(&mut self, id: u32, name: &str) -> Result<(), DilationError> {
let payload = serde_json::to_vec(&serde_json::json!({"open": {"id": id, "name": name}})).map_err(|_| DilationError::Proto("json"))?;
self.write_frame(id, 0, false, &payload).await?;
Ok(())
}
pub async fn send(&mut self, id: u32, seq: u64, payload: &[u8], fin: bool) -> Result<(), DilationError> {
let w = *self.send_window.get(&id).unwrap_or(&i32::MAX);
if w <= 0 {
self._send_queues.entry(id).or_default().push_back([seq.to_be_bytes().as_slice(), if fin { &[1][..] } else { &[0][..] }, payload].concat());
return Ok(());
}
self.send_window.insert(id, w-1);
self.write_frame(id, seq, fin, payload).await
}
pub async fn recv(&mut self) -> Result<L2Frame, DilationError> {
self.read_frame().await
}
async fn write_frame(&mut self, ch: u32, seq: u64, fin: bool, body: &[u8]) -> Result<(), DilationError> {
let mut hdr = Vec::with_capacity(1+1+4+8+4);
hdr.push(0x01);
hdr.push(if fin { 0x01 } else { 0 });
hdr.extend_from_slice(&ch.to_be_bytes());
hdr.extend_from_slice(&seq.to_be_bytes());
hdr.extend_from_slice(&(body.len() as u32).to_be_bytes());
self.stream.write_all(&hdr).await.map_err(|_| DilationError::Io)?;
self.stream.write_all(body).await.map_err(|_| DilationError::Io)?;
Ok(())
}
async fn read_frame(&mut self) -> Result<L2Frame, DilationError> {
let mut t = [0u8;1]; self.stream.read_exact(&mut t).await.map_err(|_| DilationError::Io)?;
match t[0] {
0x01 => {
let mut rest = [0u8; 1+4+8+4];
self.stream.read_exact(&mut rest).await.map_err(|_| DilationError::Io)?;
let flags = rest[0];
let mut chb=[0u8;4]; chb.copy_from_slice(&rest[1..5]); let ch = u32::from_be_bytes(chb);
let mut sqb=[0u8;8]; sqb.copy_from_slice(&rest[5..13]); let seq = u64::from_be_bytes(sqb);
let mut lnb=[0u8;4]; lnb.copy_from_slice(&rest[13..17]); let len = u32::from_be_bytes(lnb) as usize;
let mut body = vec![0u8; len]; self.stream.read_exact(&mut body).await.map_err(|_| DilationError::Io)?;
let mut send_update: Option<(u32, u32)> = None;
{
let e = self.recv_window.entry(ch).or_insert((64, 16, 48));
e.0 -= 1;
if e.0 <= e.1 {
let grant = e.2.max(0) as u32;
e.0 += e.2;
send_update = Some((ch, grant));
}
}
if let Some((ch, credit)) = send_update {
let _ = self.window_update(ch, credit).await;
}
Ok(L2Frame::Data { ch, seq, fin: flags & 0x01 != 0, body })
}
0x02 => {
let mut b=[0u8;8]; self.stream.read_exact(&mut b).await.map_err(|_| DilationError::Io)?;
let mut chb=[0u8;4]; chb.copy_from_slice(&b[..4]); let ch = u32::from_be_bytes(chb);
let mut crb=[0u8;4]; crb.copy_from_slice(&b[4..]); let credit = u32::from_be_bytes(crb);
Ok(L2Frame::WindowUpdate { ch, credit })
}
0x03 => Ok(L2Frame::Ping),
0x04 => Ok(L2Frame::Pong),
_ => Err(DilationError::Proto("bad type"))
}
}
pub async fn ping(&mut self) -> Result<(), DilationError> { self.write_frame(0xffff_fffe, 0, false, b"ping").await }
pub async fn recv_pong(&mut self) -> Result<(), DilationError> {
let f = self.read_frame().await?; match f { L2Frame::Data{ch, body, ..} if ch==0xffff_fffe && body==b"pong" => Ok(()), _ => Err(DilationError::Proto("expected pong")) }
}
pub async fn window_update(&mut self, ch: u32, credit: u32) -> Result<(), DilationError> {
let mut hdr = Vec::with_capacity(1+4+4);
hdr.push(0x02);
hdr.extend_from_slice(&ch.to_be_bytes());
hdr.extend_from_slice(&credit.to_be_bytes());
self.stream.write_all(&hdr).await.map_err(|_| DilationError::Io)
}
pub fn add_send_window(&mut self, ch: u32, credit: i32) { let e = self.send_window.entry(ch).or_insert(0); *e += credit; }
pub fn set_recv_window(&mut self, ch: u32, credit: i32, threshold: i32, grant: i32) {
self.recv_window.insert(ch, (credit, threshold, grant));
}
pub async fn flush(&mut self, ch: u32) -> Result<(), DilationError> {
loop {
let window = *self.send_window.get(&ch).unwrap_or(&0);
if window <= 0 { break; }
let rec_opt = {
let q = self._send_queues.get_mut(&ch);
if let Some(q) = q { q.pop_front() } else { None }
};
if let Some(mut rec) = rec_opt {
if rec.len() < 9 { continue; }
let fin = rec[8] == 1; let mut seqb=[0u8;8]; seqb.copy_from_slice(&rec[..8]); let seq=u64::from_be_bytes(seqb);
let body = rec.split_off(9);
let w = *self.send_window.get(&ch).unwrap_or(&0); self.send_window.insert(ch, w-1);
self.write_frame(ch, seq, fin, &body).await?;
} else { break; }
}
Ok(())
}
pub fn rebind<T: AsyncReadExt + AsyncWriteExt + Unpin>(self, stream: T) -> DilationSession<T> {
DilationSession { stream, negotiated: self.negotiated, _send_queues: HashMap::new(), send_window: HashMap::new(), recv_window: HashMap::new() }
}
}