1#![deny(
38 missing_docs,
39 unused_must_use,
40 unused_mut,
41 unused_imports,
42 unused_import_braces
43)]
44
45pub use tungstenite;
46
47mod compat;
48mod handshake;
49
50#[cfg(any(
51 feature = "async-tls",
52 feature = "async-native-tls",
53 feature = "tokio-native-tls",
54 feature = "tokio-rustls-manual-roots",
55 feature = "tokio-rustls-native-certs",
56 feature = "tokio-rustls-webpki-roots",
57 feature = "tokio-openssl",
58))]
59pub mod stream;
60
61use std::{
62 io::{Read, Write},
63 pin::Pin,
64 sync::{Arc, Mutex, MutexGuard},
65 task::{ready, Context, Poll},
66};
67
68use compat::{cvt, AllowStd, ContextWaker};
69use futures_core::stream::{FusedStream, Stream};
70use futures_io::{AsyncRead, AsyncWrite};
71use log::*;
72
73#[cfg(feature = "handshake")]
74use tungstenite::{
75 client::IntoClientRequest,
76 handshake::{
77 client::{ClientHandshake, Response},
78 server::{Callback, NoCallback},
79 HandshakeError,
80 },
81};
82use tungstenite::{
83 error::Error as WsError,
84 protocol::{Message, Role, WebSocket, WebSocketConfig},
85};
86
87#[cfg(feature = "async-std-runtime")]
88pub mod async_std;
89#[cfg(feature = "async-tls")]
90pub mod async_tls;
91#[cfg(feature = "gio-runtime")]
92pub mod gio;
93#[cfg(feature = "tokio-runtime")]
94pub mod tokio;
95
96pub mod bytes;
97pub use bytes::ByteReader;
98pub use bytes::ByteWriter;
99
100use tungstenite::protocol::CloseFrame;
101
102#[cfg(feature = "handshake")]
115pub async fn client_async<'a, R, S>(
116 request: R,
117 stream: S,
118) -> Result<(WebSocketStream<S>, Response), WsError>
119where
120 R: IntoClientRequest + Unpin,
121 S: AsyncRead + AsyncWrite + Unpin,
122{
123 client_async_with_config(request, stream, None).await
124}
125
126#[cfg(feature = "handshake")]
129pub async fn client_async_with_config<'a, R, S>(
130 request: R,
131 stream: S,
132 config: Option<WebSocketConfig>,
133) -> Result<(WebSocketStream<S>, Response), WsError>
134where
135 R: IntoClientRequest + Unpin,
136 S: AsyncRead + AsyncWrite + Unpin,
137{
138 let f = handshake::client_handshake(stream, move |allow_std| {
139 let request = request.into_client_request()?;
140 let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
141 cli_handshake.handshake()
142 });
143 f.await.map_err(|e| match e {
144 HandshakeError::Failure(e) => e,
145 e => WsError::Io(std::io::Error::new(
146 std::io::ErrorKind::Other,
147 e.to_string(),
148 )),
149 })
150}
151
152#[cfg(feature = "handshake")]
164pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
165where
166 S: AsyncRead + AsyncWrite + Unpin,
167{
168 accept_hdr_async(stream, NoCallback).await
169}
170
171#[cfg(feature = "handshake")]
174pub async fn accept_async_with_config<S>(
175 stream: S,
176 config: Option<WebSocketConfig>,
177) -> Result<WebSocketStream<S>, WsError>
178where
179 S: AsyncRead + AsyncWrite + Unpin,
180{
181 accept_hdr_async_with_config(stream, NoCallback, config).await
182}
183
184#[cfg(feature = "handshake")]
190pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
191where
192 S: AsyncRead + AsyncWrite + Unpin,
193 C: Callback + Unpin,
194{
195 accept_hdr_async_with_config(stream, callback, None).await
196}
197
198#[cfg(feature = "handshake")]
201pub async fn accept_hdr_async_with_config<S, C>(
202 stream: S,
203 callback: C,
204 config: Option<WebSocketConfig>,
205) -> Result<WebSocketStream<S>, WsError>
206where
207 S: AsyncRead + AsyncWrite + Unpin,
208 C: Callback + Unpin,
209{
210 let f = handshake::server_handshake(stream, move |allow_std| {
211 tungstenite::accept_hdr_with_config(allow_std, callback, config)
212 });
213 f.await.map_err(|e| match e {
214 HandshakeError::Failure(e) => e,
215 e => WsError::Io(std::io::Error::new(
216 std::io::ErrorKind::Other,
217 e.to_string(),
218 )),
219 })
220}
221
222#[derive(Debug)]
232pub struct WebSocketStream<S> {
233 inner: WebSocket<AllowStd<S>>,
234 #[cfg(feature = "futures-03-sink")]
235 closing: bool,
236 ended: bool,
237 ready: bool,
242}
243
244impl<S> WebSocketStream<S> {
245 pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
248 where
249 S: AsyncRead + AsyncWrite + Unpin,
250 {
251 handshake::without_handshake(stream, move |allow_std| {
252 WebSocket::from_raw_socket(allow_std, role, config)
253 })
254 .await
255 }
256
257 pub async fn from_partially_read(
260 stream: S,
261 part: Vec<u8>,
262 role: Role,
263 config: Option<WebSocketConfig>,
264 ) -> Self
265 where
266 S: AsyncRead + AsyncWrite + Unpin,
267 {
268 handshake::without_handshake(stream, move |allow_std| {
269 WebSocket::from_partially_read(allow_std, part, role, config)
270 })
271 .await
272 }
273
274 pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
275 Self {
276 inner: ws,
277 #[cfg(feature = "futures-03-sink")]
278 closing: false,
279 ended: false,
280 ready: true,
281 }
282 }
283
284 fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
285 where
286 F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
287 AllowStd<S>: Read + Write,
288 {
289 #[cfg(feature = "verbose-logging")]
290 trace!("{}:{} WebSocketStream.with_context", file!(), line!());
291 if let Some((kind, ctx)) = ctx {
292 self.inner.get_mut().set_waker(kind, ctx.waker());
293 }
294 f(&mut self.inner)
295 }
296
297 pub fn into_inner(self) -> S {
299 self.inner.into_inner().into_inner()
300 }
301
302 pub fn get_ref(&self) -> &S
304 where
305 S: AsyncRead + AsyncWrite + Unpin,
306 {
307 self.inner.get_ref().get_ref()
308 }
309
310 pub fn get_mut(&mut self) -> &mut S
312 where
313 S: AsyncRead + AsyncWrite + Unpin,
314 {
315 self.inner.get_mut().get_mut()
316 }
317
318 pub fn get_config(&self) -> &WebSocketConfig {
320 self.inner.get_config()
321 }
322
323 pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
325 where
326 S: AsyncRead + AsyncWrite + Unpin,
327 {
328 self.send(Message::Close(msg)).await
329 }
330
331 pub fn split(self) -> (WebSocketSender<S>, WebSocketReceiver<S>) {
334 let shared = Arc::new(Shared(Mutex::new(self)));
335 let sender = WebSocketSender {
336 shared: shared.clone(),
337 };
338
339 let receiver = WebSocketReceiver { shared };
340 (sender, receiver)
341 }
342
343 pub fn reunite(
348 sender: WebSocketSender<S>,
349 receiver: WebSocketReceiver<S>,
350 ) -> Result<Self, (WebSocketSender<S>, WebSocketReceiver<S>)> {
351 if sender.is_pair_of(&receiver) {
352 drop(receiver);
353 let stream = Arc::try_unwrap(sender.shared)
354 .ok()
355 .expect("reunite the stream")
356 .into_inner();
357
358 Ok(stream)
359 } else {
360 Err((sender, receiver))
361 }
362 }
363}
364
365impl<S> WebSocketStream<S>
366where
367 S: AsyncRead + AsyncWrite + Unpin,
368{
369 fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Message, WsError>>> {
370 #[cfg(feature = "verbose-logging")]
371 trace!("{}:{} WebSocketStream.poll_next", file!(), line!());
372
373 if self.ended {
377 return Poll::Ready(None);
378 }
379
380 match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
381 #[cfg(feature = "verbose-logging")]
382 trace!(
383 "{}:{} WebSocketStream.with_context poll_next -> read()",
384 file!(),
385 line!()
386 );
387 cvt(s.read())
388 })) {
389 Ok(v) => Poll::Ready(Some(Ok(v))),
390 Err(e) => {
391 self.ended = true;
392 if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
393 Poll::Ready(None)
394 } else {
395 Poll::Ready(Some(Err(e)))
396 }
397 }
398 }
399 }
400
401 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
402 if self.ready {
403 return Poll::Ready(Ok(()));
404 }
405
406 self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
408 .map(|r| {
409 self.ready = true;
410 r
411 })
412 }
413
414 fn start_send(&mut self, item: Message) -> Result<(), WsError> {
415 match self.with_context(None, |s| s.write(item)) {
416 Ok(()) => {
417 self.ready = true;
418 Ok(())
419 }
420 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
421 self.ready = false;
424 Ok(())
425 }
426 Err(e) => {
427 self.ready = true;
428 debug!("websocket start_send error: {}", e);
429 Err(e)
430 }
431 }
432 }
433
434 fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
435 self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
436 .map(|r| {
437 self.ready = true;
438 match r {
439 Err(WsError::ConnectionClosed) => Ok(()),
441 other => other,
442 }
443 })
444 }
445
446 #[cfg(feature = "futures-03-sink")]
447 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
448 self.ready = true;
449 let res = if self.closing {
450 self.with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
452 } else {
453 self.with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
454 };
455
456 match res {
457 Ok(()) => Poll::Ready(Ok(())),
458 Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
459 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
460 trace!("WouldBlock");
461 self.closing = true;
462 Poll::Pending
463 }
464 Err(err) => {
465 debug!("websocket close error: {}", err);
466 Poll::Ready(Err(err))
467 }
468 }
469 }
470}
471
472impl<S> Stream for WebSocketStream<S>
473where
474 S: AsyncRead + AsyncWrite + Unpin,
475{
476 type Item = Result<Message, WsError>;
477
478 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
479 self.get_mut().poll_next(cx)
480 }
481}
482
483impl<S> FusedStream for WebSocketStream<S>
484where
485 S: AsyncRead + AsyncWrite + Unpin,
486{
487 fn is_terminated(&self) -> bool {
488 self.ended
489 }
490}
491
492#[cfg(feature = "futures-03-sink")]
493impl<S> futures_util::Sink<Message> for WebSocketStream<S>
494where
495 S: AsyncRead + AsyncWrite + Unpin,
496{
497 type Error = WsError;
498
499 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
500 self.get_mut().poll_ready(cx)
501 }
502
503 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
504 self.get_mut().start_send(item)
505 }
506
507 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
508 self.get_mut().poll_flush(cx)
509 }
510
511 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
512 self.get_mut().poll_close(cx)
513 }
514}
515
516#[cfg(not(feature = "futures-03-sink"))]
517impl<S> bytes::private::SealedSender for WebSocketStream<S>
518where
519 S: AsyncRead + AsyncWrite + Unpin,
520{
521 fn poll_write(
522 self: Pin<&mut Self>,
523 cx: &mut Context<'_>,
524 buf: &[u8],
525 ) -> Poll<Result<usize, WsError>> {
526 let me = self.get_mut();
527 ready!(me.poll_ready(cx))?;
528 let len = buf.len();
529 me.start_send(Message::binary(buf.to_owned()))?;
530 Poll::Ready(Ok(len))
531 }
532
533 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
534 self.get_mut().poll_flush(cx)
535 }
536
537 fn poll_close(
538 self: Pin<&mut Self>,
539 cx: &mut Context<'_>,
540 msg: &mut Option<Message>,
541 ) -> Poll<Result<(), WsError>> {
542 let me = self.get_mut();
543 send_helper(me, msg, cx)
544 }
545}
546
547impl<S> WebSocketStream<S> {
548 pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
550 where
551 S: AsyncRead + AsyncWrite + Unpin,
552 {
553 Send {
554 ws: self,
555 msg: Some(msg),
556 }
557 .await
558 }
559}
560
561struct Send<W> {
562 ws: W,
563 msg: Option<Message>,
564}
565
566fn send_helper<S>(
568 ws: &mut WebSocketStream<S>,
569 msg: &mut Option<Message>,
570 cx: &mut Context<'_>,
571) -> Poll<Result<(), WsError>>
572where
573 S: AsyncRead + AsyncWrite + Unpin,
574{
575 if msg.is_some() {
576 ready!(ws.poll_ready(cx))?;
577 let msg = msg.take().expect("unreachable");
578 ws.start_send(msg)?;
579 }
580
581 ws.poll_flush(cx)
582}
583
584impl<S> std::future::Future for Send<&mut WebSocketStream<S>>
585where
586 S: AsyncRead + AsyncWrite + Unpin,
587{
588 type Output = Result<(), WsError>;
589
590 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
591 let me = self.get_mut();
592 send_helper(me.ws, &mut me.msg, cx)
593 }
594}
595
596impl<S> std::future::Future for Send<&Shared<S>>
597where
598 S: AsyncRead + AsyncWrite + Unpin,
599{
600 type Output = Result<(), WsError>;
601
602 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
603 let me = self.get_mut();
604 let mut ws = me.ws.lock();
605 send_helper(&mut ws, &mut me.msg, cx)
606 }
607}
608
609#[derive(Debug)]
611pub struct WebSocketSender<S> {
612 shared: Arc<Shared<S>>,
613}
614
615impl<S> WebSocketSender<S> {
616 pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
618 where
619 S: AsyncRead + AsyncWrite + Unpin,
620 {
621 Send {
622 ws: &*self.shared,
623 msg: Some(msg),
624 }
625 .await
626 }
627
628 pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
630 where
631 S: AsyncRead + AsyncWrite + Unpin,
632 {
633 self.send(Message::Close(msg)).await
634 }
635
636 pub fn is_pair_of(&self, other: &WebSocketReceiver<S>) -> bool {
639 Arc::ptr_eq(&self.shared, &other.shared)
640 }
641}
642
643#[cfg(feature = "futures-03-sink")]
644impl<T> futures_util::Sink<Message> for WebSocketSender<T>
645where
646 T: AsyncRead + AsyncWrite + Unpin,
647{
648 type Error = WsError;
649
650 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
651 self.shared.lock().poll_ready(cx)
652 }
653
654 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
655 self.shared.lock().start_send(item)
656 }
657
658 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
659 self.shared.lock().poll_flush(cx)
660 }
661
662 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
663 self.shared.lock().poll_close(cx)
664 }
665}
666
667#[cfg(not(feature = "futures-03-sink"))]
668impl<S> bytes::private::SealedSender for WebSocketSender<S>
669where
670 S: AsyncRead + AsyncWrite + Unpin,
671{
672 fn poll_write(
673 self: Pin<&mut Self>,
674 cx: &mut Context<'_>,
675 buf: &[u8],
676 ) -> Poll<Result<usize, WsError>> {
677 let me = self.get_mut();
678 let mut ws = me.shared.lock();
679 ready!(ws.poll_ready(cx))?;
680 let len = buf.len();
681 ws.start_send(Message::binary(buf.to_owned()))?;
682 Poll::Ready(Ok(len))
683 }
684
685 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
686 self.shared.lock().poll_flush(cx)
687 }
688
689 fn poll_close(
690 self: Pin<&mut Self>,
691 cx: &mut Context<'_>,
692 msg: &mut Option<Message>,
693 ) -> Poll<Result<(), WsError>> {
694 let me = self.get_mut();
695 let mut ws = me.shared.lock();
696 send_helper(&mut ws, msg, cx)
697 }
698}
699
700#[derive(Debug)]
702pub struct WebSocketReceiver<S> {
703 shared: Arc<Shared<S>>,
704}
705
706impl<S> WebSocketReceiver<S> {
707 pub fn is_pair_of(&self, other: &WebSocketSender<S>) -> bool {
710 Arc::ptr_eq(&self.shared, &other.shared)
711 }
712}
713
714impl<S> Stream for WebSocketReceiver<S>
715where
716 S: AsyncRead + AsyncWrite + Unpin,
717{
718 type Item = Result<Message, WsError>;
719
720 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
721 self.shared.lock().poll_next(cx)
722 }
723}
724
725impl<S> FusedStream for WebSocketReceiver<S>
726where
727 S: AsyncRead + AsyncWrite + Unpin,
728{
729 fn is_terminated(&self) -> bool {
730 self.shared.lock().ended
731 }
732}
733
734#[derive(Debug)]
735struct Shared<S>(Mutex<WebSocketStream<S>>);
736
737impl<S> Shared<S> {
738 fn lock(&self) -> MutexGuard<'_, WebSocketStream<S>> {
739 self.0.lock().expect("lock shared stream")
740 }
741
742 fn into_inner(self) -> WebSocketStream<S> {
743 self.0.into_inner().expect("get shared stream")
744 }
745}
746
747#[cfg(any(
748 feature = "async-tls",
749 feature = "async-std-runtime",
750 feature = "tokio-runtime",
751 feature = "gio-runtime"
752))]
753#[inline]
755pub(crate) fn domain(
756 request: &tungstenite::handshake::client::Request,
757) -> Result<String, tungstenite::Error> {
758 request
759 .uri()
760 .host()
761 .map(|host| {
762 let host = if host.starts_with('[') {
768 &host[1..host.len() - 1]
769 } else {
770 host
771 };
772
773 host.to_owned()
774 })
775 .ok_or(tungstenite::Error::Url(
776 tungstenite::error::UrlError::NoHostName,
777 ))
778}
779
780#[cfg(any(
781 feature = "async-std-runtime",
782 feature = "tokio-runtime",
783 feature = "gio-runtime"
784))]
785#[inline]
787pub(crate) fn port(
788 request: &tungstenite::handshake::client::Request,
789) -> Result<u16, tungstenite::Error> {
790 request
791 .uri()
792 .port_u16()
793 .or_else(|| match request.uri().scheme_str() {
794 Some("wss") => Some(443),
795 Some("ws") => Some(80),
796 _ => None,
797 })
798 .ok_or(tungstenite::Error::Url(
799 tungstenite::error::UrlError::UnsupportedUrlScheme,
800 ))
801}
802
803#[cfg(test)]
804mod tests {
805 #[cfg(any(
806 feature = "async-tls",
807 feature = "async-std-runtime",
808 feature = "tokio-runtime",
809 feature = "gio-runtime"
810 ))]
811 #[test]
812 fn domain_strips_ipv6_brackets() {
813 use tungstenite::client::IntoClientRequest;
814
815 let request = "ws://[::1]:80".into_client_request().unwrap();
816 assert_eq!(crate::domain(&request).unwrap(), "::1");
817 }
818
819 #[cfg(feature = "handshake")]
820 #[test]
821 fn requests_cannot_contain_invalid_uris() {
822 use tungstenite::client::IntoClientRequest;
823
824 assert!("ws://[".into_client_request().is_err());
825 assert!("ws://[blabla/bla".into_client_request().is_err());
826 assert!("ws://[::1/bla".into_client_request().is_err());
827 }
828}