1pub mod stream;
11
12#[cfg(feature = "rustls")]
13pub mod rustls;
14
15use std::io::ErrorKind;
16
17use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
18use tungstenite::{
19 Error as WsError, HandshakeError, Message, WebSocket,
20 client::IntoClientRequest,
21 handshake::server::{Callback, NoCallback},
22 protocol::CloseFrame,
23};
24pub use tungstenite::{
25 Message as WebSocketMessage, error::Error as TungsteniteError, handshake::client::Response,
26 protocol::WebSocketConfig,
27};
28
29#[cfg(feature = "rustls")]
30pub use crate::rustls::{
31 AutoStream, ConnectStream, Connector, client_async_tls, client_async_tls_with_config,
32 client_async_tls_with_connector, client_async_tls_with_connector_and_config, connect_async,
33 connect_async_with_config, connect_async_with_tls_connector,
34 connect_async_with_tls_connector_and_config,
35};
36
37pub struct WebSocketStream<S> {
38 inner: WebSocket<SyncStream<S>>,
39}
40
41impl<S> WebSocketStream<S>
42where
43 S: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug,
44{
45 pub async fn send(&mut self, message: Message) -> Result<(), WsError> {
46 self.inner.send(message)?;
49
50 self.inner
52 .get_mut()
53 .flush_write_buf()
54 .await
55 .map_err(WsError::Io)?;
56
57 Ok(())
58 }
59
60 pub async fn read(&mut self) -> Result<Message, WsError> {
61 loop {
62 match self.inner.read() {
63 Ok(msg) => {
64 let _ = self.inner.get_mut().flush_write_buf().await;
66 return Ok(msg);
67 }
68 Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
69 self.inner
71 .get_mut()
72 .fill_read_buf()
73 .await
74 .map_err(WsError::Io)?;
75 continue;
77 }
78 Err(e) => {
79 let _ = self.inner.get_mut().flush_write_buf().await;
81 return Err(e);
82 }
83 }
84 }
85 }
86
87 pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), WsError> {
88 loop {
89 match self.inner.close(close_frame.clone()) {
90 Ok(()) => return Ok(()),
91 Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
92 let sync_stream = self.inner.get_mut();
93
94 let flushed = sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
95
96 if flushed == 0 {
97 sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
98 }
99 continue;
100 }
101 Err(e) => return Err(e),
102 }
103 }
104 }
105
106 pub fn get_ref(&self) -> &S {
107 self.inner.get_ref().get_ref()
108 }
109
110 pub fn get_mut(&mut self) -> &mut S {
111 self.inner.get_mut().get_mut()
112 }
113
114 pub fn get_inner(self) -> WebSocket<SyncStream<S>> {
115 self.inner
116 }
117}
118
119pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
131where
132 S: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug,
133{
134 accept_hdr_async(stream, NoCallback).await
135}
136
137pub async fn accept_async_with_config<S>(
140 stream: S,
141 config: Option<WebSocketConfig>,
142) -> Result<WebSocketStream<S>, WsError>
143where
144 S: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug,
145{
146 accept_hdr_with_config_async(stream, NoCallback, config).await
147}
148pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
154where
155 S: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug,
156 C: Callback,
157{
158 accept_hdr_with_config_async(stream, callback, None).await
159}
160
161pub async fn accept_hdr_with_config_async<S, C>(
164 stream: S,
165 callback: C,
166 config: Option<WebSocketConfig>,
167) -> Result<WebSocketStream<S>, WsError>
168where
169 S: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug,
170 C: Callback,
171{
172 let sync_stream = SyncStream::new(stream);
173 let mut handshake_result = tungstenite::accept_hdr_with_config(sync_stream, callback, config);
174
175 loop {
176 match handshake_result {
177 Ok(mut websocket) => {
178 websocket
179 .get_mut()
180 .flush_write_buf()
181 .await
182 .map_err(WsError::Io)?;
183 return Ok(WebSocketStream { inner: websocket });
184 }
185 Err(HandshakeError::Interrupted(mut mid_handshake)) => {
186 let sync_stream = mid_handshake.get_mut().get_mut();
187
188 sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
189
190 sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
191
192 handshake_result = mid_handshake.handshake();
193 }
194 Err(HandshakeError::Failure(error)) => {
195 return Err(error);
196 }
197 }
198 }
199}
200
201pub async fn client_async<R, S>(
215 request: R,
216 stream: S,
217) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
218where
219 R: IntoClientRequest,
220 S: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug,
221{
222 client_async_with_config(request, stream, None).await
223}
224
225pub async fn client_async_with_config<R, S>(
228 request: R,
229 stream: S,
230 config: Option<WebSocketConfig>,
231) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
232where
233 R: IntoClientRequest,
234 S: AsyncRead + AsyncWrite + Unpin,
235{
236 let sync_stream = SyncStream::new(stream);
237 let mut handshake_result =
238 tungstenite::client::client_with_config(request, sync_stream, config);
239
240 loop {
241 match handshake_result {
242 Ok((mut websocket, response)) => {
243 websocket
245 .get_mut()
246 .flush_write_buf()
247 .await
248 .map_err(WsError::Io)?;
249 return Ok((WebSocketStream { inner: websocket }, response));
250 }
251 Err(HandshakeError::Interrupted(mut mid_handshake)) => {
252 let sync_stream = mid_handshake.get_mut().get_mut();
253
254 sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
256
257 sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
258
259 handshake_result = mid_handshake.handshake();
260 }
261 Err(HandshakeError::Failure(error)) => {
262 return Err(error);
263 }
264 }
265 }
266}
267
268#[inline]
269#[allow(clippy::result_large_err)]
270#[cfg(feature = "rustls")]
271pub(crate) fn domain(
272 request: &tungstenite::handshake::client::Request,
273) -> Result<String, tungstenite::Error> {
274 request
275 .uri()
276 .host()
277 .map(|host| {
278 let host = if host.starts_with('[') {
285 &host[1..host.len() - 1]
286 } else {
287 host
288 };
289
290 host.to_owned()
291 })
292 .ok_or(tungstenite::Error::Url(
293 tungstenite::error::UrlError::NoHostName,
294 ))
295}