1use std::future::Future;
2use std::net::{IpAddr, SocketAddr};
3use std::pin::Pin;
4use std::str::FromStr;
5use std::sync::{Arc, Weak};
6
7use futures_util::stream::{SplitSink, SplitStream};
8use futures_util::{SinkExt, StreamExt};
9use log::{debug, error, info};
10use serde::Deserialize;
11use tokio::fs::File;
12use tokio::io::AsyncReadExt;
13use tokio::net::{TcpStream, UdpSocket};
14use tokio::process::Command;
15use tokio::sync::Mutex;
16use tokio::time::{Duration, sleep, timeout};
17use tokio_tungstenite::tungstenite::protocol::Message;
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
19use uuid::Uuid;
20
21use crate::protocol::*;
22use crate::utils::{AnyError, resolve_host};
23
24#[derive(Default, Deserialize, Clone)]
25#[serde(rename_all = "camelCase")]
26pub struct Status {
27 pub battery_percentage: Option<i32>,
28}
29
30pub type GetStatusClosure =
31 Box<dyn Fn() -> Pin<Box<dyn Future<Output = Status> + Send + Sync>> + Send + Sync>;
32
33struct RelayInner {
34 me: Weak<Mutex<Self>>,
35 bind_address: String,
37 relay_id: Uuid,
38 streamer_url: String,
39 password: String,
40 name: String,
41 on_status_updated: Option<Box<dyn Fn(String) + Send + Sync>>,
42 get_status: Option<Arc<GetStatusClosure>>,
43 ws_writer: Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
44 started: bool,
45 connected: bool,
46 wrong_password: bool,
47 reconnect_on_tunnel_error: Arc<Mutex<bool>>,
48 start_on_reconnect_soon: Arc<Mutex<bool>>,
49 relay_to_destination: Option<tokio::task::JoinHandle<Result<(), AnyError>>>,
50}
51
52impl RelayInner {
53 fn new() -> Arc<Mutex<Self>> {
54 Arc::new_cyclic(|me| {
55 Mutex::new(Self {
56 me: me.clone(),
57 bind_address: Self::get_default_bind_address(),
58 relay_id: Uuid::new_v4(),
59 streamer_url: "".to_string(),
60 password: "".to_string(),
61 name: "".to_string(),
62 on_status_updated: None,
63 get_status: None,
64 ws_writer: None,
65 started: false,
66 connected: false,
67 wrong_password: false,
68 reconnect_on_tunnel_error: Arc::new(Mutex::new(false)),
69 start_on_reconnect_soon: Arc::new(Mutex::new(false)),
70 relay_to_destination: None,
71 })
72 })
73 }
74
75 fn set_bind_address(&mut self, address: String) {
76 self.bind_address = address;
77 }
78
79 async fn setup<F>(
80 &mut self,
81 streamer_url: String,
82 password: String,
83 relay_id: Uuid,
84 name: String,
85 on_status_updated: F,
86 get_status: Option<GetStatusClosure>,
87 ) where
88 F: Fn(String) + Send + Sync + 'static,
89 {
90 self.on_status_updated = Some(Box::new(on_status_updated));
91 self.get_status = get_status.map(Arc::new);
92 self.relay_id = relay_id;
93 self.streamer_url = streamer_url;
94 self.password = password;
95 self.name = name;
96 }
97
98 fn is_started(&self) -> bool {
99 self.started
100 }
101
102 async fn start(&mut self) {
103 if !self.started {
104 self.started = true;
105 self.start_internal().await;
106 }
107 }
108
109 async fn stop(&mut self) {
110 if self.started {
111 self.started = false;
112 self.stop_internal().await;
113 }
114 }
115
116 fn get_default_bind_address() -> String {
117 let interfaces = pnet::datalink::interfaces();
119 let interface = interfaces.iter().find(|interface| {
120 interface.is_up() && !interface.is_loopback() && !interface.ips.is_empty()
121 });
122
123 let ipv4_addresses: Vec<String> = interface
125 .expect("No available network interfaces found")
126 .ips
127 .iter()
128 .filter_map(|ip| {
129 let ip = ip.ip();
130 ip.is_ipv4().then(|| ip.to_string())
131 })
132 .collect();
133
134 ipv4_addresses
136 .first()
137 .cloned()
138 .unwrap_or("0.0.0.0:0".to_string())
139 }
140
141 async fn start_internal(&mut self) {
142 if !self.started {
143 self.stop_internal().await;
144 return;
145 }
146
147 let request = match url::Url::parse(&self.streamer_url) {
148 Ok(url) => url,
149 Err(e) => {
150 error!("Failed to parse URL: {}", e);
151 return;
152 }
153 };
154
155 match timeout(Duration::from_secs(10), connect_async(request.to_string())).await {
156 Ok(Ok((ws_stream, _))) => {
157 debug!("Connected to {}", self.streamer_url);
158 let (writer, reader) = ws_stream.split();
159 self.ws_writer = Some(writer);
160 self.start_websocket_receiver(reader);
161 }
162 Ok(Err(error)) => {
163 debug!(
164 "Failed to connect to {} with error: {}",
165 self.streamer_url, error
166 );
167 self.reconnect_soon().await;
168 }
169 Err(_elapsed) => {
170 debug!(
171 "Failed to connect to {} within 10 seconds",
172 self.streamer_url
173 );
174 self.reconnect_soon().await;
175 }
176 }
177 }
178
179 fn start_websocket_receiver(
180 &mut self,
181 mut reader: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
182 ) {
183 let relay = self.me.clone();
185
186 tokio::spawn(async move {
187 let Some(relay_arc) = relay.upgrade() else {
188 return;
189 };
190
191 while let Some(result) = reader.next().await {
192 let mut relay = relay_arc.lock().await;
193 match result {
194 Ok(message) => match message {
195 Message::Text(text) => {
196 match serde_json::from_str::<MessageToRelay>(&text) {
197 Ok(message) => {
198 if let Err(error) = relay.handle_message(message).await {
199 error!("Message handling failed with error: {}", error);
200 relay.reconnect_soon().await;
201 break;
202 }
203 }
204 _ => {
205 error!("Failed to deserialize message: {}", text);
206 }
207 }
208 }
209 Message::Binary(data) => {
210 debug!("Received binary message of length: {}", data.len());
211 }
212 Message::Ping(data) => {
213 relay.send_message(Message::Pong(data)).await.ok();
214 }
215 Message::Pong(_) => {
216 debug!("Received pong message");
217 }
218 Message::Close(frame) => {
219 info!("Received close message: {:?}", frame);
220 relay.reconnect_soon().await;
221 break;
222 }
223 Message::Frame(_) => {
224 unreachable!("This is never used")
225 }
226 },
227 Err(e) => {
228 debug!("Error processing message: {}", e);
229 if e.to_string()
231 .contains("Connection reset without closing handshake")
232 {
233 relay.reconnect_soon().await;
234 }
235 break;
236 }
237 }
238 }
239 });
240 }
241
242 async fn stop_internal(&mut self) {
243 if let Some(mut ws_writer) = self.ws_writer.take() {
244 match ws_writer.close().await {
245 Err(e) => {
246 error!("Error closing WebSocket: {}", e);
247 }
248 _ => {
249 debug!("WebSocket closed successfully");
250 }
251 }
252 }
253 self.connected = false;
254 self.wrong_password = false;
255 *self.reconnect_on_tunnel_error.lock().await = false;
256 *self.start_on_reconnect_soon.lock().await = false;
257 if let Some(relay_to_destination) = self.relay_to_destination.take() {
258 relay_to_destination.abort();
259 relay_to_destination.await.ok();
260 }
261 self.update_status();
262 }
263
264 fn update_status(&self) {
265 let Some(on_status_updated) = &self.on_status_updated else {
266 return;
267 };
268 let status = if self.connected {
269 "Connected to streamer"
270 } else if self.wrong_password {
271 "Wrong password"
272 } else if self.started {
273 "Connecting to streamer"
274 } else {
275 "Disconnected from streamer"
276 };
277 on_status_updated(status.to_string());
278 }
279
280 async fn reconnect_soon(&mut self) {
281 self.stop_internal().await;
282 *self.start_on_reconnect_soon.lock().await = false;
283 let start_on_reconnect_soon = Arc::new(Mutex::new(true));
284 self.start_on_reconnect_soon = start_on_reconnect_soon.clone();
285 self.start_soon(start_on_reconnect_soon);
286 }
287
288 fn start_soon(&mut self, start_on_reconnect_soon: Arc<Mutex<bool>>) {
289 let relay = self.me.clone();
290
291 tokio::spawn(async move {
292 sleep(Duration::from_secs(5)).await;
293
294 if *start_on_reconnect_soon.lock().await {
295 debug!("Reconnecting...");
296 if let Some(relay) = relay.upgrade() {
297 relay.lock().await.start_internal().await;
298 }
299 }
300 });
301 }
302
303 async fn handle_message(&mut self, message: MessageToRelay) -> Result<(), AnyError> {
304 match message {
305 MessageToRelay::Hello(hello) => self.handle_message_hello(hello).await,
306 MessageToRelay::Identified(identified) => {
307 self.handle_message_identified(identified).await
308 }
309 MessageToRelay::Request(request) => self.handle_message_request(request).await,
310 }
311 }
312
313 async fn handle_message_hello(&mut self, hello: Hello) -> Result<(), AnyError> {
314 let authentication = calculate_authentication(
315 &self.password,
316 &hello.authentication.salt,
317 &hello.authentication.challenge,
318 );
319 let identify = Identify {
320 id: self.relay_id,
321 name: self.name.clone(),
322 authentication,
323 };
324 self.send(MessageToStreamer::Identify(identify)).await
325 }
326
327 async fn handle_message_identified(&mut self, identified: Identified) -> Result<(), AnyError> {
328 match identified.result {
329 MoblinkResult::Ok(_) => {
330 self.connected = true;
331 }
332 MoblinkResult::WrongPassword(_) => {
333 self.wrong_password = true;
334 }
335 }
336 self.update_status();
337 Ok(())
338 }
339
340 async fn handle_message_request(&mut self, request: MessageRequest) -> Result<(), AnyError> {
341 match &request.data {
342 MessageRequestData::StartTunnel(start_tunnel) => {
343 self.handle_message_request_start_tunnel(&request, start_tunnel)
344 .await
345 }
346 MessageRequestData::Status(_) => self.handle_message_request_status(request).await,
347 }
348 }
349
350 async fn handle_message_request_start_tunnel(
351 &mut self,
352 request: &MessageRequest,
353 start_tunnel: &StartTunnelRequest,
354 ) -> Result<(), AnyError> {
355 let local_bind_addr_for_streamer = parse_socket_addr("0.0.0.0")?;
357 let local_bind_addr_for_destination = parse_socket_addr(&self.bind_address)?;
358
359 debug!(
360 "Binding streamer socket on: {}, destination socket on: {}",
361 local_bind_addr_for_streamer, local_bind_addr_for_destination
362 );
363 let streamer_socket = create_dual_stack_udp_socket(local_bind_addr_for_streamer).await?;
366 let streamer_port = streamer_socket.local_addr()?.port();
367 let streamer_socket = Arc::new(streamer_socket);
368
369 let data = ResponseData::StartTunnel(StartTunnelResponseData {
371 port: streamer_port,
372 });
373 let response = request.to_ok_response(data);
374 self.send(MessageToStreamer::Response(response)).await?;
375
376 let destination_socket =
379 create_dual_stack_udp_socket(local_bind_addr_for_destination).await?;
380
381 let destination_socket = Arc::new(destination_socket);
382 let destination_address = resolve_host(&start_tunnel.address).await?;
383 let destination_address = match IpAddr::from_str(&destination_address)? {
384 IpAddr::V4(v4) => IpAddr::V4(v4),
385 IpAddr::V6(v6) => {
386 if let Some(mapped_v4) = v6.to_ipv4() {
388 IpAddr::V4(mapped_v4)
389 } else {
390 IpAddr::V6(v6)
392 }
393 }
394 };
395
396 let destination_address = SocketAddr::new(destination_address, start_tunnel.port);
397 info!("Destination address: {}", destination_address);
398
399 self.relay_to_destination = Some(
400 self.start_relay_from_streamer_to_destination(
401 streamer_socket,
402 destination_socket,
403 destination_address,
404 )
405 .await,
406 );
407
408 Ok(())
409 }
410
411 async fn start_relay_from_streamer_to_destination(
412 &mut self,
413 streamer_socket: Arc<UdpSocket>,
414 destination_socket: Arc<UdpSocket>,
415 destination_addr: SocketAddr,
416 ) -> tokio::task::JoinHandle<Result<(), AnyError>> {
417 *self.reconnect_on_tunnel_error.lock().await = false;
418 let reconnect_on_tunnel_error = Arc::new(Mutex::new(true));
419 self.reconnect_on_tunnel_error = reconnect_on_tunnel_error.clone();
420 let relay = self.me.clone();
421
422 tokio::spawn(async move {
423 let streamer_address = Arc::new(Mutex::new(None));
424 let mut relay_to_destination_started = false;
425 let mut buf = [0; 2048];
426
427 loop {
428 let (size, remote_addr) = streamer_socket.recv_from(&mut buf).await?;
429 destination_socket
430 .send_to(&buf[..size], &destination_addr)
431 .await?;
432 streamer_address.lock().await.replace(remote_addr);
433
434 if !relay_to_destination_started {
435 start_relay_from_destination_to_streamer(
436 relay.clone(),
437 streamer_socket.clone(),
438 destination_socket.clone(),
439 streamer_address.clone(),
440 reconnect_on_tunnel_error.clone(),
441 );
442 relay_to_destination_started = true;
443 }
444 }
445 })
446 }
447
448 async fn handle_message_request_status(
449 &mut self,
450 request: MessageRequest,
451 ) -> Result<(), AnyError> {
452 let mut battery_percentage = None;
453 if let Some(get_status) = self.get_status.as_ref() {
454 battery_percentage = get_status().await.battery_percentage;
455 }
456 let data = ResponseData::Status(StatusResponseData { battery_percentage });
457 let response = request.to_ok_response(data);
458 self.send(MessageToStreamer::Response(response)).await
459 }
460
461 async fn send(&mut self, message: MessageToStreamer) -> Result<(), AnyError> {
462 let text = serde_json::to_string(&message)?;
463 self.send_message(Message::Text(text.into())).await
464 }
465
466 async fn send_message(&mut self, message: Message) -> Result<(), AnyError> {
467 let Some(writer) = self.ws_writer.as_mut() else {
468 return Err("No websocket writer".into());
469 };
470 writer.send(message).await?;
471 Ok(())
472 }
473}
474
475pub struct Relay {
476 inner: Arc<Mutex<RelayInner>>,
477}
478
479impl Default for Relay {
480 fn default() -> Self {
481 Self::new()
482 }
483}
484
485impl Relay {
486 pub fn new() -> Self {
487 Self {
488 inner: RelayInner::new(),
489 }
490 }
491
492 pub async fn set_bind_address(&self, address: String) {
493 self.inner.lock().await.set_bind_address(address);
494 }
495
496 pub async fn setup<F>(
497 &self,
498 streamer_url: String,
499 password: String,
500 relay_id: Uuid,
501 name: String,
502 on_status_updated: F,
503 get_status: Option<GetStatusClosure>,
504 ) where
505 F: Fn(String) + Send + Sync + 'static,
506 {
507 self.inner
508 .lock()
509 .await
510 .setup(
511 streamer_url,
512 password,
513 relay_id,
514 name,
515 on_status_updated,
516 get_status,
517 )
518 .await;
519 }
520
521 pub async fn is_started(&self) -> bool {
522 self.inner.lock().await.is_started()
523 }
524
525 pub async fn start(&self) {
526 self.inner.lock().await.start().await;
527 }
528
529 pub async fn stop(&self) {
530 self.inner.lock().await.stop().await;
531 }
532}
533
534fn start_relay_from_destination_to_streamer(
535 relay: Weak<Mutex<RelayInner>>,
536 streamer_socket: Arc<UdpSocket>,
537 destination_socket: Arc<UdpSocket>,
538 streamer_address: Arc<Mutex<Option<SocketAddr>>>,
539 reconnect_on_tunnel_error: Arc<Mutex<bool>>,
540) {
541 tokio::spawn(async move {
542 loop {
543 if let Err(error) = relay_one_packet_from_destination_to_streamer(
544 &streamer_socket,
545 &destination_socket,
546 &streamer_address,
547 )
548 .await
549 {
550 info!("(relay_to_streamer) Failed with error: {}", error);
551 break;
552 }
553 }
554
555 if *reconnect_on_tunnel_error.lock().await {
556 if let Some(relay) = relay.upgrade() {
557 relay.lock().await.reconnect_soon().await;
558 }
559 } else {
560 info!("Not reconnecting after tunnel error");
561 }
562 });
563}
564
565async fn relay_one_packet_from_destination_to_streamer(
566 streamer_socket: &Arc<UdpSocket>,
567 destination_socket: &Arc<UdpSocket>,
568 streamer_address: &Arc<Mutex<Option<SocketAddr>>>,
569) -> Result<(), AnyError> {
570 let mut buf = [0; 2048];
571 let size = timeout(Duration::from_secs(30), destination_socket.recv(&mut buf)).await??;
572 let streamer_addr = streamer_address
573 .lock()
574 .await
575 .ok_or("Failed to get address lock")?;
576 streamer_socket
577 .send_to(&buf[..size], &streamer_addr)
578 .await?;
579 Ok(())
580}
581
582async fn create_dual_stack_udp_socket(
583 addr: SocketAddr,
584) -> Result<tokio::net::UdpSocket, std::io::Error> {
585 let socket = match addr.is_ipv4() {
586 true => {
587 tokio::net::UdpSocket::bind(addr).await?
589 }
590 false => {
591 let socket = socket2::Socket::new(
593 socket2::Domain::IPV6,
594 socket2::Type::DGRAM,
595 Some(socket2::Protocol::UDP),
596 )?;
597
598 socket.set_only_v6(false)?;
600
601 socket.bind(&socket2::SockAddr::from(addr))?;
603
604 tokio::net::UdpSocket::from_std(socket.into())?
606 }
607 };
608
609 Ok(socket)
610}
611
612fn parse_socket_addr(addr_str: &str) -> Result<SocketAddr, std::io::Error> {
615 if let Ok(socket_addr) = SocketAddr::from_str(addr_str) {
617 return Ok(socket_addr);
618 }
619
620 if let Ok(ip_addr) = IpAddr::from_str(addr_str) {
623 return Ok(SocketAddr::new(ip_addr, 0));
625 }
626
627 Err(std::io::Error::new(
629 std::io::ErrorKind::InvalidInput,
630 "Invalid socket address syntax. Expected 'IP:port' or 'IP'.",
631 ))
632}
633
634pub fn create_get_status_closure(
635 status_executable: &Option<String>,
636 status_file: &Option<String>,
637) -> Option<GetStatusClosure> {
638 let status_executable = status_executable.clone();
639 let status_file = status_file.clone();
640 Some(Box::new(move || {
641 let status_executable = status_executable.clone();
642 let status_file = status_file.clone();
643 Box::pin(async move {
644 let output = if let Some(status_executable) = &status_executable {
645 let Ok(output) = Command::new(status_executable).output().await else {
646 return Default::default();
647 };
648 output.stdout
649 } else if let Some(status_file) = &status_file {
650 let Ok(mut file) = File::open(status_file).await else {
651 return Default::default();
652 };
653 let mut contents = vec![];
654 if file.read_to_end(&mut contents).await.is_err() {
655 return Default::default();
656 }
657 contents
658 } else {
659 return Default::default();
660 };
661 let output = String::from_utf8(output).unwrap_or_default();
662 match serde_json::from_str(&output) {
663 Ok(status) => status,
664 Err(e) => {
665 error!("Failed to decode status with error: {e}");
666 Default::default()
667 }
668 }
669 })
670 }))
671}