1use std::collections::VecDeque;
11use std::task::{Poll, ready};
12
13use bytes::Bytes;
14
15use crate::{Error, Result};
16
17use super::{Frame, FrameConsumer, FrameProducer};
18
19const MAX_GROUP_CACHE: u64 = 32 * 1024 * 1024; const MAX_GROUP_FRAMES: usize = 1024;
24
25#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
29#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
30pub struct Group {
31 pub sequence: u64,
32}
33
34impl Group {
35 pub fn produce(self) -> GroupProducer {
36 GroupProducer::new(self)
37 }
38}
39
40impl From<usize> for Group {
41 fn from(sequence: usize) -> Self {
42 Self {
43 sequence: sequence as u64,
44 }
45 }
46}
47
48impl From<u64> for Group {
49 fn from(sequence: u64) -> Self {
50 Self { sequence }
51 }
52}
53
54impl From<u32> for Group {
55 fn from(sequence: u32) -> Self {
56 Self {
57 sequence: sequence as u64,
58 }
59 }
60}
61
62impl From<u16> for Group {
63 fn from(sequence: u16) -> Self {
64 Self {
65 sequence: sequence as u64,
66 }
67 }
68}
69
70#[derive(Default)]
71struct GroupState {
72 frames: VecDeque<FrameProducer>,
75
76 offset: usize,
78
79 cache: u64,
81
82 fin: bool,
84
85 abort: Option<Error>,
87}
88
89impl GroupState {
90 fn poll_get_frame(&self, index: usize) -> Poll<Result<Option<FrameConsumer>>> {
91 if index < self.offset {
92 Poll::Ready(Err(Error::CacheFull))
93 } else if let Some(frame) = self.frames.get(index - self.offset) {
94 Poll::Ready(Ok(Some(frame.consume())))
95 } else if self.fin {
96 Poll::Ready(Ok(None))
97 } else if let Some(err) = &self.abort {
98 Poll::Ready(Err(err.clone()))
99 } else {
100 Poll::Pending
101 }
102 }
103
104 fn poll_finished(&self) -> Poll<Result<u64>> {
105 if self.fin {
106 Poll::Ready(Ok((self.offset + self.frames.len()) as u64))
107 } else if let Some(err) = &self.abort {
108 Poll::Ready(Err(err.clone()))
109 } else {
110 Poll::Pending
111 }
112 }
113
114 fn evict(&mut self) {
116 while self.cache > MAX_GROUP_CACHE || self.frames.len() > MAX_GROUP_FRAMES {
117 let Some(frame) = self.frames.pop_front() else {
118 break;
119 };
120 self.cache -= frame.info.size;
121 self.offset += 1;
122 }
123 }
124}
125
126fn modify(state: &conducer::Producer<GroupState>) -> Result<conducer::Mut<'_, GroupState>> {
127 state.write().map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
128}
129
130pub struct GroupProducer {
136 state: conducer::Producer<GroupState>,
138
139 pub info: Group,
141}
142
143impl GroupProducer {
144 pub fn new(info: Group) -> Self {
146 Self {
147 info,
148 state: conducer::Producer::default(),
149 }
150 }
151
152 pub fn write_frame<B: Into<Bytes>>(&mut self, frame: B) -> Result<()> {
157 let data = frame.into();
158 let frame = Frame {
159 size: data.len() as u64,
160 };
161 let mut frame = self.create_frame(frame)?;
162 frame.write(data)?;
163 frame.finish()?;
164 Ok(())
165 }
166
167 pub fn create_frame(&mut self, info: Frame) -> Result<FrameProducer> {
169 let frame = info.produce();
170 self.append_frame(frame.clone())?;
171 Ok(frame)
172 }
173
174 pub fn append_frame(&mut self, frame: FrameProducer) -> Result<()> {
176 let mut state = modify(&self.state)?;
177 if state.fin {
178 return Err(Error::Closed);
179 }
180 state.cache += frame.info.size;
181 state.frames.push_back(frame);
182 state.evict();
183 Ok(())
184 }
185
186 pub fn frame_count(&self) -> usize {
188 let state = self.state.read();
189 state.offset + state.frames.len()
190 }
191
192 pub fn finish(&mut self) -> Result<()> {
194 let mut state = modify(&self.state)?;
195 state.fin = true;
196 Ok(())
197 }
198
199 pub fn abort(&mut self, err: Error) -> Result<()> {
203 let mut guard = modify(&self.state)?;
204
205 for frame in guard.frames.iter_mut() {
207 frame.abort(err.clone()).ok();
209 }
210
211 guard.abort = Some(err);
212 guard.close();
213 Ok(())
214 }
215
216 pub fn consume(&self) -> GroupConsumer {
218 GroupConsumer {
219 info: self.info.clone(),
220 state: self.state.consume(),
221 index: 0,
222 }
223 }
224
225 pub async fn closed(&self) -> Error {
227 self.state.closed().await;
228 self.state.read().abort.clone().unwrap_or(Error::Dropped)
229 }
230
231 pub async fn unused(&self) -> Result<()> {
233 self.state
234 .unused()
235 .await
236 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
237 }
238}
239
240impl Clone for GroupProducer {
241 fn clone(&self) -> Self {
242 Self {
243 info: self.info.clone(),
244 state: self.state.clone(),
245 }
246 }
247}
248
249impl From<Group> for GroupProducer {
250 fn from(info: Group) -> Self {
251 GroupProducer::new(info)
252 }
253}
254
255#[derive(Clone)]
257pub struct GroupConsumer {
258 state: conducer::Consumer<GroupState>,
260
261 pub info: Group,
263
264 index: usize,
267}
268
269impl GroupConsumer {
270 fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R>>
272 where
273 F: Fn(&conducer::Ref<'_, GroupState>) -> Poll<Result<R>>,
274 {
275 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
276 Ok(res) => res,
277 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
279 })
280 }
281
282 pub async fn get_frame(&self, index: usize) -> Result<Option<FrameConsumer>> {
286 conducer::wait(|waiter| self.poll_get_frame(waiter, index)).await
287 }
288
289 pub fn poll_get_frame(&self, waiter: &conducer::Waiter, index: usize) -> Poll<Result<Option<FrameConsumer>>> {
293 self.poll(waiter, |state| state.poll_get_frame(index))
294 }
295
296 pub async fn next_frame(&mut self) -> Result<Option<FrameConsumer>> {
298 conducer::wait(|waiter| self.poll_next_frame(waiter)).await
299 }
300
301 pub fn poll_next_frame(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<FrameConsumer>>> {
305 let Some(frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
306 return Poll::Ready(Ok(None));
307 };
308
309 self.index += 1;
310 Poll::Ready(Ok(Some(frame)))
311 }
312
313 pub fn poll_read_frame(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Bytes>>> {
315 let Some(mut frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
316 return Poll::Ready(Ok(None));
317 };
318
319 let data = ready!(frame.poll_read_all(waiter))?;
320 self.index += 1;
321
322 Poll::Ready(Ok(Some(data)))
323 }
324
325 pub async fn read_frame(&mut self) -> Result<Option<Bytes>> {
327 conducer::wait(|waiter| self.poll_read_frame(waiter)).await
328 }
329
330 pub fn poll_read_frame_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Vec<Bytes>>>> {
332 let Some(mut frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
333 return Poll::Ready(Ok(None));
334 };
335
336 let data = ready!(frame.poll_read_all_chunks(waiter))?;
337 self.index += 1;
338
339 Poll::Ready(Ok(Some(data)))
340 }
341
342 pub async fn read_frame_chunks(&mut self) -> Result<Option<Vec<Bytes>>> {
344 conducer::wait(|waiter| self.poll_read_frame_chunks(waiter)).await
345 }
346
347 pub fn poll_finished(&mut self, waiter: &conducer::Waiter) -> Poll<Result<u64>> {
349 self.poll(waiter, |state| state.poll_finished())
350 }
351
352 pub async fn finished(&mut self) -> Result<u64> {
354 conducer::wait(|waiter| self.poll_finished(waiter)).await
355 }
356}
357
358#[cfg(test)]
359mod test {
360 use super::*;
361 use futures::FutureExt;
362
363 #[test]
364 fn basic_frame_reading() {
365 let mut producer = Group { sequence: 0 }.produce();
366 producer.write_frame(Bytes::from_static(b"frame0")).unwrap();
367 producer.write_frame(Bytes::from_static(b"frame1")).unwrap();
368 producer.finish().unwrap();
369
370 let mut consumer = producer.consume();
371 let f0 = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
372 assert_eq!(f0.info.size, 6);
373 let f1 = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
374 assert_eq!(f1.info.size, 6);
375 let end = consumer.next_frame().now_or_never().unwrap().unwrap();
376 assert!(end.is_none());
377 }
378
379 #[test]
380 fn read_frame_all_at_once() {
381 let mut producer = Group { sequence: 0 }.produce();
382 producer.write_frame(Bytes::from_static(b"hello")).unwrap();
383 producer.finish().unwrap();
384
385 let mut consumer = producer.consume();
386 let data = consumer.read_frame().now_or_never().unwrap().unwrap().unwrap();
387 assert_eq!(data, Bytes::from_static(b"hello"));
388 }
389
390 #[test]
391 fn read_frame_chunks() {
392 let mut producer = Group { sequence: 0 }.produce();
393 let mut frame = producer.create_frame(Frame { size: 10 }).unwrap();
394 frame.write(Bytes::from_static(b"hello")).unwrap();
395 frame.write(Bytes::from_static(b"world")).unwrap();
396 frame.finish().unwrap();
397 producer.finish().unwrap();
398
399 let mut consumer = producer.consume();
400 let chunks = consumer.read_frame_chunks().now_or_never().unwrap().unwrap().unwrap();
401 assert_eq!(chunks.len(), 2);
402 assert_eq!(chunks[0], Bytes::from_static(b"hello"));
403 assert_eq!(chunks[1], Bytes::from_static(b"world"));
404 }
405
406 #[test]
407 fn get_frame_by_index() {
408 let mut producer = Group { sequence: 0 }.produce();
409 producer.write_frame(Bytes::from_static(b"a")).unwrap();
410 producer.write_frame(Bytes::from_static(b"bb")).unwrap();
411 producer.finish().unwrap();
412
413 let consumer = producer.consume();
414 let f0 = consumer.get_frame(0).now_or_never().unwrap().unwrap().unwrap();
415 assert_eq!(f0.info.size, 1);
416 let f1 = consumer.get_frame(1).now_or_never().unwrap().unwrap().unwrap();
417 assert_eq!(f1.info.size, 2);
418 let f2 = consumer.get_frame(2).now_or_never().unwrap().unwrap();
419 assert!(f2.is_none());
420 }
421
422 #[test]
423 fn group_finish_returns_none() {
424 let mut producer = Group { sequence: 0 }.produce();
425 producer.finish().unwrap();
426
427 let mut consumer = producer.consume();
428 let end = consumer.next_frame().now_or_never().unwrap().unwrap();
429 assert!(end.is_none());
430 }
431
432 #[test]
433 fn abort_propagates() {
434 let mut producer = Group { sequence: 0 }.produce();
435 let mut consumer = producer.consume();
436 producer.abort(crate::Error::Cancel).unwrap();
437
438 let result = consumer.next_frame().now_or_never().unwrap();
439 assert!(matches!(result, Err(crate::Error::Cancel)));
440 }
441
442 #[tokio::test]
443 async fn pending_then_ready() {
444 let mut producer = Group { sequence: 0 }.produce();
445 let mut consumer = producer.consume();
446
447 assert!(consumer.next_frame().now_or_never().is_none());
449
450 producer.write_frame(Bytes::from_static(b"data")).unwrap();
451 producer.finish().unwrap();
452
453 let frame = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
454 assert_eq!(frame.info.size, 4);
455 }
456
457 #[test]
458 fn eviction_drops_old_frames() {
459 let mut producer = Group { sequence: 0 }.produce();
460
461 let big = Bytes::from(vec![0u8; MAX_GROUP_CACHE as usize]);
463 producer.write_frame(big.clone()).unwrap();
464 producer.write_frame(big).unwrap();
465
466 let consumer = producer.consume();
468 let result = consumer.get_frame(0).now_or_never().unwrap();
469 assert!(matches!(result, Err(crate::Error::CacheFull)));
470
471 let f1 = consumer.get_frame(1).now_or_never().unwrap().unwrap().unwrap();
473 assert_eq!(f1.info.size, MAX_GROUP_CACHE);
474 }
475
476 #[test]
477 fn no_eviction_under_limit() {
478 let mut producer = Group { sequence: 0 }.produce();
479 producer.write_frame(Bytes::from_static(b"small")).unwrap();
480 producer.write_frame(Bytes::from_static(b"frames")).unwrap();
481 producer.finish().unwrap();
482
483 let consumer = producer.consume();
484 let f0 = consumer.get_frame(0).now_or_never().unwrap().unwrap().unwrap();
485 assert_eq!(f0.info.size, 5);
486 let f1 = consumer.get_frame(1).now_or_never().unwrap().unwrap().unwrap();
487 assert_eq!(f1.info.size, 6);
488 }
489
490 #[test]
491 fn eviction_by_frame_count() {
492 let mut producer = Group { sequence: 0 }.produce();
493
494 for _ in 0..=MAX_GROUP_FRAMES {
496 producer.write_frame(Bytes::from_static(b"x")).unwrap();
497 }
498
499 let consumer = producer.consume();
501 let result = consumer.get_frame(0).now_or_never().unwrap();
502 assert!(matches!(result, Err(crate::Error::CacheFull)));
503
504 let f = consumer
506 .get_frame(MAX_GROUP_FRAMES)
507 .now_or_never()
508 .unwrap()
509 .unwrap()
510 .unwrap();
511 assert_eq!(f.info.size, 1);
512 }
513
514 #[test]
515 fn next_frame_returns_cache_full_on_tombstone() {
516 let mut producer = Group { sequence: 0 }.produce();
517
518 let big = Bytes::from(vec![0u8; MAX_GROUP_CACHE as usize]);
519 producer.write_frame(big.clone()).unwrap();
520 producer.write_frame(big).unwrap();
521
522 let mut consumer = producer.consume();
523 let result = consumer.next_frame().now_or_never().unwrap();
525 assert!(matches!(result, Err(crate::Error::CacheFull)));
526 }
527
528 #[test]
529 fn clone_consumer_independent() {
530 let mut producer = Group { sequence: 0 }.produce();
531 producer.write_frame(Bytes::from_static(b"a")).unwrap();
532
533 let mut c1 = producer.consume();
534 let _ = c1.next_frame().now_or_never().unwrap().unwrap().unwrap();
536
537 let mut c2 = c1.clone();
539
540 producer.write_frame(Bytes::from_static(b"b")).unwrap();
541 producer.finish().unwrap();
542
543 let f = c2.next_frame().now_or_never().unwrap().unwrap().unwrap();
545 assert_eq!(f.info.size, 1); let end = c2.next_frame().now_or_never().unwrap().unwrap();
548 assert!(end.is_none());
549 }
550}