moblink_rust/
relay.rs

1use std::future::Future;
2use std::net::{IpAddr, SocketAddr};
3use std::pin::Pin;
4use std::str::FromStr;
5use std::sync::{Arc, Weak};
6
7use futures_util::stream::{SplitSink, SplitStream};
8use futures_util::{SinkExt, StreamExt};
9use log::{debug, error, info};
10use serde::Deserialize;
11use tokio::fs::File;
12use tokio::io::AsyncReadExt;
13use tokio::net::{TcpStream, UdpSocket};
14use tokio::process::Command;
15use tokio::sync::Mutex;
16use tokio::time::{Duration, sleep, timeout};
17use tokio_tungstenite::tungstenite::protocol::Message;
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
19use uuid::Uuid;
20
21use crate::protocol::*;
22use crate::utils::{AnyError, resolve_host};
23
24#[derive(Default, Deserialize, Clone)]
25#[serde(rename_all = "camelCase")]
26pub struct Status {
27    pub battery_percentage: Option<i32>,
28}
29
30pub type GetStatusClosure =
31    Box<dyn Fn() -> Pin<Box<dyn Future<Output = Status> + Send + Sync>> + Send + Sync>;
32
33struct RelayInner {
34    me: Weak<Mutex<Self>>,
35    /// Store a local IP address  for binding UDP sockets
36    bind_address: String,
37    relay_id: Uuid,
38    streamer_url: String,
39    password: String,
40    name: String,
41    on_status_updated: Option<Box<dyn Fn(String) + Send + Sync>>,
42    get_status: Option<Arc<GetStatusClosure>>,
43    ws_writer: Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
44    started: bool,
45    connected: bool,
46    wrong_password: bool,
47    reconnect_on_tunnel_error: Arc<Mutex<bool>>,
48    start_on_reconnect_soon: Arc<Mutex<bool>>,
49}
50
51impl RelayInner {
52    fn new() -> Arc<Mutex<Self>> {
53        Arc::new_cyclic(|me| {
54            Mutex::new(Self {
55                me: me.clone(),
56                bind_address: Self::get_default_bind_address(),
57                relay_id: Uuid::new_v4(),
58                streamer_url: "".to_string(),
59                password: "".to_string(),
60                name: "".to_string(),
61                on_status_updated: None,
62                get_status: None,
63                ws_writer: None,
64                started: false,
65                connected: false,
66                wrong_password: false,
67                reconnect_on_tunnel_error: Arc::new(Mutex::new(false)),
68                start_on_reconnect_soon: Arc::new(Mutex::new(false)),
69            })
70        })
71    }
72
73    fn set_bind_address(&mut self, address: String) {
74        self.bind_address = address;
75    }
76
77    async fn setup<F>(
78        &mut self,
79        streamer_url: String,
80        password: String,
81        relay_id: Uuid,
82        name: String,
83        on_status_updated: F,
84        get_status: Option<GetStatusClosure>,
85    ) where
86        F: Fn(String) + Send + Sync + 'static,
87    {
88        self.on_status_updated = Some(Box::new(on_status_updated));
89        self.get_status = get_status.map(Arc::new);
90        self.relay_id = relay_id;
91        self.streamer_url = streamer_url;
92        self.password = password;
93        self.name = name;
94    }
95
96    fn is_started(&self) -> bool {
97        self.started
98    }
99
100    async fn start(&mut self) {
101        if !self.started {
102            self.started = true;
103            self.start_internal().await;
104        }
105    }
106
107    async fn stop(&mut self) {
108        if self.started {
109            self.started = false;
110            self.stop_internal().await;
111        }
112    }
113
114    fn get_default_bind_address() -> String {
115        // Get main network interface
116        let interfaces = pnet::datalink::interfaces();
117        let interface = interfaces.iter().find(|interface| {
118            interface.is_up() && !interface.is_loopback() && !interface.ips.is_empty()
119        });
120
121        // Only ipv4 addresses are supported
122        let ipv4_addresses: Vec<String> = interface
123            .expect("No available network interfaces found")
124            .ips
125            .iter()
126            .filter_map(|ip| {
127                let ip = ip.ip();
128                ip.is_ipv4().then(|| ip.to_string())
129            })
130            .collect();
131
132        // Return the first address
133        ipv4_addresses
134            .first()
135            .cloned()
136            .unwrap_or("0.0.0.0:0".to_string())
137    }
138
139    async fn start_internal(&mut self) {
140        if !self.started {
141            self.stop_internal().await;
142            return;
143        }
144
145        let request = match url::Url::parse(&self.streamer_url) {
146            Ok(url) => url,
147            Err(e) => {
148                error!("Failed to parse URL: {}", e);
149                return;
150            }
151        };
152
153        match timeout(Duration::from_secs(10), connect_async(request.to_string())).await {
154            Ok(Ok((ws_stream, _))) => {
155                debug!("Connected to {}", self.streamer_url);
156                let (writer, reader) = ws_stream.split();
157                self.ws_writer = Some(writer);
158                self.start_websocket_receiver(reader);
159            }
160            Ok(Err(error)) => {
161                debug!(
162                    "Failed to connect to {} with error: {}",
163                    self.streamer_url, error
164                );
165                self.reconnect_soon().await;
166            }
167            Err(_elapsed) => {
168                debug!(
169                    "Failed to connect to {} within 10 seconds",
170                    self.streamer_url
171                );
172                self.reconnect_soon().await;
173            }
174        }
175    }
176
177    fn start_websocket_receiver(
178        &mut self,
179        mut reader: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
180    ) {
181        // Task to process messages received from the channel.
182        let relay = self.me.clone();
183
184        tokio::spawn(async move {
185            let Some(relay_arc) = relay.upgrade() else {
186                return;
187            };
188
189            while let Some(result) = reader.next().await {
190                let mut relay = relay_arc.lock().await;
191                match result {
192                    Ok(message) => match message {
193                        Message::Text(text) => {
194                            match serde_json::from_str::<MessageToRelay>(&text) {
195                                Ok(message) => {
196                                    if let Err(error) = relay.handle_message(message).await {
197                                        error!("Message handling failed with error: {}", error);
198                                        relay.reconnect_soon().await;
199                                        break;
200                                    }
201                                }
202                                _ => {
203                                    error!("Failed to deserialize message: {}", text);
204                                }
205                            }
206                        }
207                        Message::Binary(data) => {
208                            debug!("Received binary message of length: {}", data.len());
209                        }
210                        Message::Ping(data) => {
211                            relay.send_message(Message::Pong(data)).await.ok();
212                        }
213                        Message::Pong(_) => {
214                            debug!("Received pong message");
215                        }
216                        Message::Close(frame) => {
217                            info!("Received close message: {:?}", frame);
218                            relay.reconnect_soon().await;
219                            break;
220                        }
221                        Message::Frame(_) => {
222                            unreachable!("This is never used")
223                        }
224                    },
225                    Err(e) => {
226                        debug!("Error processing message: {}", e);
227                        // TODO: There has to be a better way to handle this
228                        if e.to_string()
229                            .contains("Connection reset without closing handshake")
230                        {
231                            relay.reconnect_soon().await;
232                        }
233                        break;
234                    }
235                }
236            }
237        });
238    }
239
240    async fn stop_internal(&mut self) {
241        if let Some(mut ws_writer) = self.ws_writer.take() {
242            match ws_writer.close().await {
243                Err(e) => {
244                    error!("Error closing WebSocket: {}", e);
245                }
246                _ => {
247                    debug!("WebSocket closed successfully");
248                }
249            }
250        }
251        self.connected = false;
252        self.wrong_password = false;
253        *self.reconnect_on_tunnel_error.lock().await = false;
254        *self.start_on_reconnect_soon.lock().await = false;
255        self.update_status();
256    }
257
258    fn update_status(&self) {
259        let Some(on_status_updated) = &self.on_status_updated else {
260            return;
261        };
262        let status = if self.connected {
263            "Connected to streamer"
264        } else if self.wrong_password {
265            "Wrong password"
266        } else if self.started {
267            "Connecting to streamer"
268        } else {
269            "Disconnected from streamer"
270        };
271        on_status_updated(status.to_string());
272    }
273
274    async fn reconnect_soon(&mut self) {
275        self.stop_internal().await;
276        *self.start_on_reconnect_soon.lock().await = false;
277        let start_on_reconnect_soon = Arc::new(Mutex::new(true));
278        self.start_on_reconnect_soon = start_on_reconnect_soon.clone();
279        self.start_soon(start_on_reconnect_soon);
280    }
281
282    fn start_soon(&mut self, start_on_reconnect_soon: Arc<Mutex<bool>>) {
283        let relay = self.me.clone();
284
285        tokio::spawn(async move {
286            sleep(Duration::from_secs(5)).await;
287
288            if *start_on_reconnect_soon.lock().await {
289                debug!("Reconnecting...");
290                if let Some(relay) = relay.upgrade() {
291                    relay.lock().await.start_internal().await;
292                }
293            }
294        });
295    }
296
297    async fn handle_message(&mut self, message: MessageToRelay) -> Result<(), AnyError> {
298        match message {
299            MessageToRelay::Hello(hello) => self.handle_message_hello(hello).await,
300            MessageToRelay::Identified(identified) => {
301                self.handle_message_identified(identified).await
302            }
303            MessageToRelay::Request(request) => self.handle_message_request(request).await,
304        }
305    }
306
307    async fn handle_message_hello(&mut self, hello: Hello) -> Result<(), AnyError> {
308        let authentication = calculate_authentication(
309            &self.password,
310            &hello.authentication.salt,
311            &hello.authentication.challenge,
312        );
313        let identify = Identify {
314            id: self.relay_id,
315            name: self.name.clone(),
316            authentication,
317        };
318        self.send(MessageToStreamer::Identify(identify)).await
319    }
320
321    async fn handle_message_identified(&mut self, identified: Identified) -> Result<(), AnyError> {
322        match identified.result {
323            MoblinkResult::Ok(_) => {
324                self.connected = true;
325            }
326            MoblinkResult::WrongPassword(_) => {
327                self.wrong_password = true;
328            }
329        }
330        self.update_status();
331        Ok(())
332    }
333
334    async fn handle_message_request(&mut self, request: MessageRequest) -> Result<(), AnyError> {
335        match &request.data {
336            MessageRequestData::StartTunnel(start_tunnel) => {
337                self.handle_message_request_start_tunnel(&request, start_tunnel)
338                    .await
339            }
340            MessageRequestData::Status(_) => self.handle_message_request_status(request).await,
341        }
342    }
343
344    async fn handle_message_request_start_tunnel(
345        &mut self,
346        request: &MessageRequest,
347        start_tunnel: &StartTunnelRequest,
348    ) -> Result<(), AnyError> {
349        // Pick bind addresses from the relay
350        let local_bind_addr_for_streamer = parse_socket_addr("0.0.0.0")?;
351        let local_bind_addr_for_destination = parse_socket_addr(&self.bind_address)?;
352
353        debug!(
354            "Binding streamer socket on: {}, destination socket on: {}",
355            local_bind_addr_for_streamer, local_bind_addr_for_destination
356        );
357        // Create a UDP socket bound for receiving packets from the server.
358        // Use dual-stack socket creation.
359        let streamer_socket = create_dual_stack_udp_socket(local_bind_addr_for_streamer).await?;
360        let streamer_port = streamer_socket.local_addr()?.port();
361        let streamer_socket = Arc::new(streamer_socket);
362
363        // Inform the server about the chosen port.
364        let data = ResponseData::StartTunnel(StartTunnelResponseData {
365            port: streamer_port,
366        });
367        let response = request.to_ok_response(data);
368        self.send(MessageToStreamer::Response(response)).await?;
369
370        // Create a new UDP socket for communication with the destination.
371        // Use dual-stack socket creation.
372        let destination_socket =
373            create_dual_stack_udp_socket(local_bind_addr_for_destination).await?;
374
375        let destination_socket = Arc::new(destination_socket);
376        let destination_address = resolve_host(&start_tunnel.address).await?;
377        let destination_address = match IpAddr::from_str(&destination_address)? {
378            IpAddr::V4(v4) => IpAddr::V4(v4),
379            IpAddr::V6(v6) => {
380                // If it’s an IPv4-mapped IPv6 like ::ffff:x.x.x.x, convert to real IPv4
381                if let Some(mapped_v4) = v6.to_ipv4() {
382                    IpAddr::V4(mapped_v4)
383                } else {
384                    // Otherwise, keep it as IPv6
385                    IpAddr::V6(v6)
386                }
387            }
388        };
389
390        let destination_address = SocketAddr::new(destination_address, start_tunnel.port);
391        info!("Destination address: {}", destination_address);
392        let streamer_address: Arc<Mutex<Option<SocketAddr>>> = Arc::new(Mutex::new(None));
393
394        let relay_to_destination = start_relay_from_streamer_to_destination(
395            streamer_socket.clone(),
396            destination_socket.clone(),
397            streamer_address.clone(),
398            destination_address,
399        );
400        let relay_to_streamer = start_relay_from_destination_to_streamer(
401            streamer_socket,
402            destination_socket,
403            streamer_address,
404        );
405
406        *self.reconnect_on_tunnel_error.lock().await = false;
407        let reconnect_on_tunnel_error = Arc::new(Mutex::new(true));
408        self.reconnect_on_tunnel_error = reconnect_on_tunnel_error.clone();
409        let relay = self.me.clone();
410
411        tokio::spawn(async move {
412            let Some(relay) = relay.upgrade() else {
413                return;
414            };
415
416            // Wait for relay tasks to complete (they won't unless an error occurs or the
417            // socket is closed).
418            tokio::select! {
419                res = relay_to_destination => {
420                    if let Err(e) = res {
421                        error!("relay_to_destination task failed: {}", e);
422                    }
423                }
424                res = relay_to_streamer => {
425                    if let Err(e) = res {
426                        error!("relay_to_streamer task failed: {}", e);
427                    }
428                }
429            }
430
431            if *reconnect_on_tunnel_error.lock().await {
432                relay.lock().await.reconnect_soon().await;
433            } else {
434                info!("Not reconnecting after tunnel error");
435            }
436        });
437
438        Ok(())
439    }
440
441    async fn handle_message_request_status(
442        &mut self,
443        request: MessageRequest,
444    ) -> Result<(), AnyError> {
445        let mut battery_percentage = None;
446        if let Some(get_status) = self.get_status.as_ref() {
447            battery_percentage = get_status().await.battery_percentage;
448        }
449        let data = ResponseData::Status(StatusResponseData { battery_percentage });
450        let response = request.to_ok_response(data);
451        self.send(MessageToStreamer::Response(response)).await
452    }
453
454    async fn send(&mut self, message: MessageToStreamer) -> Result<(), AnyError> {
455        let text = serde_json::to_string(&message)?;
456        self.send_message(Message::Text(text.into())).await
457    }
458
459    async fn send_message(&mut self, message: Message) -> Result<(), AnyError> {
460        let Some(writer) = self.ws_writer.as_mut() else {
461            return Err("No websocket writer".into());
462        };
463        writer.send(message).await?;
464        Ok(())
465    }
466}
467
468pub struct Relay {
469    inner: Arc<Mutex<RelayInner>>,
470}
471
472impl Default for Relay {
473    fn default() -> Self {
474        Self::new()
475    }
476}
477
478impl Relay {
479    pub fn new() -> Self {
480        Self {
481            inner: RelayInner::new(),
482        }
483    }
484
485    pub async fn set_bind_address(&self, address: String) {
486        self.inner.lock().await.set_bind_address(address);
487    }
488
489    pub async fn setup<F>(
490        &self,
491        streamer_url: String,
492        password: String,
493        relay_id: Uuid,
494        name: String,
495        on_status_updated: F,
496        get_status: Option<GetStatusClosure>,
497    ) where
498        F: Fn(String) + Send + Sync + 'static,
499    {
500        self.inner
501            .lock()
502            .await
503            .setup(
504                streamer_url,
505                password,
506                relay_id,
507                name,
508                on_status_updated,
509                get_status,
510            )
511            .await;
512    }
513
514    pub async fn is_started(&self) -> bool {
515        self.inner.lock().await.is_started()
516    }
517
518    pub async fn start(&self) {
519        self.inner.lock().await.start().await;
520    }
521
522    pub async fn stop(&self) {
523        self.inner.lock().await.stop().await;
524    }
525}
526
527fn start_relay_from_streamer_to_destination(
528    streamer_socket: Arc<UdpSocket>,
529    destination_socket: Arc<UdpSocket>,
530    streamer_addr: Arc<Mutex<Option<SocketAddr>>>,
531    destination_addr: SocketAddr,
532) -> tokio::task::JoinHandle<()> {
533    tokio::spawn(async move {
534        loop {
535            if let Err(error) = relay_one_packet_from_streamer_to_destination(
536                &streamer_socket,
537                &destination_socket,
538                &streamer_addr,
539                &destination_addr,
540            )
541            .await
542            {
543                error!("(relay_to_destination) Failed with error: {}", error);
544                break;
545            }
546        }
547    })
548}
549
550async fn relay_one_packet_from_streamer_to_destination(
551    streamer_socket: &Arc<UdpSocket>,
552    destination_socket: &Arc<UdpSocket>,
553    streamer_addr: &Arc<Mutex<Option<SocketAddr>>>,
554    destination_addr: &SocketAddr,
555) -> Result<(), AnyError> {
556    let mut buf = [0; 2048];
557    let (size, remote_addr) =
558        timeout(Duration::from_secs(30), streamer_socket.recv_from(&mut buf)).await??;
559    destination_socket
560        .send_to(&buf[..size], &destination_addr)
561        .await?;
562    streamer_addr.lock().await.replace(remote_addr);
563    Ok(())
564}
565
566fn start_relay_from_destination_to_streamer(
567    streamer_socket: Arc<UdpSocket>,
568    destination_socket: Arc<UdpSocket>,
569    streamer_address: Arc<Mutex<Option<SocketAddr>>>,
570) -> tokio::task::JoinHandle<()> {
571    tokio::spawn(async move {
572        loop {
573            if let Err(error) = relay_one_packet_from_destination_to_streamer(
574                &streamer_socket,
575                &destination_socket,
576                &streamer_address,
577            )
578            .await
579            {
580                error!("(relay_to_streamer) Failed with error: {}", error);
581                break;
582            }
583        }
584    })
585}
586
587async fn relay_one_packet_from_destination_to_streamer(
588    streamer_socket: &Arc<UdpSocket>,
589    destination_socket: &Arc<UdpSocket>,
590    streamer_address: &Arc<Mutex<Option<SocketAddr>>>,
591) -> Result<(), AnyError> {
592    let mut buf = [0; 2048];
593    let size = timeout(Duration::from_secs(30), destination_socket.recv(&mut buf)).await??;
594    let streamer_addr = streamer_address
595        .lock()
596        .await
597        .ok_or("Failed to get address lock")?;
598    streamer_socket
599        .send_to(&buf[..size], &streamer_addr)
600        .await?;
601    Ok(())
602}
603
604async fn create_dual_stack_udp_socket(
605    addr: SocketAddr,
606) -> Result<tokio::net::UdpSocket, std::io::Error> {
607    let socket = match addr.is_ipv4() {
608        true => {
609            // Create an IPv4 socket
610            tokio::net::UdpSocket::bind(addr).await?
611        }
612        false => {
613            // Create a dual-stack socket (supporting both IPv4 and IPv6)
614            let socket = socket2::Socket::new(
615                socket2::Domain::IPV6,
616                socket2::Type::DGRAM,
617                Some(socket2::Protocol::UDP),
618            )?;
619
620            // Set IPV6_V6ONLY to false to enable dual-stack support
621            socket.set_only_v6(false)?;
622
623            // Bind the socket
624            socket.bind(&socket2::SockAddr::from(addr))?;
625
626            // Convert to a tokio UdpSocket
627            tokio::net::UdpSocket::from_std(socket.into())?
628        }
629    };
630
631    Ok(socket)
632}
633
634// Helper function to parse a string into a SocketAddr, handling IP addresses
635// without ports.
636fn parse_socket_addr(addr_str: &str) -> Result<SocketAddr, std::io::Error> {
637    // Attempt to parse the string as a full SocketAddr (IP:port)
638    if let Ok(socket_addr) = SocketAddr::from_str(addr_str) {
639        return Ok(socket_addr);
640    }
641
642    // If parsing as SocketAddr fails, try parsing as IP address and append default
643    // port
644    if let Ok(ip_addr) = IpAddr::from_str(addr_str) {
645        // Use 0 as the default port, allowing the OS to assign an available port
646        return Ok(SocketAddr::new(ip_addr, 0));
647    }
648
649    // Return an error if both attempts fail
650    Err(std::io::Error::new(
651        std::io::ErrorKind::InvalidInput,
652        "Invalid socket address syntax. Expected 'IP:port' or 'IP'.",
653    ))
654}
655
656pub fn create_get_status_closure(
657    status_executable: &Option<String>,
658    status_file: &Option<String>,
659) -> Option<GetStatusClosure> {
660    let status_executable = status_executable.clone();
661    let status_file = status_file.clone();
662    Some(Box::new(move || {
663        let status_executable = status_executable.clone();
664        let status_file = status_file.clone();
665        Box::pin(async move {
666            let output = if let Some(status_executable) = &status_executable {
667                let Ok(output) = Command::new(status_executable).output().await else {
668                    return Default::default();
669                };
670                output.stdout
671            } else if let Some(status_file) = &status_file {
672                let Ok(mut file) = File::open(status_file).await else {
673                    return Default::default();
674                };
675                let mut contents = vec![];
676                if file.read_to_end(&mut contents).await.is_err() {
677                    return Default::default();
678                }
679                contents
680            } else {
681                return Default::default();
682            };
683            let output = String::from_utf8(output).unwrap_or_default();
684            match serde_json::from_str(&output) {
685                Ok(status) => status,
686                Err(e) => {
687                    error!("Failed to decode status with error: {e}");
688                    Default::default()
689                }
690            }
691        })
692    }))
693}