cloudpub_common/
fair_channel.rs

1extern crate alloc;
2
3use alloc::collections::VecDeque;
4use alloc::vec::Vec;
5use parking_lot::RwLock;
6use std::sync::Arc;
7use tokio::sync::mpsc::error::{SendError, TryRecvError, TrySendError};
8use tokio::sync::Notify;
9
10/// Trait for defining grouping logic for fair queuing.
11/// Groups are determined based on `group_id`. None represents the control group.
12pub trait FairGroup: Clone {
13    fn group_id(&self) -> Option<u32>;
14    fn get_size(&self) -> Option<usize>;
15}
16
17/// Spatially distancing fair queue. First in, first out, ensuring that
18/// each group of similar values is placed as far apart as possible.
19pub struct FairQueue<V: FairGroup> {
20    ctrl_group: VecDeque<Arc<V>>,
21    groups: Vec<VecDeque<Arc<V>>>,
22    pointer: usize,
23    max_group_size: usize,
24}
25
26impl<V: FairGroup> FairQueue<V> {
27    pub fn new(max_group_size: usize) -> Self {
28        Self {
29            ctrl_group: VecDeque::new(),
30            groups: Vec::new(),
31            pointer: 0,
32            max_group_size,
33        }
34    }
35
36    /// Check if a value can be inserted without exceeding group limits
37    pub fn can_insert(&self, value: &V) -> bool {
38        match value.group_id() {
39            None => true, // Control group can always be inserted
40            Some(group_id) => {
41                if let Some(group) = self
42                    .groups
43                    .iter()
44                    .find(|group| group.front().map(|v| v.group_id()) == Some(Some(group_id)))
45                {
46                    let can = group.len() < self.max_group_size;
47                    if !can {
48                        tracing::error!("Cannot insert value into group: group is full");
49                    }
50                    can
51                } else {
52                    true // New group can always be created
53                }
54            }
55        }
56    }
57
58    /// Inserts a new item into the queue, ensuring spatial distancing between items of the same group.
59    /// Returns true if inserted successfully, false if group is full.
60    pub fn insert(&mut self, value: Arc<V>) -> bool {
61        //value.trace_message("INSERT");
62        match value.group_id() {
63            None => {
64                // Control group (group_id is None)
65                self.ctrl_group.push_back(value);
66                true
67            }
68            Some(group_id) => {
69                // Regular group
70                if let Some(group) = self
71                    .groups
72                    .iter_mut()
73                    .find(|group| group.front().map(|v| v.group_id()) == Some(Some(group_id)))
74                {
75                    if group.len() >= self.max_group_size {
76                        return false; // Group is full
77                    }
78                    group.push_back(value);
79                } else {
80                    let mut new_group = VecDeque::new();
81                    new_group.push_back(value);
82                    self.groups.push(new_group);
83                }
84                true
85            }
86        }
87    }
88
89    #[inline(always)]
90    pub fn pop(&mut self) -> Option<Arc<V>> {
91        if let Some(v) = self.ctrl_group.pop_front() {
92            //v.trace_message("POP");
93            return Some(v);
94        }
95        for _ in 0..self.groups.len() {
96            let pointer = self.pointer;
97            // Optimistically move queue pointer to the next group
98            self.pointer = (pointer + 1) % self.groups.len();
99
100            let group = &mut self.groups[pointer];
101            let item = group.pop_front();
102
103            if item.is_some() {
104                if group.is_empty() {
105                    self.groups.remove(pointer);
106                    if pointer < self.groups.len() {
107                        self.pointer = pointer;
108                    } else {
109                        self.pointer = 0;
110                    }
111                }
112                return item;
113            }
114        }
115
116        None
117    }
118}
119
120/// Shared state between sender and receiver
121struct ChannelState<T: FairGroup + 'static> {
122    queue: FairQueue<T>,
123    closed: bool,
124}
125
126impl<T: FairGroup + 'static> ChannelState<T> {
127    fn new(max_group_size: usize) -> Self {
128        Self {
129            queue: FairQueue::new(max_group_size),
130            closed: false,
131        }
132    }
133
134    fn can_insert(&self, value: &T) -> bool {
135        self.queue.can_insert(value)
136    }
137}
138
139/// Sender half of the fair channel
140pub struct FairSender<T: FairGroup + 'static> {
141    state: Arc<RwLock<ChannelState<T>>>,
142    notify_recv: Arc<Notify>,
143    notify_send: Arc<Notify>,
144}
145
146impl<T: FairGroup + 'static> Clone for FairSender<T> {
147    fn clone(&self) -> Self {
148        Self {
149            state: Arc::clone(&self.state),
150            notify_recv: Arc::clone(&self.notify_recv),
151            notify_send: Arc::clone(&self.notify_send),
152        }
153    }
154}
155
156impl<T: FairGroup + 'static> FairSender<T> {
157    /// Send a value, waiting if the channel is full
158    pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
159        let value_arc = Arc::new(value);
160
161        loop {
162            {
163                let mut state = self.state.write();
164                if state.closed {
165                    return Err(SendError(
166                        Arc::try_unwrap(value_arc).unwrap_or_else(|arc| (*arc).clone()),
167                    ));
168                }
169
170                // Check group capacity
171                if state.can_insert(&value_arc) {
172                    state.queue.insert(value_arc);
173                    drop(state);
174                    self.notify_recv.notify_waiters();
175                    return Ok(());
176                }
177            }
178
179            // Wait for space to become available in the group
180            self.notify_send.notified().await;
181        }
182    }
183
184    /// Try to send a value without waiting
185    pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
186        let value_arc = Arc::new(value);
187
188        let mut state = self.state.write();
189        if state.closed {
190            return Err(TrySendError::Closed(
191                Arc::try_unwrap(value_arc).unwrap_or_else(|arc| (*arc).clone()),
192            ));
193        }
194
195        if !state.queue.can_insert(&value_arc) {
196            return Err(TrySendError::Full(
197                Arc::try_unwrap(value_arc).unwrap_or_else(|arc| (*arc).clone()),
198            ));
199        }
200
201        state.queue.insert(value_arc);
202        drop(state); // Release lock before notifying
203        self.notify_recv.notify_waiters();
204        Ok(())
205    }
206
207    /// Check if the channel is closed
208    pub async fn closed(&self) {
209        loop {
210            {
211                let state = self.state.read();
212                if state.closed {
213                    return;
214                }
215            }
216
217            // Wait for the channel to be closed
218            self.notify_send.notified().await;
219        }
220    }
221}
222
223/// Receiver half of the fair channel
224pub struct FairReceiver<T: FairGroup + 'static> {
225    state: Arc<RwLock<ChannelState<T>>>,
226    notify_recv: Arc<Notify>,
227    notify_send: Arc<Notify>,
228}
229
230impl<T: FairGroup + 'static> FairReceiver<T> {
231    /// Receive a value, waiting if the channel is empty
232    pub async fn recv(&mut self) -> Option<T> {
233        loop {
234            {
235                let mut state = self.state.write();
236                if let Some(value_arc) = state.queue.pop() {
237                    drop(state);
238                    self.notify_send.notify_waiters();
239                    return Some(Arc::try_unwrap(value_arc).unwrap_or_else(|arc| (*arc).clone()));
240                }
241
242                if state.closed {
243                    return None;
244                }
245            }
246
247            // Wait for a value to become available
248            self.notify_recv.notified().await;
249        }
250    }
251
252    /// Try to receive a value without waiting
253    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
254        let mut state = self.state.write();
255
256        if let Some(value_arc) = state.queue.pop() {
257            drop(state); // Release lock before notifying
258            self.notify_send.notify_waiters();
259            return Ok(Arc::try_unwrap(value_arc).unwrap_or_else(|arc| (*arc).clone()));
260        }
261
262        if state.closed {
263            Err(TryRecvError::Disconnected)
264        } else {
265            Err(TryRecvError::Empty)
266        }
267    }
268
269    /// Close the receiver, which will cause all senders to return errors
270    pub async fn close(&mut self) {
271        let mut state = self.state.write();
272        state.closed = true;
273        drop(state); // Release lock before notifying
274        self.notify_send.notify_waiters();
275    }
276}
277
278impl<T: FairGroup + 'static> Drop for FairReceiver<T> {
279    fn drop(&mut self) {
280        // Mark the channel as closed when receiver is dropped
281        if let Some(mut state) = self.state.try_write() {
282            state.closed = true;
283            drop(state); // Release lock before notifying
284            self.notify_send.notify_waiters();
285        }
286    }
287}
288
289/// Creates a new fair channel with the specified max group size
290pub fn fair_channel<T: FairGroup + 'static>(
291    max_group_size: usize,
292) -> (FairSender<T>, FairReceiver<T>) {
293    let state = Arc::new(RwLock::new(ChannelState::new(max_group_size)));
294    let notify_recv = Arc::new(Notify::new());
295    let notify_send = Arc::new(Notify::new());
296
297    let sender = FairSender {
298        state: Arc::clone(&state),
299        notify_recv: Arc::clone(&notify_recv),
300        notify_send: Arc::clone(&notify_send),
301    };
302
303    let receiver = FairReceiver {
304        state,
305        notify_recv,
306        notify_send,
307    };
308
309    (sender, receiver)
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[derive(Debug, PartialEq, Clone)]
317    struct Event {
318        timestamp: u32,
319        user_id: &'static str,
320    }
321
322    impl FairGroup for Event {
323        fn group_id(&self) -> Option<u32> {
324            // Convert user_id to a hash or use a simple mapping
325            match self.user_id {
326                "user1" => Some(1),
327                "user2" => Some(2),
328                "user3" => Some(3),
329                _ => Some(0), // Default group
330            }
331        }
332
333        fn get_size(&self) -> Option<usize> {
334            None // Not used in current implementation
335        }
336    }
337
338    #[test]
339    fn test_spaced_fairness() {
340        let event1 = Event {
341            timestamp: 1,
342            user_id: "user1",
343        };
344        let event2 = Event {
345            timestamp: 2,
346            user_id: "user2",
347        };
348        let event3 = Event {
349            timestamp: 3,
350            user_id: "user1",
351        };
352        let event4 = Event {
353            timestamp: 4,
354            user_id: "user3",
355        };
356        let event5 = Event {
357            timestamp: 5,
358            user_id: "user2",
359        };
360        let event6 = Event {
361            timestamp: 6,
362            user_id: "user1",
363        };
364        let event7 = Event {
365            timestamp: 7,
366            user_id: "user1",
367        };
368        let event8 = Event {
369            timestamp: 8,
370            user_id: "user3",
371        };
372
373        let mut queue = FairQueue::new(usize::MAX);
374
375        let event1_arc = Arc::new(event1.clone());
376        let event2_arc = Arc::new(event2.clone());
377        let event3_arc = Arc::new(event3.clone());
378        let event4_arc = Arc::new(event4.clone());
379        let event5_arc = Arc::new(event5.clone());
380        let event6_arc = Arc::new(event6.clone());
381        let event7_arc = Arc::new(event7.clone());
382        let event8_arc = Arc::new(event8.clone());
383
384        queue.insert(event1_arc.clone());
385        queue.insert(event2_arc.clone());
386        queue.insert(event3_arc.clone());
387        queue.insert(event4_arc.clone());
388        queue.insert(event5_arc.clone());
389        queue.insert(event6_arc.clone());
390        queue.insert(event7_arc.clone());
391        queue.insert(event8_arc.clone());
392
393        // With weighted round-robin prioritizing smaller buffers:
394        // After insertion: user1=[1,3,6,7], user2=[2,5], user3=[4,8]
395        // Weighted selection will favor smaller groups (user2 and user3)
396        let mut results = Vec::new();
397        while let Some(event) = queue.pop() {
398            results.push(event);
399        }
400
401        // Verify we got all events
402        assert_eq!(results.len(), 8);
403
404        // Verify fairness: smaller groups should be prioritized
405        // The exact order may vary but should favor smaller buffers
406        let user1_events: Vec<_> = results.iter().filter(|e| e.user_id == "user1").collect();
407        let user2_events: Vec<_> = results.iter().filter(|e| e.user_id == "user2").collect();
408        let user3_events: Vec<_> = results.iter().filter(|e| e.user_id == "user3").collect();
409
410        assert_eq!(user1_events.len(), 4);
411        assert_eq!(user2_events.len(), 2);
412        assert_eq!(user3_events.len(), 2);
413    }
414
415    #[tokio::test]
416    async fn test_fair_channel_basic() {
417        let (tx, mut rx) = fair_channel(5);
418
419        let event1 = Event {
420            timestamp: 1,
421            user_id: "user1",
422        };
423        let event2 = Event {
424            timestamp: 2,
425            user_id: "user2",
426        };
427
428        tx.send(event1).await.unwrap();
429        tx.send(event2).await.unwrap();
430
431        let received1 = rx.recv().await.unwrap();
432        let received2 = rx.recv().await.unwrap();
433
434        assert_eq!(received1.timestamp, 1);
435        assert_eq!(received2.timestamp, 2);
436    }
437
438    #[tokio::test]
439    async fn test_fair_channel_fairness() {
440        let (tx, mut rx) = fair_channel(5);
441
442        // Send events from different users
443        for i in 0..6 {
444            let user_id = match i % 3 {
445                0 => "user1",
446                1 => "user2",
447                _ => "user3",
448            };
449            let event = Event {
450                timestamp: i,
451                user_id,
452            };
453            tx.send(event).await.unwrap();
454        }
455
456        // Receive events and verify fair distribution
457        let mut received = Vec::new();
458        for _ in 0..6 {
459            received.push(rx.recv().await.unwrap());
460        }
461
462        // With weighted round-robin, smaller groups are prioritized
463        // After sending: user1=[0,3], user2=[1,4], user3=[2,5]
464        // All groups have equal size, so behavior depends on weighted selection
465        let user1_count = received.iter().filter(|e| e.user_id == "user1").count();
466        let user2_count = received.iter().filter(|e| e.user_id == "user2").count();
467        let user3_count = received.iter().filter(|e| e.user_id == "user3").count();
468
469        // Verify all users get fair representation
470        assert_eq!(user1_count, 2);
471        assert_eq!(user2_count, 2);
472        assert_eq!(user3_count, 2);
473
474        // Verify we received all expected timestamps
475        let mut timestamps: Vec<_> = received.iter().map(|e| e.timestamp).collect();
476        timestamps.sort();
477        assert_eq!(timestamps, vec![0, 1, 2, 3, 4, 5]);
478    }
479
480    #[tokio::test]
481    async fn test_fair_channel_closed_method() {
482        let (tx, mut rx) = fair_channel(5);
483
484        // Start a task that waits for the channel to be closed
485        let tx_clone = tx.clone();
486        let closed_task = tokio::spawn(async move {
487            tx_clone.closed().await;
488        });
489
490        // Give the closed task a moment to start
491        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
492
493        // Verify the closed task is still running
494        assert!(!closed_task.is_finished());
495
496        // Close the receiver
497        rx.close().await;
498
499        // The closed task should now complete
500        closed_task.await.unwrap();
501
502        // Verify that sending now returns an error
503        let result = tx
504            .send(Event {
505                timestamp: 1,
506                user_id: "user1",
507            })
508            .await;
509        assert!(matches!(result, Err(SendError(_))));
510    }
511}