use std::task::{Poll, ready};
use bytes::Bytes;
use crate::{Error, Result};
use super::{Frame, FrameConsumer, FrameProducer};
#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Group {
pub sequence: u64,
}
impl Group {
pub fn produce(self) -> GroupProducer {
GroupProducer::new(self)
}
}
impl From<usize> for Group {
fn from(sequence: usize) -> Self {
Self {
sequence: sequence as u64,
}
}
}
impl From<u64> for Group {
fn from(sequence: u64) -> Self {
Self { sequence }
}
}
impl From<u32> for Group {
fn from(sequence: u32) -> Self {
Self {
sequence: sequence as u64,
}
}
}
impl From<u16> for Group {
fn from(sequence: u16) -> Self {
Self {
sequence: sequence as u64,
}
}
}
#[derive(Default)]
struct GroupState {
frames: Vec<FrameProducer>,
fin: bool,
abort: Option<Error>,
}
impl GroupState {
fn poll_get_frame(&self, index: usize) -> Poll<Result<Option<FrameConsumer>>> {
if let Some(frame) = self.frames.get(index) {
Poll::Ready(Ok(Some(frame.consume())))
} else if self.fin {
Poll::Ready(Ok(None))
} else if let Some(err) = &self.abort {
Poll::Ready(Err(err.clone()))
} else {
Poll::Pending
}
}
fn poll_finished(&self) -> Poll<Result<u64>> {
if self.fin {
Poll::Ready(Ok(self.frames.len() as u64))
} else if let Some(err) = &self.abort {
Poll::Ready(Err(err.clone()))
} else {
Poll::Pending
}
}
}
fn modify(state: &conducer::Producer<GroupState>) -> Result<conducer::Mut<'_, GroupState>> {
state.write().map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
}
pub struct GroupProducer {
state: conducer::Producer<GroupState>,
pub info: Group,
}
impl GroupProducer {
pub fn new(info: Group) -> Self {
Self {
info,
state: conducer::Producer::default(),
}
}
pub fn write_frame<B: Into<Bytes>>(&mut self, frame: B) -> Result<()> {
let data = frame.into();
let frame = Frame {
size: data.len() as u64,
};
let mut frame = self.create_frame(frame)?;
frame.write(data)?;
frame.finish()?;
Ok(())
}
pub fn create_frame(&mut self, info: Frame) -> Result<FrameProducer> {
let frame = info.produce();
self.append_frame(frame.clone())?;
Ok(frame)
}
pub fn append_frame(&mut self, frame: FrameProducer) -> Result<()> {
let mut state = modify(&self.state)?;
if state.fin {
return Err(Error::Closed);
}
state.frames.push(frame);
Ok(())
}
pub fn finish(&mut self) -> Result<()> {
let mut state = modify(&self.state)?;
state.fin = true;
Ok(())
}
pub fn abort(&mut self, err: Error) -> Result<()> {
let mut guard = modify(&self.state)?;
for frame in guard.frames.iter_mut() {
frame.abort(err.clone()).ok();
}
guard.abort = Some(err);
guard.close();
Ok(())
}
pub fn consume(&self) -> GroupConsumer {
GroupConsumer {
info: self.info.clone(),
state: self.state.consume(),
index: 0,
}
}
pub async fn closed(&self) -> Error {
self.state.closed().await;
self.state.read().abort.clone().unwrap_or(Error::Dropped)
}
pub async fn unused(&self) -> Result<()> {
self.state
.unused()
.await
.map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
}
}
impl Clone for GroupProducer {
fn clone(&self) -> Self {
Self {
info: self.info.clone(),
state: self.state.clone(),
}
}
}
impl From<Group> for GroupProducer {
fn from(info: Group) -> Self {
GroupProducer::new(info)
}
}
#[derive(Clone)]
pub struct GroupConsumer {
state: conducer::Consumer<GroupState>,
pub info: Group,
index: usize,
}
impl GroupConsumer {
fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R>>
where
F: Fn(&conducer::Ref<'_, GroupState>) -> Poll<Result<R>>,
{
Poll::Ready(match ready!(self.state.poll(waiter, f)) {
Ok(res) => res,
Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
})
}
pub async fn get_frame(&self, index: usize) -> Result<Option<FrameConsumer>> {
conducer::wait(|waiter| self.poll_get_frame(waiter, index)).await
}
pub fn poll_get_frame(&self, waiter: &conducer::Waiter, index: usize) -> Poll<Result<Option<FrameConsumer>>> {
self.poll(waiter, |state| state.poll_get_frame(index))
}
pub async fn next_frame(&mut self) -> Result<Option<FrameConsumer>> {
conducer::wait(|waiter| self.poll_next_frame(waiter)).await
}
pub fn poll_next_frame(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<FrameConsumer>>> {
let Some(frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
return Poll::Ready(Ok(None));
};
self.index += 1;
Poll::Ready(Ok(Some(frame)))
}
pub fn poll_read_frame(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Bytes>>> {
let Some(mut frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
return Poll::Ready(Ok(None));
};
let data = ready!(frame.poll_read_all(waiter))?;
self.index += 1;
Poll::Ready(Ok(Some(data)))
}
pub async fn read_frame(&mut self) -> Result<Option<Bytes>> {
conducer::wait(|waiter| self.poll_read_frame(waiter)).await
}
pub fn poll_read_frame_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Vec<Bytes>>>> {
let Some(mut frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
return Poll::Ready(Ok(None));
};
let data = ready!(frame.poll_read_all_chunks(waiter))?;
self.index += 1;
Poll::Ready(Ok(Some(data)))
}
pub async fn read_frame_chunks(&mut self) -> Result<Option<Vec<Bytes>>> {
conducer::wait(|waiter| self.poll_read_frame_chunks(waiter)).await
}
pub fn poll_finished(&mut self, waiter: &conducer::Waiter) -> Poll<Result<u64>> {
self.poll(waiter, |state| state.poll_finished())
}
pub async fn finished(&mut self) -> Result<u64> {
conducer::wait(|waiter| self.poll_finished(waiter)).await
}
}
#[cfg(test)]
mod test {
use super::*;
use futures::FutureExt;
#[test]
fn basic_frame_reading() {
let mut producer = Group { sequence: 0 }.produce();
producer.write_frame(Bytes::from_static(b"frame0")).unwrap();
producer.write_frame(Bytes::from_static(b"frame1")).unwrap();
producer.finish().unwrap();
let mut consumer = producer.consume();
let f0 = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
assert_eq!(f0.info.size, 6);
let f1 = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
assert_eq!(f1.info.size, 6);
let end = consumer.next_frame().now_or_never().unwrap().unwrap();
assert!(end.is_none());
}
#[test]
fn read_frame_all_at_once() {
let mut producer = Group { sequence: 0 }.produce();
producer.write_frame(Bytes::from_static(b"hello")).unwrap();
producer.finish().unwrap();
let mut consumer = producer.consume();
let data = consumer.read_frame().now_or_never().unwrap().unwrap().unwrap();
assert_eq!(data, Bytes::from_static(b"hello"));
}
#[test]
fn read_frame_chunks() {
let mut producer = Group { sequence: 0 }.produce();
let mut frame = producer.create_frame(Frame { size: 10 }).unwrap();
frame.write(Bytes::from_static(b"hello")).unwrap();
frame.write(Bytes::from_static(b"world")).unwrap();
frame.finish().unwrap();
producer.finish().unwrap();
let mut consumer = producer.consume();
let chunks = consumer.read_frame_chunks().now_or_never().unwrap().unwrap().unwrap();
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0], Bytes::from_static(b"hello"));
assert_eq!(chunks[1], Bytes::from_static(b"world"));
}
#[test]
fn get_frame_by_index() {
let mut producer = Group { sequence: 0 }.produce();
producer.write_frame(Bytes::from_static(b"a")).unwrap();
producer.write_frame(Bytes::from_static(b"bb")).unwrap();
producer.finish().unwrap();
let consumer = producer.consume();
let f0 = consumer.get_frame(0).now_or_never().unwrap().unwrap().unwrap();
assert_eq!(f0.info.size, 1);
let f1 = consumer.get_frame(1).now_or_never().unwrap().unwrap().unwrap();
assert_eq!(f1.info.size, 2);
let f2 = consumer.get_frame(2).now_or_never().unwrap().unwrap();
assert!(f2.is_none());
}
#[test]
fn group_finish_returns_none() {
let mut producer = Group { sequence: 0 }.produce();
producer.finish().unwrap();
let mut consumer = producer.consume();
let end = consumer.next_frame().now_or_never().unwrap().unwrap();
assert!(end.is_none());
}
#[test]
fn abort_propagates() {
let mut producer = Group { sequence: 0 }.produce();
let mut consumer = producer.consume();
producer.abort(crate::Error::Cancel).unwrap();
let result = consumer.next_frame().now_or_never().unwrap();
assert!(matches!(result, Err(crate::Error::Cancel)));
}
#[tokio::test]
async fn pending_then_ready() {
let mut producer = Group { sequence: 0 }.produce();
let mut consumer = producer.consume();
assert!(consumer.next_frame().now_or_never().is_none());
producer.write_frame(Bytes::from_static(b"data")).unwrap();
producer.finish().unwrap();
let frame = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
assert_eq!(frame.info.size, 4);
}
#[test]
fn clone_consumer_independent() {
let mut producer = Group { sequence: 0 }.produce();
producer.write_frame(Bytes::from_static(b"a")).unwrap();
let mut c1 = producer.consume();
let _ = c1.next_frame().now_or_never().unwrap().unwrap().unwrap();
let mut c2 = c1.clone();
producer.write_frame(Bytes::from_static(b"b")).unwrap();
producer.finish().unwrap();
let f = c2.next_frame().now_or_never().unwrap().unwrap().unwrap();
assert_eq!(f.info.size, 1);
let end = c2.next_frame().now_or_never().unwrap().unwrap();
assert!(end.is_none());
}
}