use crate::bytes::BytesMut;
use crate::codec::{Decoder, Encoder};
use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::stream::Stream;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
const DEFAULT_CAPACITY: usize = 8192;
const READ_BUF_SIZE: usize = 8192;
const MAX_READ_PASSES_PER_POLL: usize = 32;
const MAX_WRITE_PASSES_PER_POLL: usize = 32;
pub struct Framed<T, U> {
inner: T,
codec: U,
read_buf: BytesMut,
write_buf: BytesMut,
eof: bool,
}
impl<T, U> Framed<T, U> {
#[inline]
pub fn new(inner: T, codec: U) -> Self {
Self::with_capacity(inner, codec, DEFAULT_CAPACITY)
}
pub fn with_capacity(inner: T, codec: U, capacity: usize) -> Self {
Self {
inner,
codec,
read_buf: BytesMut::with_capacity(capacity),
write_buf: BytesMut::with_capacity(capacity),
eof: false,
}
}
#[inline]
#[must_use]
pub fn get_ref(&self) -> &T {
&self.inner
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner
}
#[inline]
#[must_use]
pub fn codec(&self) -> &U {
&self.codec
}
pub fn codec_mut(&mut self) -> &mut U {
&mut self.codec
}
#[inline]
#[must_use]
pub fn read_buffer(&self) -> &BytesMut {
&self.read_buf
}
#[inline]
#[must_use]
pub fn write_buffer(&self) -> &BytesMut {
&self.write_buf
}
#[inline]
pub fn into_inner(self) -> T {
self.inner
}
pub fn into_parts(self) -> FramedParts<T, U> {
FramedParts {
inner: self.inner,
codec: self.codec,
read_buf: self.read_buf,
write_buf: self.write_buf,
}
}
}
pub struct FramedParts<T, U> {
pub inner: T,
pub codec: U,
pub read_buf: BytesMut,
pub write_buf: BytesMut,
}
impl<T, U> Stream for Framed<T, U>
where
T: AsyncRead + Unpin,
U: Decoder + Unpin,
{
type Item = Result<<U as Decoder>::Item, <U as Decoder>::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let mut read_passes = 0usize;
let mut should_yield = false;
loop {
if !this.eof {
match this.codec.decode(&mut this.read_buf) {
Ok(Some(item)) => return Poll::Ready(Some(Ok(item))),
Ok(None) => {
if should_yield {
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
Err(e) => return Poll::Ready(Some(Err(e))),
}
}
if this.eof {
return match this.codec.decode_eof(&mut this.read_buf) {
Ok(Some(item)) => Poll::Ready(Some(Ok(item))),
Ok(None) => Poll::Ready(None),
Err(e) => Poll::Ready(Some(Err(e))),
};
}
let mut tmp = [0u8; READ_BUF_SIZE];
let mut read_buf = ReadBuf::new(&mut tmp);
match Pin::new(&mut this.inner).poll_read(cx, &mut read_buf) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
Poll::Ready(Ok(())) => {
let filled = read_buf.filled();
if filled.is_empty() {
this.eof = true;
} else {
this.read_buf.put_slice(filled);
read_passes += 1;
if read_passes >= MAX_READ_PASSES_PER_POLL {
should_yield = true;
}
}
}
}
}
}
}
impl<T, U> Framed<T, U> {
pub fn send<I>(&mut self, item: I) -> Result<(), <U as Encoder<I>>::Error>
where
U: Encoder<I>,
{
self.codec.encode(item, &mut self.write_buf)
}
}
impl<T, U> Framed<T, U>
where
T: AsyncWrite + Unpin,
{
pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut write_passes = 0usize;
while !self.write_buf.is_empty() {
if write_passes >= MAX_WRITE_PASSES_PER_POLL {
cx.waker().wake_by_ref();
return Poll::Pending;
}
let n = match Pin::new(&mut self.inner).poll_write(cx, &self.write_buf) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Ready(Ok(n)) => n,
};
if n == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)));
}
let _ = self.write_buf.split_to(n);
write_passes += 1;
}
Pin::new(&mut self.inner).poll_flush(cx)
}
pub fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.poll_flush(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Ready(Ok(())) => {}
}
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
impl<T: std::fmt::Debug, U: std::fmt::Debug> std::fmt::Debug for Framed<T, U> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Framed")
.field("inner", &self.inner)
.field("codec", &self.codec)
.field("read_buf_len", &self.read_buf.len())
.field("write_buf_len", &self.write_buf.len())
.field("eof", &self.eof)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::{LinesCodec, LinesCodecError};
use std::io;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::Waker;
fn noop_waker() -> Waker {
std::task::Waker::noop().clone()
}
struct TrackWaker(Arc<AtomicBool>);
use std::task::Wake;
impl Wake for TrackWaker {
fn wake(self: Arc<Self>) {
self.0.store(true, Ordering::SeqCst);
}
fn wake_by_ref(self: &Arc<Self>) {
self.0.store(true, Ordering::SeqCst);
}
}
fn track_waker(flag: Arc<AtomicBool>) -> Waker {
Waker::from(Arc::new(TrackWaker(flag)))
}
#[derive(Debug)]
struct DuplexBuf {
read_data: Vec<u8>,
read_pos: usize,
written: Vec<u8>,
}
impl DuplexBuf {
fn new(read_data: &[u8]) -> Self {
Self {
read_data: read_data.to_vec(),
read_pos: 0,
written: Vec::new(),
}
}
}
impl AsyncRead for DuplexBuf {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
let remaining = &this.read_data[this.read_pos..];
if remaining.is_empty() {
return Poll::Ready(Ok(()));
}
let n = std::cmp::min(remaining.len(), buf.remaining());
buf.put_slice(&remaining[..n]);
this.read_pos += n;
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for DuplexBuf {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
this.written.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[derive(Debug)]
struct AlwaysReadyDuplex {
reads: usize,
panic_after: usize,
written: Vec<u8>,
}
impl AlwaysReadyDuplex {
fn new(panic_after: usize) -> Self {
Self {
reads: 0,
panic_after,
written: Vec::new(),
}
}
}
impl AsyncRead for AlwaysReadyDuplex {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
assert!(
this.reads < this.panic_after,
"transport was polled too many times without yielding"
);
this.reads += 1;
buf.put_slice(b"a");
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for AlwaysReadyDuplex {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
this.written.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[derive(Debug)]
struct AlwaysReadyPartialWriteDuplex {
writes: usize,
panic_after: usize,
max_per_write: usize,
written: Vec<u8>,
}
impl AlwaysReadyPartialWriteDuplex {
fn new(max_per_write: usize, panic_after: usize) -> Self {
Self {
writes: 0,
panic_after,
max_per_write,
written: Vec::new(),
}
}
}
impl AsyncRead for AlwaysReadyPartialWriteDuplex {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for AlwaysReadyPartialWriteDuplex {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
assert!(
this.writes < this.panic_after,
"transport was polled too many times without yielding"
);
this.writes += 1;
let n = std::cmp::min(buf.len(), this.max_per_write);
this.written.extend_from_slice(&buf[..n]);
Poll::Ready(Ok(n))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[derive(Debug)]
struct ErrorDuplex {
kind: io::ErrorKind,
}
impl ErrorDuplex {
fn new(kind: io::ErrorKind) -> Self {
Self { kind }
}
}
impl AsyncRead for ErrorDuplex {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let kind = self.get_mut().kind;
Poll::Ready(Err(io::Error::new(kind, "framed duplex read error")))
}
}
impl AsyncWrite for ErrorDuplex {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(Ok(0))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[test]
fn framed_read_and_write() {
let transport = DuplexBuf::new(b"incoming\n");
let mut framed = Framed::new(transport, LinesCodec::new());
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut framed).poll_next(&mut cx);
assert!(matches!(poll, Poll::Ready(Some(Ok(ref s))) if s == "incoming"));
framed.send("outgoing".to_string()).unwrap();
let poll = framed.poll_flush(&mut cx);
assert!(matches!(poll, Poll::Ready(Ok(()))));
assert_eq!(&framed.get_ref().written, b"outgoing\n");
}
#[test]
fn framed_multiple_reads() {
let transport = DuplexBuf::new(b"one\ntwo\nthree\n");
let mut framed = Framed::new(transport, LinesCodec::new());
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut framed).poll_next(&mut cx);
assert!(matches!(poll, Poll::Ready(Some(Ok(ref s))) if s == "one"));
let poll = Pin::new(&mut framed).poll_next(&mut cx);
assert!(matches!(poll, Poll::Ready(Some(Ok(ref s))) if s == "two"));
let poll = Pin::new(&mut framed).poll_next(&mut cx);
assert!(matches!(poll, Poll::Ready(Some(Ok(ref s))) if s == "three"));
let poll = Pin::new(&mut framed).poll_next(&mut cx);
assert!(matches!(poll, Poll::Ready(None)));
}
#[test]
fn framed_close() {
let transport = DuplexBuf::new(b"");
let mut framed = Framed::new(transport, LinesCodec::new());
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
framed.send("final".to_string()).unwrap();
let poll = framed.poll_close(&mut cx);
assert!(matches!(poll, Poll::Ready(Ok(()))));
assert_eq!(&framed.get_ref().written, b"final\n");
}
#[test]
fn framed_accessors() {
let transport = DuplexBuf::new(b"");
let mut framed = Framed::new(transport, LinesCodec::new());
assert!(framed.read_buffer().is_empty());
assert!(framed.write_buffer().is_empty());
let _codec = framed.codec();
let _codec_mut = framed.codec_mut();
let _transport = framed.get_ref();
let _transport_mut = framed.get_mut();
}
#[test]
fn framed_into_parts() {
let transport = DuplexBuf::new(b"");
let framed = Framed::new(transport, LinesCodec::new());
let parts = framed.into_parts();
assert!(parts.read_buf.is_empty());
assert!(parts.write_buf.is_empty());
}
#[test]
fn framed_debug() {
let transport = DuplexBuf::new(b"");
let framed = Framed::new(transport, LinesCodec::new());
let dbg = format!("{framed:?}");
assert!(dbg.contains("Framed"));
assert!(dbg.contains("read_buf_len"));
assert!(dbg.contains("write_buf_len"));
}
#[test]
fn framed_with_capacity() {
let transport = DuplexBuf::new(b"");
let framed = Framed::with_capacity(transport, LinesCodec::new(), 256);
assert!(framed.read_buffer().is_empty());
assert!(framed.write_buffer().is_empty());
}
#[test]
fn framed_into_inner() {
let transport = DuplexBuf::new(b"test-data");
let framed = Framed::new(transport, LinesCodec::new());
let inner = framed.into_inner();
assert_eq!(&inner.read_data, b"test-data");
assert_eq!(inner.read_pos, 0);
}
#[test]
fn framed_parts_fields() {
let transport = DuplexBuf::new(b"parts-test\n");
let mut framed = Framed::new(transport, LinesCodec::new());
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let _ = Pin::new(&mut framed).poll_next(&mut cx);
let parts = framed.into_parts();
let inner = parts.inner;
assert_eq!(&inner.read_data, b"parts-test\n");
let _ = parts.codec;
}
#[test]
fn framed_get_mut_modifies_transport() {
let transport = DuplexBuf::new(b"");
let mut framed = Framed::new(transport, LinesCodec::new());
framed.get_mut().read_data = b"modified\n".to_vec();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut framed).poll_next(&mut cx);
assert!(matches!(poll, Poll::Ready(Some(Ok(ref s))) if s == "modified"));
}
#[test]
fn framed_codec_mut_accessible() {
let transport = DuplexBuf::new(b"");
let mut framed = Framed::new(transport, LinesCodec::new());
let _codec = framed.codec_mut();
}
#[test]
fn framed_empty_read_returns_none() {
let transport = DuplexBuf::new(b"");
let mut framed = Framed::new(transport, LinesCodec::new());
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut framed).poll_next(&mut cx);
assert!(matches!(poll, Poll::Ready(None)));
}
#[test]
fn framed_yields_cooperatively_on_always_ready_transport() {
let transport = AlwaysReadyDuplex::new(MAX_READ_PASSES_PER_POLL + 1);
let mut framed = Framed::new(transport, LinesCodec::new());
let woke = Arc::new(AtomicBool::new(false));
let waker = track_waker(Arc::clone(&woke));
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut framed).poll_next(&mut cx);
assert!(matches!(poll, Poll::Pending));
assert!(
woke.load(Ordering::SeqCst),
"cooperative yield should self-wake for continued draining"
);
assert_eq!(
framed.get_ref().reads,
MAX_READ_PASSES_PER_POLL,
"poll_next should stop after the cooperative read budget"
);
assert_eq!(
framed.read_buffer().len(),
MAX_READ_PASSES_PER_POLL,
"already-read bytes must stay buffered across the cooperative yield"
);
}
#[test]
fn framed_write_side_yields_cooperatively_on_always_ready_partial_transport() {
let transport = AlwaysReadyPartialWriteDuplex::new(1, MAX_WRITE_PASSES_PER_POLL + 1);
let mut framed = Framed::new(transport, LinesCodec::new());
let woke = Arc::new(AtomicBool::new(false));
let waker = track_waker(Arc::clone(&woke));
let mut cx = Context::from_waker(&waker);
framed
.send("x".repeat(MAX_WRITE_PASSES_PER_POLL + 8))
.expect("encode test frame");
let poll = framed.poll_flush(&mut cx);
assert!(matches!(poll, Poll::Pending));
assert!(
woke.load(Ordering::SeqCst),
"cooperative yield should self-wake for continued draining"
);
assert_eq!(
framed.get_ref().writes,
MAX_WRITE_PASSES_PER_POLL,
"poll_flush should stop after the cooperative write budget"
);
assert!(
!framed.write_buffer().is_empty(),
"buffered frame bytes must remain after the cooperative yield"
);
}
#[test]
fn framed_preserves_io_error_kind_from_lines_codec() {
let transport = ErrorDuplex::new(io::ErrorKind::ConnectionReset);
let mut framed = Framed::new(transport, LinesCodec::new());
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut framed).poll_next(&mut cx);
match poll {
Poll::Ready(Some(Err(LinesCodecError::Io(err)))) => {
assert_eq!(err.kind(), io::ErrorKind::ConnectionReset);
}
other => panic!("expected io error propagation, got {other:?}"), }
}
}