1use std::collections::VecDeque;
11use std::task::{Poll, ready};
12
13use bytes::Bytes;
14
15use crate::{Error, Result};
16
17use super::{Frame, FrameConsumer, FrameProducer};
18
19const MAX_GROUP_CACHE: u64 = 32 * 1024 * 1024; const MAX_GROUP_FRAMES: usize = 1024;
24
25#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
29#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
30pub struct Group {
31 pub sequence: u64,
34}
35
36impl Group {
37 pub fn produce(self) -> GroupProducer {
39 GroupProducer::new(self)
40 }
41}
42
43impl From<usize> for Group {
44 fn from(sequence: usize) -> Self {
45 Self {
46 sequence: sequence as u64,
47 }
48 }
49}
50
51impl From<u64> for Group {
52 fn from(sequence: u64) -> Self {
53 Self { sequence }
54 }
55}
56
57impl From<u32> for Group {
58 fn from(sequence: u32) -> Self {
59 Self {
60 sequence: sequence as u64,
61 }
62 }
63}
64
65impl From<u16> for Group {
66 fn from(sequence: u16) -> Self {
67 Self {
68 sequence: sequence as u64,
69 }
70 }
71}
72
73#[derive(Default)]
74struct GroupState {
75 frames: VecDeque<FrameProducer>,
78
79 offset: usize,
81
82 cache: u64,
84
85 fin: bool,
87
88 abort: Option<Error>,
90}
91
92impl GroupState {
93 fn poll_get_frame(&self, index: usize) -> Poll<Result<Option<FrameConsumer>>> {
94 if index < self.offset {
95 Poll::Ready(Err(Error::CacheFull))
96 } else if let Some(frame) = self.frames.get(index - self.offset) {
97 Poll::Ready(Ok(Some(frame.consume())))
98 } else if self.fin {
99 Poll::Ready(Ok(None))
100 } else if let Some(err) = &self.abort {
101 Poll::Ready(Err(err.clone()))
102 } else {
103 Poll::Pending
104 }
105 }
106
107 fn poll_finished(&self) -> Poll<Result<u64>> {
108 if self.fin {
109 Poll::Ready(Ok((self.offset + self.frames.len()) as u64))
110 } else if let Some(err) = &self.abort {
111 Poll::Ready(Err(err.clone()))
112 } else {
113 Poll::Pending
114 }
115 }
116
117 fn evict(&mut self) {
119 while self.cache > MAX_GROUP_CACHE || self.frames.len() > MAX_GROUP_FRAMES {
120 let Some(frame) = self.frames.pop_front() else {
121 break;
122 };
123 self.cache -= frame.size;
124 self.offset += 1;
125 }
126 }
127}
128
129fn modify(state: &conducer::Producer<GroupState>) -> Result<conducer::Mut<'_, GroupState>> {
130 state.write().map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
131}
132
133pub struct GroupProducer {
139 state: conducer::Producer<GroupState>,
141
142 info: Group,
144}
145
146impl std::ops::Deref for GroupProducer {
147 type Target = Group;
148
149 fn deref(&self) -> &Self::Target {
150 &self.info
151 }
152}
153
154impl GroupProducer {
155 pub fn new(info: Group) -> Self {
157 Self {
158 info,
159 state: conducer::Producer::default(),
160 }
161 }
162
163 pub fn write_frame<B: Into<Bytes>>(&mut self, frame: B) -> Result<()> {
168 let data = frame.into();
169 let frame = Frame {
170 size: data.len() as u64,
171 };
172 let mut frame = self.create_frame(frame)?;
173 frame.write(data)?;
174 frame.finish()?;
175 Ok(())
176 }
177
178 pub fn create_frame(&mut self, info: Frame) -> Result<FrameProducer> {
180 let frame = info.produce();
181 self.append_frame(frame.clone())?;
182 Ok(frame)
183 }
184
185 pub fn append_frame(&mut self, frame: FrameProducer) -> Result<()> {
187 let mut state = modify(&self.state)?;
188 if state.fin {
189 return Err(Error::Closed);
190 }
191 state.cache += frame.size;
192 state.frames.push_back(frame);
193 state.evict();
194 Ok(())
195 }
196
197 pub fn frame_count(&self) -> usize {
199 let state = self.state.read();
200 state.offset + state.frames.len()
201 }
202
203 pub fn finish(&mut self) -> Result<()> {
205 let mut state = modify(&self.state)?;
206 state.fin = true;
207 Ok(())
208 }
209
210 pub fn abort(&mut self, err: Error) -> Result<()> {
216 let mut guard = modify(&self.state)?;
217 guard.abort = Some(err);
218 guard.close();
219 Ok(())
220 }
221
222 pub fn consume(&self) -> GroupConsumer {
224 GroupConsumer {
225 info: self.info.clone(),
226 state: self.state.consume(),
227 index: 0,
228 }
229 }
230
231 pub async fn closed(&self) -> Error {
233 self.state.closed().await;
234 self.state.read().abort.clone().unwrap_or(Error::Dropped)
235 }
236
237 pub async fn unused(&self) -> Result<()> {
239 self.state
240 .unused()
241 .await
242 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
243 }
244}
245
246impl Clone for GroupProducer {
247 fn clone(&self) -> Self {
248 Self {
249 info: self.info.clone(),
250 state: self.state.clone(),
251 }
252 }
253}
254
255impl From<Group> for GroupProducer {
256 fn from(info: Group) -> Self {
257 GroupProducer::new(info)
258 }
259}
260
261#[derive(Clone)]
263pub struct GroupConsumer {
264 state: conducer::Consumer<GroupState>,
266
267 info: Group,
269
270 index: usize,
273}
274
275impl std::ops::Deref for GroupConsumer {
276 type Target = Group;
277
278 fn deref(&self) -> &Self::Target {
279 &self.info
280 }
281}
282
283impl GroupConsumer {
284 fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R>>
286 where
287 F: Fn(&conducer::Ref<'_, GroupState>) -> Poll<Result<R>>,
288 {
289 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
290 Ok(res) => res,
291 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
293 })
294 }
295
296 pub async fn get_frame(&self, index: usize) -> Result<Option<FrameConsumer>> {
300 conducer::wait(|waiter| self.poll_get_frame(waiter, index)).await
301 }
302
303 pub fn poll_get_frame(&self, waiter: &conducer::Waiter, index: usize) -> Poll<Result<Option<FrameConsumer>>> {
307 self.poll(waiter, |state| state.poll_get_frame(index))
308 }
309
310 pub async fn next_frame(&mut self) -> Result<Option<FrameConsumer>> {
312 conducer::wait(|waiter| self.poll_next_frame(waiter)).await
313 }
314
315 pub fn poll_next_frame(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<FrameConsumer>>> {
319 let Some(frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
320 return Poll::Ready(Ok(None));
321 };
322
323 self.index += 1;
324 Poll::Ready(Ok(Some(frame)))
325 }
326
327 pub fn poll_read_frame(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Bytes>>> {
329 let Some(mut frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
330 return Poll::Ready(Ok(None));
331 };
332
333 let data = ready!(frame.poll_read_all(waiter))?;
334 self.index += 1;
335
336 Poll::Ready(Ok(Some(data)))
337 }
338
339 pub async fn read_frame(&mut self) -> Result<Option<Bytes>> {
341 conducer::wait(|waiter| self.poll_read_frame(waiter)).await
342 }
343
344 pub fn poll_read_frame_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Vec<Bytes>>>> {
346 let Some(mut frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
347 return Poll::Ready(Ok(None));
348 };
349
350 let data = ready!(frame.poll_read_all_chunks(waiter))?;
351 self.index += 1;
352
353 Poll::Ready(Ok(Some(data)))
354 }
355
356 pub async fn read_frame_chunks(&mut self) -> Result<Option<Vec<Bytes>>> {
358 conducer::wait(|waiter| self.poll_read_frame_chunks(waiter)).await
359 }
360
361 pub fn poll_finished(&mut self, waiter: &conducer::Waiter) -> Poll<Result<u64>> {
363 self.poll(waiter, |state| state.poll_finished())
364 }
365
366 pub async fn finished(&mut self) -> Result<u64> {
368 conducer::wait(|waiter| self.poll_finished(waiter)).await
369 }
370}
371
372#[cfg(test)]
373mod test {
374 use super::*;
375 use futures::FutureExt;
376
377 #[test]
378 fn basic_frame_reading() {
379 let mut producer = Group { sequence: 0 }.produce();
380 producer.write_frame(Bytes::from_static(b"frame0")).unwrap();
381 producer.write_frame(Bytes::from_static(b"frame1")).unwrap();
382 producer.finish().unwrap();
383
384 let mut consumer = producer.consume();
385 let f0 = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
386 assert_eq!(f0.size, 6);
387 let f1 = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
388 assert_eq!(f1.size, 6);
389 let end = consumer.next_frame().now_or_never().unwrap().unwrap();
390 assert!(end.is_none());
391 }
392
393 #[test]
394 fn read_frame_all_at_once() {
395 let mut producer = Group { sequence: 0 }.produce();
396 producer.write_frame(Bytes::from_static(b"hello")).unwrap();
397 producer.finish().unwrap();
398
399 let mut consumer = producer.consume();
400 let data = consumer.read_frame().now_or_never().unwrap().unwrap().unwrap();
401 assert_eq!(data, Bytes::from_static(b"hello"));
402 }
403
404 #[test]
405 fn read_frame_chunks() {
406 let mut producer = Group { sequence: 0 }.produce();
407 let mut frame = producer.create_frame(Frame { size: 10 }).unwrap();
408 frame.write(Bytes::from_static(b"hello")).unwrap();
409 frame.write(Bytes::from_static(b"world")).unwrap();
410 frame.finish().unwrap();
411 producer.finish().unwrap();
412
413 let mut consumer = producer.consume();
416 let chunks = consumer.read_frame_chunks().now_or_never().unwrap().unwrap().unwrap();
417 assert_eq!(chunks.len(), 1);
418 assert_eq!(chunks[0], Bytes::from_static(b"helloworld"));
419 }
420
421 #[test]
422 fn get_frame_by_index() {
423 let mut producer = Group { sequence: 0 }.produce();
424 producer.write_frame(Bytes::from_static(b"a")).unwrap();
425 producer.write_frame(Bytes::from_static(b"bb")).unwrap();
426 producer.finish().unwrap();
427
428 let consumer = producer.consume();
429 let f0 = consumer.get_frame(0).now_or_never().unwrap().unwrap().unwrap();
430 assert_eq!(f0.size, 1);
431 let f1 = consumer.get_frame(1).now_or_never().unwrap().unwrap().unwrap();
432 assert_eq!(f1.size, 2);
433 let f2 = consumer.get_frame(2).now_or_never().unwrap().unwrap();
434 assert!(f2.is_none());
435 }
436
437 #[test]
438 fn group_finish_returns_none() {
439 let mut producer = Group { sequence: 0 }.produce();
440 producer.finish().unwrap();
441
442 let mut consumer = producer.consume();
443 let end = consumer.next_frame().now_or_never().unwrap().unwrap();
444 assert!(end.is_none());
445 }
446
447 #[test]
448 fn abort_propagates() {
449 let mut producer = Group { sequence: 0 }.produce();
450 let mut consumer = producer.consume();
451 producer.abort(crate::Error::Cancel).unwrap();
452
453 let result = consumer.next_frame().now_or_never().unwrap();
454 assert!(matches!(result, Err(crate::Error::Cancel)));
455 }
456
457 #[tokio::test]
458 async fn pending_then_ready() {
459 let mut producer = Group { sequence: 0 }.produce();
460 let mut consumer = producer.consume();
461
462 assert!(consumer.next_frame().now_or_never().is_none());
464
465 producer.write_frame(Bytes::from_static(b"data")).unwrap();
466 producer.finish().unwrap();
467
468 let frame = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
469 assert_eq!(frame.size, 4);
470 }
471
472 #[test]
473 fn eviction_drops_old_frames() {
474 let mut producer = Group { sequence: 0 }.produce();
475
476 let big = Bytes::from(vec![0u8; MAX_GROUP_CACHE as usize]);
478 producer.write_frame(big.clone()).unwrap();
479 producer.write_frame(big).unwrap();
480
481 let consumer = producer.consume();
483 let result = consumer.get_frame(0).now_or_never().unwrap();
484 assert!(matches!(result, Err(crate::Error::CacheFull)));
485
486 let f1 = consumer.get_frame(1).now_or_never().unwrap().unwrap().unwrap();
488 assert_eq!(f1.size, MAX_GROUP_CACHE);
489 }
490
491 #[test]
492 fn no_eviction_under_limit() {
493 let mut producer = Group { sequence: 0 }.produce();
494 producer.write_frame(Bytes::from_static(b"small")).unwrap();
495 producer.write_frame(Bytes::from_static(b"frames")).unwrap();
496 producer.finish().unwrap();
497
498 let consumer = producer.consume();
499 let f0 = consumer.get_frame(0).now_or_never().unwrap().unwrap().unwrap();
500 assert_eq!(f0.size, 5);
501 let f1 = consumer.get_frame(1).now_or_never().unwrap().unwrap().unwrap();
502 assert_eq!(f1.size, 6);
503 }
504
505 #[test]
506 fn eviction_by_frame_count() {
507 let mut producer = Group { sequence: 0 }.produce();
508
509 for _ in 0..=MAX_GROUP_FRAMES {
511 producer.write_frame(Bytes::from_static(b"x")).unwrap();
512 }
513
514 let consumer = producer.consume();
516 let result = consumer.get_frame(0).now_or_never().unwrap();
517 assert!(matches!(result, Err(crate::Error::CacheFull)));
518
519 let f = consumer
521 .get_frame(MAX_GROUP_FRAMES)
522 .now_or_never()
523 .unwrap()
524 .unwrap()
525 .unwrap();
526 assert_eq!(f.size, 1);
527 }
528
529 #[test]
530 fn next_frame_returns_cache_full_on_tombstone() {
531 let mut producer = Group { sequence: 0 }.produce();
532
533 let big = Bytes::from(vec![0u8; MAX_GROUP_CACHE as usize]);
534 producer.write_frame(big.clone()).unwrap();
535 producer.write_frame(big).unwrap();
536
537 let mut consumer = producer.consume();
538 let result = consumer.next_frame().now_or_never().unwrap();
540 assert!(matches!(result, Err(crate::Error::CacheFull)));
541 }
542
543 #[test]
544 fn clone_consumer_independent() {
545 let mut producer = Group { sequence: 0 }.produce();
546 producer.write_frame(Bytes::from_static(b"a")).unwrap();
547
548 let mut c1 = producer.consume();
549 let _ = c1.next_frame().now_or_never().unwrap().unwrap().unwrap();
551
552 let mut c2 = c1.clone();
554
555 producer.write_frame(Bytes::from_static(b"b")).unwrap();
556 producer.finish().unwrap();
557
558 let f = c2.next_frame().now_or_never().unwrap().unwrap().unwrap();
560 assert_eq!(f.size, 1); let end = c2.next_frame().now_or_never().unwrap().unwrap();
563 assert!(end.is_none());
564 }
565}