1#![deny(missing_docs)]
2use bytes::Bytes;
7use holochain_serialized_bytes::prelude::*;
8use holochain_types::websocket::AllowedOrigins;
9use std::io::ErrorKind;
10pub use std::io::{Error, Result};
11use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
12use std::sync::Arc;
13use tokio::net::ToSocketAddrs;
14use tokio::select;
15use tokio_tungstenite::tungstenite::handshake::client::Request;
16use tokio_tungstenite::tungstenite::handshake::server::{Callback, ErrorResponse, Response};
17use tokio_tungstenite::tungstenite::http::{HeaderMap, HeaderValue, StatusCode};
18use tokio_tungstenite::tungstenite::protocol::Message;
19
20#[derive(Debug, serde::Serialize, serde::Deserialize, SerializedBytes)]
24#[serde(rename_all = "snake_case", tag = "type")]
25pub enum WireMessage {
26 Signal {
28 #[serde(with = "serde_bytes")]
30 data: Vec<u8>,
31 },
32
33 Authenticate {
35 #[serde(with = "serde_bytes")]
37 data: Vec<u8>,
38 },
39
40 Request {
42 id: u64,
44 #[serde(with = "serde_bytes")]
46 data: Vec<u8>,
47 },
48
49 Response {
51 id: u64,
53 #[serde(with = "serde_bytes")]
55 data: Option<Vec<u8>>,
56 },
57}
58
59impl WireMessage {
60 fn try_from_bytes(b: Vec<u8>) -> WebsocketResult<Self> {
62 let b = UnsafeBytes::from(b);
63 let b = SerializedBytes::from(b);
64 let b: WireMessage = b.try_into()?;
65 Ok(b)
66 }
67
68 fn authenticate<S>(s: S) -> WebsocketResult<Message>
70 where
71 S: std::fmt::Debug,
72 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
73 {
74 let s1 = SerializedBytes::try_from(s)?;
75 let s2 = Self::Authenticate {
76 data: UnsafeBytes::from(s1).into(),
77 };
78 let s3: SerializedBytes = s2.try_into()?;
79 Ok(Message::Binary(Bytes::copy_from_slice(
80 s3.bytes().as_slice(),
81 )))
82 }
83
84 fn request<S>(s: S) -> WebsocketResult<(Message, u64)>
86 where
87 S: std::fmt::Debug,
88 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
89 {
90 static ID: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
91 let id = ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
92 tracing::trace!(?s, %id, "OutRequest");
93 let s1 = SerializedBytes::try_from(s)?;
94 let s2 = Self::Request {
95 id,
96 data: UnsafeBytes::from(s1).into(),
97 };
98 let s3: SerializedBytes = s2.try_into()?;
99 Ok((
100 Message::Binary(Bytes::copy_from_slice(s3.bytes().as_slice())),
101 id,
102 ))
103 }
104
105 fn response<S>(id: u64, s: S) -> WebsocketResult<Message>
107 where
108 S: std::fmt::Debug,
109 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
110 {
111 let s1 = SerializedBytes::try_from(s)?;
112 let s2 = Self::Response {
113 id,
114 data: Some(UnsafeBytes::from(s1).into()),
115 };
116 let s3: SerializedBytes = s2.try_into()?;
117 Ok(Message::Binary(Bytes::copy_from_slice(
118 s3.bytes().as_slice(),
119 )))
120 }
121
122 fn signal<S>(s: S) -> WebsocketResult<Message>
124 where
125 S: std::fmt::Debug,
126 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
127 {
128 tracing::trace!(?s, "SendSignal");
129 let s1 = SerializedBytes::try_from(s)?;
130 let s2 = Self::Signal {
131 data: UnsafeBytes::from(s1).into(),
132 };
133 let s3: SerializedBytes = s2.try_into()?;
134 Ok(Message::Binary(Bytes::copy_from_slice(
135 s3.bytes().as_slice(),
136 )))
137 }
138}
139
140#[derive(Clone, Debug)]
142pub struct WebsocketConfig {
143 pub default_request_timeout: std::time::Duration,
146
147 pub max_message_size: usize,
149
150 pub max_frame_size: usize,
152
153 pub allowed_origins: Option<AllowedOrigins>,
156}
157
158impl WebsocketConfig {
159 pub const CLIENT_DEFAULT: WebsocketConfig = WebsocketConfig {
161 default_request_timeout: std::time::Duration::from_secs(60),
162 max_message_size: 64 << 20,
163 max_frame_size: 16 << 20,
164 allowed_origins: None,
165 };
166
167 pub const LISTENER_DEFAULT: WebsocketConfig = WebsocketConfig {
169 default_request_timeout: std::time::Duration::from_secs(60),
170 max_message_size: 64 << 20,
171 max_frame_size: 16 << 20,
172 allowed_origins: Some(AllowedOrigins::Any),
173 };
174
175 pub(crate) fn as_tungstenite(
177 &self,
178 ) -> tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
179 tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
180 .max_message_size(Some(self.max_message_size))
181 .max_frame_size(Some(self.max_frame_size))
182 }
183}
184
185struct RMapInner(
186 pub std::collections::HashMap<
187 u64,
188 tokio::sync::oneshot::Sender<WebsocketResult<SerializedBytes>>,
189 >,
190);
191
192impl Drop for RMapInner {
193 fn drop(&mut self) {
194 self.close();
195 }
196}
197
198impl RMapInner {
199 fn close(&mut self) {
200 for (_, s) in self.0.drain() {
201 let _ = s.send(Err(WebsocketError::Close("ConnectionClosed".to_string())));
202 }
203 }
204}
205
206#[derive(Clone)]
207struct RMap(Arc<std::sync::Mutex<RMapInner>>);
208
209impl Default for RMap {
210 fn default() -> Self {
211 Self(Arc::new(std::sync::Mutex::new(RMapInner(
212 std::collections::HashMap::default(),
213 ))))
214 }
215}
216
217impl RMap {
218 pub fn close(&self) {
219 if let Ok(mut lock) = self.0.lock() {
220 lock.close();
221 }
222 }
223
224 pub fn insert(
225 &self,
226 id: u64,
227 sender: tokio::sync::oneshot::Sender<WebsocketResult<SerializedBytes>>,
228 ) {
229 self.0.lock().unwrap().0.insert(id, sender);
230 }
231
232 pub fn remove(
233 &self,
234 id: u64,
235 ) -> Option<tokio::sync::oneshot::Sender<WebsocketResult<SerializedBytes>>> {
236 self.0.lock().unwrap().0.remove(&id)
237 }
238}
239
240#[derive(thiserror::Error, Debug)]
245pub enum WebsocketError {
246 #[error("Websocket closed: {0}")]
248 Close(String),
249 #[error("Received a message that did not deserialize: {0}")]
251 Deserialize(#[from] SerializedBytesError),
252 #[error("Websocket error: {0}")]
254 Websocket(#[from] Box<tokio_tungstenite::tungstenite::Error>),
255 #[error("Timeout")]
257 Timeout(#[from] tokio::time::error::Elapsed),
258 #[error("IO error: {0}")]
260 Io(#[from] Error),
261 #[error("Other error: {0}")]
263 Other(String),
264}
265
266pub type WebsocketResult<T> = std::result::Result<T, WebsocketError>;
268
269type WsStream = tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>;
270type WsSend =
271 futures::stream::SplitSink<WsStream, tokio_tungstenite::tungstenite::protocol::Message>;
272type WsSendSync = Arc<tokio::sync::Mutex<WsSend>>;
273type WsRecv = futures::stream::SplitStream<WsStream>;
274type WsRecvSync = Arc<tokio::sync::Mutex<WsRecv>>;
275
276#[derive(Clone)]
277struct WsCore {
278 pub send: WsSendSync,
279 pub recv: WsRecvSync,
280 pub rmap: RMap,
281 pub timeout: std::time::Duration,
282}
283
284#[derive(Clone)]
285struct WsCoreSync(Arc<std::sync::Mutex<Option<WsCore>>>);
286
287impl PartialEq for WsCoreSync {
288 fn eq(&self, other: &Self) -> bool {
289 Arc::ptr_eq(&self.0, &other.0)
290 }
291}
292
293impl WsCoreSync {
294 fn close(&self) {
295 if let Some(core) = self.0.lock().unwrap().take() {
296 core.rmap.close();
297 tokio::task::spawn(async move {
298 use futures::sink::SinkExt;
299 let _ = core.send.lock().await.close().await;
300 });
301 }
302 }
303
304 fn close_if_err<R>(&self, r: WebsocketResult<R>) -> WebsocketResult<R> {
305 match r {
306 Err(e @ WebsocketError::Deserialize { .. }) => {
307 Err(e)
310 }
311 Err(err) => {
312 self.close();
313 Err(err)
314 }
315 Ok(res) => Ok(res),
316 }
317 }
318
319 pub async fn exec<F, C, R>(&self, c: C) -> WebsocketResult<R>
320 where
321 F: std::future::Future<Output = WebsocketResult<R>>,
322 C: FnOnce(WsCoreSync, WsCore) -> F,
323 {
324 let core = match self.0.lock().unwrap().as_ref() {
325 Some(core) => core.clone(),
326 None => return Err(WebsocketError::Close("No connection".to_string())),
327 };
328 self.close_if_err(c(self.clone(), core).await)
329 }
330}
331
332#[derive(PartialEq)]
334pub struct WebsocketRespond {
335 id: u64,
336 core: WsCoreSync,
337}
338
339impl std::fmt::Debug for WebsocketRespond {
340 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 f.debug_struct("WebsocketRespond")
342 .field("id", &self.id)
343 .finish()
344 }
345}
346
347impl WebsocketRespond {
348 pub async fn respond<S>(self, s: S) -> WebsocketResult<()>
350 where
351 S: std::fmt::Debug,
352 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
353 {
354 tracing::trace!(?s, %self.id, "OutResponse");
355 use futures::sink::SinkExt;
356 self.core
357 .exec(move |_, core| async move {
358 tokio::time::timeout(core.timeout, async {
359 let s = WireMessage::response(self.id, s)?;
360 core.send.lock().await.send(s).await.map_err(Box::new)?;
361 Ok(())
362 })
363 .await?
364 })
365 .await
366 }
367}
368
369#[derive(Debug, PartialEq)]
371pub enum ReceiveMessage<D>
372where
373 D: std::fmt::Debug,
374 SerializedBytes: TryInto<D, Error = SerializedBytesError>,
375{
376 Authenticate(Vec<u8>),
378
379 Signal(Vec<u8>),
381
382 Request(D, WebsocketRespond),
384 BadRequest(WebsocketRespond),
386}
387
388pub struct WebsocketReceiver(
393 WsCoreSync,
394 std::net::SocketAddr,
395 tokio::task::JoinHandle<()>,
396);
397
398impl Drop for WebsocketReceiver {
399 fn drop(&mut self) {
400 self.0.close();
401 self.2.abort();
402 }
403}
404
405impl WebsocketReceiver {
406 fn new(core: WsCoreSync, addr: std::net::SocketAddr) -> Self {
407 let core2 = core.clone();
408 let ping_task = tokio::task::spawn(async move {
409 loop {
410 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
411 let core = core2.0.lock().unwrap().as_ref().cloned();
412 if let Some(core) = core {
413 use futures::sink::SinkExt;
414 if core
415 .send
416 .lock()
417 .await
418 .send(Message::Ping(Bytes::new()))
419 .await
420 .is_err()
421 {
422 core2.close();
423 }
424 } else {
425 break;
426 }
427 }
428 });
429 Self(core, addr, ping_task)
430 }
431
432 pub fn peer_addr(&self) -> std::net::SocketAddr {
434 self.1
435 }
436
437 pub async fn recv<D>(&mut self) -> WebsocketResult<ReceiveMessage<D>>
439 where
440 D: std::fmt::Debug,
441 SerializedBytes: TryInto<D, Error = SerializedBytesError>,
442 {
443 match self.recv_inner().await {
444 Err(err) => {
445 tracing::warn!(?err, "WebsocketReceiver Error");
446 Err(err)
447 }
448 Ok(msg) => Ok(msg),
449 }
450 }
451
452 async fn recv_inner<D>(&mut self) -> WebsocketResult<ReceiveMessage<D>>
453 where
454 D: std::fmt::Debug,
455 SerializedBytes: TryInto<D, Error = SerializedBytesError>,
456 {
457 use futures::sink::SinkExt;
458 use futures::stream::StreamExt;
459 loop {
460 if let Some(result) = self
461 .0
462 .exec(move |core_sync, core| async move {
463 let msg = core
464 .recv
465 .lock()
466 .await
467 .next()
468 .await
469 .ok_or::<WebsocketError>(WebsocketError::Other(
470 "ReceiverClosed".to_string(),
471 ))?
472 .map_err(Box::new)?;
473 let msg = match msg {
474 Message::Text(s) => s.as_bytes().to_vec(),
475 Message::Binary(b) => b.to_vec(),
476 Message::Ping(b) => {
477 core.send
478 .lock()
479 .await
480 .send(Message::Pong(b))
481 .await
482 .map_err(Box::new)?;
483 return Ok(None);
484 }
485 Message::Pong(_) => return Ok(None),
486 Message::Close(frame) => {
487 return Err(WebsocketError::Close(format!("{frame:?}")));
488 }
489 Message::Frame(_) => {
490 return Err(WebsocketError::Other("UnexpectedRawFrame".to_string()))
491 }
492 };
493 match WireMessage::try_from_bytes(msg)? {
494 WireMessage::Authenticate { data } => {
495 Ok(Some(ReceiveMessage::Authenticate(data)))
496 }
497 WireMessage::Request { id, data } => {
498 let resp = WebsocketRespond {
499 id,
500 core: core_sync,
501 };
502 let data: D =
503 match SerializedBytes::from(UnsafeBytes::from(data)).try_into() {
504 Ok(value) => value,
505 Err(_) => {
506 return Ok(Some(ReceiveMessage::BadRequest(resp)));
507 }
508 };
509 tracing::trace!(?data, %id, "InRequest");
510 Ok(Some(ReceiveMessage::Request(data, resp)))
511 }
512 WireMessage::Response { id, data } => {
513 if let Some(sender) = core.rmap.remove(id) {
514 if let Some(data) = data {
515 let data = SerializedBytes::from(UnsafeBytes::from(data));
516 tracing::trace!(%id, ?data, "InResponse");
517 let _ = sender.send(Ok(data));
518 }
519 }
520 Ok(None)
521 }
522 WireMessage::Signal { data } => Ok(Some(ReceiveMessage::Signal(data))),
523 }
524 })
525 .await?
526 {
527 return Ok(result);
528 }
529 }
530 }
531}
532
533#[derive(Clone)]
537pub struct WebsocketSender(WsCoreSync, std::time::Duration);
538
539impl WebsocketSender {
540 pub async fn authenticate<S>(&self, s: S) -> WebsocketResult<()>
542 where
543 S: std::fmt::Debug,
544 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
545 {
546 self.authenticate_timeout(s, self.1).await
547 }
548
549 pub async fn authenticate_timeout<S>(
551 &self,
552 s: S,
553 timeout: std::time::Duration,
554 ) -> WebsocketResult<()>
555 where
556 S: std::fmt::Debug,
557 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
558 {
559 use futures::sink::SinkExt;
560 self.0
561 .exec(move |_, core| async move {
562 tokio::time::timeout(timeout, async {
563 let s = WireMessage::authenticate(s)?;
564 core.send.lock().await.send(s).await.map_err(Box::new)?;
565 Ok(())
566 })
567 .await?
568 })
569 .await
570 }
571
572 pub async fn request<S, R>(&self, s: S) -> WebsocketResult<R>
576 where
577 S: std::fmt::Debug,
578 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
579 R: serde::de::DeserializeOwned + std::fmt::Debug,
580 {
581 self.request_timeout(s, self.1).await
582 }
583
584 pub async fn request_timeout<S, R>(
586 &self,
587 s: S,
588 timeout: std::time::Duration,
589 ) -> WebsocketResult<R>
590 where
591 S: std::fmt::Debug,
592 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
593 R: serde::de::DeserializeOwned + std::fmt::Debug,
594 {
595 let timeout_at = tokio::time::Instant::now() + timeout;
596
597 use futures::sink::SinkExt;
598
599 let (s, id) = WireMessage::request(s)?;
600
601 struct D(RMap, u64);
603
604 impl Drop for D {
605 fn drop(&mut self) {
606 self.0.remove(self.1);
607 }
608 }
609
610 let (resp_s, resp_r) = tokio::sync::oneshot::channel();
611
612 let _drop = self
613 .0
614 .exec(move |_, core| async move {
615 let drop = D(core.rmap.clone(), id);
617
618 core.rmap.insert(id, resp_s);
620
621 tokio::time::timeout_at(timeout_at, async move {
622 core.send.lock().await.send(s).await.map_err(Box::new)?;
624
625 Ok(drop)
626 })
627 .await?
628 })
629 .await?;
630
631 tokio::time::timeout_at(timeout_at, async {
636 let resp = resp_r
638 .await
639 .map_err(|_| WebsocketError::Other("ResponderDropped".to_string()))??;
640
641 let res = decode(&Vec::from(UnsafeBytes::from(resp)))?;
643 tracing::trace!(?res, %id, "OutRequestResponse");
644 Ok(res)
645 })
646 .await?
647 }
648
649 pub async fn signal<S>(&self, s: S) -> WebsocketResult<()>
651 where
652 S: std::fmt::Debug,
653 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
654 {
655 self.signal_timeout(s, self.1).await
656 }
657
658 pub async fn signal_timeout<S>(&self, s: S, timeout: std::time::Duration) -> WebsocketResult<()>
660 where
661 S: std::fmt::Debug,
662 SerializedBytes: TryFrom<S, Error = SerializedBytesError>,
663 {
664 use futures::sink::SinkExt;
665 self.0
666 .exec(move |_, core| async move {
667 tokio::time::timeout(timeout, async {
668 let s = WireMessage::signal(s)?;
669 core.send.lock().await.send(s).await.map_err(Box::new)?;
670 Ok(())
671 })
672 .await?
673 })
674 .await
675 }
676}
677
678fn split(
679 stream: WsStream,
680 timeout: std::time::Duration,
681 peer_addr: std::net::SocketAddr,
682) -> WebsocketResult<(WebsocketSender, WebsocketReceiver)> {
683 let (sink, stream) = futures::stream::StreamExt::split(stream);
684
685 let core = WsCore {
691 send: Arc::new(tokio::sync::Mutex::new(sink)),
692 recv: Arc::new(tokio::sync::Mutex::new(stream)),
693 rmap: RMap::default(),
694 timeout,
695 };
696
697 let core_send = WsCoreSync(Arc::new(std::sync::Mutex::new(Some(core))));
698 let core_recv = core_send.clone();
699
700 Ok((
701 WebsocketSender(core_send, timeout),
702 WebsocketReceiver::new(core_recv, peer_addr),
703 ))
704}
705
706pub async fn connect(
708 config: Arc<WebsocketConfig>,
709 request: impl Into<ConnectRequest>,
710) -> WebsocketResult<(WebsocketSender, WebsocketReceiver)> {
711 let request = request.into();
712 let stream = tokio::net::TcpStream::connect(request.addr).await?;
713 let peer_addr = stream.peer_addr()?;
714 let (stream, _addr) = tokio_tungstenite::client_async_with_config(
715 request.into_client_request()?,
716 stream,
717 Some(config.as_tungstenite()),
718 )
719 .await
720 .map_err(Box::new)?;
721 split(stream, config.default_request_timeout, peer_addr)
722}
723
724pub struct ConnectRequest {
726 addr: std::net::SocketAddr,
727 headers: HeaderMap<HeaderValue>,
728}
729
730impl From<std::net::SocketAddr> for ConnectRequest {
731 fn from(addr: std::net::SocketAddr) -> Self {
732 Self::new(addr)
733 }
734}
735
736impl ConnectRequest {
737 pub fn new(addr: std::net::SocketAddr) -> Self {
739 let mut cr = ConnectRequest {
740 addr,
741 headers: HeaderMap::new(),
742 };
743
744 cr.headers.insert(
747 "Origin",
748 HeaderValue::from_str("holochain_websocket").expect("Invalid Origin value"),
749 );
750
751 cr
752 }
753
754 pub fn try_set_header(mut self, name: &'static str, value: &str) -> Result<Self> {
758 self.headers
759 .insert(name, HeaderValue::from_str(value).map_err(Error::other)?);
760 Ok(self)
761 }
762
763 fn into_client_request(
764 self,
765 ) -> Result<impl tokio_tungstenite::tungstenite::client::IntoClientRequest + Unpin> {
766 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
767 let mut req =
768 String::into_client_request(format!("ws://{}", self.addr)).map_err(Error::other)?;
769 for (name, value) in self.headers {
770 if let Some(name) = name {
771 req.headers_mut().insert(name, value);
772 } else {
773 tracing::warn!("Dropping invalid header");
774 }
775 }
776 Ok(req)
777 }
778
779 #[cfg(test)]
780 pub(crate) fn clear_headers(mut self) -> Self {
781 self.headers.clear();
782
783 self
784 }
785}
786
787#[async_trait::async_trait]
789trait TcpListener: Send + Sync {
790 async fn accept(&self) -> Result<(tokio::net::TcpStream, SocketAddr)>;
791
792 fn local_addrs(&self) -> Result<Vec<SocketAddr>>;
793}
794
795#[async_trait::async_trait]
796impl TcpListener for tokio::net::TcpListener {
797 async fn accept(&self) -> Result<(tokio::net::TcpStream, SocketAddr)> {
798 self.accept().await
799 }
800
801 fn local_addrs(&self) -> Result<Vec<SocketAddr>> {
802 Ok(vec![self.local_addr()?])
803 }
804}
805
806struct DualStackListener {
807 v4: tokio::net::TcpListener,
808 v6: tokio::net::TcpListener,
809}
810
811#[async_trait::async_trait]
812impl TcpListener for DualStackListener {
813 async fn accept(&self) -> Result<(tokio::net::TcpStream, SocketAddr)> {
814 let (stream, addr) = select! {
815 res = self.v4.accept() => res?,
816 res = self.v6.accept() => res?,
817 };
818 Ok((stream, addr))
819 }
820
821 fn local_addrs(&self) -> Result<Vec<SocketAddr>> {
822 Ok(vec![self.v4.local_addr()?, self.v6.local_addr()?])
823 }
824}
825
826pub struct WebsocketListener {
828 config: Arc<WebsocketConfig>,
829 access_control: Arc<AllowedOrigins>,
830 listener: Box<dyn TcpListener>,
831}
832
833impl Drop for WebsocketListener {
834 fn drop(&mut self) {
835 tracing::info!("WebsocketListenerDrop");
836 }
837}
838
839impl WebsocketListener {
840 pub async fn bind(config: Arc<WebsocketConfig>, addr: impl ToSocketAddrs) -> Result<Self> {
842 let access_control = Arc::new(config.allowed_origins.clone().ok_or_else(|| {
843 Error::other("WebsocketListener requires allowed_origins to be set in the config")
844 })?);
845
846 let listener = tokio::net::TcpListener::bind(addr).await?;
847
848 let addr = listener.local_addr()?;
849 tracing::info!(?addr, "WebsocketListener Listening");
850
851 Ok(Self {
852 config,
853 access_control,
854 listener: Box::new(listener),
855 })
856 }
857
858 pub async fn dual_bind(
875 config: Arc<WebsocketConfig>,
876 addr_v4: SocketAddrV4,
877 addr_v6: SocketAddrV6,
878 ) -> Result<Self> {
879 let access_control = Arc::new(config.allowed_origins.clone().ok_or_else(|| {
880 Error::other("WebsocketListener requires allowed_origins to be set in the config")
881 })?);
882
883 let addr_v6: SocketAddr = addr_v6.into();
884 let mut addr_v4: SocketAddr = addr_v4.into();
885
886 if addr_v6.port() != 0 && addr_v6.port() != addr_v4.port() {
888 return Err(Error::other(
889 "dual_bind requires the same port for IPv4 and IPv6",
890 ));
891 }
892
893 let mut listener: Option<DualStackListener> = None;
897 for _ in 0..5 {
898 let v6_listener = match tokio::net::TcpListener::bind(addr_v6).await {
899 Ok(l) => l,
900 Err(e) if e.kind() == ErrorKind::AddrNotAvailable => {
902 tracing::info!(?e, "Failed to bind IPv6 listener because IPv6 appears to be disabled, falling back to IPv4 only");
903 return Self::bind(config, addr_v4).await;
904 }
905 Err(e) => {
906 tracing::error!("Failed to bind IPv6 listener: {:?}", e);
907 return Err(e);
908 }
909 };
910
911 addr_v4.set_port(v6_listener.local_addr()?.port());
912
913 let v4_listener = match tokio::net::TcpListener::bind(addr_v4).await {
914 Ok(l) => l,
915 Err(e) if e.kind() == ErrorKind::AddrNotAvailable => {
917 tracing::info!(?e, "Failed to bind IPv4 listener because IPv4 appears to be disabled, falling back to IPv6 only");
918 return Ok(Self {
921 config,
922 access_control,
923 listener: Box::new(v6_listener),
924 });
925 }
926 Err(e) if addr_v6.ip().is_unspecified() && e.kind() == ErrorKind::AddrInUse => {
929 tracing::info!(?e, "Failed to bind IPv4 listener because the address is already in use, falling back to IPv6 only");
930 return Ok(Self {
933 config,
934 access_control,
935 listener: Box::new(v6_listener),
936 });
937 }
938 Err(e) if addr_v6.port() == 0 && e.kind() == ErrorKind::AddrInUse => {
941 tracing::warn!(?e, "Failed to bind the same port for IPv4 that was selected for IPv6, retrying with a new port");
942 continue;
943 }
944 Err(e) => {
945 tracing::error!("Failed to bind IPv4 listener: {:?}", e);
946 return Err(e);
947 }
948 };
949
950 listener = Some(DualStackListener {
951 v4: v4_listener,
952 v6: v6_listener,
953 });
954 break;
955 }
956
957 let listener = listener.ok_or_else(|| {
960 Error::other("Failed to bind listener to IPv4 and IPv6 interfaces after 5 retries")
961 })?;
962
963 let addr = listener.v4.local_addr()?;
964 tracing::info!(?addr, "WebsocketListener listening");
965
966 let addr = listener.v6.local_addr()?;
967 tracing::info!(?addr, "WebsocketListener listening");
968
969 Ok(Self {
970 config,
971 access_control,
972 listener: Box::new(listener),
973 })
974 }
975
976 pub fn local_addrs(&self) -> Result<Vec<std::net::SocketAddr>> {
978 self.listener.local_addrs()
979 }
980
981 pub async fn accept(&self) -> WebsocketResult<(WebsocketSender, WebsocketReceiver)> {
983 let (stream, addr) = self.listener.accept().await?;
984 tracing::debug!(?addr, "Accept Incoming Websocket Connection");
985 let stream = tokio_tungstenite::accept_hdr_async_with_config(
986 stream,
987 ConnectCallback {
988 allowed_origin: self.access_control.clone(),
989 },
990 Some(self.config.as_tungstenite()),
991 )
992 .await
993 .map_err(Error::other)?;
994 split(stream, self.config.default_request_timeout, addr)
995 }
996}
997
998struct ConnectCallback {
999 allowed_origin: Arc<AllowedOrigins>,
1000}
1001
1002impl Callback for ConnectCallback {
1003 fn on_request(
1004 self,
1005 request: &Request,
1006 response: Response,
1007 ) -> std::result::Result<Response, ErrorResponse> {
1008 tracing::trace!(
1009 "Checking incoming websocket connection request with allowed origin {:?}: {:?}",
1010 self.allowed_origin,
1011 request.headers()
1012 );
1013 match request
1014 .headers()
1015 .get("Origin")
1016 .and_then(|v| v.to_str().ok())
1017 {
1018 Some(origin) => {
1019 if self.allowed_origin.is_allowed(origin) {
1020 Ok(response)
1021 } else {
1022 tracing::warn!("Rejecting websocket connection request with disallowed `Origin` header: {:?}", request);
1023 let allowed_origin: String = self.allowed_origin.as_ref().clone().into();
1024 match HeaderValue::from_str(&allowed_origin) {
1025 Ok(allowed_origin) => {
1026 let mut err_response = ErrorResponse::new(None);
1027 *err_response.status_mut() = StatusCode::BAD_REQUEST;
1028 err_response
1029 .headers_mut()
1030 .insert("Access-Control-Allow-Origin", allowed_origin);
1031 Err(err_response)
1032 }
1033 Err(_) => {
1034 let mut err_response = ErrorResponse::new(Some(
1036 "Invalid listener configuration for `Origin`".to_string(),
1037 ));
1038 *err_response.status_mut() = StatusCode::BAD_REQUEST;
1039 Err(err_response)
1040 }
1041 }
1042 }
1043 }
1044 None => {
1045 tracing::warn!(
1046 "Rejecting websocket connection request with missing `Origin` header: {:?}",
1047 request
1048 );
1049 let mut err_response =
1050 ErrorResponse::new(Some("Missing `Origin` header".to_string()));
1051 *err_response.status_mut() = StatusCode::BAD_REQUEST;
1052 Err(err_response)
1053 }
1054 }
1055 }
1056}
1057
1058#[cfg(test)]
1059mod test;