1#![deny(
38 missing_docs,
39 unused_must_use,
40 unused_mut,
41 unused_imports,
42 unused_import_braces
43)]
44
45pub use tungstenite;
46
47mod compat;
48mod handshake;
49
50#[cfg(any(
51 feature = "async-tls",
52 feature = "async-native-tls",
53 feature = "tokio-native-tls",
54 feature = "tokio-rustls-manual-roots",
55 feature = "tokio-rustls-native-certs",
56 feature = "tokio-rustls-webpki-roots",
57 feature = "tokio-openssl",
58))]
59pub mod stream;
60
61use std::{
62 io::{Read, Write},
63 pin::Pin,
64 sync::{Arc, Mutex, MutexGuard},
65 task::{ready, Context, Poll},
66};
67
68use compat::{cvt, AllowStd, ContextWaker};
69use futures_core::stream::{FusedStream, Stream};
70use futures_io::{AsyncRead, AsyncWrite};
71use log::*;
72
73#[cfg(feature = "handshake")]
74use tungstenite::{
75 client::IntoClientRequest,
76 handshake::{
77 client::{ClientHandshake, Response},
78 server::{Callback, NoCallback},
79 HandshakeError,
80 },
81};
82use tungstenite::{
83 error::Error as WsError,
84 protocol::{Message, Role, WebSocket, WebSocketConfig},
85};
86
87#[cfg(feature = "async-std-runtime")]
88pub mod async_std;
89#[cfg(feature = "async-tls")]
90pub mod async_tls;
91#[cfg(feature = "gio-runtime")]
92pub mod gio;
93#[cfg(feature = "tokio-runtime")]
94pub mod tokio;
95
96pub mod bytes;
97pub use bytes::ByteReader;
98pub use bytes::ByteWriter;
99
100use tungstenite::protocol::CloseFrame;
101
102#[cfg(feature = "handshake")]
115pub async fn client_async<'a, R, S>(
116 request: R,
117 stream: S,
118) -> Result<(WebSocketStream<S>, Response), WsError>
119where
120 R: IntoClientRequest + Unpin,
121 S: AsyncRead + AsyncWrite + Unpin,
122{
123 client_async_with_config(request, stream, None).await
124}
125
126#[cfg(feature = "handshake")]
129pub async fn client_async_with_config<'a, R, S>(
130 request: R,
131 stream: S,
132 config: Option<WebSocketConfig>,
133) -> Result<(WebSocketStream<S>, Response), WsError>
134where
135 R: IntoClientRequest + Unpin,
136 S: AsyncRead + AsyncWrite + Unpin,
137{
138 let f = handshake::client_handshake(stream, move |allow_std| {
139 let request = request.into_client_request()?;
140 let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
141 cli_handshake.handshake()
142 });
143 f.await.map_err(|e| match e {
144 HandshakeError::Failure(e) => e,
145 e => WsError::Io(std::io::Error::new(
146 std::io::ErrorKind::Other,
147 e.to_string(),
148 )),
149 })
150}
151
152#[cfg(feature = "handshake")]
164pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
165where
166 S: AsyncRead + AsyncWrite + Unpin,
167{
168 accept_hdr_async(stream, NoCallback).await
169}
170
171#[cfg(feature = "handshake")]
174pub async fn accept_async_with_config<S>(
175 stream: S,
176 config: Option<WebSocketConfig>,
177) -> Result<WebSocketStream<S>, WsError>
178where
179 S: AsyncRead + AsyncWrite + Unpin,
180{
181 accept_hdr_async_with_config(stream, NoCallback, config).await
182}
183
184#[cfg(feature = "handshake")]
190pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
191where
192 S: AsyncRead + AsyncWrite + Unpin,
193 C: Callback + Unpin,
194{
195 accept_hdr_async_with_config(stream, callback, None).await
196}
197
198#[cfg(feature = "handshake")]
201pub async fn accept_hdr_async_with_config<S, C>(
202 stream: S,
203 callback: C,
204 config: Option<WebSocketConfig>,
205) -> Result<WebSocketStream<S>, WsError>
206where
207 S: AsyncRead + AsyncWrite + Unpin,
208 C: Callback + Unpin,
209{
210 let f = handshake::server_handshake(stream, move |allow_std| {
211 tungstenite::accept_hdr_with_config(allow_std, callback, config)
212 });
213 f.await.map_err(|e| match e {
214 HandshakeError::Failure(e) => e,
215 e => WsError::Io(std::io::Error::new(
216 std::io::ErrorKind::Other,
217 e.to_string(),
218 )),
219 })
220}
221
222#[derive(Debug)]
232pub struct WebSocketStream<S> {
233 inner: WebSocket<AllowStd<S>>,
234 #[cfg(feature = "futures-03-sink")]
235 closing: bool,
236 ended: bool,
237 ready: bool,
242}
243
244impl<S> WebSocketStream<S> {
245 pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
248 where
249 S: AsyncRead + AsyncWrite + Unpin,
250 {
251 handshake::without_handshake(stream, move |allow_std| {
252 WebSocket::from_raw_socket(allow_std, role, config)
253 })
254 .await
255 }
256
257 pub async fn from_partially_read(
260 stream: S,
261 part: Vec<u8>,
262 role: Role,
263 config: Option<WebSocketConfig>,
264 ) -> Self
265 where
266 S: AsyncRead + AsyncWrite + Unpin,
267 {
268 handshake::without_handshake(stream, move |allow_std| {
269 WebSocket::from_partially_read(allow_std, part, role, config)
270 })
271 .await
272 }
273
274 pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
275 Self {
276 inner: ws,
277 #[cfg(feature = "futures-03-sink")]
278 closing: false,
279 ended: false,
280 ready: true,
281 }
282 }
283
284 fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
285 where
286 F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
287 AllowStd<S>: Read + Write,
288 {
289 #[cfg(feature = "verbose-logging")]
290 trace!("{}:{} WebSocketStream.with_context", file!(), line!());
291 if let Some((kind, ctx)) = ctx {
292 self.inner.get_mut().set_waker(kind, ctx.waker());
293 }
294 f(&mut self.inner)
295 }
296
297 pub fn get_ref(&self) -> &S
299 where
300 S: AsyncRead + AsyncWrite + Unpin,
301 {
302 self.inner.get_ref().get_ref()
303 }
304
305 pub fn get_mut(&mut self) -> &mut S
307 where
308 S: AsyncRead + AsyncWrite + Unpin,
309 {
310 self.inner.get_mut().get_mut()
311 }
312
313 pub fn get_config(&self) -> &WebSocketConfig {
315 self.inner.get_config()
316 }
317
318 pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
320 where
321 S: AsyncRead + AsyncWrite + Unpin,
322 {
323 self.send(Message::Close(msg)).await
324 }
325
326 pub fn split(self) -> (WebSocketSender<S>, WebSocketReceiver<S>) {
329 let shared = Arc::new(Shared(Mutex::new(self)));
330 let sender = WebSocketSender {
331 shared: shared.clone(),
332 };
333
334 let receiver = WebSocketReceiver { shared };
335 (sender, receiver)
336 }
337
338 pub fn reunite(
343 sender: WebSocketSender<S>,
344 receiver: WebSocketReceiver<S>,
345 ) -> Result<Self, (WebSocketSender<S>, WebSocketReceiver<S>)> {
346 if sender.is_pair_of(&receiver) {
347 drop(receiver);
348 let stream = Arc::try_unwrap(sender.shared)
349 .ok()
350 .expect("reunite the stream")
351 .into_inner();
352
353 Ok(stream)
354 } else {
355 Err((sender, receiver))
356 }
357 }
358}
359
360impl<S> WebSocketStream<S>
361where
362 S: AsyncRead + AsyncWrite + Unpin,
363{
364 fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Message, WsError>>> {
365 #[cfg(feature = "verbose-logging")]
366 trace!("{}:{} WebSocketStream.poll_next", file!(), line!());
367
368 if self.ended {
372 return Poll::Ready(None);
373 }
374
375 match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
376 #[cfg(feature = "verbose-logging")]
377 trace!(
378 "{}:{} WebSocketStream.with_context poll_next -> read()",
379 file!(),
380 line!()
381 );
382 cvt(s.read())
383 })) {
384 Ok(v) => Poll::Ready(Some(Ok(v))),
385 Err(e) => {
386 self.ended = true;
387 if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
388 Poll::Ready(None)
389 } else {
390 Poll::Ready(Some(Err(e)))
391 }
392 }
393 }
394 }
395
396 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
397 if self.ready {
398 return Poll::Ready(Ok(()));
399 }
400
401 self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
403 .map(|r| {
404 self.ready = true;
405 r
406 })
407 }
408
409 fn start_send(&mut self, item: Message) -> Result<(), WsError> {
410 match self.with_context(None, |s| s.write(item)) {
411 Ok(()) => {
412 self.ready = true;
413 Ok(())
414 }
415 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
416 self.ready = false;
419 Ok(())
420 }
421 Err(e) => {
422 self.ready = true;
423 debug!("websocket start_send error: {}", e);
424 Err(e)
425 }
426 }
427 }
428
429 fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
430 self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
431 .map(|r| {
432 self.ready = true;
433 match r {
434 Err(WsError::ConnectionClosed) => Ok(()),
436 other => other,
437 }
438 })
439 }
440
441 #[cfg(feature = "futures-03-sink")]
442 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
443 self.ready = true;
444 let res = if self.closing {
445 self.with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
447 } else {
448 self.with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
449 };
450
451 match res {
452 Ok(()) => Poll::Ready(Ok(())),
453 Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
454 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
455 trace!("WouldBlock");
456 self.closing = true;
457 Poll::Pending
458 }
459 Err(err) => {
460 debug!("websocket close error: {}", err);
461 Poll::Ready(Err(err))
462 }
463 }
464 }
465}
466
467impl<S> Stream for WebSocketStream<S>
468where
469 S: AsyncRead + AsyncWrite + Unpin,
470{
471 type Item = Result<Message, WsError>;
472
473 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
474 self.get_mut().poll_next(cx)
475 }
476}
477
478impl<S> FusedStream for WebSocketStream<S>
479where
480 S: AsyncRead + AsyncWrite + Unpin,
481{
482 fn is_terminated(&self) -> bool {
483 self.ended
484 }
485}
486
487#[cfg(feature = "futures-03-sink")]
488impl<S> futures_util::Sink<Message> for WebSocketStream<S>
489where
490 S: AsyncRead + AsyncWrite + Unpin,
491{
492 type Error = WsError;
493
494 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
495 self.get_mut().poll_ready(cx)
496 }
497
498 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
499 self.get_mut().start_send(item)
500 }
501
502 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
503 self.get_mut().poll_flush(cx)
504 }
505
506 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
507 self.get_mut().poll_close(cx)
508 }
509}
510
511#[cfg(not(feature = "futures-03-sink"))]
512impl<S> bytes::private::SealedSender for WebSocketStream<S>
513where
514 S: AsyncRead + AsyncWrite + Unpin,
515{
516 fn poll_write(
517 self: Pin<&mut Self>,
518 cx: &mut Context<'_>,
519 buf: &[u8],
520 ) -> Poll<Result<usize, WsError>> {
521 let me = self.get_mut();
522 ready!(me.poll_ready(cx))?;
523 let len = buf.len();
524 me.start_send(Message::binary(buf.to_owned()))?;
525 Poll::Ready(Ok(len))
526 }
527
528 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
529 self.get_mut().poll_flush(cx)
530 }
531
532 fn poll_close(
533 self: Pin<&mut Self>,
534 cx: &mut Context<'_>,
535 msg: &mut Option<Message>,
536 ) -> Poll<Result<(), WsError>> {
537 let me = self.get_mut();
538 send_helper(me, msg, cx)
539 }
540}
541
542impl<S> WebSocketStream<S> {
543 pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
545 where
546 S: AsyncRead + AsyncWrite + Unpin,
547 {
548 Send {
549 ws: self,
550 msg: Some(msg),
551 }
552 .await
553 }
554}
555
556struct Send<W> {
557 ws: W,
558 msg: Option<Message>,
559}
560
561fn send_helper<S>(
563 ws: &mut WebSocketStream<S>,
564 msg: &mut Option<Message>,
565 cx: &mut Context<'_>,
566) -> Poll<Result<(), WsError>>
567where
568 S: AsyncRead + AsyncWrite + Unpin,
569{
570 if msg.is_some() {
571 ready!(ws.poll_ready(cx))?;
572 let msg = msg.take().expect("unreachable");
573 ws.start_send(msg)?;
574 }
575
576 ws.poll_flush(cx)
577}
578
579impl<S> std::future::Future for Send<&mut WebSocketStream<S>>
580where
581 S: AsyncRead + AsyncWrite + Unpin,
582{
583 type Output = Result<(), WsError>;
584
585 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
586 let me = self.get_mut();
587 send_helper(me.ws, &mut me.msg, cx)
588 }
589}
590
591impl<S> std::future::Future for Send<&Shared<S>>
592where
593 S: AsyncRead + AsyncWrite + Unpin,
594{
595 type Output = Result<(), WsError>;
596
597 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
598 let me = self.get_mut();
599 let mut ws = me.ws.lock();
600 send_helper(&mut ws, &mut me.msg, cx)
601 }
602}
603
604#[derive(Debug)]
606pub struct WebSocketSender<S> {
607 shared: Arc<Shared<S>>,
608}
609
610impl<S> WebSocketSender<S> {
611 pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
613 where
614 S: AsyncRead + AsyncWrite + Unpin,
615 {
616 Send {
617 ws: &*self.shared,
618 msg: Some(msg),
619 }
620 .await
621 }
622
623 pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
625 where
626 S: AsyncRead + AsyncWrite + Unpin,
627 {
628 self.send(Message::Close(msg)).await
629 }
630
631 pub fn is_pair_of(&self, other: &WebSocketReceiver<S>) -> bool {
634 Arc::ptr_eq(&self.shared, &other.shared)
635 }
636}
637
638#[cfg(feature = "futures-03-sink")]
639impl<T> futures_util::Sink<Message> for WebSocketSender<T>
640where
641 T: AsyncRead + AsyncWrite + Unpin,
642{
643 type Error = WsError;
644
645 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
646 self.shared.lock().poll_ready(cx)
647 }
648
649 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
650 self.shared.lock().start_send(item)
651 }
652
653 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
654 self.shared.lock().poll_flush(cx)
655 }
656
657 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
658 self.shared.lock().poll_close(cx)
659 }
660}
661
662#[cfg(not(feature = "futures-03-sink"))]
663impl<S> bytes::private::SealedSender for WebSocketSender<S>
664where
665 S: AsyncRead + AsyncWrite + Unpin,
666{
667 fn poll_write(
668 self: Pin<&mut Self>,
669 cx: &mut Context<'_>,
670 buf: &[u8],
671 ) -> Poll<Result<usize, WsError>> {
672 let me = self.get_mut();
673 let mut ws = me.shared.lock();
674 ready!(ws.poll_ready(cx))?;
675 let len = buf.len();
676 ws.start_send(Message::binary(buf.to_owned()))?;
677 Poll::Ready(Ok(len))
678 }
679
680 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
681 self.shared.lock().poll_flush(cx)
682 }
683
684 fn poll_close(
685 self: Pin<&mut Self>,
686 cx: &mut Context<'_>,
687 msg: &mut Option<Message>,
688 ) -> Poll<Result<(), WsError>> {
689 let me = self.get_mut();
690 let mut ws = me.shared.lock();
691 send_helper(&mut ws, msg, cx)
692 }
693}
694
695#[derive(Debug)]
697pub struct WebSocketReceiver<S> {
698 shared: Arc<Shared<S>>,
699}
700
701impl<S> WebSocketReceiver<S> {
702 pub fn is_pair_of(&self, other: &WebSocketSender<S>) -> bool {
705 Arc::ptr_eq(&self.shared, &other.shared)
706 }
707}
708
709impl<S> Stream for WebSocketReceiver<S>
710where
711 S: AsyncRead + AsyncWrite + Unpin,
712{
713 type Item = Result<Message, WsError>;
714
715 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
716 self.shared.lock().poll_next(cx)
717 }
718}
719
720impl<S> FusedStream for WebSocketReceiver<S>
721where
722 S: AsyncRead + AsyncWrite + Unpin,
723{
724 fn is_terminated(&self) -> bool {
725 self.shared.lock().ended
726 }
727}
728
729#[derive(Debug)]
730struct Shared<S>(Mutex<WebSocketStream<S>>);
731
732impl<S> Shared<S> {
733 fn lock(&self) -> MutexGuard<'_, WebSocketStream<S>> {
734 self.0.lock().expect("lock shared stream")
735 }
736
737 fn into_inner(self) -> WebSocketStream<S> {
738 self.0.into_inner().expect("get shared stream")
739 }
740}
741
742#[cfg(any(
743 feature = "async-tls",
744 feature = "async-std-runtime",
745 feature = "tokio-runtime",
746 feature = "gio-runtime"
747))]
748#[inline]
750pub(crate) fn domain(
751 request: &tungstenite::handshake::client::Request,
752) -> Result<String, tungstenite::Error> {
753 request
754 .uri()
755 .host()
756 .map(|host| {
757 let host = if host.starts_with('[') {
763 &host[1..host.len() - 1]
764 } else {
765 host
766 };
767
768 host.to_owned()
769 })
770 .ok_or(tungstenite::Error::Url(
771 tungstenite::error::UrlError::NoHostName,
772 ))
773}
774
775#[cfg(any(
776 feature = "async-std-runtime",
777 feature = "tokio-runtime",
778 feature = "gio-runtime"
779))]
780#[inline]
782pub(crate) fn port(
783 request: &tungstenite::handshake::client::Request,
784) -> Result<u16, tungstenite::Error> {
785 request
786 .uri()
787 .port_u16()
788 .or_else(|| match request.uri().scheme_str() {
789 Some("wss") => Some(443),
790 Some("ws") => Some(80),
791 _ => None,
792 })
793 .ok_or(tungstenite::Error::Url(
794 tungstenite::error::UrlError::UnsupportedUrlScheme,
795 ))
796}
797
798#[cfg(test)]
799mod tests {
800 #[cfg(any(
801 feature = "async-tls",
802 feature = "async-std-runtime",
803 feature = "tokio-runtime",
804 feature = "gio-runtime"
805 ))]
806 #[test]
807 fn domain_strips_ipv6_brackets() {
808 use tungstenite::client::IntoClientRequest;
809
810 let request = "ws://[::1]:80".into_client_request().unwrap();
811 assert_eq!(crate::domain(&request).unwrap(), "::1");
812 }
813
814 #[cfg(feature = "handshake")]
815 #[test]
816 fn requests_cannot_contain_invalid_uris() {
817 use tungstenite::client::IntoClientRequest;
818
819 assert!("ws://[".into_client_request().is_err());
820 assert!("ws://[blabla/bla".into_client_request().is_err());
821 assert!("ws://[::1/bla".into_client_request().is_err());
822 }
823}