irox_networking/
pool.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2023 IROX Contributors
3
4use std::fmt::Debug;
5use std::io::Write;
6use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::{Arc, Mutex};
9use std::thread;
10use std::thread::JoinHandle;
11
12use log::{error, info};
13
14pub type OnConnectionCallback = Box<dyn FnMut(&TcpStream, &SocketAddr)>;
15pub type ConnectionWorker = Box<dyn Fn(&TcpStream)>;
16
17pub struct TCPConnectionManager {
18    active_connections: Arc<Mutex<Vec<TcpStream>>>,
19    running_thread: JoinHandle<()>,
20}
21
22impl TCPConnectionManager {
23    pub fn start<A: ToSocketAddrs + Debug>(
24        addr: A,
25        close: Arc<AtomicBool>,
26    ) -> Result<TCPConnectionManager, std::io::Error> {
27        let mut addr: Vec<SocketAddr> = match addr.to_socket_addrs() {
28            Ok(a) => a.collect(),
29            Err(e) => {
30                error!("Error converting {addr:?} to socketaddr");
31                return Err(e);
32            }
33        };
34        info!("Collected SocketAddrs: {addr:?}");
35        let Some(addr) = addr.pop() else {
36            return Err(std::io::ErrorKind::InvalidInput.into());
37        };
38
39        let sock = match TcpListener::bind(addr) {
40            Ok(s) => s,
41            Err(e) => {
42                error!("Error binding to address {:?}: {e:?}", &addr);
43                return Err(e);
44            }
45        };
46
47        let active_connections = Arc::new(Mutex::new(Vec::new()));
48
49        let conns = active_connections.clone();
50        let handle = thread::spawn(move || {
51            while !close.load(Ordering::Relaxed) {
52                let client = match sock.accept() {
53                    Ok(c) => c,
54                    Err(e) => {
55                        error!("SocketAccept error: {e:?}");
56                        continue;
57                    }
58                };
59                info!("New client connected: {}", client.1);
60
61                let Ok(ref mut conns) = conns.lock() else {
62                    continue;
63                };
64                conns.push(client.0);
65            }
66        });
67
68        Ok(TCPConnectionManager {
69            active_connections,
70            running_thread: handle,
71        })
72    }
73
74    pub fn join(self) -> thread::Result<()> {
75        self.running_thread.join()
76    }
77
78    pub fn write_to_all_connected(&mut self, data: &[u8]) {
79        let Ok(ref mut conns) = self.active_connections.lock() else {
80            return;
81        };
82        conns.retain_mut(|x| {
83            let Ok(()) = x.write_all(data) else {
84                // remove and close the TCP stream if there was an error writing to it.
85                return false;
86            };
87            true
88        });
89    }
90
91    pub fn for_each_connected<T: FnMut(&mut TcpStream) -> bool>(&mut self, func: T) {
92        let Ok(ref mut conns) = self.active_connections.lock() else {
93            return;
94        };
95        conns.retain_mut(func);
96    }
97}