1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4use std::sync::{Arc, Weak};
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use distant_auth::Verifier;
9use log::*;
10use serde::de::DeserializeOwned;
11use serde::Serialize;
12use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
13use tokio::task::JoinHandle;
14
15use super::{ConnectionState, RequestCtx, ServerHandler, ServerReply, ServerState, ShutdownTimer};
16use crate::common::{
17 Backup, Connection, Frame, Interest, Keychain, Response, Transport, UntypedRequest, Version,
18};
19
20pub type ServerKeychain = Keychain<oneshot::Receiver<Backup>>;
21
22const SLEEP_DURATION: Duration = Duration::from_millis(1);
24
25const MINIMUM_HEARTBEAT_DURATION: Duration = Duration::from_secs(5);
27
28pub(super) struct ConnectionTask(JoinHandle<io::Result<()>>);
30
31impl ConnectionTask {
32 pub fn build() -> ConnectionTaskBuilder<(), (), ()> {
34 ConnectionTaskBuilder::new()
35 }
36
37 pub fn is_finished(&self) -> bool {
39 self.0.is_finished()
40 }
41}
42
43impl Future for ConnectionTask {
44 type Output = io::Result<()>;
45
46 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
47 match Future::poll(Pin::new(&mut self.0), cx) {
48 Poll::Pending => Poll::Pending,
49 Poll::Ready(x) => match x {
50 Ok(x) => Poll::Ready(x),
51 Err(x) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, x))),
52 },
53 }
54 }
55}
56
57pub(super) struct ConnectionTaskBuilder<H, S, T> {
59 handler: Weak<H>,
60 state: Weak<ServerState<S>>,
61 keychain: Keychain<oneshot::Receiver<Backup>>,
62 transport: T,
63 shutdown: broadcast::Receiver<()>,
64 shutdown_timer: Weak<RwLock<ShutdownTimer>>,
65 sleep_duration: Duration,
66 heartbeat_duration: Duration,
67 verifier: Weak<Verifier>,
68 version: Version,
69}
70
71impl ConnectionTaskBuilder<(), (), ()> {
72 pub fn new() -> Self {
74 Self {
75 handler: Weak::new(),
76 state: Weak::new(),
77 keychain: Keychain::new(),
78 transport: (),
79 shutdown: broadcast::channel(1).1,
80 shutdown_timer: Weak::new(),
81 sleep_duration: SLEEP_DURATION,
82 heartbeat_duration: MINIMUM_HEARTBEAT_DURATION,
83 verifier: Weak::new(),
84 version: Version::default(),
85 }
86 }
87}
88
89impl<H, S, T> ConnectionTaskBuilder<H, S, T> {
90 pub fn handler<U>(self, handler: Weak<U>) -> ConnectionTaskBuilder<U, S, T> {
91 ConnectionTaskBuilder {
92 handler,
93 state: self.state,
94 keychain: self.keychain,
95 transport: self.transport,
96 shutdown: self.shutdown,
97 shutdown_timer: self.shutdown_timer,
98 sleep_duration: self.sleep_duration,
99 heartbeat_duration: self.heartbeat_duration,
100 verifier: self.verifier,
101 version: self.version,
102 }
103 }
104
105 pub fn state<U>(self, state: Weak<ServerState<U>>) -> ConnectionTaskBuilder<H, U, T> {
106 ConnectionTaskBuilder {
107 handler: self.handler,
108 state,
109 keychain: self.keychain,
110 transport: self.transport,
111 shutdown: self.shutdown,
112 shutdown_timer: self.shutdown_timer,
113 sleep_duration: self.sleep_duration,
114 heartbeat_duration: self.heartbeat_duration,
115 verifier: self.verifier,
116 version: self.version,
117 }
118 }
119
120 pub fn keychain(self, keychain: ServerKeychain) -> ConnectionTaskBuilder<H, S, T> {
121 ConnectionTaskBuilder {
122 handler: self.handler,
123 state: self.state,
124 keychain,
125 transport: self.transport,
126 shutdown: self.shutdown,
127 shutdown_timer: self.shutdown_timer,
128 sleep_duration: self.sleep_duration,
129 heartbeat_duration: self.heartbeat_duration,
130 verifier: self.verifier,
131 version: self.version,
132 }
133 }
134
135 pub fn transport<U>(self, transport: U) -> ConnectionTaskBuilder<H, S, U> {
136 ConnectionTaskBuilder {
137 handler: self.handler,
138 keychain: self.keychain,
139 state: self.state,
140 transport,
141 shutdown: self.shutdown,
142 shutdown_timer: self.shutdown_timer,
143 sleep_duration: self.sleep_duration,
144 heartbeat_duration: self.heartbeat_duration,
145 verifier: self.verifier,
146 version: self.version,
147 }
148 }
149
150 pub fn shutdown(self, shutdown: broadcast::Receiver<()>) -> ConnectionTaskBuilder<H, S, T> {
151 ConnectionTaskBuilder {
152 handler: self.handler,
153 state: self.state,
154 keychain: self.keychain,
155 transport: self.transport,
156 shutdown,
157 shutdown_timer: self.shutdown_timer,
158 sleep_duration: self.sleep_duration,
159 heartbeat_duration: self.heartbeat_duration,
160 verifier: self.verifier,
161 version: self.version,
162 }
163 }
164
165 pub fn shutdown_timer(
166 self,
167 shutdown_timer: Weak<RwLock<ShutdownTimer>>,
168 ) -> ConnectionTaskBuilder<H, S, T> {
169 ConnectionTaskBuilder {
170 handler: self.handler,
171 state: self.state,
172 keychain: self.keychain,
173 transport: self.transport,
174 shutdown: self.shutdown,
175 shutdown_timer,
176 sleep_duration: self.sleep_duration,
177 heartbeat_duration: self.heartbeat_duration,
178 verifier: self.verifier,
179 version: self.version,
180 }
181 }
182
183 pub fn sleep_duration(self, sleep_duration: Duration) -> ConnectionTaskBuilder<H, S, T> {
184 ConnectionTaskBuilder {
185 handler: self.handler,
186 state: self.state,
187 keychain: self.keychain,
188 transport: self.transport,
189 shutdown: self.shutdown,
190 shutdown_timer: self.shutdown_timer,
191 sleep_duration,
192 heartbeat_duration: self.heartbeat_duration,
193 verifier: self.verifier,
194 version: self.version,
195 }
196 }
197
198 pub fn heartbeat_duration(
199 self,
200 heartbeat_duration: Duration,
201 ) -> ConnectionTaskBuilder<H, S, T> {
202 ConnectionTaskBuilder {
203 handler: self.handler,
204 state: self.state,
205 keychain: self.keychain,
206 transport: self.transport,
207 shutdown: self.shutdown,
208 shutdown_timer: self.shutdown_timer,
209 sleep_duration: self.sleep_duration,
210 heartbeat_duration,
211 verifier: self.verifier,
212 version: self.version,
213 }
214 }
215
216 pub fn verifier(self, verifier: Weak<Verifier>) -> ConnectionTaskBuilder<H, S, T> {
217 ConnectionTaskBuilder {
218 handler: self.handler,
219 state: self.state,
220 keychain: self.keychain,
221 transport: self.transport,
222 shutdown: self.shutdown,
223 shutdown_timer: self.shutdown_timer,
224 sleep_duration: self.sleep_duration,
225 heartbeat_duration: self.heartbeat_duration,
226 verifier,
227 version: self.version,
228 }
229 }
230
231 pub fn version(self, version: Version) -> ConnectionTaskBuilder<H, S, T> {
232 ConnectionTaskBuilder {
233 handler: self.handler,
234 state: self.state,
235 keychain: self.keychain,
236 transport: self.transport,
237 shutdown: self.shutdown,
238 shutdown_timer: self.shutdown_timer,
239 sleep_duration: self.sleep_duration,
240 heartbeat_duration: self.heartbeat_duration,
241 verifier: self.verifier,
242 version,
243 }
244 }
245}
246
247impl<H, T> ConnectionTaskBuilder<H, Response<H::Response>, T>
248where
249 H: ServerHandler + Sync + 'static,
250 H::Request: DeserializeOwned + Send + Sync + 'static,
251 H::Response: Serialize + Send + 'static,
252 T: Transport + 'static,
253{
254 pub fn spawn(self) -> ConnectionTask {
255 ConnectionTask(tokio::spawn(self.run()))
256 }
257
258 async fn run(self) -> io::Result<()> {
259 let ConnectionTaskBuilder {
260 handler,
261 state,
262 keychain,
263 transport,
264 mut shutdown,
265 shutdown_timer,
266 sleep_duration,
267 heartbeat_duration,
268 verifier,
269 version,
270 } = self;
271
272 let (mut local_shutdown, channel_tx, connection_state) = ConnectionState::channel();
274
275 macro_rules! terminate_connection {
277 (@fatal $($msg:tt)+) => {
279 error!($($msg)+);
280 terminate_connection!();
281 return Err(io::Error::new(io::ErrorKind::Other, format!($($msg)+)));
282 };
283
284 (@error($tx:ident, $rx:ident) $($msg:tt)+) => {
286 error!($($msg)+);
287 terminate_connection!($tx, $rx);
288 return Err(io::Error::new(io::ErrorKind::Other, format!($($msg)+)));
289 };
290
291 (@debug($tx:ident, $rx:ident) $($msg:tt)+) => {
293 debug!($($msg)+);
294 terminate_connection!($tx, $rx);
295 return Ok(());
296 };
297
298 (@shutdown) => {
300 debug!("Shutdown triggered before a connection could be fully established");
301 terminate_connection!();
302 return Ok(());
303 };
304
305 (@shutdown) => {
307 debug!("Shutdown triggered before a connection could be fully established");
308 terminate_connection!();
309 return Ok(());
310 };
311
312 (@shutdown($id:ident, $tx:ident, $rx:ident)) => {{
314 debug!("[Conn {}] Shutdown triggered", $id);
315 terminate_connection!($tx, $rx);
316 return Ok(());
317 }};
318
319 ($tx:ident, $rx:ident) => {
322 let _ = channel_tx.send(($tx, $rx));
324
325 terminate_connection!();
326 };
327
328 () => {
331 if let Some(state) = Weak::upgrade(&state) {
333 if let Some(timer) = Weak::upgrade(&shutdown_timer) {
334 if state.connections.read().await.values().filter(|conn| !conn.is_finished()).count() <= 1 {
335 debug!("Last connection terminating, so restarting shutdown timer");
336 timer.write().await.restart();
337 }
338 }
339 }
340 };
341 }
342
343 macro_rules! await_or_shutdown {
347 ($(@save($id:ident, $tx:ident, $rx:ident))? $future:expr) => {{
348 let mut f = $future;
349
350 loop {
351 let use_shutdown = match shutdown.try_recv() {
352 Ok(_) => {
353 terminate_connection!(@shutdown $(($id, $tx, $rx))?);
354 }
355 Err(broadcast::error::TryRecvError::Empty) => true,
356 Err(broadcast::error::TryRecvError::Lagged(_)) => true,
357 Err(broadcast::error::TryRecvError::Closed) => false,
358 };
359
360 let use_local_shutdown = match local_shutdown.try_recv() {
361 Ok(_) => {
362 terminate_connection!(@shutdown $(($id, $tx, $rx))?);
363 }
364 Err(oneshot::error::TryRecvError::Empty) => true,
365 Err(oneshot::error::TryRecvError::Closed) => false,
366 };
367
368 if use_shutdown && use_local_shutdown {
369 tokio::select! {
370 x = shutdown.recv() => {
371 if x.is_err() {
372 continue;
373 }
374
375 terminate_connection!(@shutdown $(($id, $tx, $rx))?);
376 }
377 x = &mut local_shutdown => {
378 if x.is_err() {
379 continue;
380 }
381
382 terminate_connection!(@shutdown $(($id, $tx, $rx))?);
383 }
384 x = &mut f => { break x; }
385 }
386 } else if use_shutdown {
387 tokio::select! {
388 x = shutdown.recv() => {
389 if x.is_err() {
390 continue;
391 }
392
393 terminate_connection!(@shutdown $(($id, $tx, $rx))?);
394 }
395 x = &mut f => { break x; }
396 }
397 } else if use_local_shutdown {
398 tokio::select! {
399 x = &mut local_shutdown => {
400 if x.is_err() {
401 continue;
402 }
403
404 terminate_connection!(@shutdown $(($id, $tx, $rx))?);
405 }
406 x = &mut f => { break x; }
407 }
408 } else {
409 break f.await;
410 }
411 }
412 }};
413 }
414
415 let handler = match Weak::upgrade(&handler) {
417 Some(handler) => handler,
418 None => {
419 terminate_connection!(@fatal "Failed to setup connection because handler dropped");
420 }
421 };
422
423 let state = match Weak::upgrade(&state) {
425 Some(state) => state,
426 None => {
427 terminate_connection!(@fatal "Failed to setup connection because state dropped");
428 }
429 };
430
431 debug!("Establishing full connection using {transport:?}");
433 let mut connection = match Weak::upgrade(&verifier) {
434 Some(verifier) => {
435 match await_or_shutdown!(Box::pin(Connection::server(
436 transport,
437 verifier.as_ref(),
438 keychain,
439 version
440 ))) {
441 Ok(connection) => connection,
442 Err(x) => {
443 terminate_connection!(@fatal "Failed to setup connection: {x}");
444 }
445 }
446 }
447 None => {
448 terminate_connection!(@fatal "Verifier has been dropped");
449 }
450 };
451
452 let id = connection.id();
454
455 info!("[Conn {id}] Connection established");
457 if let Err(x) = await_or_shutdown!(handler.on_connect(id)) {
458 terminate_connection!(@fatal "[Conn {id}] Accepting connection failed: {x}");
459 }
460
461 let mut last_heartbeat = Instant::now();
462
463 let (tx, mut rx) = match state.connections.write().await.remove(&id) {
465 Some(conn) => match conn.shutdown_and_wait().await {
466 Some(x) => {
467 debug!("[Conn {id}] Marked as existing connection");
468 x
469 }
470 None => {
471 warn!("[Conn {id}] Existing connection with id, but channels not saved");
472 mpsc::unbounded_channel::<Response<H::Response>>()
473 }
474 },
475 None => {
476 debug!("[Conn {id}] Marked as new connection");
477 mpsc::unbounded_channel::<Response<H::Response>>()
478 }
479 };
480
481 state.connections.write().await.insert(id, connection_state);
483
484 debug!("[Conn {id}] Beginning read/write loop");
485 loop {
486 let ready = match await_or_shutdown!(
487 @save(id, tx, rx)
488 Box::pin(connection.ready(Interest::READABLE | Interest::WRITABLE))
489 ) {
490 Ok(ready) => ready,
491 Err(x) => {
492 terminate_connection!(@error(tx, rx) "[Conn {id}] Failed to examine ready state: {x}");
493 }
494 };
495
496 let mut read_blocked = !ready.is_readable();
498 let mut write_blocked = !ready.is_writable();
499
500 if ready.is_readable() {
501 match connection.try_read_frame() {
502 Ok(Some(frame)) => match UntypedRequest::from_slice(frame.as_item()) {
503 Ok(request) => match request.to_typed_request() {
504 Ok(request) => {
505 if log::log_enabled!(Level::Debug) {
506 let debug_header = if !request.header.is_empty() {
507 format!(" | header {}", request.header)
508 } else {
509 String::new()
510 };
511 debug!("[Conn {id}] New request {}{debug_header}", request.id);
512 }
513 let origin_id = request.id.clone();
514 let ctx = RequestCtx {
515 connection_id: id,
516 request,
517 reply: ServerReply {
518 origin_id,
519 tx: tx.clone(),
520 },
521 };
522
523 let handler = Arc::clone(&handler);
526 tokio::spawn(async move { handler.on_request(ctx).await });
527 }
528 Err(x) => {
529 if log::log_enabled!(Level::Debug) {
530 error!(
531 "[Conn {id}] Failed receiving {}",
532 String::from_utf8_lossy(&request.payload),
533 );
534 }
535
536 error!("[Conn {id}] Invalid request: {x}");
537 }
538 },
539 Err(x) => {
540 error!("[Conn {id}] Invalid request payload: {x}");
541 }
542 },
543 Ok(None) => {
544 terminate_connection!(@debug(tx, rx) "[Conn {id}] Connection closed");
545 }
546 Err(x) if x.kind() == io::ErrorKind::WouldBlock => read_blocked = true,
547 Err(x) => {
548 terminate_connection!(@error(tx, rx) "[Conn {id}] {x}");
549 }
550 }
551 }
552
553 if ready.is_writable() {
556 if last_heartbeat.elapsed() >= heartbeat_duration {
558 trace!("[Conn {id}] Sending heartbeat via empty frame");
559 match connection.try_write_frame(Frame::empty()) {
560 Ok(()) => (),
561 Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
562 Err(x) => error!("[Conn {id}] Send failed: {x}"),
563 }
564 last_heartbeat = Instant::now();
565 }
566 else if let Ok(response) = rx.try_recv() {
570 if log_enabled!(Level::Trace) {
572 trace!(
573 "[Conn {id}] Sending {}",
574 &response
575 .to_vec()
576 .map(|x| String::from_utf8_lossy(&x).to_string())
577 .unwrap_or_else(|_| "<Cannot serialize>".to_string())
578 );
579 }
580
581 match response.to_vec() {
582 Ok(data) => match connection.try_write_frame(data) {
583 Ok(()) => (),
584 Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
585 Err(x) => error!("[Conn {id}] Send failed: {x}"),
586 },
587 Err(x) => {
588 error!("[Conn {id}] Unable to serialize outgoing response: {x}");
589 }
590 }
591 } else {
592 match connection.try_flush() {
599 Ok(0) => write_blocked = true,
600 Ok(_) => (),
601 Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
602 Err(x) => {
603 error!("[Conn {id}] Failed to flush outgoing data: {x}");
604 }
605 }
606 }
607 }
608
609 if read_blocked && write_blocked {
611 tokio::time::sleep(sleep_duration).await;
612 }
613 }
614 }
615}
616
617#[cfg(test)]
618mod tests {
619 use std::sync::atomic::{AtomicBool, Ordering};
620
621 use async_trait::async_trait;
622 use distant_auth::DummyAuthHandler;
623 use test_log::test;
624
625 use super::*;
626 use crate::common::{
627 HeapSecretKey, InmemoryTransport, Ready, Reconnectable, Request, Response,
628 };
629 use crate::server::{ConnectionId, Shutdown};
630
631 struct TestServerHandler;
632
633 #[async_trait]
634 impl ServerHandler for TestServerHandler {
635 type Request = u16;
636 type Response = String;
637
638 async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
639 ctx.reply.send("hello".to_string()).unwrap();
641 }
642 }
643
644 macro_rules! wait_for_termination {
645 ($task:ident) => {{
646 let timeout_millis = 500;
647 let sleep_millis = 50;
648 let start = std::time::Instant::now();
649 while !$task.is_finished() {
650 if start.elapsed() > std::time::Duration::from_millis(timeout_millis) {
651 panic!("Exceeded timeout of {timeout_millis}ms");
652 }
653 tokio::time::sleep(std::time::Duration::from_millis(sleep_millis)).await;
654 }
655 }};
656 }
657
658 macro_rules! server_version {
659 () => {
660 Version::new(1, 2, 3)
661 };
662 }
663
664 #[test(tokio::test)]
665 async fn should_terminate_if_fails_access_verifier() {
666 let handler = Arc::new(TestServerHandler);
667 let state = Arc::new(ServerState::default());
668 let keychain = ServerKeychain::new();
669 let (t1, _t2) = InmemoryTransport::pair(100);
670 let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
671
672 let task = ConnectionTask::build()
673 .handler(Arc::downgrade(&handler))
674 .state(Arc::downgrade(&state))
675 .keychain(keychain)
676 .transport(t1)
677 .shutdown_timer(Arc::downgrade(&shutdown_timer))
678 .verifier(Weak::new())
679 .spawn();
680
681 wait_for_termination!(task);
682
683 let err = task.await.unwrap_err();
684 assert!(
685 err.to_string().contains("Verifier has been dropped"),
686 "Unexpected error: {err}"
687 );
688 }
689
690 #[test(tokio::test)]
691 async fn should_terminate_if_fails_to_setup_server_connection() {
692 let handler = Arc::new(TestServerHandler);
693 let state = Arc::new(ServerState::default());
694 let keychain = ServerKeychain::new();
695 let (t1, t2) = InmemoryTransport::pair(100);
696 let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
697
698 let verifier = Arc::new(Verifier::static_key(HeapSecretKey::generate(32).unwrap()));
700
701 let task = ConnectionTask::build()
702 .handler(Arc::downgrade(&handler))
703 .state(Arc::downgrade(&state))
704 .keychain(keychain)
705 .transport(t1)
706 .shutdown_timer(Arc::downgrade(&shutdown_timer))
707 .verifier(Arc::downgrade(&verifier))
708 .version(server_version!())
709 .spawn();
710
711 tokio::spawn(async move {
713 let _client = Connection::client(t2, DummyAuthHandler, server_version!())
714 .await
715 .expect("Fail to establish client-side connection");
716 });
717
718 wait_for_termination!(task);
719
720 let err = task.await.unwrap_err();
721 assert!(
722 err.to_string().contains("Failed to setup connection"),
723 "Unexpected error: {err}"
724 );
725 }
726
727 #[test(tokio::test)]
728 async fn should_terminate_if_fails_access_server_handler() {
729 let state = Arc::new(ServerState::default());
730 let keychain = ServerKeychain::new();
731 let (t1, t2) = InmemoryTransport::pair(100);
732 let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
733 let verifier = Arc::new(Verifier::none());
734
735 let task = ConnectionTask::build()
736 .handler(Weak::<TestServerHandler>::new())
737 .state(Arc::downgrade(&state))
738 .keychain(keychain)
739 .transport(t1)
740 .shutdown_timer(Arc::downgrade(&shutdown_timer))
741 .verifier(Arc::downgrade(&verifier))
742 .version(server_version!())
743 .spawn();
744
745 tokio::spawn(async move {
747 let _client = Connection::client(t2, DummyAuthHandler, server_version!())
748 .await
749 .expect("Fail to establish client-side connection");
750 });
751
752 wait_for_termination!(task);
753
754 let err = task.await.unwrap_err();
755 assert!(
756 err.to_string().contains("handler dropped"),
757 "Unexpected error: {err}"
758 );
759 }
760
761 #[test(tokio::test)]
762 async fn should_terminate_if_accepting_connection_fails_on_server_handler() {
763 struct BadAcceptServerHandler;
764
765 #[async_trait]
766 impl ServerHandler for BadAcceptServerHandler {
767 type Request = u16;
768 type Response = String;
769
770 async fn on_connect(&self, _: ConnectionId) -> io::Result<()> {
771 Err(io::Error::new(io::ErrorKind::Other, "bad connect"))
772 }
773
774 async fn on_request(&self, _: RequestCtx<Self::Request, Self::Response>) {
775 unreachable!();
776 }
777 }
778
779 let handler = Arc::new(BadAcceptServerHandler);
780 let state = Arc::new(ServerState::default());
781 let keychain = ServerKeychain::new();
782 let (t1, t2) = InmemoryTransport::pair(100);
783 let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
784 let verifier = Arc::new(Verifier::none());
785
786 let task = ConnectionTask::build()
787 .handler(Arc::downgrade(&handler))
788 .state(Arc::downgrade(&state))
789 .keychain(keychain)
790 .transport(t1)
791 .shutdown_timer(Arc::downgrade(&shutdown_timer))
792 .verifier(Arc::downgrade(&verifier))
793 .version(server_version!())
794 .spawn();
795
796 tokio::spawn(async move {
799 let _client = Connection::client(t2, DummyAuthHandler, server_version!())
800 .await
801 .expect("Fail to establish client-side connection");
802 });
803
804 wait_for_termination!(task);
805
806 let err = task.await.unwrap_err();
807 assert!(
808 err.to_string().contains("Accepting connection failed"),
809 "Unexpected error: {err}"
810 );
811 }
812
813 #[test(tokio::test)]
814 async fn should_terminate_if_connection_fails_to_become_ready() {
815 let handler = Arc::new(TestServerHandler);
816 let state = Arc::new(ServerState::default());
817 let keychain = ServerKeychain::new();
818 let (t1, t2) = InmemoryTransport::pair(100);
819 let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
820 let verifier = Arc::new(Verifier::none());
821
822 #[derive(Debug)]
823 struct FakeTransport {
824 inner: InmemoryTransport,
825 fail_ready: Arc<AtomicBool>,
826 }
827
828 #[async_trait]
829 impl Transport for FakeTransport {
830 fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
831 self.inner.try_read(buf)
832 }
833
834 fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
835 self.inner.try_write(buf)
836 }
837
838 async fn ready(&self, interest: Interest) -> io::Result<Ready> {
839 if self.fail_ready.load(Ordering::Relaxed) {
840 Err(io::Error::new(
841 io::ErrorKind::Other,
842 "targeted ready failure",
843 ))
844 } else {
845 self.inner.ready(interest).await
846 }
847 }
848 }
849
850 #[async_trait]
851 impl Reconnectable for FakeTransport {
852 async fn reconnect(&mut self) -> io::Result<()> {
853 self.inner.reconnect().await
854 }
855 }
856
857 let fail_ready = Arc::new(AtomicBool::new(false));
858 let task = ConnectionTask::build()
859 .handler(Arc::downgrade(&handler))
860 .state(Arc::downgrade(&state))
861 .keychain(keychain)
862 .transport(FakeTransport {
863 inner: t1,
864 fail_ready: Arc::clone(&fail_ready),
865 })
866 .shutdown_timer(Arc::downgrade(&shutdown_timer))
867 .verifier(Arc::downgrade(&verifier))
868 .version(server_version!())
869 .spawn();
870
871 tokio::spawn(async move {
874 let _client = Connection::client(t2, DummyAuthHandler, server_version!())
875 .await
876 .expect("Fail to establish client-side connection");
877
878 tokio::time::sleep(Duration::from_millis(50)).await;
881
882 fail_ready.store(true, Ordering::Relaxed);
885 tokio::time::sleep(Duration::from_secs(1)).await;
886 });
887
888 wait_for_termination!(task);
889
890 let err = task.await.unwrap_err();
891 assert!(
892 err.to_string().contains("targeted ready failure"),
893 "Unexpected error: {err}"
894 );
895 }
896
897 #[test(tokio::test)]
898 async fn should_terminate_if_connection_closes() {
899 let handler = Arc::new(TestServerHandler);
900 let state = Arc::new(ServerState::default());
901 let keychain = ServerKeychain::new();
902 let (t1, t2) = InmemoryTransport::pair(100);
903 let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
904 let verifier = Arc::new(Verifier::none());
905
906 let task = ConnectionTask::build()
907 .handler(Arc::downgrade(&handler))
908 .state(Arc::downgrade(&state))
909 .keychain(keychain)
910 .transport(t1)
911 .shutdown_timer(Arc::downgrade(&shutdown_timer))
912 .verifier(Arc::downgrade(&verifier))
913 .version(server_version!())
914 .spawn();
915
916 tokio::spawn(async move {
919 let _client = Connection::client(t2, DummyAuthHandler, server_version!())
920 .await
921 .expect("Fail to establish client-side connection");
922 });
923
924 wait_for_termination!(task);
925 task.await.unwrap();
926 }
927
928 #[test(tokio::test)]
929 async fn should_invoke_server_handler_to_process_request_in_new_task_and_forward_responses() {
930 let handler = Arc::new(TestServerHandler);
931 let state = Arc::new(ServerState::default());
932 let keychain = ServerKeychain::new();
933 let (t1, t2) = InmemoryTransport::pair(100);
934 let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
935 let verifier = Arc::new(Verifier::none());
936
937 let _conn = ConnectionTask::build()
938 .handler(Arc::downgrade(&handler))
939 .state(Arc::downgrade(&state))
940 .keychain(keychain)
941 .transport(t1)
942 .shutdown_timer(Arc::downgrade(&shutdown_timer))
943 .verifier(Arc::downgrade(&verifier))
944 .version(server_version!())
945 .spawn();
946
947 let task = tokio::spawn(async move {
949 let mut client = Connection::client(t2, DummyAuthHandler, server_version!())
950 .await
951 .expect("Fail to establish client-side connection");
952
953 client.write_frame_for(&Request::new(123u16)).await.unwrap();
954 client
955 .read_frame_as::<Response<String>>()
956 .await
957 .unwrap()
958 .unwrap()
959 });
960
961 let response = task.await.unwrap();
962 assert_eq!(response.payload, "hello");
963 }
964
965 #[test(tokio::test)]
966 async fn should_send_heartbeat_via_empty_frame_every_minimum_duration() {
967 let handler = Arc::new(TestServerHandler);
968 let state = Arc::new(ServerState::default());
969 let keychain = ServerKeychain::new();
970 let (t1, t2) = InmemoryTransport::pair(100);
971 let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
972 let verifier = Arc::new(Verifier::none());
973
974 let _conn = ConnectionTask::build()
975 .handler(Arc::downgrade(&handler))
976 .state(Arc::downgrade(&state))
977 .keychain(keychain)
978 .transport(t1)
979 .shutdown_timer(Arc::downgrade(&shutdown_timer))
980 .heartbeat_duration(Duration::from_millis(200))
981 .verifier(Arc::downgrade(&verifier))
982 .version(server_version!())
983 .spawn();
984
985 let task = tokio::spawn(async move {
987 let mut client = Connection::client(t2, DummyAuthHandler, server_version!())
988 .await
989 .expect("Fail to establish client-side connection");
990
991 assert_eq!(
993 client.try_read_frame().unwrap_err().kind(),
994 io::ErrorKind::WouldBlock,
995 "got a frame early"
996 );
997
998 tokio::time::sleep(Duration::from_millis(250)).await;
1000 assert_eq!(
1001 client.read_frame().await.unwrap().unwrap(),
1002 Frame::empty(),
1003 "non-empty frame"
1004 );
1005
1006 assert_eq!(
1008 client.try_read_frame().unwrap_err().kind(),
1009 io::ErrorKind::WouldBlock,
1010 "got a frame early"
1011 );
1012
1013 tokio::time::sleep(Duration::from_millis(250)).await;
1015 assert_eq!(
1016 client.read_frame().await.unwrap().unwrap(),
1017 Frame::empty(),
1018 "non-empty frame"
1019 );
1020 });
1021
1022 task.await.unwrap();
1023 }
1024
1025 #[test(tokio::test)]
1026 async fn should_be_able_to_shutdown_while_establishing_connection() {
1027 let handler = Arc::new(TestServerHandler);
1028 let state = Arc::new(ServerState::default());
1029 let keychain = ServerKeychain::new();
1030 let (t1, _t2) = InmemoryTransport::pair(100);
1031 let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
1032 let verifier = Arc::new(Verifier::none());
1033
1034 let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
1035 let conn = ConnectionTask::build()
1036 .handler(Arc::downgrade(&handler))
1037 .state(Arc::downgrade(&state))
1038 .keychain(keychain)
1039 .transport(t1)
1040 .shutdown(shutdown_rx)
1041 .shutdown_timer(Arc::downgrade(&shutdown_timer))
1042 .heartbeat_duration(Duration::from_millis(200))
1043 .verifier(Arc::downgrade(&verifier))
1044 .spawn();
1045
1046 shutdown_tx
1049 .send(())
1050 .expect("Failed to send shutdown signal");
1051 conn.await.unwrap();
1052 }
1053
1054 #[test(tokio::test)]
1055 async fn should_be_able_to_shutdown_while_accepting_connection() {
1056 struct HangingAcceptServerHandler;
1057
1058 #[async_trait]
1059 impl ServerHandler for HangingAcceptServerHandler {
1060 type Request = ();
1061 type Response = ();
1062
1063 async fn on_connect(&self, _: ConnectionId) -> io::Result<()> {
1064 tokio::time::sleep(Duration::MAX).await;
1066 Err(io::Error::new(io::ErrorKind::Other, "bad connect"))
1067 }
1068
1069 async fn on_request(&self, _: RequestCtx<Self::Request, Self::Response>) {
1070 unreachable!();
1071 }
1072 }
1073
1074 let handler = Arc::new(HangingAcceptServerHandler);
1075 let state = Arc::new(ServerState::default());
1076 let keychain = ServerKeychain::new();
1077 let (t1, t2) = InmemoryTransport::pair(100);
1078 let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
1079 let verifier = Arc::new(Verifier::none());
1080
1081 let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
1082 let conn = ConnectionTask::build()
1083 .handler(Arc::downgrade(&handler))
1084 .state(Arc::downgrade(&state))
1085 .keychain(keychain)
1086 .transport(t1)
1087 .shutdown(shutdown_rx)
1088 .shutdown_timer(Arc::downgrade(&shutdown_timer))
1089 .heartbeat_duration(Duration::from_millis(200))
1090 .verifier(Arc::downgrade(&verifier))
1091 .version(server_version!())
1092 .spawn();
1093
1094 let _client_task =
1096 tokio::spawn(Connection::client(t2, DummyAuthHandler, server_version!()));
1097
1098 shutdown_tx
1101 .send(())
1102 .expect("Failed to send shutdown signal");
1103 conn.await.unwrap();
1104 }
1105
1106 #[test(tokio::test)]
1107 async fn should_be_able_to_shutdown_while_waiting_for_connection_to_be_ready() {
1108 struct AcceptServerHandler {
1109 tx: mpsc::Sender<()>,
1110 }
1111
1112 #[async_trait]
1113 impl ServerHandler for AcceptServerHandler {
1114 type Request = ();
1115 type Response = ();
1116
1117 async fn on_connect(&self, _: ConnectionId) -> io::Result<()> {
1118 self.tx.send(()).await.unwrap();
1119 Ok(())
1120 }
1121
1122 async fn on_request(&self, _: RequestCtx<Self::Request, Self::Response>) {
1123 unreachable!();
1124 }
1125 }
1126
1127 let (tx, mut rx) = mpsc::channel(100);
1128 let handler = Arc::new(AcceptServerHandler { tx });
1129 let state = Arc::new(ServerState::default());
1130 let keychain = ServerKeychain::new();
1131 let (t1, t2) = InmemoryTransport::pair(100);
1132 let shutdown_timer = Arc::new(RwLock::new(ShutdownTimer::start(Shutdown::Never)));
1133 let verifier = Arc::new(Verifier::none());
1134
1135 let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
1136 let conn = ConnectionTask::build()
1137 .handler(Arc::downgrade(&handler))
1138 .state(Arc::downgrade(&state))
1139 .keychain(keychain)
1140 .transport(t1)
1141 .shutdown(shutdown_rx)
1142 .shutdown_timer(Arc::downgrade(&shutdown_timer))
1143 .heartbeat_duration(Duration::from_millis(200))
1144 .verifier(Arc::downgrade(&verifier))
1145 .version(server_version!())
1146 .spawn();
1147
1148 let _client_task =
1150 tokio::spawn(Connection::client(t2, DummyAuthHandler, server_version!()));
1151
1152 let _ = rx.recv().await;
1154
1155 shutdown_tx
1158 .send(())
1159 .expect("Failed to send shutdown signal");
1160 conn.await.unwrap();
1161 }
1162}