1use std::future::Future;
2use std::net::{IpAddr, SocketAddr};
3use std::pin::Pin;
4use std::str::FromStr;
5use std::sync::{Arc, Weak};
6
7use futures_util::stream::{SplitSink, SplitStream};
8use futures_util::{SinkExt, StreamExt};
9use log::{debug, error, info};
10use serde::Deserialize;
11use tokio::fs::File;
12use tokio::io::AsyncReadExt;
13use tokio::net::{TcpStream, UdpSocket};
14use tokio::process::Command;
15use tokio::sync::Mutex;
16use tokio::time::{Duration, sleep, timeout};
17use tokio_tungstenite::tungstenite::protocol::Message;
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
19use uuid::Uuid;
20
21use crate::protocol::*;
22use crate::utils::{AnyError, resolve_host};
23
24#[derive(Default, Deserialize, Clone)]
25#[serde(rename_all = "camelCase")]
26pub struct Status {
27 pub battery_percentage: Option<i32>,
28}
29
30pub type GetStatusClosure =
31 Box<dyn Fn() -> Pin<Box<dyn Future<Output = Status> + Send + Sync>> + Send + Sync>;
32
33struct RelayInner {
34 me: Weak<Mutex<Self>>,
35 bind_address: String,
37 relay_id: Uuid,
38 streamer_url: String,
39 password: String,
40 name: String,
41 on_status_updated: Option<Box<dyn Fn(String) + Send + Sync>>,
42 get_status: Option<Arc<GetStatusClosure>>,
43 ws_writer: Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
44 started: bool,
45 connected: bool,
46 wrong_password: bool,
47 reconnect_on_tunnel_error: Arc<Mutex<bool>>,
48 start_on_reconnect_soon: Arc<Mutex<bool>>,
49}
50
51impl RelayInner {
52 fn new() -> Arc<Mutex<Self>> {
53 Arc::new_cyclic(|me| {
54 Mutex::new(Self {
55 me: me.clone(),
56 bind_address: Self::get_default_bind_address(),
57 relay_id: Uuid::new_v4(),
58 streamer_url: "".to_string(),
59 password: "".to_string(),
60 name: "".to_string(),
61 on_status_updated: None,
62 get_status: None,
63 ws_writer: None,
64 started: false,
65 connected: false,
66 wrong_password: false,
67 reconnect_on_tunnel_error: Arc::new(Mutex::new(false)),
68 start_on_reconnect_soon: Arc::new(Mutex::new(false)),
69 })
70 })
71 }
72
73 fn set_bind_address(&mut self, address: String) {
74 self.bind_address = address;
75 }
76
77 async fn setup<F>(
78 &mut self,
79 streamer_url: String,
80 password: String,
81 relay_id: Uuid,
82 name: String,
83 on_status_updated: F,
84 get_status: Option<GetStatusClosure>,
85 ) where
86 F: Fn(String) + Send + Sync + 'static,
87 {
88 self.on_status_updated = Some(Box::new(on_status_updated));
89 self.get_status = get_status.map(Arc::new);
90 self.relay_id = relay_id;
91 self.streamer_url = streamer_url;
92 self.password = password;
93 self.name = name;
94 }
95
96 fn is_started(&self) -> bool {
97 self.started
98 }
99
100 async fn start(&mut self) {
101 if !self.started {
102 self.started = true;
103 self.start_internal().await;
104 }
105 }
106
107 async fn stop(&mut self) {
108 if self.started {
109 self.started = false;
110 self.stop_internal().await;
111 }
112 }
113
114 fn get_default_bind_address() -> String {
115 let interfaces = pnet::datalink::interfaces();
117 let interface = interfaces.iter().find(|interface| {
118 interface.is_up() && !interface.is_loopback() && !interface.ips.is_empty()
119 });
120
121 let ipv4_addresses: Vec<String> = interface
123 .expect("No available network interfaces found")
124 .ips
125 .iter()
126 .filter_map(|ip| {
127 let ip = ip.ip();
128 ip.is_ipv4().then(|| ip.to_string())
129 })
130 .collect();
131
132 ipv4_addresses
134 .first()
135 .cloned()
136 .unwrap_or("0.0.0.0:0".to_string())
137 }
138
139 async fn start_internal(&mut self) {
140 if !self.started {
141 self.stop_internal().await;
142 return;
143 }
144
145 let request = match url::Url::parse(&self.streamer_url) {
146 Ok(url) => url,
147 Err(e) => {
148 error!("Failed to parse URL: {}", e);
149 return;
150 }
151 };
152
153 match timeout(Duration::from_secs(10), connect_async(request.to_string())).await {
154 Ok(Ok((ws_stream, _))) => {
155 debug!("Connected to {}", self.streamer_url);
156 let (writer, reader) = ws_stream.split();
157 self.ws_writer = Some(writer);
158 self.start_websocket_receiver(reader);
159 }
160 Ok(Err(error)) => {
161 debug!(
162 "Failed to connect to {} with error: {}",
163 self.streamer_url, error
164 );
165 self.reconnect_soon().await;
166 }
167 Err(_elapsed) => {
168 debug!(
169 "Failed to connect to {} within 10 seconds",
170 self.streamer_url
171 );
172 self.reconnect_soon().await;
173 }
174 }
175 }
176
177 fn start_websocket_receiver(
178 &mut self,
179 mut reader: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
180 ) {
181 let relay = self.me.clone();
183
184 tokio::spawn(async move {
185 let Some(relay_arc) = relay.upgrade() else {
186 return;
187 };
188
189 while let Some(result) = reader.next().await {
190 let mut relay = relay_arc.lock().await;
191 match result {
192 Ok(message) => match message {
193 Message::Text(text) => {
194 match serde_json::from_str::<MessageToRelay>(&text) {
195 Ok(message) => {
196 if let Err(error) = relay.handle_message(message).await {
197 error!("Message handling failed with error: {}", error);
198 relay.reconnect_soon().await;
199 break;
200 }
201 }
202 _ => {
203 error!("Failed to deserialize message: {}", text);
204 }
205 }
206 }
207 Message::Binary(data) => {
208 debug!("Received binary message of length: {}", data.len());
209 }
210 Message::Ping(data) => {
211 relay.send_message(Message::Pong(data)).await.ok();
212 }
213 Message::Pong(_) => {
214 debug!("Received pong message");
215 }
216 Message::Close(frame) => {
217 info!("Received close message: {:?}", frame);
218 relay.reconnect_soon().await;
219 break;
220 }
221 Message::Frame(_) => {
222 unreachable!("This is never used")
223 }
224 },
225 Err(e) => {
226 debug!("Error processing message: {}", e);
227 if e.to_string()
229 .contains("Connection reset without closing handshake")
230 {
231 relay.reconnect_soon().await;
232 }
233 break;
234 }
235 }
236 }
237 });
238 }
239
240 async fn stop_internal(&mut self) {
241 if let Some(mut ws_writer) = self.ws_writer.take() {
242 match ws_writer.close().await {
243 Err(e) => {
244 error!("Error closing WebSocket: {}", e);
245 }
246 _ => {
247 debug!("WebSocket closed successfully");
248 }
249 }
250 }
251 self.connected = false;
252 self.wrong_password = false;
253 *self.reconnect_on_tunnel_error.lock().await = false;
254 *self.start_on_reconnect_soon.lock().await = false;
255 self.update_status();
256 }
257
258 fn update_status(&self) {
259 let Some(on_status_updated) = &self.on_status_updated else {
260 return;
261 };
262 let status = if self.connected {
263 "Connected to streamer"
264 } else if self.wrong_password {
265 "Wrong password"
266 } else if self.started {
267 "Connecting to streamer"
268 } else {
269 "Disconnected from streamer"
270 };
271 on_status_updated(status.to_string());
272 }
273
274 async fn reconnect_soon(&mut self) {
275 self.stop_internal().await;
276 *self.start_on_reconnect_soon.lock().await = false;
277 let start_on_reconnect_soon = Arc::new(Mutex::new(true));
278 self.start_on_reconnect_soon = start_on_reconnect_soon.clone();
279 self.start_soon(start_on_reconnect_soon);
280 }
281
282 fn start_soon(&mut self, start_on_reconnect_soon: Arc<Mutex<bool>>) {
283 let relay = self.me.clone();
284
285 tokio::spawn(async move {
286 sleep(Duration::from_secs(5)).await;
287
288 if *start_on_reconnect_soon.lock().await {
289 debug!("Reconnecting...");
290 if let Some(relay) = relay.upgrade() {
291 relay.lock().await.start_internal().await;
292 }
293 }
294 });
295 }
296
297 async fn handle_message(&mut self, message: MessageToRelay) -> Result<(), AnyError> {
298 match message {
299 MessageToRelay::Hello(hello) => self.handle_message_hello(hello).await,
300 MessageToRelay::Identified(identified) => {
301 self.handle_message_identified(identified).await
302 }
303 MessageToRelay::Request(request) => self.handle_message_request(request).await,
304 }
305 }
306
307 async fn handle_message_hello(&mut self, hello: Hello) -> Result<(), AnyError> {
308 let authentication = calculate_authentication(
309 &self.password,
310 &hello.authentication.salt,
311 &hello.authentication.challenge,
312 );
313 let identify = Identify {
314 id: self.relay_id,
315 name: self.name.clone(),
316 authentication,
317 };
318 self.send(MessageToStreamer::Identify(identify)).await
319 }
320
321 async fn handle_message_identified(&mut self, identified: Identified) -> Result<(), AnyError> {
322 match identified.result {
323 MoblinkResult::Ok(_) => {
324 self.connected = true;
325 }
326 MoblinkResult::WrongPassword(_) => {
327 self.wrong_password = true;
328 }
329 }
330 self.update_status();
331 Ok(())
332 }
333
334 async fn handle_message_request(&mut self, request: MessageRequest) -> Result<(), AnyError> {
335 match &request.data {
336 MessageRequestData::StartTunnel(start_tunnel) => {
337 self.handle_message_request_start_tunnel(&request, start_tunnel)
338 .await
339 }
340 MessageRequestData::Status(_) => self.handle_message_request_status(request).await,
341 }
342 }
343
344 async fn handle_message_request_start_tunnel(
345 &mut self,
346 request: &MessageRequest,
347 start_tunnel: &StartTunnelRequest,
348 ) -> Result<(), AnyError> {
349 let local_bind_addr_for_streamer = parse_socket_addr("0.0.0.0")?;
351 let local_bind_addr_for_destination = parse_socket_addr(&self.bind_address)?;
352
353 debug!(
354 "Binding streamer socket on: {}, destination socket on: {}",
355 local_bind_addr_for_streamer, local_bind_addr_for_destination
356 );
357 let streamer_socket = create_dual_stack_udp_socket(local_bind_addr_for_streamer).await?;
360 let streamer_port = streamer_socket.local_addr()?.port();
361 let streamer_socket = Arc::new(streamer_socket);
362
363 let data = ResponseData::StartTunnel(StartTunnelResponseData {
365 port: streamer_port,
366 });
367 let response = request.to_ok_response(data);
368 self.send(MessageToStreamer::Response(response)).await?;
369
370 let destination_socket =
373 create_dual_stack_udp_socket(local_bind_addr_for_destination).await?;
374
375 let destination_socket = Arc::new(destination_socket);
376 let destination_address = resolve_host(&start_tunnel.address).await?;
377 let destination_address = match IpAddr::from_str(&destination_address)? {
378 IpAddr::V4(v4) => IpAddr::V4(v4),
379 IpAddr::V6(v6) => {
380 if let Some(mapped_v4) = v6.to_ipv4() {
382 IpAddr::V4(mapped_v4)
383 } else {
384 IpAddr::V6(v6)
386 }
387 }
388 };
389
390 let destination_address = SocketAddr::new(destination_address, start_tunnel.port);
391 info!("Destination address: {}", destination_address);
392 let streamer_address: Arc<Mutex<Option<SocketAddr>>> = Arc::new(Mutex::new(None));
393
394 let relay_to_destination = start_relay_from_streamer_to_destination(
395 streamer_socket.clone(),
396 destination_socket.clone(),
397 streamer_address.clone(),
398 destination_address,
399 );
400 let relay_to_streamer = start_relay_from_destination_to_streamer(
401 streamer_socket,
402 destination_socket,
403 streamer_address,
404 );
405
406 *self.reconnect_on_tunnel_error.lock().await = false;
407 let reconnect_on_tunnel_error = Arc::new(Mutex::new(true));
408 self.reconnect_on_tunnel_error = reconnect_on_tunnel_error.clone();
409 let relay = self.me.clone();
410
411 tokio::spawn(async move {
412 let Some(relay) = relay.upgrade() else {
413 return;
414 };
415
416 tokio::select! {
419 res = relay_to_destination => {
420 if let Err(e) = res {
421 error!("relay_to_destination task failed: {}", e);
422 }
423 }
424 res = relay_to_streamer => {
425 if let Err(e) = res {
426 error!("relay_to_streamer task failed: {}", e);
427 }
428 }
429 }
430
431 if *reconnect_on_tunnel_error.lock().await {
432 relay.lock().await.reconnect_soon().await;
433 } else {
434 info!("Not reconnecting after tunnel error");
435 }
436 });
437
438 Ok(())
439 }
440
441 async fn handle_message_request_status(
442 &mut self,
443 request: MessageRequest,
444 ) -> Result<(), AnyError> {
445 let mut battery_percentage = None;
446 if let Some(get_status) = self.get_status.as_ref() {
447 battery_percentage = get_status().await.battery_percentage;
448 }
449 let data = ResponseData::Status(StatusResponseData { battery_percentage });
450 let response = request.to_ok_response(data);
451 self.send(MessageToStreamer::Response(response)).await
452 }
453
454 async fn send(&mut self, message: MessageToStreamer) -> Result<(), AnyError> {
455 let text = serde_json::to_string(&message)?;
456 self.send_message(Message::Text(text.into())).await
457 }
458
459 async fn send_message(&mut self, message: Message) -> Result<(), AnyError> {
460 let Some(writer) = self.ws_writer.as_mut() else {
461 return Err("No websocket writer".into());
462 };
463 writer.send(message).await?;
464 Ok(())
465 }
466}
467
468pub struct Relay {
469 inner: Arc<Mutex<RelayInner>>,
470}
471
472impl Default for Relay {
473 fn default() -> Self {
474 Self::new()
475 }
476}
477
478impl Relay {
479 pub fn new() -> Self {
480 Self {
481 inner: RelayInner::new(),
482 }
483 }
484
485 pub async fn set_bind_address(&self, address: String) {
486 self.inner.lock().await.set_bind_address(address);
487 }
488
489 pub async fn setup<F>(
490 &self,
491 streamer_url: String,
492 password: String,
493 relay_id: Uuid,
494 name: String,
495 on_status_updated: F,
496 get_status: Option<GetStatusClosure>,
497 ) where
498 F: Fn(String) + Send + Sync + 'static,
499 {
500 self.inner
501 .lock()
502 .await
503 .setup(
504 streamer_url,
505 password,
506 relay_id,
507 name,
508 on_status_updated,
509 get_status,
510 )
511 .await;
512 }
513
514 pub async fn is_started(&self) -> bool {
515 self.inner.lock().await.is_started()
516 }
517
518 pub async fn start(&self) {
519 self.inner.lock().await.start().await;
520 }
521
522 pub async fn stop(&self) {
523 self.inner.lock().await.stop().await;
524 }
525}
526
527fn start_relay_from_streamer_to_destination(
528 streamer_socket: Arc<UdpSocket>,
529 destination_socket: Arc<UdpSocket>,
530 streamer_addr: Arc<Mutex<Option<SocketAddr>>>,
531 destination_addr: SocketAddr,
532) -> tokio::task::JoinHandle<()> {
533 tokio::spawn(async move {
534 loop {
535 if let Err(error) = relay_one_packet_from_streamer_to_destination(
536 &streamer_socket,
537 &destination_socket,
538 &streamer_addr,
539 &destination_addr,
540 )
541 .await
542 {
543 error!("(relay_to_destination) Failed with error: {}", error);
544 break;
545 }
546 }
547 })
548}
549
550async fn relay_one_packet_from_streamer_to_destination(
551 streamer_socket: &Arc<UdpSocket>,
552 destination_socket: &Arc<UdpSocket>,
553 streamer_addr: &Arc<Mutex<Option<SocketAddr>>>,
554 destination_addr: &SocketAddr,
555) -> Result<(), AnyError> {
556 let mut buf = [0; 2048];
557 let (size, remote_addr) =
558 timeout(Duration::from_secs(30), streamer_socket.recv_from(&mut buf)).await??;
559 destination_socket
560 .send_to(&buf[..size], &destination_addr)
561 .await?;
562 streamer_addr.lock().await.replace(remote_addr);
563 Ok(())
564}
565
566fn start_relay_from_destination_to_streamer(
567 streamer_socket: Arc<UdpSocket>,
568 destination_socket: Arc<UdpSocket>,
569 streamer_address: Arc<Mutex<Option<SocketAddr>>>,
570) -> tokio::task::JoinHandle<()> {
571 tokio::spawn(async move {
572 loop {
573 if let Err(error) = relay_one_packet_from_destination_to_streamer(
574 &streamer_socket,
575 &destination_socket,
576 &streamer_address,
577 )
578 .await
579 {
580 error!("(relay_to_streamer) Failed with error: {}", error);
581 break;
582 }
583 }
584 })
585}
586
587async fn relay_one_packet_from_destination_to_streamer(
588 streamer_socket: &Arc<UdpSocket>,
589 destination_socket: &Arc<UdpSocket>,
590 streamer_address: &Arc<Mutex<Option<SocketAddr>>>,
591) -> Result<(), AnyError> {
592 let mut buf = [0; 2048];
593 let size = timeout(Duration::from_secs(30), destination_socket.recv(&mut buf)).await??;
594 let streamer_addr = streamer_address
595 .lock()
596 .await
597 .ok_or("Failed to get address lock")?;
598 streamer_socket
599 .send_to(&buf[..size], &streamer_addr)
600 .await?;
601 Ok(())
602}
603
604async fn create_dual_stack_udp_socket(
605 addr: SocketAddr,
606) -> Result<tokio::net::UdpSocket, std::io::Error> {
607 let socket = match addr.is_ipv4() {
608 true => {
609 tokio::net::UdpSocket::bind(addr).await?
611 }
612 false => {
613 let socket = socket2::Socket::new(
615 socket2::Domain::IPV6,
616 socket2::Type::DGRAM,
617 Some(socket2::Protocol::UDP),
618 )?;
619
620 socket.set_only_v6(false)?;
622
623 socket.bind(&socket2::SockAddr::from(addr))?;
625
626 tokio::net::UdpSocket::from_std(socket.into())?
628 }
629 };
630
631 Ok(socket)
632}
633
634fn parse_socket_addr(addr_str: &str) -> Result<SocketAddr, std::io::Error> {
637 if let Ok(socket_addr) = SocketAddr::from_str(addr_str) {
639 return Ok(socket_addr);
640 }
641
642 if let Ok(ip_addr) = IpAddr::from_str(addr_str) {
645 return Ok(SocketAddr::new(ip_addr, 0));
647 }
648
649 Err(std::io::Error::new(
651 std::io::ErrorKind::InvalidInput,
652 "Invalid socket address syntax. Expected 'IP:port' or 'IP'.",
653 ))
654}
655
656pub fn create_get_status_closure(
657 status_executable: &Option<String>,
658 status_file: &Option<String>,
659) -> Option<GetStatusClosure> {
660 let status_executable = status_executable.clone();
661 let status_file = status_file.clone();
662 Some(Box::new(move || {
663 let status_executable = status_executable.clone();
664 let status_file = status_file.clone();
665 Box::pin(async move {
666 let output = if let Some(status_executable) = &status_executable {
667 let Ok(output) = Command::new(status_executable).output().await else {
668 return Default::default();
669 };
670 output.stdout
671 } else if let Some(status_file) = &status_file {
672 let Ok(mut file) = File::open(status_file).await else {
673 return Default::default();
674 };
675 let mut contents = vec![];
676 if file.read_to_end(&mut contents).await.is_err() {
677 return Default::default();
678 }
679 contents
680 } else {
681 return Default::default();
682 };
683 let output = String::from_utf8(output).unwrap_or_default();
684 match serde_json::from_str(&output) {
685 Ok(status) => status,
686 Err(e) => {
687 error!("Failed to decode status with error: {e}");
688 Default::default()
689 }
690 }
691 })
692 }))
693}