distant_net/server/
connection.rs

1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4use std::sync::{Arc, Weak};
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use distant_auth::Verifier;
9use log::*;
10use serde::de::DeserializeOwned;
11use serde::Serialize;
12use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
13use tokio::task::JoinHandle;
14
15use super::{ConnectionState, RequestCtx, ServerHandler, ServerReply, ServerState, ShutdownTimer};
16use crate::common::{
17    Backup, Connection, Frame, Interest, Keychain, Response, Transport, UntypedRequest, Version,
18};
19
20pub type ServerKeychain = Keychain<oneshot::Receiver<Backup>>;
21
22/// Time to wait inbetween connection read/write when nothing was read or written on last pass.
23const SLEEP_DURATION: Duration = Duration::from_millis(1);
24
25/// Minimum time between heartbeats to communicate to the client connection.
26const MINIMUM_HEARTBEAT_DURATION: Duration = Duration::from_secs(5);
27
28/// Represents an individual connection on the server.
29pub(super) struct ConnectionTask(JoinHandle<io::Result<()>>);
30
31impl ConnectionTask {
32    /// Starts building a new connection
33    pub fn build() -> ConnectionTaskBuilder<(), (), ()> {
34        ConnectionTaskBuilder::new()
35    }
36
37    /// Returns true if the task has finished
38    pub fn is_finished(&self) -> bool {
39        self.0.is_finished()
40    }
41}
42
43impl Future for ConnectionTask {
44    type Output = io::Result<()>;
45
46    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
47        match Future::poll(Pin::new(&mut self.0), cx) {
48            Poll::Pending => Poll::Pending,
49            Poll::Ready(x) => match x {
50                Ok(x) => Poll::Ready(x),
51                Err(x) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, x))),
52            },
53        }
54    }
55}
56
57/// Represents a builder for a new connection task.
58pub(super) struct ConnectionTaskBuilder<H, S, T> {
59    handler: Weak<H>,
60    state: Weak<ServerState<S>>,
61    keychain: Keychain<oneshot::Receiver<Backup>>,
62    transport: T,
63    shutdown: broadcast::Receiver<()>,
64    shutdown_timer: Weak<RwLock<ShutdownTimer>>,
65    sleep_duration: Duration,
66    heartbeat_duration: Duration,
67    verifier: Weak<Verifier>,
68    version: Version,
69}
70
71impl ConnectionTaskBuilder<(), (), ()> {
72    /// Starts building a new connection.
73    pub fn new() -> Self {
74        Self {
75            handler: Weak::new(),
76            state: Weak::new(),
77            keychain: Keychain::new(),
78            transport: (),
79            shutdown: broadcast::channel(1).1,
80            shutdown_timer: Weak::new(),
81            sleep_duration: SLEEP_DURATION,
82            heartbeat_duration: MINIMUM_HEARTBEAT_DURATION,
83            verifier: Weak::new(),
84            version: Version::default(),
85        }
86    }
87}
88
89impl<H, S, T> ConnectionTaskBuilder<H, S, T> {
90    pub fn handler<U>(self, handler: Weak<U>) -> ConnectionTaskBuilder<U, S, T> {
91        ConnectionTaskBuilder {
92            handler,
93            state: self.state,
94            keychain: self.keychain,
95            transport: self.transport,
96            shutdown: self.shutdown,
97            shutdown_timer: self.shutdown_timer,
98            sleep_duration: self.sleep_duration,
99            heartbeat_duration: self.heartbeat_duration,
100            verifier: self.verifier,
101            version: self.version,
102        }
103    }
104
105    pub fn state<U>(self, state: Weak<ServerState<U>>) -> ConnectionTaskBuilder<H, U, T> {
106        ConnectionTaskBuilder {
107            handler: self.handler,
108            state,
109            keychain: self.keychain,
110            transport: self.transport,
111            shutdown: self.shutdown,
112            shutdown_timer: self.shutdown_timer,
113            sleep_duration: self.sleep_duration,
114            heartbeat_duration: self.heartbeat_duration,
115            verifier: self.verifier,
116            version: self.version,
117        }
118    }
119
120    pub fn keychain(self, keychain: ServerKeychain) -> ConnectionTaskBuilder<H, S, T> {
121        ConnectionTaskBuilder {
122            handler: self.handler,
123            state: self.state,
124            keychain,
125            transport: self.transport,
126            shutdown: self.shutdown,
127            shutdown_timer: self.shutdown_timer,
128            sleep_duration: self.sleep_duration,
129            heartbeat_duration: self.heartbeat_duration,
130            verifier: self.verifier,
131            version: self.version,
132        }
133    }
134
135    pub fn transport<U>(self, transport: U) -> ConnectionTaskBuilder<H, S, U> {
136        ConnectionTaskBuilder {
137            handler: self.handler,
138            keychain: self.keychain,
139            state: self.state,
140            transport,
141            shutdown: self.shutdown,
142            shutdown_timer: self.shutdown_timer,
143            sleep_duration: self.sleep_duration,
144            heartbeat_duration: self.heartbeat_duration,
145            verifier: self.verifier,
146            version: self.version,
147        }
148    }
149
150    pub fn shutdown(self, shutdown: broadcast::Receiver<()>) -> ConnectionTaskBuilder<H, S, T> {
151        ConnectionTaskBuilder {
152            handler: self.handler,
153            state: self.state,
154            keychain: self.keychain,
155            transport: self.transport,
156            shutdown,
157            shutdown_timer: self.shutdown_timer,
158            sleep_duration: self.sleep_duration,
159            heartbeat_duration: self.heartbeat_duration,
160            verifier: self.verifier,
161            version: self.version,
162        }
163    }
164
165    pub fn shutdown_timer(
166        self,
167        shutdown_timer: Weak<RwLock<ShutdownTimer>>,
168    ) -> ConnectionTaskBuilder<H, S, T> {
169        ConnectionTaskBuilder {
170            handler: self.handler,
171            state: self.state,
172            keychain: self.keychain,
173            transport: self.transport,
174            shutdown: self.shutdown,
175            shutdown_timer,
176            sleep_duration: self.sleep_duration,
177            heartbeat_duration: self.heartbeat_duration,
178            verifier: self.verifier,
179            version: self.version,
180        }
181    }
182
183    pub fn sleep_duration(self, sleep_duration: Duration) -> ConnectionTaskBuilder<H, S, T> {
184        ConnectionTaskBuilder {
185            handler: self.handler,
186            state: self.state,
187            keychain: self.keychain,
188            transport: self.transport,
189            shutdown: self.shutdown,
190            shutdown_timer: self.shutdown_timer,
191            sleep_duration,
192            heartbeat_duration: self.heartbeat_duration,
193            verifier: self.verifier,
194            version: self.version,
195        }
196    }
197
198    pub fn heartbeat_duration(
199        self,
200        heartbeat_duration: Duration,
201    ) -> ConnectionTaskBuilder<H, S, T> {
202        ConnectionTaskBuilder {
203            handler: self.handler,
204            state: self.state,
205            keychain: self.keychain,
206            transport: self.transport,
207            shutdown: self.shutdown,
208            shutdown_timer: self.shutdown_timer,
209            sleep_duration: self.sleep_duration,
210            heartbeat_duration,
211            verifier: self.verifier,
212            version: self.version,
213        }
214    }
215
216    pub fn verifier(self, verifier: Weak<Verifier>) -> ConnectionTaskBuilder<H, S, T> {
217        ConnectionTaskBuilder {
218            handler: self.handler,
219            state: self.state,
220            keychain: self.keychain,
221            transport: self.transport,
222            shutdown: self.shutdown,
223            shutdown_timer: self.shutdown_timer,
224            sleep_duration: self.sleep_duration,
225            heartbeat_duration: self.heartbeat_duration,
226            verifier,
227            version: self.version,
228        }
229    }
230
231    pub fn version(self, version: Version) -> ConnectionTaskBuilder<H, S, T> {
232        ConnectionTaskBuilder {
233            handler: self.handler,
234            state: self.state,
235            keychain: self.keychain,
236            transport: self.transport,
237            shutdown: self.shutdown,
238            shutdown_timer: self.shutdown_timer,
239            sleep_duration: self.sleep_duration,
240            heartbeat_duration: self.heartbeat_duration,
241            verifier: self.verifier,
242            version,
243        }
244    }
245}
246
247impl<H, T> ConnectionTaskBuilder<H, Response<H::Response>, T>
248where
249    H: ServerHandler + Sync + 'static,
250    H::Request: DeserializeOwned + Send + Sync + 'static,
251    H::Response: Serialize + Send + 'static,
252    T: Transport + 'static,
253{
254    pub fn spawn(self) -> ConnectionTask {
255        ConnectionTask(tokio::spawn(self.run()))
256    }
257
258    async fn run(self) -> io::Result<()> {
259        let ConnectionTaskBuilder {
260            handler,
261            state,
262            keychain,
263            transport,
264            mut shutdown,
265            shutdown_timer,
266            sleep_duration,
267            heartbeat_duration,
268            verifier,
269            version,
270        } = self;
271
272        // NOTE: This exists purely to make the compiler happy for macro_rules declaration order.
273        let (mut local_shutdown, channel_tx, connection_state) = ConnectionState::channel();
274
275        // Will check if no more connections and restart timer if that's the case
276        macro_rules! terminate_connection {
277            // Prints an error message and does not store state
278            (@fatal $($msg:tt)+) => {
279                error!($($msg)+);
280                terminate_connection!();
281                return Err(io::Error::new(io::ErrorKind::Other, format!($($msg)+)));
282            };
283
284            // Prints an error message and stores state before terminating
285            (@error($tx:ident, $rx:ident) $($msg:tt)+) => {
286                error!($($msg)+);
287                terminate_connection!($tx, $rx);
288                return Err(io::Error::new(io::ErrorKind::Other, format!($($msg)+)));
289            };
290
291            // Prints a debug message and stores state before terminating
292            (@debug($tx:ident, $rx:ident) $($msg:tt)+) => {
293                debug!($($msg)+);
294                terminate_connection!($tx, $rx);
295                return Ok(());
296            };
297
298            // Prints a shutdown message with no connection id and exit without sending state
299            (@shutdown) => {
300                debug!("Shutdown triggered before a connection could be fully established");
301                terminate_connection!();
302                return Ok(());
303            };
304
305            // Prints a shutdown message with no connection id and stores state before terminating
306            (@shutdown) => {
307                debug!("Shutdown triggered before a connection could be fully established");
308                terminate_connection!();
309                return Ok(());
310            };
311
312            // Prints a shutdown message and stores state before terminating
313            (@shutdown($id:ident, $tx:ident, $rx:ident)) => {{
314                debug!("[Conn {}] Shutdown triggered", $id);
315                terminate_connection!($tx, $rx);
316                return Ok(());
317            }};
318
319            // Performs the connection termination by removing it from server state and
320            // restarting the shutdown timer if it was the last connection
321            ($tx:ident, $rx:ident) => {
322                // Send the channels back
323                let _ = channel_tx.send(($tx, $rx));
324
325                terminate_connection!();
326            };
327
328            // Performs the connection termination by removing it from server state and
329            // restarting the shutdown timer if it was the last connection
330            () => {
331                // Restart our shutdown timer if this is the last connection
332                if let Some(state) = Weak::upgrade(&state) {
333                    if let Some(timer) = Weak::upgrade(&shutdown_timer) {
334                        if state.connections.read().await.values().filter(|conn| !conn.is_finished()).count() <= 1 {
335                            debug!("Last connection terminating, so restarting shutdown timer");
336                            timer.write().await.restart();
337                        }
338                    }
339                }
340            };
341        }
342
343        /// Awaits a future to complete, or detects if a signal was received by either the global
344        /// or local shutdown channel. Shutdown only occurs if a signal was received, and any
345        /// errors received by either shutdown channel are ignored.
346        macro_rules! await_or_shutdown {
347            ($(@save($id:ident, $tx:ident, $rx:ident))? $future:expr) => {{
348                let mut f = $future;
349
350                loop {
351                    let use_shutdown = match shutdown.try_recv() {
352                        Ok(_) => {
353                            terminate_connection!(@shutdown $(($id, $tx, $rx))?);
354                        }
355                        Err(broadcast::error::TryRecvError::Empty) => true,
356                        Err(broadcast::error::TryRecvError::Lagged(_)) => true,
357                        Err(broadcast::error::TryRecvError::Closed) => false,
358                    };
359
360                    let use_local_shutdown = match local_shutdown.try_recv() {
361                        Ok(_) => {
362                            terminate_connection!(@shutdown $(($id, $tx, $rx))?);
363                        }
364                        Err(oneshot::error::TryRecvError::Empty) => true,
365                        Err(oneshot::error::TryRecvError::Closed) => false,
366                    };
367
368                    if use_shutdown && use_local_shutdown {
369                        tokio::select! {
370                            x = shutdown.recv() => {
371                                if x.is_err() {
372                                    continue;
373                                }
374
375                                terminate_connection!(@shutdown $(($id, $tx, $rx))?);
376                            }
377                            x = &mut local_shutdown => {
378                                if x.is_err() {
379                                    continue;
380                                }
381
382                                terminate_connection!(@shutdown $(($id, $tx, $rx))?);
383                            }
384                            x = &mut f => { break x; }
385                        }
386                    } else if use_shutdown {
387                        tokio::select! {
388                            x = shutdown.recv() => {
389                                if x.is_err() {
390                                    continue;
391                                }
392
393                                terminate_connection!(@shutdown $(($id, $tx, $rx))?);
394                            }
395                            x = &mut f => { break x; }
396                        }
397                    } else if use_local_shutdown {
398                        tokio::select! {
399                            x = &mut local_shutdown => {
400                                if x.is_err() {
401                                    continue;
402                                }
403
404                                terminate_connection!(@shutdown $(($id, $tx, $rx))?);
405                            }
406                            x = &mut f => { break x; }
407                        }
408                    } else {
409                        break f.await;
410                    }
411                }
412            }};
413        }
414
415        // Attempt to upgrade our handler for use with the connection going forward
416        let handler = match Weak::upgrade(&handler) {
417            Some(handler) => handler,
418            None => {
419                terminate_connection!(@fatal "Failed to setup connection because handler dropped");
420            }
421        };
422
423        // Attempt to upgrade our state for use with the connection going forward
424        let state = match Weak::upgrade(&state) {
425            Some(state) => state,
426            None => {
427                terminate_connection!(@fatal "Failed to setup connection because state dropped");
428            }
429        };
430
431        // Properly establish the connection's transport
432        debug!("Establishing full connection using {transport:?}");
433        let mut connection = match Weak::upgrade(&verifier) {
434            Some(verifier) => {
435                match await_or_shutdown!(Box::pin(Connection::server(
436                    transport,
437                    verifier.as_ref(),
438                    keychain,
439                    version
440                ))) {
441                    Ok(connection) => connection,
442                    Err(x) => {
443                        terminate_connection!(@fatal "Failed to setup connection: {x}");
444                    }
445                }
446            }
447            None => {
448                terminate_connection!(@fatal "Verifier has been dropped");
449            }
450        };
451
452        // Update our id to be the connection id
453        let id = connection.id();
454
455        // Create local data for the connection and then process it
456        info!("[Conn {id}] Connection established");
457        if let Err(x) = await_or_shutdown!(handler.on_connect(id)) {
458            terminate_connection!(@fatal "[Conn {id}] Accepting connection failed: {x}");
459        }
460
461        let mut last_heartbeat = Instant::now();
462
463        // Restore our connection's channels if we have them, otherwise make new ones
464        let (tx, mut rx) = match state.connections.write().await.remove(&id) {
465            Some(conn) => match conn.shutdown_and_wait().await {
466                Some(x) => {
467                    debug!("[Conn {id}] Marked as existing connection");
468                    x
469                }
470                None => {
471                    warn!("[Conn {id}] Existing connection with id, but channels not saved");
472                    mpsc::unbounded_channel::<Response<H::Response>>()
473                }
474            },
475            None => {
476                debug!("[Conn {id}] Marked as new connection");
477                mpsc::unbounded_channel::<Response<H::Response>>()
478            }
479        };
480
481        // Store our connection details
482        state.connections.write().await.insert(id, connection_state);
483
484        debug!("[Conn {id}] Beginning read/write loop");
485        loop {
486            let ready = match await_or_shutdown!(
487                @save(id, tx, rx)
488                Box::pin(connection.ready(Interest::READABLE | Interest::WRITABLE))
489            ) {
490                Ok(ready) => ready,
491                Err(x) => {
492                    terminate_connection!(@error(tx, rx) "[Conn {id}] Failed to examine ready state: {x}");
493                }
494            };
495
496            // Keep track of whether we read or wrote anything
497            let mut read_blocked = !ready.is_readable();
498            let mut write_blocked = !ready.is_writable();
499
500            if ready.is_readable() {
501                match connection.try_read_frame() {
502                    Ok(Some(frame)) => match UntypedRequest::from_slice(frame.as_item()) {
503                        Ok(request) => match request.to_typed_request() {
504                            Ok(request) => {
505                                if log::log_enabled!(Level::Debug) {
506                                    let debug_header = if !request.header.is_empty() {
507                                        format!(" | header {}", request.header)
508                                    } else {
509                                        String::new()
510                                    };
511                                    debug!("[Conn {id}] New request {}{debug_header}", request.id);
512                                }
513                                let origin_id = request.id.clone();
514                                let ctx = RequestCtx {
515                                    connection_id: id,
516                                    request,
517                                    reply: ServerReply {
518                                        origin_id,
519                                        tx: tx.clone(),
520                                    },
521                                };
522
523                                // Spawn a new task to run the request handler so we don't block
524                                // our connection from processing other requests
525                                let handler = Arc::clone(&handler);
526                                tokio::spawn(async move { handler.on_request(ctx).await });
527                            }
528                            Err(x) => {
529                                if log::log_enabled!(Level::Debug) {
530                                    error!(
531                                        "[Conn {id}] Failed receiving {}",
532                                        String::from_utf8_lossy(&request.payload),
533                                    );
534                                }
535
536                                error!("[Conn {id}] Invalid request: {x}");
537                            }
538                        },
539                        Err(x) => {
540                            error!("[Conn {id}] Invalid request payload: {x}");
541                        }
542                    },
543                    Ok(None) => {
544                        terminate_connection!(@debug(tx, rx) "[Conn {id}] Connection closed");
545                    }
546                    Err(x) if x.kind() == io::ErrorKind::WouldBlock => read_blocked = true,
547                    Err(x) => {
548                        terminate_connection!(@error(tx, rx) "[Conn {id}] {x}");
549                    }
550                }
551            }
552
553            // If our socket is ready to be written to, we try to get the next item from
554            // the queue and process it
555            if ready.is_writable() {
556                // Send a heartbeat if we have exceeded our last time
557                if last_heartbeat.elapsed() >= heartbeat_duration {
558                    trace!("[Conn {id}] Sending heartbeat via empty frame");
559                    match connection.try_write_frame(Frame::empty()) {
560                        Ok(()) => (),
561                        Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
562                        Err(x) => error!("[Conn {id}] Send failed: {x}"),
563                    }
564                    last_heartbeat = Instant::now();
565                }
566                // If we get more data to write, attempt to write it, which will result in writing
567                // any queued bytes as well. Othewise, we attempt to flush any pending outgoing
568                // bytes that weren't sent earlier.
569                else if let Ok(response) = rx.try_recv() {
570                    // Log our message as a string, which can be expensive
571                    if log_enabled!(Level::Trace) {
572                        trace!(
573                            "[Conn {id}] Sending {}",
574                            &response
575                                .to_vec()
576                                .map(|x| String::from_utf8_lossy(&x).to_string())
577                                .unwrap_or_else(|_| "<Cannot serialize>".to_string())
578                        );
579                    }
580
581                    match response.to_vec() {
582                        Ok(data) => match connection.try_write_frame(data) {
583                            Ok(()) => (),
584                            Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
585                            Err(x) => error!("[Conn {id}] Send failed: {x}"),
586                        },
587                        Err(x) => {
588                            error!("[Conn {id}] Unable to serialize outgoing response: {x}");
589                        }
590                    }
591                } else {
592                    // In the case of flushing, there are two scenarios in which we want to
593                    // mark no write occurring:
594                    //
595                    // 1. When flush did not write any bytes, which can happen when the buffer
596                    //    is empty
597                    // 2. When the call to write bytes blocks
598                    match connection.try_flush() {
599                        Ok(0) => write_blocked = true,
600                        Ok(_) => (),
601                        Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
602                        Err(x) => {
603                            error!("[Conn {id}] Failed to flush outgoing data: {x}");
604                        }
605                    }
606                }
607            }
608
609            // If we did not read or write anything, sleep a bit to offload CPU usage
610            if read_blocked && write_blocked {
611                tokio::time::sleep(sleep_duration).await;
612            }
613        }
614    }
615}
616
617#[cfg(test)]
618mod tests {
619    use std::sync::atomic::{AtomicBool, Ordering};
620
621    use async_trait::async_trait;
622    use distant_auth::DummyAuthHandler;
623    use test_log::test;
624
625    use super::*;
626    use crate::common::{
627        HeapSecretKey, InmemoryTransport, Ready, Reconnectable, Request, Response,
628    };
629    use crate::server::{ConnectionId, Shutdown};
630
631    struct TestServerHandler;
632
633    #[async_trait]
634    impl ServerHandler for TestServerHandler {
635        type Request = u16;
636        type Response = String;
637
638        async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
639            // Always send back "hello"
640            ctx.reply.send("hello".to_string()).unwrap();
641        }
642    }
643
644    macro_rules! wait_for_termination {
645        ($task:ident) => {{
646            let timeout_millis = 500;
647            let sleep_millis = 50;
648            let start = std::time::Instant::now();
649            while !$task.is_finished() {
650                if start.elapsed() > std::time::Duration::from_millis(timeout_millis) {
651                    panic!("Exceeded timeout of {timeout_millis}ms");
652                }
653                tokio::time::sleep(std::time::Duration::from_millis(sleep_millis)).await;
654            }
655        }};
656    }
657
658    macro_rules! server_version {
659        () => {
660            Version::new(1, 2, 3)
661        };
662    }
663
664    #[test(tokio::test)]
665    async fn should_terminate_if_fails_access_verifier() {
666        let handler = Arc::new(TestServerHandler);
667        let state = Arc::new(ServerState::default());
668        let keychain = ServerKeychain::new();
669        let (t1, _t2) = InmemoryTransport::pair(100);
670        let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
671
672        let task = ConnectionTask::build()
673            .handler(Arc::downgrade(&handler))
674            .state(Arc::downgrade(&state))
675            .keychain(keychain)
676            .transport(t1)
677            .shutdown_timer(Arc::downgrade(&shutdown_timer))
678            .verifier(Weak::new())
679            .spawn();
680
681        wait_for_termination!(task);
682
683        let err = task.await.unwrap_err();
684        assert!(
685            err.to_string().contains("Verifier has been dropped"),
686            "Unexpected error: {err}"
687        );
688    }
689
690    #[test(tokio::test)]
691    async fn should_terminate_if_fails_to_setup_server_connection() {
692        let handler = Arc::new(TestServerHandler);
693        let state = Arc::new(ServerState::default());
694        let keychain = ServerKeychain::new();
695        let (t1, t2) = InmemoryTransport::pair(100);
696        let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
697
698        // Create a verifier that wants a key, so we will fail from client-side
699        let verifier = Arc::new(Verifier::static_key(HeapSecretKey::generate(32).unwrap()));
700
701        let task = ConnectionTask::build()
702            .handler(Arc::downgrade(&handler))
703            .state(Arc::downgrade(&state))
704            .keychain(keychain)
705            .transport(t1)
706            .shutdown_timer(Arc::downgrade(&shutdown_timer))
707            .verifier(Arc::downgrade(&verifier))
708            .version(server_version!())
709            .spawn();
710
711        // Spawn a task to handle establishing connection from client-side
712        tokio::spawn(async move {
713            let _client = Connection::client(t2, DummyAuthHandler, server_version!())
714                .await
715                .expect("Fail to establish client-side connection");
716        });
717
718        wait_for_termination!(task);
719
720        let err = task.await.unwrap_err();
721        assert!(
722            err.to_string().contains("Failed to setup connection"),
723            "Unexpected error: {err}"
724        );
725    }
726
727    #[test(tokio::test)]
728    async fn should_terminate_if_fails_access_server_handler() {
729        let state = Arc::new(ServerState::default());
730        let keychain = ServerKeychain::new();
731        let (t1, t2) = InmemoryTransport::pair(100);
732        let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
733        let verifier = Arc::new(Verifier::none());
734
735        let task = ConnectionTask::build()
736            .handler(Weak::<TestServerHandler>::new())
737            .state(Arc::downgrade(&state))
738            .keychain(keychain)
739            .transport(t1)
740            .shutdown_timer(Arc::downgrade(&shutdown_timer))
741            .verifier(Arc::downgrade(&verifier))
742            .version(server_version!())
743            .spawn();
744
745        // Spawn a task to handle establishing connection from client-side
746        tokio::spawn(async move {
747            let _client = Connection::client(t2, DummyAuthHandler, server_version!())
748                .await
749                .expect("Fail to establish client-side connection");
750        });
751
752        wait_for_termination!(task);
753
754        let err = task.await.unwrap_err();
755        assert!(
756            err.to_string().contains("handler dropped"),
757            "Unexpected error: {err}"
758        );
759    }
760
761    #[test(tokio::test)]
762    async fn should_terminate_if_accepting_connection_fails_on_server_handler() {
763        struct BadAcceptServerHandler;
764
765        #[async_trait]
766        impl ServerHandler for BadAcceptServerHandler {
767            type Request = u16;
768            type Response = String;
769
770            async fn on_connect(&self, _: ConnectionId) -> io::Result<()> {
771                Err(io::Error::new(io::ErrorKind::Other, "bad connect"))
772            }
773
774            async fn on_request(&self, _: RequestCtx<Self::Request, Self::Response>) {
775                unreachable!();
776            }
777        }
778
779        let handler = Arc::new(BadAcceptServerHandler);
780        let state = Arc::new(ServerState::default());
781        let keychain = ServerKeychain::new();
782        let (t1, t2) = InmemoryTransport::pair(100);
783        let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
784        let verifier = Arc::new(Verifier::none());
785
786        let task = ConnectionTask::build()
787            .handler(Arc::downgrade(&handler))
788            .state(Arc::downgrade(&state))
789            .keychain(keychain)
790            .transport(t1)
791            .shutdown_timer(Arc::downgrade(&shutdown_timer))
792            .verifier(Arc::downgrade(&verifier))
793            .version(server_version!())
794            .spawn();
795
796        // Spawn a task to handle establishing connection from client-side, and then closes to
797        // trigger the server-side to close
798        tokio::spawn(async move {
799            let _client = Connection::client(t2, DummyAuthHandler, server_version!())
800                .await
801                .expect("Fail to establish client-side connection");
802        });
803
804        wait_for_termination!(task);
805
806        let err = task.await.unwrap_err();
807        assert!(
808            err.to_string().contains("Accepting connection failed"),
809            "Unexpected error: {err}"
810        );
811    }
812
813    #[test(tokio::test)]
814    async fn should_terminate_if_connection_fails_to_become_ready() {
815        let handler = Arc::new(TestServerHandler);
816        let state = Arc::new(ServerState::default());
817        let keychain = ServerKeychain::new();
818        let (t1, t2) = InmemoryTransport::pair(100);
819        let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
820        let verifier = Arc::new(Verifier::none());
821
822        #[derive(Debug)]
823        struct FakeTransport {
824            inner: InmemoryTransport,
825            fail_ready: Arc<AtomicBool>,
826        }
827
828        #[async_trait]
829        impl Transport for FakeTransport {
830            fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
831                self.inner.try_read(buf)
832            }
833
834            fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
835                self.inner.try_write(buf)
836            }
837
838            async fn ready(&self, interest: Interest) -> io::Result<Ready> {
839                if self.fail_ready.load(Ordering::Relaxed) {
840                    Err(io::Error::new(
841                        io::ErrorKind::Other,
842                        "targeted ready failure",
843                    ))
844                } else {
845                    self.inner.ready(interest).await
846                }
847            }
848        }
849
850        #[async_trait]
851        impl Reconnectable for FakeTransport {
852            async fn reconnect(&mut self) -> io::Result<()> {
853                self.inner.reconnect().await
854            }
855        }
856
857        let fail_ready = Arc::new(AtomicBool::new(false));
858        let task = ConnectionTask::build()
859            .handler(Arc::downgrade(&handler))
860            .state(Arc::downgrade(&state))
861            .keychain(keychain)
862            .transport(FakeTransport {
863                inner: t1,
864                fail_ready: Arc::clone(&fail_ready),
865            })
866            .shutdown_timer(Arc::downgrade(&shutdown_timer))
867            .verifier(Arc::downgrade(&verifier))
868            .version(server_version!())
869            .spawn();
870
871        // Spawn a task to handle establishing connection from client-side, set ready to fail
872        // for the server-side after client connection completes, and wait a bit
873        tokio::spawn(async move {
874            let _client = Connection::client(t2, DummyAuthHandler, server_version!())
875                .await
876                .expect("Fail to establish client-side connection");
877
878            // NOTE: Need to sleep for a little bit to hand control back to server to finish
879            //       its side of the connection before toggling ready to fail
880            tokio::time::sleep(Duration::from_millis(50)).await;
881
882            // Toggle ready to fail and then wait awhile so we fail by ready and not connection
883            // being dropped
884            fail_ready.store(true, Ordering::Relaxed);
885            tokio::time::sleep(Duration::from_secs(1)).await;
886        });
887
888        wait_for_termination!(task);
889
890        let err = task.await.unwrap_err();
891        assert!(
892            err.to_string().contains("targeted ready failure"),
893            "Unexpected error: {err}"
894        );
895    }
896
897    #[test(tokio::test)]
898    async fn should_terminate_if_connection_closes() {
899        let handler = Arc::new(TestServerHandler);
900        let state = Arc::new(ServerState::default());
901        let keychain = ServerKeychain::new();
902        let (t1, t2) = InmemoryTransport::pair(100);
903        let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
904        let verifier = Arc::new(Verifier::none());
905
906        let task = ConnectionTask::build()
907            .handler(Arc::downgrade(&handler))
908            .state(Arc::downgrade(&state))
909            .keychain(keychain)
910            .transport(t1)
911            .shutdown_timer(Arc::downgrade(&shutdown_timer))
912            .verifier(Arc::downgrade(&verifier))
913            .version(server_version!())
914            .spawn();
915
916        // Spawn a task to handle establishing connection from client-side, and then closes to
917        // trigger the server-side to close
918        tokio::spawn(async move {
919            let _client = Connection::client(t2, DummyAuthHandler, server_version!())
920                .await
921                .expect("Fail to establish client-side connection");
922        });
923
924        wait_for_termination!(task);
925        task.await.unwrap();
926    }
927
928    #[test(tokio::test)]
929    async fn should_invoke_server_handler_to_process_request_in_new_task_and_forward_responses() {
930        let handler = Arc::new(TestServerHandler);
931        let state = Arc::new(ServerState::default());
932        let keychain = ServerKeychain::new();
933        let (t1, t2) = InmemoryTransport::pair(100);
934        let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
935        let verifier = Arc::new(Verifier::none());
936
937        let _conn = ConnectionTask::build()
938            .handler(Arc::downgrade(&handler))
939            .state(Arc::downgrade(&state))
940            .keychain(keychain)
941            .transport(t1)
942            .shutdown_timer(Arc::downgrade(&shutdown_timer))
943            .verifier(Arc::downgrade(&verifier))
944            .version(server_version!())
945            .spawn();
946
947        // Spawn a task to handle establishing connection from client-side
948        let task = tokio::spawn(async move {
949            let mut client = Connection::client(t2, DummyAuthHandler, server_version!())
950                .await
951                .expect("Fail to establish client-side connection");
952
953            client.write_frame_for(&Request::new(123u16)).await.unwrap();
954            client
955                .read_frame_as::<Response<String>>()
956                .await
957                .unwrap()
958                .unwrap()
959        });
960
961        let response = task.await.unwrap();
962        assert_eq!(response.payload, "hello");
963    }
964
965    #[test(tokio::test)]
966    async fn should_send_heartbeat_via_empty_frame_every_minimum_duration() {
967        let handler = Arc::new(TestServerHandler);
968        let state = Arc::new(ServerState::default());
969        let keychain = ServerKeychain::new();
970        let (t1, t2) = InmemoryTransport::pair(100);
971        let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
972        let verifier = Arc::new(Verifier::none());
973
974        let _conn = ConnectionTask::build()
975            .handler(Arc::downgrade(&handler))
976            .state(Arc::downgrade(&state))
977            .keychain(keychain)
978            .transport(t1)
979            .shutdown_timer(Arc::downgrade(&shutdown_timer))
980            .heartbeat_duration(Duration::from_millis(200))
981            .verifier(Arc::downgrade(&verifier))
982            .version(server_version!())
983            .spawn();
984
985        // Spawn a task to handle establishing connection from client-side
986        let task = tokio::spawn(async move {
987            let mut client = Connection::client(t2, DummyAuthHandler, server_version!())
988                .await
989                .expect("Fail to establish client-side connection");
990
991            // Verify we don't get a frame immediately
992            assert_eq!(
993                client.try_read_frame().unwrap_err().kind(),
994                io::ErrorKind::WouldBlock,
995                "got a frame early"
996            );
997
998            // Sleep more than our minimum heartbeat duration to ensure we get one
999            tokio::time::sleep(Duration::from_millis(250)).await;
1000            assert_eq!(
1001                client.read_frame().await.unwrap().unwrap(),
1002                Frame::empty(),
1003                "non-empty frame"
1004            );
1005
1006            // Verify we don't get a frame immediately
1007            assert_eq!(
1008                client.try_read_frame().unwrap_err().kind(),
1009                io::ErrorKind::WouldBlock,
1010                "got a frame early"
1011            );
1012
1013            // Sleep more than our minimum heartbeat duration to ensure we get one
1014            tokio::time::sleep(Duration::from_millis(250)).await;
1015            assert_eq!(
1016                client.read_frame().await.unwrap().unwrap(),
1017                Frame::empty(),
1018                "non-empty frame"
1019            );
1020        });
1021
1022        task.await.unwrap();
1023    }
1024
1025    #[test(tokio::test)]
1026    async fn should_be_able_to_shutdown_while_establishing_connection() {
1027        let handler = Arc::new(TestServerHandler);
1028        let state = Arc::new(ServerState::default());
1029        let keychain = ServerKeychain::new();
1030        let (t1, _t2) = InmemoryTransport::pair(100);
1031        let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
1032        let verifier = Arc::new(Verifier::none());
1033
1034        let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
1035        let conn = ConnectionTask::build()
1036            .handler(Arc::downgrade(&handler))
1037            .state(Arc::downgrade(&state))
1038            .keychain(keychain)
1039            .transport(t1)
1040            .shutdown(shutdown_rx)
1041            .shutdown_timer(Arc::downgrade(&shutdown_timer))
1042            .heartbeat_duration(Duration::from_millis(200))
1043            .verifier(Arc::downgrade(&verifier))
1044            .spawn();
1045
1046        // Shutdown server connection task while it is establishing a full connection with the
1047        // client, verifying that we do not get an error in return
1048        shutdown_tx
1049            .send(())
1050            .expect("Failed to send shutdown signal");
1051        conn.await.unwrap();
1052    }
1053
1054    #[test(tokio::test)]
1055    async fn should_be_able_to_shutdown_while_accepting_connection() {
1056        struct HangingAcceptServerHandler;
1057
1058        #[async_trait]
1059        impl ServerHandler for HangingAcceptServerHandler {
1060            type Request = ();
1061            type Response = ();
1062
1063            async fn on_connect(&self, _: ConnectionId) -> io::Result<()> {
1064                // Wait "forever" so we can ensure that we fail at this step
1065                tokio::time::sleep(Duration::MAX).await;
1066                Err(io::Error::new(io::ErrorKind::Other, "bad connect"))
1067            }
1068
1069            async fn on_request(&self, _: RequestCtx<Self::Request, Self::Response>) {
1070                unreachable!();
1071            }
1072        }
1073
1074        let handler = Arc::new(HangingAcceptServerHandler);
1075        let state = Arc::new(ServerState::default());
1076        let keychain = ServerKeychain::new();
1077        let (t1, t2) = InmemoryTransport::pair(100);
1078        let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
1079        let verifier = Arc::new(Verifier::none());
1080
1081        let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
1082        let conn = ConnectionTask::build()
1083            .handler(Arc::downgrade(&handler))
1084            .state(Arc::downgrade(&state))
1085            .keychain(keychain)
1086            .transport(t1)
1087            .shutdown(shutdown_rx)
1088            .shutdown_timer(Arc::downgrade(&shutdown_timer))
1089            .heartbeat_duration(Duration::from_millis(200))
1090            .verifier(Arc::downgrade(&verifier))
1091            .version(server_version!())
1092            .spawn();
1093
1094        // Spawn a task to handle the client-side establishment of a full connection
1095        let _client_task =
1096            tokio::spawn(Connection::client(t2, DummyAuthHandler, server_version!()));
1097
1098        // Shutdown server connection task while it is accepting the connection, verifying that we
1099        // do not get an error in return
1100        shutdown_tx
1101            .send(())
1102            .expect("Failed to send shutdown signal");
1103        conn.await.unwrap();
1104    }
1105
1106    #[test(tokio::test)]
1107    async fn should_be_able_to_shutdown_while_waiting_for_connection_to_be_ready() {
1108        struct AcceptServerHandler {
1109            tx: mpsc::Sender<()>,
1110        }
1111
1112        #[async_trait]
1113        impl ServerHandler for AcceptServerHandler {
1114            type Request = ();
1115            type Response = ();
1116
1117            async fn on_connect(&self, _: ConnectionId) -> io::Result<()> {
1118                self.tx.send(()).await.unwrap();
1119                Ok(())
1120            }
1121
1122            async fn on_request(&self, _: RequestCtx<Self::Request, Self::Response>) {
1123                unreachable!();
1124            }
1125        }
1126
1127        let (tx, mut rx) = mpsc::channel(100);
1128        let handler = Arc::new(AcceptServerHandler { tx });
1129        let state = Arc::new(ServerState::default());
1130        let keychain = ServerKeychain::new();
1131        let (t1, t2) = InmemoryTransport::pair(100);
1132        let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
1133        let verifier = Arc::new(Verifier::none());
1134
1135        let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
1136        let conn = ConnectionTask::build()
1137            .handler(Arc::downgrade(&handler))
1138            .state(Arc::downgrade(&state))
1139            .keychain(keychain)
1140            .transport(t1)
1141            .shutdown(shutdown_rx)
1142            .shutdown_timer(Arc::downgrade(&shutdown_timer))
1143            .heartbeat_duration(Duration::from_millis(200))
1144            .verifier(Arc::downgrade(&verifier))
1145            .version(server_version!())
1146            .spawn();
1147
1148        // Spawn a task to handle the client-side establishment of a full connection
1149        let _client_task =
1150            tokio::spawn(Connection::client(t2, DummyAuthHandler, server_version!()));
1151
1152        // Wait to ensure we complete the accept call first
1153        let _ = rx.recv().await;
1154
1155        // Shutdown server connection task while it is accepting the connection, verifying that we
1156        // do not get an error in return
1157        shutdown_tx
1158            .send(())
1159            .expect("Failed to send shutdown signal");
1160        conn.await.unwrap();
1161    }
1162}