1use crate::metrics::Metrics;
2use crate::ActionOp;
3use std::collections::VecDeque;
4use std::marker::PhantomData;
5use std::sync::{Arc, Condvar, Mutex};
6
7#[derive(Clone)]
9pub enum BackpressurePolicy<T>
10where
11 T: Send + Sync + Clone + 'static,
12{
13 BlockOnFull,
15 DropOldest,
17 DropLatest,
19 DropLatestIf {
21 predicate: Arc<dyn Fn(&ActionOp<T>) -> bool + Send + Sync>,
22 },
23 DropOldestIf {
25 predicate: Arc<dyn Fn(&ActionOp<T>) -> bool + Send + Sync>,
26 },
27}
28
29impl<T> Default for BackpressurePolicy<T>
30where
31 T: Send + Sync + Clone + 'static,
32{
33 fn default() -> Self {
34 BackpressurePolicy::BlockOnFull
35 }
36}
37
38#[derive(thiserror::Error, Debug)]
39pub(crate) enum SenderError<T> {
40 #[error("Failed to send: {0}")]
41 SendError(T),
42 #[error("Channel is closed")]
43 ChannelClosed,
44}
45
46struct MpscQueue<T>
48where
49 T: Send + Sync + Clone + 'static,
50{
51 queue: Mutex<VecDeque<ActionOp<T>>>,
52 condvar: Condvar,
53 capacity: usize,
54 policy: BackpressurePolicy<T>,
55 metrics: Option<Arc<dyn Metrics + Send + Sync>>,
56 closed: Mutex<bool>,
57}
58
59impl<T> MpscQueue<T>
60where
61 T: Send + Sync + Clone + 'static,
62{
63 fn new(
64 capacity: usize,
65 policy: BackpressurePolicy<T>,
66 metrics: Option<Arc<dyn Metrics + Send + Sync>>,
67 ) -> Self {
68 Self {
69 queue: Mutex::new(VecDeque::new()),
70 condvar: Condvar::new(),
71 capacity,
72 policy,
73 metrics,
74 closed: Mutex::new(false),
75 }
76 }
77
78 fn send(&self, item: ActionOp<T>) -> Result<i64, SenderError<ActionOp<T>>> {
79 let mut queue = self.queue.lock().unwrap();
80
81 if *self.closed.lock().unwrap() {
83 return Err(SenderError::ChannelClosed);
84 }
85
86 if queue.len() >= self.capacity {
87 match &self.policy {
88 BackpressurePolicy::BlockOnFull => {
89 while queue.len() >= self.capacity {
91 queue = self.condvar.wait(queue).unwrap();
92 if *self.closed.lock().unwrap() {
93 return Err(SenderError::ChannelClosed);
94 }
95 }
96 queue.push_back(item);
97 }
98 BackpressurePolicy::DropOldest => {
99 if let Some(dropped_item) = queue.pop_front() {
101 if let Some(metrics) = &self.metrics {
102 if let ActionOp::Action(action) = &dropped_item {
103 metrics.action_dropped(Some(action as &dyn std::any::Any));
104 }
105 }
106 }
107 queue.push_back(item);
108 }
109 BackpressurePolicy::DropLatest => {
110 if let Some(metrics) = &self.metrics {
112 if let ActionOp::Action(action) = &item {
113 metrics.action_dropped(Some(action as &dyn std::any::Any));
114 }
115 }
116 return Ok(queue.len() as i64);
117 }
118 BackpressurePolicy::DropLatestIf { predicate } => {
119 let mut dropped_count = 0;
121 let mut i = 0;
122 while i < queue.len() {
123 if predicate(&queue[i]) {
124 if let Some(dropped_item) = queue.remove(i) {
125 dropped_count += 1;
126 if let Some(metrics) = &self.metrics {
127 if let ActionOp::Action(action) = &dropped_item {
128 metrics.action_dropped(Some(action as &dyn std::any::Any));
129 }
130 }
131 break;
132 }
133 } else {
134 i += 1;
135 }
136 }
137
138 if dropped_count > 0 {
139 queue.push_back(item);
140 } else {
141 return Err(SenderError::SendError(item));
142 }
143 }
144 BackpressurePolicy::DropOldestIf { predicate } => {
145 let mut dropped_count = 0;
147 let mut i = 0;
148 while i < queue.len() {
149 let index = queue.len() - i - 1;
150 if predicate(&queue[index]) {
151 if let Some(dropped_item) = queue.remove(index) {
152 dropped_count += 1;
153 if let Some(metrics) = &self.metrics {
154 if let ActionOp::Action(action) = &dropped_item {
155 metrics.action_dropped(Some(action as &dyn std::any::Any));
156 }
157 }
158 break;
159 }
160 } else {
161 i += 1;
162 }
163 }
164
165 if dropped_count > 0 {
166 queue.push_back(item);
167 } else {
168 return Err(SenderError::SendError(item));
169 }
170 }
171 }
172 } else {
173 queue.push_back(item);
174 }
175
176 if let Some(metrics) = &self.metrics {
178 metrics.queue_size(queue.len());
179 }
180
181 self.condvar.notify_one();
182 Ok(queue.len() as i64)
183 }
184
185 fn try_send(&self, item: ActionOp<T>) -> Result<i64, SenderError<ActionOp<T>>> {
186 if *self.closed.lock().unwrap() {
188 return Err(SenderError::ChannelClosed);
189 }
190
191 let mut queue = self.queue.lock().unwrap();
192
193 if queue.len() >= self.capacity {
194 return Err(SenderError::SendError(item));
195 } else {
196 queue.push_back(item);
197 }
198
199 if let Some(metrics) = &self.metrics {
201 metrics.queue_size(queue.len());
202 }
203
204 self.condvar.notify_one();
205 Ok(queue.len() as i64)
206 }
207
208 fn recv(&self) -> Option<ActionOp<T>> {
209 let mut queue = self.queue.lock().unwrap();
210
211 while queue.is_empty() {
213 if *self.closed.lock().unwrap() {
214 return None;
215 }
216 queue = self.condvar.wait(queue).unwrap();
217 }
218
219 let item = queue.pop_front();
220 self.condvar.notify_one();
221 item
222 }
223
224 fn try_recv(&self) -> Option<ActionOp<T>> {
225 let mut queue = self.queue.lock().unwrap();
226 let item = queue.pop_front();
227 if item.is_some() {
228 self.condvar.notify_one();
229 }
230 item
231 }
232
233 fn len(&self) -> usize {
234 self.queue.lock().unwrap().len()
235 }
236
237 fn close(&self) {
238 *self.closed.lock().unwrap() = true;
239 self.condvar.notify_all();
240 }
241}
242
243#[derive(Clone)]
245pub(crate) struct SenderChannel<T>
246where
247 T: Send + Sync + Clone + 'static,
248{
249 _name: String,
250 queue: Arc<MpscQueue<T>>,
251}
252
253impl<Action> Drop for SenderChannel<Action>
254where
255 Action: Send + Sync + Clone + 'static,
256{
257 fn drop(&mut self) {
258 #[cfg(feature = "store-log")]
259 eprintln!("store: drop '{}' sender channel", self._name);
260 }
261}
262
263#[allow(dead_code)]
264impl<T> SenderChannel<T>
265where
266 T: Send + Sync + Clone + 'static,
267{
268 pub fn send(&self, item: ActionOp<T>) -> Result<i64, SenderError<ActionOp<T>>> {
269 self.queue.send(item)
270 }
271
272 pub fn try_send(&self, item: ActionOp<T>) -> Result<i64, SenderError<ActionOp<T>>> {
273 self.queue.try_send(item)
274 }
275}
276
277#[allow(dead_code)]
278pub(crate) struct ReceiverChannel<T>
279where
280 T: Send + Sync + Clone + 'static,
281{
282 name: String,
283 queue: Arc<MpscQueue<T>>,
284 metrics: Option<Arc<dyn Metrics + Send + Sync>>,
285}
286
287impl<Action> Drop for ReceiverChannel<Action>
288where
289 Action: Send + Sync + Clone + 'static,
290{
291 fn drop(&mut self) {
292 #[cfg(feature = "store-log")]
293 eprintln!("store: drop '{}' receiver channel", self.name);
294 self.close();
295 }
296}
297
298#[allow(dead_code)]
299impl<T> ReceiverChannel<T>
300where
301 T: Send + Sync + Clone + 'static,
302{
303 pub fn recv(&self) -> Option<ActionOp<T>> {
304 self.queue.recv()
305 }
306
307 #[allow(dead_code)]
308 pub fn try_recv(&self) -> Option<ActionOp<T>> {
309 self.queue.try_recv()
310 }
311
312 pub fn len(&self) -> usize {
313 self.queue.len()
314 }
315
316 pub fn close(&self) {
317 self.queue.close();
318 }
319}
320
321pub(crate) struct BackpressureChannel<MSG>
323where
324 MSG: Send + Sync + Clone + 'static,
325{
326 phantom_data: PhantomData<MSG>,
327}
328
329impl<MSG> BackpressureChannel<MSG>
330where
331 MSG: Send + Sync + Clone + 'static,
332{
333 #[allow(dead_code)]
334 pub fn pair(
335 capacity: usize,
336 policy: BackpressurePolicy<MSG>,
337 ) -> (SenderChannel<MSG>, ReceiverChannel<MSG>) {
338 Self::pair_with("<anon>", capacity, policy, None)
339 }
340
341 #[allow(dead_code)]
342 pub fn pair_with_metrics(
343 capacity: usize,
344 policy: BackpressurePolicy<MSG>,
345 metrics: Option<Arc<dyn Metrics + Send + Sync>>,
346 ) -> (SenderChannel<MSG>, ReceiverChannel<MSG>) {
347 Self::pair_with("<anon>", capacity, policy, metrics)
348 }
349
350 #[allow(dead_code)]
351 pub fn pair_with(
352 name: &str,
353 capacity: usize,
354 policy: BackpressurePolicy<MSG>,
355 metrics: Option<Arc<dyn Metrics + Send + Sync>>,
356 ) -> (SenderChannel<MSG>, ReceiverChannel<MSG>) {
357 let queue = Arc::new(MpscQueue::new(capacity, policy, metrics.clone()));
358
359 (
360 SenderChannel {
361 _name: name.to_string(),
362 queue: queue.clone(),
363 },
364 ReceiverChannel {
365 name: name.to_string(),
366 queue,
367 metrics,
368 },
369 )
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn test_basic_send_recv() {
379 let (sender, receiver) =
380 BackpressureChannel::<i32>::pair(5, BackpressurePolicy::BlockOnFull);
381
382 sender.send(ActionOp::Action(1)).unwrap();
383 sender.send(ActionOp::Action(2)).unwrap();
384
385 assert_eq!(receiver.recv(), Some(ActionOp::Action(1)));
386 assert_eq!(receiver.recv(), Some(ActionOp::Action(2)));
387 assert_eq!(receiver.try_recv(), None);
388 }
389
390 #[test]
391 fn test_drop_oldest() {
392 let (sender, receiver) =
393 BackpressureChannel::<i32>::pair(2, BackpressurePolicy::DropOldest);
394
395 sender.send(ActionOp::Action(1)).unwrap();
396 sender.send(ActionOp::Action(2)).unwrap();
397 sender.send(ActionOp::Action(3)).unwrap(); assert_eq!(receiver.recv(), Some(ActionOp::Action(2)));
400 assert_eq!(receiver.recv(), Some(ActionOp::Action(3)));
401 assert_eq!(receiver.try_recv(), None);
402 }
403
404 #[test]
405 fn test_drop_latest() {
406 let (sender, receiver) =
407 BackpressureChannel::<i32>::pair(2, BackpressurePolicy::DropLatest);
408
409 sender.send(ActionOp::Action(1)).unwrap();
410 sender.send(ActionOp::Action(2)).unwrap();
411 sender.send(ActionOp::Action(3)).unwrap(); assert_eq!(receiver.recv(), Some(ActionOp::Action(1)));
414 assert_eq!(receiver.recv(), Some(ActionOp::Action(2)));
415 assert_eq!(receiver.try_recv(), None);
416 }
417
418 #[test]
419 fn test_predicate_dropping() {
420 let predicate = Arc::new(|action_op: &ActionOp<i32>| match action_op {
422 ActionOp::Action(value) => *value < 5,
423 ActionOp::Exit(_) => false,
424 });
425
426 let (sender, receiver) =
427 BackpressureChannel::<i32>::pair(2, BackpressurePolicy::DropLatestIf { predicate });
428
429 sender.send(ActionOp::Action(1)).unwrap(); sender.send(ActionOp::Action(6)).unwrap(); let result = sender.send(ActionOp::Action(7)); assert!(
436 result.is_ok(),
437 "Should succeed because predicate should drop the first item"
438 );
439
440 let received_item = receiver.recv();
442 assert!(received_item.is_some());
443 if let Some(ActionOp::Action(value)) = received_item {
444 assert_eq!(value, 6, "Should receive 6, not 1");
446 }
447
448 let received_item = receiver.recv();
449 assert!(received_item.is_some());
450 if let Some(ActionOp::Action(value)) = received_item {
451 assert_eq!(value, 7, "Should receive 7");
452 }
453 }
454
455 #[test]
456 fn test_block_on_full() {
457 let (sender, receiver) =
458 BackpressureChannel::<i32>::pair(1, BackpressurePolicy::BlockOnFull);
459
460 sender.send(ActionOp::Action(1)).unwrap();
461
462 let result = sender.try_send(ActionOp::Action(2));
464 assert!(result.is_err(), "Should fail because channel is full");
465
466 assert_eq!(receiver.recv(), Some(ActionOp::Action(1)));
468
469 sender.send(ActionOp::Action(2)).unwrap();
471 assert_eq!(receiver.recv(), Some(ActionOp::Action(2)));
472 }
473}