1use std::{
2 net::SocketAddr,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use airio_core::{ListenerEvent, Transport, utils::RwStreamSink};
8use airio_tcp::TcpStream;
9use async_tungstenite::{
10 accept_async_with_config, client_async_with_config,
11 tungstenite::{self, protocol::WebSocketConfig},
12};
13use futures::{FutureExt, Stream, TryFutureExt};
14
15use crate::framed::BytesWebSocketStream;
16pub use tungstenite::Error;
17
18mod framed;
19
20#[derive(Debug, Clone)]
21pub struct Config {
22 pub websocket: WebSocketConfig,
23 pub tcp: airio_tcp::Config,
24}
25
26impl Default for Config {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl Config {
33 pub fn new() -> Self {
34 Self {
35 websocket: WebSocketConfig::default(),
36 tcp: airio_tcp::Config::default(),
37 }
38 }
39
40 pub fn read_buffer_size(mut self, read_buffer_size: usize) -> Self {
42 self.websocket.read_buffer_size = read_buffer_size;
43 self
44 }
45
46 pub fn write_buffer_size(mut self, write_buffer_size: usize) -> Self {
48 self.websocket.write_buffer_size = write_buffer_size;
49 self
50 }
51
52 pub fn max_write_buffer_size(mut self, max_write_buffer_size: usize) -> Self {
54 self.websocket.max_write_buffer_size = max_write_buffer_size;
55 self
56 }
57
58 pub fn max_message_size(mut self, max_message_size: Option<usize>) -> Self {
60 self.websocket.max_message_size = max_message_size;
61 self
62 }
63
64 pub fn max_frame_size(mut self, max_frame_size: Option<usize>) -> Self {
66 self.websocket.max_frame_size = max_frame_size;
67 self
68 }
69
70 pub fn accept_unmasked_frames(mut self, accept_unmasked_frames: bool) -> Self {
72 self.websocket.accept_unmasked_frames = accept_unmasked_frames;
73 self
74 }
75}
76
77type ListenerUpgrade = Pin<
78 Box<dyn Future<Output = Result<RwStreamSink<BytesWebSocketStream<TcpStream>>, Error>> + Send>,
79>;
80
81impl Transport for Config {
82 type Output = RwStreamSink<BytesWebSocketStream<TcpStream>>;
83 type Error = tungstenite::Error;
84 type Dialer = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
85 type ListenerUpgrade = ListenerUpgrade;
86 type Listener = ListenStream;
87
88 fn connect(&self, addr: SocketAddr) -> Result<Self::Dialer, Self::Error> {
89 let dialer = self.tcp.connect(addr)?;
90 let config = self.websocket.clone();
91 let request = tungstenite::http::Uri::builder()
92 .scheme("ws")
93 .authority(addr.to_string())
94 .path_and_query("/")
95 .build()
96 .map_err(tungstenite::Error::from)?;
97 tracing::debug!("Connecting to WebSocket at {}", request);
98 Ok(dialer
99 .map_err(tungstenite::Error::from)
100 .and_then(move |stream| client_async_with_config(request, stream, Some(config)))
101 .map_ok(|(s, response)| {
102 tracing::debug!("WebSocket handshake response: {:?}", response);
103 BytesWebSocketStream::new(s)
104 })
105 .map_ok(RwStreamSink::new)
106 .boxed())
107 }
108
109 fn listen(&self, addr: SocketAddr) -> Result<Self::Listener, Self::Error> {
110 let listener = self.tcp.listen(addr)?;
111 tracing::debug!("Listening for WebSocket connections on {}", addr);
112 Ok(ListenStream {
113 config: self.websocket.clone(),
114 inner: listener,
115 })
116 }
117}
118
119pub struct ListenStream {
120 config: WebSocketConfig,
121 inner: airio_tcp::ListenStream,
122}
123
124impl Stream for ListenStream {
125 type Item = ListenerEvent<ListenerUpgrade, Error>;
126
127 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
128 let config = self.config.clone();
129 let event = match Pin::new(&mut self.inner).poll_next(cx) {
130 Poll::Ready(Some(event)) => event
131 .map_upgrade(|u| {
132 u.map_err(Error::from)
133 .and_then(move |stream| accept_async_with_config(stream, Some(config)))
134 .map_ok(BytesWebSocketStream::new)
135 .map_ok(RwStreamSink::new)
136 .boxed()
137 })
138 .map_err(Error::from),
139 Poll::Ready(None) => return Poll::Ready(None),
140 Poll::Pending => return Poll::Pending,
141 };
142 Poll::Ready(Some(event))
143 }
144}