1use std::net::{Ipv4Addr, SocketAddrV4};
2use std::sync::Arc;
3
4use async_compat::Compat;
5use futures_rustls::TlsAcceptor;
6use rustls::ServerConfig;
7use tokio::net::TcpListener;
8use tokio::sync::mpsc;
9
10use crate::*;
11
12pub struct Server {
14 active_sessions: Sessions,
16
17 listening_socket: SocketAddrV4,
19
20 tls_config: Option<ServerConfig>,
22
23 #[cfg(any(test, feature = "__tests"))]
24 receiver: Option<mpsc::Receiver<SessionMessage>>,
25
26 #[cfg(any(test, feature = "__tests"))]
27 sender: Option<mpsc::Sender<SessionMessage>>,
28}
29
30impl Default for Server {
31 #[inline]
32 fn default() -> Self {
33 Self::builder().build()
34 }
35}
36
37impl Server {
38 #[inline]
40 pub fn builder() -> ServerBuilder {
41 ServerBuilder::default()
42 }
43
44 #[inline]
46 fn handle_message(&mut self, msg: SessionMessage) {
47 use SessionMessage::*;
48
49 match msg {
50 JoinSession(peer_id, session_id, sender) => {
51 let res = self.active_sessions.join(session_id, peer_id);
52 let _ = sender.send(res);
53 },
54
55 PeerLeft(peer_id, session_id) => {
56 self.active_sessions.remove_peer(session_id, peer_id);
57 },
58
59 StartSession(peer_id, sender) => {
60 let infos = self.active_sessions.start(peer_id);
61 let _ = sender.send(infos);
62 },
63
64 #[cfg(any(test, feature = "__tests"))]
65 Ping(sender) => {
66 let _ = sender.send(());
67 },
68
69 #[cfg(any(test, feature = "__tests"))]
70 Sessions(sender) => {
71 let _ = sender.send(self.active_sessions.clone());
72 },
73 }
74 }
75
76 #[inline]
78 pub fn new() -> Self {
79 Self::default()
80 }
81
82 #[inline]
84 pub async fn run(mut self) -> Result<(), anyhow::Error> {
85 let tcp_listener = TcpListener::bind(self.listening_socket).await?;
86
87 let tls_acceptor =
88 self.tls_config.take().map(Arc::new).map(TlsAcceptor::from);
89
90 #[cfg(not(any(test, feature = "__tests")))]
91 let (sender, mut receiver) = mpsc::channel(1024);
92
93 #[cfg(any(test, feature = "__tests"))]
94 let (sender, mut receiver) =
95 (self.sender.take().unwrap(), self.receiver.take().unwrap());
96
97 tokio::spawn(async move {
98 while let Some(msg) = receiver.recv().await {
99 self.handle_message(msg);
100 }
101 });
102
103 loop {
104 let (tcp_stream, client_addr) = tcp_listener.accept().await?;
105
106 let tcp_stream = Compat::new(tcp_stream);
107
108 let sender = sender.clone();
109
110 if let Some(tls_acceptor) = &tls_acceptor {
111 let tls_stream = tls_acceptor.accept(tcp_stream).await?;
112 let run = Connection::run(tls_stream, client_addr, sender);
113 tokio::spawn(run);
114 } else {
115 let run = Connection::run(tcp_stream, client_addr, sender);
116 tokio::spawn(run);
117 }
118 }
119 }
120
121 #[cfg(any(test, feature = "__tests"))]
123 pub fn test_sender(&self) -> TestSender {
124 TestSender::new(self.sender.as_ref().unwrap().clone())
125 }
126}
127
128#[derive(Clone)]
130pub struct ServerBuilder {
131 ip: Ipv4Addr,
133
134 port: u16,
136
137 tls_config: Option<ServerConfig>,
139}
140
141impl Default for ServerBuilder {
142 #[inline]
143 fn default() -> Self {
144 Self {
145 ip: Ipv4Addr::new(0, 0, 0, 0),
146 port: common::SERVER_LISTENING_PORT,
147 tls_config: None,
148 }
149 }
150}
151
152impl ServerBuilder {
153 #[inline]
155 pub fn build(self) -> Server {
156 #[cfg(any(test, feature = "__tests"))]
157 let (sender, receiver) = mpsc::channel(1024);
158
159 Server {
160 active_sessions: Sessions::new(),
161 listening_socket: SocketAddrV4::new(self.ip, self.port),
162 tls_config: self.tls_config,
163 #[cfg(any(test, feature = "__tests"))]
164 sender: Some(sender),
165 #[cfg(any(test, feature = "__tests"))]
166 receiver: Some(receiver),
167 }
168 }
169
170 #[inline]
172 pub fn with_ip(mut self, ip: Ipv4Addr) -> Self {
173 self.ip = ip;
174 self
175 }
176
177 #[inline]
179 pub fn with_port(mut self, port: u16) -> Self {
180 self.port = port;
181 self
182 }
183
184 #[inline]
186 pub fn with_tls_config(mut self, config: ServerConfig) -> Self {
187 self.tls_config = Some(config);
188 self
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[tokio::test]
197 async fn server_start_twice() -> Result<(), anyhow::Error> {
198 let server = Server::new();
199
200 let sender = server.test_sender();
201
202 tokio::spawn(async move { server.run().await });
203
204 sender.start().await?;
205
206 let server = Server::new();
207
208 assert!(
209 server.run().await.is_err(),
210 "running the server should fail if the socket is being used"
211 );
212
213 Ok(())
214 }
215}