1use std::task::{Poll, ready};
11
12use bytes::Bytes;
13
14use crate::{Error, Result};
15
16use super::{Frame, FrameConsumer, FrameProducer};
17
18#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub struct Group {
24 pub sequence: u64,
25}
26
27impl Group {
28 pub fn produce(self) -> GroupProducer {
29 GroupProducer::new(self)
30 }
31}
32
33impl From<usize> for Group {
34 fn from(sequence: usize) -> Self {
35 Self {
36 sequence: sequence as u64,
37 }
38 }
39}
40
41impl From<u64> for Group {
42 fn from(sequence: u64) -> Self {
43 Self { sequence }
44 }
45}
46
47impl From<u32> for Group {
48 fn from(sequence: u32) -> Self {
49 Self {
50 sequence: sequence as u64,
51 }
52 }
53}
54
55impl From<u16> for Group {
56 fn from(sequence: u16) -> Self {
57 Self {
58 sequence: sequence as u64,
59 }
60 }
61}
62
63#[derive(Default)]
64struct GroupState {
65 frames: Vec<FrameProducer>,
68
69 fin: bool,
71
72 abort: Option<Error>,
74}
75
76impl GroupState {
77 fn poll_get_frame(&self, index: usize) -> Poll<Result<Option<FrameConsumer>>> {
78 if let Some(frame) = self.frames.get(index) {
79 Poll::Ready(Ok(Some(frame.consume())))
80 } else if self.fin {
81 Poll::Ready(Ok(None))
82 } else if let Some(err) = &self.abort {
83 Poll::Ready(Err(err.clone()))
84 } else {
85 Poll::Pending
86 }
87 }
88
89 fn poll_finished(&self) -> Poll<Result<u64>> {
90 if self.fin {
91 Poll::Ready(Ok(self.frames.len() as u64))
92 } else if let Some(err) = &self.abort {
93 Poll::Ready(Err(err.clone()))
94 } else {
95 Poll::Pending
96 }
97 }
98}
99
100fn modify(state: &conducer::Producer<GroupState>) -> Result<conducer::Mut<'_, GroupState>> {
101 state.write().map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
102}
103
104pub struct GroupProducer {
110 state: conducer::Producer<GroupState>,
112
113 pub info: Group,
115}
116
117impl GroupProducer {
118 pub fn new(info: Group) -> Self {
120 Self {
121 info,
122 state: conducer::Producer::default(),
123 }
124 }
125
126 pub fn write_frame<B: Into<Bytes>>(&mut self, frame: B) -> Result<()> {
131 let data = frame.into();
132 let frame = Frame {
133 size: data.len() as u64,
134 };
135 let mut frame = self.create_frame(frame)?;
136 frame.write(data)?;
137 frame.finish()?;
138 Ok(())
139 }
140
141 pub fn create_frame(&mut self, info: Frame) -> Result<FrameProducer> {
143 let frame = info.produce();
144 self.append_frame(frame.clone())?;
145 Ok(frame)
146 }
147
148 pub fn append_frame(&mut self, frame: FrameProducer) -> Result<()> {
150 let mut state = modify(&self.state)?;
151 if state.fin {
152 return Err(Error::Closed);
153 }
154 state.frames.push(frame);
155 Ok(())
156 }
157
158 pub fn finish(&mut self) -> Result<()> {
160 let mut state = modify(&self.state)?;
161 state.fin = true;
162 Ok(())
163 }
164
165 pub fn abort(&mut self, err: Error) -> Result<()> {
169 let mut guard = modify(&self.state)?;
170
171 for frame in guard.frames.iter_mut() {
173 frame.abort(err.clone()).ok();
175 }
176
177 guard.abort = Some(err);
178 guard.close();
179 Ok(())
180 }
181
182 pub fn consume(&self) -> GroupConsumer {
184 GroupConsumer {
185 info: self.info.clone(),
186 state: self.state.consume(),
187 index: 0,
188 }
189 }
190
191 pub async fn closed(&self) -> Error {
193 self.state.closed().await;
194 self.state.read().abort.clone().unwrap_or(Error::Dropped)
195 }
196
197 pub async fn unused(&self) -> Result<()> {
199 self.state
200 .unused()
201 .await
202 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
203 }
204}
205
206impl Clone for GroupProducer {
207 fn clone(&self) -> Self {
208 Self {
209 info: self.info.clone(),
210 state: self.state.clone(),
211 }
212 }
213}
214
215impl From<Group> for GroupProducer {
216 fn from(info: Group) -> Self {
217 GroupProducer::new(info)
218 }
219}
220
221#[derive(Clone)]
223pub struct GroupConsumer {
224 state: conducer::Consumer<GroupState>,
226
227 pub info: Group,
229
230 index: usize,
233}
234
235impl GroupConsumer {
236 fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R>>
238 where
239 F: Fn(&conducer::Ref<'_, GroupState>) -> Poll<Result<R>>,
240 {
241 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
242 Ok(res) => res,
243 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
245 })
246 }
247
248 pub async fn get_frame(&self, index: usize) -> Result<Option<FrameConsumer>> {
252 conducer::wait(|waiter| self.poll_get_frame(waiter, index)).await
253 }
254
255 pub fn poll_get_frame(&self, waiter: &conducer::Waiter, index: usize) -> Poll<Result<Option<FrameConsumer>>> {
259 self.poll(waiter, |state| state.poll_get_frame(index))
260 }
261
262 pub async fn next_frame(&mut self) -> Result<Option<FrameConsumer>> {
264 conducer::wait(|waiter| self.poll_next_frame(waiter)).await
265 }
266
267 pub fn poll_next_frame(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<FrameConsumer>>> {
271 let Some(frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
272 return Poll::Ready(Ok(None));
273 };
274
275 self.index += 1;
276 Poll::Ready(Ok(Some(frame)))
277 }
278
279 pub fn poll_read_frame(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Bytes>>> {
281 let Some(mut frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
282 return Poll::Ready(Ok(None));
283 };
284
285 let data = ready!(frame.poll_read_all(waiter))?;
286 self.index += 1;
287
288 Poll::Ready(Ok(Some(data)))
289 }
290
291 pub async fn read_frame(&mut self) -> Result<Option<Bytes>> {
293 conducer::wait(|waiter| self.poll_read_frame(waiter)).await
294 }
295
296 pub fn poll_read_frame_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Vec<Bytes>>>> {
298 let Some(mut frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
299 return Poll::Ready(Ok(None));
300 };
301
302 let data = ready!(frame.poll_read_all_chunks(waiter))?;
303 self.index += 1;
304
305 Poll::Ready(Ok(Some(data)))
306 }
307
308 pub async fn read_frame_chunks(&mut self) -> Result<Option<Vec<Bytes>>> {
310 conducer::wait(|waiter| self.poll_read_frame_chunks(waiter)).await
311 }
312
313 pub fn poll_finished(&mut self, waiter: &conducer::Waiter) -> Poll<Result<u64>> {
315 self.poll(waiter, |state| state.poll_finished())
316 }
317
318 pub async fn finished(&mut self) -> Result<u64> {
320 conducer::wait(|waiter| self.poll_finished(waiter)).await
321 }
322}
323
324#[cfg(test)]
325mod test {
326 use super::*;
327 use futures::FutureExt;
328
329 #[test]
330 fn basic_frame_reading() {
331 let mut producer = Group { sequence: 0 }.produce();
332 producer.write_frame(Bytes::from_static(b"frame0")).unwrap();
333 producer.write_frame(Bytes::from_static(b"frame1")).unwrap();
334 producer.finish().unwrap();
335
336 let mut consumer = producer.consume();
337 let f0 = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
338 assert_eq!(f0.info.size, 6);
339 let f1 = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
340 assert_eq!(f1.info.size, 6);
341 let end = consumer.next_frame().now_or_never().unwrap().unwrap();
342 assert!(end.is_none());
343 }
344
345 #[test]
346 fn read_frame_all_at_once() {
347 let mut producer = Group { sequence: 0 }.produce();
348 producer.write_frame(Bytes::from_static(b"hello")).unwrap();
349 producer.finish().unwrap();
350
351 let mut consumer = producer.consume();
352 let data = consumer.read_frame().now_or_never().unwrap().unwrap().unwrap();
353 assert_eq!(data, Bytes::from_static(b"hello"));
354 }
355
356 #[test]
357 fn read_frame_chunks() {
358 let mut producer = Group { sequence: 0 }.produce();
359 let mut frame = producer.create_frame(Frame { size: 10 }).unwrap();
360 frame.write(Bytes::from_static(b"hello")).unwrap();
361 frame.write(Bytes::from_static(b"world")).unwrap();
362 frame.finish().unwrap();
363 producer.finish().unwrap();
364
365 let mut consumer = producer.consume();
366 let chunks = consumer.read_frame_chunks().now_or_never().unwrap().unwrap().unwrap();
367 assert_eq!(chunks.len(), 2);
368 assert_eq!(chunks[0], Bytes::from_static(b"hello"));
369 assert_eq!(chunks[1], Bytes::from_static(b"world"));
370 }
371
372 #[test]
373 fn get_frame_by_index() {
374 let mut producer = Group { sequence: 0 }.produce();
375 producer.write_frame(Bytes::from_static(b"a")).unwrap();
376 producer.write_frame(Bytes::from_static(b"bb")).unwrap();
377 producer.finish().unwrap();
378
379 let consumer = producer.consume();
380 let f0 = consumer.get_frame(0).now_or_never().unwrap().unwrap().unwrap();
381 assert_eq!(f0.info.size, 1);
382 let f1 = consumer.get_frame(1).now_or_never().unwrap().unwrap().unwrap();
383 assert_eq!(f1.info.size, 2);
384 let f2 = consumer.get_frame(2).now_or_never().unwrap().unwrap();
385 assert!(f2.is_none());
386 }
387
388 #[test]
389 fn group_finish_returns_none() {
390 let mut producer = Group { sequence: 0 }.produce();
391 producer.finish().unwrap();
392
393 let mut consumer = producer.consume();
394 let end = consumer.next_frame().now_or_never().unwrap().unwrap();
395 assert!(end.is_none());
396 }
397
398 #[test]
399 fn abort_propagates() {
400 let mut producer = Group { sequence: 0 }.produce();
401 let mut consumer = producer.consume();
402 producer.abort(crate::Error::Cancel).unwrap();
403
404 let result = consumer.next_frame().now_or_never().unwrap();
405 assert!(matches!(result, Err(crate::Error::Cancel)));
406 }
407
408 #[tokio::test]
409 async fn pending_then_ready() {
410 let mut producer = Group { sequence: 0 }.produce();
411 let mut consumer = producer.consume();
412
413 assert!(consumer.next_frame().now_or_never().is_none());
415
416 producer.write_frame(Bytes::from_static(b"data")).unwrap();
417 producer.finish().unwrap();
418
419 let frame = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
420 assert_eq!(frame.info.size, 4);
421 }
422
423 #[test]
424 fn clone_consumer_independent() {
425 let mut producer = Group { sequence: 0 }.produce();
426 producer.write_frame(Bytes::from_static(b"a")).unwrap();
427
428 let mut c1 = producer.consume();
429 let _ = c1.next_frame().now_or_never().unwrap().unwrap().unwrap();
431
432 let mut c2 = c1.clone();
434
435 producer.write_frame(Bytes::from_static(b"b")).unwrap();
436 producer.finish().unwrap();
437
438 let f = c2.next_frame().now_or_never().unwrap().unwrap().unwrap();
440 assert_eq!(f.info.size, 1); let end = c2.next_frame().now_or_never().unwrap().unwrap();
443 assert!(end.is_none());
444 }
445}