1use std::io;
4
5use futures_util::{SinkExt, StreamExt};
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio::sync::mpsc;
8use tokio::task::JoinHandle;
9use tokio_tungstenite::WebSocketStream;
10use tokio_tungstenite::tungstenite::protocol::Message as WsMessage;
11
12use vox_types::{Backing, Link, LinkRx, LinkTx, LinkTxPermit, WriteSlot};
13
14pub struct WsLink<S> {
23 stream: WebSocketStream<S>,
24}
25
26impl<S> WsLink<S> {
27 pub fn new(stream: WebSocketStream<S>) -> Self {
29 Self { stream }
30 }
31}
32
33impl<S> WsLink<S>
34where
35 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
36{
37 pub async fn server(stream: S) -> Result<Self, io::Error> {
39 let ws = tokio_tungstenite::accept_async(stream)
40 .await
41 .map_err(|e| io::Error::other(e.to_string()))?;
42 Ok(Self::new(ws))
43 }
44}
45
46impl<S> Link for WsLink<S>
47where
48 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
49{
50 type Tx = WsLinkTx;
51 type Rx = WsLinkRx;
52
53 fn split(self) -> (Self::Tx, Self::Rx) {
54 let (tx_out, rx_out) = mpsc::channel::<Vec<u8>>(1);
55 let (tx_in, rx_in) = mpsc::channel::<Result<WsMessage, io::Error>>(1);
56
57 let io_task = tokio::spawn(io_loop(self.stream, rx_out, tx_in));
58
59 (
60 WsLinkTx {
61 tx: tx_out,
62 io_task,
63 },
64 WsLinkRx { rx: rx_in },
65 )
66 }
67}
68
69async fn io_loop<S>(
75 mut ws: WebSocketStream<S>,
76 mut rx_out: mpsc::Receiver<Vec<u8>>,
77 tx_in: mpsc::Sender<Result<WsMessage, io::Error>>,
78) where
79 S: AsyncRead + AsyncWrite + Unpin,
80{
81 loop {
82 tokio::select! {
83 msg = rx_out.recv() => {
85 match msg {
86 Some(bytes) => {
87 if let Err(e) = ws.feed(WsMessage::binary(bytes)).await {
88 let _ = tx_in.send(Err(io::Error::other(e.to_string()))).await;
89 return;
90 }
91 while let Ok(bytes) = rx_out.try_recv() {
93 if let Err(e) = ws.feed(WsMessage::binary(bytes)).await {
94 let _ = tx_in.send(Err(io::Error::other(e.to_string()))).await;
95 return;
96 }
97 }
98 if let Err(e) = ws.flush().await {
99 let _ = tx_in.send(Err(io::Error::other(e.to_string()))).await;
100 return;
101 }
102 }
103 None => {
104 return;
108 }
109 }
110 }
111 frame = ws.next() => {
113 match frame {
114 Some(Ok(msg)) => {
115 if tx_in.send(Ok(msg)).await.is_err() {
116 return;
118 }
119 }
120 Some(Err(e)) => {
121 use tokio_tungstenite::tungstenite::error::ProtocolError;
122 use tokio_tungstenite::tungstenite::Error as WsError;
123 match &e {
124 WsError::Protocol(
127 ProtocolError::ResetWithoutClosingHandshake,
128 ) => {
129 return;
130 }
131 _ => {
132 let _ = tx_in.send(Err(io::Error::other(e.to_string()))).await;
133 return;
134 }
135 }
136 }
137 None => {
138 return;
140 }
141 }
142 }
143 }
144 }
145}
146
147pub struct WsLinkTx {
157 tx: mpsc::Sender<Vec<u8>>,
158 io_task: JoinHandle<()>,
159}
160
161pub struct WsLinkTxPermit {
163 permit: mpsc::OwnedPermit<Vec<u8>>,
164}
165
166pub struct WsWriteSlot {
168 buf: Vec<u8>,
169 permit: mpsc::OwnedPermit<Vec<u8>>,
170}
171
172impl LinkTx for WsLinkTx {
173 type Permit = WsLinkTxPermit;
174
175 async fn reserve(&self) -> io::Result<Self::Permit> {
176 let permit = self.tx.clone().reserve_owned().await.map_err(|_| {
177 io::Error::new(
178 io::ErrorKind::ConnectionReset,
179 "websocket writer task stopped",
180 )
181 })?;
182 Ok(WsLinkTxPermit { permit })
183 }
184
185 async fn close(self) -> io::Result<()> {
186 drop(self.tx);
187 self.io_task.await.map_err(io::Error::other)
188 }
189}
190
191impl LinkTxPermit for WsLinkTxPermit {
193 type Slot = WsWriteSlot;
194
195 fn alloc(self, len: usize) -> io::Result<Self::Slot> {
196 Ok(WsWriteSlot {
197 buf: vec![0u8; len],
198 permit: self.permit,
199 })
200 }
201}
202
203impl WriteSlot for WsWriteSlot {
204 fn as_mut_slice(&mut self) -> &mut [u8] {
205 &mut self.buf
206 }
207
208 fn commit(self) {
209 drop(self.permit.send(self.buf));
210 }
211}
212
213pub struct WsLinkRx {
219 rx: mpsc::Receiver<Result<WsMessage, io::Error>>,
220}
221
222#[derive(Debug)]
224pub struct WsLinkRxError(io::Error);
225
226impl std::fmt::Display for WsLinkRxError {
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 write!(f, "websocket rx: {}", self.0)
229 }
230}
231
232impl std::error::Error for WsLinkRxError {
233 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
234 Some(&self.0)
235 }
236}
237
238impl LinkRx for WsLinkRx {
240 type Error = WsLinkRxError;
241
242 async fn recv(&mut self) -> Result<Option<Backing>, Self::Error> {
243 loop {
244 match self.rx.recv().await {
245 Some(Ok(WsMessage::Binary(data))) => {
246 return Ok(Some(Backing::Boxed(Vec::from(data).into_boxed_slice())));
247 }
248 Some(Ok(WsMessage::Close(_))) | None => {
249 return Ok(None);
250 }
251 Some(Ok(WsMessage::Ping(_) | WsMessage::Pong(_) | WsMessage::Frame(_))) => {
252 continue;
253 }
254 Some(Ok(WsMessage::Text(_))) => {
255 return Err(WsLinkRxError(io::Error::new(
256 io::ErrorKind::InvalidData,
257 "text frames not allowed on vox websocket link",
258 )));
259 }
260 Some(Err(e)) => {
261 return Err(WsLinkRxError(e));
262 }
263 }
264 }
265 }
266}
267
268#[cfg(test)]
273mod tests {
274 use tokio_tungstenite::WebSocketStream;
275 use tokio_tungstenite::tungstenite::protocol::Role;
276 use vox_types::{Backing, Link, LinkRx, LinkTx, LinkTxPermit, WriteSlot};
277
278 use super::*;
279
280 type TestWsLink = WsLink<tokio::io::DuplexStream>;
281
282 async fn ws_pair() -> (TestWsLink, TestWsLink) {
284 let (a, b) = tokio::io::duplex(64 * 1024);
285 let ws_a = WebSocketStream::from_raw_socket(a, Role::Server, None).await;
286 let ws_b = WebSocketStream::from_raw_socket(b, Role::Client, None).await;
287 (WsLink::new(ws_a), WsLink::new(ws_b))
288 }
289
290 fn payload(backing: &Backing) -> &[u8] {
291 match backing {
292 Backing::Boxed(b) => b,
293 Backing::Shared(s) => s.as_bytes(),
294 }
295 }
296
297 #[tokio::test]
298 async fn round_trip_single() {
299 let (a, b) = ws_pair().await;
300 let (tx_a, _rx_a) = a.split();
301 let (_tx_b, mut rx_b) = b.split();
302
303 let permit = tx_a.reserve().await.unwrap();
304 let mut slot = permit.alloc(5).unwrap();
305 slot.as_mut_slice().copy_from_slice(b"hello");
306 slot.commit();
307
308 let msg = rx_b.recv().await.unwrap().unwrap();
309 assert_eq!(payload(&msg), b"hello");
310 }
311
312 #[tokio::test]
313 async fn multiple_messages_in_order() {
314 let (a, b) = ws_pair().await;
315 let (tx_a, _rx_a) = a.split();
316 let (_tx_b, mut rx_b) = b.split();
317
318 let payloads: &[&[u8]] = &[b"one", b"two", b"three", b"four"];
319 for p in payloads {
320 let permit = tx_a.reserve().await.unwrap();
321 let mut slot = permit.alloc(p.len()).unwrap();
322 slot.as_mut_slice().copy_from_slice(p);
323 slot.commit();
324 }
325
326 for expected in payloads {
327 let msg = rx_b.recv().await.unwrap().unwrap();
328 assert_eq!(payload(&msg), *expected);
329 }
330 }
331
332 #[tokio::test]
334 async fn empty_payload() {
335 let (a, b) = ws_pair().await;
336 let (tx_a, _rx_a) = a.split();
337 let (_tx_b, mut rx_b) = b.split();
338
339 let permit = tx_a.reserve().await.unwrap();
340 let slot = permit.alloc(0).unwrap();
341 slot.commit();
342
343 let msg = rx_b.recv().await.unwrap().unwrap();
344 assert_eq!(payload(&msg), b"");
345 }
346
347 #[tokio::test]
349 async fn eof_on_peer_close() {
350 let (a, b) = ws_pair().await;
351 let (tx_a, _rx_a) = a.split();
352 let (_tx_b, mut rx_b) = b.split();
353
354 tx_a.close().await.unwrap();
355
356 assert!(rx_b.recv().await.unwrap().is_none());
357 assert!(rx_b.recv().await.unwrap().is_none());
359 }
360
361 #[tokio::test]
363 async fn dropped_permit_sends_nothing() {
364 let (a, b) = ws_pair().await;
365 let (tx_a, _rx_a) = a.split();
366 let (_tx_b, mut rx_b) = b.split();
367
368 let permit = tx_a.reserve().await.unwrap();
370 drop(permit);
371
372 let permit = tx_a.reserve().await.unwrap();
374 let mut slot = permit.alloc(3).unwrap();
375 slot.as_mut_slice().copy_from_slice(b"yep");
376 slot.commit();
377
378 let msg = rx_b.recv().await.unwrap().unwrap();
379 assert_eq!(payload(&msg), b"yep");
380 }
381
382 #[tokio::test]
384 async fn dropped_slot_sends_nothing() {
385 let (a, b) = ws_pair().await;
386 let (tx_a, _rx_a) = a.split();
387 let (_tx_b, mut rx_b) = b.split();
388
389 let permit = tx_a.reserve().await.unwrap();
391 let slot = permit.alloc(3).unwrap();
392 drop(slot);
393
394 let permit = tx_a.reserve().await.unwrap();
396 let mut slot = permit.alloc(2).unwrap();
397 slot.as_mut_slice().copy_from_slice(b"ok");
398 slot.commit();
399
400 let msg = rx_b.recv().await.unwrap().unwrap();
401 assert_eq!(payload(&msg), b"ok");
402 }
403}