1#![deny(missing_docs)]
2use holochain_serialized_bytes::prelude::*;
7use holochain_types::websocket::AllowedOrigins;
8use std::io::ErrorKind;
9pub use std::io::{Error, Result};
10use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
11use std::sync::Arc;
12use tokio::net::ToSocketAddrs;
13use tokio::select;
14use tokio_tungstenite::tungstenite::handshake::client::Request;
15use tokio_tungstenite::tungstenite::handshake::server::{Callback, ErrorResponse, Response};
16use tokio_tungstenite::tungstenite::http::{HeaderMap, HeaderValue, StatusCode};
17use tokio_tungstenite::tungstenite::protocol::Message;
18
19#[derive(Debug, serde::Serialize, serde::Deserialize, SerializedBytes)]
20#[serde(rename_all = "snake_case", tag = "type")]
21pub enum WireMessage {
25 Signal {
27 #[serde(with = "serde_bytes")]
28 data: Vec<u8>,
30 },
31
32 Authenticate {
34 #[serde(with = "serde_bytes")]
35 data: Vec<u8>,
37 },
38
39 Request {
41 id: u64,
43 #[serde(with = "serde_bytes")]
44 data: Vec<u8>,
46 },
47
48 Response {
50 id: u64,
52 #[serde(with = "serde_bytes")]
53 data: Option<Vec<u8>>,
55 },
56}
57
58impl WireMessage {
59 fn try_from_bytes(b: Vec<u8>) -> WebsocketResult<Self> {
61 let b = UnsafeBytes::from(b);
62 let b = SerializedBytes::from(b);
63 let b: WireMessage = b.try_into()?;
64 Ok(b)
65 }
66
67 fn authenticate<S>(s: S) -> WebsocketResult<Message>
69 where
70 S: std::fmt::Debug,
71 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
72 {
73 let s1 = SerializedBytes::try_from(s)?;
74 let s2 = Self::Authenticate {
75 data: UnsafeBytes::from(s1).into(),
76 };
77 let s3: SerializedBytes = s2.try_into()?;
78 Ok(Message::Binary(UnsafeBytes::from(s3).into()))
79 }
80
81 fn request<S>(s: S) -> WebsocketResult<(Message, u64)>
83 where
84 S: std::fmt::Debug,
85 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
86 {
87 static ID: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
88 let id = ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
89 tracing::trace!(?s, %id, "OutRequest");
90 let s1 = SerializedBytes::try_from(s)?;
91 let s2 = Self::Request {
92 id,
93 data: UnsafeBytes::from(s1).into(),
94 };
95 let s3: SerializedBytes = s2.try_into()?;
96 Ok((Message::Binary(UnsafeBytes::from(s3).into()), id))
97 }
98
99 fn response<S>(id: u64, s: S) -> WebsocketResult<Message>
101 where
102 S: std::fmt::Debug,
103 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
104 {
105 let s1 = SerializedBytes::try_from(s)?;
106 let s2 = Self::Response {
107 id,
108 data: Some(UnsafeBytes::from(s1).into()),
109 };
110 let s3: SerializedBytes = s2.try_into()?;
111 Ok(Message::Binary(UnsafeBytes::from(s3).into()))
112 }
113
114 fn signal<S>(s: S) -> WebsocketResult<Message>
116 where
117 S: std::fmt::Debug,
118 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
119 {
120 tracing::trace!(?s, "SendSignal");
121 let s1 = SerializedBytes::try_from(s)?;
122 let s2 = Self::Signal {
123 data: UnsafeBytes::from(s1).into(),
124 };
125 let s3: SerializedBytes = s2.try_into()?;
126 Ok(Message::Binary(UnsafeBytes::from(s3).into()))
127 }
128}
129
130#[derive(Clone, Debug)]
132pub struct WebsocketConfig {
133 pub default_request_timeout: std::time::Duration,
136
137 pub max_message_size: usize,
139
140 pub max_frame_size: usize,
142
143 pub allowed_origins: Option<AllowedOrigins>,
146}
147
148impl WebsocketConfig {
149 pub const CLIENT_DEFAULT: WebsocketConfig = WebsocketConfig {
151 default_request_timeout: std::time::Duration::from_secs(60),
152 max_message_size: 64 << 20,
153 max_frame_size: 16 << 20,
154 allowed_origins: None,
155 };
156
157 pub const LISTENER_DEFAULT: WebsocketConfig = WebsocketConfig {
159 default_request_timeout: std::time::Duration::from_secs(60),
160 max_message_size: 64 << 20,
161 max_frame_size: 16 << 20,
162 allowed_origins: Some(AllowedOrigins::Any),
163 };
164
165 pub(crate) fn as_tungstenite(
167 &self,
168 ) -> tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
169 tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
170 max_message_size: Some(self.max_message_size),
171 max_frame_size: Some(self.max_frame_size),
172 ..Default::default()
173 }
174 }
175}
176
177struct RMapInner(
178 pub std::collections::HashMap<
179 u64,
180 tokio::sync::oneshot::Sender<WebsocketResult<SerializedBytes>>,
181 >,
182);
183
184impl Drop for RMapInner {
185 fn drop(&mut self) {
186 self.close();
187 }
188}
189
190impl RMapInner {
191 fn close(&mut self) {
192 for (_, s) in self.0.drain() {
193 let _ = s.send(Err(WebsocketError::Close("ConnectionClosed".to_string())));
194 }
195 }
196}
197
198#[derive(Clone)]
199struct RMap(Arc<std::sync::Mutex<RMapInner>>);
200
201impl Default for RMap {
202 fn default() -> Self {
203 Self(Arc::new(std::sync::Mutex::new(RMapInner(
204 std::collections::HashMap::default(),
205 ))))
206 }
207}
208
209impl RMap {
210 pub fn close(&self) {
211 if let Ok(mut lock) = self.0.lock() {
212 lock.close();
213 }
214 }
215
216 pub fn insert(
217 &self,
218 id: u64,
219 sender: tokio::sync::oneshot::Sender<WebsocketResult<SerializedBytes>>,
220 ) {
221 self.0.lock().unwrap().0.insert(id, sender);
222 }
223
224 pub fn remove(
225 &self,
226 id: u64,
227 ) -> Option<tokio::sync::oneshot::Sender<WebsocketResult<SerializedBytes>>> {
228 self.0.lock().unwrap().0.remove(&id)
229 }
230}
231
232#[derive(thiserror::Error, Debug)]
237pub enum WebsocketError {
238 #[error("Websocket closed: {0}")]
240 Close(String),
241 #[error("Received a message that did not deserialize: {0}")]
243 Deserialize(#[from] SerializedBytesError),
244 #[error("Websocket error: {0}")]
246 Websocket(#[from] tokio_tungstenite::tungstenite::Error),
247 #[error("Timeout")]
249 Timeout(#[from] tokio::time::error::Elapsed),
250 #[error("IO error: {0}")]
252 Io(#[from] Error),
253 #[error("Other error: {0}")]
255 Other(String),
256}
257
258pub type WebsocketResult<T> = std::result::Result<T, WebsocketError>;
260
261type WsStream = tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>;
262type WsSend =
263 futures::stream::SplitSink<WsStream, tokio_tungstenite::tungstenite::protocol::Message>;
264type WsSendSync = Arc<tokio::sync::Mutex<WsSend>>;
265type WsRecv = futures::stream::SplitStream<WsStream>;
266type WsRecvSync = Arc<tokio::sync::Mutex<WsRecv>>;
267
268#[derive(Clone)]
269struct WsCore {
270 pub send: WsSendSync,
271 pub recv: WsRecvSync,
272 pub rmap: RMap,
273 pub timeout: std::time::Duration,
274}
275
276#[derive(Clone)]
277struct WsCoreSync(Arc<std::sync::Mutex<Option<WsCore>>>);
278
279impl PartialEq for WsCoreSync {
280 fn eq(&self, other: &Self) -> bool {
281 Arc::ptr_eq(&self.0, &other.0)
282 }
283}
284
285impl WsCoreSync {
286 fn close(&self) {
287 if let Some(core) = self.0.lock().unwrap().take() {
288 core.rmap.close();
289 tokio::task::spawn(async move {
290 use futures::sink::SinkExt;
291 let _ = core.send.lock().await.close().await;
292 });
293 }
294 }
295
296 fn close_if_err<R>(&self, r: WebsocketResult<R>) -> WebsocketResult<R> {
297 match r {
298 Err(e @ WebsocketError::Deserialize { .. }) => {
299 Err(e)
302 }
303 Err(err) => {
304 self.close();
305 Err(err)
306 }
307 Ok(res) => Ok(res),
308 }
309 }
310
311 pub async fn exec<F, C, R>(&self, c: C) -> WebsocketResult<R>
312 where
313 F: std::future::Future<Output = WebsocketResult<R>>,
314 C: FnOnce(WsCoreSync, WsCore) -> F,
315 {
316 let core = match self.0.lock().unwrap().as_ref() {
317 Some(core) => core.clone(),
318 None => return Err(WebsocketError::Close("No connection".to_string())),
319 };
320 self.close_if_err(c(self.clone(), core).await)
321 }
322}
323
324#[derive(PartialEq)]
326pub struct WebsocketRespond {
327 id: u64,
328 core: WsCoreSync,
329}
330
331impl std::fmt::Debug for WebsocketRespond {
332 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333 f.debug_struct("WebsocketRespond")
334 .field("id", &self.id)
335 .finish()
336 }
337}
338
339impl WebsocketRespond {
340 pub async fn respond<S>(self, s: S) -> WebsocketResult<()>
342 where
343 S: std::fmt::Debug,
344 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
345 {
346 tracing::trace!(?s, %self.id, "OutResponse");
347 use futures::sink::SinkExt;
348 self.core
349 .exec(move |_, core| async move {
350 tokio::time::timeout(core.timeout, async {
351 let s = WireMessage::response(self.id, s)?;
352 core.send.lock().await.send(s).await?;
353 Ok(())
354 })
355 .await?
356 })
357 .await
358 }
359}
360
361#[derive(Debug, PartialEq)]
363pub enum ReceiveMessage<D>
364where
365 D: std::fmt::Debug,
366 SerializedBytes: TryInto<D, Error = SerializedBytesError>,
367{
368 Authenticate(Vec<u8>),
370
371 Signal(Vec<u8>),
373
374 Request(D, WebsocketRespond),
376}
377
378pub struct WebsocketReceiver(
383 WsCoreSync,
384 std::net::SocketAddr,
385 tokio::task::JoinHandle<()>,
386);
387
388impl Drop for WebsocketReceiver {
389 fn drop(&mut self) {
390 self.0.close();
391 self.2.abort();
392 }
393}
394
395impl WebsocketReceiver {
396 fn new(core: WsCoreSync, addr: std::net::SocketAddr) -> Self {
397 let core2 = core.clone();
398 let ping_task = tokio::task::spawn(async move {
399 loop {
400 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
401 let core = core2.0.lock().unwrap().as_ref().cloned();
402 if let Some(core) = core {
403 use futures::sink::SinkExt;
404 if core
405 .send
406 .lock()
407 .await
408 .send(Message::Ping(Vec::new()))
409 .await
410 .is_err()
411 {
412 core2.close();
413 }
414 } else {
415 break;
416 }
417 }
418 });
419 Self(core, addr, ping_task)
420 }
421
422 pub fn peer_addr(&self) -> std::net::SocketAddr {
424 self.1
425 }
426
427 pub async fn recv<D>(&mut self) -> WebsocketResult<ReceiveMessage<D>>
429 where
430 D: std::fmt::Debug,
431 SerializedBytes: TryInto<D, Error = SerializedBytesError>,
432 {
433 match self.recv_inner().await {
434 Err(err) => {
435 tracing::warn!(?err, "WebsocketReceiver Error");
436 Err(err)
437 }
438 Ok(msg) => Ok(msg),
439 }
440 }
441
442 async fn recv_inner<D>(&mut self) -> WebsocketResult<ReceiveMessage<D>>
443 where
444 D: std::fmt::Debug,
445 SerializedBytes: TryInto<D, Error = SerializedBytesError>,
446 {
447 use futures::sink::SinkExt;
448 use futures::stream::StreamExt;
449 loop {
450 if let Some(result) = self
451 .0
452 .exec(move |core_sync, core| async move {
453 let msg = core
454 .recv
455 .lock()
456 .await
457 .next()
458 .await
459 .ok_or::<WebsocketError>(WebsocketError::Other(
460 "ReceiverClosed".to_string(),
461 ))??;
462 let msg = match msg {
463 Message::Text(s) => s.into_bytes(),
464 Message::Binary(b) => b,
465 Message::Ping(b) => {
466 core.send.lock().await.send(Message::Pong(b)).await?;
467 return Ok(None);
468 }
469 Message::Pong(_) => return Ok(None),
470 Message::Close(frame) => {
471 return Err(WebsocketError::Close(format!("{frame:?}")));
472 }
473 Message::Frame(_) => {
474 return Err(WebsocketError::Other("UnexpectedRawFrame".to_string()))
475 }
476 };
477 match WireMessage::try_from_bytes(msg)? {
478 WireMessage::Authenticate { data } => {
479 Ok(Some(ReceiveMessage::Authenticate(data)))
480 }
481 WireMessage::Request { id, data } => {
482 let resp = WebsocketRespond {
483 id,
484 core: core_sync,
485 };
486 let data: D =
487 SerializedBytes::from(UnsafeBytes::from(data)).try_into()?;
488 tracing::trace!(?data, %id, "InRequest");
489 Ok(Some(ReceiveMessage::Request(data, resp)))
490 }
491 WireMessage::Response { id, data } => {
492 if let Some(sender) = core.rmap.remove(id) {
493 if let Some(data) = data {
494 let data = SerializedBytes::from(UnsafeBytes::from(data));
495 tracing::trace!(%id, ?data, "InResponse");
496 let _ = sender.send(Ok(data));
497 }
498 }
499 Ok(None)
500 }
501 WireMessage::Signal { data } => Ok(Some(ReceiveMessage::Signal(data))),
502 }
503 })
504 .await?
505 {
506 return Ok(result);
507 }
508 }
509 }
510}
511
512#[derive(Clone)]
516pub struct WebsocketSender(WsCoreSync, std::time::Duration);
517
518impl WebsocketSender {
519 pub async fn authenticate<S>(&self, s: S) -> WebsocketResult<()>
521 where
522 S: std::fmt::Debug,
523 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
524 {
525 self.authenticate_timeout(s, self.1).await
526 }
527
528 pub async fn authenticate_timeout<S>(
530 &self,
531 s: S,
532 timeout: std::time::Duration,
533 ) -> WebsocketResult<()>
534 where
535 S: std::fmt::Debug,
536 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
537 {
538 use futures::sink::SinkExt;
539 self.0
540 .exec(move |_, core| async move {
541 tokio::time::timeout(timeout, async {
542 let s = WireMessage::authenticate(s)?;
543 core.send.lock().await.send(s).await?;
544 Ok(())
545 })
546 .await?
547 })
548 .await
549 }
550
551 pub async fn request<S, R>(&self, s: S) -> WebsocketResult<R>
555 where
556 S: std::fmt::Debug,
557 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
558 R: serde::de::DeserializeOwned + std::fmt::Debug,
559 {
560 self.request_timeout(s, self.1).await
561 }
562
563 pub async fn request_timeout<S, R>(
565 &self,
566 s: S,
567 timeout: std::time::Duration,
568 ) -> WebsocketResult<R>
569 where
570 S: std::fmt::Debug,
571 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
572 R: serde::de::DeserializeOwned + std::fmt::Debug,
573 {
574 let timeout_at = tokio::time::Instant::now() + timeout;
575
576 use futures::sink::SinkExt;
577
578 let (s, id) = WireMessage::request(s)?;
579
580 struct D(RMap, u64);
582
583 impl Drop for D {
584 fn drop(&mut self) {
585 self.0.remove(self.1);
586 }
587 }
588
589 let (resp_s, resp_r) = tokio::sync::oneshot::channel();
590
591 let _drop = self
592 .0
593 .exec(move |_, core| async move {
594 let drop = D(core.rmap.clone(), id);
596
597 core.rmap.insert(id, resp_s);
599
600 tokio::time::timeout_at(timeout_at, async move {
601 core.send.lock().await.send(s).await?;
603
604 Ok(drop)
605 })
606 .await?
607 })
608 .await?;
609
610 tokio::time::timeout_at(timeout_at, async {
615 let resp = resp_r
617 .await
618 .map_err(|_| WebsocketError::Other("ResponderDropped".to_string()))??;
619
620 let res = decode(&Vec::from(UnsafeBytes::from(resp)))?;
622 tracing::trace!(?res, %id, "OutRequestResponse");
623 Ok(res)
624 })
625 .await?
626 }
627
628 pub async fn signal<S>(&self, s: S) -> WebsocketResult<()>
630 where
631 S: std::fmt::Debug,
632 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
633 {
634 self.signal_timeout(s, self.1).await
635 }
636
637 pub async fn signal_timeout<S>(&self, s: S, timeout: std::time::Duration) -> WebsocketResult<()>
639 where
640 S: std::fmt::Debug,
641 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
642 {
643 use futures::sink::SinkExt;
644 self.0
645 .exec(move |_, core| async move {
646 tokio::time::timeout(timeout, async {
647 let s = WireMessage::signal(s)?;
648 core.send.lock().await.send(s).await?;
649 Ok(())
650 })
651 .await?
652 })
653 .await
654 }
655}
656
657fn split(
658 stream: WsStream,
659 timeout: std::time::Duration,
660 peer_addr: std::net::SocketAddr,
661) -> WebsocketResult<(WebsocketSender, WebsocketReceiver)> {
662 let (sink, stream) = futures::stream::StreamExt::split(stream);
663
664 let core = WsCore {
670 send: Arc::new(tokio::sync::Mutex::new(sink)),
671 recv: Arc::new(tokio::sync::Mutex::new(stream)),
672 rmap: RMap::default(),
673 timeout,
674 };
675
676 let core_send = WsCoreSync(Arc::new(std::sync::Mutex::new(Some(core))));
677 let core_recv = core_send.clone();
678
679 Ok((
680 WebsocketSender(core_send, timeout),
681 WebsocketReceiver::new(core_recv, peer_addr),
682 ))
683}
684
685pub async fn connect(
687 config: Arc<WebsocketConfig>,
688 request: impl Into<ConnectRequest>,
689) -> WebsocketResult<(WebsocketSender, WebsocketReceiver)> {
690 let request = request.into();
691 let stream = tokio::net::TcpStream::connect(request.addr).await?;
692 let peer_addr = stream.peer_addr()?;
693 let (stream, _addr) = tokio_tungstenite::client_async_with_config(
694 request.into_client_request()?,
695 stream,
696 Some(config.as_tungstenite()),
697 )
698 .await?;
699 split(stream, config.default_request_timeout, peer_addr)
700}
701
702pub struct ConnectRequest {
704 addr: std::net::SocketAddr,
705 headers: HeaderMap<HeaderValue>,
706}
707
708impl From<std::net::SocketAddr> for ConnectRequest {
709 fn from(addr: std::net::SocketAddr) -> Self {
710 Self::new(addr)
711 }
712}
713
714impl ConnectRequest {
715 pub fn new(addr: std::net::SocketAddr) -> Self {
717 let mut cr = ConnectRequest {
718 addr,
719 headers: HeaderMap::new(),
720 };
721
722 cr.headers.insert(
725 "Origin",
726 HeaderValue::from_str("holochain_websocket").expect("Invalid Origin value"),
727 );
728
729 cr
730 }
731
732 pub fn try_set_header(mut self, name: &'static str, value: &str) -> Result<Self> {
736 self.headers
737 .insert(name, HeaderValue::from_str(value).map_err(Error::other)?);
738 Ok(self)
739 }
740
741 fn into_client_request(
742 self,
743 ) -> Result<impl tokio_tungstenite::tungstenite::client::IntoClientRequest + Unpin> {
744 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
745 let mut req =
746 String::into_client_request(format!("ws://{}", self.addr)).map_err(Error::other)?;
747 for (name, value) in self.headers {
748 if let Some(name) = name {
749 req.headers_mut().insert(name, value);
750 } else {
751 tracing::warn!("Dropping invalid header");
752 }
753 }
754 Ok(req)
755 }
756
757 #[cfg(test)]
758 pub(crate) fn clear_headers(mut self) -> Self {
759 self.headers.clear();
760
761 self
762 }
763}
764
765#[async_trait::async_trait]
767trait TcpListener: Send + Sync {
768 async fn accept(&self) -> Result<(tokio::net::TcpStream, SocketAddr)>;
769
770 fn local_addrs(&self) -> Result<Vec<SocketAddr>>;
771}
772
773#[async_trait::async_trait]
774impl TcpListener for tokio::net::TcpListener {
775 async fn accept(&self) -> Result<(tokio::net::TcpStream, SocketAddr)> {
776 self.accept().await
777 }
778
779 fn local_addrs(&self) -> Result<Vec<SocketAddr>> {
780 Ok(vec![self.local_addr()?])
781 }
782}
783
784struct DualStackListener {
785 v4: tokio::net::TcpListener,
786 v6: tokio::net::TcpListener,
787}
788
789#[async_trait::async_trait]
790impl TcpListener for DualStackListener {
791 async fn accept(&self) -> Result<(tokio::net::TcpStream, SocketAddr)> {
792 let (stream, addr) = select! {
793 res = self.v4.accept() => res?,
794 res = self.v6.accept() => res?,
795 };
796 Ok((stream, addr))
797 }
798
799 fn local_addrs(&self) -> Result<Vec<SocketAddr>> {
800 Ok(vec![self.v4.local_addr()?, self.v6.local_addr()?])
801 }
802}
803
804pub struct WebsocketListener {
806 config: Arc<WebsocketConfig>,
807 access_control: Arc<AllowedOrigins>,
808 listener: Box<dyn TcpListener>,
809}
810
811impl Drop for WebsocketListener {
812 fn drop(&mut self) {
813 tracing::info!("WebsocketListenerDrop");
814 }
815}
816
817impl WebsocketListener {
818 pub async fn bind(config: Arc<WebsocketConfig>, addr: impl ToSocketAddrs) -> Result<Self> {
820 let access_control = Arc::new(config.allowed_origins.clone().ok_or_else(|| {
821 Error::other("WebsocketListener requires allowed_origins to be set in the config")
822 })?);
823
824 let listener = tokio::net::TcpListener::bind(addr).await?;
825
826 let addr = listener.local_addr()?;
827 tracing::info!(?addr, "WebsocketListener Listening");
828
829 Ok(Self {
830 config,
831 access_control,
832 listener: Box::new(listener),
833 })
834 }
835
836 pub async fn dual_bind(
853 config: Arc<WebsocketConfig>,
854 addr_v4: SocketAddrV4,
855 addr_v6: SocketAddrV6,
856 ) -> Result<Self> {
857 let access_control = Arc::new(config.allowed_origins.clone().ok_or_else(|| {
858 Error::other("WebsocketListener requires allowed_origins to be set in the config")
859 })?);
860
861 let addr_v6: SocketAddr = addr_v6.into();
862 let mut addr_v4: SocketAddr = addr_v4.into();
863
864 if addr_v6.port() != 0 && addr_v6.port() != addr_v4.port() {
866 return Err(Error::other(
867 "dual_bind requires the same port for IPv4 and IPv6",
868 ));
869 }
870
871 let mut listener: Option<DualStackListener> = None;
875 for _ in 0..5 {
876 let v6_listener = match tokio::net::TcpListener::bind(addr_v6).await {
877 Ok(l) => l,
878 Err(e) if e.kind() == ErrorKind::AddrNotAvailable => {
880 tracing::info!(?e, "Failed to bind IPv6 listener because IPv6 appears to be disabled, falling back to IPv4 only");
881 return Self::bind(config, addr_v4).await;
882 }
883 Err(e) => {
884 return Err(e);
885 }
886 };
887
888 addr_v4.set_port(v6_listener.local_addr()?.port());
889
890 let v4_listener = match tokio::net::TcpListener::bind(addr_v4).await {
891 Ok(l) => l,
892 Err(e) if e.kind() == ErrorKind::AddrNotAvailable => {
894 tracing::info!(?e, "Failed to bind IPv4 listener because IPv4 appears to be disabled, falling back to IPv6 only");
895 return Ok(Self {
898 config,
899 access_control,
900 listener: Box::new(v6_listener),
901 });
902 }
903 Err(e) if addr_v6.port() == 0 && e.kind() == ErrorKind::AddrInUse => {
906 tracing::warn!(?e, "Failed to bind the same port for IPv4 that was selected for IPv6, retrying with a new port");
907 continue;
908 }
909 Err(e) => {
910 return Err(e);
911 }
912 };
913
914 listener = Some(DualStackListener {
915 v4: v4_listener,
916 v6: v6_listener,
917 });
918 break;
919 }
920
921 let listener = listener.ok_or_else(|| {
924 Error::other("Failed to bind listener to IPv4 and IPv6 interfaces after 5 retries")
925 })?;
926
927 let addr = listener.v4.local_addr()?;
928 tracing::info!(?addr, "WebsocketListener listening");
929
930 let addr = listener.v6.local_addr()?;
931 tracing::info!(?addr, "WebsocketListener listening");
932
933 Ok(Self {
934 config,
935 access_control,
936 listener: Box::new(listener),
937 })
938 }
939
940 pub fn local_addrs(&self) -> Result<Vec<std::net::SocketAddr>> {
942 self.listener.local_addrs()
943 }
944
945 pub async fn accept(&self) -> WebsocketResult<(WebsocketSender, WebsocketReceiver)> {
947 let (stream, addr) = self.listener.accept().await?;
948 tracing::debug!(?addr, "Accept Incoming Websocket Connection");
949 let stream = tokio_tungstenite::accept_hdr_async_with_config(
950 stream,
951 ConnectCallback {
952 allowed_origin: self.access_control.clone(),
953 },
954 Some(self.config.as_tungstenite()),
955 )
956 .await
957 .map_err(Error::other)?;
958 split(stream, self.config.default_request_timeout, addr)
959 }
960}
961
962struct ConnectCallback {
963 allowed_origin: Arc<AllowedOrigins>,
964}
965
966impl Callback for ConnectCallback {
967 fn on_request(
968 self,
969 request: &Request,
970 response: Response,
971 ) -> std::result::Result<Response, ErrorResponse> {
972 tracing::trace!(
973 "Checking incoming websocket connection request with allowed origin {:?}: {:?}",
974 self.allowed_origin,
975 request.headers()
976 );
977 match request
978 .headers()
979 .get("Origin")
980 .and_then(|v| v.to_str().ok())
981 {
982 Some(origin) => {
983 if self.allowed_origin.is_allowed(origin) {
984 Ok(response)
985 } else {
986 tracing::warn!("Rejecting websocket connection request with disallowed `Origin` header: {:?}", request);
987 let allowed_origin: String = self.allowed_origin.as_ref().clone().into();
988 match HeaderValue::from_str(&allowed_origin) {
989 Ok(allowed_origin) => {
990 let mut err_response = ErrorResponse::new(None);
991 *err_response.status_mut() = StatusCode::BAD_REQUEST;
992 err_response
993 .headers_mut()
994 .insert("Access-Control-Allow-Origin", allowed_origin);
995 Err(err_response)
996 }
997 Err(_) => {
998 let mut err_response = ErrorResponse::new(Some(
1000 "Invalid listener configuration for `Origin`".to_string(),
1001 ));
1002 *err_response.status_mut() = StatusCode::BAD_REQUEST;
1003 Err(err_response)
1004 }
1005 }
1006 }
1007 }
1008 None => {
1009 tracing::warn!(
1010 "Rejecting websocket connection request with missing `Origin` header: {:?}",
1011 request
1012 );
1013 let mut err_response =
1014 ErrorResponse::new(Some("Missing `Origin` header".to_string()));
1015 *err_response.status_mut() = StatusCode::BAD_REQUEST;
1016 Err(err_response)
1017 }
1018 }
1019 }
1020}
1021
1022#[cfg(test)]
1023mod test;