moblink_rust/
relay_service.rs

1use std::collections::HashMap;
2use std::net::Ipv4Addr;
3use std::path::PathBuf;
4use std::sync::{Arc, Weak};
5use std::time::Duration;
6
7use log::{error, info};
8use mdns_sd::{ServiceDaemon, ServiceEvent};
9use network_interface::{NetworkInterface, NetworkInterfaceConfig};
10use regex::Regex;
11use serde::{Deserialize, Serialize};
12use tokio::fs::File;
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14use tokio::sync::Mutex;
15use tokio::task::JoinHandle;
16use uuid::Uuid;
17
18use crate::MDNS_SERVICE_TYPE;
19use crate::relay::{GetStatusClosure, Relay, Status};
20use crate::utils::{any_address_belongs_to_this_machine, get_first_ipv4_address};
21
22#[derive(Serialize, Deserialize, Default)]
23struct DatabaseContent {
24    relay_ids: HashMap<String, Uuid>,
25}
26
27struct Database {
28    path: PathBuf,
29    content: DatabaseContent,
30}
31
32impl Database {
33    async fn new(path: PathBuf) -> Self {
34        let content = Self::load(&path).await;
35        Self { path, content }
36    }
37
38    async fn load(path: &PathBuf) -> DatabaseContent {
39        let mut content = "".to_string();
40        if let Ok(mut file) = File::open(path).await {
41            let mut buffer = vec![];
42            if file.read_to_end(&mut buffer).await.is_ok() {
43                content = String::from_utf8(buffer).unwrap_or_default();
44            }
45        }
46        serde_json::from_str(&content).unwrap_or_default()
47    }
48
49    async fn store(&self) {
50        let content = serde_json::to_string(&self.content).unwrap_or_default();
51        if let Ok(mut file) = File::create(&self.path).await {
52            file.write_all(content.as_bytes()).await.ok();
53        }
54    }
55
56    async fn get_relay_id(&mut self, name: &str) -> Uuid {
57        if !self.content.relay_ids.contains_key(name) {
58            self.content
59                .relay_ids
60                .insert(name.to_string(), Uuid::new_v4());
61            self.store().await;
62        }
63        *self.content.relay_ids.get(name).unwrap()
64    }
65}
66
67struct ServiceRelay {
68    interface_name: String,
69    interface_address: Ipv4Addr,
70    streamer_name: String,
71    streamer_url: String,
72    relay: Relay,
73}
74
75impl ServiceRelay {
76    async fn new(
77        interface_name: String,
78        interface_address: Ipv4Addr,
79        streamer_name: String,
80        streamer_url: String,
81        password: String,
82        get_status: Option<GetStatusClosure>,
83        database: Arc<Mutex<Database>>,
84    ) -> Self {
85        let relay = Relay::new();
86        relay.set_bind_address(interface_address.to_string()).await;
87        relay
88            .setup(
89                streamer_url.clone(),
90                password,
91                database.lock().await.get_relay_id(&interface_name).await,
92                interface_name.clone(),
93                |_| {},
94                get_status,
95            )
96            .await;
97        relay.start().await;
98        Self {
99            interface_name,
100            interface_address,
101            streamer_name,
102            streamer_url,
103            relay,
104        }
105    }
106}
107
108struct Streamer {
109    name: String,
110    url: String,
111}
112
113struct NetworkInterfaceFilter {
114    patterns_to_allow: Option<Regex>,
115    patterns_to_ignore: Option<Regex>,
116}
117
118impl NetworkInterfaceFilter {
119    fn new(patterns_to_allow: Vec<String>, patterns_to_ignore: Vec<String>) -> Self {
120        Self {
121            patterns_to_allow: Self::compile(patterns_to_allow),
122            patterns_to_ignore: Self::compile(patterns_to_ignore),
123        }
124    }
125
126    fn filter(&self, interfaces: &mut Vec<NetworkInterface>) {
127        if let Some(patterns_to_allow) = &self.patterns_to_allow {
128            interfaces.retain(|interface| patterns_to_allow.is_match(&interface.name));
129        }
130        if let Some(patterns_to_ignore) = &self.patterns_to_ignore {
131            interfaces.retain(|interface| !patterns_to_ignore.is_match(&interface.name));
132        }
133    }
134
135    fn compile(patterns: Vec<String>) -> Option<Regex> {
136        if !patterns.is_empty() {
137            let pattern = format!("^{}$", patterns.join("|"));
138            match Regex::new(&pattern) {
139                Ok(regex) => return Some(regex),
140                Err(error) => {
141                    error!("Failed to compile regex {} with error: {}", pattern, error);
142                }
143            }
144        }
145        None
146    }
147}
148
149struct RelayServiceInner {
150    me: Weak<Mutex<Self>>,
151    password: String,
152    network_interface_filter: NetworkInterfaceFilter,
153    get_status: Option<GetStatusClosure>,
154    status: Status,
155    relays: Vec<ServiceRelay>,
156    network_interfaces: Vec<NetworkInterface>,
157    streamers: Vec<Streamer>,
158    network_interface_monitor: Option<JoinHandle<()>>,
159    streamers_monitor: Option<JoinHandle<()>>,
160    get_status_updater: Option<JoinHandle<()>>,
161    database: Arc<Mutex<Database>>,
162}
163
164impl RelayServiceInner {
165    async fn new(
166        password: String,
167        network_interfaces_to_allow: Vec<String>,
168        network_interfaces_to_ignore: Vec<String>,
169        get_status: Option<GetStatusClosure>,
170        database: PathBuf,
171    ) -> Arc<Mutex<Self>> {
172        let database = Arc::new(Mutex::new(Database::new(database).await));
173        Arc::new_cyclic(|me| {
174            Mutex::new(Self {
175                me: me.clone(),
176                password,
177                network_interface_filter: NetworkInterfaceFilter::new(
178                    network_interfaces_to_allow,
179                    network_interfaces_to_ignore,
180                ),
181                get_status,
182                status: Default::default(),
183                relays: Vec::new(),
184                network_interfaces: Vec::new(),
185                streamers: Vec::new(),
186                network_interface_monitor: None,
187                streamers_monitor: None,
188                get_status_updater: None,
189                database,
190            })
191        })
192    }
193
194    async fn start(&mut self) {
195        self.start_network_interfaces_monitor();
196        self.start_streamers_monitor();
197        self.start_get_status_updater();
198    }
199
200    async fn stop(&mut self) {
201        if let Some(network_interface_monitor) = self.network_interface_monitor.take() {
202            network_interface_monitor.abort();
203            network_interface_monitor.await.ok();
204        }
205        if let Some(streamers_finder) = self.streamers_monitor.take() {
206            streamers_finder.abort();
207            streamers_finder.await.ok();
208        }
209    }
210
211    fn start_network_interfaces_monitor(&mut self) {
212        let relay_service = self.me.clone();
213        self.network_interface_monitor = Some(tokio::spawn(async move {
214            loop {
215                let Ok(interfaces) = NetworkInterface::show() else {
216                    break;
217                };
218                let Some(relay_service) = relay_service.upgrade() else {
219                    break;
220                };
221                {
222                    let mut relay_service = relay_service.lock().await;
223                    relay_service.update_network_interfaces(interfaces);
224                    relay_service.updated().await;
225                }
226                tokio::time::sleep(Duration::from_secs(3)).await;
227            }
228        }));
229    }
230
231    fn update_network_interfaces(&mut self, mut interfaces: Vec<NetworkInterface>) {
232        self.network_interface_filter.filter(&mut interfaces);
233        self.network_interfaces = interfaces;
234    }
235
236    fn start_streamers_monitor(&mut self) {
237        let relay_service = self.me.clone();
238        self.streamers_monitor = Some(tokio::spawn(async move {
239            loop {
240                let Ok(browser) = ServiceDaemon::new() else {
241                    return;
242                };
243                let Ok(receiver) = browser.browse(MDNS_SERVICE_TYPE) else {
244                    return;
245                };
246                while let Ok(event) = receiver.recv_async().await {
247                    if let ServiceEvent::ServiceResolved(info) = event {
248                        info!(
249                            "mDNS-SD: Found streamer {} {:?} {}",
250                            info.get_fullname(),
251                            info.get_addresses(),
252                            info.get_port()
253                        );
254                        let Some(name) = info.get_property_val_str("name") else {
255                            continue;
256                        };
257                        let addresses = info.get_addresses_v4();
258                        if any_address_belongs_to_this_machine(&addresses) {
259                            continue;
260                        }
261                        let Some(address) = addresses.iter().next().cloned() else {
262                            continue;
263                        };
264                        let Some(relay_service) = relay_service.upgrade() else {
265                            break;
266                        };
267                        {
268                            let mut relay_service = relay_service.lock().await;
269                            relay_service.add_streamer(name.to_string(), *address, info.get_port());
270                            relay_service.updated().await;
271                        }
272                    }
273                }
274            }
275        }));
276    }
277
278    fn start_get_status_updater(&mut self) {
279        let relay_service = self.me.clone();
280        self.get_status_updater = Some(tokio::spawn(async move {
281            loop {
282                let Some(relay_service) = relay_service.upgrade() else {
283                    break;
284                };
285                relay_service.lock().await.update_status().await;
286                tokio::time::sleep(Duration::from_secs(5)).await;
287            }
288        }));
289    }
290
291    async fn update_status(&mut self) {
292        self.status = if let Some(get_status) = &self.get_status {
293            get_status().await
294        } else {
295            Status::default()
296        }
297    }
298
299    fn add_streamer(&mut self, name: String, address: Ipv4Addr, port: u16) {
300        let url = format!("ws://{}:{}", address, port);
301        self.streamers.retain(|streamer| streamer.url != url);
302        self.streamers.push(Streamer { name, url });
303    }
304
305    async fn updated(&mut self) {
306        let old_number_of_relays = self.relays.len();
307        self.add_relays().await;
308        self.remove_relays().await;
309        let new_number_of_relays = self.relays.len();
310        if new_number_of_relays != old_number_of_relays {
311            info!("Number of relays: {}", new_number_of_relays);
312        }
313    }
314
315    async fn add_relays(&mut self) {
316        for interface in &self.network_interfaces {
317            let Some(interface_address) = get_first_ipv4_address(interface) else {
318                continue;
319            };
320            if interface_address.is_loopback() {
321                continue;
322            }
323            for streamer in &self.streamers {
324                if self.relay_already_added(interface_address, &streamer.url) {
325                    continue;
326                }
327                info!(
328                    "Adding relay called {} with interface address {} for streamer name {} and \
329                     URL {}",
330                    interface.name, interface_address, streamer.name, streamer.url
331                );
332                self.relays.push(
333                    ServiceRelay::new(
334                        interface.name.clone(),
335                        interface_address,
336                        streamer.name.clone(),
337                        streamer.url.clone(),
338                        self.password.clone(),
339                        self.create_get_status_closure(),
340                        self.database.clone(),
341                    )
342                    .await,
343                );
344            }
345        }
346    }
347
348    pub fn create_get_status_closure(&self) -> Option<GetStatusClosure> {
349        let relay_service = self.me.clone();
350        Some(Box::new(move || {
351            let relay_service = relay_service.clone();
352            Box::pin(async move {
353                if let Some(relay_service) = relay_service.upgrade() {
354                    relay_service.lock().await.status.clone()
355                } else {
356                    Status::default()
357                }
358            })
359        }))
360    }
361
362    fn relay_already_added(&self, interface_address: Ipv4Addr, streamer_url: &str) -> bool {
363        self.relays.iter().any(|relay| {
364            relay.interface_address == interface_address && relay.streamer_url == streamer_url
365        })
366    }
367
368    async fn remove_relays(&mut self) {
369        let mut relays_to_keep: Vec<ServiceRelay> = Vec::new();
370        let mut relays_to_remove: Vec<ServiceRelay> = Vec::new();
371        for relay in self.relays.drain(..) {
372            if Self::should_keep_relay(&self.network_interfaces, relay.interface_address) {
373                relays_to_keep.push(relay);
374            } else {
375                relays_to_remove.push(relay);
376            }
377        }
378        self.relays = relays_to_keep;
379        for relay in relays_to_remove {
380            info!(
381                "Removing relay called {} with interface address {} for streamer name {} and URL \
382                 {}",
383                relay.interface_name,
384                relay.interface_address,
385                relay.streamer_name,
386                relay.streamer_url
387            );
388            relay.relay.stop().await;
389        }
390    }
391
392    fn should_keep_relay(
393        network_interfaces: &[NetworkInterface],
394        interface_address: Ipv4Addr,
395    ) -> bool {
396        network_interfaces
397            .iter()
398            .any(|interface| get_first_ipv4_address(interface) == Some(interface_address))
399    }
400}
401
402pub struct RelayService {
403    inner: Arc<Mutex<RelayServiceInner>>,
404}
405
406impl RelayService {
407    pub async fn new(
408        password: String,
409        network_interfaces_to_allow: Vec<String>,
410        network_interfaces_to_ignore: Vec<String>,
411        get_status: Option<GetStatusClosure>,
412        database: PathBuf,
413    ) -> Self {
414        Self {
415            inner: RelayServiceInner::new(
416                password,
417                network_interfaces_to_allow,
418                network_interfaces_to_ignore,
419                get_status,
420                database,
421            )
422            .await,
423        }
424    }
425
426    pub async fn start(&self) {
427        self.inner.lock().await.start().await;
428    }
429
430    pub async fn stop(&self) {
431        self.inner.lock().await.stop().await;
432    }
433}