1use std::{
2 collections::BTreeMap,
3 pin::Pin,
4 sync::{Arc, Weak},
5};
6
7use moire::sync::SyncMutex;
8use tokio::sync::Semaphore;
9
10use moire::task::FutureExt as _;
11use roam_types::{
12 Caller, ChannelBinder, ChannelBody, ChannelClose, ChannelId, ChannelItem,
13 ChannelLivenessHandle, ChannelMessage, ChannelSink, CreditSink, Handler, IdAllocator,
14 IncomingChannelMessage, MaybeSend, Payload, ReplySink, RequestBody, RequestCall, RequestId,
15 RequestMessage, RequestResponse, RoamError, SelfRef, TxError,
16};
17
18use crate::session::{ConnectionHandle, ConnectionMessage, ConnectionSender, DropControlRequest};
19use moire::sync::mpsc;
20
21type ResponseSlot = moire::sync::oneshot::Sender<SelfRef<RequestMessage<'static>>>;
22
23struct DriverShared {
25 pending_responses: SyncMutex<BTreeMap<RequestId, ResponseSlot>>,
26 request_ids: SyncMutex<IdAllocator<RequestId>>,
27 channel_ids: SyncMutex<IdAllocator<ChannelId>>,
28 channel_senders:
30 SyncMutex<BTreeMap<ChannelId, tokio::sync::mpsc::Sender<IncomingChannelMessage>>>,
31 channel_buffers: SyncMutex<BTreeMap<ChannelId, Vec<IncomingChannelMessage>>>,
38 channel_credits: SyncMutex<BTreeMap<ChannelId, Arc<Semaphore>>>,
41}
42
43struct CallerDropGuard {
44 control_tx: mpsc::UnboundedSender<DropControlRequest>,
45 request: DropControlRequest,
46}
47
48impl Drop for CallerDropGuard {
49 fn drop(&mut self) {
50 let _ = self.control_tx.send(self.request);
51 }
52}
53
54pub struct DriverReplySink {
61 sender: Option<ConnectionSender>,
62 request_id: RequestId,
63 binder: DriverChannelBinder,
64}
65
66impl ReplySink for DriverReplySink {
67 async fn send_reply(mut self, response: RequestResponse<'_>) {
68 let sender = self
69 .sender
70 .take()
71 .expect("unreachable: send_reply takes self by value");
72 if let Err(_e) = sender.send_response(self.request_id, response).await {
73 sender.mark_failure(self.request_id, "send_response failed");
74 }
75 }
76
77 fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
78 Some(&self.binder)
79 }
80}
81
82impl Drop for DriverReplySink {
84 fn drop(&mut self) {
85 if let Some(sender) = self.sender.take() {
86 sender.mark_failure(self.request_id, "no reply sent")
87 }
88 }
89}
90
91pub struct DriverChannelSink {
99 sender: ConnectionSender,
100 channel_id: ChannelId,
101 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
102}
103
104impl ChannelSink for DriverChannelSink {
105 fn send_payload<'payload>(
106 &self,
107 payload: Payload<'payload>,
108 ) -> Pin<Box<dyn std::future::Future<Output = Result<(), TxError>> + Send + 'payload>> {
109 let sender = self.sender.clone();
110 let channel_id = self.channel_id;
111 Box::pin(async move {
112 sender
113 .send(ConnectionMessage::Channel(ChannelMessage {
114 id: channel_id,
115 body: ChannelBody::Item(ChannelItem { item: payload }),
116 }))
117 .await
118 .map_err(|()| TxError::Transport("connection closed".into()))
119 })
120 }
121
122 fn close_channel(
123 &self,
124 _metadata: roam_types::Metadata,
125 ) -> Pin<Box<dyn std::future::Future<Output = Result<(), TxError>> + Send + 'static>> {
126 let sender = self.sender.clone();
130 let channel_id = self.channel_id;
131 Box::pin(async move {
132 sender
133 .send(ConnectionMessage::Channel(ChannelMessage {
134 id: channel_id,
135 body: ChannelBody::Close(ChannelClose {
136 metadata: Default::default(),
137 }),
138 }))
139 .await
140 .map_err(|()| TxError::Transport("connection closed".into()))
141 })
142 }
143
144 fn close_channel_on_drop(&self) {
145 let _ = self
146 .local_control_tx
147 .send(DriverLocalControl::CloseChannel {
148 channel_id: self.channel_id,
149 });
150 }
151}
152
153#[must_use = "Dropping NoopCaller may close the connection if it is the last caller."]
157#[derive(Clone)]
158pub struct NoopCaller(#[allow(dead_code)] DriverCaller);
159
160impl From<DriverCaller> for NoopCaller {
161 fn from(caller: DriverCaller) -> Self {
162 Self(caller)
163 }
164}
165
166#[derive(Clone)]
167struct DriverChannelBinder {
168 sender: ConnectionSender,
169 shared: Arc<DriverShared>,
170 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
171 drop_guard: Option<Arc<CallerDropGuard>>,
172}
173
174impl DriverChannelBinder {
175 fn create_tx_channel(
176 &self,
177 initial_credit: u32,
178 ) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
179 let channel_id = self.shared.channel_ids.lock().alloc();
180 let inner = DriverChannelSink {
181 sender: self.sender.clone(),
182 channel_id,
183 local_control_tx: self.local_control_tx.clone(),
184 };
185 let sink = Arc::new(CreditSink::new(inner, initial_credit));
186 self.shared
187 .channel_credits
188 .lock()
189 .insert(channel_id, Arc::clone(sink.credit()));
190 (channel_id, sink)
191 }
192
193 fn register_rx_channel(
194 &self,
195 channel_id: ChannelId,
196 ) -> tokio::sync::mpsc::Receiver<IncomingChannelMessage> {
197 let (tx, rx) = tokio::sync::mpsc::channel(64);
198 let mut terminal_buffered = false;
199 if let Some(buffered) = self.shared.channel_buffers.lock().remove(&channel_id) {
200 for msg in buffered {
201 let is_terminal = matches!(
202 msg,
203 IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
204 );
205 let _ = tx.try_send(msg);
206 if is_terminal {
207 terminal_buffered = true;
208 break;
209 }
210 }
211 }
212 if terminal_buffered {
213 self.shared.channel_credits.lock().remove(&channel_id);
214 return rx;
215 }
216
217 self.shared.channel_senders.lock().insert(channel_id, tx);
218 rx
219 }
220}
221
222impl ChannelBinder for DriverChannelBinder {
223 fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
224 let (id, sink) = self.create_tx_channel(initial_credit);
225 (id, sink as Arc<dyn ChannelSink>)
226 }
227
228 fn create_rx(
229 &self,
230 ) -> (
231 ChannelId,
232 tokio::sync::mpsc::Receiver<IncomingChannelMessage>,
233 ) {
234 let channel_id = self.shared.channel_ids.lock().alloc();
235 let rx = self.register_rx_channel(channel_id);
236 (channel_id, rx)
237 }
238
239 fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
240 let inner = DriverChannelSink {
241 sender: self.sender.clone(),
242 channel_id,
243 local_control_tx: self.local_control_tx.clone(),
244 };
245 let sink = Arc::new(CreditSink::new(inner, initial_credit));
246 self.shared
247 .channel_credits
248 .lock()
249 .insert(channel_id, Arc::clone(sink.credit()));
250 sink
251 }
252
253 fn register_rx(
254 &self,
255 channel_id: ChannelId,
256 ) -> tokio::sync::mpsc::Receiver<IncomingChannelMessage> {
257 self.register_rx_channel(channel_id)
258 }
259
260 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
261 self.drop_guard
262 .as_ref()
263 .map(|guard| guard.clone() as ChannelLivenessHandle)
264 }
265}
266
267#[derive(Clone)]
270pub struct DriverCaller {
271 sender: ConnectionSender,
272 shared: Arc<DriverShared>,
273 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
274 _drop_guard: Option<Arc<CallerDropGuard>>,
275}
276
277impl DriverCaller {
278 pub fn create_tx_channel(
284 &self,
285 initial_credit: u32,
286 ) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
287 let channel_id = self.shared.channel_ids.lock().alloc();
288 let inner = DriverChannelSink {
289 sender: self.sender.clone(),
290 channel_id,
291 local_control_tx: self.local_control_tx.clone(),
292 };
293 let sink = Arc::new(CreditSink::new(inner, initial_credit));
294 self.shared
295 .channel_credits
296 .lock()
297 .insert(channel_id, Arc::clone(sink.credit()));
298 (channel_id, sink)
299 }
300
301 #[cfg(test)]
306 pub(crate) fn connection_sender(&self) -> &ConnectionSender {
307 &self.sender
308 }
309
310 pub fn register_rx_channel(
315 &self,
316 channel_id: ChannelId,
317 ) -> tokio::sync::mpsc::Receiver<IncomingChannelMessage> {
318 let (tx, rx) = tokio::sync::mpsc::channel(64);
319 let mut terminal_buffered = false;
320 if let Some(buffered) = self.shared.channel_buffers.lock().remove(&channel_id) {
322 for msg in buffered {
323 let is_terminal = matches!(
324 msg,
325 IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
326 );
327 let _ = tx.try_send(msg);
328 if is_terminal {
329 terminal_buffered = true;
330 break;
331 }
332 }
333 }
334 if terminal_buffered {
335 self.shared.channel_credits.lock().remove(&channel_id);
336 return rx;
337 }
338
339 self.shared.channel_senders.lock().insert(channel_id, tx);
340 rx
341 }
342}
343
344impl ChannelBinder for DriverCaller {
345 fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
346 let (id, sink) = self.create_tx_channel(initial_credit);
347 (id, sink as Arc<dyn ChannelSink>)
348 }
349
350 fn create_rx(
351 &self,
352 ) -> (
353 ChannelId,
354 tokio::sync::mpsc::Receiver<IncomingChannelMessage>,
355 ) {
356 let channel_id = self.shared.channel_ids.lock().alloc();
357 let rx = self.register_rx_channel(channel_id);
358 (channel_id, rx)
359 }
360
361 fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
362 let inner = DriverChannelSink {
363 sender: self.sender.clone(),
364 channel_id,
365 local_control_tx: self.local_control_tx.clone(),
366 };
367 let sink = Arc::new(CreditSink::new(inner, initial_credit));
368 self.shared
369 .channel_credits
370 .lock()
371 .insert(channel_id, Arc::clone(sink.credit()));
372 sink
373 }
374
375 fn register_rx(
376 &self,
377 channel_id: ChannelId,
378 ) -> tokio::sync::mpsc::Receiver<IncomingChannelMessage> {
379 self.register_rx_channel(channel_id)
380 }
381
382 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
383 self._drop_guard
384 .as_ref()
385 .map(|guard| guard.clone() as ChannelLivenessHandle)
386 }
387}
388
389impl Caller for DriverCaller {
390 fn call<'a>(
391 &'a self,
392 call: RequestCall<'a>,
393 ) -> impl std::future::Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>>
394 + MaybeSend
395 + 'a {
396 async {
397 let req_id = self.shared.request_ids.lock().alloc();
399
400 let (tx, rx) = moire::sync::oneshot::channel("driver.response");
403 self.shared.pending_responses.lock().insert(req_id, tx);
404
405 let send_result = self
408 .sender
409 .send(ConnectionMessage::Request(RequestMessage {
410 id: req_id,
411 body: RequestBody::Call(call),
412 }))
413 .await;
414
415 if send_result.is_err() {
416 self.shared.pending_responses.lock().remove(&req_id);
418 return Err(RoamError::Cancelled);
419 }
420
421 let response_msg: SelfRef<RequestMessage<'static>> = rx
423 .named("awaiting_response")
424 .await
425 .map_err(|_| RoamError::Cancelled)?;
426
427 let response = response_msg.map(|m| match m.body {
429 RequestBody::Response(r) => r,
430 _ => unreachable!("pending_responses only gets Response variants"),
431 });
432
433 Ok(response)
434 }
435 .named("Caller::call")
436 }
437
438 fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
439 Some(self)
440 }
441}
442
443pub struct Driver<H: Handler<DriverReplySink>> {
450 sender: ConnectionSender,
451 rx: mpsc::Receiver<SelfRef<ConnectionMessage<'static>>>,
452 failures_rx: mpsc::UnboundedReceiver<(RequestId, &'static str)>,
453 local_control_rx: mpsc::UnboundedReceiver<DriverLocalControl>,
454 handler: Arc<H>,
455 shared: Arc<DriverShared>,
456 in_flight_handlers: BTreeMap<RequestId, moire::task::JoinHandle<()>>,
459 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
460 drop_control_seed: Option<mpsc::UnboundedSender<DropControlRequest>>,
461 drop_control_request: DropControlRequest,
462 drop_guard: SyncMutex<Option<Weak<CallerDropGuard>>>,
463}
464
465enum DriverLocalControl {
466 CloseChannel { channel_id: ChannelId },
467}
468
469impl<H: Handler<DriverReplySink>> Driver<H> {
470 pub fn new(handle: ConnectionHandle, handler: H) -> Self {
471 let conn_id = handle.connection_id();
472 let ConnectionHandle {
473 sender,
474 rx,
475 failures_rx,
476 control_tx,
477 parity,
478 } = handle;
479 let drop_control_request = DropControlRequest::Close(conn_id);
480 let (local_control_tx, local_control_rx) = mpsc::unbounded_channel("driver.local_control");
481 Self {
482 sender,
483 rx,
484 failures_rx,
485 local_control_rx,
486 handler: Arc::new(handler),
487 shared: Arc::new(DriverShared {
488 pending_responses: SyncMutex::new("driver.pending_responses", BTreeMap::new()),
489 request_ids: SyncMutex::new("driver.request_ids", IdAllocator::new(parity)),
490 channel_ids: SyncMutex::new("driver.channel_ids", IdAllocator::new(parity)),
491 channel_senders: SyncMutex::new("driver.channel_senders", BTreeMap::new()),
492 channel_buffers: SyncMutex::new("driver.channel_buffers", BTreeMap::new()),
493 channel_credits: SyncMutex::new("driver.channel_credits", BTreeMap::new()),
494 }),
495 in_flight_handlers: BTreeMap::new(),
496 local_control_tx,
497 drop_control_seed: control_tx,
498 drop_control_request,
499 drop_guard: SyncMutex::new("driver.drop_guard", None),
500 }
501 }
502
503 fn existing_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
509 self.drop_guard.lock().as_ref().and_then(Weak::upgrade)
510 }
511
512 fn connection_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
513 let drop_guard = if let Some(existing) = self.existing_drop_guard() {
514 Some(existing)
515 } else if let Some(seed) = &self.drop_control_seed {
516 let mut guard = self.drop_guard.lock();
517 if let Some(existing) = guard.as_ref().and_then(Weak::upgrade) {
518 Some(existing)
519 } else {
520 let arc = Arc::new(CallerDropGuard {
521 control_tx: seed.clone(),
522 request: self.drop_control_request,
523 });
524 *guard = Some(Arc::downgrade(&arc));
525 Some(arc)
526 }
527 } else {
528 None
529 };
530 drop_guard
531 }
532
533 pub fn caller(&self) -> DriverCaller {
534 let drop_guard = self.connection_drop_guard();
535 DriverCaller {
536 sender: self.sender.clone(),
537 shared: Arc::clone(&self.shared),
538 local_control_tx: self.local_control_tx.clone(),
539 _drop_guard: drop_guard,
540 }
541 }
542
543 fn internal_binder(&self) -> DriverChannelBinder {
544 DriverChannelBinder {
545 sender: self.sender.clone(),
546 shared: Arc::clone(&self.shared),
547 local_control_tx: self.local_control_tx.clone(),
548 drop_guard: self.existing_drop_guard(),
549 }
550 }
551
552 pub async fn run(&mut self) {
557 loop {
558 tokio::select! {
559 msg = self.rx.recv() => {
560 match msg {
561 Some(msg) => self.handle_msg(msg),
562 None => break,
563 }
564 }
565 Some((req_id, _reason)) = self.failures_rx.recv() => {
566 self.in_flight_handlers.remove(&req_id);
568 if self.shared.pending_responses.lock().remove(&req_id).is_none() {
569 let error: Result<(), RoamError<core::convert::Infallible>> =
573 Err(RoamError::Cancelled);
574 let _ = self.sender.send_response(req_id, RequestResponse {
575 ret: Payload::outgoing(&error),
576 channels: vec![],
577 metadata: Default::default(),
578 }).await;
579 }
580 }
581 Some(ctrl) = self.local_control_rx.recv() => {
582 self.handle_local_control(ctrl).await;
583 }
584 }
585 }
586
587 for (_, handle) in std::mem::take(&mut self.in_flight_handlers) {
588 handle.abort();
589 }
590 self.shared.pending_responses.lock().clear();
591
592 self.shared.channel_senders.lock().clear();
595 self.shared.channel_buffers.lock().clear();
596 self.shared.channel_credits.lock().clear();
597 }
598
599 async fn handle_local_control(&mut self, control: DriverLocalControl) {
600 match control {
601 DriverLocalControl::CloseChannel { channel_id } => {
602 let _ = self
603 .sender
604 .send(ConnectionMessage::Channel(ChannelMessage {
605 id: channel_id,
606 body: ChannelBody::Close(ChannelClose {
607 metadata: Default::default(),
608 }),
609 }))
610 .await;
611 }
612 }
613 }
614
615 fn handle_msg(&mut self, msg: SelfRef<ConnectionMessage<'static>>) {
616 let is_request = matches!(&*msg, ConnectionMessage::Request(_));
617 if is_request {
618 let msg = msg.map(|m| match m {
619 ConnectionMessage::Request(r) => r,
620 _ => unreachable!(),
621 });
622 self.handle_request(msg);
623 } else {
624 let msg = msg.map(|m| match m {
625 ConnectionMessage::Channel(c) => c,
626 _ => unreachable!(),
627 });
628 self.handle_channel(msg);
629 }
630 }
631
632 fn handle_request(&mut self, msg: SelfRef<RequestMessage<'static>>) {
633 let req_id = msg.id;
634 let is_call = matches!(&msg.body, RequestBody::Call(_));
635 let is_response = matches!(&msg.body, RequestBody::Response(_));
636 let is_cancel = matches!(&msg.body, RequestBody::Cancel(_));
637
638 if is_call {
639 let reply = DriverReplySink {
642 sender: Some(self.sender.clone()),
643 request_id: req_id,
644 binder: self.internal_binder(),
645 };
646 let call = msg.map(|m| match m.body {
647 RequestBody::Call(c) => c,
648 _ => unreachable!(),
649 });
650 let handler = Arc::clone(&self.handler);
651 let join_handle = moire::task::spawn(
652 async move {
653 handler.handle(call, reply).await;
654 }
655 .named("handler"),
656 );
657 self.in_flight_handlers.insert(req_id, join_handle);
658 } else if is_response {
659 if let Some(tx) = self.shared.pending_responses.lock().remove(&req_id) {
661 let _: Result<(), _> = tx.send(msg);
662 }
663 } else if is_cancel {
664 if let Some(handle) = self.in_flight_handlers.remove(&req_id) {
669 handle.abort();
670 }
671 }
674 }
675
676 fn handle_channel(&mut self, msg: SelfRef<ChannelMessage<'static>>) {
677 let chan_id = msg.id;
678
679 let sender = self.shared.channel_senders.lock().get(&chan_id).cloned();
682
683 match &msg.body {
684 ChannelBody::Item(_item) => {
686 if let Some(tx) = &sender {
687 let item = msg.map(|m| match m.body {
688 ChannelBody::Item(item) => item,
689 _ => unreachable!(),
690 });
691 let _ = tx.try_send(IncomingChannelMessage::Item(item));
693 } else {
694 let item = msg.map(|m| match m.body {
696 ChannelBody::Item(item) => item,
697 _ => unreachable!(),
698 });
699 self.shared
700 .channel_buffers
701 .lock()
702 .entry(chan_id)
703 .or_default()
704 .push(IncomingChannelMessage::Item(item));
705 }
706 }
707 ChannelBody::Close(_close) => {
709 if let Some(tx) = &sender {
710 let close = msg.map(|m| match m.body {
711 ChannelBody::Close(close) => close,
712 _ => unreachable!(),
713 });
714 let _ = tx.try_send(IncomingChannelMessage::Close(close));
715 } else {
716 let close = msg.map(|m| match m.body {
718 ChannelBody::Close(close) => close,
719 _ => unreachable!(),
720 });
721 self.shared
722 .channel_buffers
723 .lock()
724 .entry(chan_id)
725 .or_default()
726 .push(IncomingChannelMessage::Close(close));
727 }
728 self.shared.channel_senders.lock().remove(&chan_id);
729 self.shared.channel_credits.lock().remove(&chan_id);
730 }
731 ChannelBody::Reset(_reset) => {
733 if let Some(tx) = &sender {
734 let reset = msg.map(|m| match m.body {
735 ChannelBody::Reset(reset) => reset,
736 _ => unreachable!(),
737 });
738 let _ = tx.try_send(IncomingChannelMessage::Reset(reset));
739 } else {
740 let reset = msg.map(|m| match m.body {
742 ChannelBody::Reset(reset) => reset,
743 _ => unreachable!(),
744 });
745 self.shared
746 .channel_buffers
747 .lock()
748 .entry(chan_id)
749 .or_default()
750 .push(IncomingChannelMessage::Reset(reset));
751 }
752 self.shared.channel_senders.lock().remove(&chan_id);
753 self.shared.channel_credits.lock().remove(&chan_id);
754 }
755 ChannelBody::GrantCredit(grant) => {
758 if let Some(semaphore) = self.shared.channel_credits.lock().get(&chan_id) {
759 semaphore.add_permits(grant.additional as usize);
760 }
761 }
762 }
763 }
764}