1extern crate alloc;
2
3use alloc::collections::VecDeque;
4use alloc::vec::Vec;
5use parking_lot::RwLock;
6use std::sync::Arc;
7use tokio::sync::mpsc::error::{SendError, TryRecvError, TrySendError};
8use tokio::sync::Notify;
9
10pub trait FairGroup: Clone {
13 fn group_id(&self) -> Option<u32>;
14 fn get_size(&self) -> Option<usize>;
15}
16
17pub struct FairQueue<V: FairGroup> {
20 ctrl_group: VecDeque<Arc<V>>,
21 groups: Vec<VecDeque<Arc<V>>>,
22 pointer: usize,
23 max_group_size: usize,
24}
25
26impl<V: FairGroup> FairQueue<V> {
27 pub fn new(max_group_size: usize) -> Self {
28 Self {
29 ctrl_group: VecDeque::new(),
30 groups: Vec::new(),
31 pointer: 0,
32 max_group_size,
33 }
34 }
35
36 pub fn can_insert(&self, value: &V) -> bool {
38 match value.group_id() {
39 None => true, Some(group_id) => {
41 if let Some(group) = self
42 .groups
43 .iter()
44 .find(|group| group.front().map(|v| v.group_id()) == Some(Some(group_id)))
45 {
46 let can = group.len() < self.max_group_size;
47 if !can {
48 tracing::error!("Cannot insert value into group: group is full");
49 }
50 can
51 } else {
52 true }
54 }
55 }
56 }
57
58 pub fn insert(&mut self, value: Arc<V>) -> bool {
61 match value.group_id() {
63 None => {
64 self.ctrl_group.push_back(value);
66 true
67 }
68 Some(group_id) => {
69 if let Some(group) = self
71 .groups
72 .iter_mut()
73 .find(|group| group.front().map(|v| v.group_id()) == Some(Some(group_id)))
74 {
75 if group.len() >= self.max_group_size {
76 return false; }
78 group.push_back(value);
79 } else {
80 let mut new_group = VecDeque::new();
81 new_group.push_back(value);
82 self.groups.push(new_group);
83 }
84 true
85 }
86 }
87 }
88
89 #[inline(always)]
90 pub fn pop(&mut self) -> Option<Arc<V>> {
91 if let Some(v) = self.ctrl_group.pop_front() {
92 return Some(v);
94 }
95 for _ in 0..self.groups.len() {
96 let pointer = self.pointer;
97 self.pointer = (pointer + 1) % self.groups.len();
99
100 let group = &mut self.groups[pointer];
101 let item = group.pop_front();
102
103 if item.is_some() {
104 if group.is_empty() {
105 self.groups.remove(pointer);
106 if pointer < self.groups.len() {
107 self.pointer = pointer;
108 } else {
109 self.pointer = 0;
110 }
111 }
112 return item;
113 }
114 }
115
116 None
117 }
118}
119
120struct ChannelState<T: FairGroup + 'static> {
122 queue: FairQueue<T>,
123 closed: bool,
124}
125
126impl<T: FairGroup + 'static> ChannelState<T> {
127 fn new(max_group_size: usize) -> Self {
128 Self {
129 queue: FairQueue::new(max_group_size),
130 closed: false,
131 }
132 }
133
134 fn can_insert(&self, value: &T) -> bool {
135 self.queue.can_insert(value)
136 }
137}
138
139pub struct FairSender<T: FairGroup + 'static> {
141 state: Arc<RwLock<ChannelState<T>>>,
142 notify_recv: Arc<Notify>,
143 notify_send: Arc<Notify>,
144}
145
146impl<T: FairGroup + 'static> Clone for FairSender<T> {
147 fn clone(&self) -> Self {
148 Self {
149 state: Arc::clone(&self.state),
150 notify_recv: Arc::clone(&self.notify_recv),
151 notify_send: Arc::clone(&self.notify_send),
152 }
153 }
154}
155
156impl<T: FairGroup + 'static> FairSender<T> {
157 pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
159 let value_arc = Arc::new(value);
160
161 loop {
162 {
163 let mut state = self.state.write();
164 if state.closed {
165 return Err(SendError(
166 Arc::try_unwrap(value_arc).unwrap_or_else(|arc| (*arc).clone()),
167 ));
168 }
169
170 if state.can_insert(&value_arc) {
172 state.queue.insert(value_arc);
173 drop(state);
174 self.notify_recv.notify_waiters();
175 return Ok(());
176 }
177 }
178
179 self.notify_send.notified().await;
181 }
182 }
183
184 pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
186 let value_arc = Arc::new(value);
187
188 let mut state = self.state.write();
189 if state.closed {
190 return Err(TrySendError::Closed(
191 Arc::try_unwrap(value_arc).unwrap_or_else(|arc| (*arc).clone()),
192 ));
193 }
194
195 if !state.queue.can_insert(&value_arc) {
196 return Err(TrySendError::Full(
197 Arc::try_unwrap(value_arc).unwrap_or_else(|arc| (*arc).clone()),
198 ));
199 }
200
201 state.queue.insert(value_arc);
202 drop(state); self.notify_recv.notify_waiters();
204 Ok(())
205 }
206
207 pub async fn closed(&self) {
209 loop {
210 {
211 let state = self.state.read();
212 if state.closed {
213 return;
214 }
215 }
216
217 self.notify_send.notified().await;
219 }
220 }
221}
222
223pub struct FairReceiver<T: FairGroup + 'static> {
225 state: Arc<RwLock<ChannelState<T>>>,
226 notify_recv: Arc<Notify>,
227 notify_send: Arc<Notify>,
228}
229
230impl<T: FairGroup + 'static> FairReceiver<T> {
231 pub async fn recv(&mut self) -> Option<T> {
233 loop {
234 {
235 let mut state = self.state.write();
236 if let Some(value_arc) = state.queue.pop() {
237 drop(state);
238 self.notify_send.notify_waiters();
239 return Some(Arc::try_unwrap(value_arc).unwrap_or_else(|arc| (*arc).clone()));
240 }
241
242 if state.closed {
243 return None;
244 }
245 }
246
247 self.notify_recv.notified().await;
249 }
250 }
251
252 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
254 let mut state = self.state.write();
255
256 if let Some(value_arc) = state.queue.pop() {
257 drop(state); self.notify_send.notify_waiters();
259 return Ok(Arc::try_unwrap(value_arc).unwrap_or_else(|arc| (*arc).clone()));
260 }
261
262 if state.closed {
263 Err(TryRecvError::Disconnected)
264 } else {
265 Err(TryRecvError::Empty)
266 }
267 }
268
269 pub async fn close(&mut self) {
271 let mut state = self.state.write();
272 state.closed = true;
273 drop(state); self.notify_send.notify_waiters();
275 }
276}
277
278impl<T: FairGroup + 'static> Drop for FairReceiver<T> {
279 fn drop(&mut self) {
280 if let Some(mut state) = self.state.try_write() {
282 state.closed = true;
283 drop(state); self.notify_send.notify_waiters();
285 }
286 }
287}
288
289pub fn fair_channel<T: FairGroup + 'static>(
291 max_group_size: usize,
292) -> (FairSender<T>, FairReceiver<T>) {
293 let state = Arc::new(RwLock::new(ChannelState::new(max_group_size)));
294 let notify_recv = Arc::new(Notify::new());
295 let notify_send = Arc::new(Notify::new());
296
297 let sender = FairSender {
298 state: Arc::clone(&state),
299 notify_recv: Arc::clone(¬ify_recv),
300 notify_send: Arc::clone(¬ify_send),
301 };
302
303 let receiver = FairReceiver {
304 state,
305 notify_recv,
306 notify_send,
307 };
308
309 (sender, receiver)
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[derive(Debug, PartialEq, Clone)]
317 struct Event {
318 timestamp: u32,
319 user_id: &'static str,
320 }
321
322 impl FairGroup for Event {
323 fn group_id(&self) -> Option<u32> {
324 match self.user_id {
326 "user1" => Some(1),
327 "user2" => Some(2),
328 "user3" => Some(3),
329 _ => Some(0), }
331 }
332
333 fn get_size(&self) -> Option<usize> {
334 None }
336 }
337
338 #[test]
339 fn test_spaced_fairness() {
340 let event1 = Event {
341 timestamp: 1,
342 user_id: "user1",
343 };
344 let event2 = Event {
345 timestamp: 2,
346 user_id: "user2",
347 };
348 let event3 = Event {
349 timestamp: 3,
350 user_id: "user1",
351 };
352 let event4 = Event {
353 timestamp: 4,
354 user_id: "user3",
355 };
356 let event5 = Event {
357 timestamp: 5,
358 user_id: "user2",
359 };
360 let event6 = Event {
361 timestamp: 6,
362 user_id: "user1",
363 };
364 let event7 = Event {
365 timestamp: 7,
366 user_id: "user1",
367 };
368 let event8 = Event {
369 timestamp: 8,
370 user_id: "user3",
371 };
372
373 let mut queue = FairQueue::new(usize::MAX);
374
375 let event1_arc = Arc::new(event1.clone());
376 let event2_arc = Arc::new(event2.clone());
377 let event3_arc = Arc::new(event3.clone());
378 let event4_arc = Arc::new(event4.clone());
379 let event5_arc = Arc::new(event5.clone());
380 let event6_arc = Arc::new(event6.clone());
381 let event7_arc = Arc::new(event7.clone());
382 let event8_arc = Arc::new(event8.clone());
383
384 queue.insert(event1_arc.clone());
385 queue.insert(event2_arc.clone());
386 queue.insert(event3_arc.clone());
387 queue.insert(event4_arc.clone());
388 queue.insert(event5_arc.clone());
389 queue.insert(event6_arc.clone());
390 queue.insert(event7_arc.clone());
391 queue.insert(event8_arc.clone());
392
393 let mut results = Vec::new();
397 while let Some(event) = queue.pop() {
398 results.push(event);
399 }
400
401 assert_eq!(results.len(), 8);
403
404 let user1_events: Vec<_> = results.iter().filter(|e| e.user_id == "user1").collect();
407 let user2_events: Vec<_> = results.iter().filter(|e| e.user_id == "user2").collect();
408 let user3_events: Vec<_> = results.iter().filter(|e| e.user_id == "user3").collect();
409
410 assert_eq!(user1_events.len(), 4);
411 assert_eq!(user2_events.len(), 2);
412 assert_eq!(user3_events.len(), 2);
413 }
414
415 #[tokio::test]
416 async fn test_fair_channel_basic() {
417 let (tx, mut rx) = fair_channel(5);
418
419 let event1 = Event {
420 timestamp: 1,
421 user_id: "user1",
422 };
423 let event2 = Event {
424 timestamp: 2,
425 user_id: "user2",
426 };
427
428 tx.send(event1).await.unwrap();
429 tx.send(event2).await.unwrap();
430
431 let received1 = rx.recv().await.unwrap();
432 let received2 = rx.recv().await.unwrap();
433
434 assert_eq!(received1.timestamp, 1);
435 assert_eq!(received2.timestamp, 2);
436 }
437
438 #[tokio::test]
439 async fn test_fair_channel_fairness() {
440 let (tx, mut rx) = fair_channel(5);
441
442 for i in 0..6 {
444 let user_id = match i % 3 {
445 0 => "user1",
446 1 => "user2",
447 _ => "user3",
448 };
449 let event = Event {
450 timestamp: i,
451 user_id,
452 };
453 tx.send(event).await.unwrap();
454 }
455
456 let mut received = Vec::new();
458 for _ in 0..6 {
459 received.push(rx.recv().await.unwrap());
460 }
461
462 let user1_count = received.iter().filter(|e| e.user_id == "user1").count();
466 let user2_count = received.iter().filter(|e| e.user_id == "user2").count();
467 let user3_count = received.iter().filter(|e| e.user_id == "user3").count();
468
469 assert_eq!(user1_count, 2);
471 assert_eq!(user2_count, 2);
472 assert_eq!(user3_count, 2);
473
474 let mut timestamps: Vec<_> = received.iter().map(|e| e.timestamp).collect();
476 timestamps.sort();
477 assert_eq!(timestamps, vec![0, 1, 2, 3, 4, 5]);
478 }
479
480 #[tokio::test]
481 async fn test_fair_channel_closed_method() {
482 let (tx, mut rx) = fair_channel(5);
483
484 let tx_clone = tx.clone();
486 let closed_task = tokio::spawn(async move {
487 tx_clone.closed().await;
488 });
489
490 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
492
493 assert!(!closed_task.is_finished());
495
496 rx.close().await;
498
499 closed_task.await.unwrap();
501
502 let result = tx
504 .send(Event {
505 timestamp: 1,
506 user_id: "user1",
507 })
508 .await;
509 assert!(matches!(result, Err(SendError(_))));
510 }
511}