aggligator_transport_websocket/
lib.rs

1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![doc(
4    html_logo_url = "https://raw.githubusercontent.com/surban/aggligator/master/.misc/aggligator.png",
5    html_favicon_url = "https://raw.githubusercontent.com/surban/aggligator/master/.misc/aggligator.png",
6    issue_tracker_base_url = "https://github.com/surban/aggligator/issues/"
7)]
8
9//! [Aggligator](aggligator) transport: WebSocket on a native platform (not web).
10
11use async_trait::async_trait;
12use axum::{
13    body::Body,
14    extract::{ConnectInfo, WebSocketUpgrade},
15    http::StatusCode,
16    response::Response,
17    routing::get,
18    Router,
19};
20use bytes::Bytes;
21use futures::{SinkExt, StreamExt, TryStreamExt};
22use std::{
23    any::Any,
24    cmp::Ordering,
25    collections::{HashMap, HashSet},
26    fmt,
27    hash::{Hash, Hasher},
28    io::{Error, ErrorKind, Result},
29    net::{IpAddr, Ipv6Addr, SocketAddr},
30    sync::Arc,
31    time::Duration,
32};
33use tokio::{
34    net::TcpSocket,
35    sync::{mpsc, watch, Mutex},
36    time::sleep,
37};
38use tokio_tungstenite::{client_async_tls_with_config, tungstenite::protocol::WebSocketConfig, Connector};
39use tokio_util::io::{CopyToBytes, SinkWriter, StreamReader};
40use url::Url;
41
42use aggligator::{
43    control::Direction,
44    io::{IoBox, StreamBox},
45    transport::{AcceptedStreamBox, AcceptingTransport, ConnectingTransport, LinkTag, LinkTagBox},
46    Link,
47};
48use aggligator_transport_tcp::util::{self, NetworkInterface};
49pub use aggligator_transport_tcp::IpVersion;
50
51static NAME: &str = "websocket";
52
53/// Link tag for outgoing WebSocket link.
54#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
55pub struct OutgoingWebSocketLinkTag {
56    /// Local interface name.
57    pub interface: Option<Vec<u8>>,
58    /// Remote socket address.
59    pub remote: SocketAddr,
60    /// Remote URL.
61    pub url: String,
62    /// Whether to use TLS for connecting.
63    pub tls: bool,
64}
65
66impl fmt::Display for OutgoingWebSocketLinkTag {
67    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68        write!(
69            f,
70            "{} -> {} ({})",
71            String::from_utf8_lossy(self.interface.as_deref().unwrap_or_default()),
72            &self.remote,
73            &self.url
74        )
75    }
76}
77
78impl LinkTag for OutgoingWebSocketLinkTag {
79    fn transport_name(&self) -> &str {
80        NAME
81    }
82
83    fn direction(&self) -> Direction {
84        Direction::Outgoing
85    }
86
87    fn user_data(&self) -> Vec<u8> {
88        self.interface.clone().unwrap_or_default()
89    }
90
91    fn as_any(&self) -> &dyn Any {
92        self
93    }
94
95    fn box_clone(&self) -> LinkTagBox {
96        Box::new(self.clone())
97    }
98
99    fn dyn_cmp(&self, other: &dyn LinkTag) -> Ordering {
100        let other = other.as_any().downcast_ref::<Self>().unwrap();
101        Ord::cmp(self, other)
102    }
103
104    fn dyn_hash(&self, mut state: &mut dyn Hasher) {
105        Hash::hash(self, &mut state)
106    }
107}
108
109/// WebSocket transport for outgoing connections.
110///
111/// This transport is packet-based.
112#[derive(Clone)]
113pub struct WebSocketConnector {
114    urls: Vec<Url>,
115    ip_version: IpVersion,
116    resolve_interval: Duration,
117    connector: Option<Connector>,
118    web_socket_config: Option<WebSocketConfig>,
119    multi_interface: bool,
120    interface_filter: Arc<dyn Fn(&NetworkInterface) -> bool + Send + Sync>,
121}
122
123impl fmt::Debug for WebSocketConnector {
124    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
125        f.debug_struct("WebSocketConnector")
126            .field("urls", &self.urls)
127            .field("ip_version", &self.ip_version)
128            .field("resolve_interval", &self.resolve_interval)
129            .field("web_socket_config", &self.web_socket_config)
130            .field("multi_interface", &self.multi_interface)
131            .finish()
132    }
133}
134
135impl fmt::Display for WebSocketConnector {
136    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
137        let urls: Vec<_> = self.urls.iter().map(|url| url.to_string()).collect();
138        if self.urls.len() > 1 {
139            write!(f, "[{}]", urls.join(", "))
140        } else {
141            write!(f, "{}", &urls[0])
142        }
143    }
144}
145
146impl WebSocketConnector {
147    /// Create a new WebSocket transport for outgoing connections.
148    ///
149    /// `urls` contains one or more WebSocket URLs of the target.
150    ///
151    /// It is checked at creation that at least one URL can be resolved to an IP address.
152    ///
153    /// Host name resolution is retried periodically, thus DNS updates will be taken
154    /// into account without the need to recreate this transport.
155    pub async fn new(urls: impl IntoIterator<Item = impl AsRef<str>>) -> Result<Self> {
156        let this = Self::unresolved(urls).await?;
157
158        let addrs = this.resolve().await;
159        if addrs.values().all(|addrs| addrs.is_empty()) {
160            return Err(Error::new(ErrorKind::NotFound, "cannot resolve IP address of any URL"));
161        }
162        tracing::info!(?addrs, "URLs initially resolved");
163
164        Ok(this)
165    }
166
167    /// Create a new WebSocket transport for outgoing connections ut checking that at least one URL can be resolved.
168    ///
169    /// `urls` contains one or more WebSocket URLs of the target.
170    ///
171    /// Host name resolution is retried periodically, thus DNS updates will be taken
172    /// into account without the need to recreate this transport.
173    pub async fn unresolved(urls: impl IntoIterator<Item = impl AsRef<str>>) -> Result<Self> {
174        let urls = urls
175            .into_iter()
176            .map(|url| url.as_ref().parse::<Url>())
177            .collect::<std::result::Result<Vec<_>, _>>()
178            .map_err(|err| Error::new(ErrorKind::InvalidInput, err))?;
179
180        if urls.is_empty() {
181            return Err(Error::new(ErrorKind::InvalidInput, "at least one URL is required"));
182        }
183        for url in &urls {
184            if !url.has_host() {
185                return Err(Error::new(ErrorKind::InvalidInput, "URL must have a host"));
186            }
187            if !["ws", "wss"].contains(&url.scheme()) {
188                return Err(Error::new(ErrorKind::InvalidInput, "URL must have scheme ws or wss"));
189            }
190        }
191
192        Ok(Self {
193            urls,
194            ip_version: IpVersion::Both,
195            resolve_interval: Duration::from_secs(10),
196            connector: None,
197            web_socket_config: None,
198            multi_interface: !cfg!(target_os = "android"),
199            interface_filter: Arc::new(|_| true),
200        })
201    }
202
203    /// Sets the IP version used for connecting.
204    pub fn set_ip_version(&mut self, ip_version: IpVersion) {
205        self.ip_version = ip_version;
206    }
207
208    /// Sets the interval for re-resolving the hostname and checking for changed network interfaces.
209    pub fn set_resolve_interval(&mut self, resolve_interval: Duration) {
210        self.resolve_interval = resolve_interval;
211    }
212
213    /// Sets the WebSocket connector for establishing the connection.
214    ///
215    /// Allows control of TLS.
216    pub fn set_connector(&mut self, connector: Option<Connector>) {
217        self.connector = connector;
218    }
219
220    /// Sets the WebSocket connection configuration.
221    pub fn set_web_socket_config(&mut self, web_socket_config: Option<WebSocketConfig>) {
222        self.web_socket_config = web_socket_config;
223    }
224
225    /// Sets whether all available local interfaces should be used for connecting.
226    ///
227    /// If this is true (default for non-Android platforms), a separate link is
228    /// established for each pair of server IP and local interface. Each outgoing socket
229    /// is explicitly bound to a local interface.
230    ///
231    /// If this is false (default for Android platform), one link is established for
232    /// each server IP. The operating system automatically assigns a local interface
233    /// for the outgoing socket.
234    pub fn set_multi_interface(&mut self, multi_interface: bool) {
235        self.multi_interface = multi_interface;
236    }
237
238    /// Sets the local interface filter.
239    ///
240    /// It is only used when multi interface is enabled.
241    ///
242    /// The provided function is called for each discoved local interface and should
243    /// return whether the interface should be used for establishing links.
244    ///
245    /// By default all local interfaces are used.
246    pub fn set_interface_filter(
247        &mut self, interface_filter: impl Fn(&NetworkInterface) -> bool + Send + Sync + 'static,
248    ) {
249        self.interface_filter = Arc::new(interface_filter);
250    }
251
252    /// Resolve URLs to socket addresses.
253    async fn resolve(&self) -> HashMap<&Url, Vec<SocketAddr>> {
254        let mut url_addrs = HashMap::new();
255
256        for url in &self.urls {
257            let host = url.host_str().unwrap();
258            let port = url.port_or_known_default().unwrap();
259            let addrs = util::resolve_hosts(&[format!("{host}:{port}")], self.ip_version).await;
260            url_addrs.insert(url, addrs);
261        }
262
263        url_addrs
264    }
265}
266
267#[async_trait]
268impl ConnectingTransport for WebSocketConnector {
269    fn name(&self) -> &str {
270        NAME
271    }
272
273    async fn link_tags(&self, tx: watch::Sender<HashSet<LinkTagBox>>) -> Result<()> {
274        loop {
275            let interfaces: Option<Vec<NetworkInterface>> = match self.multi_interface {
276                true => Some(
277                    util::local_interfaces()?
278                        .into_iter()
279                        .filter(|iface| (self.interface_filter)(iface))
280                        .collect(),
281                ),
282                false => None,
283            };
284
285            let mut tags: HashSet<LinkTagBox> = HashSet::new();
286            for (url, addrs) in self.resolve().await {
287                for addr in addrs {
288                    match &interfaces {
289                        Some(interfaces) => {
290                            for interface in util::interface_names_for_target(interfaces, addr) {
291                                let tag = OutgoingWebSocketLinkTag {
292                                    interface: Some(interface),
293                                    remote: addr,
294                                    url: url.to_string(),
295                                    tls: url.scheme() == "wss",
296                                };
297                                tags.insert(Box::new(tag));
298                            }
299                        }
300                        None => {
301                            let tag = OutgoingWebSocketLinkTag {
302                                interface: None,
303                                remote: addr,
304                                url: url.to_string(),
305                                tls: url.scheme() == "wss",
306                            };
307                            tags.insert(Box::new(tag));
308                        }
309                    }
310                }
311            }
312
313            tx.send_if_modified(|v| {
314                if *v != tags {
315                    *v = tags;
316                    true
317                } else {
318                    false
319                }
320            });
321
322            sleep(self.resolve_interval).await;
323        }
324    }
325
326    async fn connect(&self, tag: &dyn LinkTag) -> Result<StreamBox> {
327        let tag: &OutgoingWebSocketLinkTag = tag.as_any().downcast_ref().unwrap();
328
329        // Establish TCP connection to server.
330        let socket = match tag.remote.ip() {
331            IpAddr::V4(_) => TcpSocket::new_v4(),
332            IpAddr::V6(_) => TcpSocket::new_v6(),
333        }?;
334
335        if let Some(interface) = &tag.interface {
336            util::bind_socket_to_interface(&socket, interface, tag.remote.ip())?;
337        }
338
339        let stream = socket.connect(tag.remote).await?;
340        let _ = stream.set_nodelay(true);
341
342        // Convert into WebSocket.
343        let connector = if tag.tls { self.connector.clone() } else { Some(Connector::Plain) };
344        let (web_socket, _rsp) =
345            client_async_tls_with_config(&tag.url, stream, self.web_socket_config, connector)
346                .await
347                .map_err(|err| Error::new(ErrorKind::ConnectionRefused, err))?;
348
349        // Adapt WebSocket IO.
350        let (ws_tx, ws_rx) = web_socket.split();
351        let ws_tx = Box::pin(
352            ws_tx
353                .with(
354                    |data: Bytes| async move { Ok::<_, tungstenite::Error>(tungstenite::Message::Binary(data)) },
355                )
356                .sink_map_err(Error::other),
357        );
358        let ws_write = SinkWriter::new(CopyToBytes::new(ws_tx));
359
360        let ws_rx = Box::pin(
361            ws_rx
362                .try_filter_map(|msg: tungstenite::Message| async move {
363                    if let tungstenite::Message::Binary(data) = msg {
364                        Ok(Some(data))
365                    } else {
366                        Ok(None)
367                    }
368                })
369                .map_err(Error::other),
370        );
371        let ws_read = StreamReader::new(ws_rx);
372
373        Ok(IoBox::new(ws_read, ws_write).into())
374    }
375
376    async fn link_filter(&self, new: &Link<LinkTagBox>, existing: &[Link<LinkTagBox>]) -> bool {
377        let Some(new_tag) = new.tag().as_any().downcast_ref::<OutgoingWebSocketLinkTag>() else { return true };
378
379        let intro = format!(
380            "Judging {} WebSocket link {} {} ({}) on {}",
381            new.direction(),
382            match new.direction() {
383                Direction::Incoming => "from",
384                Direction::Outgoing => "to",
385            },
386            new_tag.remote,
387            String::from_utf8_lossy(new.remote_user_data()),
388            String::from_utf8_lossy(new_tag.interface.as_deref().unwrap_or(b"any interface"))
389        );
390
391        match existing.iter().find(|link| {
392            let Some(tag) = link.tag().as_any().downcast_ref::<OutgoingWebSocketLinkTag>() else { return false };
393            tag.interface == new_tag.interface && link.remote_user_data() == new.remote_user_data()
394        }) {
395            Some(other) => {
396                let other_tag = other.tag().as_any().downcast_ref::<OutgoingWebSocketLinkTag>().unwrap();
397                tracing::debug!("{intro} => link {} is redundant, rejecting.", other_tag.remote);
398                false
399            }
400            None => {
401                tracing::debug!("{intro} => accepted.");
402                true
403            }
404        }
405    }
406}
407
408/// Link tag for incoming WebSocket link.
409#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
410pub struct IncomingWebSocketLinkTag {
411    /// Local socket address.
412    pub local: SocketAddr,
413    /// Remote socket address.
414    pub remote: SocketAddr,
415    /// WebSocket sub-protocol.
416    pub protocol: Option<String>,
417}
418
419impl fmt::Display for IncomingWebSocketLinkTag {
420    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
421        write!(
422            f,
423            "{} <- {}{}",
424            &self.local,
425            &self.remote,
426            match &self.protocol {
427                Some(protocol) => format!(" ({protocol})"),
428                None => String::new(),
429            }
430        )
431    }
432}
433
434impl LinkTag for IncomingWebSocketLinkTag {
435    fn transport_name(&self) -> &str {
436        NAME
437    }
438
439    fn direction(&self) -> Direction {
440        Direction::Incoming
441    }
442
443    fn user_data(&self) -> Vec<u8> {
444        match self.local.ip() {
445            IpAddr::V4(ip) => ip.octets().into(),
446            IpAddr::V6(ip) => ip.octets().into(),
447        }
448    }
449
450    fn as_any(&self) -> &dyn Any {
451        self
452    }
453
454    fn box_clone(&self) -> LinkTagBox {
455        Box::new(self.clone())
456    }
457
458    fn dyn_cmp(&self, other: &dyn LinkTag) -> Ordering {
459        let other = other.as_any().downcast_ref::<Self>().unwrap();
460        Ord::cmp(self, other)
461    }
462
463    fn dyn_hash(&self, mut state: &mut dyn Hasher) {
464        Hash::hash(self, &mut state)
465    }
466}
467
468struct IncomingWebSocket {
469    local: SocketAddr,
470    remote: SocketAddr,
471    web_socket: axum::extract::ws::WebSocket,
472}
473
474/// Builds a [WebSocket transport listener](WebSocketAcceptor).
475pub struct WebSocketAcceptorBuilder {
476    tx: mpsc::Sender<IncomingWebSocket>,
477    rx: mpsc::Receiver<IncomingWebSocket>,
478}
479
480impl fmt::Debug for WebSocketAcceptorBuilder {
481    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
482        f.debug_struct("WebSocketAcceptorBuilder").finish()
483    }
484}
485
486impl WebSocketAcceptorBuilder {
487    fn new() -> Self {
488        let (tx, rx) = mpsc::channel(16);
489        Self { tx, rx }
490    }
491}
492
493impl WebSocketAcceptorBuilder {
494    /// Creates a Axum router that accepts a WebSocket connection at the specified `path`.
495    ///
496    /// The router must be converted into a service with connection info,
497    /// see [`axum::Router::into_make_service_with_connect_info`] with
498    /// connection info type [`SocketAddr`].
499    pub fn router(&self, path: &str) -> Router {
500        let protocols: [String; 0] = [];
501        self.custom_router(path, SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), protocols)
502    }
503
504    /// Creates a Axum router that accepts a WebSocket connection at the specified `path` with custom options.
505    ///
506    /// `local_addr` specifies to local address the axum server is listening on.
507    /// This is used for link filtering if the server is listening on multiple IP addresses.
508    ///
509    /// `protocols` specifies the known WebSocket protocols to advertise to a connecting client.
510    ///
511    /// The router must be converted into a service with connection info,
512    /// see [`axum::Router::into_make_service_with_connect_info`] with
513    /// connection info type [`SocketAddr`].
514    pub fn custom_router(
515        &self, path: &str, local_addr: SocketAddr, protocols: impl IntoIterator<Item = impl AsRef<str>>,
516    ) -> Router {
517        let protocols: Vec<_> = protocols.into_iter().map(|p| p.as_ref().to_string()).collect();
518        let tx = self.tx.clone();
519
520        Router::new().route(
521            path,
522            get(move |ws: WebSocketUpgrade, ConnectInfo(remote): ConnectInfo<SocketAddr>| async move {
523                match tx.reserve_owned().await {
524                    Ok(permit) => ws.protocols(protocols.clone()).on_upgrade(move |web_socket| async move {
525                        permit.send(IncomingWebSocket { local: local_addr, remote, web_socket });
526                    }),
527                    Err(_) => Response::builder()
528                        .status(StatusCode::SERVICE_UNAVAILABLE)
529                        .body(Body::from("WebSocketAcceptor was dropped"))
530                        .unwrap(),
531                }
532            }),
533        )
534    }
535
536    /// Builds the [WebSocket transport listener](WebSocketAcceptor).
537    pub fn build(self) -> WebSocketAcceptor {
538        WebSocketAcceptor { rx: Mutex::new(self.rx) }
539    }
540}
541
542/// WebSocket transport for incoming connections.
543///
544/// This transport is packet-based.
545#[derive(Debug)]
546pub struct WebSocketAcceptor {
547    rx: Mutex<mpsc::Receiver<IncomingWebSocket>>,
548}
549
550impl fmt::Display for WebSocketAcceptor {
551    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
552        f.debug_struct("WebSocketAcceptor").finish()
553    }
554}
555
556impl WebSocketAcceptor {
557    /// Create a new WebSocket transport listening for incoming connections at the specified `path`.
558    pub fn new(path: &str) -> (Self, Router) {
559        let wsab = WebSocketAcceptorBuilder::new();
560        let router = wsab.router(path);
561        (wsab.build(), router)
562    }
563
564    /// Starts building a WebSocket transport listener.
565    pub fn builder() -> WebSocketAcceptorBuilder {
566        WebSocketAcceptorBuilder::new()
567    }
568}
569
570#[async_trait]
571impl AcceptingTransport for WebSocketAcceptor {
572    fn name(&self) -> &str {
573        NAME
574    }
575
576    async fn listen(&self, tx: mpsc::Sender<AcceptedStreamBox>) -> Result<()> {
577        let mut rx = self.rx.try_lock().unwrap();
578
579        while let Some(IncomingWebSocket { local, mut remote, web_socket }) = rx.recv().await {
580            let protocol = web_socket.protocol().and_then(|hv| hv.to_str().ok()).map(|s| s.to_string());
581            util::use_proper_ipv4(&mut remote);
582
583            // Adapt WebSocket IO.
584            let (ws_tx, ws_rx) = web_socket.split();
585
586            let ws_tx =
587                Box::pin(
588                    ws_tx
589                        .with(|data: Bytes| async move {
590                            Ok::<_, axum::Error>(axum::extract::ws::Message::Binary(data))
591                        })
592                        .sink_map_err(Error::other),
593                );
594            let ws_write = SinkWriter::new(CopyToBytes::new(ws_tx));
595
596            let ws_rx = Box::pin(
597                ws_rx
598                    .try_filter_map(|msg: axum::extract::ws::Message| async move {
599                        if let axum::extract::ws::Message::Binary(data) = msg {
600                            Ok(Some(data))
601                        } else {
602                            Ok(None)
603                        }
604                    })
605                    .map_err(Error::other),
606            );
607            let ws_read = StreamReader::new(ws_rx);
608
609            // Build tag.
610            tracing::debug!("Accepted WebSocket connection from {remote}");
611            let tag = IncomingWebSocketLinkTag { local, remote, protocol };
612
613            let _ = tx.send(AcceptedStreamBox::new(IoBox::new(ws_read, ws_write).into(), tag)).await;
614        }
615
616        Err(Error::new(ErrorKind::ConnectionReset, "router was dropped"))
617    }
618}