nt_rs/
backend.rs

1use core::panic;
2use futures::{sink::SinkExt, stream::StreamExt, FutureExt};
3use http::{header::SEC_WEBSOCKET_PROTOCOL, Request};
4use std::{io::Cursor, str::FromStr};
5use tokio::{select, task::JoinHandle};
6use tokio_tungstenite::connect_async;
7use tungstenite::{handshake::client::generate_key, Message};
8
9use http::Uri;
10
11use crate::{
12    types::{BinaryMessage, TextMessage},
13    Backend, Error, Result, Timer,
14};
15
16pub struct TokioBackend {}
17
18impl Backend for TokioBackend {
19    type Output = JoinHandle<()>;
20    type Error = crate::Error;
21
22    fn create(
23        host: &str,
24        name: &str,
25        send: flume::Sender<Result<crate::NtMessage>>,
26        receive: flume::Receiver<crate::NtMessage>,
27    ) -> Result<Self::Output> {
28        let uri = Uri::from_str(&format!("ws://{host}:5810/nt/{name}"))?;
29
30        let send2 = send.clone();
31
32        Ok(tokio::spawn(async move {
33            let req = Request::builder()
34                .method("GET")
35                .header("Host", uri.host().unwrap())
36                .header("Connection", "Upgrade")
37                .header("Upgrade", "websocket")
38                .header("Sec-WebSocket-Version", "13")
39                .header("Sec-WebSocket-Key", generate_key())
40                .header("Sec-WebSocket-Protocol", "networktables.first.wpi.edu")
41                .uri(uri)
42                .body(())?;
43
44            let (mut connection, res) = connect_async(req).await?;
45
46            if res
47                .headers()
48                .get(SEC_WEBSOCKET_PROTOCOL)
49                .ok_or(Error::UnsupportedServer)?
50                != "networktables.first.wpi.edu"
51            {
52                return Err(Error::UnsupportedServer);
53            }
54
55            loop {
56                select! {
57                    message = receive.recv_async() => {
58                        let message = message?;
59
60                        match message {
61                            crate::NtMessage::Text(msg) => connection.send(Message::Text(serde_json::to_string(&[msg])?)).await?,
62                            crate::NtMessage::Binary(msg) => {
63                                let mut buf = Vec::new();
64                                msg.to_writer(&mut buf)?;
65                                connection.send(Message::Binary(buf)).await?
66                            },
67                        }
68                    }
69                    message = connection.next() => {
70                        if message.is_none() {
71                            return Ok(());
72                        }
73                        let message = message.unwrap()?;
74
75                        match message {
76                            Message::Text(msg) => {
77                                let msgs = serde_json::from_str::<Vec<TextMessage>>(&msg)?;
78                                for msg in msgs {
79                                    send.send(Ok(crate::NtMessage::Text(msg))).map_err(|_| Error::Send)?;
80                                }
81                            }
82                            Message::Binary(msg) => {
83                                let mut cursor = Cursor::new(msg);
84
85                                while (cursor.position() as usize) < cursor.get_ref().len() {
86                                    send.send(Ok(crate::NtMessage::Binary(BinaryMessage::from_reader(&mut cursor)?))).map_err(|_| Error::Send)?;
87                                }
88                            }
89                            _ => return <Result<()>>::Err(Error::UnknownFrame),
90                        }
91                    }
92                }
93            }
94        }.map(move |out| {
95            if let Err(err) = out {
96                let _res = send2.send(Err(err));
97            }
98        })))
99    }
100}
101
102impl Timer for TokioBackend {
103    async fn time(duration: std::time::Duration) {
104        tokio::time::sleep(duration).await;
105    }
106}