Skip to main content

airio_ws/
lib.rs

1use std::{
2    net::SocketAddr,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use airio_core::{ListenerEvent, Transport, utils::RwStreamSink};
8use airio_tcp::TcpStream;
9use async_tungstenite::{
10    accept_async_with_config, client_async_with_config,
11    tungstenite::{self, protocol::WebSocketConfig},
12};
13use futures::{FutureExt, Stream, TryFutureExt};
14
15use crate::framed::BytesWebSocketStream;
16pub use tungstenite::Error;
17
18mod framed;
19
20#[derive(Debug, Clone)]
21pub struct Config {
22    pub websocket: WebSocketConfig,
23    pub tcp: airio_tcp::Config,
24}
25
26impl Default for Config {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl Config {
33    pub fn new() -> Self {
34        Self {
35            websocket: WebSocketConfig::default(),
36            tcp: airio_tcp::Config::default(),
37        }
38    }
39
40    /// Set [`Self::read_buffer_size`].
41    pub fn read_buffer_size(mut self, read_buffer_size: usize) -> Self {
42        self.websocket.read_buffer_size = read_buffer_size;
43        self
44    }
45
46    /// Set [`Self::write_buffer_size`].
47    pub fn write_buffer_size(mut self, write_buffer_size: usize) -> Self {
48        self.websocket.write_buffer_size = write_buffer_size;
49        self
50    }
51
52    /// Set [`Self::max_write_buffer_size`].
53    pub fn max_write_buffer_size(mut self, max_write_buffer_size: usize) -> Self {
54        self.websocket.max_write_buffer_size = max_write_buffer_size;
55        self
56    }
57
58    /// Set [`Self::max_message_size`].
59    pub fn max_message_size(mut self, max_message_size: Option<usize>) -> Self {
60        self.websocket.max_message_size = max_message_size;
61        self
62    }
63
64    /// Set [`Self::max_frame_size`].
65    pub fn max_frame_size(mut self, max_frame_size: Option<usize>) -> Self {
66        self.websocket.max_frame_size = max_frame_size;
67        self
68    }
69
70    /// Set [`Self::accept_unmasked_frames`].
71    pub fn accept_unmasked_frames(mut self, accept_unmasked_frames: bool) -> Self {
72        self.websocket.accept_unmasked_frames = accept_unmasked_frames;
73        self
74    }
75}
76
77type ListenerUpgrade = Pin<
78    Box<dyn Future<Output = Result<RwStreamSink<BytesWebSocketStream<TcpStream>>, Error>> + Send>,
79>;
80
81impl Transport for Config {
82    type Output = RwStreamSink<BytesWebSocketStream<TcpStream>>;
83    type Error = tungstenite::Error;
84    type Dialer = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
85    type ListenerUpgrade = ListenerUpgrade;
86    type Listener = ListenStream;
87
88    fn connect(&self, addr: SocketAddr) -> Result<Self::Dialer, Self::Error> {
89        let dialer = self.tcp.connect(addr)?;
90        let config = self.websocket.clone();
91        let request = tungstenite::http::Uri::builder()
92            .scheme("ws")
93            .authority(addr.to_string())
94            .path_and_query("/")
95            .build()
96            .map_err(tungstenite::Error::from)?;
97        tracing::debug!("Connecting to WebSocket at {}", request);
98        Ok(dialer
99            .map_err(tungstenite::Error::from)
100            .and_then(move |stream| client_async_with_config(request, stream, Some(config)))
101            .map_ok(|(s, response)| {
102                tracing::debug!("WebSocket handshake response: {:?}", response);
103                BytesWebSocketStream::new(s)
104            })
105            .map_ok(RwStreamSink::new)
106            .boxed())
107    }
108
109    fn listen(&self, addr: SocketAddr) -> Result<Self::Listener, Self::Error> {
110        let listener = self.tcp.listen(addr)?;
111        tracing::debug!("Listening for WebSocket connections on {}", addr);
112        Ok(ListenStream {
113            config: self.websocket.clone(),
114            inner: listener,
115        })
116    }
117}
118
119pub struct ListenStream {
120    config: WebSocketConfig,
121    inner: airio_tcp::ListenStream,
122}
123
124impl Stream for ListenStream {
125    type Item = ListenerEvent<ListenerUpgrade, Error>;
126
127    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
128        let config = self.config.clone();
129        let event = match Pin::new(&mut self.inner).poll_next(cx) {
130            Poll::Ready(Some(event)) => event
131                .map_upgrade(|u| {
132                    u.map_err(Error::from)
133                        .and_then(move |stream| accept_async_with_config(stream, Some(config)))
134                        .map_ok(BytesWebSocketStream::new)
135                        .map_ok(RwStreamSink::new)
136                        .boxed()
137                })
138                .map_err(Error::from),
139            Poll::Ready(None) => return Poll::Ready(None),
140            Poll::Pending => return Poll::Pending,
141        };
142        Poll::Ready(Some(event))
143    }
144}