1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![doc(
4 html_logo_url = "https://raw.githubusercontent.com/surban/aggligator/master/.misc/aggligator.png",
5 html_favicon_url = "https://raw.githubusercontent.com/surban/aggligator/master/.misc/aggligator.png",
6 issue_tracker_base_url = "https://github.com/surban/aggligator/issues/"
7)]
8
9use async_trait::async_trait;
12use axum::{
13 body::Body,
14 extract::{ConnectInfo, WebSocketUpgrade},
15 http::StatusCode,
16 response::Response,
17 routing::get,
18 Router,
19};
20use bytes::Bytes;
21use futures::{SinkExt, StreamExt, TryStreamExt};
22use std::{
23 any::Any,
24 cmp::Ordering,
25 collections::{HashMap, HashSet},
26 fmt,
27 hash::{Hash, Hasher},
28 io::{Error, ErrorKind, Result},
29 net::{IpAddr, Ipv6Addr, SocketAddr},
30 sync::Arc,
31 time::Duration,
32};
33use tokio::{
34 net::TcpSocket,
35 sync::{mpsc, watch, Mutex},
36 time::sleep,
37};
38use tokio_tungstenite::{client_async_tls_with_config, tungstenite::protocol::WebSocketConfig, Connector};
39use tokio_util::io::{CopyToBytes, SinkWriter, StreamReader};
40use url::Url;
41
42use aggligator::{
43 control::Direction,
44 io::{IoBox, StreamBox},
45 transport::{AcceptedStreamBox, AcceptingTransport, ConnectingTransport, LinkTag, LinkTagBox},
46 Link,
47};
48use aggligator_transport_tcp::util::{self, NetworkInterface};
49pub use aggligator_transport_tcp::IpVersion;
50
51static NAME: &str = "websocket";
52
53#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
55pub struct OutgoingWebSocketLinkTag {
56 pub interface: Option<Vec<u8>>,
58 pub remote: SocketAddr,
60 pub url: String,
62 pub tls: bool,
64}
65
66impl fmt::Display for OutgoingWebSocketLinkTag {
67 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68 write!(
69 f,
70 "{} -> {} ({})",
71 String::from_utf8_lossy(self.interface.as_deref().unwrap_or_default()),
72 &self.remote,
73 &self.url
74 )
75 }
76}
77
78impl LinkTag for OutgoingWebSocketLinkTag {
79 fn transport_name(&self) -> &str {
80 NAME
81 }
82
83 fn direction(&self) -> Direction {
84 Direction::Outgoing
85 }
86
87 fn user_data(&self) -> Vec<u8> {
88 self.interface.clone().unwrap_or_default()
89 }
90
91 fn as_any(&self) -> &dyn Any {
92 self
93 }
94
95 fn box_clone(&self) -> LinkTagBox {
96 Box::new(self.clone())
97 }
98
99 fn dyn_cmp(&self, other: &dyn LinkTag) -> Ordering {
100 let other = other.as_any().downcast_ref::<Self>().unwrap();
101 Ord::cmp(self, other)
102 }
103
104 fn dyn_hash(&self, mut state: &mut dyn Hasher) {
105 Hash::hash(self, &mut state)
106 }
107}
108
109#[derive(Clone)]
113pub struct WebSocketConnector {
114 urls: Vec<Url>,
115 ip_version: IpVersion,
116 resolve_interval: Duration,
117 connector: Option<Connector>,
118 web_socket_config: Option<WebSocketConfig>,
119 multi_interface: bool,
120 interface_filter: Arc<dyn Fn(&NetworkInterface) -> bool + Send + Sync>,
121}
122
123impl fmt::Debug for WebSocketConnector {
124 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
125 f.debug_struct("WebSocketConnector")
126 .field("urls", &self.urls)
127 .field("ip_version", &self.ip_version)
128 .field("resolve_interval", &self.resolve_interval)
129 .field("web_socket_config", &self.web_socket_config)
130 .field("multi_interface", &self.multi_interface)
131 .finish()
132 }
133}
134
135impl fmt::Display for WebSocketConnector {
136 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
137 let urls: Vec<_> = self.urls.iter().map(|url| url.to_string()).collect();
138 if self.urls.len() > 1 {
139 write!(f, "[{}]", urls.join(", "))
140 } else {
141 write!(f, "{}", &urls[0])
142 }
143 }
144}
145
146impl WebSocketConnector {
147 pub async fn new(urls: impl IntoIterator<Item = impl AsRef<str>>) -> Result<Self> {
156 let this = Self::unresolved(urls).await?;
157
158 let addrs = this.resolve().await;
159 if addrs.values().all(|addrs| addrs.is_empty()) {
160 return Err(Error::new(ErrorKind::NotFound, "cannot resolve IP address of any URL"));
161 }
162 tracing::info!(?addrs, "URLs initially resolved");
163
164 Ok(this)
165 }
166
167 pub async fn unresolved(urls: impl IntoIterator<Item = impl AsRef<str>>) -> Result<Self> {
174 let urls = urls
175 .into_iter()
176 .map(|url| url.as_ref().parse::<Url>())
177 .collect::<std::result::Result<Vec<_>, _>>()
178 .map_err(|err| Error::new(ErrorKind::InvalidInput, err))?;
179
180 if urls.is_empty() {
181 return Err(Error::new(ErrorKind::InvalidInput, "at least one URL is required"));
182 }
183 for url in &urls {
184 if !url.has_host() {
185 return Err(Error::new(ErrorKind::InvalidInput, "URL must have a host"));
186 }
187 if !["ws", "wss"].contains(&url.scheme()) {
188 return Err(Error::new(ErrorKind::InvalidInput, "URL must have scheme ws or wss"));
189 }
190 }
191
192 Ok(Self {
193 urls,
194 ip_version: IpVersion::Both,
195 resolve_interval: Duration::from_secs(10),
196 connector: None,
197 web_socket_config: None,
198 multi_interface: !cfg!(target_os = "android"),
199 interface_filter: Arc::new(|_| true),
200 })
201 }
202
203 pub fn set_ip_version(&mut self, ip_version: IpVersion) {
205 self.ip_version = ip_version;
206 }
207
208 pub fn set_resolve_interval(&mut self, resolve_interval: Duration) {
210 self.resolve_interval = resolve_interval;
211 }
212
213 pub fn set_connector(&mut self, connector: Option<Connector>) {
217 self.connector = connector;
218 }
219
220 pub fn set_web_socket_config(&mut self, web_socket_config: Option<WebSocketConfig>) {
222 self.web_socket_config = web_socket_config;
223 }
224
225 pub fn set_multi_interface(&mut self, multi_interface: bool) {
235 self.multi_interface = multi_interface;
236 }
237
238 pub fn set_interface_filter(
247 &mut self, interface_filter: impl Fn(&NetworkInterface) -> bool + Send + Sync + 'static,
248 ) {
249 self.interface_filter = Arc::new(interface_filter);
250 }
251
252 async fn resolve(&self) -> HashMap<&Url, Vec<SocketAddr>> {
254 let mut url_addrs = HashMap::new();
255
256 for url in &self.urls {
257 let host = url.host_str().unwrap();
258 let port = url.port_or_known_default().unwrap();
259 let addrs = util::resolve_hosts(&[format!("{host}:{port}")], self.ip_version).await;
260 url_addrs.insert(url, addrs);
261 }
262
263 url_addrs
264 }
265}
266
267#[async_trait]
268impl ConnectingTransport for WebSocketConnector {
269 fn name(&self) -> &str {
270 NAME
271 }
272
273 async fn link_tags(&self, tx: watch::Sender<HashSet<LinkTagBox>>) -> Result<()> {
274 loop {
275 let interfaces: Option<Vec<NetworkInterface>> = match self.multi_interface {
276 true => Some(
277 util::local_interfaces()?
278 .into_iter()
279 .filter(|iface| (self.interface_filter)(iface))
280 .collect(),
281 ),
282 false => None,
283 };
284
285 let mut tags: HashSet<LinkTagBox> = HashSet::new();
286 for (url, addrs) in self.resolve().await {
287 for addr in addrs {
288 match &interfaces {
289 Some(interfaces) => {
290 for interface in util::interface_names_for_target(interfaces, addr) {
291 let tag = OutgoingWebSocketLinkTag {
292 interface: Some(interface),
293 remote: addr,
294 url: url.to_string(),
295 tls: url.scheme() == "wss",
296 };
297 tags.insert(Box::new(tag));
298 }
299 }
300 None => {
301 let tag = OutgoingWebSocketLinkTag {
302 interface: None,
303 remote: addr,
304 url: url.to_string(),
305 tls: url.scheme() == "wss",
306 };
307 tags.insert(Box::new(tag));
308 }
309 }
310 }
311 }
312
313 tx.send_if_modified(|v| {
314 if *v != tags {
315 *v = tags;
316 true
317 } else {
318 false
319 }
320 });
321
322 sleep(self.resolve_interval).await;
323 }
324 }
325
326 async fn connect(&self, tag: &dyn LinkTag) -> Result<StreamBox> {
327 let tag: &OutgoingWebSocketLinkTag = tag.as_any().downcast_ref().unwrap();
328
329 let socket = match tag.remote.ip() {
331 IpAddr::V4(_) => TcpSocket::new_v4(),
332 IpAddr::V6(_) => TcpSocket::new_v6(),
333 }?;
334
335 if let Some(interface) = &tag.interface {
336 util::bind_socket_to_interface(&socket, interface, tag.remote.ip())?;
337 }
338
339 let stream = socket.connect(tag.remote).await?;
340 let _ = stream.set_nodelay(true);
341
342 let connector = if tag.tls { self.connector.clone() } else { Some(Connector::Plain) };
344 let (web_socket, _rsp) =
345 client_async_tls_with_config(&tag.url, stream, self.web_socket_config, connector)
346 .await
347 .map_err(|err| Error::new(ErrorKind::ConnectionRefused, err))?;
348
349 let (ws_tx, ws_rx) = web_socket.split();
351 let ws_tx = Box::pin(
352 ws_tx
353 .with(
354 |data: Bytes| async move { Ok::<_, tungstenite::Error>(tungstenite::Message::Binary(data)) },
355 )
356 .sink_map_err(Error::other),
357 );
358 let ws_write = SinkWriter::new(CopyToBytes::new(ws_tx));
359
360 let ws_rx = Box::pin(
361 ws_rx
362 .try_filter_map(|msg: tungstenite::Message| async move {
363 if let tungstenite::Message::Binary(data) = msg {
364 Ok(Some(data))
365 } else {
366 Ok(None)
367 }
368 })
369 .map_err(Error::other),
370 );
371 let ws_read = StreamReader::new(ws_rx);
372
373 Ok(IoBox::new(ws_read, ws_write).into())
374 }
375
376 async fn link_filter(&self, new: &Link<LinkTagBox>, existing: &[Link<LinkTagBox>]) -> bool {
377 let Some(new_tag) = new.tag().as_any().downcast_ref::<OutgoingWebSocketLinkTag>() else { return true };
378
379 let intro = format!(
380 "Judging {} WebSocket link {} {} ({}) on {}",
381 new.direction(),
382 match new.direction() {
383 Direction::Incoming => "from",
384 Direction::Outgoing => "to",
385 },
386 new_tag.remote,
387 String::from_utf8_lossy(new.remote_user_data()),
388 String::from_utf8_lossy(new_tag.interface.as_deref().unwrap_or(b"any interface"))
389 );
390
391 match existing.iter().find(|link| {
392 let Some(tag) = link.tag().as_any().downcast_ref::<OutgoingWebSocketLinkTag>() else { return false };
393 tag.interface == new_tag.interface && link.remote_user_data() == new.remote_user_data()
394 }) {
395 Some(other) => {
396 let other_tag = other.tag().as_any().downcast_ref::<OutgoingWebSocketLinkTag>().unwrap();
397 tracing::debug!("{intro} => link {} is redundant, rejecting.", other_tag.remote);
398 false
399 }
400 None => {
401 tracing::debug!("{intro} => accepted.");
402 true
403 }
404 }
405 }
406}
407
408#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
410pub struct IncomingWebSocketLinkTag {
411 pub local: SocketAddr,
413 pub remote: SocketAddr,
415 pub protocol: Option<String>,
417}
418
419impl fmt::Display for IncomingWebSocketLinkTag {
420 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
421 write!(
422 f,
423 "{} <- {}{}",
424 &self.local,
425 &self.remote,
426 match &self.protocol {
427 Some(protocol) => format!(" ({protocol})"),
428 None => String::new(),
429 }
430 )
431 }
432}
433
434impl LinkTag for IncomingWebSocketLinkTag {
435 fn transport_name(&self) -> &str {
436 NAME
437 }
438
439 fn direction(&self) -> Direction {
440 Direction::Incoming
441 }
442
443 fn user_data(&self) -> Vec<u8> {
444 match self.local.ip() {
445 IpAddr::V4(ip) => ip.octets().into(),
446 IpAddr::V6(ip) => ip.octets().into(),
447 }
448 }
449
450 fn as_any(&self) -> &dyn Any {
451 self
452 }
453
454 fn box_clone(&self) -> LinkTagBox {
455 Box::new(self.clone())
456 }
457
458 fn dyn_cmp(&self, other: &dyn LinkTag) -> Ordering {
459 let other = other.as_any().downcast_ref::<Self>().unwrap();
460 Ord::cmp(self, other)
461 }
462
463 fn dyn_hash(&self, mut state: &mut dyn Hasher) {
464 Hash::hash(self, &mut state)
465 }
466}
467
468struct IncomingWebSocket {
469 local: SocketAddr,
470 remote: SocketAddr,
471 web_socket: axum::extract::ws::WebSocket,
472}
473
474pub struct WebSocketAcceptorBuilder {
476 tx: mpsc::Sender<IncomingWebSocket>,
477 rx: mpsc::Receiver<IncomingWebSocket>,
478}
479
480impl fmt::Debug for WebSocketAcceptorBuilder {
481 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
482 f.debug_struct("WebSocketAcceptorBuilder").finish()
483 }
484}
485
486impl WebSocketAcceptorBuilder {
487 fn new() -> Self {
488 let (tx, rx) = mpsc::channel(16);
489 Self { tx, rx }
490 }
491}
492
493impl WebSocketAcceptorBuilder {
494 pub fn router(&self, path: &str) -> Router {
500 let protocols: [String; 0] = [];
501 self.custom_router(path, SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), protocols)
502 }
503
504 pub fn custom_router(
515 &self, path: &str, local_addr: SocketAddr, protocols: impl IntoIterator<Item = impl AsRef<str>>,
516 ) -> Router {
517 let protocols: Vec<_> = protocols.into_iter().map(|p| p.as_ref().to_string()).collect();
518 let tx = self.tx.clone();
519
520 Router::new().route(
521 path,
522 get(move |ws: WebSocketUpgrade, ConnectInfo(remote): ConnectInfo<SocketAddr>| async move {
523 match tx.reserve_owned().await {
524 Ok(permit) => ws.protocols(protocols.clone()).on_upgrade(move |web_socket| async move {
525 permit.send(IncomingWebSocket { local: local_addr, remote, web_socket });
526 }),
527 Err(_) => Response::builder()
528 .status(StatusCode::SERVICE_UNAVAILABLE)
529 .body(Body::from("WebSocketAcceptor was dropped"))
530 .unwrap(),
531 }
532 }),
533 )
534 }
535
536 pub fn build(self) -> WebSocketAcceptor {
538 WebSocketAcceptor { rx: Mutex::new(self.rx) }
539 }
540}
541
542#[derive(Debug)]
546pub struct WebSocketAcceptor {
547 rx: Mutex<mpsc::Receiver<IncomingWebSocket>>,
548}
549
550impl fmt::Display for WebSocketAcceptor {
551 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
552 f.debug_struct("WebSocketAcceptor").finish()
553 }
554}
555
556impl WebSocketAcceptor {
557 pub fn new(path: &str) -> (Self, Router) {
559 let wsab = WebSocketAcceptorBuilder::new();
560 let router = wsab.router(path);
561 (wsab.build(), router)
562 }
563
564 pub fn builder() -> WebSocketAcceptorBuilder {
566 WebSocketAcceptorBuilder::new()
567 }
568}
569
570#[async_trait]
571impl AcceptingTransport for WebSocketAcceptor {
572 fn name(&self) -> &str {
573 NAME
574 }
575
576 async fn listen(&self, tx: mpsc::Sender<AcceptedStreamBox>) -> Result<()> {
577 let mut rx = self.rx.try_lock().unwrap();
578
579 while let Some(IncomingWebSocket { local, mut remote, web_socket }) = rx.recv().await {
580 let protocol = web_socket.protocol().and_then(|hv| hv.to_str().ok()).map(|s| s.to_string());
581 util::use_proper_ipv4(&mut remote);
582
583 let (ws_tx, ws_rx) = web_socket.split();
585
586 let ws_tx =
587 Box::pin(
588 ws_tx
589 .with(|data: Bytes| async move {
590 Ok::<_, axum::Error>(axum::extract::ws::Message::Binary(data))
591 })
592 .sink_map_err(Error::other),
593 );
594 let ws_write = SinkWriter::new(CopyToBytes::new(ws_tx));
595
596 let ws_rx = Box::pin(
597 ws_rx
598 .try_filter_map(|msg: axum::extract::ws::Message| async move {
599 if let axum::extract::ws::Message::Binary(data) = msg {
600 Ok(Some(data))
601 } else {
602 Ok(None)
603 }
604 })
605 .map_err(Error::other),
606 );
607 let ws_read = StreamReader::new(ws_rx);
608
609 tracing::debug!("Accepted WebSocket connection from {remote}");
611 let tag = IncomingWebSocketLinkTag { local, remote, protocol };
612
613 let _ = tx.send(AcceptedStreamBox::new(IoBox::new(ws_read, ws_write).into(), tag)).await;
614 }
615
616 Err(Error::new(ErrorKind::ConnectionReset, "router was dropped"))
617 }
618}