coralstack_cmd_ipc/
channel.rs1use std::sync::atomic::{AtomicBool, Ordering};
14use std::sync::Arc;
15
16use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
17use futures::channel::oneshot;
18use futures::future::{BoxFuture, Shared};
19use futures::lock::Mutex as AsyncMutex;
20use futures::{FutureExt, StreamExt};
21use parking_lot::Mutex;
22
23use crate::error::ChannelError;
24use crate::message::Message;
25
26pub trait CommandChannel: Send + Sync {
39 fn id(&self) -> &str;
41
42 fn start(&self) -> BoxFuture<'_, Result<(), ChannelError>>;
45
46 fn close(&self) -> BoxFuture<'_, ()>;
50
51 fn send(&self, msg: Message) -> Result<(), ChannelError>;
54
55 fn recv(&self) -> BoxFuture<'_, Option<Message>>;
58}
59
60pub struct InMemoryChannel {
64 id: String,
65 outbound: UnboundedSender<Message>,
66 inbound: AsyncMutex<Option<UnboundedReceiver<Message>>>,
67 close_tx: Mutex<Option<oneshot::Sender<()>>>,
72 close_rx: Shared<oneshot::Receiver<()>>,
73 closed: AtomicBool,
74}
75
76impl InMemoryChannel {
77 pub fn pair(id_a: impl Into<String>, id_b: impl Into<String>) -> (Arc<Self>, Arc<Self>) {
84 let (tx_a_to_b, rx_b) = unbounded();
85 let (tx_b_to_a, rx_a) = unbounded();
86 let (close_a_tx, close_a_rx) = oneshot::channel::<()>();
87 let (close_b_tx, close_b_rx) = oneshot::channel::<()>();
88 let a = Arc::new(Self {
89 id: id_a.into(),
90 outbound: tx_a_to_b,
91 inbound: AsyncMutex::new(Some(rx_a)),
92 close_tx: Mutex::new(Some(close_a_tx)),
93 close_rx: close_a_rx.shared(),
94 closed: AtomicBool::new(false),
95 });
96 let b = Arc::new(Self {
97 id: id_b.into(),
98 outbound: tx_b_to_a,
99 inbound: AsyncMutex::new(Some(rx_b)),
100 close_tx: Mutex::new(Some(close_b_tx)),
101 close_rx: close_b_rx.shared(),
102 closed: AtomicBool::new(false),
103 });
104 (a, b)
105 }
106}
107
108impl CommandChannel for InMemoryChannel {
109 fn id(&self) -> &str {
110 &self.id
111 }
112
113 fn start(&self) -> BoxFuture<'_, Result<(), ChannelError>> {
114 Box::pin(async { Ok(()) })
115 }
116
117 fn close(&self) -> BoxFuture<'_, ()> {
118 Box::pin(async move {
119 self.closed.store(true, Ordering::SeqCst);
120 self.outbound.close_channel();
121 if let Some(tx) = self.close_tx.lock().take() {
123 let _ = tx.send(());
124 }
125 })
126 }
127
128 fn send(&self, msg: Message) -> Result<(), ChannelError> {
129 if self.closed.load(Ordering::SeqCst) {
130 return Err(ChannelError::Closed);
131 }
132 self.outbound
133 .unbounded_send(msg)
134 .map_err(|e| ChannelError::Send(e.to_string()))
135 }
136
137 fn recv(&self) -> BoxFuture<'_, Option<Message>> {
138 Box::pin(async move {
139 if self.closed.load(Ordering::SeqCst) {
140 return None;
141 }
142 let mut guard = self.inbound.lock().await;
143 let rx = guard.as_mut()?;
144 let close_fut = self.close_rx.clone();
145 futures::select_biased! {
146 msg = rx.next().fuse() => msg,
147 _ = close_fut.fuse() => None,
148 }
149 })
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::message::MessageId;
157 use futures::executor::block_on;
158 use futures::future::join;
159
160 fn ping(id: MessageId) -> Message {
161 Message::ListCommandsRequest { id, meta: None }
162 }
163
164 #[test]
165 fn pair_sends_in_both_directions() {
166 let (a, b) = InMemoryChannel::pair("alice", "bob");
167 block_on(async {
168 assert_eq!(a.id(), "alice");
169 assert_eq!(b.id(), "bob");
170
171 let m1 = ping(MessageId::new_v4());
172 let m2 = ping(MessageId::new_v4());
173
174 a.send(m1.clone()).unwrap();
175 b.send(m2.clone()).unwrap();
176
177 assert_eq!(b.recv().await, Some(m1));
178 assert_eq!(a.recv().await, Some(m2));
179 });
180 }
181
182 #[test]
183 fn recv_awaits_future_send() {
184 let (a, b) = InMemoryChannel::pair("alice", "bob");
185 block_on(async {
186 let msg = ping(MessageId::new_v4());
187 let expected = msg.clone();
188
189 let (_, recvd) = join(
193 async {
194 a.send(msg).unwrap();
195 },
196 b.recv(),
197 )
198 .await;
199 assert_eq!(recvd, Some(expected));
200 });
201 }
202
203 #[test]
204 fn close_stops_recv_on_both_sides() {
205 let (a, b) = InMemoryChannel::pair("alice", "bob");
206 block_on(async {
207 a.close().await;
208 assert!(a.recv().await.is_none());
210 assert!(b.recv().await.is_none());
212 });
213 }
214
215 #[test]
216 fn send_after_close_is_error() {
217 let (a, _b) = InMemoryChannel::pair("alice", "bob");
218 block_on(async {
219 a.close().await;
220 });
221 let err = a.send(ping(MessageId::new_v4())).unwrap_err();
222 assert!(matches!(err, ChannelError::Closed));
223 }
224
225 #[test]
226 fn queued_messages_drain_after_peer_close() {
227 let (a, b) = InMemoryChannel::pair("alice", "bob");
228 block_on(async {
229 let m = ping(MessageId::new_v4());
230 let expected = m.clone();
231 a.send(m).unwrap();
232 a.close().await;
233 assert_eq!(b.recv().await, Some(expected));
235 assert!(b.recv().await.is_none());
236 });
237 }
238}