1#![cfg_attr(docsrs, feature(doc_cfg))]
13#![warn(missing_docs)]
14#![deny(rustdoc::broken_intra_doc_links)]
15#![doc(
16 html_logo_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
17)]
18#![doc(
19 html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
20)]
21
22use std::io::ErrorKind;
23
24use compio_buf::IntoInner;
25use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
26use tungstenite::{
27 Error as WsError, HandshakeError, Message, WebSocket,
28 client::IntoClientRequest,
29 handshake::server::{Callback, NoCallback},
30 protocol::{CloseFrame, Role, WebSocketConfig},
31};
32
33mod tls;
34pub use tls::*;
35pub use tungstenite;
36
37pub struct Config {
48 websocket: Option<WebSocketConfig>,
50
51 buffer_size_base: usize,
53
54 buffer_size_limit: usize,
56
57 disable_nagle: bool,
60}
61
62impl Config {
63 const DEFAULT_BUF_SIZE: usize = 128 * 1024;
65 const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024;
67
68 pub fn new() -> Self {
70 Self {
71 websocket: None,
72 buffer_size_base: Self::DEFAULT_BUF_SIZE,
73 buffer_size_limit: Self::DEFAULT_MAX_BUFFER,
74 disable_nagle: false,
75 }
76 }
77
78 pub fn websocket_config(&self) -> Option<&WebSocketConfig> {
80 self.websocket.as_ref()
81 }
82
83 pub fn buffer_size_base(&self) -> usize {
85 self.buffer_size_base
86 }
87
88 pub fn buffer_size_limit(&self) -> usize {
90 self.buffer_size_limit
91 }
92
93 pub fn with_buffer_size_base(mut self, size: usize) -> Self {
97 self.buffer_size_base = size;
98 self
99 }
100
101 pub fn with_buffer_size_limit(mut self, size: usize) -> Self {
105 self.buffer_size_limit = size;
106 self
107 }
108
109 pub fn with_buffer_sizes(mut self, base: usize, limit: usize) -> Self {
113 self.buffer_size_base = base;
114 self.buffer_size_limit = limit;
115 self
116 }
117
118 pub fn disable_nagle(mut self, disable: bool) -> Self {
123 self.disable_nagle = disable;
124 self
125 }
126}
127
128impl Default for Config {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl From<WebSocketConfig> for Config {
135 fn from(config: WebSocketConfig) -> Self {
136 Self {
137 websocket: Some(config),
138 ..Default::default()
139 }
140 }
141}
142
143impl From<Option<WebSocketConfig>> for Config {
144 fn from(config: Option<WebSocketConfig>) -> Self {
145 Self {
146 websocket: config,
147 ..Default::default()
148 }
149 }
150}
151
152#[derive(Debug)]
154pub struct WebSocketStream<S> {
155 inner: WebSocket<SyncStream<S>>,
156}
157
158impl<S> WebSocketStream<S>
159where
160 S: AsyncRead + AsyncWrite,
161{
162 pub async fn from_raw_socket(stream: S, role: Role, config: impl Into<Config>) -> Self {
169 let config = config.into();
170 let sync_stream =
171 SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
172 WebSocketStream {
173 inner: WebSocket::from_raw_socket(sync_stream, role, config.websocket),
174 }
175 }
176
177 pub async fn from_partially_read(
184 stream: S,
185 part: Vec<u8>,
186 role: Role,
187 config: impl Into<Config>,
188 ) -> Self {
189 let config = config.into();
190 let sync_stream =
191 SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
192 WebSocketStream {
193 inner: WebSocket::from_partially_read(sync_stream, part, role, config.websocket),
194 }
195 }
196
197 pub async fn send(&mut self, message: Message) -> Result<(), WsError> {
199 self.inner.send(message)?;
202
203 self.flush().await
205 }
206
207 pub async fn read(&mut self) -> Result<Message, WsError> {
209 loop {
210 match self.inner.read() {
211 Ok(msg) => {
212 self.flush().await?;
213 return Ok(msg);
214 }
215 Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
216 self.inner
218 .get_mut()
219 .fill_read_buf()
220 .await
221 .map_err(WsError::Io)?;
222 }
223 Err(e) => {
224 let _ = self.flush().await;
225 return Err(e);
226 }
227 }
228 }
229 }
230
231 pub async fn flush(&mut self) -> Result<(), WsError> {
233 loop {
234 match self.inner.flush() {
235 Ok(()) => break,
236 Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
237 self.inner
238 .get_mut()
239 .flush_write_buf()
240 .await
241 .map_err(WsError::Io)?;
242 }
243 Err(WsError::ConnectionClosed) => break,
244 Err(e) => return Err(e),
245 }
246 }
247 self.inner
248 .get_mut()
249 .flush_write_buf()
250 .await
251 .map_err(WsError::Io)?;
252 Ok(())
253 }
254
255 pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), WsError> {
257 loop {
258 match self.inner.close(close_frame.clone()) {
259 Ok(()) => break,
260 Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
261 let sync_stream = self.inner.get_mut();
262
263 let flushed = sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
264
265 if flushed == 0 {
266 sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
267 }
268 }
269 Err(WsError::ConnectionClosed) => break,
270 Err(e) => return Err(e),
271 }
272 }
273 self.flush().await
274 }
275
276 pub fn get_ref(&self) -> &S {
278 self.inner.get_ref().get_ref()
279 }
280
281 pub fn get_mut(&mut self) -> &mut S {
283 self.inner.get_mut().get_mut()
284 }
285}
286
287impl<S> IntoInner for WebSocketStream<S> {
288 type Inner = WebSocket<SyncStream<S>>;
289
290 fn into_inner(self) -> Self::Inner {
291 self.inner
292 }
293}
294
295pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
306where
307 S: AsyncRead + AsyncWrite,
308{
309 accept_hdr_async(stream, NoCallback).await
310}
311
312pub async fn accept_async_with_config<S>(
314 stream: S,
315 config: impl Into<Config>,
316) -> Result<WebSocketStream<S>, WsError>
317where
318 S: AsyncRead + AsyncWrite,
319{
320 accept_hdr_with_config_async(stream, NoCallback, config).await
321}
322pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
328where
329 S: AsyncRead + AsyncWrite,
330 C: Callback,
331{
332 accept_hdr_with_config_async(stream, callback, None).await
333}
334
335pub async fn accept_hdr_with_config_async<S, C>(
337 stream: S,
338 callback: C,
339 config: impl Into<Config>,
340) -> Result<WebSocketStream<S>, WsError>
341where
342 S: AsyncRead + AsyncWrite,
343 C: Callback,
344{
345 let config = config.into();
346 let sync_stream =
347 SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
348 let mut handshake_result =
349 tungstenite::accept_hdr_with_config(sync_stream, callback, config.websocket);
350
351 loop {
352 match handshake_result {
353 Ok(mut websocket) => {
354 websocket
355 .get_mut()
356 .flush_write_buf()
357 .await
358 .map_err(WsError::Io)?;
359 return Ok(WebSocketStream { inner: websocket });
360 }
361 Err(HandshakeError::Interrupted(mut mid_handshake)) => {
362 let sync_stream = mid_handshake.get_mut().get_mut();
363
364 sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
365
366 sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
367
368 handshake_result = mid_handshake.handshake();
369 }
370 Err(HandshakeError::Failure(error)) => {
371 return Err(error);
372 }
373 }
374 }
375}
376
377pub async fn client_async<R, S>(
391 request: R,
392 stream: S,
393) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
394where
395 R: IntoClientRequest,
396 S: AsyncRead + AsyncWrite,
397{
398 client_async_with_config(request, stream, None).await
399}
400
401pub async fn client_async_with_config<R, S>(
403 request: R,
404 stream: S,
405 config: impl Into<Config>,
406) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
407where
408 R: IntoClientRequest,
409 S: AsyncRead + AsyncWrite,
410{
411 let config = config.into();
412 let sync_stream =
413 SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
414 let mut handshake_result =
415 tungstenite::client::client_with_config(request, sync_stream, config.websocket);
416
417 loop {
418 match handshake_result {
419 Ok((mut websocket, response)) => {
420 websocket
422 .get_mut()
423 .flush_write_buf()
424 .await
425 .map_err(WsError::Io)?;
426 return Ok((WebSocketStream { inner: websocket }, response));
427 }
428 Err(HandshakeError::Interrupted(mut mid_handshake)) => {
429 let sync_stream = mid_handshake.get_mut().get_mut();
430
431 sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
433
434 sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
435
436 handshake_result = mid_handshake.handshake();
437 }
438 Err(HandshakeError::Failure(error)) => {
439 return Err(error);
440 }
441 }
442 }
443}