use std::task::{Poll, ready};
use bytes::{Bytes, BytesMut};
use crate::{Error, Result};
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Frame {
pub size: u64,
}
impl Frame {
pub fn produce(self) -> FrameProducer {
FrameProducer::new(self)
}
}
impl From<usize> for Frame {
fn from(size: usize) -> Self {
Self { size: size as u64 }
}
}
impl From<u64> for Frame {
fn from(size: u64) -> Self {
Self { size }
}
}
impl From<u32> for Frame {
fn from(size: u32) -> Self {
Self { size: size as u64 }
}
}
impl From<u16> for Frame {
fn from(size: u16) -> Self {
Self { size: size as u64 }
}
}
#[derive(Default, Debug)]
struct FrameState {
chunks: Vec<Bytes>,
remaining: u64,
abort: Option<Error>,
}
impl FrameState {
fn write_chunk(&mut self, chunk: Bytes) -> Result<()> {
if let Some(err) = &self.abort {
return Err(err.clone());
}
self.remaining = self.remaining.checked_sub(chunk.len() as u64).ok_or(Error::WrongSize)?;
self.chunks.push(chunk);
Ok(())
}
fn poll_read_chunk(&self, index: usize) -> Poll<Result<Option<Bytes>>> {
if let Some(chunk) = self.chunks.get(index).cloned() {
Poll::Ready(Ok(Some(chunk)))
} else if self.remaining == 0 {
Poll::Ready(Ok(None))
} else if let Some(err) = &self.abort {
Poll::Ready(Err(err.clone()))
} else {
Poll::Pending
}
}
fn poll_read_chunks(&self, index: usize) -> Poll<Result<Vec<Bytes>>> {
if index >= self.chunks.len() && self.remaining == 0 {
Poll::Ready(Ok(Vec::new()))
} else if self.remaining == 0 {
Poll::Ready(Ok(self.chunks[index..].to_vec()))
} else if let Some(err) = &self.abort {
Poll::Ready(Err(err.clone()))
} else {
Poll::Pending
}
}
fn poll_read_all(&self, index: usize) -> Poll<Result<Bytes>> {
let chunks = ready!(self.poll_read_all_chunks(index)?);
Poll::Ready(Ok(match chunks.len() {
0 => Bytes::new(),
1 => chunks[0].clone(),
_ => {
let size = chunks.iter().map(Bytes::len).sum();
let mut buf = BytesMut::with_capacity(size);
for chunk in chunks {
buf.extend_from_slice(chunk.as_ref());
}
buf.freeze()
}
}))
}
fn poll_read_all_chunks(&self, index: usize) -> Poll<Result<&[Bytes]>> {
if self.remaining > 0 {
Poll::Pending
} else if let Some(err) = &self.abort {
Poll::Ready(Err(err.clone()))
} else if index < self.chunks.len() {
Poll::Ready(Ok(&self.chunks[index..]))
} else {
Poll::Ready(Ok(&[]))
}
}
}
pub struct FrameProducer {
pub info: Frame,
state: conducer::Producer<FrameState>,
}
impl FrameProducer {
pub fn new(info: Frame) -> Self {
let state = FrameState {
chunks: Vec::new(),
remaining: info.size,
abort: None,
};
Self {
info,
state: conducer::Producer::new(state),
}
}
pub fn write<B: Into<Bytes>>(&mut self, chunk: B) -> Result<()> {
let chunk = chunk.into();
let mut state = self.modify()?;
state.write_chunk(chunk)
}
#[deprecated(note = "use write(chunk) instead")]
pub fn write_chunk<B: Into<Bytes>>(&mut self, chunk: B) -> Result<()> {
self.write(chunk)
}
pub fn finish(&mut self) -> Result<()> {
let state = self.modify()?;
if state.remaining != 0 {
return Err(Error::WrongSize);
}
Ok(())
}
pub fn abort(&mut self, err: Error) -> Result<()> {
let mut guard = self.modify()?;
guard.abort = Some(err);
guard.close();
Ok(())
}
pub fn consume(&self) -> FrameConsumer {
FrameConsumer {
info: self.info.clone(),
state: self.state.consume(),
index: 0,
}
}
pub async fn unused(&self) -> Result<()> {
self.state
.unused()
.await
.map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
}
fn modify(&mut self) -> Result<conducer::Mut<'_, FrameState>> {
self.state
.write()
.map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
}
}
impl Clone for FrameProducer {
fn clone(&self) -> Self {
Self {
info: self.info.clone(),
state: self.state.clone(),
}
}
}
impl From<Frame> for FrameProducer {
fn from(info: Frame) -> Self {
FrameProducer::new(info)
}
}
#[derive(Clone)]
pub struct FrameConsumer {
pub info: Frame,
state: conducer::Consumer<FrameState>,
index: usize,
}
impl FrameConsumer {
fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R>>
where
F: Fn(&conducer::Ref<'_, FrameState>) -> Poll<Result<R>>,
{
Poll::Ready(match ready!(self.state.poll(waiter, f)) {
Ok(res) => res,
Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
})
}
pub fn poll_read_all(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Bytes>> {
let data = ready!(self.poll(waiter, |state| state.poll_read_all(self.index))?);
self.index = usize::MAX;
Poll::Ready(Ok(data))
}
pub async fn read_all(&mut self) -> Result<Bytes> {
conducer::wait(|waiter| self.poll_read_all(waiter)).await
}
pub fn poll_read_all_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Vec<Bytes>>> {
let chunks = ready!(self.poll(waiter, |state| {
state
.poll_read_all_chunks(self.index)
.map(|res| res.map(|chunks| chunks.to_vec()))
})?);
self.index += chunks.len();
Poll::Ready(Ok(chunks))
}
pub fn poll_read_chunk(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Bytes>>> {
let Some(chunk) = ready!(self.poll(waiter, |state| state.poll_read_chunk(self.index))?) else {
return Poll::Ready(Ok(None));
};
self.index += 1;
Poll::Ready(Ok(Some(chunk)))
}
pub async fn read_chunk(&mut self) -> Result<Option<Bytes>> {
conducer::wait(|waiter| self.poll_read_chunk(waiter)).await
}
pub fn poll_read_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Vec<Bytes>>> {
let chunks = ready!(self.poll(waiter, |state| state.poll_read_chunks(self.index))?);
self.index += chunks.len();
Poll::Ready(Ok(chunks))
}
pub async fn read_chunks(&mut self) -> Result<Vec<Bytes>> {
conducer::wait(|waiter| self.poll_read_chunks(waiter)).await
}
}
#[cfg(test)]
mod test {
use super::*;
use futures::FutureExt;
#[test]
fn single_chunk_roundtrip() {
let mut producer = Frame { size: 5 }.produce();
producer.write(Bytes::from_static(b"hello")).unwrap();
producer.finish().unwrap();
let mut consumer = producer.consume();
let data = consumer.read_all().now_or_never().unwrap().unwrap();
assert_eq!(data, Bytes::from_static(b"hello"));
}
#[test]
fn multi_chunk_read_all() {
let mut producer = Frame { size: 10 }.produce();
producer.write(Bytes::from_static(b"hello")).unwrap();
producer.write(Bytes::from_static(b"world")).unwrap();
producer.finish().unwrap();
let mut consumer = producer.consume();
let data = consumer.read_all().now_or_never().unwrap().unwrap();
assert_eq!(data, Bytes::from_static(b"helloworld"));
}
#[test]
fn read_chunk_sequential() {
let mut producer = Frame { size: 10 }.produce();
producer.write(Bytes::from_static(b"hello")).unwrap();
producer.write(Bytes::from_static(b"world")).unwrap();
producer.finish().unwrap();
let mut consumer = producer.consume();
let c1 = consumer.read_chunk().now_or_never().unwrap().unwrap();
assert_eq!(c1, Some(Bytes::from_static(b"hello")));
let c2 = consumer.read_chunk().now_or_never().unwrap().unwrap();
assert_eq!(c2, Some(Bytes::from_static(b"world")));
let c3 = consumer.read_chunk().now_or_never().unwrap().unwrap();
assert_eq!(c3, None);
}
#[test]
fn read_all_chunks() {
let mut producer = Frame { size: 10 }.produce();
producer.write(Bytes::from_static(b"hello")).unwrap();
producer.write(Bytes::from_static(b"world")).unwrap();
producer.finish().unwrap();
let mut consumer = producer.consume();
let chunks = consumer.read_chunks().now_or_never().unwrap().unwrap();
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0], Bytes::from_static(b"hello"));
assert_eq!(chunks[1], Bytes::from_static(b"world"));
}
#[test]
fn finish_checks_remaining() {
let mut producer = Frame { size: 5 }.produce();
producer.write(Bytes::from_static(b"hi")).unwrap();
let err = producer.finish().unwrap_err();
assert!(matches!(err, Error::WrongSize));
}
#[test]
fn write_too_many_bytes() {
let mut producer = Frame { size: 3 }.produce();
let err = producer.write(Bytes::from_static(b"toolong")).unwrap_err();
assert!(matches!(err, Error::WrongSize));
}
#[test]
fn abort_propagates() {
let mut producer = Frame { size: 5 }.produce();
let mut consumer = producer.consume();
producer.abort(Error::Cancel).unwrap();
let err = consumer.read_all().now_or_never().unwrap().unwrap_err();
assert!(matches!(err, Error::Cancel));
}
#[test]
fn empty_frame() {
let mut producer = Frame { size: 0 }.produce();
producer.finish().unwrap();
let mut consumer = producer.consume();
let data = consumer.read_all().now_or_never().unwrap().unwrap();
assert_eq!(data, Bytes::new());
}
#[tokio::test]
async fn pending_then_ready() {
let mut producer = Frame { size: 5 }.produce();
let mut consumer = producer.consume();
assert!(consumer.read_all().now_or_never().is_none());
producer.write(Bytes::from_static(b"hello")).unwrap();
producer.finish().unwrap();
let data = consumer.read_all().now_or_never().unwrap().unwrap();
assert_eq!(data, Bytes::from_static(b"hello"));
}
}