1#![deny(
38 missing_docs,
39 unused_must_use,
40 unused_mut,
41 unused_imports,
42 unused_import_braces
43)]
44
45pub use tungstenite;
46
47mod compat;
48mod handshake;
49
50#[cfg(any(
51 feature = "async-tls",
52 feature = "async-native-tls",
53 feature = "tokio-native-tls",
54 feature = "tokio-rustls-manual-roots",
55 feature = "tokio-rustls-native-certs",
56 feature = "tokio-rustls-webpki-roots",
57 feature = "tokio-openssl",
58))]
59pub mod stream;
60
61use std::{
62 io::{Read, Write},
63 pin::Pin,
64 sync::{Arc, Mutex, MutexGuard},
65 task::{ready, Context, Poll},
66};
67
68use compat::{cvt, AllowStd, ContextWaker};
69use futures_core::stream::{FusedStream, Stream};
70use futures_io::{AsyncRead, AsyncWrite};
71use log::*;
72
73#[cfg(feature = "handshake")]
74use tungstenite::{
75 client::IntoClientRequest,
76 handshake::{
77 client::{ClientHandshake, Response},
78 server::{Callback, NoCallback},
79 HandshakeError,
80 },
81};
82use tungstenite::{
83 error::Error as WsError,
84 protocol::{Message, Role, WebSocket, WebSocketConfig},
85};
86
87#[cfg(feature = "async-std-runtime")]
88pub mod async_std;
89#[cfg(feature = "async-tls")]
90pub mod async_tls;
91#[cfg(feature = "gio-runtime")]
92pub mod gio;
93#[cfg(feature = "tokio-runtime")]
94pub mod tokio;
95
96pub mod bytes;
97pub use bytes::ByteReader;
98#[cfg(feature = "futures-03-sink")]
99pub use bytes::ByteWriter;
100
101use tungstenite::protocol::CloseFrame;
102
103#[cfg(feature = "handshake")]
116pub async fn client_async<'a, R, S>(
117 request: R,
118 stream: S,
119) -> Result<(WebSocketStream<S>, Response), WsError>
120where
121 R: IntoClientRequest + Unpin,
122 S: AsyncRead + AsyncWrite + Unpin,
123{
124 client_async_with_config(request, stream, None).await
125}
126
127#[cfg(feature = "handshake")]
130pub async fn client_async_with_config<'a, R, S>(
131 request: R,
132 stream: S,
133 config: Option<WebSocketConfig>,
134) -> Result<(WebSocketStream<S>, Response), WsError>
135where
136 R: IntoClientRequest + Unpin,
137 S: AsyncRead + AsyncWrite + Unpin,
138{
139 let f = handshake::client_handshake(stream, move |allow_std| {
140 let request = request.into_client_request()?;
141 let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
142 cli_handshake.handshake()
143 });
144 f.await.map_err(|e| match e {
145 HandshakeError::Failure(e) => e,
146 e => WsError::Io(std::io::Error::new(
147 std::io::ErrorKind::Other,
148 e.to_string(),
149 )),
150 })
151}
152
153#[cfg(feature = "handshake")]
165pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
166where
167 S: AsyncRead + AsyncWrite + Unpin,
168{
169 accept_hdr_async(stream, NoCallback).await
170}
171
172#[cfg(feature = "handshake")]
175pub async fn accept_async_with_config<S>(
176 stream: S,
177 config: Option<WebSocketConfig>,
178) -> Result<WebSocketStream<S>, WsError>
179where
180 S: AsyncRead + AsyncWrite + Unpin,
181{
182 accept_hdr_async_with_config(stream, NoCallback, config).await
183}
184
185#[cfg(feature = "handshake")]
191pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
192where
193 S: AsyncRead + AsyncWrite + Unpin,
194 C: Callback + Unpin,
195{
196 accept_hdr_async_with_config(stream, callback, None).await
197}
198
199#[cfg(feature = "handshake")]
202pub async fn accept_hdr_async_with_config<S, C>(
203 stream: S,
204 callback: C,
205 config: Option<WebSocketConfig>,
206) -> Result<WebSocketStream<S>, WsError>
207where
208 S: AsyncRead + AsyncWrite + Unpin,
209 C: Callback + Unpin,
210{
211 let f = handshake::server_handshake(stream, move |allow_std| {
212 tungstenite::accept_hdr_with_config(allow_std, callback, config)
213 });
214 f.await.map_err(|e| match e {
215 HandshakeError::Failure(e) => e,
216 e => WsError::Io(std::io::Error::new(
217 std::io::ErrorKind::Other,
218 e.to_string(),
219 )),
220 })
221}
222
223#[derive(Debug)]
233pub struct WebSocketStream<S> {
234 inner: WebSocket<AllowStd<S>>,
235 #[cfg(feature = "futures-03-sink")]
236 closing: bool,
237 ended: bool,
238 ready: bool,
243}
244
245impl<S> WebSocketStream<S> {
246 pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
249 where
250 S: AsyncRead + AsyncWrite + Unpin,
251 {
252 handshake::without_handshake(stream, move |allow_std| {
253 WebSocket::from_raw_socket(allow_std, role, config)
254 })
255 .await
256 }
257
258 pub async fn from_partially_read(
261 stream: S,
262 part: Vec<u8>,
263 role: Role,
264 config: Option<WebSocketConfig>,
265 ) -> Self
266 where
267 S: AsyncRead + AsyncWrite + Unpin,
268 {
269 handshake::without_handshake(stream, move |allow_std| {
270 WebSocket::from_partially_read(allow_std, part, role, config)
271 })
272 .await
273 }
274
275 pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
276 Self {
277 inner: ws,
278 #[cfg(feature = "futures-03-sink")]
279 closing: false,
280 ended: false,
281 ready: true,
282 }
283 }
284
285 fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
286 where
287 F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
288 AllowStd<S>: Read + Write,
289 {
290 #[cfg(feature = "verbose-logging")]
291 trace!("{}:{} WebSocketStream.with_context", file!(), line!());
292 if let Some((kind, ctx)) = ctx {
293 self.inner.get_mut().set_waker(kind, ctx.waker());
294 }
295 f(&mut self.inner)
296 }
297
298 pub fn get_ref(&self) -> &S
300 where
301 S: AsyncRead + AsyncWrite + Unpin,
302 {
303 self.inner.get_ref().get_ref()
304 }
305
306 pub fn get_mut(&mut self) -> &mut S
308 where
309 S: AsyncRead + AsyncWrite + Unpin,
310 {
311 self.inner.get_mut().get_mut()
312 }
313
314 pub fn get_config(&self) -> &WebSocketConfig {
316 self.inner.get_config()
317 }
318
319 pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
321 where
322 S: AsyncRead + AsyncWrite + Unpin,
323 {
324 self.send(Message::Close(msg)).await
325 }
326
327 pub fn split(self) -> (WebSocketSender<S>, WebSocketReceiver<S>) {
330 let shared = Arc::new(Shared(Mutex::new(self)));
331 let sender = WebSocketSender {
332 shared: shared.clone(),
333 };
334
335 let receiver = WebSocketReceiver { shared };
336 (sender, receiver)
337 }
338
339 pub fn reunite(
344 sender: WebSocketSender<S>,
345 receiver: WebSocketReceiver<S>,
346 ) -> Result<Self, (WebSocketSender<S>, WebSocketReceiver<S>)> {
347 if sender.is_pair_of(&receiver) {
348 drop(receiver);
349 let stream = Arc::try_unwrap(sender.shared)
350 .ok()
351 .expect("reunite the stream")
352 .into_inner();
353
354 Ok(stream)
355 } else {
356 Err((sender, receiver))
357 }
358 }
359}
360
361impl<T> WebSocketStream<T>
362where
363 T: AsyncRead + AsyncWrite + Unpin,
364{
365 fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Message, WsError>>> {
366 #[cfg(feature = "verbose-logging")]
367 trace!("{}:{} WebSocketStream.poll_next", file!(), line!());
368
369 if self.ended {
373 return Poll::Ready(None);
374 }
375
376 match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
377 #[cfg(feature = "verbose-logging")]
378 trace!(
379 "{}:{} WebSocketStream.with_context poll_next -> read()",
380 file!(),
381 line!()
382 );
383 cvt(s.read())
384 })) {
385 Ok(v) => Poll::Ready(Some(Ok(v))),
386 Err(e) => {
387 self.ended = true;
388 if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
389 Poll::Ready(None)
390 } else {
391 Poll::Ready(Some(Err(e)))
392 }
393 }
394 }
395 }
396
397 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
398 if self.ready {
399 return Poll::Ready(Ok(()));
400 }
401
402 self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
404 .map(|r| {
405 self.ready = true;
406 r
407 })
408 }
409
410 fn start_send(&mut self, item: Message) -> Result<(), WsError> {
411 match self.with_context(None, |s| s.write(item)) {
412 Ok(()) => {
413 self.ready = true;
414 Ok(())
415 }
416 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
417 self.ready = false;
420 Ok(())
421 }
422 Err(e) => {
423 self.ready = true;
424 debug!("websocket start_send error: {}", e);
425 Err(e)
426 }
427 }
428 }
429
430 fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
431 self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
432 .map(|r| {
433 self.ready = true;
434 match r {
435 Err(WsError::ConnectionClosed) => Ok(()),
437 other => other,
438 }
439 })
440 }
441
442 #[cfg(feature = "futures-03-sink")]
443 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
444 self.ready = true;
445 let res = if self.closing {
446 self.with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
448 } else {
449 self.with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
450 };
451
452 match res {
453 Ok(()) => Poll::Ready(Ok(())),
454 Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
455 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
456 trace!("WouldBlock");
457 self.closing = true;
458 Poll::Pending
459 }
460 Err(err) => {
461 debug!("websocket close error: {}", err);
462 Poll::Ready(Err(err))
463 }
464 }
465 }
466}
467
468impl<T> Stream for WebSocketStream<T>
469where
470 T: AsyncRead + AsyncWrite + Unpin,
471{
472 type Item = Result<Message, WsError>;
473
474 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
475 self.get_mut().poll_next(cx)
476 }
477}
478
479impl<T> FusedStream for WebSocketStream<T>
480where
481 T: AsyncRead + AsyncWrite + Unpin,
482{
483 fn is_terminated(&self) -> bool {
484 self.ended
485 }
486}
487
488#[cfg(feature = "futures-03-sink")]
489impl<T> futures_util::Sink<Message> for WebSocketStream<T>
490where
491 T: AsyncRead + AsyncWrite + Unpin,
492{
493 type Error = WsError;
494
495 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
496 self.get_mut().poll_ready(cx)
497 }
498
499 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
500 self.get_mut().start_send(item)
501 }
502
503 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
504 self.get_mut().poll_flush(cx)
505 }
506
507 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
508 self.get_mut().poll_close(cx)
509 }
510}
511
512impl<S> WebSocketStream<S> {
513 pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
515 where
516 S: AsyncRead + AsyncWrite + Unpin,
517 {
518 Send {
519 ws: self,
520 msg: Some(msg),
521 }
522 .await
523 }
524}
525
526struct Send<W> {
527 ws: W,
528 msg: Option<Message>,
529}
530
531fn send_helper<S>(
533 ws: &mut WebSocketStream<S>,
534 msg: &mut Option<Message>,
535 cx: &mut Context<'_>,
536) -> Poll<Result<(), WsError>>
537where
538 S: AsyncRead + AsyncWrite + Unpin,
539{
540 if msg.is_some() {
541 ready!(ws.poll_ready(cx))?;
542 let msg = msg.take().expect("unreachable");
543 ws.start_send(msg)?;
544 }
545
546 ws.poll_flush(cx)
547}
548
549impl<S> std::future::Future for Send<&mut WebSocketStream<S>>
550where
551 S: AsyncRead + AsyncWrite + Unpin,
552{
553 type Output = Result<(), WsError>;
554
555 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
556 let me = self.get_mut();
557 send_helper(me.ws, &mut me.msg, cx)
558 }
559}
560
561impl<S> std::future::Future for Send<&Shared<S>>
562where
563 S: AsyncRead + AsyncWrite + Unpin,
564{
565 type Output = Result<(), WsError>;
566
567 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
568 let me = self.get_mut();
569 let mut ws = me.ws.lock();
570 send_helper(&mut ws, &mut me.msg, cx)
571 }
572}
573
574#[derive(Debug)]
576pub struct WebSocketSender<S> {
577 shared: Arc<Shared<S>>,
578}
579
580impl<S> WebSocketSender<S> {
581 pub async fn send(&self, msg: Message) -> Result<(), WsError>
583 where
584 S: AsyncRead + AsyncWrite + Unpin,
585 {
586 Send {
587 ws: &*self.shared,
588 msg: Some(msg),
589 }
590 .await
591 }
592
593 pub async fn close(&self, msg: Option<CloseFrame>) -> Result<(), WsError>
595 where
596 S: AsyncRead + AsyncWrite + Unpin,
597 {
598 self.send(Message::Close(msg)).await
599 }
600
601 pub fn is_pair_of(&self, other: &WebSocketReceiver<S>) -> bool {
604 Arc::ptr_eq(&self.shared, &other.shared)
605 }
606}
607
608#[cfg(feature = "futures-03-sink")]
609impl<T> futures_util::Sink<Message> for WebSocketSender<T>
610where
611 T: AsyncRead + AsyncWrite + Unpin,
612{
613 type Error = WsError;
614
615 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
616 self.shared.lock().poll_ready(cx)
617 }
618
619 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
620 self.shared.lock().start_send(item)
621 }
622
623 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
624 self.shared.lock().poll_flush(cx)
625 }
626
627 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
628 self.shared.lock().poll_close(cx)
629 }
630}
631
632#[derive(Debug)]
634pub struct WebSocketReceiver<S> {
635 shared: Arc<Shared<S>>,
636}
637
638impl<S> WebSocketReceiver<S> {
639 pub fn is_pair_of(&self, other: &WebSocketSender<S>) -> bool {
642 Arc::ptr_eq(&self.shared, &other.shared)
643 }
644}
645
646impl<S> Stream for WebSocketReceiver<S>
647where
648 S: AsyncRead + AsyncWrite + Unpin,
649{
650 type Item = Result<Message, WsError>;
651
652 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
653 self.shared.lock().poll_next(cx)
654 }
655}
656
657impl<S> FusedStream for WebSocketReceiver<S>
658where
659 S: AsyncRead + AsyncWrite + Unpin,
660{
661 fn is_terminated(&self) -> bool {
662 self.shared.lock().ended
663 }
664}
665
666#[derive(Debug)]
667struct Shared<S>(Mutex<WebSocketStream<S>>);
668
669impl<S> Shared<S> {
670 fn lock(&self) -> MutexGuard<'_, WebSocketStream<S>> {
671 self.0.lock().expect("lock shared stream")
672 }
673
674 fn into_inner(self) -> WebSocketStream<S> {
675 self.0.into_inner().expect("get shared stream")
676 }
677}
678
679#[cfg(any(
680 feature = "async-tls",
681 feature = "async-std-runtime",
682 feature = "tokio-runtime",
683 feature = "gio-runtime"
684))]
685#[inline]
687pub(crate) fn domain(
688 request: &tungstenite::handshake::client::Request,
689) -> Result<String, tungstenite::Error> {
690 request
691 .uri()
692 .host()
693 .map(|host| {
694 let host = if host.starts_with('[') {
700 &host[1..host.len() - 1]
701 } else {
702 host
703 };
704
705 host.to_owned()
706 })
707 .ok_or(tungstenite::Error::Url(
708 tungstenite::error::UrlError::NoHostName,
709 ))
710}
711
712#[cfg(any(
713 feature = "async-std-runtime",
714 feature = "tokio-runtime",
715 feature = "gio-runtime"
716))]
717#[inline]
719pub(crate) fn port(
720 request: &tungstenite::handshake::client::Request,
721) -> Result<u16, tungstenite::Error> {
722 request
723 .uri()
724 .port_u16()
725 .or_else(|| match request.uri().scheme_str() {
726 Some("wss") => Some(443),
727 Some("ws") => Some(80),
728 _ => None,
729 })
730 .ok_or(tungstenite::Error::Url(
731 tungstenite::error::UrlError::UnsupportedUrlScheme,
732 ))
733}
734
735#[cfg(test)]
736mod tests {
737 #[cfg(any(
738 feature = "async-tls",
739 feature = "async-std-runtime",
740 feature = "tokio-runtime",
741 feature = "gio-runtime"
742 ))]
743 #[test]
744 fn domain_strips_ipv6_brackets() {
745 use tungstenite::client::IntoClientRequest;
746
747 let request = "ws://[::1]:80".into_client_request().unwrap();
748 assert_eq!(crate::domain(&request).unwrap(), "::1");
749 }
750
751 #[cfg(feature = "handshake")]
752 #[test]
753 fn requests_cannot_contain_invalid_uris() {
754 use tungstenite::client::IntoClientRequest;
755
756 assert!("ws://[".into_client_request().is_err());
757 assert!("ws://[blabla/bla".into_client_request().is_err());
758 assert!("ws://[::1/bla".into_client_request().is_err());
759 }
760}