nodo/channels/
double_buffer_rx.rs1use crate::{
4 channels::{
5 BackStage, FrontStage, OverflowPolicy, Rx, RxChannelTimeseries, RxConnectable,
6 SharedBackStage, SyncResult,
7 },
8 core::{Message, TimestampKind},
9 prelude::{Pop, RetentionPolicy},
10};
11use core::ops;
12use std::{
13 collections::vec_deque,
14 fmt,
15 sync::{Arc, RwLock},
16};
17
18pub struct DoubleBufferRx<T> {
28 pub(crate) back: SharedBackStage<T>,
29 front: FrontStage<T>,
30 pub(crate) is_connected: bool,
31}
32
33impl<T> DoubleBufferRx<T> {
34 pub fn new(overflow_policy: OverflowPolicy, retention_policy: RetentionPolicy) -> Self {
37 let back = BackStage::new(overflow_policy, retention_policy);
38 let capacity = back.capacity();
39 Self {
40 back: Arc::new(RwLock::new(back)),
41 front: FrontStage::new(capacity),
42 is_connected: false,
43 }
44 }
45
46 pub fn new_latest() -> Self {
48 Self::new(OverflowPolicy::Forget(1), RetentionPolicy::Keep)
49 }
50
51 pub fn new_auto_size() -> Self {
56 Self::new(OverflowPolicy::Resize, RetentionPolicy::Drop)
57 }
58
59 pub fn front_len(&self) -> usize {
62 self.front.len()
63 }
64
65 pub fn pop_all(&mut self) -> std::collections::vec_deque::Drain<'_, T> {
66 self.front.drain(..)
67 }
68
69 pub fn latest(&self) -> Option<&T> {
71 let n = self.front.len();
72 if n == 0 {
73 None
74 } else {
75 Some(&self.front[n - 1])
76 }
77 }
78
79 pub fn is_full(&self) -> bool {
82 match self.back.read().unwrap().overflow_policy() {
84 OverflowPolicy::Reject(n) | OverflowPolicy::Forget(n) => self.front.len() == *n,
85 OverflowPolicy::Resize => false,
86 }
87 }
88
89 pub fn clear(&mut self) {
90 self.front.clear();
91 }
92
93 pub fn drain<R>(&mut self, range: R) -> vec_deque::Drain<'_, T>
94 where
95 R: ops::RangeBounds<usize>,
96 {
97 self.front.drain(range)
98 }
99}
100
101impl<T> DoubleBufferRx<Message<T>> {
102 pub fn as_acq_time_series<'a>(&'a self) -> RxChannelTimeseries<'a, T> {
103 RxChannelTimeseries {
104 channel: self,
105 kind: TimestampKind::Acq,
106 }
107 }
108
109 pub fn as_pub_time_series<'a>(&'a self) -> RxChannelTimeseries<'a, T> {
110 RxChannelTimeseries {
111 channel: self,
112 kind: TimestampKind::Pub,
113 }
114 }
115}
116
117impl<T> Pop for DoubleBufferRx<T> {
118 type Output = T;
119
120 fn is_empty(&self) -> bool {
121 self.front.is_empty()
122 }
123
124 fn pop(&mut self) -> Result<T, RxRecvError> {
125 self.front.pop().ok_or(RxRecvError::QueueEmtpy)
126 }
127}
128
129impl<T> ops::Index<usize> for DoubleBufferRx<T> {
130 type Output = T;
131
132 fn index(&self, idx: usize) -> &Self::Output {
133 &self.front[idx]
134 }
135}
136
137impl<T> ops::IndexMut<usize> for DoubleBufferRx<T> {
138 fn index_mut(&mut self, idx: usize) -> &mut Self::Output {
139 &mut self.front[idx]
140 }
141}
142
143impl<T: Send + Sync> Rx for DoubleBufferRx<T> {
144 fn is_connected(&self) -> bool {
145 self.is_connected
146 }
147
148 fn sync(&mut self) -> SyncResult {
149 self.back.write().unwrap().sync(&mut self.front)
150 }
151
152 fn len(&self) -> usize {
155 self.front_len()
156 }
157}
158
159#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub enum RxRecvError {
161 QueueEmtpy,
162}
163
164impl fmt::Display for RxRecvError {
165 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
166 match self {
167 RxRecvError::QueueEmtpy => write!(fmt, "QueueEmtpy"),
168 }
169 }
170}
171
172impl std::error::Error for RxRecvError {}
173
174impl<T: Send + Sync> RxConnectable for DoubleBufferRx<T> {
175 type Message = T;
176
177 fn overflow_policy(&self) -> OverflowPolicy {
178 *self.back.read().unwrap().overflow_policy()
179 }
180
181 fn on_connect(&mut self) -> SharedBackStage<Self::Message> {
182 self.is_connected = true;
183 self.back.clone()
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use crate::{
190 channels::{FlushResult, SyncResult},
191 prelude::*,
192 };
193 use std::sync::mpsc;
194
195 fn fixed_channel<T: Clone + Send + Sync>(
196 size: usize,
197 ) -> (DoubleBufferTx<T>, DoubleBufferRx<T>) {
198 let mut tx = DoubleBufferTx::new(size);
199 let mut rx =
200 DoubleBufferRx::new(OverflowPolicy::Reject(size), RetentionPolicy::EnforceEmpty);
201 connect(&mut tx, &mut rx).unwrap();
202 (tx, rx)
203 }
204
205 #[test]
206 fn test() {
207 const NUM_MESSAGES: usize = 100;
208 const NUM_ROUNDS: usize = 100;
209
210 let (mut tx, mut rx) = fixed_channel(NUM_MESSAGES);
211
212 let (sync_tx, sync_rx) = mpsc::sync_channel(1);
214 let (rep_tx, rep_rx) = mpsc::sync_channel(1);
215
216 let t1 = std::thread::spawn(move || {
218 for k in 0..NUM_ROUNDS {
219 sync_rx.recv().unwrap();
221
222 assert_eq!(
223 rx.sync(),
224 SyncResult {
225 received: NUM_MESSAGES,
226 ..Default::default()
227 }
228 );
229
230 rep_tx.send(()).unwrap();
231
232 for i in 0..NUM_MESSAGES {
234 assert_eq!(rx.pop().unwrap(), format!("hello {k} {i}"));
235 }
236 }
237 });
238
239 let t2 = std::thread::spawn(move || {
241 for k in 0..NUM_ROUNDS {
242 for i in 0..NUM_MESSAGES {
244 tx.push(format!("hello {k} {i}")).unwrap();
245 }
246 assert_eq!(
247 tx.flush(),
248 FlushResult {
249 available: NUM_MESSAGES,
250 published: NUM_MESSAGES,
251 ..Default::default()
252 }
253 );
254
255 sync_tx.send(()).unwrap();
257 rep_rx.recv().unwrap();
258 }
259 });
260
261 t1.join().unwrap();
262 t2.join().unwrap();
263 }
264}