mt_sea/
net.rs

1use std::{
2    collections::{HashMap, HashSet},
3    net::{Ipv4Addr, SocketAddr, TcpStream},
4    str::FromStr,
5};
6
7use log::{debug, error, info, warn};
8use nalgebra::DMatrix;
9
10use rkyv::{Archive, Deserialize, Serialize, api::low::from_bytes, rancor};
11
12use crate::{Action, ShipKind, ShipName, VariableHuman, WindData, client::Client};
13
14pub const PROTO_IDENTIFIER: u8 = 69;
15pub const CONTROLLER_CLIENT_ID: ShipName = 0;
16pub const CLIENT_REGISTER_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(150);
17pub const CLIENT_LISTEN_PORT: u16 = 6594;
18pub const CLIENT_REJOIN_POLL_INTERVAL: std::time::Duration = std::time::Duration::from_secs(1);
19pub const CLIENT_HEARTBEAT_TCP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(1);
20pub const CLIENT_HEARTBEAT_TCP_INTERVAL: std::time::Duration =
21    std::time::Duration::from_millis(200);
22pub const SERVER_DROP_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(200);
23pub const CLIENT_TO_CLIENT_TIMEOUT: std::time::Duration = std::time::Duration::MAX; // TODO can be deleted in that case
24pub const CLIENT_TO_CLIENT_INIT_RETRY_TIMEOUT: std::time::Duration =
25    std::time::Duration::from_millis(50);
26
27pub fn get_domain_id() -> u16 {
28    let val = std::env::var("MINOT_DOMAIN_ID")
29        .ok()
30        .unwrap_or("0".to_owned());
31    let parsed = val.parse::<u16>().ok();
32    match parsed {
33        Some(parsed) => parsed,
34        None => {
35            warn!("Invalid MINOT_DOMAIN_ID, selecting default 0");
36            0
37        }
38    }
39}
40
41#[derive(Archive, Serialize, Deserialize, Clone, Debug)]
42pub struct WindAt {
43    pub data: WindData,
44    pub at_var: Option<String>,
45}
46
47/// A wrapper type for using 0.8 rkyv APIs with nalgebra
48#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
49pub struct NetArray<T: nalgebra::Scalar> {
50    cols: usize,
51    data: Vec<T>,
52    rows: usize,
53}
54
55impl<T: nalgebra::Scalar> From<DMatrix<T>> for NetArray<T> {
56    fn from(value: DMatrix<T>) -> Self {
57        Self {
58            rows: value.nrows(),
59            cols: value.ncols(),
60            data: value.data.into(),
61        }
62    }
63}
64
65impl<T: nalgebra::Scalar> From<NetArray<T>> for DMatrix<T> {
66    fn from(value: NetArray<T>) -> Self {
67        Self::from_data(nalgebra::VecStorage::new(
68            nalgebra::Dyn(value.rows),
69            nalgebra::Dyn(value.cols),
70            value.data,
71        ))
72    }
73}
74
75#[derive(Serialize, Deserialize, Archive, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
76pub enum RatPubRegisterKind {
77    Publish,
78    Subscribe,
79}
80
81#[derive(Serialize, Deserialize, Archive, Clone, Debug)]
82pub enum PacketKind {
83    Acknowledge,
84    Retry,
85    RequestVarSend(String),
86    JoinRequest {
87        tcp_port: u16,
88        other_client_entrance: u16,
89        kind: ShipKind,
90        remove_rules_on_disconnect: bool,
91        domain_id: u16,
92    },
93    Welcome {
94        addr: crate::NetworkShipAddress,
95        wait_for_ack: bool,
96    }, // the id of the rat so the coordinator can differentiate them and the tcp port for 1:1 and heartbeat
97    Heartbeat,
98    Disconnect,
99    RuleAppend {
100        variable: String,
101        commands: Vec<VariableHuman>,
102    },
103    RulesClear,
104    LockNext {
105        unlock_first: bool,
106    },
107    Unlock,
108    RawDataf64(NetArray<f64>),
109    RawDataf32(NetArray<f32>),
110    RawDatai32(NetArray<i32>),
111    RawDatau8(NetArray<u8>),
112    VariableTaskRequest(String),
113    RatAction {
114        action: Action,
115        lock_until_ack: bool,
116    },
117    Wind(Vec<WindAt>),
118    WindDynamic(String),
119    RegisterShipAtVar {
120        ship: String,
121        var: String,
122        kind: RatPubRegisterKind,
123    },
124}
125
126// unsafe impl Send for PacketKind {}
127
128// With a join request, the client sends a joinrequest with udp to all
129// available broadcast addresses. The UDP port is not important here.
130// The target port is the fixed udp port of the server.
131// Since we only have one fixed port on a device, the coordinator,
132// the clients need to use dynamic ports everywhere else.
133//
134// The coordinator listens to these requests on the fixed port.
135// The join request includes the tcp listener port of the client.
136// The coordinator must save this port. It now always uses that when communicating with it and it also sends to it other clients if they need
137// to have a connection.
138
139#[derive(Archive, Serialize, Deserialize, Copy, Clone, Debug, Default)]
140pub struct Header {
141    pub source: ShipName,
142    pub target: ShipName,
143}
144
145#[derive(Archive, Serialize, Deserialize, Clone, Debug)]
146pub struct Packet {
147    pub header: Header,
148    pub data: PacketKind,
149}
150
151// unsafe impl Send for Packet {}
152
153#[derive(Clone, Debug)]
154pub struct ShipHandle {
155    pub name: ShipKind,
156    pub addr_from_coord: crate::NetworkShipAddress,
157    pub ship: ShipName,
158    // Send here to disconnect the tcp listener
159    pub disconnect: tokio::sync::broadcast::Sender<bool>,
160    // get requests from the client
161    pub recv: tokio::sync::broadcast::Sender<(Packet, Option<SocketAddr>)>,
162    // send to this client
163    pub send: tokio::sync::mpsc::Sender<Packet>,
164    pub other_client_port: u16,
165    pub remove_rules_on_disconnect: bool,
166}
167
168#[derive(Debug)]
169pub struct Sea {
170    pub network_clients_chan: tokio::sync::broadcast::Sender<ShipHandle>,
171    dissolve_network: tokio::sync::mpsc::Sender<tokio::sync::mpsc::Sender<()>>,
172}
173
174/// Server handling the Sea network
175impl Sea {
176    pub async fn init(
177        external_ip: Option<[u8; 4]>,
178        clients_wait_for_ack: std::sync::Arc<std::sync::RwLock<bool>>,
179    ) -> Self {
180        let (rejoin_req_tx, mut rejoin_req_rx) = tokio::sync::mpsc::channel::<(String, Packet)>(10);
181
182        let (clients_tx, mut clients_rx) = tokio::sync::broadcast::channel::<ShipHandle>(10);
183
184        let (dissolve_network_tx, mut dissolve_network_rx) =
185            tokio::sync::mpsc::channel::<tokio::sync::mpsc::Sender<()>>(10);
186
187        // task to disconnect all clients i.e. dissolve
188        tokio::spawn(async move {
189            let mut clients: Vec<tokio::sync::broadcast::Sender<bool>> = Vec::new();
190            loop {
191                tokio::select! {
192                    answer = dissolve_network_rx.recv() => {
193                        match answer {
194                            None => {
195                                // channel closed
196                                return;
197                            }
198                            Some(answer) => {
199                                for c in clients.iter() {
200                                    c.send(true).unwrap();
201                                }
202                                // notify that we are finished
203                                answer.send(()).await.unwrap();
204                                return;
205                            }
206                        }
207                    }
208                    newclient = clients_rx.recv() => {
209                        match newclient {
210                            Err(e) => {
211                                error!("Error receiving new client in dissolve handler: {e}");
212                            }
213                            Ok(client) => {
214                                clients.push(client.disconnect);
215                            }
216                        }
217                    }
218                }
219            }
220        });
221
222        // task to handle join requests on udp socket
223        tokio::spawn(async move {
224            let coordinator_domain_id = get_domain_id();
225            if coordinator_domain_id > 0 {
226                info!("Coordinator using domain ID {}", coordinator_domain_id);
227            }
228            let udp_listener = Client::get_udp_socket(external_ip, Some(CLIENT_LISTEN_PORT)).await;
229            let rejoin_request = Packet {
230                header: Header {
231                    source: ShipName::MAX,
232                    target: CONTROLLER_CLIENT_ID,
233                },
234                // dummy. names are padded with maximal length 64 chars
235                data: PacketKind::JoinRequest {
236                    tcp_port: 0,
237                    other_client_entrance: 0,
238                    kind: Sea::pad_ship_kind_name(&ShipKind::Rat("".to_string())),
239                    remove_rules_on_disconnect: false,
240                    domain_id: 0,
241                },
242            };
243            let bytes_rejoin_request = rkyv::api::high::to_bytes::<rancor::Error>(&rejoin_request)
244                .expect("could not serialize rejoin request");
245            let expected_n_bytes_for_rejoin_request = bytes_rejoin_request.len();
246            let mut new_clients_without_response = HashMap::<String, (usize, Vec<u8>)>::new();
247
248            info!("Listening {:?}", udp_listener.local_addr().unwrap());
249            debug!(
250                "Expecting {} bytes for JoinRequest including PROTO_IDENTIFIER",
251                expected_n_bytes_for_rejoin_request + 1
252            );
253
254            loop {
255                let mut buf = [0; 256]; // JoinRequest normally 115 bytes
256                let (n, addr) = udp_listener.recv_from(&mut buf).await.unwrap();
257                let id = format!("{}:{}", addr.ip(), addr.port());
258                debug!("Receiving {} bytes from {} via UDP", n, id);
259
260                match new_clients_without_response.get_mut(&id) {
261                    Some((kum, buffer)) => {
262                        *kum += n;
263                        buffer.extend_from_slice(&buf[..n]);
264                    }
265                    None => {
266                        let mut buffer = Vec::with_capacity(1024);
267                        if buf[0] != PROTO_IDENTIFIER {
268                            continue; // not meant for us
269                        } else {
270                            buffer.extend_from_slice(&buf[1..n]);
271                        }
272                        new_clients_without_response.insert(id, (buffer.len(), buffer));
273                    }
274                }
275
276                let mut to_delete = Vec::<String>::new();
277                for (id, (kum, buffer)) in new_clients_without_response.iter_mut() {
278                    if *kum != expected_n_bytes_for_rejoin_request {
279                        continue;
280                    }
281
282                    let packet: Packet = match from_bytes::<Packet, rancor::Error>(buffer) {
283                        Err(e) => {
284                            error!("Received package is broken: {e}");
285                            continue;
286                        }
287                        Ok(packet) => packet,
288                    };
289
290                    match rejoin_req_tx.send((id.clone(), packet)).await {
291                        Err(e) => {
292                            error!("Could not send rejoin request to internal channel: {e}");
293                        }
294                        Ok(_) => {
295                            to_delete.push(id.clone());
296                        }
297                    };
298                }
299
300                for id in to_delete {
301                    new_clients_without_response.remove(&id);
302                }
303
304                tokio::task::yield_now().await; // needed to yield to receiver, else the thread is sometimes blocking
305            }
306        });
307
308        // task to wait for each new join request
309        let clients_tx_inner = clients_tx.clone();
310        tokio::spawn(async move {
311            let coordinator_domain_id = get_domain_id();
312            let rat_lock = std::sync::Arc::new(std::sync::Mutex::new(HashSet::new()));
313            // let rat_lock_unlock = Arc::clone(&rat_lock); // currently we only allow unique names, so no unlock needed.
314            loop {
315                let receive = rejoin_req_rx.recv().await;
316                if let Some((addr, packet)) = receive {
317                    match packet.data {
318                        PacketKind::JoinRequest {
319                            tcp_port: client_tcp_port,
320                            other_client_entrance: other_client_port,
321                            kind: ship_kind,
322                            remove_rules_on_disconnect,
323                            domain_id: client_domain_id,
324                        } => {
325                            // Filter by domain ID
326                            if client_domain_id != coordinator_domain_id {
327                                debug!(
328                                    "Rejecting join request from domain {} (coordinator is domain {})",
329                                    client_domain_id, coordinator_domain_id
330                                );
331                                continue;
332                            }
333
334                            let ship_kind = Sea::unpad_ship_kind_name(&ship_kind);
335                            debug!("Received RejoinRequest: {:?} from {:?}", ship_kind, addr);
336                            {
337                                let mut lock = rat_lock.lock().unwrap();
338                                if lock.get(&ship_kind).is_some() {
339                                    debug!(
340                                        "requested client already exists or is in the progress of joining the network"
341                                    );
342                                    continue;
343                                }
344                                lock.insert(ship_kind.clone());
345                            }
346                            let generated_id = rand::random::<ShipName>().abs();
347                            let (disconnect_tx, _disconnect_rx) =
348                                tokio::sync::broadcast::channel::<bool>(1);
349
350                            // task to handle tcp connection to this client
351                            let curr_client_create_sender = clients_tx_inner.clone();
352                            let ships_lock_for_disconnect = std::sync::Arc::clone(&rat_lock);
353                            let clwa = std::sync::Arc::clone(&clients_wait_for_ack);
354                            tokio::spawn(async move {
355                                let ip = addr.split(':').next().unwrap();
356                                let client_stream = tokio::net::TcpStream::connect(format!(
357                                    "{}:{}",
358                                    ip, client_tcp_port
359                                ))
360                                .await
361                                .expect("could not connect to client");
362
363                                let socket =
364                                    socket2::Socket::from(client_stream.into_std().unwrap());
365                                socket.set_keepalive(true).unwrap();
366
367                                // socket
368                                //     .set_tcp_keepalive(
369                                //         &socket2::TcpKeepalive::new()
370                                //             .with_time(CLIENT_HEARTBEAT_TCP_TIMEOUT)
371                                //             .with_interval(CLIENT_HEARTBEAT_TCP_INTERVAL),
372                                //     )
373                                //     .unwrap();
374                                socket
375                                    .set_linger(Some(std::time::Duration::from_secs(30)))
376                                    .unwrap();
377                                let stream: TcpStream = socket.into();
378                                let client_stream =
379                                    tokio::net::TcpStream::from_std(stream).unwrap();
380
381                                let (rh, wh) = client_stream.into_split();
382
383                                let (tx, _) = tokio::sync::broadcast::channel::<(
384                                    Packet,
385                                    Option<SocketAddr>,
386                                )>(10);
387                                let tx_out = tx.clone();
388
389                                let (client_sender_tx, client_sender_rx) =
390                                    tokio::sync::mpsc::channel::<Packet>(10);
391
392                                // reading stream from client
393                                let current_ship = ship_kind.clone();
394
395                                let ship_kind_for_disconnect = ship_kind.clone();
396                                tokio::spawn(async move {
397                                    Client::receive_from_socket(rh, tx, None).await;
398                                    warn!("Client {:?} disconnected.", current_ship);
399                                    {
400                                        let mut lock = ships_lock_for_disconnect.lock().unwrap();
401                                        lock.remove(&ship_kind_for_disconnect);
402                                    }
403                                    // TODO when here, the client is disconnected, so stop all processing (block or send wait signal) until the clients are back
404                                });
405
406                                // writing stream to client
407                                tokio::spawn(async move {
408                                    Client::send_to_socket(client_sender_rx, wh).await;
409                                });
410
411                                let ip_parsed = Ipv4Addr::from_str(ip).expect("Strange ip format");
412
413                                let client_addr = crate::NetworkShipAddress {
414                                    ip: ip_parsed.octets(),
415                                    port: client_tcp_port,
416                                    ship: generated_id,
417                                    kind: ship_kind.clone(),
418                                };
419
420                                let current_clients_wait_for_ack = { *clwa.read().unwrap() };
421                                let welcome_packet = Packet {
422                                    header: Header {
423                                        source: CONTROLLER_CLIENT_ID,
424                                        target: generated_id,
425                                    },
426                                    data: PacketKind::Welcome {
427                                        addr: client_addr.clone(),
428                                        wait_for_ack: current_clients_wait_for_ack,
429                                    },
430                                };
431
432                                match client_sender_tx.send(welcome_packet).await {
433                                    Ok(_) => {}
434                                    Err(e) => {
435                                        error!("Could not send welcome packet to channel: {e}");
436                                    }
437                                }
438
439                                let ship_handle = ShipHandle {
440                                    ship: generated_id,
441                                    disconnect: disconnect_tx,
442                                    recv: tx_out,
443                                    send: client_sender_tx,
444                                    name: ship_kind,
445                                    addr_from_coord: client_addr,
446                                    other_client_port,
447                                    remove_rules_on_disconnect,
448                                };
449                                curr_client_create_sender.send(ship_handle).unwrap();
450                                debug!("ShipHandle created and sent");
451                            });
452                        }
453                        _ => {
454                            warn!("Received unexpected packet: {packet:?}");
455                        }
456                    }
457                } else {
458                    error!("Channel closed, could not receive rejoin requests in channel.");
459                }
460            }
461        });
462
463        Self {
464            network_clients_chan: clients_tx,
465            dissolve_network: dissolve_network_tx,
466        }
467    }
468
469    pub fn pad_string(input: &str) -> String {
470        if input.len() >= 64 {
471            return input.to_string(); // Return the string if it's already 64 or longer
472        }
473        let padding_count = 64 - input.len();
474        let padding = "#".repeat(padding_count);
475        format!("{}{}", input, padding)
476    }
477
478    pub fn reverse_padding(input: &str) -> String {
479        let trimmed: &str = input.trim_end_matches('#');
480        trimmed.to_string()
481    }
482
483    pub fn pad_ship_kind_name(kind: &ShipKind) -> ShipKind {
484        match kind {
485            ShipKind::Rat(name) => ShipKind::Rat(Self::pad_string(name)),
486            ShipKind::Wind(name) => ShipKind::Wind(Self::pad_string(name)),
487        }
488    }
489
490    pub fn unpad_ship_kind_name(kind: &ShipKind) -> ShipKind {
491        match kind {
492            ShipKind::Rat(name) => ShipKind::Rat(Self::reverse_padding(name)),
493            ShipKind::Wind(name) => ShipKind::Wind(Self::reverse_padding(name)),
494        }
495    }
496
497    // TODO never worked
498    pub async fn cleanup(&mut self) {
499        let (answer_tx, mut answer_rx) = tokio::sync::mpsc::channel(1);
500        if let Err(e) = self.dissolve_network.send(answer_tx).await {
501            error!("Error while droppping network: {e}");
502        }
503
504        let answer_timeout = tokio::time::timeout(SERVER_DROP_TIMEOUT, answer_rx.recv());
505        // tokio::task::yield_now().await;
506        match answer_timeout.await {
507            Err(e) => {
508                error!("Dropping network timeout, discarding waiting for completion: {e}");
509            }
510            Ok(None) => {
511                warn!("Sender already closed in dissolving answer");
512            }
513            _ => {}
514        }
515    }
516}
517
518// TODO block_in_place blocks the current thread. So when the dissolve_network receivers are running on the same thread, is this a deadlock since we are blocking the thread?
519// impl Drop for Sea {
520//     fn drop(&mut self) {
521//         tokio::task::block_in_place(move || {
522//             let rt = tokio::runtime::Handle::current();
523//             rt.block_on(self.cleanup());
524//         });
525//     }
526// }
527
528// impl std::future::AsyncDrop for Sea {
529//     type Dropper<'a> = impl std::future::Future<Output = ()>;
530
531//     fn async_drop(self: std::pin::Pin<&mut Self>) -> Self::Dropper<'_> {
532//         println!("calling async drop");
533//         async move {
534//             let (answer_tx, mut answer_rx) = tokio::sync::mpsc::channel(1);
535//             match self.dissolve_network.send(answer_tx).await {
536//                 Err(e) => {
537//                     error!("Error while droppping network: {e}");
538//                 }
539//                 Ok(_) => {}
540//             }
541
542//             let answer_timeout = tokio::time::timeout(SERVER_DROP_TIMEOUT, answer_rx.recv());
543//             match answer_timeout.await {
544//                 Err(e) => {
545//                     error!("Dropping network timeout, discarding waiting for completion: {e}");
546//                 }
547//                 Ok(None) => {
548//                     error!("Sender already closed in dissolving answer");
549//                 }
550//                 _ => {}
551//             }
552//         }
553//     }
554// }