collab_server/
server.rs

1use std::net::{Ipv4Addr, SocketAddrV4};
2use std::sync::Arc;
3
4use async_compat::Compat;
5use futures_rustls::TlsAcceptor;
6use rustls::ServerConfig;
7use tokio::net::TcpListener;
8use tokio::sync::mpsc;
9
10use crate::*;
11
12/// TODO: docs
13pub struct Server {
14    /// TODO: docs
15    active_sessions: Sessions,
16
17    /// The socket the server is listening on.
18    listening_socket: SocketAddrV4,
19
20    /// TODO: docs
21    tls_config: Option<ServerConfig>,
22
23    #[cfg(any(test, feature = "__tests"))]
24    receiver: Option<mpsc::Receiver<SessionMessage>>,
25
26    #[cfg(any(test, feature = "__tests"))]
27    sender: Option<mpsc::Sender<SessionMessage>>,
28}
29
30impl Default for Server {
31    #[inline]
32    fn default() -> Self {
33        Self::builder().build()
34    }
35}
36
37impl Server {
38    /// TODO: docs
39    #[inline]
40    pub fn builder() -> ServerBuilder {
41        ServerBuilder::default()
42    }
43
44    /// Handles a message received from a connection.
45    #[inline]
46    fn handle_message(&mut self, msg: SessionMessage) {
47        use SessionMessage::*;
48
49        match msg {
50            JoinSession(peer_id, session_id, sender) => {
51                let res = self.active_sessions.join(session_id, peer_id);
52                let _ = sender.send(res);
53            },
54
55            PeerLeft(peer_id, session_id) => {
56                self.active_sessions.remove_peer(session_id, peer_id);
57            },
58
59            StartSession(peer_id, sender) => {
60                let infos = self.active_sessions.start(peer_id);
61                let _ = sender.send(infos);
62            },
63
64            #[cfg(any(test, feature = "__tests"))]
65            Ping(sender) => {
66                let _ = sender.send(());
67            },
68
69            #[cfg(any(test, feature = "__tests"))]
70            Sessions(sender) => {
71                let _ = sender.send(self.active_sessions.clone());
72            },
73        }
74    }
75
76    /// TODO: docs
77    #[inline]
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    /// TODO: docs
83    #[inline]
84    pub async fn run(mut self) -> Result<(), anyhow::Error> {
85        let tcp_listener = TcpListener::bind(self.listening_socket).await?;
86
87        let tls_acceptor =
88            self.tls_config.take().map(Arc::new).map(TlsAcceptor::from);
89
90        #[cfg(not(any(test, feature = "__tests")))]
91        let (sender, mut receiver) = mpsc::channel(1024);
92
93        #[cfg(any(test, feature = "__tests"))]
94        let (sender, mut receiver) =
95            (self.sender.take().unwrap(), self.receiver.take().unwrap());
96
97        tokio::spawn(async move {
98            while let Some(msg) = receiver.recv().await {
99                self.handle_message(msg);
100            }
101        });
102
103        loop {
104            let (tcp_stream, client_addr) = tcp_listener.accept().await?;
105
106            let tcp_stream = Compat::new(tcp_stream);
107
108            let sender = sender.clone();
109
110            if let Some(tls_acceptor) = &tls_acceptor {
111                let tls_stream = tls_acceptor.accept(tcp_stream).await?;
112                let run = Connection::run(tls_stream, client_addr, sender);
113                tokio::spawn(run);
114            } else {
115                let run = Connection::run(tcp_stream, client_addr, sender);
116                tokio::spawn(run);
117            }
118        }
119    }
120
121    /// TODO: docs
122    #[cfg(any(test, feature = "__tests"))]
123    pub fn test_sender(&self) -> TestSender {
124        TestSender::new(self.sender.as_ref().unwrap().clone())
125    }
126}
127
128/// TODO: docs
129#[derive(Clone)]
130pub struct ServerBuilder {
131    /// The IP address to run the server on.
132    ip: Ipv4Addr,
133
134    /// The port the server will listen on.
135    port: u16,
136
137    /// TODO: docs
138    tls_config: Option<ServerConfig>,
139}
140
141impl Default for ServerBuilder {
142    #[inline]
143    fn default() -> Self {
144        Self {
145            ip: Ipv4Addr::new(0, 0, 0, 0),
146            port: common::SERVER_LISTENING_PORT,
147            tls_config: None,
148        }
149    }
150}
151
152impl ServerBuilder {
153    /// TODO: docs
154    #[inline]
155    pub fn build(self) -> Server {
156        #[cfg(any(test, feature = "__tests"))]
157        let (sender, receiver) = mpsc::channel(1024);
158
159        Server {
160            active_sessions: Sessions::new(),
161            listening_socket: SocketAddrV4::new(self.ip, self.port),
162            tls_config: self.tls_config,
163            #[cfg(any(test, feature = "__tests"))]
164            sender: Some(sender),
165            #[cfg(any(test, feature = "__tests"))]
166            receiver: Some(receiver),
167        }
168    }
169
170    /// TODO: docs
171    #[inline]
172    pub fn with_ip(mut self, ip: Ipv4Addr) -> Self {
173        self.ip = ip;
174        self
175    }
176
177    /// TODO: docs
178    #[inline]
179    pub fn with_port(mut self, port: u16) -> Self {
180        self.port = port;
181        self
182    }
183
184    /// TODO: docs
185    #[inline]
186    pub fn with_tls_config(mut self, config: ServerConfig) -> Self {
187        self.tls_config = Some(config);
188        self
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[tokio::test]
197    async fn server_start_twice() -> Result<(), anyhow::Error> {
198        let server = Server::new();
199
200        let sender = server.test_sender();
201
202        tokio::spawn(async move { server.run().await });
203
204        sender.start().await?;
205
206        let server = Server::new();
207
208        assert!(
209            server.run().await.is_err(),
210            "running the server should fail if the socket is being used"
211        );
212
213        Ok(())
214    }
215}