1use std::collections::HashMap;
6use std::sync::Arc;
7
8use ecksport_core::frame::MsgFlags;
9use ecksport_core::peer::PeerData;
10use futures::{future, pin_mut};
11use tokio::sync::{mpsc, oneshot};
12use tracing::*;
13
14use ecksport_core::topic;
15use ecksport_core::traits::{AsyncRecvFrame, AsyncSendFrame};
16
17use crate::channel::InbMsg;
18use crate::channel_state::Creator;
19use crate::event::{InbEvent, OpenChanCmd, PushFlags, WorkerCommand};
20use crate::shared_state::{ChanSharedState, ConnSharedState};
21use crate::{channel, connection::Connection, errors::Error};
22
23pub enum WorkerEvent {
26 NewChan(channel::ChannelHandle),
28
29 Notification(topic::Topic, Vec<u8>),
31}
32
33pub struct ConnectionHandle {
36 shared: Arc<ConnSharedState>,
37 event_rx: mpsc::Receiver<Result<WorkerEvent, Error>>,
38 cmd_tx: mpsc::Sender<WorkerCommand>,
39}
40
41impl ConnectionHandle {
42 pub fn protocol(&self) -> topic::Topic {
44 self.shared.protocol()
45 }
46
47 pub fn peer(&self) -> &PeerData {
49 self.shared.peer_data()
50 }
51
52 pub fn initiator(&self) -> Creator {
54 self.shared.initiator()
55 }
56
57 async fn submit_command(&self, cmd: WorkerCommand) -> Result<(), Error> {
58 if self.cmd_tx.send(cmd).await.is_err() {
59 return Err(Error::ConnWorkerExit);
60 }
61
62 Ok(())
63 }
64
65 pub async fn open_channel(
68 &self,
69 topic: topic::Topic,
70 init_msg: Vec<u8>,
71 flags: MsgFlags,
72 ) -> Result<channel::ChannelHandle, Error> {
73 let (chh_tx, chh_rx) = oneshot::channel();
75 let cmd = OpenChanCmd {
76 topic,
77 init_msg,
78 flags,
79 chh_tx,
80 };
81
82 self.submit_command(WorkerCommand::OpenChannel(cmd)).await?;
83
84 match chh_rx.await {
86 Ok(chh) => Ok(chh),
87 Err(_) => Err(Error::ConnWorkerExit),
89 }
90 }
91
92 pub async fn send_notification(&self, topic: topic::Topic, msg: Vec<u8>) -> Result<(), Error> {
94 self.submit_command(WorkerCommand::SendNotification(topic, msg))
95 .await?;
96 Ok(())
97 }
98
99 pub fn has_event(&mut self) -> bool {
102 !self.event_rx.is_empty()
103 }
104
105 pub async fn wait_event(&mut self) -> Result<Option<WorkerEvent>, Error> {
107 self.event_rx.recv().await.transpose()
108 }
109
110 pub fn wait_event_blocking(&mut self) -> Result<Option<WorkerEvent>, Error> {
112 self.event_rx.blocking_recv().transpose()
113 }
114}
115
116impl Drop for ConnectionHandle {
117 fn drop(&mut self) {
118 self.shared.set_dropped();
119 }
120}
121
122struct WorkerIo<'c, T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin + 'static> {
125 conn: &'c mut Connection<T>,
127
128 cmd_rx: mpsc::Receiver<WorkerCommand>,
131}
132
133impl<'c, T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin + 'static> WorkerIo<'c, T> {
134 pub fn new(conn: &'c mut Connection<T>, cmd_rx: mpsc::Receiver<WorkerCommand>) -> Self {
135 Self { conn, cmd_rx }
136 }
137
138 async fn wait_for_signal(&mut self) -> Result<Signal, Error> {
140 let ev_fut = self.conn.next_event();
141 let cmd_fut = self.cmd_rx.recv();
142 pin_mut!(ev_fut);
143 pin_mut!(cmd_fut);
144
145 match future::select(ev_fut, cmd_fut).await {
146 future::Either::Left((ev, _)) => match ev? {
147 Some(ev) => Ok(Signal::ConnEvent(ev)),
148 None => Ok(Signal::RemoteClosed),
149 },
150
151 future::Either::Right((cmd, _)) => match cmd {
152 Some(cmd) => Ok(Signal::Command(cmd)),
153 None => Ok(Signal::Shutdown),
155 },
156 }
157 }
158}
159
160enum Signal {
162 ConnEvent(InbEvent),
163 Command(WorkerCommand),
164 RemoteClosed,
165 Shutdown,
166}
167
168struct WorkerState {
171 protocol: topic::Topic,
173
174 shared_state: Arc<ConnSharedState>,
177
178 chan_inb_tbl: HashMap<u32, mpsc::Sender<Result<channel::InbMsg, Error>>>,
182
183 event_tx: mpsc::Sender<Result<WorkerEvent, Error>>,
185
186 cmd_tx: mpsc::Sender<WorkerCommand>,
189}
190
191impl WorkerState {
192 fn new(
193 protocol: topic::Topic,
194 shared_state: Arc<ConnSharedState>,
195 event_tx: mpsc::Sender<Result<WorkerEvent, Error>>,
196 cmd_tx: mpsc::Sender<WorkerCommand>,
197 ) -> Self {
198 Self {
199 protocol,
200 shared_state,
201 chan_inb_tbl: HashMap::new(),
202 event_tx,
203 cmd_tx,
204 }
205 }
206
207 async fn relay_event(&self, ev: WorkerEvent) -> Result<(), Error> {
209 if self.event_tx.send(Ok(ev)).await.is_err() {
210 return Err(Error::ConnRecvDropped);
212 }
213 Ok(())
214 }
215
216 async fn relay_err(&self, e: Error) -> Result<(), Error> {
218 if self.event_tx.send(Err(e)).await.is_err() {
219 return Err(Error::ConnRecvDropped);
220 }
221 Ok(())
222 }
223
224 fn close_chan_inb(&mut self, id: u32) {
226 assert!(self.chan_inb_tbl.remove(&id).is_some());
227 }
228
229 async fn cleanup_chan(&mut self, id: u32) {
232 if self.chan_inb_tbl.contains_key(&id) {
233 warn!(%id, "cleaning up channel that we still have inbound queue open");
234 self.chan_inb_tbl.remove(&id);
235 }
236
237 let mut states = self.shared_state.chan_shared.write().await;
238 assert!(states.remove(&id).is_some());
239 }
240
241 async fn create_chan(
244 &mut self,
245 new_id: u32,
246 topic: topic::Topic,
247 ) -> Result<channel::ChannelHandle, Error> {
248 let (inb_tx, inb_rx) = mpsc::channel(64);
250 assert!(self.chan_inb_tbl.insert(new_id, inb_tx).is_none());
251
252 let css = Arc::new(ChanSharedState::new(self.protocol, topic));
253 {
254 let mut states = self.shared_state.chan_shared.write().await;
255 assert!(!states.contains_key(&new_id));
256 states.insert(new_id, css.clone());
257 }
258
259 let cmd_tx = self.cmd_tx.clone();
260 let pd = self.shared_state.peer_data().clone();
261 let handle = channel::ChannelHandle::new(new_id, pd, css, inb_rx, cmd_tx);
262 Ok(handle)
263 }
264
265 async fn relay_inb_msg(
267 &mut self,
268 id: u32,
269 flags: PushFlags,
270 payload: Vec<u8>,
271 ) -> Result<(), Error> {
272 let ch_inb_tx = self.chan_inb_tbl.get(&id).ok_or(Error::RecvOnUnkChan(id))?;
273 if ch_inb_tx
274 .send(Ok(InbMsg::new(flags, payload)))
275 .await
276 .is_err()
277 {
278 warn!(%id, "channel dropped without being explicitly closed");
281 }
282
283 Ok(())
284 }
285
286 async fn relay_inb_err(&mut self, id: u32, err: Error) -> Result<(), Error> {
287 let ch_inb_tx = self.chan_inb_tbl.get(&id).ok_or(Error::RecvOnUnkChan(id))?;
288 if ch_inb_tx.send(Err(err)).await.is_err() {
289 warn!(%id, "channel dropped without being explicitly closed");
292 }
293
294 Ok(())
295 }
296
297 async fn shutdown_channels(&mut self) -> Result<(), Error> {
300 for (id, ch) in self.chan_inb_tbl.drain() {
302 if ch.send(Err(Error::ConnWorkerExit)).await.is_err() {
303 warn!(%id, "channel dropped without being explicitly closed");
304 }
305 }
306
307 Ok(())
308 }
309}
310
311pub async fn spawn_connection_worker<
313 T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin + 'static,
314>(
315 conn: Connection<T>,
316) -> ConnectionHandle {
317 let (event_tx, event_rx) = mpsc::channel(64);
320 let (cmd_tx, cmd_rx) = mpsc::channel(256);
321
322 let proto = conn.protocol();
324 let pd = conn.peer_data().clone();
325 let peer = pd.location().clone();
326 let initer = conn.initiator();
327 let shared = Arc::new(ConnSharedState::new(proto, pd, initer));
328
329 let worker_span = debug_span!("conn", %peer);
330 debug!(parent: &worker_span, "spawning worker task");
331
332 tokio::spawn(
334 conn_worker_task(conn, shared.clone(), event_tx, cmd_rx, cmd_tx.clone())
335 .instrument(worker_span),
336 );
337
338 ConnectionHandle {
339 shared,
340 cmd_tx,
341 event_rx,
342 }
343}
344
345pub async fn conn_worker_task<
346 T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin + 'static,
347>(
348 mut conn: Connection<T>,
349 shared: Arc<ConnSharedState>,
350 event_tx: mpsc::Sender<Result<WorkerEvent, Error>>,
351 cmd_rx: mpsc::Receiver<WorkerCommand>,
352 cmd_tx: mpsc::Sender<WorkerCommand>,
353) {
354 let proto = conn.protocol();
355 let wio = WorkerIo::new(&mut conn, cmd_rx);
356 let mut wstate = WorkerState::new(proto, shared, event_tx, cmd_tx);
357 if let Err(e) = do_worker(wio, &mut wstate).await {
358 warn!(err = %e, "connection worker task exited");
359 }
360}
361
362async fn do_worker<'c, T: AsyncRecvFrame + AsyncSendFrame + Sync + Send + Unpin + 'static>(
365 mut wio: WorkerIo<'c, T>,
366 wstate: &mut WorkerState,
367) -> Result<(), Error> {
368 loop {
369 let signal = match wio.wait_for_signal().await {
370 Ok(s) => s,
371 Err(e) => {
372 if let Err(e) = wstate.shutdown_channels().await {
377 warn!(err = %e, "encountered error while shutting down channels");
378 }
379
380 return Err(e.into());
381 }
382 };
383
384 match signal {
385 Signal::ConnEvent(ev) => handle_conn_event(ev, wstate).await?,
386
387 Signal::Command(cmd) => match cmd {
388 WorkerCommand::OpenChannel(occ) => {
389 let topic = occ.topic;
390 let flags = occ.flags;
391
392 let id = wio.conn.open_channel(topic, occ.init_msg, flags).await?;
394
395 let ch_handle = wstate.create_chan(id, topic).await?;
396
397 if occ.chh_tx.send(ch_handle).is_err() {
398 warn!(%topic, %id, "channel sendback closed before open completed");
399 }
401 }
402
403 WorkerCommand::SendMsg(msg) => {
404 let id = msg.id();
405 let flags = *msg.flags();
406 let still_open = wio.conn.send_message(id, msg.into_payload(), flags).await?;
407 if !still_open {
408 wstate.cleanup_chan(id).await;
409 }
410 }
411
412 WorkerCommand::CloseChannel(id) => {
413 let still_open = wio.conn.close_channel(id).await?;
414 if !still_open {
415 wstate.cleanup_chan(id).await;
416 }
417 }
418
419 WorkerCommand::SendNotification(topic, notif) => {
420 wio.conn.send_notification(topic, notif).await?;
421 }
422 },
423
424 Signal::RemoteClosed => {
425 if !wstate.chan_inb_tbl.is_empty() {
426 let channels = wstate.chan_inb_tbl.len();
428 warn!(%channels, "remote side closed with channels still open");
429 }
430
431 return Ok(());
432 }
433
434 Signal::Shutdown => {
435 wstate.shutdown_channels().await?;
437 return Ok(());
438 }
439 }
440 }
441}
442
443async fn handle_conn_event(conn_ev: InbEvent, wstate: &mut WorkerState) -> Result<(), Error> {
444 match conn_ev {
445 InbEvent::NewChannel(id, topic, flags, payload) => {
446 let ch_handle = wstate.create_chan(id, topic).await?;
448 wstate.relay_inb_msg(id, flags, payload).await?;
449
450 wstate.relay_event(WorkerEvent::NewChan(ch_handle)).await?;
452 Ok(())
453 }
454
455 InbEvent::PushChannel(id, flags, payload) => {
456 wstate.relay_inb_msg(id, flags, payload).await?;
457 Ok(())
458 }
459
460 InbEvent::CloseChannel(id, still_alive) => {
461 wstate.close_chan_inb(id);
463
464 if !still_alive {
466 wstate.cleanup_chan(id).await;
467 }
468
469 Ok(())
470 }
471
472 InbEvent::Notification(topic, payload) => {
473 wstate
475 .relay_event(WorkerEvent::Notification(topic, payload))
476 .await?;
477 Ok(())
478 }
479 }
480}