Skip to main content

binocular/infra/
channel.rs

1use std::fmt::Debug;
2use std::time::{Duration, Instant};
3
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum ChannelError {
6    Disconnected,
7    Empty,
8    Full,
9}
10
11impl std::fmt::Display for ChannelError {
12    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13        match self {
14            ChannelError::Disconnected => write!(f, "channel disconnected"),
15            ChannelError::Empty => write!(f, "channel empty"),
16            ChannelError::Full => write!(f, "channel full"),
17        }
18    }
19}
20
21impl std::error::Error for ChannelError {}
22
23pub trait Sender<T>: Clone + Send + 'static {
24    fn send(&self, value: T) -> Result<(), ChannelError>;
25
26    fn try_send(&self, value: T) -> Result<(), ChannelError>;
27}
28
29pub trait Receiver<T>: Send + 'static {
30    fn recv(&self) -> Result<T, ChannelError>;
31
32    fn try_recv(&self) -> Result<Option<T>, ChannelError>;
33}
34
35pub struct KanalSender<T>(kanal::Sender<T>);
36
37impl<T> Clone for KanalSender<T> {
38    fn clone(&self) -> Self {
39        KanalSender(self.0.clone())
40    }
41}
42
43impl<T: Send + 'static> Sender<T> for KanalSender<T> {
44    fn send(&self, value: T) -> Result<(), ChannelError> {
45        self.0.send(value).map_err(|_| ChannelError::Disconnected)
46    }
47
48    fn try_send(&self, value: T) -> Result<(), ChannelError> {
49        match self.0.try_send(value) {
50            Ok(true) => Ok(()),
51            Ok(false) => Err(ChannelError::Full),
52            Err(kanal::SendError::Closed | kanal::SendError::ReceiveClosed) => {
53                Err(ChannelError::Disconnected)
54            }
55        }
56    }
57}
58
59pub struct KanalReceiver<T>(kanal::Receiver<T>);
60
61impl<T: Send + 'static> Receiver<T> for KanalReceiver<T> {
62    fn recv(&self) -> Result<T, ChannelError> {
63        self.0.recv().map_err(|_| ChannelError::Disconnected)
64    }
65
66    fn try_recv(&self) -> Result<Option<T>, ChannelError> {
67        match self.0.try_recv() {
68            Ok(Some(value)) => Ok(Some(value)),
69            Ok(None) => Ok(None),
70            Err(_) => Err(ChannelError::Disconnected),
71        }
72    }
73}
74
75pub fn unbounded<T: Send + 'static>() -> (impl Sender<T>, impl Receiver<T>) {
76    let (tx, rx) = kanal::unbounded();
77    (KanalSender(tx), KanalReceiver(rx))
78}
79
80pub fn bounded<T: Send + 'static>(capacity: usize) -> (impl Sender<T>, impl Receiver<T>) {
81    let (tx, rx) = kanal::bounded(capacity);
82    (KanalSender(tx), KanalReceiver(rx))
83}
84
85const BATCH_FLUSH_INTERVAL: Duration = Duration::from_millis(50);
86
87pub struct BatchSender<T, S: Sender<Vec<T>>> {
88    tx: S,
89    buf: Vec<T>,
90    capacity: usize,
91    last_flush: Instant,
92}
93
94impl<T, S: Sender<Vec<T>>> BatchSender<T, S> {
95    pub fn new(tx: S, capacity: usize) -> Self {
96        Self {
97            tx,
98            buf: Vec::with_capacity(capacity),
99            capacity,
100            last_flush: Instant::now(),
101        }
102    }
103
104    fn flush_buf(&mut self) {
105        if !self.buf.is_empty() {
106            let _ = self.tx.send(std::mem::replace(
107                &mut self.buf,
108                Vec::with_capacity(self.capacity),
109            ));
110            self.last_flush = Instant::now();
111        }
112    }
113
114    pub fn push(&mut self, item: T) {
115        self.buf.push(item);
116        if self.buf.len() >= self.capacity || self.last_flush.elapsed() >= BATCH_FLUSH_INTERVAL {
117            self.flush_buf();
118        }
119    }
120
121    pub fn tick(&mut self) {
122        if self.last_flush.elapsed() >= BATCH_FLUSH_INTERVAL {
123            self.flush_buf();
124        }
125    }
126}
127
128impl<T, S: Sender<Vec<T>>> Drop for BatchSender<T, S> {
129    fn drop(&mut self) {
130        self.flush_buf();
131    }
132}
133
134use std::marker::PhantomData;
135
136pub struct MapSender<T, U, F, S> {
137    tx: S,
138    mapper: F,
139    _phantom: PhantomData<fn(T) -> U>,
140}
141
142impl<T, U, F, S> MapSender<T, U, F, S> {
143    pub fn new(tx: S, mapper: F) -> Self {
144        Self {
145            tx,
146            mapper,
147            _phantom: PhantomData,
148        }
149    }
150}
151
152impl<T, U, F, S> Clone for MapSender<T, U, F, S>
153where
154    S: Clone,
155    F: Clone,
156{
157    fn clone(&self) -> Self {
158        Self {
159            tx: self.tx.clone(),
160            mapper: self.mapper.clone(),
161            _phantom: PhantomData,
162        }
163    }
164}
165
166impl<T, U, F, S> Sender<T> for MapSender<T, U, F, S>
167where
168    T: Send + 'static,
169    U: Send + 'static,
170    F: Fn(T) -> U + Clone + Send + 'static,
171    S: Sender<U>,
172{
173    fn send(&self, value: T) -> Result<(), ChannelError> {
174        self.tx.send((self.mapper)(value))
175    }
176
177    fn try_send(&self, value: T) -> Result<(), ChannelError> {
178        self.tx.try_send((self.mapper)(value))
179    }
180}
181
182pub type DefaultSender<T> = KanalSender<T>;
183
184pub type DefaultReceiver<T> = KanalReceiver<T>;
185
186pub fn unbounded_default<T: Send + 'static>() -> (DefaultSender<T>, DefaultReceiver<T>) {
187    let (tx, rx) = kanal::unbounded();
188    (KanalSender(tx), KanalReceiver(rx))
189}
190
191pub fn bounded_default<T: Send + 'static>(
192    capacity: usize,
193) -> (DefaultSender<T>, DefaultReceiver<T>) {
194    let (tx, rx) = kanal::bounded(capacity);
195    (KanalSender(tx), KanalReceiver(rx))
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_unbounded_send_recv() {
204        let (tx, rx) = unbounded_default::<i32>();
205        tx.send(42).unwrap();
206        assert_eq!(rx.recv().unwrap(), 42);
207    }
208
209    #[test]
210    fn test_try_recv_empty() {
211        let (_tx, rx) = unbounded_default::<i32>();
212        assert_eq!(rx.try_recv().unwrap(), None);
213    }
214
215    #[test]
216    fn test_try_recv_with_value() {
217        let (tx, rx) = unbounded_default::<i32>();
218        tx.send(42).unwrap();
219        assert_eq!(rx.try_recv().unwrap(), Some(42));
220    }
221
222    #[test]
223    fn test_sender_clone() {
224        let (tx, rx) = unbounded_default::<i32>();
225        let tx2 = tx.clone();
226        tx.send(1).unwrap();
227        tx2.send(2).unwrap();
228        assert_eq!(rx.recv().unwrap(), 1);
229        assert_eq!(rx.recv().unwrap(), 2);
230    }
231
232    #[test]
233    fn test_map_sender() {
234        let (tx, rx) = unbounded_default::<String>();
235        let map_tx = MapSender::new(tx, |i: i32| i.to_string());
236
237        map_tx.send(42).unwrap();
238        assert_eq!(rx.recv().unwrap(), "42");
239    }
240}