use crate::api::session::{FramePhase, SessionTransport};
use crate::errors::CoreError;
use bytes::{Bytes, BytesMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
const HANDSHAKE_FRAME_CAP: usize = 64 * 1024;
const STEADY_STATE_FRAME_CAP: usize = 4 * 1024 * 1024;
const RECV_BUF_INITIAL_CAPACITY: usize = 64 * 1024;
const RECV_CHUNK: usize = 64 * 1024;
const SHRINK_SLACK_MULT: usize = 4;
pub struct TcpSessionTransport {
write_half: Mutex<tokio::net::tcp::OwnedWriteHalf>,
read_half: Mutex<(tokio::net::tcp::OwnedReadHalf, BytesMut)>,
frame_cap: AtomicUsize,
}
impl TcpSessionTransport {
pub fn new(stream: TcpStream) -> Self {
let _ = stream.set_nodelay(true);
let (r, w) = stream.into_split();
Self {
write_half: Mutex::new(w),
read_half: Mutex::new((r, BytesMut::with_capacity(RECV_BUF_INITIAL_CAPACITY))),
frame_cap: AtomicUsize::new(HANDSHAKE_FRAME_CAP),
}
}
#[cfg(test)]
pub(crate) async fn accum_capacity(&self) -> usize {
self.read_half.lock().await.1.capacity()
}
}
impl SessionTransport for TcpSessionTransport {
async fn send_bytes(&self, data: &[u8]) -> Result<(), CoreError> {
if data.len() > STEADY_STATE_FRAME_CAP {
return Err(CoreError::NetworkError(format!(
"frame too large: {} > {}",
data.len(),
STEADY_STATE_FRAME_CAP
)));
}
let mut w = self.write_half.lock().await;
let len = (data.len() as u32).to_be_bytes();
w.write_all(&len)
.await
.map_err(|e| CoreError::NetworkError(e.to_string()))?;
w.write_all(data)
.await
.map_err(|e| CoreError::NetworkError(e.to_string()))?;
w.flush()
.await
.map_err(|e| CoreError::NetworkError(e.to_string()))?;
Ok(())
}
async fn recv_bytes(&self) -> Result<Bytes, CoreError> {
let cap = self.frame_cap.load(Ordering::Relaxed);
let mut guard = self.read_half.lock().await;
let (r, buf) = &mut *guard;
let mut len_buf = [0u8; 4];
r.read_exact(&mut len_buf)
.await
.map_err(|e| CoreError::NetworkError(e.to_string()))?;
let len = u32::from_be_bytes(len_buf) as usize;
if len > cap {
return Err(CoreError::NetworkError(format!(
"oversized frame from peer: {} > {}",
len, cap
)));
}
buf.clear();
let mut filled = 0usize;
while filled < len {
let chunk = (len - filled).min(RECV_CHUNK);
buf.resize(filled + chunk, 0);
r.read_exact(&mut buf[filled..filled + chunk])
.await
.map_err(|e| CoreError::NetworkError(e.to_string()))?;
filled += chunk;
}
let frame = buf.split_to(len).freeze();
if len > RECV_BUF_INITIAL_CAPACITY * SHRINK_SLACK_MULT {
*buf = BytesMut::with_capacity(RECV_BUF_INITIAL_CAPACITY);
}
Ok(frame)
}
fn set_frame_phase(&self, phase: FramePhase) {
let cap = match phase {
FramePhase::Handshake => HANDSHAKE_FRAME_CAP,
FramePhase::Established => STEADY_STATE_FRAME_CAP,
};
self.frame_cap.store(cap, Ordering::Relaxed);
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::net::{TcpListener, TcpStream};
async fn tcp_pair() -> (TcpSessionTransport, TcpSessionTransport) {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("addr");
let (client, accepted) = tokio::join!(TcpStream::connect(addr), listener.accept());
let client = client.expect("connect");
let (server, _) = accepted.expect("accept");
(
TcpSessionTransport::new(client),
TcpSessionTransport::new(server),
)
}
#[tokio::test]
async fn handshake_phase_rejects_oversized_frame() {
let (client, server) = tcp_pair().await; let big = vec![0u8; 100 * 1024]; client
.send_bytes(&big)
.await
.expect("send is within the 4 MiB send cap");
let err = server
.recv_bytes()
.await
.expect_err("oversized handshake-phase frame must be rejected");
assert!(matches!(err, CoreError::NetworkError(_)));
}
#[tokio::test]
async fn established_phase_accepts_large_frame_and_resets_accumulator() {
let (client, server) = tcp_pair().await;
server.set_frame_phase(FramePhase::Established);
let payload = vec![7u8; 1024 * 1024];
let (send_res, recv_res) = tokio::join!(client.send_bytes(&payload), server.recv_bytes());
send_res.expect("send 1 MiB");
let got = recv_res.expect("recv 1 MiB");
assert_eq!(got.len(), payload.len());
assert_eq!(&got[..8], &payload[..8]);
let cap = server.accum_capacity().await;
assert!(
cap <= RECV_BUF_INITIAL_CAPACITY * SHRINK_SLACK_MULT,
"accumulator must reset to baseline after a large frame (LEGS-003); capacity = {cap}"
);
client.send_bytes(b"small").await.expect("send small");
let got = server.recv_bytes().await.expect("recv small");
assert_eq!(&got[..], b"small");
}
#[tokio::test]
async fn send_rejects_over_steady_state_cap() {
let (client, _server) = tcp_pair().await;
let too_big = vec![0u8; STEADY_STATE_FRAME_CAP + 1];
assert!(client.send_bytes(&too_big).await.is_err());
}
}