1use serde::Deserialize;
2use std::future::Future;
3use std::net::{IpAddr, SocketAddr};
4use std::pin::Pin;
5use std::str::FromStr;
6use std::sync::{Arc, Weak};
7
8use base64::engine::general_purpose;
9use base64::Engine as _;
10use futures_util::stream::{SplitSink, SplitStream};
11use futures_util::{SinkExt, StreamExt};
12use log::{debug, error, info};
13use sha2::{Digest, Sha256};
14use tokio::net::{TcpStream, UdpSocket};
15use tokio::sync::Mutex;
16use tokio::time::{sleep, timeout, Duration};
17use tokio_tungstenite::tungstenite::protocol::Message;
18use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
19
20use crate::protocol::*;
21
22#[derive(Default, Deserialize)]
23#[serde(rename_all = "camelCase")]
24pub struct Status {
25 pub battery_percentage: Option<i32>,
26}
27
28pub type GetStatusClosure =
29 Box<dyn Fn() -> Pin<Box<dyn Future<Output = Status> + Send + Sync>> + Send + Sync>;
30
31pub struct Relay {
32 me: Weak<Mutex<Self>>,
33 bind_address: String,
35 relay_id: String,
36 streamer_url: String,
37 password: String,
38 name: String,
39 on_status_updated: Option<Box<dyn Fn(String) + Send + Sync>>,
40 get_status: Option<Arc<GetStatusClosure>>,
41 ws_writer: Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
42 started: bool,
43 connected: bool,
44 wrong_password: bool,
45 reconnect_on_tunnel_error: Arc<Mutex<bool>>,
46 start_on_reconnect_soon: Arc<Mutex<bool>>,
47}
48
49impl Relay {
50 pub fn new() -> Arc<Mutex<Self>> {
51 Arc::new_cyclic(|me| {
52 Mutex::new(Self {
53 me: me.clone(),
54 bind_address: Self::get_default_bind_address(),
55 relay_id: "".to_string(),
56 streamer_url: "".to_string(),
57 password: "".to_string(),
58 name: "".to_string(),
59 on_status_updated: None,
60 get_status: None,
61 ws_writer: None,
62 started: false,
63 connected: false,
64 wrong_password: false,
65 reconnect_on_tunnel_error: Arc::new(Mutex::new(false)),
66 start_on_reconnect_soon: Arc::new(Mutex::new(false)),
67 })
68 })
69 }
70
71 pub fn set_bind_address(&mut self, address: String) {
72 self.bind_address = address;
73 }
74
75 pub async fn setup<F>(
76 &mut self,
77 streamer_url: String,
78 password: String,
79 relay_id: String,
80 name: String,
81 on_status_updated: F,
82 get_status: GetStatusClosure,
83 ) where
84 F: Fn(String) + Send + Sync + 'static,
85 {
86 self.on_status_updated = Some(Box::new(on_status_updated));
87 self.get_status = Some(Arc::new(get_status));
88 self.relay_id = relay_id;
89 self.streamer_url = streamer_url;
90 self.password = password;
91 self.name = name;
92 info!("Binding to address: {:?}", self.bind_address);
93 }
94
95 pub fn is_started(&self) -> bool {
96 self.started
97 }
98
99 pub async fn start(&mut self) {
100 if !self.started {
101 self.started = true;
102 self.start_internal().await;
103 }
104 }
105
106 pub async fn stop(&mut self) {
107 if self.started {
108 self.started = false;
109 self.stop_internal().await;
110 }
111 }
112
113 fn get_default_bind_address() -> String {
114 let interfaces = pnet::datalink::interfaces();
116 let interface = interfaces.iter().find(|interface| {
117 interface.is_up() && !interface.is_loopback() && !interface.ips.is_empty()
118 });
119
120 let ipv4_addresses: Vec<String> = interface
122 .expect("No available network interfaces found")
123 .ips
124 .iter()
125 .filter_map(|ip| {
126 let ip = ip.ip();
127 ip.is_ipv4().then(|| ip.to_string())
128 })
129 .collect();
130
131 ipv4_addresses
133 .first()
134 .cloned()
135 .unwrap_or("0.0.0.0:0".to_string())
136 }
137
138 async fn start_internal(&mut self) {
139 info!("Start internal");
140 if !self.started {
141 self.stop_internal().await;
142 return;
143 }
144
145 let request = match url::Url::parse(&self.streamer_url) {
146 Ok(url) => url,
147 Err(e) => {
148 error!("Failed to parse URL: {}", e);
149 return;
150 }
151 };
152
153 match timeout(Duration::from_secs(10), connect_async(request.to_string())).await {
154 Ok(Ok((ws_stream, _))) => {
155 info!("WebSocket connected");
156 let (writer, reader) = ws_stream.split();
157 self.ws_writer = Some(writer);
158 self.start_websocket_receiver(reader);
159 }
160 Ok(Err(error)) => {
161 error!("WebSocket connection failed immediately: {}", error);
163 self.reconnect_soon().await;
164 }
165 Err(_elapsed) => {
166 error!("WebSocket connection attempt timed out after 10 seconds");
168 self.reconnect_soon().await;
169 }
170 }
171 }
172
173 fn start_websocket_receiver(
174 &mut self,
175 mut reader: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
176 ) {
177 let relay = self.me.clone();
179
180 tokio::spawn(async move {
181 let Some(relay_arc) = relay.upgrade() else {
182 return;
183 };
184
185 while let Some(result) = reader.next().await {
186 let mut relay = relay_arc.lock().await;
187 match result {
188 Ok(message) => match message {
189 Message::Text(text) => {
190 if let Ok(message) = serde_json::from_str::<MessageToRelay>(&text) {
191 relay.handle_message(message).await.ok();
192 } else {
193 error!("Failed to deserialize message: {}", text);
194 }
195 }
196 Message::Binary(data) => {
197 debug!("Received binary message of length: {}", data.len());
198 }
199 Message::Ping(_) => {
200 debug!("Received ping message");
201 }
202 Message::Pong(_) => {
203 debug!("Received pong message");
204 }
205 Message::Close(frame) => {
206 info!("Received close message: {:?}", frame);
207 relay.reconnect_soon().await;
208 break;
209 }
210 Message::Frame(_) => {
211 unreachable!("This is never used")
212 }
213 },
214 Err(e) => {
215 error!("Error processing message: {}", e);
216 if e.to_string()
218 .contains("Connection reset without closing handshake")
219 {
220 relay.reconnect_soon().await;
221 }
222 break;
223 }
224 }
225 }
226 });
227 }
228
229 async fn stop_internal(&mut self) {
230 info!("Stop internal");
231 if let Some(mut ws_writer) = self.ws_writer.take() {
232 if let Err(e) = ws_writer.close().await {
233 error!("Error closing WebSocket: {}", e);
234 } else {
235 info!("WebSocket closed successfully");
236 }
237 }
238 self.connected = false;
239 self.wrong_password = false;
240 *self.reconnect_on_tunnel_error.lock().await = false;
241 *self.start_on_reconnect_soon.lock().await = false;
242 self.update_status();
243 }
244
245 fn update_status(&self) {
246 let Some(on_status_updated) = &self.on_status_updated else {
247 return;
248 };
249 let status = if self.connected {
250 "Connected to streamer"
251 } else if self.wrong_password {
252 "Wrong password"
253 } else if self.started {
254 "Connecting to streamer"
255 } else {
256 "Disconnected from streamer"
257 };
258 on_status_updated(status.to_string());
259 }
260
261 async fn reconnect_soon(&mut self) {
262 self.stop_internal().await;
263 *self.start_on_reconnect_soon.lock().await = false;
264 let start_on_reconnect_soon = Arc::new(Mutex::new(true));
265 self.start_on_reconnect_soon = start_on_reconnect_soon.clone();
266 self.start_soon(start_on_reconnect_soon);
267 }
268
269 fn start_soon(&mut self, start_on_reconnect_soon: Arc<Mutex<bool>>) {
270 let relay = self.me.clone();
271
272 tokio::spawn(async move {
273 info!("Reconnecting in 5 seconds...");
274 sleep(Duration::from_secs(5)).await;
275
276 if *start_on_reconnect_soon.lock().await {
277 info!("Reconnecting...");
278 if let Some(relay) = relay.upgrade() {
279 relay.lock().await.start_internal().await;
280 }
281 }
282 });
283 }
284
285 async fn handle_message(
286 &mut self,
287 message: MessageToRelay,
288 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
289 match message {
290 MessageToRelay::Hello(hello) => self.handle_message_hello(hello).await,
291 MessageToRelay::Identified(identified) => {
292 self.handle_message_identified(identified).await
293 }
294 MessageToRelay::Request(request) => self.handle_message_request(request).await,
295 }
296 }
297
298 async fn handle_message_hello(
299 &mut self,
300 hello: Hello,
301 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
302 let authentication = calculate_authentication(
303 &self.password,
304 &hello.authentication.salt,
305 &hello.authentication.challenge,
306 );
307 let identify = Identify {
308 id: self.relay_id.clone(),
309 name: self.name.clone(),
310 authentication,
311 };
312 self.send(MessageToStreamer::Identify(identify)).await
313 }
314
315 async fn handle_message_identified(
316 &mut self,
317 identified: Identified,
318 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
319 match identified.result {
320 MoblinkResult::Ok(_) => {
321 self.connected = true;
322 }
323 MoblinkResult::WrongPassword(_) => {
324 self.wrong_password = true;
325 }
326 }
327 self.update_status();
328 Ok(())
329 }
330
331 async fn handle_message_request(
332 &mut self,
333 request: MessageRequest,
334 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
335 match &request.data {
336 MessageRequestData::StartTunnel(start_tunnel) => {
337 self.handle_message_request_start_tunnel(&request, start_tunnel)
338 .await
339 }
340 MessageRequestData::Status(_) => self.handle_message_request_status(request).await,
341 }
342 }
343
344 async fn handle_message_request_start_tunnel(
345 &mut self,
346 request: &MessageRequest,
347 start_tunnel: &StartTunnelRequest,
348 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
349 let local_bind_addr_for_streamer = parse_socket_addr("0.0.0.0")?;
351 let local_bind_addr_for_destination = parse_socket_addr(&self.bind_address)?;
352
353 info!(
354 "Binding streamer socket on: {}, destination socket on: {}",
355 local_bind_addr_for_streamer, local_bind_addr_for_destination
356 );
357 let streamer_socket = create_dual_stack_udp_socket(local_bind_addr_for_streamer).await?;
360 let streamer_port = streamer_socket.local_addr()?.port();
361 info!("Listening on UDP port: {}", streamer_port);
362 let streamer_socket = Arc::new(streamer_socket);
363
364 let data = ResponseData::StartTunnel(StartTunnelResponseData {
366 port: streamer_port,
367 });
368 let response = request.to_ok_response(data);
369 self.send(MessageToStreamer::Response(response)).await?;
370
371 let destination_socket =
374 create_dual_stack_udp_socket(local_bind_addr_for_destination).await?;
375
376 info!(
377 "Bound destination socket to: {:?}",
378 destination_socket.local_addr()?
379 );
380 let destination_socket = Arc::new(destination_socket);
381
382 let normalized_ip = match IpAddr::from_str(&start_tunnel.address)? {
383 IpAddr::V4(v4) => IpAddr::V4(v4),
384 IpAddr::V6(v6) => {
385 if let Some(mapped_v4) = v6.to_ipv4() {
387 IpAddr::V4(mapped_v4)
388 } else {
389 IpAddr::V6(v6)
391 }
392 }
393 };
394 let destination_addr = SocketAddr::new(normalized_ip, start_tunnel.port);
395 info!("Destination address resolved: {}", destination_addr);
396
397 let streamer_addr: Arc<Mutex<Option<SocketAddr>>> = Arc::new(Mutex::new(None));
399
400 let relay_to_destination = start_relay_from_streamer_to_destination(
401 streamer_socket.clone(),
402 destination_socket.clone(),
403 streamer_addr.clone(),
404 destination_addr,
405 );
406 let relay_to_streamer = start_relay_from_destination_to_streamer(
407 streamer_socket,
408 destination_socket,
409 streamer_addr,
410 );
411
412 *self.reconnect_on_tunnel_error.lock().await = false;
413 let reconnect_on_tunnel_error = Arc::new(Mutex::new(true));
414 self.reconnect_on_tunnel_error = reconnect_on_tunnel_error.clone();
415 let relay = self.me.clone();
416
417 tokio::spawn(async move {
418 let Some(relay) = relay.upgrade() else {
419 return;
420 };
421
422 tokio::select! {
425 res = relay_to_destination => {
426 if let Err(e) = res {
427 error!("relay_to_destination task failed: {}", e);
428 }
429 }
430 res = relay_to_streamer => {
431 if let Err(e) = res {
432 error!("relay_to_streamer task failed: {}", e);
433 }
434 }
435 }
436
437 if *reconnect_on_tunnel_error.lock().await {
438 relay.lock().await.reconnect_soon().await;
439 } else {
440 info!("Not reconnecting after tunnel error");
441 }
442 });
443
444 Ok(())
445 }
446
447 async fn handle_message_request_status(
448 &mut self,
449 request: MessageRequest,
450 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
451 let Some(get_status) = self.get_status.as_ref() else {
452 error!("get_battery_percentage is not set");
453 return Err("get_battery_percentage function not set".into());
454 };
455 let status = get_status().await;
456 let data = ResponseData::Status(StatusResponseData {
457 battery_percentage: status.battery_percentage,
458 });
459 let response = request.to_ok_response(data);
460 self.send(MessageToStreamer::Response(response)).await
461 }
462
463 async fn send(
464 &mut self,
465 message: MessageToStreamer,
466 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
467 let text = serde_json::to_string(&message)?;
468 let Some(writer) = self.ws_writer.as_mut() else {
469 return Err("No websocket writer".into());
470 };
471 writer.send(Message::Text(text.into())).await?;
472 Ok(())
473 }
474}
475
476fn start_relay_from_streamer_to_destination(
477 streamer_socket: Arc<UdpSocket>,
478 destination_socket: Arc<UdpSocket>,
479 streamer_addr: Arc<Mutex<Option<SocketAddr>>>,
480 destination_addr: SocketAddr,
481) -> tokio::task::JoinHandle<()> {
482 tokio::spawn(async move {
483 debug!("(relay_to_destination) Task started");
484 loop {
485 let mut buf = [0; 2048];
486 let (size, remote_addr) =
487 match timeout(Duration::from_secs(30), streamer_socket.recv_from(&mut buf)).await {
488 Ok(result) => match result {
489 Ok((size, addr)) => (size, addr),
490 Err(e) => {
491 error!("(relay_to_destination) Error receiving from server: {}", e);
492 continue;
493 }
494 },
495 Err(e) => {
496 error!(
497 "(relay_to_destination) Timeout receiving from server: {}",
498 e
499 );
500 break;
501 }
502 };
503
504 debug!(
505 "(relay_to_destination) Received {} bytes from server: {}",
506 size, remote_addr
507 );
508
509 match destination_socket
511 .send_to(&buf[..size], &destination_addr)
512 .await
513 {
514 Ok(bytes_sent) => {
515 debug!(
516 "(relay_to_destination) Sent {} bytes to destination",
517 bytes_sent
518 )
519 }
520 Err(e) => {
521 error!(
522 "(relay_to_destination) Failed to send to destination: {}",
523 e
524 );
525 break;
526 }
527 }
528
529 let mut streamer_addr_lock = streamer_addr.lock().await;
531 if streamer_addr_lock.is_none() {
532 *streamer_addr_lock = Some(remote_addr);
533 debug!(
534 "(relay_to_destination) Server remote address set to: {}",
535 remote_addr
536 );
537 }
538 }
539 info!("(relay_to_destination) Task exiting");
540 })
541}
542
543fn start_relay_from_destination_to_streamer(
544 streamer_socket: Arc<UdpSocket>,
545 destination_socket: Arc<UdpSocket>,
546 streamer_addr: Arc<Mutex<Option<SocketAddr>>>,
547) -> tokio::task::JoinHandle<()> {
548 tokio::spawn(async move {
549 debug!("(relay_to_streamer) Task started");
550 loop {
551 let mut buf = [0; 2048];
552 let (size, remote_addr) = match timeout(
553 Duration::from_secs(30),
554 destination_socket.recv_from(&mut buf),
555 )
556 .await
557 {
558 Ok(result) => match result {
559 Ok((size, addr)) => (size, addr),
560 Err(e) => {
561 error!(
562 "(relay_to_streamer) Error receiving from destination: {}",
563 e
564 );
565 continue;
566 }
567 },
568 Err(e) => {
569 error!(
570 "(relay_to_streamer) Timeout receiving from destination: {}",
571 e
572 );
573 break;
574 }
575 };
576
577 debug!(
578 "(relay_to_streamer) Received {} bytes from destination: {}",
579 size, remote_addr
580 );
581 let streamer_addr_lock = streamer_addr.lock().await;
583 match *streamer_addr_lock {
584 Some(streamer_addr) => {
585 match streamer_socket.send_to(&buf[..size], &streamer_addr).await {
586 Ok(bytes_sent) => {
587 debug!("(relay_to_streamer) Sent {} bytes to server", bytes_sent)
588 }
589 Err(e) => {
590 error!("(relay_to_streamer) Failed to send to server: {}", e);
591 break;
592 }
593 }
594 }
595 None => {
596 error!("(relay_to_streamer) Server address not set, cannot forward packet");
597 }
598 }
599 }
600 info!("(relay_to_streamer) Task exiting");
601 })
602}
603
604fn calculate_authentication(password: &str, salt: &str, challenge: &str) -> String {
605 let mut hasher = Sha256::new();
606 hasher.update(format!("{}{}", password, salt).as_bytes());
607 let hash1 = hasher.finalize_reset();
608 hasher.update(format!("{}{}", general_purpose::STANDARD.encode(hash1), challenge).as_bytes());
609 let hash2 = hasher.finalize();
610 general_purpose::STANDARD.encode(hash2)
611}
612
613async fn create_dual_stack_udp_socket(
614 addr: SocketAddr,
615) -> Result<tokio::net::UdpSocket, std::io::Error> {
616 let socket = match addr.is_ipv4() {
617 true => {
618 tokio::net::UdpSocket::bind(addr).await?
620 }
621 false => {
622 let socket = socket2::Socket::new(
624 socket2::Domain::IPV6,
625 socket2::Type::DGRAM,
626 Some(socket2::Protocol::UDP),
627 )?;
628
629 socket.set_only_v6(false)?;
631
632 socket.bind(&socket2::SockAddr::from(addr))?;
634
635 tokio::net::UdpSocket::from_std(socket.into())?
637 }
638 };
639
640 Ok(socket)
641}
642
643fn parse_socket_addr(addr_str: &str) -> Result<SocketAddr, std::io::Error> {
646 if let Ok(socket_addr) = SocketAddr::from_str(addr_str) {
648 return Ok(socket_addr);
649 }
650
651 if let Ok(ip_addr) = IpAddr::from_str(addr_str) {
654 return Ok(SocketAddr::new(ip_addr, 0));
656 }
657
658 Err(std::io::Error::new(
660 std::io::ErrorKind::InvalidInput,
661 "Invalid socket address syntax. Expected 'IP:port' or 'IP'.",
662 ))
663}