ecksport_net/
worker.rs

1//! Concurrent connection worker that returns an event queue of new channel
2//! handles and notifications, operating in the background.
3// TODO make this also send and respond to ping messages
4
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use ecksport_core::frame::MsgFlags;
9use ecksport_core::peer::PeerData;
10use futures::{future, pin_mut};
11use tokio::sync::{mpsc, oneshot};
12use tracing::*;
13
14use ecksport_core::topic;
15use ecksport_core::traits::{AsyncRecvFrame, AsyncSendFrame};
16
17use crate::channel::InbMsg;
18use crate::channel_state::Creator;
19use crate::event::{InbEvent, OpenChanCmd, PushFlags, WorkerCommand};
20use crate::shared_state::{ChanSharedState, ConnSharedState};
21use crate::{channel, connection::Connection, errors::Error};
22
23/// Event about the connection that whatever is holding the worker handle would
24/// be interested in.
25pub enum WorkerEvent {
26    /// New channel that's been initiated by the remote.
27    NewChan(channel::ChannelHandle),
28
29    /// Oneshot notification sent by the remote.
30    Notification(topic::Topic, Vec<u8>),
31}
32
33/// Handle to a connection worker that accepts new channels as channel handles
34/// and relays messages.
35pub struct ConnectionHandle {
36    shared: Arc<ConnSharedState>,
37    event_rx: mpsc::Receiver<Result<WorkerEvent, Error>>,
38    cmd_tx: mpsc::Sender<WorkerCommand>,
39}
40
41impl ConnectionHandle {
42    /// Returns the protocol that was negotiated in the client.
43    pub fn protocol(&self) -> topic::Topic {
44        self.shared.protocol()
45    }
46
47    /// Returns the peer data structure.
48    pub fn peer(&self) -> &PeerData {
49        self.shared.peer_data()
50    }
51
52    /// The party that initiated the underlying connection.
53    pub fn initiator(&self) -> Creator {
54        self.shared.initiator()
55    }
56
57    async fn submit_command(&self, cmd: WorkerCommand) -> Result<(), Error> {
58        if self.cmd_tx.send(cmd).await.is_err() {
59            return Err(Error::ConnWorkerExit);
60        }
61
62        Ok(())
63    }
64
65    /// Opens a channel by queueing a command and waiting for it to be executed
66    /// by the worker task.
67    pub async fn open_channel(
68        &self,
69        topic: topic::Topic,
70        init_msg: Vec<u8>,
71        flags: MsgFlags,
72    ) -> Result<channel::ChannelHandle, Error> {
73        // Create and send the open channel command.
74        let (chh_tx, chh_rx) = oneshot::channel();
75        let cmd = OpenChanCmd {
76            topic,
77            init_msg,
78            flags,
79            chh_tx,
80        };
81
82        self.submit_command(WorkerCommand::OpenChannel(cmd)).await?;
83
84        // Now wait for a response.
85        match chh_rx.await {
86            Ok(chh) => Ok(chh),
87            // TODO more expressive handling
88            Err(_) => Err(Error::ConnWorkerExit),
89        }
90    }
91
92    /// Queues a notification to be sent.
93    pub async fn send_notification(&self, topic: topic::Topic, msg: Vec<u8>) -> Result<(), Error> {
94        self.submit_command(WorkerCommand::SendNotification(topic, msg))
95            .await?;
96        Ok(())
97    }
98
99    /// Returns if there are any events in the queue.  This might be an error
100    /// message.
101    pub fn has_event(&mut self) -> bool {
102        !self.event_rx.is_empty()
103    }
104
105    /// Waits for an event, asynchronously.
106    pub async fn wait_event(&mut self) -> Result<Option<WorkerEvent>, Error> {
107        self.event_rx.recv().await.transpose()
108    }
109
110    /// Waits for an event, blockingly.
111    pub fn wait_event_blocking(&mut self) -> Result<Option<WorkerEvent>, Error> {
112        self.event_rx.blocking_recv().transpose()
113    }
114}
115
116impl Drop for ConnectionHandle {
117    fn drop(&mut self) {
118        self.shared.set_dropped();
119    }
120}
121
122/// Encapsulates worker IO so that we can select across all the different inputs
123/// easily.
124struct WorkerIo<'c, T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin + 'static> {
125    /// Low-level connection type that manages the low-level protocol state machine.
126    conn: &'c mut Connection<T>,
127
128    /// Queue that commands are sent to the worker from consumer code holding
129    /// channels and whatnot.
130    cmd_rx: mpsc::Receiver<WorkerCommand>,
131}
132
133impl<'c, T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin + 'static> WorkerIo<'c, T> {
134    pub fn new(conn: &'c mut Connection<T>, cmd_rx: mpsc::Receiver<WorkerCommand>) -> Self {
135        Self { conn, cmd_rx }
136    }
137
138    /// Selects across the IO inputs and returns a signal to act on.
139    async fn wait_for_signal(&mut self) -> Result<Signal, Error> {
140        let ev_fut = self.conn.next_event();
141        let cmd_fut = self.cmd_rx.recv();
142        pin_mut!(ev_fut);
143        pin_mut!(cmd_fut);
144
145        match future::select(ev_fut, cmd_fut).await {
146            future::Either::Left((ev, _)) => match ev? {
147                Some(ev) => Ok(Signal::ConnEvent(ev)),
148                None => Ok(Signal::RemoteClosed),
149            },
150
151            future::Either::Right((cmd, _)) => match cmd {
152                Some(cmd) => Ok(Signal::Command(cmd)),
153                // All cleaned up, I guess?
154                None => Ok(Signal::Shutdown),
155            },
156        }
157    }
158}
159
160/// Signal from IO handles to be processed by the connection worker.
161enum Signal {
162    ConnEvent(InbEvent),
163    Command(WorkerCommand),
164    RemoteClosed,
165    Shutdown,
166}
167
168/// Wraps the persistent worker bookkeeping state we need to share in different
169/// places.
170struct WorkerState {
171    /// Protocol the worker is operating on.
172    protocol: topic::Topic,
173
174    /// State shared between the worker and channels, used to track when
175    /// channels are ungracefully dropped so we can still clean them up.
176    shared_state: Arc<ConnSharedState>,
177
178    /// Inbound queues we maintain for each channel.  When the remote closes
179    /// their side of the channel, we remove these queues to signal to the
180    /// consumer that the channel has been closed.
181    chan_inb_tbl: HashMap<u32, mpsc::Sender<Result<channel::InbMsg, Error>>>,
182
183    /// Event queue sent to whatever is monitoring the worker.
184    event_tx: mpsc::Sender<Result<WorkerEvent, Error>>,
185
186    /// Queue cloned when constructing new channels so they can send commands
187    /// back to the worker.
188    cmd_tx: mpsc::Sender<WorkerCommand>,
189}
190
191impl WorkerState {
192    fn new(
193        protocol: topic::Topic,
194        shared_state: Arc<ConnSharedState>,
195        event_tx: mpsc::Sender<Result<WorkerEvent, Error>>,
196        cmd_tx: mpsc::Sender<WorkerCommand>,
197    ) -> Self {
198        Self {
199            protocol,
200            shared_state,
201            chan_inb_tbl: HashMap::new(),
202            event_tx,
203            cmd_tx,
204        }
205    }
206
207    /// Relays an event to the worker handle.
208    async fn relay_event(&self, ev: WorkerEvent) -> Result<(), Error> {
209        if self.event_tx.send(Ok(ev)).await.is_err() {
210            // TODO should we do something to handle this more?
211            return Err(Error::ConnRecvDropped);
212        }
213        Ok(())
214    }
215
216    /// Relays an error message to the receiver.
217    async fn relay_err(&self, e: Error) -> Result<(), Error> {
218        if self.event_tx.send(Err(e)).await.is_err() {
219            return Err(Error::ConnRecvDropped);
220        }
221        Ok(())
222    }
223
224    /// Closes the inbound queue for a channel, in response to the remote cleaning it up.
225    fn close_chan_inb(&mut self, id: u32) {
226        assert!(self.chan_inb_tbl.remove(&id).is_some());
227    }
228
229    /// After we've seen that a channel has been closed at the connection level,
230    /// we should clean it up here.
231    async fn cleanup_chan(&mut self, id: u32) {
232        if self.chan_inb_tbl.contains_key(&id) {
233            warn!(%id, "cleaning up channel that we still have inbound queue open");
234            self.chan_inb_tbl.remove(&id);
235        }
236
237        let mut states = self.shared_state.chan_shared.write().await;
238        assert!(states.remove(&id).is_some());
239    }
240
241    /// Updates the internal bookkeeping and creates the appropriate channels
242    /// to construct a new channel handle.
243    async fn create_chan(
244        &mut self,
245        new_id: u32,
246        topic: topic::Topic,
247    ) -> Result<channel::ChannelHandle, Error> {
248        // TODO make configurable
249        let (inb_tx, inb_rx) = mpsc::channel(64);
250        assert!(self.chan_inb_tbl.insert(new_id, inb_tx).is_none());
251
252        let css = Arc::new(ChanSharedState::new(self.protocol, topic));
253        {
254            let mut states = self.shared_state.chan_shared.write().await;
255            assert!(!states.contains_key(&new_id));
256            states.insert(new_id, css.clone());
257        }
258
259        let cmd_tx = self.cmd_tx.clone();
260        let pd = self.shared_state.peer_data().clone();
261        let handle = channel::ChannelHandle::new(new_id, pd, css, inb_rx, cmd_tx);
262        Ok(handle)
263    }
264
265    /// Relays an inbound message by sending it on the channel to the handle.
266    async fn relay_inb_msg(
267        &mut self,
268        id: u32,
269        flags: PushFlags,
270        payload: Vec<u8>,
271    ) -> Result<(), Error> {
272        let ch_inb_tx = self.chan_inb_tbl.get(&id).ok_or(Error::RecvOnUnkChan(id))?;
273        if ch_inb_tx
274            .send(Ok(InbMsg::new(flags, payload)))
275            .await
276            .is_err()
277        {
278            // This means the channel handle was dropped.  We'll get around to
279            // checking for this on the next loop around.
280            warn!(%id, "channel dropped without being explicitly closed");
281        }
282
283        Ok(())
284    }
285
286    async fn relay_inb_err(&mut self, id: u32, err: Error) -> Result<(), Error> {
287        let ch_inb_tx = self.chan_inb_tbl.get(&id).ok_or(Error::RecvOnUnkChan(id))?;
288        if ch_inb_tx.send(Err(err)).await.is_err() {
289            // This means the channel handle was dropped.  We'll get around to
290            // checking for this on the next loop around.
291            warn!(%id, "channel dropped without being explicitly closed");
292        }
293
294        Ok(())
295    }
296
297    /// Relays a shutdown error message to all channels that can receive it, and drops our
298    /// end of the channel.
299    async fn shutdown_channels(&mut self) -> Result<(), Error> {
300        // Calling .drain() here will drop the channels.
301        for (id, ch) in self.chan_inb_tbl.drain() {
302            if ch.send(Err(Error::ConnWorkerExit)).await.is_err() {
303                warn!(%id, "channel dropped without being explicitly closed");
304            }
305        }
306
307        Ok(())
308    }
309}
310
311/// Spawns a connection worker and returns a [``WorkerHandle``] to it.
312pub async fn spawn_connection_worker<
313    T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin + 'static,
314>(
315    conn: Connection<T>,
316) -> ConnectionHandle {
317    // Create the channels.
318    // TODO make these configurable
319    let (event_tx, event_rx) = mpsc::channel(64);
320    let (cmd_tx, cmd_rx) = mpsc::channel(256);
321
322    // Copy some stuff into a new shared state instance.
323    let proto = conn.protocol();
324    let pd = conn.peer_data().clone();
325    let peer = pd.location().clone();
326    let initer = conn.initiator();
327    let shared = Arc::new(ConnSharedState::new(proto, pd, initer));
328
329    let worker_span = debug_span!("conn", %peer);
330    debug!(parent: &worker_span, "spawning worker task");
331
332    // Actually spawn the task!  This took fucking long enough to get to!
333    tokio::spawn(
334        conn_worker_task(conn, shared.clone(), event_tx, cmd_rx, cmd_tx.clone())
335            .instrument(worker_span),
336    );
337
338    ConnectionHandle {
339        shared,
340        cmd_tx,
341        event_rx,
342    }
343}
344
345pub async fn conn_worker_task<
346    T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin + 'static,
347>(
348    mut conn: Connection<T>,
349    shared: Arc<ConnSharedState>,
350    event_tx: mpsc::Sender<Result<WorkerEvent, Error>>,
351    cmd_rx: mpsc::Receiver<WorkerCommand>,
352    cmd_tx: mpsc::Sender<WorkerCommand>,
353) {
354    let proto = conn.protocol();
355    let wio = WorkerIo::new(&mut conn, cmd_rx);
356    let mut wstate = WorkerState::new(proto, shared, event_tx, cmd_tx);
357    if let Err(e) = do_worker(wio, &mut wstate).await {
358        warn!(err = %e, "connection worker task exited");
359    }
360}
361
362// TODO add some provision so that we can clean up dropped channel handles
363// without having to receive a message on them first, maybe using an atomicbool
364async fn do_worker<'c, T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin + 'static>(
365    mut wio: WorkerIo<'c, T>,
366    wstate: &mut WorkerState,
367) -> Result<(), Error> {
368    loop {
369        let signal = match wio.wait_for_signal().await {
370            Ok(s) => s,
371            Err(e) => {
372                // Shut down all the channels with an error, just so the
373                // receivers can know what happened.
374                // TODO maybe make this more explicit about being an error
375                // condition so the channels can take appropriate action
376                if let Err(e) = wstate.shutdown_channels().await {
377                    warn!(err = %e, "encountered error while shutting down channels");
378                }
379
380                return Err(e.into());
381            }
382        };
383
384        match signal {
385            Signal::ConnEvent(ev) => handle_conn_event(ev, wstate).await?,
386
387            Signal::Command(cmd) => match cmd {
388                WorkerCommand::OpenChannel(occ) => {
389                    let topic = occ.topic;
390                    let flags = occ.flags;
391
392                    // Send the initial opening so we can get the chan ID.
393                    let id = wio.conn.open_channel(topic, occ.init_msg, flags).await?;
394
395                    let ch_handle = wstate.create_chan(id, topic).await?;
396
397                    if occ.chh_tx.send(ch_handle).is_err() {
398                        warn!(%topic, %id, "channel sendback closed before open completed");
399                        // TODO should we drop the rest of this?
400                    }
401                }
402
403                WorkerCommand::SendMsg(msg) => {
404                    let id = msg.id();
405                    let flags = *msg.flags();
406                    let still_open = wio.conn.send_message(id, msg.into_payload(), flags).await?;
407                    if !still_open {
408                        wstate.cleanup_chan(id).await;
409                    }
410                }
411
412                WorkerCommand::CloseChannel(id) => {
413                    let still_open = wio.conn.close_channel(id).await?;
414                    if !still_open {
415                        wstate.cleanup_chan(id).await;
416                    }
417                }
418
419                WorkerCommand::SendNotification(topic, notif) => {
420                    wio.conn.send_notification(topic, notif).await?;
421                }
422            },
423
424            Signal::RemoteClosed => {
425                if !wstate.chan_inb_tbl.is_empty() {
426                    // Is this even a warning?
427                    let channels = wstate.chan_inb_tbl.len();
428                    warn!(%channels, "remote side closed with channels still open");
429                }
430
431                return Ok(());
432            }
433
434            Signal::Shutdown => {
435                // TODO Is there any other cleanup we should do?
436                wstate.shutdown_channels().await?;
437                return Ok(());
438            }
439        }
440    }
441}
442
443async fn handle_conn_event(conn_ev: InbEvent, wstate: &mut WorkerState) -> Result<(), Error> {
444    match conn_ev {
445        InbEvent::NewChannel(id, topic, flags, payload) => {
446            // Create the channel and write the payload to it.
447            let ch_handle = wstate.create_chan(id, topic).await?;
448            wstate.relay_inb_msg(id, flags, payload).await?;
449
450            // Send off the new channel event.
451            wstate.relay_event(WorkerEvent::NewChan(ch_handle)).await?;
452            Ok(())
453        }
454
455        InbEvent::PushChannel(id, flags, payload) => {
456            wstate.relay_inb_msg(id, flags, payload).await?;
457            Ok(())
458        }
459
460        InbEvent::CloseChannel(id, still_alive) => {
461            // Close the queue.
462            wstate.close_chan_inb(id);
463
464            // In this case we just remove the channel channel.
465            if !still_alive {
466                wstate.cleanup_chan(id).await;
467            }
468
469            Ok(())
470        }
471
472        InbEvent::Notification(topic, payload) => {
473            // Send the notification, exit if we have to.
474            wstate
475                .relay_event(WorkerEvent::Notification(topic, payload))
476                .await?;
477            Ok(())
478        }
479    }
480}