distant_net/
server.rs

1use std::io;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6use distant_auth::Verifier;
7use log::*;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use tokio::sync::{broadcast, RwLock};
11
12use crate::common::{ConnectionId, Listener, Response, Transport, Version};
13
14mod builder;
15pub use builder::*;
16
17mod config;
18pub use config::*;
19
20mod connection;
21use connection::*;
22
23mod context;
24pub use context::*;
25
26mod r#ref;
27pub use r#ref::*;
28
29mod reply;
30pub use reply::*;
31
32mod state;
33use state::*;
34
35mod shutdown_timer;
36use shutdown_timer::*;
37
38/// Represents a server that can be used to receive requests & send responses to clients.
39pub struct Server<T> {
40    /// Custom configuration details associated with the server
41    config: ServerConfig,
42
43    /// Handler used to process various server events
44    handler: T,
45
46    /// Performs authentication using various methods
47    verifier: Verifier,
48
49    /// Version associated with the server used by clients to verify compatibility
50    version: Version,
51}
52
53/// Interface for a handler that receives connections and requests
54#[async_trait]
55pub trait ServerHandler: Send {
56    /// Type of data received by the server
57    type Request;
58
59    /// Type of data sent back by the server
60    type Response;
61
62    /// Invoked upon a new connection becoming established.
63    #[allow(unused_variables)]
64    async fn on_connect(&self, id: ConnectionId) -> io::Result<()> {
65        Ok(())
66    }
67
68    /// Invoked upon an existing connection getting dropped.
69    #[allow(unused_variables)]
70    async fn on_disconnect(&self, id: ConnectionId) -> io::Result<()> {
71        Ok(())
72    }
73
74    /// Invoked upon receiving a request from a client. The server should process this
75    /// request, which can be found in `ctx`, and send one or more replies in response.
76    async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>);
77}
78
79impl Server<()> {
80    /// Creates a new [`Server`], starting with a default configuration, no authentication methods,
81    /// and no [`ServerHandler`].
82    pub fn new() -> Self {
83        Self {
84            config: Default::default(),
85            handler: (),
86            verifier: Verifier::empty(),
87            version: Default::default(),
88        }
89    }
90
91    /// Creates a new [`TcpServerBuilder`] that is used to construct a [`Server`].
92    pub fn tcp() -> TcpServerBuilder<()> {
93        TcpServerBuilder::default()
94    }
95
96    /// Creates a new [`UnixSocketServerBuilder`] that is used to construct a [`Server`].
97    #[cfg(unix)]
98    pub fn unix_socket() -> UnixSocketServerBuilder<()> {
99        UnixSocketServerBuilder::default()
100    }
101
102    /// Creates a new [`WindowsPipeServerBuilder`] that is used to construct a [`Server`].
103    #[cfg(windows)]
104    pub fn windows_pipe() -> WindowsPipeServerBuilder<()> {
105        WindowsPipeServerBuilder::default()
106    }
107}
108
109impl Default for Server<()> {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115impl<T> Server<T> {
116    /// Consumes the current server, replacing its config with `config` and returning it.
117    pub fn config(self, config: ServerConfig) -> Self {
118        Self {
119            config,
120            handler: self.handler,
121            verifier: self.verifier,
122            version: self.version,
123        }
124    }
125
126    /// Consumes the current server, replacing its handler with `handler` and returning it.
127    pub fn handler<U>(self, handler: U) -> Server<U> {
128        Server {
129            config: self.config,
130            handler,
131            verifier: self.verifier,
132            version: self.version,
133        }
134    }
135
136    /// Consumes the current server, replacing its verifier with `verifier` and returning it.
137    pub fn verifier(self, verifier: Verifier) -> Self {
138        Self {
139            config: self.config,
140            handler: self.handler,
141            verifier,
142            version: self.version,
143        }
144    }
145
146    /// Consumes the current server, replacing its version with `version` and returning it.
147    pub fn version(self, version: Version) -> Self {
148        Self {
149            config: self.config,
150            handler: self.handler,
151            verifier: self.verifier,
152            version,
153        }
154    }
155}
156
157impl<T> Server<T>
158where
159    T: ServerHandler + Sync + 'static,
160    T::Request: DeserializeOwned + Send + Sync + 'static,
161    T::Response: Serialize + Send + 'static,
162{
163    /// Consumes the server, starting a task to process connections from the `listener` and
164    /// returning a [`ServerRef`] that can be used to control the active server instance.
165    pub fn start<L>(self, listener: L) -> io::Result<ServerRef>
166    where
167        L: Listener + 'static,
168        L::Output: Transport + 'static,
169    {
170        let state = Arc::new(ServerState::new());
171        let (tx, rx) = broadcast::channel(1);
172        let task = tokio::spawn(self.task(Arc::clone(&state), listener, tx.clone(), rx));
173
174        Ok(ServerRef { shutdown: tx, task })
175    }
176
177    /// Internal task that is run to receive connections and spawn connection tasks
178    async fn task<L>(
179        self,
180        state: Arc<ServerState<Response<T::Response>>>,
181        mut listener: L,
182        shutdown_tx: broadcast::Sender<()>,
183        shutdown_rx: broadcast::Receiver<()>,
184    ) where
185        L: Listener + 'static,
186        L::Output: Transport + 'static,
187    {
188        let Server {
189            config,
190            handler,
191            verifier,
192            version,
193        } = self;
194
195        let handler = Arc::new(handler);
196        let timer = ShutdownTimer::start(config.shutdown);
197        let mut notification = timer.clone_notification();
198        let timer = Arc::new(RwLock::new(timer));
199        let verifier = Arc::new(verifier);
200
201        let mut connection_tasks = Vec::new();
202        loop {
203            // Receive a new connection, exiting if no longer accepting connections or if the shutdown
204            // signal has been received
205            let transport = tokio::select! {
206                result = listener.accept() => {
207                    match result {
208                        Ok(x) => x,
209                        Err(x) => {
210                            error!("Server no longer accepting connections: {x}");
211                            timer.read().await.abort();
212                            break;
213                        }
214                    }
215                }
216                _ = notification.wait() => {
217                    info!(
218                        "Server shutdown triggered after {}s",
219                        config.shutdown.duration().unwrap_or_default().as_secs_f32(),
220                    );
221
222                    let _ = shutdown_tx.send(());
223
224                    break;
225                }
226            };
227
228            // Ensure that the shutdown timer is cancelled now that we have a connection
229            timer.read().await.stop();
230
231            connection_tasks.push(
232                ConnectionTask::build()
233                    .handler(Arc::downgrade(&handler))
234                    .state(Arc::downgrade(&state))
235                    .keychain(state.keychain.clone())
236                    .transport(transport)
237                    .shutdown(shutdown_rx.resubscribe())
238                    .shutdown_timer(Arc::downgrade(&timer))
239                    .sleep_duration(config.connection_sleep)
240                    .heartbeat_duration(config.connection_heartbeat)
241                    .verifier(Arc::downgrade(&verifier))
242                    .version(version.clone())
243                    .spawn(),
244            );
245
246            // Clean up current tasks being tracked
247            connection_tasks.retain(|task| !task.is_finished());
248        }
249
250        // Once we stop listening, we still want to wait until all connections have terminated
251        info!("Server waiting for active connections to terminate");
252        loop {
253            connection_tasks.retain(|task| !task.is_finished());
254            if connection_tasks.is_empty() {
255                break;
256            }
257            tokio::time::sleep(Duration::from_millis(50)).await;
258        }
259        info!("Server task terminated");
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use std::time::Duration;
266
267    use async_trait::async_trait;
268    use distant_auth::{AuthenticationMethod, DummyAuthHandler, NoneAuthenticationMethod};
269    use test_log::test;
270    use tokio::sync::mpsc;
271
272    use super::*;
273    use crate::common::{Connection, InmemoryTransport, MpscListener, Request, Response};
274
275    macro_rules! server_version {
276        () => {
277            Version::new(1, 2, 3)
278        };
279    }
280
281    pub struct TestServerHandler;
282
283    #[async_trait]
284    impl ServerHandler for TestServerHandler {
285        type Request = u16;
286        type Response = String;
287
288        async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
289            // Always send back "hello"
290            ctx.reply.send("hello".to_string()).unwrap();
291        }
292    }
293
294    #[inline]
295    fn make_test_server(config: ServerConfig) -> Server<TestServerHandler> {
296        let methods: Vec<Box<dyn AuthenticationMethod>> =
297            vec![Box::new(NoneAuthenticationMethod::new())];
298
299        Server {
300            config,
301            handler: TestServerHandler,
302            verifier: Verifier::new(methods),
303            version: server_version!(),
304        }
305    }
306
307    #[allow(clippy::type_complexity)]
308    fn make_listener(
309        buffer: usize,
310    ) -> (
311        mpsc::Sender<InmemoryTransport>,
312        MpscListener<InmemoryTransport>,
313    ) {
314        MpscListener::channel(buffer)
315    }
316
317    #[test(tokio::test)]
318    async fn should_invoke_handler_upon_receiving_a_request() {
319        // Create a test listener where we will forward a connection
320        let (tx, listener) = make_listener(100);
321
322        // Make bounded transport pair and send off one of them to act as our connection
323        let (transport, connection) = InmemoryTransport::pair(100);
324        tx.send(connection)
325            .await
326            .expect("Failed to feed listener a connection");
327
328        let _server = make_test_server(ServerConfig::default())
329            .start(listener)
330            .expect("Failed to start server");
331
332        // Perform handshake and authentication with the server before beginning to send data
333        let mut connection = Connection::client(transport, DummyAuthHandler, server_version!())
334            .await
335            .expect("Failed to connect to server");
336
337        connection
338            .write_frame(Request::new(123).to_vec().unwrap())
339            .await
340            .expect("Failed to send request");
341
342        // Wait for a response
343        let frame = connection.read_frame().await.unwrap().unwrap();
344        let response: Response<String> = Response::from_slice(frame.as_item()).unwrap();
345        assert_eq!(response.payload, "hello");
346    }
347
348    #[test(tokio::test)]
349    async fn should_lonely_shutdown_if_no_connections_received_after_n_secs_when_config_set() {
350        let (_tx, listener) = make_listener(100);
351
352        let server = make_test_server(ServerConfig {
353            shutdown: Shutdown::Lonely(Duration::from_millis(100)),
354            ..Default::default()
355        })
356        .start(listener)
357        .expect("Failed to start server");
358
359        // Wait for some time
360        tokio::time::sleep(Duration::from_millis(300)).await;
361
362        assert!(server.is_finished(), "Server shutdown not triggered!");
363    }
364
365    #[test(tokio::test)]
366    async fn should_lonely_shutdown_if_last_connection_terminated_and_then_no_connections_after_n_secs(
367    ) {
368        // Create a test listener where we will forward a connection
369        let (tx, listener) = make_listener(100);
370
371        // Make bounded transport pair and send off one of them to act as our connection
372        let (transport, connection) = InmemoryTransport::pair(100);
373        tx.send(connection)
374            .await
375            .expect("Failed to feed listener a connection");
376
377        let server = make_test_server(ServerConfig {
378            shutdown: Shutdown::Lonely(Duration::from_millis(100)),
379            ..Default::default()
380        })
381        .start(listener)
382        .expect("Failed to start server");
383
384        // Drop the connection by dropping the transport
385        drop(transport);
386
387        // Wait for some time
388        tokio::time::sleep(Duration::from_millis(300)).await;
389
390        assert!(server.is_finished(), "Server shutdown not triggered!");
391    }
392
393    #[test(tokio::test)]
394    async fn should_not_lonely_shutdown_as_long_as_a_connection_exists() {
395        // Create a test listener where we will forward a connection
396        let (tx, listener) = make_listener(100);
397
398        // Make bounded transport pair and send off one of them to act as our connection
399        let (_transport, connection) = InmemoryTransport::pair(100);
400        tx.send(connection)
401            .await
402            .expect("Failed to feed listener a connection");
403
404        let server = make_test_server(ServerConfig {
405            shutdown: Shutdown::Lonely(Duration::from_millis(100)),
406            ..Default::default()
407        })
408        .start(listener)
409        .expect("Failed to start server");
410
411        // Wait for some time
412        tokio::time::sleep(Duration::from_millis(300)).await;
413
414        assert!(!server.is_finished(), "Server shutdown when it should not!");
415    }
416
417    #[test(tokio::test)]
418    async fn should_shutdown_after_n_seconds_even_with_connections_if_config_set_to_after() {
419        let (tx, listener) = make_listener(100);
420
421        // Make bounded transport pair and send off one of them to act as our connection
422        let (_transport, connection) = InmemoryTransport::pair(100);
423        tx.send(connection)
424            .await
425            .expect("Failed to feed listener a connection");
426
427        let server = make_test_server(ServerConfig {
428            shutdown: Shutdown::After(Duration::from_millis(100)),
429            ..Default::default()
430        })
431        .start(listener)
432        .expect("Failed to start server");
433
434        // Wait for some time
435        tokio::time::sleep(Duration::from_millis(300)).await;
436
437        assert!(server.is_finished(), "Server shutdown not triggered!");
438    }
439
440    #[test(tokio::test)]
441    async fn should_shutdown_after_n_seconds_if_config_set_to_after() {
442        let (_tx, listener) = make_listener(100);
443
444        let server = make_test_server(ServerConfig {
445            shutdown: Shutdown::After(Duration::from_millis(100)),
446            ..Default::default()
447        })
448        .start(listener)
449        .expect("Failed to start server");
450
451        // Wait for some time
452        tokio::time::sleep(Duration::from_millis(300)).await;
453
454        assert!(server.is_finished(), "Server shutdown not triggered!");
455    }
456
457    #[test(tokio::test)]
458    async fn should_never_shutdown_if_config_set_to_never() {
459        let (_tx, listener) = make_listener(100);
460
461        let server = make_test_server(ServerConfig {
462            shutdown: Shutdown::Never,
463            ..Default::default()
464        })
465        .start(listener)
466        .expect("Failed to start server");
467
468        // Wait for some time
469        tokio::time::sleep(Duration::from_millis(300)).await;
470
471        assert!(!server.is_finished(), "Server shutdown when it should not!");
472    }
473}