1use std::io;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6use distant_auth::Verifier;
7use log::*;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use tokio::sync::{broadcast, RwLock};
11
12use crate::common::{ConnectionId, Listener, Response, Transport, Version};
13
14mod builder;
15pub use builder::*;
16
17mod config;
18pub use config::*;
19
20mod connection;
21use connection::*;
22
23mod context;
24pub use context::*;
25
26mod r#ref;
27pub use r#ref::*;
28
29mod reply;
30pub use reply::*;
31
32mod state;
33use state::*;
34
35mod shutdown_timer;
36use shutdown_timer::*;
37
38pub struct Server<T> {
40 config: ServerConfig,
42
43 handler: T,
45
46 verifier: Verifier,
48
49 version: Version,
51}
52
53#[async_trait]
55pub trait ServerHandler: Send {
56 type Request;
58
59 type Response;
61
62 #[allow(unused_variables)]
64 async fn on_connect(&self, id: ConnectionId) -> io::Result<()> {
65 Ok(())
66 }
67
68 #[allow(unused_variables)]
70 async fn on_disconnect(&self, id: ConnectionId) -> io::Result<()> {
71 Ok(())
72 }
73
74 async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>);
77}
78
79impl Server<()> {
80 pub fn new() -> Self {
83 Self {
84 config: Default::default(),
85 handler: (),
86 verifier: Verifier::empty(),
87 version: Default::default(),
88 }
89 }
90
91 pub fn tcp() -> TcpServerBuilder<()> {
93 TcpServerBuilder::default()
94 }
95
96 #[cfg(unix)]
98 pub fn unix_socket() -> UnixSocketServerBuilder<()> {
99 UnixSocketServerBuilder::default()
100 }
101
102 #[cfg(windows)]
104 pub fn windows_pipe() -> WindowsPipeServerBuilder<()> {
105 WindowsPipeServerBuilder::default()
106 }
107}
108
109impl Default for Server<()> {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115impl<T> Server<T> {
116 pub fn config(self, config: ServerConfig) -> Self {
118 Self {
119 config,
120 handler: self.handler,
121 verifier: self.verifier,
122 version: self.version,
123 }
124 }
125
126 pub fn handler<U>(self, handler: U) -> Server<U> {
128 Server {
129 config: self.config,
130 handler,
131 verifier: self.verifier,
132 version: self.version,
133 }
134 }
135
136 pub fn verifier(self, verifier: Verifier) -> Self {
138 Self {
139 config: self.config,
140 handler: self.handler,
141 verifier,
142 version: self.version,
143 }
144 }
145
146 pub fn version(self, version: Version) -> Self {
148 Self {
149 config: self.config,
150 handler: self.handler,
151 verifier: self.verifier,
152 version,
153 }
154 }
155}
156
157impl<T> Server<T>
158where
159 T: ServerHandler + Sync + 'static,
160 T::Request: DeserializeOwned + Send + Sync + 'static,
161 T::Response: Serialize + Send + 'static,
162{
163 pub fn start<L>(self, listener: L) -> io::Result<ServerRef>
166 where
167 L: Listener + 'static,
168 L::Output: Transport + 'static,
169 {
170 let state = Arc::new(ServerState::new());
171 let (tx, rx) = broadcast::channel(1);
172 let task = tokio::spawn(self.task(Arc::clone(&state), listener, tx.clone(), rx));
173
174 Ok(ServerRef { shutdown: tx, task })
175 }
176
177 async fn task<L>(
179 self,
180 state: Arc<ServerState<Response<T::Response>>>,
181 mut listener: L,
182 shutdown_tx: broadcast::Sender<()>,
183 shutdown_rx: broadcast::Receiver<()>,
184 ) where
185 L: Listener + 'static,
186 L::Output: Transport + 'static,
187 {
188 let Server {
189 config,
190 handler,
191 verifier,
192 version,
193 } = self;
194
195 let handler = Arc::new(handler);
196 let timer = ShutdownTimer::start(config.shutdown);
197 let mut notification = timer.clone_notification();
198 let timer = Arc::new(RwLock::new(timer));
199 let verifier = Arc::new(verifier);
200
201 let mut connection_tasks = Vec::new();
202 loop {
203 let transport = tokio::select! {
206 result = listener.accept() => {
207 match result {
208 Ok(x) => x,
209 Err(x) => {
210 error!("Server no longer accepting connections: {x}");
211 timer.read().await.abort();
212 break;
213 }
214 }
215 }
216 _ = notification.wait() => {
217 info!(
218 "Server shutdown triggered after {}s",
219 config.shutdown.duration().unwrap_or_default().as_secs_f32(),
220 );
221
222 let _ = shutdown_tx.send(());
223
224 break;
225 }
226 };
227
228 timer.read().await.stop();
230
231 connection_tasks.push(
232 ConnectionTask::build()
233 .handler(Arc::downgrade(&handler))
234 .state(Arc::downgrade(&state))
235 .keychain(state.keychain.clone())
236 .transport(transport)
237 .shutdown(shutdown_rx.resubscribe())
238 .shutdown_timer(Arc::downgrade(&timer))
239 .sleep_duration(config.connection_sleep)
240 .heartbeat_duration(config.connection_heartbeat)
241 .verifier(Arc::downgrade(&verifier))
242 .version(version.clone())
243 .spawn(),
244 );
245
246 connection_tasks.retain(|task| !task.is_finished());
248 }
249
250 info!("Server waiting for active connections to terminate");
252 loop {
253 connection_tasks.retain(|task| !task.is_finished());
254 if connection_tasks.is_empty() {
255 break;
256 }
257 tokio::time::sleep(Duration::from_millis(50)).await;
258 }
259 info!("Server task terminated");
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use std::time::Duration;
266
267 use async_trait::async_trait;
268 use distant_auth::{AuthenticationMethod, DummyAuthHandler, NoneAuthenticationMethod};
269 use test_log::test;
270 use tokio::sync::mpsc;
271
272 use super::*;
273 use crate::common::{Connection, InmemoryTransport, MpscListener, Request, Response};
274
275 macro_rules! server_version {
276 () => {
277 Version::new(1, 2, 3)
278 };
279 }
280
281 pub struct TestServerHandler;
282
283 #[async_trait]
284 impl ServerHandler for TestServerHandler {
285 type Request = u16;
286 type Response = String;
287
288 async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
289 ctx.reply.send("hello".to_string()).unwrap();
291 }
292 }
293
294 #[inline]
295 fn make_test_server(config: ServerConfig) -> Server<TestServerHandler> {
296 let methods: Vec<Box<dyn AuthenticationMethod>> =
297 vec![Box::new(NoneAuthenticationMethod::new())];
298
299 Server {
300 config,
301 handler: TestServerHandler,
302 verifier: Verifier::new(methods),
303 version: server_version!(),
304 }
305 }
306
307 #[allow(clippy::type_complexity)]
308 fn make_listener(
309 buffer: usize,
310 ) -> (
311 mpsc::Sender<InmemoryTransport>,
312 MpscListener<InmemoryTransport>,
313 ) {
314 MpscListener::channel(buffer)
315 }
316
317 #[test(tokio::test)]
318 async fn should_invoke_handler_upon_receiving_a_request() {
319 let (tx, listener) = make_listener(100);
321
322 let (transport, connection) = InmemoryTransport::pair(100);
324 tx.send(connection)
325 .await
326 .expect("Failed to feed listener a connection");
327
328 let _server = make_test_server(ServerConfig::default())
329 .start(listener)
330 .expect("Failed to start server");
331
332 let mut connection = Connection::client(transport, DummyAuthHandler, server_version!())
334 .await
335 .expect("Failed to connect to server");
336
337 connection
338 .write_frame(Request::new(123).to_vec().unwrap())
339 .await
340 .expect("Failed to send request");
341
342 let frame = connection.read_frame().await.unwrap().unwrap();
344 let response: Response<String> = Response::from_slice(frame.as_item()).unwrap();
345 assert_eq!(response.payload, "hello");
346 }
347
348 #[test(tokio::test)]
349 async fn should_lonely_shutdown_if_no_connections_received_after_n_secs_when_config_set() {
350 let (_tx, listener) = make_listener(100);
351
352 let server = make_test_server(ServerConfig {
353 shutdown: Shutdown::Lonely(Duration::from_millis(100)),
354 ..Default::default()
355 })
356 .start(listener)
357 .expect("Failed to start server");
358
359 tokio::time::sleep(Duration::from_millis(300)).await;
361
362 assert!(server.is_finished(), "Server shutdown not triggered!");
363 }
364
365 #[test(tokio::test)]
366 async fn should_lonely_shutdown_if_last_connection_terminated_and_then_no_connections_after_n_secs(
367 ) {
368 let (tx, listener) = make_listener(100);
370
371 let (transport, connection) = InmemoryTransport::pair(100);
373 tx.send(connection)
374 .await
375 .expect("Failed to feed listener a connection");
376
377 let server = make_test_server(ServerConfig {
378 shutdown: Shutdown::Lonely(Duration::from_millis(100)),
379 ..Default::default()
380 })
381 .start(listener)
382 .expect("Failed to start server");
383
384 drop(transport);
386
387 tokio::time::sleep(Duration::from_millis(300)).await;
389
390 assert!(server.is_finished(), "Server shutdown not triggered!");
391 }
392
393 #[test(tokio::test)]
394 async fn should_not_lonely_shutdown_as_long_as_a_connection_exists() {
395 let (tx, listener) = make_listener(100);
397
398 let (_transport, connection) = InmemoryTransport::pair(100);
400 tx.send(connection)
401 .await
402 .expect("Failed to feed listener a connection");
403
404 let server = make_test_server(ServerConfig {
405 shutdown: Shutdown::Lonely(Duration::from_millis(100)),
406 ..Default::default()
407 })
408 .start(listener)
409 .expect("Failed to start server");
410
411 tokio::time::sleep(Duration::from_millis(300)).await;
413
414 assert!(!server.is_finished(), "Server shutdown when it should not!");
415 }
416
417 #[test(tokio::test)]
418 async fn should_shutdown_after_n_seconds_even_with_connections_if_config_set_to_after() {
419 let (tx, listener) = make_listener(100);
420
421 let (_transport, connection) = InmemoryTransport::pair(100);
423 tx.send(connection)
424 .await
425 .expect("Failed to feed listener a connection");
426
427 let server = make_test_server(ServerConfig {
428 shutdown: Shutdown::After(Duration::from_millis(100)),
429 ..Default::default()
430 })
431 .start(listener)
432 .expect("Failed to start server");
433
434 tokio::time::sleep(Duration::from_millis(300)).await;
436
437 assert!(server.is_finished(), "Server shutdown not triggered!");
438 }
439
440 #[test(tokio::test)]
441 async fn should_shutdown_after_n_seconds_if_config_set_to_after() {
442 let (_tx, listener) = make_listener(100);
443
444 let server = make_test_server(ServerConfig {
445 shutdown: Shutdown::After(Duration::from_millis(100)),
446 ..Default::default()
447 })
448 .start(listener)
449 .expect("Failed to start server");
450
451 tokio::time::sleep(Duration::from_millis(300)).await;
453
454 assert!(server.is_finished(), "Server shutdown not triggered!");
455 }
456
457 #[test(tokio::test)]
458 async fn should_never_shutdown_if_config_set_to_never() {
459 let (_tx, listener) = make_listener(100);
460
461 let server = make_test_server(ServerConfig {
462 shutdown: Shutdown::Never,
463 ..Default::default()
464 })
465 .start(listener)
466 .expect("Failed to start server");
467
468 tokio::time::sleep(Duration::from_millis(300)).await;
470
471 assert!(!server.is_finished(), "Server shutdown when it should not!");
472 }
473}