1use std::collections::VecDeque;
11use std::task::{Poll, ready};
12
13use bytes::Bytes;
14
15use crate::{Error, MAX_FRAME_SIZE, Result};
16
17use super::{Frame, FrameConsumer, FrameProducer};
18
19const MAX_GROUP_CACHE: u64 = 32 * 1024 * 1024; const MAX_GROUP_FRAMES: usize = 1024;
26
27#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
31#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
32pub struct Group {
33 pub sequence: u64,
36}
37
38impl Group {
39 pub fn produce(self) -> GroupProducer {
41 GroupProducer::new(self)
42 }
43}
44
45impl From<usize> for Group {
46 fn from(sequence: usize) -> Self {
47 Self {
48 sequence: sequence as u64,
49 }
50 }
51}
52
53impl From<u64> for Group {
54 fn from(sequence: u64) -> Self {
55 Self { sequence }
56 }
57}
58
59impl From<u32> for Group {
60 fn from(sequence: u32) -> Self {
61 Self {
62 sequence: sequence as u64,
63 }
64 }
65}
66
67impl From<u16> for Group {
68 fn from(sequence: u16) -> Self {
69 Self {
70 sequence: sequence as u64,
71 }
72 }
73}
74
75#[derive(Default)]
76struct GroupState {
77 frames: VecDeque<FrameProducer>,
80
81 offset: usize,
83
84 cache: u64,
86
87 fin: bool,
89
90 abort: Option<Error>,
92}
93
94impl GroupState {
95 fn poll_get_frame(&self, index: usize) -> Poll<Result<Option<FrameConsumer>>> {
96 if index < self.offset {
97 Poll::Ready(Err(Error::CacheFull))
98 } else if let Some(frame) = self.frames.get(index - self.offset) {
99 Poll::Ready(Ok(Some(frame.consume())))
100 } else if self.fin {
101 Poll::Ready(Ok(None))
102 } else if let Some(err) = &self.abort {
103 Poll::Ready(Err(err.clone()))
104 } else {
105 Poll::Pending
106 }
107 }
108
109 fn poll_finished(&self) -> Poll<Result<u64>> {
110 if self.fin {
111 Poll::Ready(Ok((self.offset + self.frames.len()) as u64))
112 } else if let Some(err) = &self.abort {
113 Poll::Ready(Err(err.clone()))
114 } else {
115 Poll::Pending
116 }
117 }
118
119 fn evict(&mut self) {
121 while self.cache > MAX_GROUP_CACHE || self.frames.len() > MAX_GROUP_FRAMES {
122 let Some(frame) = self.frames.pop_front() else {
123 break;
124 };
125 self.cache -= frame.size;
126 self.offset += 1;
127 }
128 }
129}
130
131fn modify(state: &kio::Producer<GroupState>) -> Result<kio::Mut<'_, GroupState>> {
132 state.write().map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
133}
134
135pub struct GroupProducer {
141 state: kio::Producer<GroupState>,
143
144 info: Group,
146}
147
148impl std::ops::Deref for GroupProducer {
149 type Target = Group;
150
151 fn deref(&self) -> &Self::Target {
152 &self.info
153 }
154}
155
156impl GroupProducer {
157 pub fn new(info: Group) -> Self {
159 Self {
160 info,
161 state: kio::Producer::default(),
162 }
163 }
164
165 pub fn write_frame<B: Into<Bytes>>(&mut self, frame: B) -> Result<()> {
170 let data = frame.into();
171 let frame = Frame {
172 size: data.len() as u64,
173 };
174 let mut frame = self.create_frame(frame)?;
175 frame.write(data)?;
176 frame.finish()?;
177 Ok(())
178 }
179
180 pub fn create_frame(&mut self, info: Frame) -> Result<FrameProducer> {
182 if info.size > MAX_FRAME_SIZE {
185 return Err(Error::FrameTooLarge);
186 }
187 let frame = info.produce();
188 self.append_frame(frame.clone())?;
189 Ok(frame)
190 }
191
192 pub fn append_frame(&mut self, frame: FrameProducer) -> Result<()> {
194 if frame.size > MAX_FRAME_SIZE {
197 return Err(Error::FrameTooLarge);
198 }
199 let mut state = modify(&self.state)?;
200 if state.fin {
201 return Err(Error::Closed);
202 }
203 state.cache += frame.size;
204 state.frames.push_back(frame);
205 state.evict();
206 Ok(())
207 }
208
209 pub fn frame_count(&self) -> usize {
211 let state = self.state.read();
212 state.offset + state.frames.len()
213 }
214
215 pub fn finish(&mut self) -> Result<()> {
217 let mut state = modify(&self.state)?;
218 state.fin = true;
219 Ok(())
220 }
221
222 pub fn abort(&mut self, err: Error) -> Result<()> {
230 let mut guard = modify(&self.state)?;
231 guard.abort = Some(err);
232 guard.frames.clear();
233 guard.cache = 0;
234 guard.close();
235 Ok(())
236 }
237
238 pub fn consume(&self) -> GroupConsumer {
240 GroupConsumer {
241 info: self.info.clone(),
242 state: self.state.consume(),
243 index: 0,
244 }
245 }
246
247 pub async fn closed(&self) -> Error {
249 self.state.closed().await;
250 self.state.read().abort.clone().unwrap_or(Error::Dropped)
251 }
252
253 pub async fn unused(&self) -> Result<()> {
255 self.state
256 .unused()
257 .await
258 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
259 }
260}
261
262impl Clone for GroupProducer {
263 fn clone(&self) -> Self {
264 Self {
265 info: self.info.clone(),
266 state: self.state.clone(),
267 }
268 }
269}
270
271impl Drop for GroupProducer {
272 fn drop(&mut self) {
273 if !self.state.is_last() {
277 return;
278 }
279 if let Ok(mut state) = modify(&self.state)
280 && !state.fin
281 {
282 state.frames.clear();
283 state.cache = 0;
284 }
285 }
286}
287
288impl From<Group> for GroupProducer {
289 fn from(info: Group) -> Self {
290 GroupProducer::new(info)
291 }
292}
293
294#[derive(Clone)]
296pub struct GroupConsumer {
297 state: kio::Consumer<GroupState>,
299
300 info: Group,
302
303 index: usize,
306}
307
308impl std::ops::Deref for GroupConsumer {
309 type Target = Group;
310
311 fn deref(&self) -> &Self::Target {
312 &self.info
313 }
314}
315
316impl GroupConsumer {
317 fn poll<F, R>(&self, waiter: &kio::Waiter, f: F) -> Poll<Result<R>>
319 where
320 F: Fn(&kio::Ref<'_, GroupState>) -> Poll<Result<R>>,
321 {
322 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
323 Ok(res) => res,
324 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
326 })
327 }
328
329 pub async fn get_frame(&self, index: usize) -> Result<Option<FrameConsumer>> {
333 kio::wait(|waiter| self.poll_get_frame(waiter, index)).await
334 }
335
336 pub fn poll_get_frame(&self, waiter: &kio::Waiter, index: usize) -> Poll<Result<Option<FrameConsumer>>> {
340 self.poll(waiter, |state| state.poll_get_frame(index))
341 }
342
343 pub async fn next_frame(&mut self) -> Result<Option<FrameConsumer>> {
345 kio::wait(|waiter| self.poll_next_frame(waiter)).await
346 }
347
348 pub fn poll_next_frame(&mut self, waiter: &kio::Waiter) -> Poll<Result<Option<FrameConsumer>>> {
352 let Some(frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
353 return Poll::Ready(Ok(None));
354 };
355
356 self.index += 1;
357 Poll::Ready(Ok(Some(frame)))
358 }
359
360 pub fn poll_read_frame(&mut self, waiter: &kio::Waiter) -> Poll<Result<Option<Bytes>>> {
362 let Some(mut frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
363 return Poll::Ready(Ok(None));
364 };
365
366 let data = ready!(frame.poll_read_all(waiter))?;
367 self.index += 1;
368
369 Poll::Ready(Ok(Some(data)))
370 }
371
372 pub async fn read_frame(&mut self) -> Result<Option<Bytes>> {
374 kio::wait(|waiter| self.poll_read_frame(waiter)).await
375 }
376
377 pub fn poll_read_frame_chunks(&mut self, waiter: &kio::Waiter) -> Poll<Result<Option<Vec<Bytes>>>> {
379 let Some(mut frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
380 return Poll::Ready(Ok(None));
381 };
382
383 let data = ready!(frame.poll_read_all_chunks(waiter))?;
384 self.index += 1;
385
386 Poll::Ready(Ok(Some(data)))
387 }
388
389 pub async fn read_frame_chunks(&mut self) -> Result<Option<Vec<Bytes>>> {
391 kio::wait(|waiter| self.poll_read_frame_chunks(waiter)).await
392 }
393
394 pub fn poll_finished(&mut self, waiter: &kio::Waiter) -> Poll<Result<u64>> {
396 self.poll(waiter, |state| state.poll_finished())
397 }
398
399 pub async fn finished(&mut self) -> Result<u64> {
401 kio::wait(|waiter| self.poll_finished(waiter)).await
402 }
403}
404
405#[cfg(test)]
406mod test {
407 use super::*;
408 use futures::FutureExt;
409
410 #[test]
411 fn basic_frame_reading() {
412 let mut producer = Group { sequence: 0 }.produce();
413 producer.write_frame(Bytes::from_static(b"frame0")).unwrap();
414 producer.write_frame(Bytes::from_static(b"frame1")).unwrap();
415 producer.finish().unwrap();
416
417 let mut consumer = producer.consume();
418 let f0 = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
419 assert_eq!(f0.size, 6);
420 let f1 = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
421 assert_eq!(f1.size, 6);
422 let end = consumer.next_frame().now_or_never().unwrap().unwrap();
423 assert!(end.is_none());
424 }
425
426 #[test]
427 fn read_frame_all_at_once() {
428 let mut producer = Group { sequence: 0 }.produce();
429 producer.write_frame(Bytes::from_static(b"hello")).unwrap();
430 producer.finish().unwrap();
431
432 let mut consumer = producer.consume();
433 let data = consumer.read_frame().now_or_never().unwrap().unwrap().unwrap();
434 assert_eq!(data, Bytes::from_static(b"hello"));
435 }
436
437 #[test]
438 fn read_frame_chunks() {
439 let mut producer = Group { sequence: 0 }.produce();
440 let mut frame = producer.create_frame(Frame { size: 10 }).unwrap();
441 frame.write(Bytes::from_static(b"hello")).unwrap();
442 frame.write(Bytes::from_static(b"world")).unwrap();
443 frame.finish().unwrap();
444 producer.finish().unwrap();
445
446 let mut consumer = producer.consume();
449 let chunks = consumer.read_frame_chunks().now_or_never().unwrap().unwrap().unwrap();
450 assert_eq!(chunks.len(), 1);
451 assert_eq!(chunks[0], Bytes::from_static(b"helloworld"));
452 }
453
454 #[test]
455 fn append_rejects_oversized_frame() {
456 let mut producer = Group { sequence: 0 }.produce();
457 let err = producer.create_frame(Frame {
458 size: MAX_FRAME_SIZE + 1,
459 });
460 assert!(
461 matches!(err, Err(Error::FrameTooLarge)),
462 "a frame over the limit is rejected"
463 );
464 assert!(producer.create_frame(Frame { size: MAX_FRAME_SIZE }).is_ok());
466 }
467
468 #[test]
469 fn get_frame_by_index() {
470 let mut producer = Group { sequence: 0 }.produce();
471 producer.write_frame(Bytes::from_static(b"a")).unwrap();
472 producer.write_frame(Bytes::from_static(b"bb")).unwrap();
473 producer.finish().unwrap();
474
475 let consumer = producer.consume();
476 let f0 = consumer.get_frame(0).now_or_never().unwrap().unwrap().unwrap();
477 assert_eq!(f0.size, 1);
478 let f1 = consumer.get_frame(1).now_or_never().unwrap().unwrap().unwrap();
479 assert_eq!(f1.size, 2);
480 let f2 = consumer.get_frame(2).now_or_never().unwrap().unwrap();
481 assert!(f2.is_none());
482 }
483
484 #[test]
485 fn group_finish_returns_none() {
486 let mut producer = Group { sequence: 0 }.produce();
487 producer.finish().unwrap();
488
489 let mut consumer = producer.consume();
490 let end = consumer.next_frame().now_or_never().unwrap().unwrap();
491 assert!(end.is_none());
492 }
493
494 #[test]
495 fn abort_propagates() {
496 let mut producer = Group { sequence: 0 }.produce();
497 let mut consumer = producer.consume();
498 producer.abort(crate::Error::Cancel).unwrap();
499
500 let result = consumer.next_frame().now_or_never().unwrap();
501 assert!(matches!(result, Err(crate::Error::Cancel)));
502 }
503
504 #[test]
505 fn abort_clears_cached_frames() {
506 let mut producer = Group { sequence: 0 }.produce();
507 producer.write_frame(Bytes::from_static(b"data")).unwrap();
508
509 let _consumer = producer.consume();
511 assert_eq!(producer.state.read().frames.len(), 1);
512
513 producer.abort(crate::Error::Cancel).unwrap();
514
515 let state = producer.state.read();
516 assert!(state.frames.is_empty(), "cached frames should be dropped on abort");
517 assert_eq!(state.cache, 0);
518 }
519
520 #[test]
521 fn drop_unfinished_clears_cached_frames() {
522 let producer = Group { sequence: 0 }.produce();
523 let mut writer = producer.clone();
524 writer.write_frame(Bytes::from_static(b"data")).unwrap();
525
526 let mut consumer = producer.consume();
528 assert_eq!(producer.state.read().frames.len(), 1);
529
530 drop(writer);
532 drop(producer);
533
534 let result = consumer.next_frame().now_or_never().unwrap();
535 assert!(matches!(result, Err(crate::Error::Dropped)));
536 }
537
538 #[test]
539 fn drop_finished_keeps_cached_frames() {
540 let mut producer = Group { sequence: 0 }.produce();
541 producer.write_frame(Bytes::from_static(b"data")).unwrap();
542 producer.finish().unwrap();
543
544 let mut consumer = producer.consume();
545 drop(producer);
546
547 let frame = consumer.read_frame().now_or_never().unwrap().unwrap().unwrap();
549 assert_eq!(frame, Bytes::from_static(b"data"));
550 }
551
552 #[tokio::test]
553 async fn pending_then_ready() {
554 let mut producer = Group { sequence: 0 }.produce();
555 let mut consumer = producer.consume();
556
557 assert!(consumer.next_frame().now_or_never().is_none());
559
560 producer.write_frame(Bytes::from_static(b"data")).unwrap();
561 producer.finish().unwrap();
562
563 let frame = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
564 assert_eq!(frame.size, 4);
565 }
566
567 #[test]
568 fn eviction_drops_old_frames() {
569 let mut producer = Group { sequence: 0 }.produce();
570
571 let big = Bytes::from(vec![0u8; MAX_GROUP_CACHE as usize]);
573 producer.write_frame(big.clone()).unwrap();
574 producer.write_frame(big).unwrap();
575
576 let consumer = producer.consume();
578 let result = consumer.get_frame(0).now_or_never().unwrap();
579 assert!(matches!(result, Err(crate::Error::CacheFull)));
580
581 let f1 = consumer.get_frame(1).now_or_never().unwrap().unwrap().unwrap();
583 assert_eq!(f1.size, MAX_GROUP_CACHE);
584 }
585
586 #[test]
587 fn no_eviction_under_limit() {
588 let mut producer = Group { sequence: 0 }.produce();
589 producer.write_frame(Bytes::from_static(b"small")).unwrap();
590 producer.write_frame(Bytes::from_static(b"frames")).unwrap();
591 producer.finish().unwrap();
592
593 let consumer = producer.consume();
594 let f0 = consumer.get_frame(0).now_or_never().unwrap().unwrap().unwrap();
595 assert_eq!(f0.size, 5);
596 let f1 = consumer.get_frame(1).now_or_never().unwrap().unwrap().unwrap();
597 assert_eq!(f1.size, 6);
598 }
599
600 #[test]
601 fn eviction_by_frame_count() {
602 let mut producer = Group { sequence: 0 }.produce();
603
604 for _ in 0..=MAX_GROUP_FRAMES {
606 producer.write_frame(Bytes::from_static(b"x")).unwrap();
607 }
608
609 let consumer = producer.consume();
611 let result = consumer.get_frame(0).now_or_never().unwrap();
612 assert!(matches!(result, Err(crate::Error::CacheFull)));
613
614 let f = consumer
616 .get_frame(MAX_GROUP_FRAMES)
617 .now_or_never()
618 .unwrap()
619 .unwrap()
620 .unwrap();
621 assert_eq!(f.size, 1);
622 }
623
624 #[test]
625 fn next_frame_returns_cache_full_on_tombstone() {
626 let mut producer = Group { sequence: 0 }.produce();
627
628 let big = Bytes::from(vec![0u8; MAX_GROUP_CACHE as usize]);
629 producer.write_frame(big.clone()).unwrap();
630 producer.write_frame(big).unwrap();
631
632 let mut consumer = producer.consume();
633 let result = consumer.next_frame().now_or_never().unwrap();
635 assert!(matches!(result, Err(crate::Error::CacheFull)));
636 }
637
638 #[test]
639 fn clone_consumer_independent() {
640 let mut producer = Group { sequence: 0 }.produce();
641 producer.write_frame(Bytes::from_static(b"a")).unwrap();
642
643 let mut c1 = producer.consume();
644 let _ = c1.next_frame().now_or_never().unwrap().unwrap().unwrap();
646
647 let mut c2 = c1.clone();
649
650 producer.write_frame(Bytes::from_static(b"b")).unwrap();
651 producer.finish().unwrap();
652
653 let f = c2.next_frame().now_or_never().unwrap().unwrap().unwrap();
655 assert_eq!(f.size, 1); let end = c2.next_frame().now_or_never().unwrap().unwrap();
658 assert!(end.is_none());
659 }
660}