1use std::{
2 collections::{HashMap, HashSet},
3 net::{Ipv4Addr, SocketAddr, TcpStream},
4 str::FromStr,
5};
6
7use log::{debug, error, info, warn};
8use nalgebra::DMatrix;
9
10use rkyv::{Archive, Deserialize, Serialize, api::low::from_bytes, rancor};
11
12use crate::{Action, ShipKind, ShipName, VariableHuman, WindData, client::Client};
13
14pub const PROTO_IDENTIFIER: u8 = 69;
15pub const CONTROLLER_CLIENT_ID: ShipName = 0;
16pub const CLIENT_REGISTER_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(150);
17pub const CLIENT_LISTEN_PORT: u16 = 6594;
18pub const CLIENT_REJOIN_POLL_INTERVAL: std::time::Duration = std::time::Duration::from_secs(1);
19pub const CLIENT_HEARTBEAT_TCP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(1);
20pub const CLIENT_HEARTBEAT_TCP_INTERVAL: std::time::Duration =
21 std::time::Duration::from_millis(200);
22pub const SERVER_DROP_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(200);
23pub const CLIENT_TO_CLIENT_TIMEOUT: std::time::Duration = std::time::Duration::MAX; pub const CLIENT_TO_CLIENT_INIT_RETRY_TIMEOUT: std::time::Duration =
25 std::time::Duration::from_millis(50);
26
27pub fn get_domain_id() -> u16 {
28 let val = std::env::var("MINOT_DOMAIN_ID")
29 .ok()
30 .unwrap_or("0".to_owned());
31 let parsed = val.parse::<u16>().ok();
32 match parsed {
33 Some(parsed) => parsed,
34 None => {
35 warn!("Invalid MINOT_DOMAIN_ID, selecting default 0");
36 0
37 }
38 }
39}
40
41#[derive(Archive, Serialize, Deserialize, Clone, Debug)]
42pub struct WindAt {
43 pub data: WindData,
44 pub at_var: Option<String>,
45}
46
47#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
49pub struct NetArray<T: nalgebra::Scalar> {
50 cols: usize,
51 data: Vec<T>,
52 rows: usize,
53}
54
55impl<T: nalgebra::Scalar> From<DMatrix<T>> for NetArray<T> {
56 fn from(value: DMatrix<T>) -> Self {
57 Self {
58 rows: value.nrows(),
59 cols: value.ncols(),
60 data: value.data.into(),
61 }
62 }
63}
64
65impl<T: nalgebra::Scalar> From<NetArray<T>> for DMatrix<T> {
66 fn from(value: NetArray<T>) -> Self {
67 Self::from_data(nalgebra::VecStorage::new(
68 nalgebra::Dyn(value.rows),
69 nalgebra::Dyn(value.cols),
70 value.data,
71 ))
72 }
73}
74
75#[derive(Serialize, Deserialize, Archive, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
76pub enum RatPubRegisterKind {
77 Publish,
78 Subscribe,
79}
80
81#[derive(Serialize, Deserialize, Archive, Clone, Debug)]
82pub enum PacketKind {
83 Acknowledge,
84 Retry,
85 RequestVarSend(String),
86 JoinRequest {
87 tcp_port: u16,
88 other_client_entrance: u16,
89 kind: ShipKind,
90 remove_rules_on_disconnect: bool,
91 domain_id: u16,
92 },
93 Welcome {
94 addr: crate::NetworkShipAddress,
95 wait_for_ack: bool,
96 }, Heartbeat,
98 Disconnect,
99 RuleAppend {
100 variable: String,
101 commands: Vec<VariableHuman>,
102 },
103 RulesClear,
104 LockNext {
105 unlock_first: bool,
106 },
107 Unlock,
108 RawDataf64(NetArray<f64>),
109 RawDataf32(NetArray<f32>),
110 RawDatai32(NetArray<i32>),
111 RawDatau8(NetArray<u8>),
112 VariableTaskRequest(String),
113 RatAction {
114 action: Action,
115 lock_until_ack: bool,
116 },
117 Wind(Vec<WindAt>),
118 WindDynamic(String),
119 RegisterShipAtVar {
120 ship: String,
121 var: String,
122 kind: RatPubRegisterKind,
123 },
124}
125
126#[derive(Archive, Serialize, Deserialize, Copy, Clone, Debug, Default)]
140pub struct Header {
141 pub source: ShipName,
142 pub target: ShipName,
143}
144
145#[derive(Archive, Serialize, Deserialize, Clone, Debug)]
146pub struct Packet {
147 pub header: Header,
148 pub data: PacketKind,
149}
150
151#[derive(Clone, Debug)]
154pub struct ShipHandle {
155 pub name: ShipKind,
156 pub addr_from_coord: crate::NetworkShipAddress,
157 pub ship: ShipName,
158 pub disconnect: tokio::sync::broadcast::Sender<bool>,
160 pub recv: tokio::sync::broadcast::Sender<(Packet, Option<SocketAddr>)>,
162 pub send: tokio::sync::mpsc::Sender<Packet>,
164 pub other_client_port: u16,
165 pub remove_rules_on_disconnect: bool,
166}
167
168#[derive(Debug)]
169pub struct Sea {
170 pub network_clients_chan: tokio::sync::broadcast::Sender<ShipHandle>,
171 dissolve_network: tokio::sync::mpsc::Sender<tokio::sync::mpsc::Sender<()>>,
172}
173
174impl Sea {
176 pub async fn init(
177 external_ip: Option<[u8; 4]>,
178 clients_wait_for_ack: std::sync::Arc<std::sync::RwLock<bool>>,
179 ) -> Self {
180 let (rejoin_req_tx, mut rejoin_req_rx) = tokio::sync::mpsc::channel::<(String, Packet)>(10);
181
182 let (clients_tx, mut clients_rx) = tokio::sync::broadcast::channel::<ShipHandle>(10);
183
184 let (dissolve_network_tx, mut dissolve_network_rx) =
185 tokio::sync::mpsc::channel::<tokio::sync::mpsc::Sender<()>>(10);
186
187 tokio::spawn(async move {
189 let mut clients: Vec<tokio::sync::broadcast::Sender<bool>> = Vec::new();
190 loop {
191 tokio::select! {
192 answer = dissolve_network_rx.recv() => {
193 match answer {
194 None => {
195 return;
197 }
198 Some(answer) => {
199 for c in clients.iter() {
200 c.send(true).unwrap();
201 }
202 answer.send(()).await.unwrap();
204 return;
205 }
206 }
207 }
208 newclient = clients_rx.recv() => {
209 match newclient {
210 Err(e) => {
211 error!("Error receiving new client in dissolve handler: {e}");
212 }
213 Ok(client) => {
214 clients.push(client.disconnect);
215 }
216 }
217 }
218 }
219 }
220 });
221
222 tokio::spawn(async move {
224 let coordinator_domain_id = get_domain_id();
225 if coordinator_domain_id > 0 {
226 info!("Coordinator using domain ID {}", coordinator_domain_id);
227 }
228 let udp_listener = Client::get_udp_socket(external_ip, Some(CLIENT_LISTEN_PORT)).await;
229 let rejoin_request = Packet {
230 header: Header {
231 source: ShipName::MAX,
232 target: CONTROLLER_CLIENT_ID,
233 },
234 data: PacketKind::JoinRequest {
236 tcp_port: 0,
237 other_client_entrance: 0,
238 kind: Sea::pad_ship_kind_name(&ShipKind::Rat("".to_string())),
239 remove_rules_on_disconnect: false,
240 domain_id: 0,
241 },
242 };
243 let bytes_rejoin_request = rkyv::api::high::to_bytes::<rancor::Error>(&rejoin_request)
244 .expect("could not serialize rejoin request");
245 let expected_n_bytes_for_rejoin_request = bytes_rejoin_request.len();
246 let mut new_clients_without_response = HashMap::<String, (usize, Vec<u8>)>::new();
247
248 info!("Listening {:?}", udp_listener.local_addr().unwrap());
249 debug!(
250 "Expecting {} bytes for JoinRequest including PROTO_IDENTIFIER",
251 expected_n_bytes_for_rejoin_request + 1
252 );
253
254 loop {
255 let mut buf = [0; 256]; let (n, addr) = udp_listener.recv_from(&mut buf).await.unwrap();
257 let id = format!("{}:{}", addr.ip(), addr.port());
258 debug!("Receiving {} bytes from {} via UDP", n, id);
259
260 match new_clients_without_response.get_mut(&id) {
261 Some((kum, buffer)) => {
262 *kum += n;
263 buffer.extend_from_slice(&buf[..n]);
264 }
265 None => {
266 let mut buffer = Vec::with_capacity(1024);
267 if buf[0] != PROTO_IDENTIFIER {
268 continue; } else {
270 buffer.extend_from_slice(&buf[1..n]);
271 }
272 new_clients_without_response.insert(id, (buffer.len(), buffer));
273 }
274 }
275
276 let mut to_delete = Vec::<String>::new();
277 for (id, (kum, buffer)) in new_clients_without_response.iter_mut() {
278 if *kum != expected_n_bytes_for_rejoin_request {
279 continue;
280 }
281
282 let packet: Packet = match from_bytes::<Packet, rancor::Error>(buffer) {
283 Err(e) => {
284 error!("Received package is broken: {e}");
285 continue;
286 }
287 Ok(packet) => packet,
288 };
289
290 match rejoin_req_tx.send((id.clone(), packet)).await {
291 Err(e) => {
292 error!("Could not send rejoin request to internal channel: {e}");
293 }
294 Ok(_) => {
295 to_delete.push(id.clone());
296 }
297 };
298 }
299
300 for id in to_delete {
301 new_clients_without_response.remove(&id);
302 }
303
304 tokio::task::yield_now().await; }
306 });
307
308 let clients_tx_inner = clients_tx.clone();
310 tokio::spawn(async move {
311 let coordinator_domain_id = get_domain_id();
312 let rat_lock = std::sync::Arc::new(std::sync::Mutex::new(HashSet::new()));
313 loop {
315 let receive = rejoin_req_rx.recv().await;
316 if let Some((addr, packet)) = receive {
317 match packet.data {
318 PacketKind::JoinRequest {
319 tcp_port: client_tcp_port,
320 other_client_entrance: other_client_port,
321 kind: ship_kind,
322 remove_rules_on_disconnect,
323 domain_id: client_domain_id,
324 } => {
325 if client_domain_id != coordinator_domain_id {
327 debug!(
328 "Rejecting join request from domain {} (coordinator is domain {})",
329 client_domain_id, coordinator_domain_id
330 );
331 continue;
332 }
333
334 let ship_kind = Sea::unpad_ship_kind_name(&ship_kind);
335 debug!("Received RejoinRequest: {:?} from {:?}", ship_kind, addr);
336 {
337 let mut lock = rat_lock.lock().unwrap();
338 if lock.get(&ship_kind).is_some() {
339 debug!(
340 "requested client already exists or is in the progress of joining the network"
341 );
342 continue;
343 }
344 lock.insert(ship_kind.clone());
345 }
346 let generated_id = rand::random::<ShipName>().abs();
347 let (disconnect_tx, _disconnect_rx) =
348 tokio::sync::broadcast::channel::<bool>(1);
349
350 let curr_client_create_sender = clients_tx_inner.clone();
352 let ships_lock_for_disconnect = std::sync::Arc::clone(&rat_lock);
353 let clwa = std::sync::Arc::clone(&clients_wait_for_ack);
354 tokio::spawn(async move {
355 let ip = addr.split(':').next().unwrap();
356 let client_stream = tokio::net::TcpStream::connect(format!(
357 "{}:{}",
358 ip, client_tcp_port
359 ))
360 .await
361 .expect("could not connect to client");
362
363 let socket =
364 socket2::Socket::from(client_stream.into_std().unwrap());
365 socket.set_keepalive(true).unwrap();
366
367 socket
375 .set_linger(Some(std::time::Duration::from_secs(30)))
376 .unwrap();
377 let stream: TcpStream = socket.into();
378 let client_stream =
379 tokio::net::TcpStream::from_std(stream).unwrap();
380
381 let (rh, wh) = client_stream.into_split();
382
383 let (tx, _) = tokio::sync::broadcast::channel::<(
384 Packet,
385 Option<SocketAddr>,
386 )>(10);
387 let tx_out = tx.clone();
388
389 let (client_sender_tx, client_sender_rx) =
390 tokio::sync::mpsc::channel::<Packet>(10);
391
392 let current_ship = ship_kind.clone();
394
395 let ship_kind_for_disconnect = ship_kind.clone();
396 tokio::spawn(async move {
397 Client::receive_from_socket(rh, tx, None).await;
398 warn!("Client {:?} disconnected.", current_ship);
399 {
400 let mut lock = ships_lock_for_disconnect.lock().unwrap();
401 lock.remove(&ship_kind_for_disconnect);
402 }
403 });
405
406 tokio::spawn(async move {
408 Client::send_to_socket(client_sender_rx, wh).await;
409 });
410
411 let ip_parsed = Ipv4Addr::from_str(ip).expect("Strange ip format");
412
413 let client_addr = crate::NetworkShipAddress {
414 ip: ip_parsed.octets(),
415 port: client_tcp_port,
416 ship: generated_id,
417 kind: ship_kind.clone(),
418 };
419
420 let current_clients_wait_for_ack = { *clwa.read().unwrap() };
421 let welcome_packet = Packet {
422 header: Header {
423 source: CONTROLLER_CLIENT_ID,
424 target: generated_id,
425 },
426 data: PacketKind::Welcome {
427 addr: client_addr.clone(),
428 wait_for_ack: current_clients_wait_for_ack,
429 },
430 };
431
432 match client_sender_tx.send(welcome_packet).await {
433 Ok(_) => {}
434 Err(e) => {
435 error!("Could not send welcome packet to channel: {e}");
436 }
437 }
438
439 let ship_handle = ShipHandle {
440 ship: generated_id,
441 disconnect: disconnect_tx,
442 recv: tx_out,
443 send: client_sender_tx,
444 name: ship_kind,
445 addr_from_coord: client_addr,
446 other_client_port,
447 remove_rules_on_disconnect,
448 };
449 curr_client_create_sender.send(ship_handle).unwrap();
450 debug!("ShipHandle created and sent");
451 });
452 }
453 _ => {
454 warn!("Received unexpected packet: {packet:?}");
455 }
456 }
457 } else {
458 error!("Channel closed, could not receive rejoin requests in channel.");
459 }
460 }
461 });
462
463 Self {
464 network_clients_chan: clients_tx,
465 dissolve_network: dissolve_network_tx,
466 }
467 }
468
469 pub fn pad_string(input: &str) -> String {
470 if input.len() >= 64 {
471 return input.to_string(); }
473 let padding_count = 64 - input.len();
474 let padding = "#".repeat(padding_count);
475 format!("{}{}", input, padding)
476 }
477
478 pub fn reverse_padding(input: &str) -> String {
479 let trimmed: &str = input.trim_end_matches('#');
480 trimmed.to_string()
481 }
482
483 pub fn pad_ship_kind_name(kind: &ShipKind) -> ShipKind {
484 match kind {
485 ShipKind::Rat(name) => ShipKind::Rat(Self::pad_string(name)),
486 ShipKind::Wind(name) => ShipKind::Wind(Self::pad_string(name)),
487 }
488 }
489
490 pub fn unpad_ship_kind_name(kind: &ShipKind) -> ShipKind {
491 match kind {
492 ShipKind::Rat(name) => ShipKind::Rat(Self::reverse_padding(name)),
493 ShipKind::Wind(name) => ShipKind::Wind(Self::reverse_padding(name)),
494 }
495 }
496
497 pub async fn cleanup(&mut self) {
499 let (answer_tx, mut answer_rx) = tokio::sync::mpsc::channel(1);
500 if let Err(e) = self.dissolve_network.send(answer_tx).await {
501 error!("Error while droppping network: {e}");
502 }
503
504 let answer_timeout = tokio::time::timeout(SERVER_DROP_TIMEOUT, answer_rx.recv());
505 match answer_timeout.await {
507 Err(e) => {
508 error!("Dropping network timeout, discarding waiting for completion: {e}");
509 }
510 Ok(None) => {
511 warn!("Sender already closed in dissolving answer");
512 }
513 _ => {}
514 }
515 }
516}
517
518