distant_net/server/
state.rs

1use std::collections::HashMap;
2
3use tokio::sync::{mpsc, oneshot, RwLock};
4use tokio::task::JoinHandle;
5
6use crate::common::{Backup, ConnectionId, Keychain};
7
8/// Contains all top-level state for the server
9pub struct ServerState<T> {
10    /// Mapping of connection ids to their tasks.
11    pub connections: RwLock<HashMap<ConnectionId, ConnectionState<T>>>,
12
13    /// Mapping of connection ids to (OTP, backup)
14    pub keychain: Keychain<oneshot::Receiver<Backup>>,
15}
16
17impl<T> ServerState<T> {
18    pub fn new() -> Self {
19        Self {
20            connections: RwLock::new(HashMap::new()),
21            keychain: Keychain::new(),
22        }
23    }
24}
25
26impl<T> Default for ServerState<T> {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32pub struct ConnectionState<T> {
33    shutdown_tx: oneshot::Sender<()>,
34    task: JoinHandle<Option<(mpsc::UnboundedSender<T>, mpsc::UnboundedReceiver<T>)>>,
35}
36
37impl<T: Send + 'static> ConnectionState<T> {
38    /// Creates new state with appropriate channels, returning
39    /// (shutdown receiver, channel sender, state).
40    #[allow(clippy::type_complexity)]
41    pub fn channel() -> (
42        oneshot::Receiver<()>,
43        oneshot::Sender<(mpsc::UnboundedSender<T>, mpsc::UnboundedReceiver<T>)>,
44        Self,
45    ) {
46        let (shutdown_tx, shutdown_rx) = oneshot::channel();
47        let (channel_tx, channel_rx) = oneshot::channel();
48
49        (
50            shutdown_rx,
51            channel_tx,
52            Self {
53                shutdown_tx,
54                task: tokio::spawn(async move {
55                    match channel_rx.await {
56                        Ok(x) => Some(x),
57                        Err(_) => None,
58                    }
59                }),
60            },
61        )
62    }
63
64    pub fn is_finished(&self) -> bool {
65        self.task.is_finished()
66    }
67
68    pub async fn shutdown_and_wait(
69        self,
70    ) -> Option<(mpsc::UnboundedSender<T>, mpsc::UnboundedReceiver<T>)> {
71        let _ = self.shutdown_tx.send(());
72        self.task.await.unwrap()
73    }
74}