distant_net/manager/server/
connection.rs

1use std::collections::HashMap;
2use std::{fmt, io};
3
4use log::*;
5use tokio::sync::{mpsc, oneshot};
6use tokio::task::JoinHandle;
7
8use crate::client::{Mailbox, UntypedClient};
9use crate::common::{ConnectionId, Destination, Map, UntypedRequest, UntypedResponse};
10use crate::manager::data::{ManagerChannelId, ManagerResponse};
11use crate::server::ServerReply;
12
13/// Represents a connection a distant manager has with some distant-compatible server
14pub struct ManagerConnection {
15    pub id: ConnectionId,
16    pub destination: Destination,
17    pub options: Map,
18    tx: mpsc::UnboundedSender<Action>,
19
20    action_task: JoinHandle<()>,
21    request_task: JoinHandle<()>,
22    response_task: JoinHandle<()>,
23}
24
25#[derive(Clone)]
26pub struct ManagerChannel {
27    channel_id: ManagerChannelId,
28    tx: mpsc::UnboundedSender<Action>,
29}
30
31impl ManagerChannel {
32    /// Returns the id associated with the channel.
33    pub fn id(&self) -> ManagerChannelId {
34        self.channel_id
35    }
36
37    /// Sends the untyped request to the server on the other side of the channel.
38    pub fn send(&self, req: UntypedRequest<'static>) -> io::Result<()> {
39        let id = self.channel_id;
40
41        self.tx.send(Action::Write { id, req }).map_err(|x| {
42            io::Error::new(
43                io::ErrorKind::BrokenPipe,
44                format!("channel {id} send failed: {x}"),
45            )
46        })
47    }
48
49    /// Closes the channel, unregistering it with the connection.
50    pub fn close(&self) -> io::Result<()> {
51        let id = self.channel_id;
52        self.tx.send(Action::Unregister { id }).map_err(|x| {
53            io::Error::new(
54                io::ErrorKind::BrokenPipe,
55                format!("channel {id} close failed: {x}"),
56            )
57        })
58    }
59}
60
61impl ManagerConnection {
62    pub async fn spawn(
63        spawn: Destination,
64        options: Map,
65        mut client: UntypedClient,
66    ) -> io::Result<Self> {
67        let connection_id = rand::random();
68        let (tx, rx) = mpsc::unbounded_channel();
69
70        // NOTE: Ensure that the connection is severed when the client is dropped; otherwise, when
71        // the connection is terminated via aborting it or the connection being dropped, the
72        // connection will persist which can cause problems such as lonely shutdown of the server
73        // never triggering!
74        client.shutdown_on_drop(true);
75
76        let (request_tx, request_rx) = mpsc::unbounded_channel();
77        let action_task = tokio::spawn(action_task(connection_id, rx, request_tx));
78        let response_task = tokio::spawn(response_task(
79            connection_id,
80            client.assign_default_mailbox(100).await?,
81            tx.clone(),
82        ));
83        let request_task = tokio::spawn(request_task(connection_id, client, request_rx));
84
85        Ok(Self {
86            id: connection_id,
87            destination: spawn,
88            options,
89            tx,
90            action_task,
91            request_task,
92            response_task,
93        })
94    }
95
96    pub fn open_channel(&self, reply: ServerReply<ManagerResponse>) -> io::Result<ManagerChannel> {
97        let channel_id = rand::random();
98        self.tx
99            .send(Action::Register {
100                id: channel_id,
101                reply,
102            })
103            .map_err(|x| {
104                io::Error::new(
105                    io::ErrorKind::BrokenPipe,
106                    format!("open_channel failed: {x}"),
107                )
108            })?;
109        Ok(ManagerChannel {
110            channel_id,
111            tx: self.tx.clone(),
112        })
113    }
114
115    pub async fn channel_ids(&self) -> io::Result<Vec<ManagerChannelId>> {
116        let (tx, rx) = oneshot::channel();
117        self.tx
118            .send(Action::GetRegistered { cb: tx })
119            .map_err(|x| {
120                io::Error::new(
121                    io::ErrorKind::BrokenPipe,
122                    format!("channel_ids failed: {x}"),
123                )
124            })?;
125
126        let channel_ids = rx.await.map_err(|x| {
127            io::Error::new(
128                io::ErrorKind::BrokenPipe,
129                format!("channel_ids callback dropped: {x}"),
130            )
131        })?;
132        Ok(channel_ids)
133    }
134
135    /// Aborts the tasks used to engage with the connection.
136    pub fn abort(&self) {
137        self.action_task.abort();
138        self.request_task.abort();
139        self.response_task.abort();
140    }
141}
142
143impl Drop for ManagerConnection {
144    fn drop(&mut self) {
145        self.abort();
146    }
147}
148
149enum Action {
150    Register {
151        id: ManagerChannelId,
152        reply: ServerReply<ManagerResponse>,
153    },
154
155    Unregister {
156        id: ManagerChannelId,
157    },
158
159    GetRegistered {
160        cb: oneshot::Sender<Vec<ManagerChannelId>>,
161    },
162
163    Read {
164        res: UntypedResponse<'static>,
165    },
166
167    Write {
168        id: ManagerChannelId,
169        req: UntypedRequest<'static>,
170    },
171}
172
173impl fmt::Debug for Action {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        match self {
176            Self::Register { id, .. } => write!(f, "Action::Register {{ id: {id}, .. }}"),
177            Self::Unregister { id } => write!(f, "Action::Unregister {{ id: {id} }}"),
178            Self::GetRegistered { .. } => write!(f, "Action::GetRegistered {{ .. }}"),
179            Self::Read { .. } => write!(f, "Action::Read {{ .. }}"),
180            Self::Write { id, .. } => write!(f, "Action::Write {{ id: {id}, .. }}"),
181        }
182    }
183}
184
185/// Internal task to process outgoing [`UntypedRequest`]s.
186async fn request_task(
187    id: ConnectionId,
188    mut client: UntypedClient,
189    mut rx: mpsc::UnboundedReceiver<UntypedRequest<'static>>,
190) {
191    while let Some(req) = rx.recv().await {
192        trace!("[Conn {id}] Firing off request {}", req.id);
193        if let Err(x) = client.fire(req).await {
194            error!("[Conn {id}] Failed to send request: {x}");
195        }
196    }
197
198    trace!("[Conn {id}] Manager request task closed");
199}
200
201/// Internal task to process incoming [`UntypedResponse`]s.
202async fn response_task(
203    id: ConnectionId,
204    mut mailbox: Mailbox<UntypedResponse<'static>>,
205    tx: mpsc::UnboundedSender<Action>,
206) {
207    while let Some(res) = mailbox.next().await {
208        trace!(
209            "[Conn {id}] Receiving response {} to request {}",
210            res.id,
211            res.origin_id
212        );
213        if let Err(x) = tx.send(Action::Read { res }) {
214            error!("[Conn {id}] Failed to forward received response: {x}");
215        }
216    }
217
218    trace!("[Conn {id}] Manager response task closed");
219}
220
221/// Internal task to process [`Action`] items.
222///
223/// * `id` - the id of the connection.
224/// * `rx` - used to receive new [`Action`]s to process.
225/// * `tx` - used to send outgoing requests through the connection.
226async fn action_task(
227    id: ConnectionId,
228    mut rx: mpsc::UnboundedReceiver<Action>,
229    tx: mpsc::UnboundedSender<UntypedRequest<'static>>,
230) {
231    let mut registered = HashMap::new();
232
233    while let Some(action) = rx.recv().await {
234        trace!("[Conn {id}] {action:?}");
235
236        match action {
237            Action::Register { id, reply } => {
238                registered.insert(id, reply);
239            }
240            Action::Unregister { id } => {
241                registered.remove(&id);
242            }
243            Action::GetRegistered { cb } => {
244                let _ = cb.send(registered.keys().copied().collect());
245            }
246            Action::Read { mut res } => {
247                // Split {channel id}_{request id} back into pieces and
248                // update the origin id to match the request id only
249                let channel_id = match res.origin_id.split_once('_') {
250                    Some((cid_str, oid_str)) => {
251                        if let Ok(cid) = cid_str.parse::<ManagerChannelId>() {
252                            res.set_origin_id(oid_str.to_string());
253                            cid
254                        } else {
255                            continue;
256                        }
257                    }
258                    None => continue,
259                };
260
261                if let Some(reply) = registered.get(&channel_id) {
262                    let response = ManagerResponse::Channel {
263                        id: channel_id,
264                        response: res,
265                    };
266
267                    if let Err(x) = reply.send(response) {
268                        error!("[Conn {id}] {x}");
269                    }
270                }
271            }
272            Action::Write { id, mut req } => {
273                // Combine channel id with request id so we can properly forward
274                // the response containing this in the origin id
275                req.set_id(format!("{id}_{}", req.id));
276
277                if let Err(x) = tx.send(req) {
278                    error!("[Conn {id}] {x}");
279                }
280            }
281        }
282    }
283
284    trace!("[Conn {id}] Manager action task closed");
285}