1use std::future::Future;
2use std::net::{IpAddr, SocketAddr};
3use std::pin::Pin;
4use std::str::FromStr;
5use std::sync::{Arc, Weak};
6
7use futures_util::stream::{SplitSink, SplitStream};
8use futures_util::{SinkExt, StreamExt};
9use log::{debug, error, info};
10use serde::Deserialize;
11use tokio::fs::File;
12use tokio::io::AsyncReadExt;
13use tokio::net::{TcpStream, UdpSocket};
14use tokio::process::Command;
15use tokio::sync::Mutex;
16use tokio::time::{Duration, sleep, timeout};
17use tokio_tungstenite::tungstenite::protocol::Message;
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
19use uuid::Uuid;
20
21use crate::protocol::*;
22use crate::utils::AnyError;
23
24#[derive(Default, Deserialize, Clone)]
25#[serde(rename_all = "camelCase")]
26pub struct Status {
27 pub battery_percentage: Option<i32>,
28}
29
30pub type GetStatusClosure =
31 Box<dyn Fn() -> Pin<Box<dyn Future<Output = Status> + Send + Sync>> + Send + Sync>;
32
33struct RelayInner {
34 me: Weak<Mutex<Self>>,
35 bind_address: String,
37 relay_id: Uuid,
38 streamer_url: String,
39 password: String,
40 name: String,
41 on_status_updated: Option<Box<dyn Fn(String) + Send + Sync>>,
42 get_status: Option<Arc<GetStatusClosure>>,
43 ws_writer: Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
44 started: bool,
45 connected: bool,
46 wrong_password: bool,
47 reconnect_on_tunnel_error: Arc<Mutex<bool>>,
48 start_on_reconnect_soon: Arc<Mutex<bool>>,
49}
50
51impl RelayInner {
52 fn new() -> Arc<Mutex<Self>> {
53 Arc::new_cyclic(|me| {
54 Mutex::new(Self {
55 me: me.clone(),
56 bind_address: Self::get_default_bind_address(),
57 relay_id: Uuid::new_v4(),
58 streamer_url: "".to_string(),
59 password: "".to_string(),
60 name: "".to_string(),
61 on_status_updated: None,
62 get_status: None,
63 ws_writer: None,
64 started: false,
65 connected: false,
66 wrong_password: false,
67 reconnect_on_tunnel_error: Arc::new(Mutex::new(false)),
68 start_on_reconnect_soon: Arc::new(Mutex::new(false)),
69 })
70 })
71 }
72
73 fn set_bind_address(&mut self, address: String) {
74 self.bind_address = address;
75 }
76
77 async fn setup<F>(
78 &mut self,
79 streamer_url: String,
80 password: String,
81 relay_id: Uuid,
82 name: String,
83 on_status_updated: F,
84 get_status: Option<GetStatusClosure>,
85 ) where
86 F: Fn(String) + Send + Sync + 'static,
87 {
88 self.on_status_updated = Some(Box::new(on_status_updated));
89 self.get_status = get_status.map(Arc::new);
90 self.relay_id = relay_id;
91 self.streamer_url = streamer_url;
92 self.password = password;
93 self.name = name;
94 info!("Binding to address: {:?}", self.bind_address);
95 }
96
97 fn is_started(&self) -> bool {
98 self.started
99 }
100
101 async fn start(&mut self) {
102 if !self.started {
103 self.started = true;
104 self.start_internal().await;
105 }
106 }
107
108 async fn stop(&mut self) {
109 if self.started {
110 self.started = false;
111 self.stop_internal().await;
112 }
113 }
114
115 fn get_default_bind_address() -> String {
116 let interfaces = pnet::datalink::interfaces();
118 let interface = interfaces.iter().find(|interface| {
119 interface.is_up() && !interface.is_loopback() && !interface.ips.is_empty()
120 });
121
122 let ipv4_addresses: Vec<String> = interface
124 .expect("No available network interfaces found")
125 .ips
126 .iter()
127 .filter_map(|ip| {
128 let ip = ip.ip();
129 ip.is_ipv4().then(|| ip.to_string())
130 })
131 .collect();
132
133 ipv4_addresses
135 .first()
136 .cloned()
137 .unwrap_or("0.0.0.0:0".to_string())
138 }
139
140 async fn start_internal(&mut self) {
141 info!("Start internal");
142 if !self.started {
143 self.stop_internal().await;
144 return;
145 }
146
147 let request = match url::Url::parse(&self.streamer_url) {
148 Ok(url) => url,
149 Err(e) => {
150 error!("Failed to parse URL: {}", e);
151 return;
152 }
153 };
154
155 match timeout(Duration::from_secs(10), connect_async(request.to_string())).await {
156 Ok(Ok((ws_stream, _))) => {
157 info!("WebSocket connected");
158 let (writer, reader) = ws_stream.split();
159 self.ws_writer = Some(writer);
160 self.start_websocket_receiver(reader);
161 }
162 Ok(Err(error)) => {
163 error!("WebSocket connection failed immediately: {}", error);
165 self.reconnect_soon().await;
166 }
167 Err(_elapsed) => {
168 error!("WebSocket connection attempt timed out after 10 seconds");
170 self.reconnect_soon().await;
171 }
172 }
173 }
174
175 fn start_websocket_receiver(
176 &mut self,
177 mut reader: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
178 ) {
179 let relay = self.me.clone();
181
182 tokio::spawn(async move {
183 let Some(relay_arc) = relay.upgrade() else {
184 return;
185 };
186
187 while let Some(result) = reader.next().await {
188 let mut relay = relay_arc.lock().await;
189 match result {
190 Ok(message) => match message {
191 Message::Text(text) => {
192 match serde_json::from_str::<MessageToRelay>(&text) {
193 Ok(message) => {
194 relay.handle_message(message).await.ok();
195 }
196 _ => {
197 error!("Failed to deserialize message: {}", text);
198 }
199 }
200 }
201 Message::Binary(data) => {
202 debug!("Received binary message of length: {}", data.len());
203 }
204 Message::Ping(data) => {
205 relay.send_message(Message::Pong(data)).await.ok();
206 }
207 Message::Pong(_) => {
208 debug!("Received pong message");
209 }
210 Message::Close(frame) => {
211 info!("Received close message: {:?}", frame);
212 relay.reconnect_soon().await;
213 break;
214 }
215 Message::Frame(_) => {
216 unreachable!("This is never used")
217 }
218 },
219 Err(e) => {
220 error!("Error processing message: {}", e);
221 if e.to_string()
223 .contains("Connection reset without closing handshake")
224 {
225 relay.reconnect_soon().await;
226 }
227 break;
228 }
229 }
230 }
231 });
232 }
233
234 async fn stop_internal(&mut self) {
235 info!("Stop internal");
236 if let Some(mut ws_writer) = self.ws_writer.take() {
237 match ws_writer.close().await {
238 Err(e) => {
239 error!("Error closing WebSocket: {}", e);
240 }
241 _ => {
242 info!("WebSocket closed successfully");
243 }
244 }
245 }
246 self.connected = false;
247 self.wrong_password = false;
248 *self.reconnect_on_tunnel_error.lock().await = false;
249 *self.start_on_reconnect_soon.lock().await = false;
250 self.update_status();
251 }
252
253 fn update_status(&self) {
254 let Some(on_status_updated) = &self.on_status_updated else {
255 return;
256 };
257 let status = if self.connected {
258 "Connected to streamer"
259 } else if self.wrong_password {
260 "Wrong password"
261 } else if self.started {
262 "Connecting to streamer"
263 } else {
264 "Disconnected from streamer"
265 };
266 on_status_updated(status.to_string());
267 }
268
269 async fn reconnect_soon(&mut self) {
270 self.stop_internal().await;
271 *self.start_on_reconnect_soon.lock().await = false;
272 let start_on_reconnect_soon = Arc::new(Mutex::new(true));
273 self.start_on_reconnect_soon = start_on_reconnect_soon.clone();
274 self.start_soon(start_on_reconnect_soon);
275 }
276
277 fn start_soon(&mut self, start_on_reconnect_soon: Arc<Mutex<bool>>) {
278 let relay = self.me.clone();
279
280 tokio::spawn(async move {
281 info!("Reconnecting in 5 seconds...");
282 sleep(Duration::from_secs(5)).await;
283
284 if *start_on_reconnect_soon.lock().await {
285 info!("Reconnecting...");
286 if let Some(relay) = relay.upgrade() {
287 relay.lock().await.start_internal().await;
288 }
289 }
290 });
291 }
292
293 async fn handle_message(&mut self, message: MessageToRelay) -> Result<(), AnyError> {
294 match message {
295 MessageToRelay::Hello(hello) => self.handle_message_hello(hello).await,
296 MessageToRelay::Identified(identified) => {
297 self.handle_message_identified(identified).await
298 }
299 MessageToRelay::Request(request) => self.handle_message_request(request).await,
300 }
301 }
302
303 async fn handle_message_hello(&mut self, hello: Hello) -> Result<(), AnyError> {
304 let authentication = calculate_authentication(
305 &self.password,
306 &hello.authentication.salt,
307 &hello.authentication.challenge,
308 );
309 let identify = Identify {
310 id: self.relay_id,
311 name: self.name.clone(),
312 authentication,
313 };
314 self.send(MessageToStreamer::Identify(identify)).await
315 }
316
317 async fn handle_message_identified(&mut self, identified: Identified) -> Result<(), AnyError> {
318 match identified.result {
319 MoblinkResult::Ok(_) => {
320 self.connected = true;
321 }
322 MoblinkResult::WrongPassword(_) => {
323 self.wrong_password = true;
324 }
325 }
326 self.update_status();
327 Ok(())
328 }
329
330 async fn handle_message_request(&mut self, request: MessageRequest) -> Result<(), AnyError> {
331 match &request.data {
332 MessageRequestData::StartTunnel(start_tunnel) => {
333 self.handle_message_request_start_tunnel(&request, start_tunnel)
334 .await
335 }
336 MessageRequestData::Status(_) => self.handle_message_request_status(request).await,
337 }
338 }
339
340 async fn handle_message_request_start_tunnel(
341 &mut self,
342 request: &MessageRequest,
343 start_tunnel: &StartTunnelRequest,
344 ) -> Result<(), AnyError> {
345 let local_bind_addr_for_streamer = parse_socket_addr("0.0.0.0")?;
347 let local_bind_addr_for_destination = parse_socket_addr(&self.bind_address)?;
348
349 info!(
350 "Binding streamer socket on: {}, destination socket on: {}",
351 local_bind_addr_for_streamer, local_bind_addr_for_destination
352 );
353 let streamer_socket = create_dual_stack_udp_socket(local_bind_addr_for_streamer).await?;
356 let streamer_port = streamer_socket.local_addr()?.port();
357 info!("Listening on UDP port: {}", streamer_port);
358 let streamer_socket = Arc::new(streamer_socket);
359
360 let data = ResponseData::StartTunnel(StartTunnelResponseData {
362 port: streamer_port,
363 });
364 let response = request.to_ok_response(data);
365 self.send(MessageToStreamer::Response(response)).await?;
366
367 let destination_socket =
370 create_dual_stack_udp_socket(local_bind_addr_for_destination).await?;
371
372 info!(
373 "Bound destination socket to: {:?}",
374 destination_socket.local_addr()?
375 );
376 let destination_socket = Arc::new(destination_socket);
377
378 let normalized_ip = match IpAddr::from_str(&start_tunnel.address)? {
379 IpAddr::V4(v4) => IpAddr::V4(v4),
380 IpAddr::V6(v6) => {
381 if let Some(mapped_v4) = v6.to_ipv4() {
383 IpAddr::V4(mapped_v4)
384 } else {
385 IpAddr::V6(v6)
387 }
388 }
389 };
390 let destination_addr = SocketAddr::new(normalized_ip, start_tunnel.port);
391 info!("Destination address resolved: {}", destination_addr);
392
393 let streamer_addr: Arc<Mutex<Option<SocketAddr>>> = Arc::new(Mutex::new(None));
395
396 let relay_to_destination = start_relay_from_streamer_to_destination(
397 streamer_socket.clone(),
398 destination_socket.clone(),
399 streamer_addr.clone(),
400 destination_addr,
401 );
402 let relay_to_streamer = start_relay_from_destination_to_streamer(
403 streamer_socket,
404 destination_socket,
405 streamer_addr,
406 );
407
408 *self.reconnect_on_tunnel_error.lock().await = false;
409 let reconnect_on_tunnel_error = Arc::new(Mutex::new(true));
410 self.reconnect_on_tunnel_error = reconnect_on_tunnel_error.clone();
411 let relay = self.me.clone();
412
413 tokio::spawn(async move {
414 let Some(relay) = relay.upgrade() else {
415 return;
416 };
417
418 tokio::select! {
421 res = relay_to_destination => {
422 if let Err(e) = res {
423 error!("relay_to_destination task failed: {}", e);
424 }
425 }
426 res = relay_to_streamer => {
427 if let Err(e) = res {
428 error!("relay_to_streamer task failed: {}", e);
429 }
430 }
431 }
432
433 if *reconnect_on_tunnel_error.lock().await {
434 relay.lock().await.reconnect_soon().await;
435 } else {
436 info!("Not reconnecting after tunnel error");
437 }
438 });
439
440 Ok(())
441 }
442
443 async fn handle_message_request_status(
444 &mut self,
445 request: MessageRequest,
446 ) -> Result<(), AnyError> {
447 let mut battery_percentage = None;
448 if let Some(get_status) = self.get_status.as_ref() {
449 battery_percentage = get_status().await.battery_percentage;
450 }
451 let data = ResponseData::Status(StatusResponseData { battery_percentage });
452 let response = request.to_ok_response(data);
453 self.send(MessageToStreamer::Response(response)).await
454 }
455
456 async fn send(&mut self, message: MessageToStreamer) -> Result<(), AnyError> {
457 let text = serde_json::to_string(&message)?;
458 self.send_message(Message::Text(text.into())).await
459 }
460
461 async fn send_message(&mut self, message: Message) -> Result<(), AnyError> {
462 let Some(writer) = self.ws_writer.as_mut() else {
463 return Err("No websocket writer".into());
464 };
465 writer.send(message).await?;
466 Ok(())
467 }
468}
469
470pub struct Relay {
471 inner: Arc<Mutex<RelayInner>>,
472}
473
474impl Default for Relay {
475 fn default() -> Self {
476 Self::new()
477 }
478}
479
480impl Relay {
481 pub fn new() -> Self {
482 Self {
483 inner: RelayInner::new(),
484 }
485 }
486
487 pub async fn set_bind_address(&self, address: String) {
488 self.inner.lock().await.set_bind_address(address);
489 }
490
491 pub async fn setup<F>(
492 &self,
493 streamer_url: String,
494 password: String,
495 relay_id: Uuid,
496 name: String,
497 on_status_updated: F,
498 get_status: Option<GetStatusClosure>,
499 ) where
500 F: Fn(String) + Send + Sync + 'static,
501 {
502 self.inner
503 .lock()
504 .await
505 .setup(
506 streamer_url,
507 password,
508 relay_id,
509 name,
510 on_status_updated,
511 get_status,
512 )
513 .await;
514 }
515
516 pub async fn is_started(&self) -> bool {
517 self.inner.lock().await.is_started()
518 }
519
520 pub async fn start(&self) {
521 self.inner.lock().await.start().await;
522 }
523
524 pub async fn stop(&self) {
525 self.inner.lock().await.stop().await;
526 }
527}
528
529fn start_relay_from_streamer_to_destination(
530 streamer_socket: Arc<UdpSocket>,
531 destination_socket: Arc<UdpSocket>,
532 streamer_addr: Arc<Mutex<Option<SocketAddr>>>,
533 destination_addr: SocketAddr,
534) -> tokio::task::JoinHandle<()> {
535 tokio::spawn(async move {
536 debug!("(relay_to_destination) Task started");
537 loop {
538 let mut buf = [0; 2048];
539 let (size, remote_addr) =
540 match timeout(Duration::from_secs(30), streamer_socket.recv_from(&mut buf)).await {
541 Ok(result) => match result {
542 Ok((size, addr)) => (size, addr),
543 Err(e) => {
544 error!("(relay_to_destination) Error receiving from server: {}", e);
545 continue;
546 }
547 },
548 Err(e) => {
549 error!(
550 "(relay_to_destination) Timeout receiving from server: {}",
551 e
552 );
553 break;
554 }
555 };
556
557 debug!(
558 "(relay_to_destination) Received {} bytes from server: {}",
559 size, remote_addr
560 );
561
562 match destination_socket
564 .send_to(&buf[..size], &destination_addr)
565 .await
566 {
567 Ok(bytes_sent) => {
568 debug!(
569 "(relay_to_destination) Sent {} bytes to destination",
570 bytes_sent
571 )
572 }
573 Err(e) => {
574 error!(
575 "(relay_to_destination) Failed to send to destination: {}",
576 e
577 );
578 break;
579 }
580 }
581
582 let mut streamer_addr_lock = streamer_addr.lock().await;
584 if streamer_addr_lock.is_none() {
585 *streamer_addr_lock = Some(remote_addr);
586 debug!(
587 "(relay_to_destination) Server remote address set to: {}",
588 remote_addr
589 );
590 }
591 }
592 info!("(relay_to_destination) Task exiting");
593 })
594}
595
596fn start_relay_from_destination_to_streamer(
597 streamer_socket: Arc<UdpSocket>,
598 destination_socket: Arc<UdpSocket>,
599 streamer_addr: Arc<Mutex<Option<SocketAddr>>>,
600) -> tokio::task::JoinHandle<()> {
601 tokio::spawn(async move {
602 debug!("(relay_to_streamer) Task started");
603 loop {
604 let mut buf = [0; 2048];
605 let (size, remote_addr) = match timeout(
606 Duration::from_secs(30),
607 destination_socket.recv_from(&mut buf),
608 )
609 .await
610 {
611 Ok(result) => match result {
612 Ok((size, addr)) => (size, addr),
613 Err(e) => {
614 error!(
615 "(relay_to_streamer) Error receiving from destination: {}",
616 e
617 );
618 continue;
619 }
620 },
621 Err(e) => {
622 error!(
623 "(relay_to_streamer) Timeout receiving from destination: {}",
624 e
625 );
626 break;
627 }
628 };
629
630 debug!(
631 "(relay_to_streamer) Received {} bytes from destination: {}",
632 size, remote_addr
633 );
634 let streamer_addr_lock = streamer_addr.lock().await;
636 match *streamer_addr_lock {
637 Some(streamer_addr) => {
638 match streamer_socket.send_to(&buf[..size], &streamer_addr).await {
639 Ok(bytes_sent) => {
640 debug!("(relay_to_streamer) Sent {} bytes to server", bytes_sent)
641 }
642 Err(e) => {
643 error!("(relay_to_streamer) Failed to send to server: {}", e);
644 break;
645 }
646 }
647 }
648 None => {
649 error!("(relay_to_streamer) Server address not set, cannot forward packet");
650 }
651 }
652 }
653 info!("(relay_to_streamer) Task exiting");
654 })
655}
656
657async fn create_dual_stack_udp_socket(
658 addr: SocketAddr,
659) -> Result<tokio::net::UdpSocket, std::io::Error> {
660 let socket = match addr.is_ipv4() {
661 true => {
662 tokio::net::UdpSocket::bind(addr).await?
664 }
665 false => {
666 let socket = socket2::Socket::new(
668 socket2::Domain::IPV6,
669 socket2::Type::DGRAM,
670 Some(socket2::Protocol::UDP),
671 )?;
672
673 socket.set_only_v6(false)?;
675
676 socket.bind(&socket2::SockAddr::from(addr))?;
678
679 tokio::net::UdpSocket::from_std(socket.into())?
681 }
682 };
683
684 Ok(socket)
685}
686
687fn parse_socket_addr(addr_str: &str) -> Result<SocketAddr, std::io::Error> {
690 if let Ok(socket_addr) = SocketAddr::from_str(addr_str) {
692 return Ok(socket_addr);
693 }
694
695 if let Ok(ip_addr) = IpAddr::from_str(addr_str) {
698 return Ok(SocketAddr::new(ip_addr, 0));
700 }
701
702 Err(std::io::Error::new(
704 std::io::ErrorKind::InvalidInput,
705 "Invalid socket address syntax. Expected 'IP:port' or 'IP'.",
706 ))
707}
708
709pub fn create_get_status_closure(
710 status_executable: &Option<String>,
711 status_file: &Option<String>,
712) -> Option<GetStatusClosure> {
713 let status_executable = status_executable.clone();
714 let status_file = status_file.clone();
715 Some(Box::new(move || {
716 let status_executable = status_executable.clone();
717 let status_file = status_file.clone();
718 Box::pin(async move {
719 let output = if let Some(status_executable) = &status_executable {
720 let Ok(output) = Command::new(status_executable).output().await else {
721 return Default::default();
722 };
723 output.stdout
724 } else if let Some(status_file) = &status_file {
725 let Ok(mut file) = File::open(status_file).await else {
726 return Default::default();
727 };
728 let mut contents = vec![];
729 if file.read_to_end(&mut contents).await.is_err() {
730 return Default::default();
731 }
732 contents
733 } else {
734 return Default::default();
735 };
736 let output = String::from_utf8(output).unwrap_or_default();
737 match serde_json::from_str(&output) {
738 Ok(status) => status,
739 Err(e) => {
740 error!("Failed to decode status with error: {e}");
741 Default::default()
742 }
743 }
744 })
745 }))
746}