rs_store/
channel.rs

1use crate::metrics::Metrics;
2use crate::ActionOp;
3use crossbeam::channel::{self, Receiver, Sender, TrySendError};
4use std::marker::PhantomData;
5use std::sync::Arc;
6
7/// the Backpressure policy
8#[derive(Clone, Default)]
9pub enum BackpressurePolicy {
10    /// Block the sender when the queue is full
11    #[default]
12    BlockOnFull,
13    /// Drop the oldest item when the queue is full
14    DropOldest,
15    /// Drop the latest item when the queue is full
16    DropLatest,
17}
18
19#[derive(thiserror::Error, Debug)]
20pub(crate) enum SenderError<T> {
21    #[error("Failed to send: {0}")]
22    SendError(T),
23    #[error("Failed to try_send: {0}")]
24    TrySendError(TrySendError<T>),
25}
26
27/// Channel to hold the sender with backpressure policy
28#[derive(Clone)]
29pub(crate) struct SenderChannel<T>
30where
31    T: Send + Sync + Clone + 'static,
32{
33    _name: String,
34    sender: Sender<ActionOp<T>>,
35    receiver: Receiver<ActionOp<T>>,
36    policy: BackpressurePolicy,
37    metrics: Option<Arc<dyn Metrics + Send + Sync>>,
38}
39
40#[cfg(dev)]
41impl<Action> Drop for SenderChannel<Action>
42where
43    Action: Send + Sync + Clone + 'static,
44{
45    fn drop(&mut self) {
46        eprintln!("store: drop '{}' sender channel", self._name);
47    }
48}
49
50impl<T> SenderChannel<T>
51where
52    T: Send + Sync + Clone + 'static,
53{
54    pub fn send(&self, item: ActionOp<T>) -> Result<i64, SenderError<ActionOp<T>>> {
55        let r = match self.policy {
56            BackpressurePolicy::BlockOnFull => {
57                match self.sender.send(item).map_err(|e| SenderError::SendError(e.0)) {
58                    Ok(_) => Ok(self.receiver.len() as i64),
59                    Err(e) => Err(e),
60                }
61            }
62            BackpressurePolicy::DropOldest => {
63                if let Err(TrySendError::Full(item)) = self.sender.try_send(item) {
64                    // Drop the oldest item and try sending again
65                    #[cfg(dev)]
66                    eprintln!("store: dropping the oldest item in channel");
67                    // Remove the oldest item
68                    let _old = self.receiver.try_recv();
69                    if let Some(metrics) = &self.metrics {
70                        if let Ok(ActionOp::Action(action)) = _old.as_ref() {
71                            metrics.action_dropped(Some(action));
72                        }
73                    }
74                    match self.sender.try_send(item).map_err(SenderError::TrySendError) {
75                        Ok(_) => Ok(self.receiver.len() as i64),
76                        Err(e) => Err(e),
77                    }
78                } else {
79                    Ok(0)
80                }
81            }
82            BackpressurePolicy::DropLatest => {
83                // Try to send the item, if the queue is full, just ignore the item (drop the latest)
84                match self.sender.try_send(item).map_err(SenderError::TrySendError) {
85                    Ok(_) => Ok(self.receiver.len() as i64),
86                    Err(err) => {
87                        #[cfg(dev)]
88                        eprintln!("store: dropping the latest item in channel");
89                        if let Some(metrics) = &self.metrics {
90                            if let SenderError::TrySendError(TrySendError::Full(
91                                ActionOp::Action(action_drop),
92                            )) = &err
93                            {
94                                metrics.action_dropped(Some(action_drop));
95                            }
96                        }
97                        Err(err)
98                    }
99                }
100            }
101        };
102
103        if let Some(metrics) = &self.metrics {
104            metrics.queue_size(self.receiver.len());
105        }
106        r
107    }
108}
109
110#[allow(dead_code)]
111pub(crate) struct ReceiverChannel<T>
112where
113    T: Send + Sync + Clone + 'static,
114{
115    name: String,
116    receiver: Receiver<ActionOp<T>>,
117    metrics: Option<Arc<dyn Metrics + Send + Sync>>,
118}
119
120#[cfg(dev)]
121impl<Action> Drop for ReceiverChannel<Action>
122where
123    Action: Send + Sync + Clone + 'static,
124{
125    fn drop(&mut self) {
126        eprintln!("store: drop '{}' receiver channel", self.name);
127    }
128}
129
130impl<T> ReceiverChannel<T>
131where
132    T: Send + Sync + Clone + 'static,
133{
134    pub fn recv(&self) -> Option<ActionOp<T>> {
135        self.receiver.recv().ok()
136    }
137
138    #[allow(dead_code)]
139    pub fn try_recv(&self) -> Option<ActionOp<T>> {
140        self.receiver.try_recv().ok()
141    }
142}
143
144/// Channel with back pressure
145pub(crate) struct BackpressureChannel<MSG>
146where
147    MSG: Send + Sync + Clone + 'static,
148{
149    phantom_data: PhantomData<MSG>,
150}
151
152impl<MSG> BackpressureChannel<MSG>
153where
154    MSG: Send + Sync + Clone + 'static,
155{
156    #[allow(dead_code)]
157    pub fn pair(
158        capacity: usize,
159        policy: BackpressurePolicy,
160    ) -> (SenderChannel<MSG>, ReceiverChannel<MSG>) {
161        Self::pair_with("<anon>", capacity, policy, None)
162    }
163
164    #[allow(dead_code)]
165    pub fn pair_with_metrics(
166        capacity: usize,
167        policy: BackpressurePolicy,
168        metrics: Option<Arc<dyn Metrics + Send + Sync>>,
169    ) -> (SenderChannel<MSG>, ReceiverChannel<MSG>) {
170        Self::pair_with("<anon>", capacity, policy, metrics)
171    }
172
173    #[allow(dead_code)]
174    pub fn pair_with(
175        name: &str,
176        capacity: usize,
177        policy: BackpressurePolicy,
178        metrics: Option<Arc<dyn Metrics + Send + Sync>>,
179    ) -> (SenderChannel<MSG>, ReceiverChannel<MSG>) {
180        let (sender, receiver) = channel::bounded(capacity);
181        (
182            SenderChannel {
183                _name: name.to_string(),
184                sender,
185                receiver: receiver.clone(),
186                policy,
187                metrics: metrics.clone(),
188            },
189            ReceiverChannel {
190                name: name.to_string(),
191                receiver,
192                metrics: metrics.clone(),
193            },
194        )
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use std::thread;
202    use std::time::Duration;
203
204    #[test]
205    fn test_channel_backpressure_drop_old() {
206        let (sender, receiver) =
207            BackpressureChannel::<i32>::pair(5, BackpressurePolicy::DropOldest);
208
209        let producer = {
210            let sender_channel = sender.clone();
211            thread::spawn(move || {
212                for i in 0..20 {
213                    // Send more messages than the channel can hold
214                    println!("Sending: {}", i);
215                    if let Err(err) = sender_channel.send(ActionOp::Action(i)) {
216                        eprintln!("Failed to send: {:?}", err);
217                    }
218                    thread::sleep(Duration::from_millis(50)); // Slow down to observe full condition
219                }
220            })
221        };
222
223        let consumer = {
224            thread::spawn(move || {
225                let mut received_items = vec![];
226                while let Some(value) = receiver.recv() {
227                    println!("Received: {:?}", value);
228                    match value {
229                        ActionOp::Action(i) => received_items.push(i),
230                        _ => {}
231                    }
232                    thread::sleep(Duration::from_millis(150)); // Slow down the consumer to create a backlog
233                }
234                println!("Channel closed, consumer thread exiting.");
235                assert!(receiver.try_recv().is_none());
236
237                received_items
238            })
239        };
240
241        // Wait for the producer to finish
242        producer.join().unwrap();
243        drop(sender); // Close the channel after the producer is done
244
245        // Collect the results from the consumer thread
246        let received_items = consumer.join().unwrap();
247
248        // Check the length of received items; it should be less than the total sent (20) due to drops
249        assert!(received_items.len() < 20);
250        // Ensure the last items were not dropped (based on the DropOld policy)
251        assert_eq!(received_items.last(), Some(&19));
252    }
253
254    #[test]
255    fn test_channel_backpressure_drop_latest() {
256        let (sender, receiver) =
257            BackpressureChannel::<i32>::pair(5, BackpressurePolicy::DropLatest);
258
259        let producer = {
260            let sender_channel = sender.clone();
261            thread::spawn(move || {
262                for i in 0..20 {
263                    // Send more messages than the channel can hold
264                    println!("Sending: {}", i);
265                    if let Err(err) = sender_channel.send(ActionOp::Action(i)) {
266                        eprintln!("Failed to send: {:?}", err);
267                    }
268                    thread::sleep(Duration::from_millis(50)); // Slow down to observe full condition
269                }
270            })
271        };
272
273        let consumer = {
274            thread::spawn(move || {
275                let mut received_items = vec![];
276                while let Some(value) = receiver.recv() {
277                    eprintln!("Received: {:?}", value);
278                    match value {
279                        ActionOp::Action(i) => received_items.push(i),
280                        _ => {}
281                    }
282                    thread::sleep(Duration::from_millis(150)); // Slow down the consumer to create a backlog
283                }
284                println!("Channel closed, consumer thread exiting.");
285                received_items
286            })
287        };
288
289        // Wait for the producer to finish
290        producer.join().unwrap();
291        drop(sender); // Close the channel after the producer is done
292
293        // Collect the results from the consumer thread
294        let received_items = consumer.join().unwrap();
295
296        // Check the length of received items; it should be less than the total sent (20) due to drops
297        assert!(received_items.len() < 20);
298
299        // Ensure the last item received is not necessarily the last one sent, based on the DropLatest policy
300        assert!(received_items.contains(&0)); // The earliest items should be present
301        assert!(received_items.last().unwrap() < &19); // The latest items might be dropped
302    }
303}