1use std::task::{Poll, ready};
11
12use bytes::Bytes;
13
14use crate::{Error, Result};
15
16use super::{Frame, FrameConsumer, FrameProducer};
17
18#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub struct Group {
24 pub sequence: u64,
25}
26
27impl Group {
28 pub fn produce(self) -> GroupProducer {
29 GroupProducer::new(self)
30 }
31}
32
33impl From<usize> for Group {
34 fn from(sequence: usize) -> Self {
35 Self {
36 sequence: sequence as u64,
37 }
38 }
39}
40
41impl From<u64> for Group {
42 fn from(sequence: u64) -> Self {
43 Self { sequence }
44 }
45}
46
47impl From<u32> for Group {
48 fn from(sequence: u32) -> Self {
49 Self {
50 sequence: sequence as u64,
51 }
52 }
53}
54
55impl From<u16> for Group {
56 fn from(sequence: u16) -> Self {
57 Self {
58 sequence: sequence as u64,
59 }
60 }
61}
62
63#[derive(Default)]
64struct GroupState {
65 frames: Vec<FrameProducer>,
68
69 fin: bool,
71
72 abort: Option<Error>,
74}
75
76impl GroupState {
77 fn poll_get_frame(&self, index: usize) -> Poll<Result<Option<FrameConsumer>>> {
78 if let Some(frame) = self.frames.get(index) {
79 Poll::Ready(Ok(Some(frame.consume())))
80 } else if self.fin {
81 Poll::Ready(Ok(None))
82 } else if let Some(err) = &self.abort {
83 Poll::Ready(Err(err.clone()))
84 } else {
85 Poll::Pending
86 }
87 }
88
89 fn poll_finished(&self) -> Poll<Result<u64>> {
90 if self.fin {
91 Poll::Ready(Ok(self.frames.len() as u64))
92 } else if let Some(err) = &self.abort {
93 Poll::Ready(Err(err.clone()))
94 } else {
95 Poll::Pending
96 }
97 }
98}
99
100fn modify(state: &conducer::Producer<GroupState>) -> Result<conducer::Mut<'_, GroupState>> {
101 state.write().map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
102}
103
104pub struct GroupProducer {
110 state: conducer::Producer<GroupState>,
112
113 pub info: Group,
115}
116
117impl GroupProducer {
118 pub fn new(info: Group) -> Self {
120 Self {
121 info,
122 state: conducer::Producer::default(),
123 }
124 }
125
126 pub fn write_frame<B: Into<Bytes>>(&mut self, frame: B) -> Result<()> {
131 let data = frame.into();
132 let frame = Frame {
133 size: data.len() as u64,
134 };
135 let mut frame = self.create_frame(frame)?;
136 frame.write(data)?;
137 frame.finish()?;
138 Ok(())
139 }
140
141 pub fn create_frame(&mut self, info: Frame) -> Result<FrameProducer> {
143 let frame = info.produce();
144 self.append_frame(frame.clone())?;
145 Ok(frame)
146 }
147
148 pub fn append_frame(&mut self, frame: FrameProducer) -> Result<()> {
150 let mut state = modify(&self.state)?;
151 if state.fin {
152 return Err(Error::Closed);
153 }
154 state.frames.push(frame);
155 Ok(())
156 }
157
158 pub fn frame_count(&self) -> usize {
160 self.state.read().frames.len()
161 }
162
163 pub fn finish(&mut self) -> Result<()> {
165 let mut state = modify(&self.state)?;
166 state.fin = true;
167 Ok(())
168 }
169
170 pub fn abort(&mut self, err: Error) -> Result<()> {
174 let mut guard = modify(&self.state)?;
175
176 for frame in guard.frames.iter_mut() {
178 frame.abort(err.clone()).ok();
180 }
181
182 guard.abort = Some(err);
183 guard.close();
184 Ok(())
185 }
186
187 pub fn consume(&self) -> GroupConsumer {
189 GroupConsumer {
190 info: self.info.clone(),
191 state: self.state.consume(),
192 index: 0,
193 }
194 }
195
196 pub async fn closed(&self) -> Error {
198 self.state.closed().await;
199 self.state.read().abort.clone().unwrap_or(Error::Dropped)
200 }
201
202 pub async fn unused(&self) -> Result<()> {
204 self.state
205 .unused()
206 .await
207 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
208 }
209}
210
211impl Clone for GroupProducer {
212 fn clone(&self) -> Self {
213 Self {
214 info: self.info.clone(),
215 state: self.state.clone(),
216 }
217 }
218}
219
220impl From<Group> for GroupProducer {
221 fn from(info: Group) -> Self {
222 GroupProducer::new(info)
223 }
224}
225
226#[derive(Clone)]
228pub struct GroupConsumer {
229 state: conducer::Consumer<GroupState>,
231
232 pub info: Group,
234
235 index: usize,
238}
239
240impl GroupConsumer {
241 fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R>>
243 where
244 F: Fn(&conducer::Ref<'_, GroupState>) -> Poll<Result<R>>,
245 {
246 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
247 Ok(res) => res,
248 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
250 })
251 }
252
253 pub async fn get_frame(&self, index: usize) -> Result<Option<FrameConsumer>> {
257 conducer::wait(|waiter| self.poll_get_frame(waiter, index)).await
258 }
259
260 pub fn poll_get_frame(&self, waiter: &conducer::Waiter, index: usize) -> Poll<Result<Option<FrameConsumer>>> {
264 self.poll(waiter, |state| state.poll_get_frame(index))
265 }
266
267 pub async fn next_frame(&mut self) -> Result<Option<FrameConsumer>> {
269 conducer::wait(|waiter| self.poll_next_frame(waiter)).await
270 }
271
272 pub fn poll_next_frame(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<FrameConsumer>>> {
276 let Some(frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
277 return Poll::Ready(Ok(None));
278 };
279
280 self.index += 1;
281 Poll::Ready(Ok(Some(frame)))
282 }
283
284 pub fn poll_read_frame(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Bytes>>> {
286 let Some(mut frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
287 return Poll::Ready(Ok(None));
288 };
289
290 let data = ready!(frame.poll_read_all(waiter))?;
291 self.index += 1;
292
293 Poll::Ready(Ok(Some(data)))
294 }
295
296 pub async fn read_frame(&mut self) -> Result<Option<Bytes>> {
298 conducer::wait(|waiter| self.poll_read_frame(waiter)).await
299 }
300
301 pub fn poll_read_frame_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Vec<Bytes>>>> {
303 let Some(mut frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
304 return Poll::Ready(Ok(None));
305 };
306
307 let data = ready!(frame.poll_read_all_chunks(waiter))?;
308 self.index += 1;
309
310 Poll::Ready(Ok(Some(data)))
311 }
312
313 pub async fn read_frame_chunks(&mut self) -> Result<Option<Vec<Bytes>>> {
315 conducer::wait(|waiter| self.poll_read_frame_chunks(waiter)).await
316 }
317
318 pub fn poll_finished(&mut self, waiter: &conducer::Waiter) -> Poll<Result<u64>> {
320 self.poll(waiter, |state| state.poll_finished())
321 }
322
323 pub async fn finished(&mut self) -> Result<u64> {
325 conducer::wait(|waiter| self.poll_finished(waiter)).await
326 }
327}
328
329#[cfg(test)]
330mod test {
331 use super::*;
332 use futures::FutureExt;
333
334 #[test]
335 fn basic_frame_reading() {
336 let mut producer = Group { sequence: 0 }.produce();
337 producer.write_frame(Bytes::from_static(b"frame0")).unwrap();
338 producer.write_frame(Bytes::from_static(b"frame1")).unwrap();
339 producer.finish().unwrap();
340
341 let mut consumer = producer.consume();
342 let f0 = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
343 assert_eq!(f0.info.size, 6);
344 let f1 = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
345 assert_eq!(f1.info.size, 6);
346 let end = consumer.next_frame().now_or_never().unwrap().unwrap();
347 assert!(end.is_none());
348 }
349
350 #[test]
351 fn read_frame_all_at_once() {
352 let mut producer = Group { sequence: 0 }.produce();
353 producer.write_frame(Bytes::from_static(b"hello")).unwrap();
354 producer.finish().unwrap();
355
356 let mut consumer = producer.consume();
357 let data = consumer.read_frame().now_or_never().unwrap().unwrap().unwrap();
358 assert_eq!(data, Bytes::from_static(b"hello"));
359 }
360
361 #[test]
362 fn read_frame_chunks() {
363 let mut producer = Group { sequence: 0 }.produce();
364 let mut frame = producer.create_frame(Frame { size: 10 }).unwrap();
365 frame.write(Bytes::from_static(b"hello")).unwrap();
366 frame.write(Bytes::from_static(b"world")).unwrap();
367 frame.finish().unwrap();
368 producer.finish().unwrap();
369
370 let mut consumer = producer.consume();
371 let chunks = consumer.read_frame_chunks().now_or_never().unwrap().unwrap().unwrap();
372 assert_eq!(chunks.len(), 2);
373 assert_eq!(chunks[0], Bytes::from_static(b"hello"));
374 assert_eq!(chunks[1], Bytes::from_static(b"world"));
375 }
376
377 #[test]
378 fn get_frame_by_index() {
379 let mut producer = Group { sequence: 0 }.produce();
380 producer.write_frame(Bytes::from_static(b"a")).unwrap();
381 producer.write_frame(Bytes::from_static(b"bb")).unwrap();
382 producer.finish().unwrap();
383
384 let consumer = producer.consume();
385 let f0 = consumer.get_frame(0).now_or_never().unwrap().unwrap().unwrap();
386 assert_eq!(f0.info.size, 1);
387 let f1 = consumer.get_frame(1).now_or_never().unwrap().unwrap().unwrap();
388 assert_eq!(f1.info.size, 2);
389 let f2 = consumer.get_frame(2).now_or_never().unwrap().unwrap();
390 assert!(f2.is_none());
391 }
392
393 #[test]
394 fn group_finish_returns_none() {
395 let mut producer = Group { sequence: 0 }.produce();
396 producer.finish().unwrap();
397
398 let mut consumer = producer.consume();
399 let end = consumer.next_frame().now_or_never().unwrap().unwrap();
400 assert!(end.is_none());
401 }
402
403 #[test]
404 fn abort_propagates() {
405 let mut producer = Group { sequence: 0 }.produce();
406 let mut consumer = producer.consume();
407 producer.abort(crate::Error::Cancel).unwrap();
408
409 let result = consumer.next_frame().now_or_never().unwrap();
410 assert!(matches!(result, Err(crate::Error::Cancel)));
411 }
412
413 #[tokio::test]
414 async fn pending_then_ready() {
415 let mut producer = Group { sequence: 0 }.produce();
416 let mut consumer = producer.consume();
417
418 assert!(consumer.next_frame().now_or_never().is_none());
420
421 producer.write_frame(Bytes::from_static(b"data")).unwrap();
422 producer.finish().unwrap();
423
424 let frame = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
425 assert_eq!(frame.info.size, 4);
426 }
427
428 #[test]
429 fn clone_consumer_independent() {
430 let mut producer = Group { sequence: 0 }.produce();
431 producer.write_frame(Bytes::from_static(b"a")).unwrap();
432
433 let mut c1 = producer.consume();
434 let _ = c1.next_frame().now_or_never().unwrap().unwrap().unwrap();
436
437 let mut c2 = c1.clone();
439
440 producer.write_frame(Bytes::from_static(b"b")).unwrap();
441 producer.finish().unwrap();
442
443 let f = c2.next_frame().now_or_never().unwrap().unwrap().unwrap();
445 assert_eq!(f.info.size, 1); let end = c2.next_frame().now_or_never().unwrap().unwrap();
448 assert!(end.is_none());
449 }
450}