moblink_rust/
relay.rs

1use std::future::Future;
2use std::net::{IpAddr, SocketAddr};
3use std::pin::Pin;
4use std::str::FromStr;
5use std::sync::{Arc, Weak};
6
7use futures_util::stream::{SplitSink, SplitStream};
8use futures_util::{SinkExt, StreamExt};
9use log::{debug, error, info};
10use serde::Deserialize;
11use tokio::fs::File;
12use tokio::io::AsyncReadExt;
13use tokio::net::{TcpStream, UdpSocket};
14use tokio::process::Command;
15use tokio::sync::Mutex;
16use tokio::time::{Duration, sleep, timeout};
17use tokio_tungstenite::tungstenite::protocol::Message;
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
19use uuid::Uuid;
20
21use crate::protocol::*;
22use crate::utils::AnyError;
23
24#[derive(Default, Deserialize, Clone)]
25#[serde(rename_all = "camelCase")]
26pub struct Status {
27    pub battery_percentage: Option<i32>,
28}
29
30pub type GetStatusClosure =
31    Box<dyn Fn() -> Pin<Box<dyn Future<Output = Status> + Send + Sync>> + Send + Sync>;
32
33struct RelayInner {
34    me: Weak<Mutex<Self>>,
35    /// Store a local IP address  for binding UDP sockets
36    bind_address: String,
37    relay_id: Uuid,
38    streamer_url: String,
39    password: String,
40    name: String,
41    on_status_updated: Option<Box<dyn Fn(String) + Send + Sync>>,
42    get_status: Option<Arc<GetStatusClosure>>,
43    ws_writer: Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
44    started: bool,
45    connected: bool,
46    wrong_password: bool,
47    reconnect_on_tunnel_error: Arc<Mutex<bool>>,
48    start_on_reconnect_soon: Arc<Mutex<bool>>,
49}
50
51impl RelayInner {
52    fn new() -> Arc<Mutex<Self>> {
53        Arc::new_cyclic(|me| {
54            Mutex::new(Self {
55                me: me.clone(),
56                bind_address: Self::get_default_bind_address(),
57                relay_id: Uuid::new_v4(),
58                streamer_url: "".to_string(),
59                password: "".to_string(),
60                name: "".to_string(),
61                on_status_updated: None,
62                get_status: None,
63                ws_writer: None,
64                started: false,
65                connected: false,
66                wrong_password: false,
67                reconnect_on_tunnel_error: Arc::new(Mutex::new(false)),
68                start_on_reconnect_soon: Arc::new(Mutex::new(false)),
69            })
70        })
71    }
72
73    fn set_bind_address(&mut self, address: String) {
74        self.bind_address = address;
75    }
76
77    async fn setup<F>(
78        &mut self,
79        streamer_url: String,
80        password: String,
81        relay_id: Uuid,
82        name: String,
83        on_status_updated: F,
84        get_status: Option<GetStatusClosure>,
85    ) where
86        F: Fn(String) + Send + Sync + 'static,
87    {
88        self.on_status_updated = Some(Box::new(on_status_updated));
89        self.get_status = get_status.map(Arc::new);
90        self.relay_id = relay_id;
91        self.streamer_url = streamer_url;
92        self.password = password;
93        self.name = name;
94        info!("Binding to address: {:?}", self.bind_address);
95    }
96
97    fn is_started(&self) -> bool {
98        self.started
99    }
100
101    async fn start(&mut self) {
102        if !self.started {
103            self.started = true;
104            self.start_internal().await;
105        }
106    }
107
108    async fn stop(&mut self) {
109        if self.started {
110            self.started = false;
111            self.stop_internal().await;
112        }
113    }
114
115    fn get_default_bind_address() -> String {
116        // Get main network interface
117        let interfaces = pnet::datalink::interfaces();
118        let interface = interfaces.iter().find(|interface| {
119            interface.is_up() && !interface.is_loopback() && !interface.ips.is_empty()
120        });
121
122        // Only ipv4 addresses are supported
123        let ipv4_addresses: Vec<String> = interface
124            .expect("No available network interfaces found")
125            .ips
126            .iter()
127            .filter_map(|ip| {
128                let ip = ip.ip();
129                ip.is_ipv4().then(|| ip.to_string())
130            })
131            .collect();
132
133        // Return the first address
134        ipv4_addresses
135            .first()
136            .cloned()
137            .unwrap_or("0.0.0.0:0".to_string())
138    }
139
140    async fn start_internal(&mut self) {
141        info!("Start internal");
142        if !self.started {
143            self.stop_internal().await;
144            return;
145        }
146
147        let request = match url::Url::parse(&self.streamer_url) {
148            Ok(url) => url,
149            Err(e) => {
150                error!("Failed to parse URL: {}", e);
151                return;
152            }
153        };
154
155        match timeout(Duration::from_secs(10), connect_async(request.to_string())).await {
156            Ok(Ok((ws_stream, _))) => {
157                info!("WebSocket connected");
158                let (writer, reader) = ws_stream.split();
159                self.ws_writer = Some(writer);
160                self.start_websocket_receiver(reader);
161            }
162            Ok(Err(error)) => {
163                // This means the future completed but the connection failed
164                error!("WebSocket connection failed immediately: {}", error);
165                self.reconnect_soon().await;
166            }
167            Err(_elapsed) => {
168                // This means the future did NOT complete within 10 seconds
169                error!("WebSocket connection attempt timed out after 10 seconds");
170                self.reconnect_soon().await;
171            }
172        }
173    }
174
175    fn start_websocket_receiver(
176        &mut self,
177        mut reader: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
178    ) {
179        // Task to process messages received from the channel.
180        let relay = self.me.clone();
181
182        tokio::spawn(async move {
183            let Some(relay_arc) = relay.upgrade() else {
184                return;
185            };
186
187            while let Some(result) = reader.next().await {
188                let mut relay = relay_arc.lock().await;
189                match result {
190                    Ok(message) => match message {
191                        Message::Text(text) => {
192                            match serde_json::from_str::<MessageToRelay>(&text) {
193                                Ok(message) => {
194                                    relay.handle_message(message).await.ok();
195                                }
196                                _ => {
197                                    error!("Failed to deserialize message: {}", text);
198                                }
199                            }
200                        }
201                        Message::Binary(data) => {
202                            debug!("Received binary message of length: {}", data.len());
203                        }
204                        Message::Ping(data) => {
205                            relay.send_message(Message::Pong(data)).await.ok();
206                        }
207                        Message::Pong(_) => {
208                            debug!("Received pong message");
209                        }
210                        Message::Close(frame) => {
211                            info!("Received close message: {:?}", frame);
212                            relay.reconnect_soon().await;
213                            break;
214                        }
215                        Message::Frame(_) => {
216                            unreachable!("This is never used")
217                        }
218                    },
219                    Err(e) => {
220                        error!("Error processing message: {}", e);
221                        // TODO: There has to be a better way to handle this
222                        if e.to_string()
223                            .contains("Connection reset without closing handshake")
224                        {
225                            relay.reconnect_soon().await;
226                        }
227                        break;
228                    }
229                }
230            }
231        });
232    }
233
234    async fn stop_internal(&mut self) {
235        info!("Stop internal");
236        if let Some(mut ws_writer) = self.ws_writer.take() {
237            match ws_writer.close().await {
238                Err(e) => {
239                    error!("Error closing WebSocket: {}", e);
240                }
241                _ => {
242                    info!("WebSocket closed successfully");
243                }
244            }
245        }
246        self.connected = false;
247        self.wrong_password = false;
248        *self.reconnect_on_tunnel_error.lock().await = false;
249        *self.start_on_reconnect_soon.lock().await = false;
250        self.update_status();
251    }
252
253    fn update_status(&self) {
254        let Some(on_status_updated) = &self.on_status_updated else {
255            return;
256        };
257        let status = if self.connected {
258            "Connected to streamer"
259        } else if self.wrong_password {
260            "Wrong password"
261        } else if self.started {
262            "Connecting to streamer"
263        } else {
264            "Disconnected from streamer"
265        };
266        on_status_updated(status.to_string());
267    }
268
269    async fn reconnect_soon(&mut self) {
270        self.stop_internal().await;
271        *self.start_on_reconnect_soon.lock().await = false;
272        let start_on_reconnect_soon = Arc::new(Mutex::new(true));
273        self.start_on_reconnect_soon = start_on_reconnect_soon.clone();
274        self.start_soon(start_on_reconnect_soon);
275    }
276
277    fn start_soon(&mut self, start_on_reconnect_soon: Arc<Mutex<bool>>) {
278        let relay = self.me.clone();
279
280        tokio::spawn(async move {
281            info!("Reconnecting in 5 seconds...");
282            sleep(Duration::from_secs(5)).await;
283
284            if *start_on_reconnect_soon.lock().await {
285                info!("Reconnecting...");
286                if let Some(relay) = relay.upgrade() {
287                    relay.lock().await.start_internal().await;
288                }
289            }
290        });
291    }
292
293    async fn handle_message(&mut self, message: MessageToRelay) -> Result<(), AnyError> {
294        match message {
295            MessageToRelay::Hello(hello) => self.handle_message_hello(hello).await,
296            MessageToRelay::Identified(identified) => {
297                self.handle_message_identified(identified).await
298            }
299            MessageToRelay::Request(request) => self.handle_message_request(request).await,
300        }
301    }
302
303    async fn handle_message_hello(&mut self, hello: Hello) -> Result<(), AnyError> {
304        let authentication = calculate_authentication(
305            &self.password,
306            &hello.authentication.salt,
307            &hello.authentication.challenge,
308        );
309        let identify = Identify {
310            id: self.relay_id,
311            name: self.name.clone(),
312            authentication,
313        };
314        self.send(MessageToStreamer::Identify(identify)).await
315    }
316
317    async fn handle_message_identified(&mut self, identified: Identified) -> Result<(), AnyError> {
318        match identified.result {
319            MoblinkResult::Ok(_) => {
320                self.connected = true;
321            }
322            MoblinkResult::WrongPassword(_) => {
323                self.wrong_password = true;
324            }
325        }
326        self.update_status();
327        Ok(())
328    }
329
330    async fn handle_message_request(&mut self, request: MessageRequest) -> Result<(), AnyError> {
331        match &request.data {
332            MessageRequestData::StartTunnel(start_tunnel) => {
333                self.handle_message_request_start_tunnel(&request, start_tunnel)
334                    .await
335            }
336            MessageRequestData::Status(_) => self.handle_message_request_status(request).await,
337        }
338    }
339
340    async fn handle_message_request_start_tunnel(
341        &mut self,
342        request: &MessageRequest,
343        start_tunnel: &StartTunnelRequest,
344    ) -> Result<(), AnyError> {
345        // Pick bind addresses from the relay
346        let local_bind_addr_for_streamer = parse_socket_addr("0.0.0.0")?;
347        let local_bind_addr_for_destination = parse_socket_addr(&self.bind_address)?;
348
349        info!(
350            "Binding streamer socket on: {}, destination socket on: {}",
351            local_bind_addr_for_streamer, local_bind_addr_for_destination
352        );
353        // Create a UDP socket bound for receiving packets from the server.
354        // Use dual-stack socket creation.
355        let streamer_socket = create_dual_stack_udp_socket(local_bind_addr_for_streamer).await?;
356        let streamer_port = streamer_socket.local_addr()?.port();
357        info!("Listening on UDP port: {}", streamer_port);
358        let streamer_socket = Arc::new(streamer_socket);
359
360        // Inform the server about the chosen port.
361        let data = ResponseData::StartTunnel(StartTunnelResponseData {
362            port: streamer_port,
363        });
364        let response = request.to_ok_response(data);
365        self.send(MessageToStreamer::Response(response)).await?;
366
367        // Create a new UDP socket for communication with the destination.
368        // Use dual-stack socket creation.
369        let destination_socket =
370            create_dual_stack_udp_socket(local_bind_addr_for_destination).await?;
371
372        info!(
373            "Bound destination socket to: {:?}",
374            destination_socket.local_addr()?
375        );
376        let destination_socket = Arc::new(destination_socket);
377
378        let normalized_ip = match IpAddr::from_str(&start_tunnel.address)? {
379            IpAddr::V4(v4) => IpAddr::V4(v4),
380            IpAddr::V6(v6) => {
381                // If it’s an IPv4-mapped IPv6 like ::ffff:x.x.x.x, convert to real IPv4
382                if let Some(mapped_v4) = v6.to_ipv4() {
383                    IpAddr::V4(mapped_v4)
384                } else {
385                    // Otherwise, keep it as IPv6
386                    IpAddr::V6(v6)
387                }
388            }
389        };
390        let destination_addr = SocketAddr::new(normalized_ip, start_tunnel.port);
391        info!("Destination address resolved: {}", destination_addr);
392
393        // Use an Arc<Mutex> to share the server_remote_addr between tasks.
394        let streamer_addr: Arc<Mutex<Option<SocketAddr>>> = Arc::new(Mutex::new(None));
395
396        let relay_to_destination = start_relay_from_streamer_to_destination(
397            streamer_socket.clone(),
398            destination_socket.clone(),
399            streamer_addr.clone(),
400            destination_addr,
401        );
402        let relay_to_streamer = start_relay_from_destination_to_streamer(
403            streamer_socket,
404            destination_socket,
405            streamer_addr,
406        );
407
408        *self.reconnect_on_tunnel_error.lock().await = false;
409        let reconnect_on_tunnel_error = Arc::new(Mutex::new(true));
410        self.reconnect_on_tunnel_error = reconnect_on_tunnel_error.clone();
411        let relay = self.me.clone();
412
413        tokio::spawn(async move {
414            let Some(relay) = relay.upgrade() else {
415                return;
416            };
417
418            // Wait for relay tasks to complete (they won't unless an error occurs or the
419            // socket is closed).
420            tokio::select! {
421                res = relay_to_destination => {
422                    if let Err(e) = res {
423                        error!("relay_to_destination task failed: {}", e);
424                    }
425                }
426                res = relay_to_streamer => {
427                    if let Err(e) = res {
428                        error!("relay_to_streamer task failed: {}", e);
429                    }
430                }
431            }
432
433            if *reconnect_on_tunnel_error.lock().await {
434                relay.lock().await.reconnect_soon().await;
435            } else {
436                info!("Not reconnecting after tunnel error");
437            }
438        });
439
440        Ok(())
441    }
442
443    async fn handle_message_request_status(
444        &mut self,
445        request: MessageRequest,
446    ) -> Result<(), AnyError> {
447        let mut battery_percentage = None;
448        if let Some(get_status) = self.get_status.as_ref() {
449            battery_percentage = get_status().await.battery_percentage;
450        }
451        let data = ResponseData::Status(StatusResponseData { battery_percentage });
452        let response = request.to_ok_response(data);
453        self.send(MessageToStreamer::Response(response)).await
454    }
455
456    async fn send(&mut self, message: MessageToStreamer) -> Result<(), AnyError> {
457        let text = serde_json::to_string(&message)?;
458        self.send_message(Message::Text(text.into())).await
459    }
460
461    async fn send_message(&mut self, message: Message) -> Result<(), AnyError> {
462        let Some(writer) = self.ws_writer.as_mut() else {
463            return Err("No websocket writer".into());
464        };
465        writer.send(message).await?;
466        Ok(())
467    }
468}
469
470pub struct Relay {
471    inner: Arc<Mutex<RelayInner>>,
472}
473
474impl Default for Relay {
475    fn default() -> Self {
476        Self::new()
477    }
478}
479
480impl Relay {
481    pub fn new() -> Self {
482        Self {
483            inner: RelayInner::new(),
484        }
485    }
486
487    pub async fn set_bind_address(&self, address: String) {
488        self.inner.lock().await.set_bind_address(address);
489    }
490
491    pub async fn setup<F>(
492        &self,
493        streamer_url: String,
494        password: String,
495        relay_id: Uuid,
496        name: String,
497        on_status_updated: F,
498        get_status: Option<GetStatusClosure>,
499    ) where
500        F: Fn(String) + Send + Sync + 'static,
501    {
502        self.inner
503            .lock()
504            .await
505            .setup(
506                streamer_url,
507                password,
508                relay_id,
509                name,
510                on_status_updated,
511                get_status,
512            )
513            .await;
514    }
515
516    pub async fn is_started(&self) -> bool {
517        self.inner.lock().await.is_started()
518    }
519
520    pub async fn start(&self) {
521        self.inner.lock().await.start().await;
522    }
523
524    pub async fn stop(&self) {
525        self.inner.lock().await.stop().await;
526    }
527}
528
529fn start_relay_from_streamer_to_destination(
530    streamer_socket: Arc<UdpSocket>,
531    destination_socket: Arc<UdpSocket>,
532    streamer_addr: Arc<Mutex<Option<SocketAddr>>>,
533    destination_addr: SocketAddr,
534) -> tokio::task::JoinHandle<()> {
535    tokio::spawn(async move {
536        debug!("(relay_to_destination) Task started");
537        loop {
538            let mut buf = [0; 2048];
539            let (size, remote_addr) =
540                match timeout(Duration::from_secs(30), streamer_socket.recv_from(&mut buf)).await {
541                    Ok(result) => match result {
542                        Ok((size, addr)) => (size, addr),
543                        Err(e) => {
544                            error!("(relay_to_destination) Error receiving from server: {}", e);
545                            continue;
546                        }
547                    },
548                    Err(e) => {
549                        error!(
550                            "(relay_to_destination) Timeout receiving from server: {}",
551                            e
552                        );
553                        break;
554                    }
555                };
556
557            debug!(
558                "(relay_to_destination) Received {} bytes from server: {}",
559                size, remote_addr
560            );
561
562            // Forward to destination.
563            match destination_socket
564                .send_to(&buf[..size], &destination_addr)
565                .await
566            {
567                Ok(bytes_sent) => {
568                    debug!(
569                        "(relay_to_destination) Sent {} bytes to destination",
570                        bytes_sent
571                    )
572                }
573                Err(e) => {
574                    error!(
575                        "(relay_to_destination) Failed to send to destination: {}",
576                        e
577                    );
578                    break;
579                }
580            }
581
582            // Set the remote address if it hasn't been set yet.
583            let mut streamer_addr_lock = streamer_addr.lock().await;
584            if streamer_addr_lock.is_none() {
585                *streamer_addr_lock = Some(remote_addr);
586                debug!(
587                    "(relay_to_destination) Server remote address set to: {}",
588                    remote_addr
589                );
590            }
591        }
592        info!("(relay_to_destination) Task exiting");
593    })
594}
595
596fn start_relay_from_destination_to_streamer(
597    streamer_socket: Arc<UdpSocket>,
598    destination_socket: Arc<UdpSocket>,
599    streamer_addr: Arc<Mutex<Option<SocketAddr>>>,
600) -> tokio::task::JoinHandle<()> {
601    tokio::spawn(async move {
602        debug!("(relay_to_streamer) Task started");
603        loop {
604            let mut buf = [0; 2048];
605            let (size, remote_addr) = match timeout(
606                Duration::from_secs(30),
607                destination_socket.recv_from(&mut buf),
608            )
609            .await
610            {
611                Ok(result) => match result {
612                    Ok((size, addr)) => (size, addr),
613                    Err(e) => {
614                        error!(
615                            "(relay_to_streamer) Error receiving from destination: {}",
616                            e
617                        );
618                        continue;
619                    }
620                },
621                Err(e) => {
622                    error!(
623                        "(relay_to_streamer) Timeout receiving from destination: {}",
624                        e
625                    );
626                    break;
627                }
628            };
629
630            debug!(
631                "(relay_to_streamer) Received {} bytes from destination: {}",
632                size, remote_addr
633            );
634            // Forward to server.
635            let streamer_addr_lock = streamer_addr.lock().await;
636            match *streamer_addr_lock {
637                Some(streamer_addr) => {
638                    match streamer_socket.send_to(&buf[..size], &streamer_addr).await {
639                        Ok(bytes_sent) => {
640                            debug!("(relay_to_streamer) Sent {} bytes to server", bytes_sent)
641                        }
642                        Err(e) => {
643                            error!("(relay_to_streamer) Failed to send to server: {}", e);
644                            break;
645                        }
646                    }
647                }
648                None => {
649                    error!("(relay_to_streamer) Server address not set, cannot forward packet");
650                }
651            }
652        }
653        info!("(relay_to_streamer) Task exiting");
654    })
655}
656
657async fn create_dual_stack_udp_socket(
658    addr: SocketAddr,
659) -> Result<tokio::net::UdpSocket, std::io::Error> {
660    let socket = match addr.is_ipv4() {
661        true => {
662            // Create an IPv4 socket
663            tokio::net::UdpSocket::bind(addr).await?
664        }
665        false => {
666            // Create a dual-stack socket (supporting both IPv4 and IPv6)
667            let socket = socket2::Socket::new(
668                socket2::Domain::IPV6,
669                socket2::Type::DGRAM,
670                Some(socket2::Protocol::UDP),
671            )?;
672
673            // Set IPV6_V6ONLY to false to enable dual-stack support
674            socket.set_only_v6(false)?;
675
676            // Bind the socket
677            socket.bind(&socket2::SockAddr::from(addr))?;
678
679            // Convert to a tokio UdpSocket
680            tokio::net::UdpSocket::from_std(socket.into())?
681        }
682    };
683
684    Ok(socket)
685}
686
687// Helper function to parse a string into a SocketAddr, handling IP addresses
688// without ports.
689fn parse_socket_addr(addr_str: &str) -> Result<SocketAddr, std::io::Error> {
690    // Attempt to parse the string as a full SocketAddr (IP:port)
691    if let Ok(socket_addr) = SocketAddr::from_str(addr_str) {
692        return Ok(socket_addr);
693    }
694
695    // If parsing as SocketAddr fails, try parsing as IP address and append default
696    // port
697    if let Ok(ip_addr) = IpAddr::from_str(addr_str) {
698        // Use 0 as the default port, allowing the OS to assign an available port
699        return Ok(SocketAddr::new(ip_addr, 0));
700    }
701
702    // Return an error if both attempts fail
703    Err(std::io::Error::new(
704        std::io::ErrorKind::InvalidInput,
705        "Invalid socket address syntax. Expected 'IP:port' or 'IP'.",
706    ))
707}
708
709pub fn create_get_status_closure(
710    status_executable: &Option<String>,
711    status_file: &Option<String>,
712) -> Option<GetStatusClosure> {
713    let status_executable = status_executable.clone();
714    let status_file = status_file.clone();
715    Some(Box::new(move || {
716        let status_executable = status_executable.clone();
717        let status_file = status_file.clone();
718        Box::pin(async move {
719            let output = if let Some(status_executable) = &status_executable {
720                let Ok(output) = Command::new(status_executable).output().await else {
721                    return Default::default();
722                };
723                output.stdout
724            } else if let Some(status_file) = &status_file {
725                let Ok(mut file) = File::open(status_file).await else {
726                    return Default::default();
727                };
728                let mut contents = vec![];
729                if file.read_to_end(&mut contents).await.is_err() {
730                    return Default::default();
731                }
732                contents
733            } else {
734                return Default::default();
735            };
736            let output = String::from_utf8(output).unwrap_or_default();
737            match serde_json::from_str(&output) {
738                Ok(status) => status,
739                Err(e) => {
740                    error!("Failed to decode status with error: {e}");
741                    Default::default()
742                }
743            }
744        })
745    }))
746}