moblink_rust/
relay.rs

1use serde::Deserialize;
2use std::future::Future;
3use std::net::{IpAddr, SocketAddr};
4use std::pin::Pin;
5use std::str::FromStr;
6use std::sync::{Arc, Weak};
7
8use base64::engine::general_purpose;
9use base64::Engine as _;
10use futures_util::stream::{SplitSink, SplitStream};
11use futures_util::{SinkExt, StreamExt};
12use log::{debug, error, info};
13use sha2::{Digest, Sha256};
14use tokio::net::{TcpStream, UdpSocket};
15use tokio::sync::Mutex;
16use tokio::time::{sleep, timeout, Duration};
17use tokio_tungstenite::tungstenite::protocol::Message;
18use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
19
20use crate::protocol::*;
21
22#[derive(Default, Deserialize)]
23#[serde(rename_all = "camelCase")]
24pub struct Status {
25    pub battery_percentage: Option<i32>,
26}
27
28pub type GetStatusClosure =
29    Box<dyn Fn() -> Pin<Box<dyn Future<Output = Status> + Send + Sync>> + Send + Sync>;
30
31pub struct Relay {
32    me: Weak<Mutex<Self>>,
33    /// Store a local IP address  for binding UDP sockets
34    bind_address: String,
35    relay_id: String,
36    streamer_url: String,
37    password: String,
38    name: String,
39    on_status_updated: Option<Box<dyn Fn(String) + Send + Sync>>,
40    get_status: Option<Arc<GetStatusClosure>>,
41    ws_writer: Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
42    started: bool,
43    connected: bool,
44    wrong_password: bool,
45    reconnect_on_tunnel_error: Arc<Mutex<bool>>,
46    start_on_reconnect_soon: Arc<Mutex<bool>>,
47}
48
49impl Relay {
50    pub fn new() -> Arc<Mutex<Self>> {
51        Arc::new_cyclic(|me| {
52            Mutex::new(Self {
53                me: me.clone(),
54                bind_address: Self::get_default_bind_address(),
55                relay_id: "".to_string(),
56                streamer_url: "".to_string(),
57                password: "".to_string(),
58                name: "".to_string(),
59                on_status_updated: None,
60                get_status: None,
61                ws_writer: None,
62                started: false,
63                connected: false,
64                wrong_password: false,
65                reconnect_on_tunnel_error: Arc::new(Mutex::new(false)),
66                start_on_reconnect_soon: Arc::new(Mutex::new(false)),
67            })
68        })
69    }
70
71    pub fn set_bind_address(&mut self, address: String) {
72        self.bind_address = address;
73    }
74
75    pub async fn setup<F>(
76        &mut self,
77        streamer_url: String,
78        password: String,
79        relay_id: String,
80        name: String,
81        on_status_updated: F,
82        get_status: GetStatusClosure,
83    ) where
84        F: Fn(String) + Send + Sync + 'static,
85    {
86        self.on_status_updated = Some(Box::new(on_status_updated));
87        self.get_status = Some(Arc::new(get_status));
88        self.relay_id = relay_id;
89        self.streamer_url = streamer_url;
90        self.password = password;
91        self.name = name;
92        info!("Binding to address: {:?}", self.bind_address);
93    }
94
95    pub fn is_started(&self) -> bool {
96        self.started
97    }
98
99    pub async fn start(&mut self) {
100        if !self.started {
101            self.started = true;
102            self.start_internal().await;
103        }
104    }
105
106    pub async fn stop(&mut self) {
107        if self.started {
108            self.started = false;
109            self.stop_internal().await;
110        }
111    }
112
113    fn get_default_bind_address() -> String {
114        // Get main network interface
115        let interfaces = pnet::datalink::interfaces();
116        let interface = interfaces.iter().find(|interface| {
117            interface.is_up() && !interface.is_loopback() && !interface.ips.is_empty()
118        });
119
120        // Only ipv4 addresses are supported
121        let ipv4_addresses: Vec<String> = interface
122            .expect("No available network interfaces found")
123            .ips
124            .iter()
125            .filter_map(|ip| {
126                let ip = ip.ip();
127                ip.is_ipv4().then(|| ip.to_string())
128            })
129            .collect();
130
131        // Return the first address
132        ipv4_addresses
133            .first()
134            .cloned()
135            .unwrap_or("0.0.0.0:0".to_string())
136    }
137
138    async fn start_internal(&mut self) {
139        info!("Start internal");
140        if !self.started {
141            self.stop_internal().await;
142            return;
143        }
144
145        let request = match url::Url::parse(&self.streamer_url) {
146            Ok(url) => url,
147            Err(e) => {
148                error!("Failed to parse URL: {}", e);
149                return;
150            }
151        };
152
153        match timeout(Duration::from_secs(10), connect_async(request.to_string())).await {
154            Ok(Ok((ws_stream, _))) => {
155                info!("WebSocket connected");
156                let (writer, reader) = ws_stream.split();
157                self.ws_writer = Some(writer);
158                self.start_websocket_receiver(reader);
159            }
160            Ok(Err(error)) => {
161                // This means the future completed but the connection failed
162                error!("WebSocket connection failed immediately: {}", error);
163                self.reconnect_soon().await;
164            }
165            Err(_elapsed) => {
166                // This means the future did NOT complete within 10 seconds
167                error!("WebSocket connection attempt timed out after 10 seconds");
168                self.reconnect_soon().await;
169            }
170        }
171    }
172
173    fn start_websocket_receiver(
174        &mut self,
175        mut reader: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
176    ) {
177        // Task to process messages received from the channel.
178        let relay = self.me.clone();
179
180        tokio::spawn(async move {
181            let Some(relay_arc) = relay.upgrade() else {
182                return;
183            };
184
185            while let Some(result) = reader.next().await {
186                let mut relay = relay_arc.lock().await;
187                match result {
188                    Ok(message) => match message {
189                        Message::Text(text) => {
190                            if let Ok(message) = serde_json::from_str::<MessageToRelay>(&text) {
191                                relay.handle_message(message).await.ok();
192                            } else {
193                                error!("Failed to deserialize message: {}", text);
194                            }
195                        }
196                        Message::Binary(data) => {
197                            debug!("Received binary message of length: {}", data.len());
198                        }
199                        Message::Ping(_) => {
200                            debug!("Received ping message");
201                        }
202                        Message::Pong(_) => {
203                            debug!("Received pong message");
204                        }
205                        Message::Close(frame) => {
206                            info!("Received close message: {:?}", frame);
207                            relay.reconnect_soon().await;
208                            break;
209                        }
210                        Message::Frame(_) => {
211                            unreachable!("This is never used")
212                        }
213                    },
214                    Err(e) => {
215                        error!("Error processing message: {}", e);
216                        // TODO: There has to be a better way to handle this
217                        if e.to_string()
218                            .contains("Connection reset without closing handshake")
219                        {
220                            relay.reconnect_soon().await;
221                        }
222                        break;
223                    }
224                }
225            }
226        });
227    }
228
229    async fn stop_internal(&mut self) {
230        info!("Stop internal");
231        if let Some(mut ws_writer) = self.ws_writer.take() {
232            if let Err(e) = ws_writer.close().await {
233                error!("Error closing WebSocket: {}", e);
234            } else {
235                info!("WebSocket closed successfully");
236            }
237        }
238        self.connected = false;
239        self.wrong_password = false;
240        *self.reconnect_on_tunnel_error.lock().await = false;
241        *self.start_on_reconnect_soon.lock().await = false;
242        self.update_status();
243    }
244
245    fn update_status(&self) {
246        let Some(on_status_updated) = &self.on_status_updated else {
247            return;
248        };
249        let status = if self.connected {
250            "Connected to streamer"
251        } else if self.wrong_password {
252            "Wrong password"
253        } else if self.started {
254            "Connecting to streamer"
255        } else {
256            "Disconnected from streamer"
257        };
258        on_status_updated(status.to_string());
259    }
260
261    async fn reconnect_soon(&mut self) {
262        self.stop_internal().await;
263        *self.start_on_reconnect_soon.lock().await = false;
264        let start_on_reconnect_soon = Arc::new(Mutex::new(true));
265        self.start_on_reconnect_soon = start_on_reconnect_soon.clone();
266        self.start_soon(start_on_reconnect_soon);
267    }
268
269    fn start_soon(&mut self, start_on_reconnect_soon: Arc<Mutex<bool>>) {
270        let relay = self.me.clone();
271
272        tokio::spawn(async move {
273            info!("Reconnecting in 5 seconds...");
274            sleep(Duration::from_secs(5)).await;
275
276            if *start_on_reconnect_soon.lock().await {
277                info!("Reconnecting...");
278                if let Some(relay) = relay.upgrade() {
279                    relay.lock().await.start_internal().await;
280                }
281            }
282        });
283    }
284
285    async fn handle_message(
286        &mut self,
287        message: MessageToRelay,
288    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
289        match message {
290            MessageToRelay::Hello(hello) => self.handle_message_hello(hello).await,
291            MessageToRelay::Identified(identified) => {
292                self.handle_message_identified(identified).await
293            }
294            MessageToRelay::Request(request) => self.handle_message_request(request).await,
295        }
296    }
297
298    async fn handle_message_hello(
299        &mut self,
300        hello: Hello,
301    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
302        let authentication = calculate_authentication(
303            &self.password,
304            &hello.authentication.salt,
305            &hello.authentication.challenge,
306        );
307        let identify = Identify {
308            id: self.relay_id.clone(),
309            name: self.name.clone(),
310            authentication,
311        };
312        self.send(MessageToStreamer::Identify(identify)).await
313    }
314
315    async fn handle_message_identified(
316        &mut self,
317        identified: Identified,
318    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
319        match identified.result {
320            MoblinkResult::Ok(_) => {
321                self.connected = true;
322            }
323            MoblinkResult::WrongPassword(_) => {
324                self.wrong_password = true;
325            }
326        }
327        self.update_status();
328        Ok(())
329    }
330
331    async fn handle_message_request(
332        &mut self,
333        request: MessageRequest,
334    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
335        match &request.data {
336            MessageRequestData::StartTunnel(start_tunnel) => {
337                self.handle_message_request_start_tunnel(&request, start_tunnel)
338                    .await
339            }
340            MessageRequestData::Status(_) => self.handle_message_request_status(request).await,
341        }
342    }
343
344    async fn handle_message_request_start_tunnel(
345        &mut self,
346        request: &MessageRequest,
347        start_tunnel: &StartTunnelRequest,
348    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
349        // Pick bind addresses from the relay
350        let local_bind_addr_for_streamer = parse_socket_addr("0.0.0.0")?;
351        let local_bind_addr_for_destination = parse_socket_addr(&self.bind_address)?;
352
353        info!(
354            "Binding streamer socket on: {}, destination socket on: {}",
355            local_bind_addr_for_streamer, local_bind_addr_for_destination
356        );
357        // Create a UDP socket bound for receiving packets from the server.
358        // Use dual-stack socket creation.
359        let streamer_socket = create_dual_stack_udp_socket(local_bind_addr_for_streamer).await?;
360        let streamer_port = streamer_socket.local_addr()?.port();
361        info!("Listening on UDP port: {}", streamer_port);
362        let streamer_socket = Arc::new(streamer_socket);
363
364        // Inform the server about the chosen port.
365        let data = ResponseData::StartTunnel(StartTunnelResponseData {
366            port: streamer_port,
367        });
368        let response = request.to_ok_response(data);
369        self.send(MessageToStreamer::Response(response)).await?;
370
371        // Create a new UDP socket for communication with the destination.
372        // Use dual-stack socket creation.
373        let destination_socket =
374            create_dual_stack_udp_socket(local_bind_addr_for_destination).await?;
375
376        info!(
377            "Bound destination socket to: {:?}",
378            destination_socket.local_addr()?
379        );
380        let destination_socket = Arc::new(destination_socket);
381
382        let normalized_ip = match IpAddr::from_str(&start_tunnel.address)? {
383            IpAddr::V4(v4) => IpAddr::V4(v4),
384            IpAddr::V6(v6) => {
385                // If it’s an IPv4-mapped IPv6 like ::ffff:x.x.x.x, convert to real IPv4
386                if let Some(mapped_v4) = v6.to_ipv4() {
387                    IpAddr::V4(mapped_v4)
388                } else {
389                    // Otherwise, keep it as IPv6
390                    IpAddr::V6(v6)
391                }
392            }
393        };
394        let destination_addr = SocketAddr::new(normalized_ip, start_tunnel.port);
395        info!("Destination address resolved: {}", destination_addr);
396
397        // Use an Arc<Mutex> to share the server_remote_addr between tasks.
398        let streamer_addr: Arc<Mutex<Option<SocketAddr>>> = Arc::new(Mutex::new(None));
399
400        let relay_to_destination = start_relay_from_streamer_to_destination(
401            streamer_socket.clone(),
402            destination_socket.clone(),
403            streamer_addr.clone(),
404            destination_addr,
405        );
406        let relay_to_streamer = start_relay_from_destination_to_streamer(
407            streamer_socket,
408            destination_socket,
409            streamer_addr,
410        );
411
412        *self.reconnect_on_tunnel_error.lock().await = false;
413        let reconnect_on_tunnel_error = Arc::new(Mutex::new(true));
414        self.reconnect_on_tunnel_error = reconnect_on_tunnel_error.clone();
415        let relay = self.me.clone();
416
417        tokio::spawn(async move {
418            let Some(relay) = relay.upgrade() else {
419                return;
420            };
421
422            // Wait for relay tasks to complete (they won't unless an error occurs or the
423            // socket is closed).
424            tokio::select! {
425                res = relay_to_destination => {
426                    if let Err(e) = res {
427                        error!("relay_to_destination task failed: {}", e);
428                    }
429                }
430                res = relay_to_streamer => {
431                    if let Err(e) = res {
432                        error!("relay_to_streamer task failed: {}", e);
433                    }
434                }
435            }
436
437            if *reconnect_on_tunnel_error.lock().await {
438                relay.lock().await.reconnect_soon().await;
439            } else {
440                info!("Not reconnecting after tunnel error");
441            }
442        });
443
444        Ok(())
445    }
446
447    async fn handle_message_request_status(
448        &mut self,
449        request: MessageRequest,
450    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
451        let Some(get_status) = self.get_status.as_ref() else {
452            error!("get_battery_percentage is not set");
453            return Err("get_battery_percentage function not set".into());
454        };
455        let status = get_status().await;
456        let data = ResponseData::Status(StatusResponseData {
457            battery_percentage: status.battery_percentage,
458        });
459        let response = request.to_ok_response(data);
460        self.send(MessageToStreamer::Response(response)).await
461    }
462
463    async fn send(
464        &mut self,
465        message: MessageToStreamer,
466    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
467        let text = serde_json::to_string(&message)?;
468        let Some(writer) = self.ws_writer.as_mut() else {
469            return Err("No websocket writer".into());
470        };
471        writer.send(Message::Text(text.into())).await?;
472        Ok(())
473    }
474}
475
476fn start_relay_from_streamer_to_destination(
477    streamer_socket: Arc<UdpSocket>,
478    destination_socket: Arc<UdpSocket>,
479    streamer_addr: Arc<Mutex<Option<SocketAddr>>>,
480    destination_addr: SocketAddr,
481) -> tokio::task::JoinHandle<()> {
482    tokio::spawn(async move {
483        debug!("(relay_to_destination) Task started");
484        loop {
485            let mut buf = [0; 2048];
486            let (size, remote_addr) =
487                match timeout(Duration::from_secs(30), streamer_socket.recv_from(&mut buf)).await {
488                    Ok(result) => match result {
489                        Ok((size, addr)) => (size, addr),
490                        Err(e) => {
491                            error!("(relay_to_destination) Error receiving from server: {}", e);
492                            continue;
493                        }
494                    },
495                    Err(e) => {
496                        error!(
497                            "(relay_to_destination) Timeout receiving from server: {}",
498                            e
499                        );
500                        break;
501                    }
502                };
503
504            debug!(
505                "(relay_to_destination) Received {} bytes from server: {}",
506                size, remote_addr
507            );
508
509            // Forward to destination.
510            match destination_socket
511                .send_to(&buf[..size], &destination_addr)
512                .await
513            {
514                Ok(bytes_sent) => {
515                    debug!(
516                        "(relay_to_destination) Sent {} bytes to destination",
517                        bytes_sent
518                    )
519                }
520                Err(e) => {
521                    error!(
522                        "(relay_to_destination) Failed to send to destination: {}",
523                        e
524                    );
525                    break;
526                }
527            }
528
529            // Set the remote address if it hasn't been set yet.
530            let mut streamer_addr_lock = streamer_addr.lock().await;
531            if streamer_addr_lock.is_none() {
532                *streamer_addr_lock = Some(remote_addr);
533                debug!(
534                    "(relay_to_destination) Server remote address set to: {}",
535                    remote_addr
536                );
537            }
538        }
539        info!("(relay_to_destination) Task exiting");
540    })
541}
542
543fn start_relay_from_destination_to_streamer(
544    streamer_socket: Arc<UdpSocket>,
545    destination_socket: Arc<UdpSocket>,
546    streamer_addr: Arc<Mutex<Option<SocketAddr>>>,
547) -> tokio::task::JoinHandle<()> {
548    tokio::spawn(async move {
549        debug!("(relay_to_streamer) Task started");
550        loop {
551            let mut buf = [0; 2048];
552            let (size, remote_addr) = match timeout(
553                Duration::from_secs(30),
554                destination_socket.recv_from(&mut buf),
555            )
556            .await
557            {
558                Ok(result) => match result {
559                    Ok((size, addr)) => (size, addr),
560                    Err(e) => {
561                        error!(
562                            "(relay_to_streamer) Error receiving from destination: {}",
563                            e
564                        );
565                        continue;
566                    }
567                },
568                Err(e) => {
569                    error!(
570                        "(relay_to_streamer) Timeout receiving from destination: {}",
571                        e
572                    );
573                    break;
574                }
575            };
576
577            debug!(
578                "(relay_to_streamer) Received {} bytes from destination: {}",
579                size, remote_addr
580            );
581            // Forward to server.
582            let streamer_addr_lock = streamer_addr.lock().await;
583            match *streamer_addr_lock {
584                Some(streamer_addr) => {
585                    match streamer_socket.send_to(&buf[..size], &streamer_addr).await {
586                        Ok(bytes_sent) => {
587                            debug!("(relay_to_streamer) Sent {} bytes to server", bytes_sent)
588                        }
589                        Err(e) => {
590                            error!("(relay_to_streamer) Failed to send to server: {}", e);
591                            break;
592                        }
593                    }
594                }
595                None => {
596                    error!("(relay_to_streamer) Server address not set, cannot forward packet");
597                }
598            }
599        }
600        info!("(relay_to_streamer) Task exiting");
601    })
602}
603
604fn calculate_authentication(password: &str, salt: &str, challenge: &str) -> String {
605    let mut hasher = Sha256::new();
606    hasher.update(format!("{}{}", password, salt).as_bytes());
607    let hash1 = hasher.finalize_reset();
608    hasher.update(format!("{}{}", general_purpose::STANDARD.encode(hash1), challenge).as_bytes());
609    let hash2 = hasher.finalize();
610    general_purpose::STANDARD.encode(hash2)
611}
612
613async fn create_dual_stack_udp_socket(
614    addr: SocketAddr,
615) -> Result<tokio::net::UdpSocket, std::io::Error> {
616    let socket = match addr.is_ipv4() {
617        true => {
618            // Create an IPv4 socket
619            tokio::net::UdpSocket::bind(addr).await?
620        }
621        false => {
622            // Create a dual-stack socket (supporting both IPv4 and IPv6)
623            let socket = socket2::Socket::new(
624                socket2::Domain::IPV6,
625                socket2::Type::DGRAM,
626                Some(socket2::Protocol::UDP),
627            )?;
628
629            // Set IPV6_V6ONLY to false to enable dual-stack support
630            socket.set_only_v6(false)?;
631
632            // Bind the socket
633            socket.bind(&socket2::SockAddr::from(addr))?;
634
635            // Convert to a tokio UdpSocket
636            tokio::net::UdpSocket::from_std(socket.into())?
637        }
638    };
639
640    Ok(socket)
641}
642
643// Helper function to parse a string into a SocketAddr, handling IP addresses
644// without ports.
645fn parse_socket_addr(addr_str: &str) -> Result<SocketAddr, std::io::Error> {
646    // Attempt to parse the string as a full SocketAddr (IP:port)
647    if let Ok(socket_addr) = SocketAddr::from_str(addr_str) {
648        return Ok(socket_addr);
649    }
650
651    // If parsing as SocketAddr fails, try parsing as IP address and append default
652    // port
653    if let Ok(ip_addr) = IpAddr::from_str(addr_str) {
654        // Use 0 as the default port, allowing the OS to assign an available port
655        return Ok(SocketAddr::new(ip_addr, 0));
656    }
657
658    // Return an error if both attempts fail
659    Err(std::io::Error::new(
660        std::io::ErrorKind::InvalidInput,
661        "Invalid socket address syntax. Expected 'IP:port' or 'IP'.",
662    ))
663}