moblink_rust/
relay.rs

1use std::future::Future;
2use std::net::{IpAddr, SocketAddr};
3use std::pin::Pin;
4use std::str::FromStr;
5use std::sync::{Arc, Weak};
6
7use futures_util::stream::{SplitSink, SplitStream};
8use futures_util::{SinkExt, StreamExt};
9use log::{debug, error, info};
10use serde::Deserialize;
11use tokio::fs::File;
12use tokio::io::AsyncReadExt;
13use tokio::net::{TcpStream, UdpSocket};
14use tokio::process::Command;
15use tokio::sync::Mutex;
16use tokio::time::{Duration, sleep, timeout};
17use tokio_tungstenite::tungstenite::protocol::Message;
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
19use uuid::Uuid;
20
21use crate::protocol::*;
22use crate::utils::{AnyError, resolve_host};
23
24#[derive(Default, Deserialize, Clone)]
25#[serde(rename_all = "camelCase")]
26pub struct Status {
27    pub battery_percentage: Option<i32>,
28}
29
30pub type GetStatusClosure =
31    Box<dyn Fn() -> Pin<Box<dyn Future<Output = Status> + Send + Sync>> + Send + Sync>;
32
33struct RelayInner {
34    me: Weak<Mutex<Self>>,
35    /// Store a local IP address  for binding UDP sockets
36    bind_address: String,
37    relay_id: Uuid,
38    streamer_url: String,
39    password: String,
40    name: String,
41    on_status_updated: Option<Box<dyn Fn(String) + Send + Sync>>,
42    get_status: Option<Arc<GetStatusClosure>>,
43    ws_writer: Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
44    started: bool,
45    connected: bool,
46    wrong_password: bool,
47    reconnect_on_tunnel_error: Arc<Mutex<bool>>,
48    start_on_reconnect_soon: Arc<Mutex<bool>>,
49    relay_to_destination: Option<tokio::task::JoinHandle<Result<(), AnyError>>>,
50}
51
52impl RelayInner {
53    fn new() -> Arc<Mutex<Self>> {
54        Arc::new_cyclic(|me| {
55            Mutex::new(Self {
56                me: me.clone(),
57                bind_address: Self::get_default_bind_address(),
58                relay_id: Uuid::new_v4(),
59                streamer_url: "".to_string(),
60                password: "".to_string(),
61                name: "".to_string(),
62                on_status_updated: None,
63                get_status: None,
64                ws_writer: None,
65                started: false,
66                connected: false,
67                wrong_password: false,
68                reconnect_on_tunnel_error: Arc::new(Mutex::new(false)),
69                start_on_reconnect_soon: Arc::new(Mutex::new(false)),
70                relay_to_destination: None,
71            })
72        })
73    }
74
75    fn set_bind_address(&mut self, address: String) {
76        self.bind_address = address;
77    }
78
79    async fn setup<F>(
80        &mut self,
81        streamer_url: String,
82        password: String,
83        relay_id: Uuid,
84        name: String,
85        on_status_updated: F,
86        get_status: Option<GetStatusClosure>,
87    ) where
88        F: Fn(String) + Send + Sync + 'static,
89    {
90        self.on_status_updated = Some(Box::new(on_status_updated));
91        self.get_status = get_status.map(Arc::new);
92        self.relay_id = relay_id;
93        self.streamer_url = streamer_url;
94        self.password = password;
95        self.name = name;
96    }
97
98    fn is_started(&self) -> bool {
99        self.started
100    }
101
102    async fn start(&mut self) {
103        if !self.started {
104            self.started = true;
105            self.start_internal().await;
106        }
107    }
108
109    async fn stop(&mut self) {
110        if self.started {
111            self.started = false;
112            self.stop_internal().await;
113        }
114    }
115
116    fn get_default_bind_address() -> String {
117        // Get main network interface
118        let interfaces = pnet::datalink::interfaces();
119        let interface = interfaces.iter().find(|interface| {
120            interface.is_up() && !interface.is_loopback() && !interface.ips.is_empty()
121        });
122
123        // Only ipv4 addresses are supported
124        let ipv4_addresses: Vec<String> = interface
125            .expect("No available network interfaces found")
126            .ips
127            .iter()
128            .filter_map(|ip| {
129                let ip = ip.ip();
130                ip.is_ipv4().then(|| ip.to_string())
131            })
132            .collect();
133
134        // Return the first address
135        ipv4_addresses
136            .first()
137            .cloned()
138            .unwrap_or("0.0.0.0:0".to_string())
139    }
140
141    async fn start_internal(&mut self) {
142        if !self.started {
143            self.stop_internal().await;
144            return;
145        }
146
147        let request = match url::Url::parse(&self.streamer_url) {
148            Ok(url) => url,
149            Err(e) => {
150                error!("Failed to parse URL: {}", e);
151                return;
152            }
153        };
154
155        match timeout(Duration::from_secs(10), connect_async(request.to_string())).await {
156            Ok(Ok((ws_stream, _))) => {
157                debug!("Connected to {}", self.streamer_url);
158                let (writer, reader) = ws_stream.split();
159                self.ws_writer = Some(writer);
160                self.start_websocket_receiver(reader);
161            }
162            Ok(Err(error)) => {
163                debug!(
164                    "Failed to connect to {} with error: {}",
165                    self.streamer_url, error
166                );
167                self.reconnect_soon().await;
168            }
169            Err(_elapsed) => {
170                debug!(
171                    "Failed to connect to {} within 10 seconds",
172                    self.streamer_url
173                );
174                self.reconnect_soon().await;
175            }
176        }
177    }
178
179    fn start_websocket_receiver(
180        &mut self,
181        mut reader: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
182    ) {
183        // Task to process messages received from the channel.
184        let relay = self.me.clone();
185
186        tokio::spawn(async move {
187            let Some(relay_arc) = relay.upgrade() else {
188                return;
189            };
190
191            while let Some(result) = reader.next().await {
192                let mut relay = relay_arc.lock().await;
193                match result {
194                    Ok(message) => match message {
195                        Message::Text(text) => {
196                            match serde_json::from_str::<MessageToRelay>(&text) {
197                                Ok(message) => {
198                                    if let Err(error) = relay.handle_message(message).await {
199                                        error!("Message handling failed with error: {}", error);
200                                        relay.reconnect_soon().await;
201                                        break;
202                                    }
203                                }
204                                _ => {
205                                    error!("Failed to deserialize message: {}", text);
206                                }
207                            }
208                        }
209                        Message::Binary(data) => {
210                            debug!("Received binary message of length: {}", data.len());
211                        }
212                        Message::Ping(data) => {
213                            relay.send_message(Message::Pong(data)).await.ok();
214                        }
215                        Message::Pong(_) => {
216                            debug!("Received pong message");
217                        }
218                        Message::Close(frame) => {
219                            info!("Received close message: {:?}", frame);
220                            relay.reconnect_soon().await;
221                            break;
222                        }
223                        Message::Frame(_) => {
224                            unreachable!("This is never used")
225                        }
226                    },
227                    Err(e) => {
228                        debug!("Error processing message: {}", e);
229                        // TODO: There has to be a better way to handle this
230                        if e.to_string()
231                            .contains("Connection reset without closing handshake")
232                        {
233                            relay.reconnect_soon().await;
234                        }
235                        break;
236                    }
237                }
238            }
239        });
240    }
241
242    async fn stop_internal(&mut self) {
243        if let Some(mut ws_writer) = self.ws_writer.take() {
244            match ws_writer.close().await {
245                Err(e) => {
246                    error!("Error closing WebSocket: {}", e);
247                }
248                _ => {
249                    debug!("WebSocket closed successfully");
250                }
251            }
252        }
253        self.connected = false;
254        self.wrong_password = false;
255        *self.reconnect_on_tunnel_error.lock().await = false;
256        *self.start_on_reconnect_soon.lock().await = false;
257        if let Some(relay_to_destination) = self.relay_to_destination.take() {
258            relay_to_destination.abort();
259            relay_to_destination.await.ok();
260        }
261        self.update_status();
262    }
263
264    fn update_status(&self) {
265        let Some(on_status_updated) = &self.on_status_updated else {
266            return;
267        };
268        let status = if self.connected {
269            "Connected to streamer"
270        } else if self.wrong_password {
271            "Wrong password"
272        } else if self.started {
273            "Connecting to streamer"
274        } else {
275            "Disconnected from streamer"
276        };
277        on_status_updated(status.to_string());
278    }
279
280    async fn reconnect_soon(&mut self) {
281        self.stop_internal().await;
282        *self.start_on_reconnect_soon.lock().await = false;
283        let start_on_reconnect_soon = Arc::new(Mutex::new(true));
284        self.start_on_reconnect_soon = start_on_reconnect_soon.clone();
285        self.start_soon(start_on_reconnect_soon);
286    }
287
288    fn start_soon(&mut self, start_on_reconnect_soon: Arc<Mutex<bool>>) {
289        let relay = self.me.clone();
290
291        tokio::spawn(async move {
292            sleep(Duration::from_secs(5)).await;
293
294            if *start_on_reconnect_soon.lock().await {
295                debug!("Reconnecting...");
296                if let Some(relay) = relay.upgrade() {
297                    relay.lock().await.start_internal().await;
298                }
299            }
300        });
301    }
302
303    async fn handle_message(&mut self, message: MessageToRelay) -> Result<(), AnyError> {
304        match message {
305            MessageToRelay::Hello(hello) => self.handle_message_hello(hello).await,
306            MessageToRelay::Identified(identified) => {
307                self.handle_message_identified(identified).await
308            }
309            MessageToRelay::Request(request) => self.handle_message_request(request).await,
310        }
311    }
312
313    async fn handle_message_hello(&mut self, hello: Hello) -> Result<(), AnyError> {
314        let authentication = calculate_authentication(
315            &self.password,
316            &hello.authentication.salt,
317            &hello.authentication.challenge,
318        );
319        let identify = Identify {
320            id: self.relay_id,
321            name: self.name.clone(),
322            authentication,
323        };
324        self.send(MessageToStreamer::Identify(identify)).await
325    }
326
327    async fn handle_message_identified(&mut self, identified: Identified) -> Result<(), AnyError> {
328        match identified.result {
329            MoblinkResult::Ok(_) => {
330                self.connected = true;
331            }
332            MoblinkResult::WrongPassword(_) => {
333                self.wrong_password = true;
334            }
335        }
336        self.update_status();
337        Ok(())
338    }
339
340    async fn handle_message_request(&mut self, request: MessageRequest) -> Result<(), AnyError> {
341        match &request.data {
342            MessageRequestData::StartTunnel(start_tunnel) => {
343                self.handle_message_request_start_tunnel(&request, start_tunnel)
344                    .await
345            }
346            MessageRequestData::Status(_) => self.handle_message_request_status(request).await,
347        }
348    }
349
350    async fn handle_message_request_start_tunnel(
351        &mut self,
352        request: &MessageRequest,
353        start_tunnel: &StartTunnelRequest,
354    ) -> Result<(), AnyError> {
355        // Pick bind addresses from the relay
356        let local_bind_addr_for_streamer = parse_socket_addr("0.0.0.0")?;
357        let local_bind_addr_for_destination = parse_socket_addr(&self.bind_address)?;
358
359        debug!(
360            "Binding streamer socket on: {}, destination socket on: {}",
361            local_bind_addr_for_streamer, local_bind_addr_for_destination
362        );
363        // Create a UDP socket bound for receiving packets from the server.
364        // Use dual-stack socket creation.
365        let streamer_socket = create_dual_stack_udp_socket(local_bind_addr_for_streamer).await?;
366        let streamer_port = streamer_socket.local_addr()?.port();
367        let streamer_socket = Arc::new(streamer_socket);
368
369        // Inform the server about the chosen port.
370        let data = ResponseData::StartTunnel(StartTunnelResponseData {
371            port: streamer_port,
372        });
373        let response = request.to_ok_response(data);
374        self.send(MessageToStreamer::Response(response)).await?;
375
376        // Create a new UDP socket for communication with the destination.
377        // Use dual-stack socket creation.
378        let destination_socket =
379            create_dual_stack_udp_socket(local_bind_addr_for_destination).await?;
380
381        let destination_socket = Arc::new(destination_socket);
382        let destination_address = resolve_host(&start_tunnel.address).await?;
383        let destination_address = match IpAddr::from_str(&destination_address)? {
384            IpAddr::V4(v4) => IpAddr::V4(v4),
385            IpAddr::V6(v6) => {
386                // If it’s an IPv4-mapped IPv6 like ::ffff:x.x.x.x, convert to real IPv4
387                if let Some(mapped_v4) = v6.to_ipv4() {
388                    IpAddr::V4(mapped_v4)
389                } else {
390                    // Otherwise, keep it as IPv6
391                    IpAddr::V6(v6)
392                }
393            }
394        };
395
396        let destination_address = SocketAddr::new(destination_address, start_tunnel.port);
397        info!("Destination address: {}", destination_address);
398
399        self.relay_to_destination = Some(
400            self.start_relay_from_streamer_to_destination(
401                streamer_socket,
402                destination_socket,
403                destination_address,
404            )
405            .await,
406        );
407
408        Ok(())
409    }
410
411    async fn start_relay_from_streamer_to_destination(
412        &mut self,
413        streamer_socket: Arc<UdpSocket>,
414        destination_socket: Arc<UdpSocket>,
415        destination_addr: SocketAddr,
416    ) -> tokio::task::JoinHandle<Result<(), AnyError>> {
417        *self.reconnect_on_tunnel_error.lock().await = false;
418        let reconnect_on_tunnel_error = Arc::new(Mutex::new(true));
419        self.reconnect_on_tunnel_error = reconnect_on_tunnel_error.clone();
420        let relay = self.me.clone();
421
422        tokio::spawn(async move {
423            let streamer_address = Arc::new(Mutex::new(None));
424            let mut relay_to_destination_started = false;
425            let mut buf = [0; 2048];
426
427            loop {
428                let (size, remote_addr) = streamer_socket.recv_from(&mut buf).await?;
429                destination_socket
430                    .send_to(&buf[..size], &destination_addr)
431                    .await?;
432                streamer_address.lock().await.replace(remote_addr);
433
434                if !relay_to_destination_started {
435                    start_relay_from_destination_to_streamer(
436                        relay.clone(),
437                        streamer_socket.clone(),
438                        destination_socket.clone(),
439                        streamer_address.clone(),
440                        reconnect_on_tunnel_error.clone(),
441                    );
442                    relay_to_destination_started = true;
443                }
444            }
445        })
446    }
447
448    async fn handle_message_request_status(
449        &mut self,
450        request: MessageRequest,
451    ) -> Result<(), AnyError> {
452        let mut battery_percentage = None;
453        if let Some(get_status) = self.get_status.as_ref() {
454            battery_percentage = get_status().await.battery_percentage;
455        }
456        let data = ResponseData::Status(StatusResponseData { battery_percentage });
457        let response = request.to_ok_response(data);
458        self.send(MessageToStreamer::Response(response)).await
459    }
460
461    async fn send(&mut self, message: MessageToStreamer) -> Result<(), AnyError> {
462        let text = serde_json::to_string(&message)?;
463        self.send_message(Message::Text(text.into())).await
464    }
465
466    async fn send_message(&mut self, message: Message) -> Result<(), AnyError> {
467        let Some(writer) = self.ws_writer.as_mut() else {
468            return Err("No websocket writer".into());
469        };
470        writer.send(message).await?;
471        Ok(())
472    }
473}
474
475pub struct Relay {
476    inner: Arc<Mutex<RelayInner>>,
477}
478
479impl Default for Relay {
480    fn default() -> Self {
481        Self::new()
482    }
483}
484
485impl Relay {
486    pub fn new() -> Self {
487        Self {
488            inner: RelayInner::new(),
489        }
490    }
491
492    pub async fn set_bind_address(&self, address: String) {
493        self.inner.lock().await.set_bind_address(address);
494    }
495
496    pub async fn setup<F>(
497        &self,
498        streamer_url: String,
499        password: String,
500        relay_id: Uuid,
501        name: String,
502        on_status_updated: F,
503        get_status: Option<GetStatusClosure>,
504    ) where
505        F: Fn(String) + Send + Sync + 'static,
506    {
507        self.inner
508            .lock()
509            .await
510            .setup(
511                streamer_url,
512                password,
513                relay_id,
514                name,
515                on_status_updated,
516                get_status,
517            )
518            .await;
519    }
520
521    pub async fn is_started(&self) -> bool {
522        self.inner.lock().await.is_started()
523    }
524
525    pub async fn start(&self) {
526        self.inner.lock().await.start().await;
527    }
528
529    pub async fn stop(&self) {
530        self.inner.lock().await.stop().await;
531    }
532}
533
534fn start_relay_from_destination_to_streamer(
535    relay: Weak<Mutex<RelayInner>>,
536    streamer_socket: Arc<UdpSocket>,
537    destination_socket: Arc<UdpSocket>,
538    streamer_address: Arc<Mutex<Option<SocketAddr>>>,
539    reconnect_on_tunnel_error: Arc<Mutex<bool>>,
540) {
541    tokio::spawn(async move {
542        loop {
543            if let Err(error) = relay_one_packet_from_destination_to_streamer(
544                &streamer_socket,
545                &destination_socket,
546                &streamer_address,
547            )
548            .await
549            {
550                info!("(relay_to_streamer) Failed with error: {}", error);
551                break;
552            }
553        }
554
555        if *reconnect_on_tunnel_error.lock().await {
556            if let Some(relay) = relay.upgrade() {
557                relay.lock().await.reconnect_soon().await;
558            }
559        } else {
560            info!("Not reconnecting after tunnel error");
561        }
562    });
563}
564
565async fn relay_one_packet_from_destination_to_streamer(
566    streamer_socket: &Arc<UdpSocket>,
567    destination_socket: &Arc<UdpSocket>,
568    streamer_address: &Arc<Mutex<Option<SocketAddr>>>,
569) -> Result<(), AnyError> {
570    let mut buf = [0; 2048];
571    let size = timeout(Duration::from_secs(30), destination_socket.recv(&mut buf)).await??;
572    let streamer_addr = streamer_address
573        .lock()
574        .await
575        .ok_or("Failed to get address lock")?;
576    streamer_socket
577        .send_to(&buf[..size], &streamer_addr)
578        .await?;
579    Ok(())
580}
581
582async fn create_dual_stack_udp_socket(
583    addr: SocketAddr,
584) -> Result<tokio::net::UdpSocket, std::io::Error> {
585    let socket = match addr.is_ipv4() {
586        true => {
587            // Create an IPv4 socket
588            tokio::net::UdpSocket::bind(addr).await?
589        }
590        false => {
591            // Create a dual-stack socket (supporting both IPv4 and IPv6)
592            let socket = socket2::Socket::new(
593                socket2::Domain::IPV6,
594                socket2::Type::DGRAM,
595                Some(socket2::Protocol::UDP),
596            )?;
597
598            // Set IPV6_V6ONLY to false to enable dual-stack support
599            socket.set_only_v6(false)?;
600
601            // Bind the socket
602            socket.bind(&socket2::SockAddr::from(addr))?;
603
604            // Convert to a tokio UdpSocket
605            tokio::net::UdpSocket::from_std(socket.into())?
606        }
607    };
608
609    Ok(socket)
610}
611
612// Helper function to parse a string into a SocketAddr, handling IP addresses
613// without ports.
614fn parse_socket_addr(addr_str: &str) -> Result<SocketAddr, std::io::Error> {
615    // Attempt to parse the string as a full SocketAddr (IP:port)
616    if let Ok(socket_addr) = SocketAddr::from_str(addr_str) {
617        return Ok(socket_addr);
618    }
619
620    // If parsing as SocketAddr fails, try parsing as IP address and append default
621    // port
622    if let Ok(ip_addr) = IpAddr::from_str(addr_str) {
623        // Use 0 as the default port, allowing the OS to assign an available port
624        return Ok(SocketAddr::new(ip_addr, 0));
625    }
626
627    // Return an error if both attempts fail
628    Err(std::io::Error::new(
629        std::io::ErrorKind::InvalidInput,
630        "Invalid socket address syntax. Expected 'IP:port' or 'IP'.",
631    ))
632}
633
634pub fn create_get_status_closure(
635    status_executable: &Option<String>,
636    status_file: &Option<String>,
637) -> Option<GetStatusClosure> {
638    let status_executable = status_executable.clone();
639    let status_file = status_file.clone();
640    Some(Box::new(move || {
641        let status_executable = status_executable.clone();
642        let status_file = status_file.clone();
643        Box::pin(async move {
644            let output = if let Some(status_executable) = &status_executable {
645                let Ok(output) = Command::new(status_executable).output().await else {
646                    return Default::default();
647                };
648                output.stdout
649            } else if let Some(status_file) = &status_file {
650                let Ok(mut file) = File::open(status_file).await else {
651                    return Default::default();
652                };
653                let mut contents = vec![];
654                if file.read_to_end(&mut contents).await.is_err() {
655                    return Default::default();
656                }
657                contents
658            } else {
659                return Default::default();
660            };
661            let output = String::from_utf8(output).unwrap_or_default();
662            match serde_json::from_str(&output) {
663                Ok(status) => status,
664                Err(e) => {
665                    error!("Failed to decode status with error: {e}");
666                    Default::default()
667                }
668            }
669        })
670    }))
671}