1#![cfg_attr(docsrs, feature(doc_cfg))]
11#![warn(missing_docs)]
12
13use std::io::ErrorKind;
14
15use compio_buf::IntoInner;
16use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
17use tungstenite::{
18 Error as WsError, HandshakeError, Message, WebSocket,
19 client::IntoClientRequest,
20 handshake::server::{Callback, NoCallback},
21 protocol::{CloseFrame, WebSocketConfig},
22};
23
24mod tls;
25pub use tls::*;
26pub use tungstenite;
27
28pub struct Config {
39 websocket: Option<WebSocketConfig>,
41
42 buffer_size_base: usize,
44
45 buffer_size_limit: usize,
47
48 disable_nagle: bool,
51}
52
53impl Config {
54 const DEFAULT_BUF_SIZE: usize = 128 * 1024;
56 const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024;
58
59 pub fn new() -> Self {
61 Self {
62 websocket: None,
63 buffer_size_base: Self::DEFAULT_BUF_SIZE,
64 buffer_size_limit: Self::DEFAULT_MAX_BUFFER,
65 disable_nagle: false,
66 }
67 }
68
69 pub fn websocket_config(&self) -> Option<&WebSocketConfig> {
71 self.websocket.as_ref()
72 }
73
74 pub fn buffer_size_base(&self) -> usize {
76 self.buffer_size_base
77 }
78
79 pub fn buffer_size_limit(&self) -> usize {
81 self.buffer_size_limit
82 }
83
84 pub fn with_buffer_size_base(mut self, size: usize) -> Self {
88 self.buffer_size_base = size;
89 self
90 }
91
92 pub fn with_buffer_size_limit(mut self, size: usize) -> Self {
96 self.buffer_size_limit = size;
97 self
98 }
99
100 pub fn with_buffer_sizes(mut self, base: usize, limit: usize) -> Self {
104 self.buffer_size_base = base;
105 self.buffer_size_limit = limit;
106 self
107 }
108
109 pub fn disable_nagle(mut self, disable: bool) -> Self {
114 self.disable_nagle = disable;
115 self
116 }
117}
118
119impl Default for Config {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125impl From<WebSocketConfig> for Config {
126 fn from(config: WebSocketConfig) -> Self {
127 Self {
128 websocket: Some(config),
129 ..Default::default()
130 }
131 }
132}
133
134impl From<Option<WebSocketConfig>> for Config {
135 fn from(config: Option<WebSocketConfig>) -> Self {
136 Self {
137 websocket: config,
138 ..Default::default()
139 }
140 }
141}
142
143#[derive(Debug)]
145pub struct WebSocketStream<S> {
146 inner: WebSocket<SyncStream<S>>,
147}
148
149impl<S> WebSocketStream<S>
150where
151 S: AsyncRead + AsyncWrite,
152{
153 pub async fn send(&mut self, message: Message) -> Result<(), WsError> {
155 self.inner.send(message)?;
158
159 self.flush().await
161 }
162
163 pub async fn read(&mut self) -> Result<Message, WsError> {
165 loop {
166 match self.inner.read() {
167 Ok(msg) => {
168 self.flush().await?;
169 return Ok(msg);
170 }
171 Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
172 self.inner
174 .get_mut()
175 .fill_read_buf()
176 .await
177 .map_err(WsError::Io)?;
178 }
179 Err(e) => {
180 let _ = self.flush().await;
181 return Err(e);
182 }
183 }
184 }
185 }
186
187 pub async fn flush(&mut self) -> Result<(), WsError> {
189 loop {
190 match self.inner.flush() {
191 Ok(()) => break,
192 Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
193 self.inner
194 .get_mut()
195 .flush_write_buf()
196 .await
197 .map_err(WsError::Io)?;
198 }
199 Err(WsError::ConnectionClosed) => break,
200 Err(e) => return Err(e),
201 }
202 }
203 self.inner
204 .get_mut()
205 .flush_write_buf()
206 .await
207 .map_err(WsError::Io)?;
208 Ok(())
209 }
210
211 pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), WsError> {
213 loop {
214 match self.inner.close(close_frame.clone()) {
215 Ok(()) => break,
216 Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
217 let sync_stream = self.inner.get_mut();
218
219 let flushed = sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
220
221 if flushed == 0 {
222 sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
223 }
224 }
225 Err(WsError::ConnectionClosed) => break,
226 Err(e) => return Err(e),
227 }
228 }
229 self.flush().await
230 }
231
232 pub fn get_ref(&self) -> &S {
234 self.inner.get_ref().get_ref()
235 }
236
237 pub fn get_mut(&mut self) -> &mut S {
239 self.inner.get_mut().get_mut()
240 }
241}
242
243impl<S> IntoInner for WebSocketStream<S> {
244 type Inner = WebSocket<SyncStream<S>>;
245
246 fn into_inner(self) -> Self::Inner {
247 self.inner
248 }
249}
250
251pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
262where
263 S: AsyncRead + AsyncWrite,
264{
265 accept_hdr_async(stream, NoCallback).await
266}
267
268pub async fn accept_async_with_config<S>(
270 stream: S,
271 config: impl Into<Config>,
272) -> Result<WebSocketStream<S>, WsError>
273where
274 S: AsyncRead + AsyncWrite,
275{
276 accept_hdr_with_config_async(stream, NoCallback, config).await
277}
278pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
284where
285 S: AsyncRead + AsyncWrite,
286 C: Callback,
287{
288 accept_hdr_with_config_async(stream, callback, None).await
289}
290
291pub async fn accept_hdr_with_config_async<S, C>(
293 stream: S,
294 callback: C,
295 config: impl Into<Config>,
296) -> Result<WebSocketStream<S>, WsError>
297where
298 S: AsyncRead + AsyncWrite,
299 C: Callback,
300{
301 let config = config.into();
302 let sync_stream =
303 SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
304 let mut handshake_result =
305 tungstenite::accept_hdr_with_config(sync_stream, callback, config.websocket);
306
307 loop {
308 match handshake_result {
309 Ok(mut websocket) => {
310 websocket
311 .get_mut()
312 .flush_write_buf()
313 .await
314 .map_err(WsError::Io)?;
315 return Ok(WebSocketStream { inner: websocket });
316 }
317 Err(HandshakeError::Interrupted(mut mid_handshake)) => {
318 let sync_stream = mid_handshake.get_mut().get_mut();
319
320 sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
321
322 sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
323
324 handshake_result = mid_handshake.handshake();
325 }
326 Err(HandshakeError::Failure(error)) => {
327 return Err(error);
328 }
329 }
330 }
331}
332
333pub async fn client_async<R, S>(
347 request: R,
348 stream: S,
349) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
350where
351 R: IntoClientRequest,
352 S: AsyncRead + AsyncWrite,
353{
354 client_async_with_config(request, stream, None).await
355}
356
357pub async fn client_async_with_config<R, S>(
359 request: R,
360 stream: S,
361 config: impl Into<Config>,
362) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
363where
364 R: IntoClientRequest,
365 S: AsyncRead + AsyncWrite,
366{
367 let config = config.into();
368 let sync_stream =
369 SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
370 let mut handshake_result =
371 tungstenite::client::client_with_config(request, sync_stream, config.websocket);
372
373 loop {
374 match handshake_result {
375 Ok((mut websocket, response)) => {
376 websocket
378 .get_mut()
379 .flush_write_buf()
380 .await
381 .map_err(WsError::Io)?;
382 return Ok((WebSocketStream { inner: websocket }, response));
383 }
384 Err(HandshakeError::Interrupted(mut mid_handshake)) => {
385 let sync_stream = mid_handshake.get_mut().get_mut();
386
387 sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
389
390 sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
391
392 handshake_result = mid_handshake.handshake();
393 }
394 Err(HandshakeError::Failure(error)) => {
395 return Err(error);
396 }
397 }
398 }
399}