1use crate::metrics::Metrics;
2use crate::ActionOp;
3use std::collections::VecDeque;
4use std::marker::PhantomData;
5use std::sync::{Arc, Condvar, Mutex};
6
7#[derive(Clone, Default)]
9pub enum BackpressurePolicy<T>
10where
11 T: Send + Sync + Clone + 'static,
12{
13 #[default]
15 BlockOnFull,
16 DropOldest,
18 DropLatest,
20 DropLatestIf {
22 predicate: Arc<dyn Fn(&ActionOp<T>) -> bool + Send + Sync>,
23 },
24 DropOldestIf {
26 predicate: Arc<dyn Fn(&ActionOp<T>) -> bool + Send + Sync>,
27 },
28}
29
30#[derive(thiserror::Error, Debug)]
31pub(crate) enum SenderError<T> {
32 #[error("Failed to send: {0}")]
33 SendError(T),
34 #[error("Channel is closed")]
35 ChannelClosed,
36}
37
38struct MpscQueue<T>
40where
41 T: Send + Sync + Clone + 'static,
42{
43 queue: Mutex<VecDeque<ActionOp<T>>>,
44 condvar: Condvar,
45 capacity: usize,
46 policy: BackpressurePolicy<T>,
47 metrics: Option<Arc<dyn Metrics + Send + Sync>>,
48 closed: Mutex<bool>,
49}
50
51impl<T> MpscQueue<T>
52where
53 T: Send + Sync + Clone + 'static,
54{
55 fn new(
56 capacity: usize,
57 policy: BackpressurePolicy<T>,
58 metrics: Option<Arc<dyn Metrics + Send + Sync>>,
59 ) -> Self {
60 Self {
61 queue: Mutex::new(VecDeque::new()),
62 condvar: Condvar::new(),
63 capacity,
64 policy,
65 metrics,
66 closed: Mutex::new(false),
67 }
68 }
69
70 fn send(&self, item: ActionOp<T>) -> Result<i64, SenderError<ActionOp<T>>> {
71 let mut queue = self.queue.lock().unwrap();
72
73 if *self.closed.lock().unwrap() {
75 return Err(SenderError::ChannelClosed);
76 }
77
78 if queue.len() >= self.capacity {
79 match &self.policy {
80 BackpressurePolicy::BlockOnFull => {
81 while queue.len() >= self.capacity {
83 queue = self.condvar.wait(queue).unwrap();
84 if *self.closed.lock().unwrap() {
85 return Err(SenderError::ChannelClosed);
86 }
87 }
88 queue.push_back(item);
89 }
90 BackpressurePolicy::DropOldest => {
91 if let Some(dropped_item) = queue.pop_front() {
93 if let Some(metrics) = &self.metrics {
94 if let ActionOp::Action(action) = &dropped_item {
95 metrics.action_dropped(Some(action as &dyn std::any::Any));
96 }
97 }
98 }
99 queue.push_back(item);
100 }
101 BackpressurePolicy::DropLatest => {
102 if let Some(metrics) = &self.metrics {
104 if let ActionOp::Action(action) = &item {
105 metrics.action_dropped(Some(action as &dyn std::any::Any));
106 }
107 }
108 return Ok(queue.len() as i64);
109 }
110 BackpressurePolicy::DropLatestIf { predicate } => {
111 let mut dropped_count = 0;
113 let mut i = 0;
114 while i < queue.len() {
115 if predicate(&queue[i]) {
116 if let Some(dropped_item) = queue.remove(i) {
117 dropped_count += 1;
118 if let Some(metrics) = &self.metrics {
119 if let ActionOp::Action(action) = &dropped_item {
120 metrics.action_dropped(Some(action as &dyn std::any::Any));
121 }
122 }
123 break;
124 }
125 }
126 i += 1;
127 }
128
129 if dropped_count > 0 {
130 queue.push_back(item);
131 } else {
132 return Err(SenderError::SendError(item));
133 }
134 }
135 BackpressurePolicy::DropOldestIf { predicate } => {
136 let mut dropped_count = 0;
138 let mut i = 0;
139 while i < queue.len() {
140 let index = queue.len() - i - 1;
141 if predicate(&queue[index]) {
142 if let Some(dropped_item) = queue.remove(index) {
143 dropped_count += 1;
144 if let Some(metrics) = &self.metrics {
145 if let ActionOp::Action(action) = &dropped_item {
146 metrics.action_dropped(Some(action as &dyn std::any::Any));
147 }
148 }
149 break;
150 }
151 }
152 i += 1;
153 }
154
155 if dropped_count > 0 {
156 queue.push_back(item);
157 } else {
158 return Err(SenderError::SendError(item));
159 }
160 }
161 }
162 } else {
163 queue.push_back(item);
164 }
165
166 if let Some(metrics) = &self.metrics {
168 metrics.queue_size(queue.len());
169 }
170
171 self.condvar.notify_one();
172 Ok(queue.len() as i64)
173 }
174
175 fn try_send(&self, item: ActionOp<T>) -> Result<i64, SenderError<ActionOp<T>>> {
176 if *self.closed.lock().unwrap() {
178 return Err(SenderError::ChannelClosed);
179 }
180
181 let mut queue = self.queue.lock().unwrap();
182
183 if queue.len() >= self.capacity {
184 match &self.policy {
185 BackpressurePolicy::BlockOnFull => {
186 return Err(SenderError::SendError(item));
187 }
188 BackpressurePolicy::DropOldest => {
189 if let Some(dropped_item) = queue.pop_front() {
191 if let Some(metrics) = &self.metrics {
192 if let ActionOp::Action(action) = &dropped_item {
193 metrics.action_dropped(Some(action as &dyn std::any::Any));
194 }
195 }
196 }
197 queue.push_back(item);
198 }
199 BackpressurePolicy::DropLatest => {
200 if let Some(metrics) = &self.metrics {
202 if let ActionOp::Action(action) = &item {
203 metrics.action_dropped(Some(action as &dyn std::any::Any));
204 }
205 }
206 return Ok(queue.len() as i64);
207 }
208 BackpressurePolicy::DropLatestIf { predicate } => {
209 let mut dropped_count = 0;
211 let mut i = 0;
212 while i < queue.len() {
213 if predicate(&queue[i]) {
214 if let Some(dropped_item) = queue.remove(i) {
215 dropped_count += 1;
216 if let Some(metrics) = &self.metrics {
217 if let ActionOp::Action(action) = &dropped_item {
218 metrics.action_dropped(Some(action as &dyn std::any::Any));
219 }
220 }
221 break;
222 }
223 }
224 i += 1;
225 }
226
227 if dropped_count > 0 {
228 queue.push_back(item);
229 } else {
230 return Err(SenderError::SendError(item));
231 }
232 }
233 BackpressurePolicy::DropOldestIf { predicate } => {
234 let mut dropped_count = 0;
236 let mut i = 0;
237 while i < queue.len() {
238 let index = queue.len() - i - 1;
239 if predicate(&queue[index]) {
240 if let Some(dropped_item) = queue.remove(index) {
241 dropped_count += 1;
242 if let Some(metrics) = &self.metrics {
243 if let ActionOp::Action(action) = &dropped_item {
244 metrics.action_dropped(Some(action as &dyn std::any::Any));
245 }
246 }
247 break;
248 }
249 }
250 i += 1;
251 }
252
253 if dropped_count > 0 {
254 queue.push_back(item);
255 } else {
256 return Err(SenderError::SendError(item));
257 }
258 }
259 }
260 } else {
261 queue.push_back(item);
262 }
263
264 if let Some(metrics) = &self.metrics {
266 metrics.queue_size(queue.len());
267 }
268
269 self.condvar.notify_one();
270 Ok(queue.len() as i64)
271 }
272
273 fn recv(&self) -> Option<ActionOp<T>> {
274 let mut queue = self.queue.lock().unwrap();
275
276 while queue.is_empty() {
278 if *self.closed.lock().unwrap() {
279 return None;
280 }
281 queue = self.condvar.wait(queue).unwrap();
282 }
283
284 let item = queue.pop_front();
285 self.condvar.notify_one();
286 item
287 }
288
289 fn try_recv(&self) -> Option<ActionOp<T>> {
290 let mut queue = self.queue.lock().unwrap();
291 let item = queue.pop_front();
292 if item.is_some() {
293 self.condvar.notify_one();
294 }
295 item
296 }
297
298 fn len(&self) -> usize {
299 self.queue.lock().unwrap().len()
300 }
301
302 fn close(&self) {
303 *self.closed.lock().unwrap() = true;
304 self.condvar.notify_all();
305 }
306}
307
308#[derive(Clone)]
310pub(crate) struct SenderChannel<T>
311where
312 T: Send + Sync + Clone + 'static,
313{
314 _name: String,
315 queue: Arc<MpscQueue<T>>,
316}
317
318impl<Action> Drop for SenderChannel<Action>
319where
320 Action: Send + Sync + Clone + 'static,
321{
322 fn drop(&mut self) {
323 #[cfg(feature = "store-log")]
324 eprintln!("store: drop '{}' sender channel", self._name);
325 }
326}
327
328#[allow(dead_code)]
329impl<T> SenderChannel<T>
330where
331 T: Send + Sync + Clone + 'static,
332{
333 pub fn send(&self, item: ActionOp<T>) -> Result<i64, SenderError<ActionOp<T>>> {
334 self.queue.send(item)
335 }
336
337 pub fn try_send(&self, item: ActionOp<T>) -> Result<i64, SenderError<ActionOp<T>>> {
338 self.queue.try_send(item)
339 }
340}
341
342#[allow(dead_code)]
343pub(crate) struct ReceiverChannel<T>
344where
345 T: Send + Sync + Clone + 'static,
346{
347 name: String,
348 queue: Arc<MpscQueue<T>>,
349 metrics: Option<Arc<dyn Metrics + Send + Sync>>,
350}
351
352impl<Action> Drop for ReceiverChannel<Action>
353where
354 Action: Send + Sync + Clone + 'static,
355{
356 fn drop(&mut self) {
357 #[cfg(feature = "store-log")]
358 eprintln!("store: drop '{}' receiver channel", self.name);
359 self.close();
360 }
361}
362
363#[allow(dead_code)]
364impl<T> ReceiverChannel<T>
365where
366 T: Send + Sync + Clone + 'static,
367{
368 pub fn recv(&self) -> Option<ActionOp<T>> {
369 self.queue.recv()
370 }
371
372 #[allow(dead_code)]
373 pub fn try_recv(&self) -> Option<ActionOp<T>> {
374 self.queue.try_recv()
375 }
376
377 pub fn len(&self) -> usize {
378 self.queue.len()
379 }
380
381 pub fn close(&self) {
382 self.queue.close();
383 }
384}
385
386pub(crate) struct BackpressureChannel<MSG>
388where
389 MSG: Send + Sync + Clone + 'static,
390{
391 phantom_data: PhantomData<MSG>,
392}
393
394impl<MSG> BackpressureChannel<MSG>
395where
396 MSG: Send + Sync + Clone + 'static,
397{
398 #[allow(dead_code)]
399 pub fn pair(
400 capacity: usize,
401 policy: BackpressurePolicy<MSG>,
402 ) -> (SenderChannel<MSG>, ReceiverChannel<MSG>) {
403 Self::pair_with("<anon>", capacity, policy, None)
404 }
405
406 #[allow(dead_code)]
407 pub fn pair_with_metrics(
408 capacity: usize,
409 policy: BackpressurePolicy<MSG>,
410 metrics: Option<Arc<dyn Metrics + Send + Sync>>,
411 ) -> (SenderChannel<MSG>, ReceiverChannel<MSG>) {
412 Self::pair_with("<anon>", capacity, policy, metrics)
413 }
414
415 #[allow(dead_code)]
416 pub fn pair_with(
417 name: &str,
418 capacity: usize,
419 policy: BackpressurePolicy<MSG>,
420 metrics: Option<Arc<dyn Metrics + Send + Sync>>,
421 ) -> (SenderChannel<MSG>, ReceiverChannel<MSG>) {
422 let queue = Arc::new(MpscQueue::new(capacity, policy, metrics.clone()));
423
424 (
425 SenderChannel {
426 _name: name.to_string(),
427 queue: queue.clone(),
428 },
429 ReceiverChannel {
430 name: name.to_string(),
431 queue,
432 metrics,
433 },
434 )
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_basic_send_recv() {
444 let (sender, receiver) =
445 BackpressureChannel::<i32>::pair(5, BackpressurePolicy::BlockOnFull);
446
447 sender.send(ActionOp::Action(1)).unwrap();
448 sender.send(ActionOp::Action(2)).unwrap();
449
450 assert_eq!(receiver.recv(), Some(ActionOp::Action(1)));
451 assert_eq!(receiver.recv(), Some(ActionOp::Action(2)));
452 assert_eq!(receiver.try_recv(), None);
453 }
454
455 #[test]
456 fn test_drop_oldest() {
457 let (sender, receiver) =
458 BackpressureChannel::<i32>::pair(2, BackpressurePolicy::DropOldest);
459
460 sender.send(ActionOp::Action(1)).unwrap();
461 sender.send(ActionOp::Action(2)).unwrap();
462 sender.send(ActionOp::Action(3)).unwrap(); assert_eq!(receiver.recv(), Some(ActionOp::Action(2)));
465 assert_eq!(receiver.recv(), Some(ActionOp::Action(3)));
466 assert_eq!(receiver.try_recv(), None);
467 }
468
469 #[test]
470 fn test_drop_latest() {
471 let (sender, receiver) =
472 BackpressureChannel::<i32>::pair(2, BackpressurePolicy::DropLatest);
473
474 sender.send(ActionOp::Action(1)).unwrap();
475 sender.send(ActionOp::Action(2)).unwrap();
476 sender.send(ActionOp::Action(3)).unwrap(); assert_eq!(receiver.recv(), Some(ActionOp::Action(1)));
479 assert_eq!(receiver.recv(), Some(ActionOp::Action(2)));
480 assert_eq!(receiver.try_recv(), None);
481 }
482
483 #[test]
484 fn test_predicate_dropping() {
485 let predicate = Arc::new(|action_op: &ActionOp<i32>| match action_op {
487 ActionOp::Action(value) => *value < 5,
488 ActionOp::Exit(_) => false,
489 });
490
491 let (sender, receiver) =
492 BackpressureChannel::<i32>::pair(2, BackpressurePolicy::DropLatestIf { predicate });
493
494 sender.send(ActionOp::Action(1)).unwrap(); sender.send(ActionOp::Action(6)).unwrap(); let result = sender.send(ActionOp::Action(7)); assert!(
501 result.is_ok(),
502 "Should succeed because predicate should drop the first item"
503 );
504
505 let received_item = receiver.recv();
507 assert!(received_item.is_some());
508 if let Some(ActionOp::Action(value)) = received_item {
509 assert_eq!(value, 6, "Should receive 6, not 1");
511 }
512
513 let received_item = receiver.recv();
514 assert!(received_item.is_some());
515 if let Some(ActionOp::Action(value)) = received_item {
516 assert_eq!(value, 7, "Should receive 7");
517 }
518 }
519
520 #[test]
521 fn test_block_on_full() {
522 let (sender, receiver) =
523 BackpressureChannel::<i32>::pair(1, BackpressurePolicy::BlockOnFull);
524
525 sender.send(ActionOp::Action(1)).unwrap();
526
527 let result = sender.try_send(ActionOp::Action(2));
529 assert!(result.is_err(), "Should fail because channel is full");
530
531 assert_eq!(receiver.recv(), Some(ActionOp::Action(1)));
533
534 sender.send(ActionOp::Action(2)).unwrap();
536 assert_eq!(receiver.recv(), Some(ActionOp::Action(2)));
537 }
538
539 #[test]
540 fn test_drop_oldest_if_predicate_always_false() {
541 let (sender, receiver) = BackpressureChannel::pair(
542 3,
543 BackpressurePolicy::DropOldestIf {
544 predicate: Arc::new(|_| false), },
546 );
547
548 assert!(sender.try_send(ActionOp::Action(1)).is_ok());
550 assert!(sender.try_send(ActionOp::Action(2)).is_ok());
551 assert!(sender.try_send(ActionOp::Action(3)).is_ok());
552 assert_eq!(receiver.len(), 3);
553
554 let result = sender.try_send(ActionOp::Action(4));
557 assert!(
558 result.is_err(),
559 "Should fail because no items match the predicate"
560 );
561
562 assert_eq!(receiver.len(), 3);
564 assert_eq!(receiver.recv(), Some(ActionOp::Action(1)));
565 assert_eq!(receiver.recv(), Some(ActionOp::Action(2)));
566 assert_eq!(receiver.recv(), Some(ActionOp::Action(3)));
567 }
568
569 #[test]
570 fn test_drop_oldest_if_predicate_sometimes_true() {
571 let (sender, receiver) = BackpressureChannel::pair(
572 3,
573 BackpressurePolicy::DropOldestIf {
574 predicate: Arc::new(|action_op: &ActionOp<i32>| {
575 if let ActionOp::Action(value) = action_op {
576 *value < 5 } else {
578 false
579 }
580 }),
581 },
582 );
583
584 assert!(sender.try_send(ActionOp::Action(6)).is_ok()); assert!(sender.try_send(ActionOp::Action(2)).is_ok()); assert!(sender.try_send(ActionOp::Action(8)).is_ok()); assert_eq!(receiver.len(), 3);
589
590 let result = sender.try_send(ActionOp::Action(9));
592 assert!(
593 result.is_ok(),
594 "Should fail because no items match the predicate"
595 );
596
597 let result = sender.try_send(ActionOp::Action(10)); assert!(
600 result.is_err(),
601 "Should fail because no items match the predicate"
602 );
603
604 assert_eq!(receiver.len(), 3);
606 assert_eq!(receiver.recv(), Some(ActionOp::Action(6)));
607 assert_eq!(receiver.recv(), Some(ActionOp::Action(8)));
608 assert_eq!(receiver.recv(), Some(ActionOp::Action(9)));
609 }
610}