1use std::{
2 collections::BTreeMap,
3 pin::Pin,
4 sync::{Arc, Weak},
5};
6
7use moire::sync::SyncMutex;
8use tokio::sync::{Semaphore, watch};
9
10use moire::task::FutureExt as _;
11use roam_types::{
12 Caller, ChannelBinder, ChannelBody, ChannelClose, ChannelCreditReplenisher,
13 ChannelCreditReplenisherHandle, ChannelId, ChannelItem, ChannelLivenessHandle, ChannelMessage,
14 ChannelSink, CreditSink, Handler, IdAllocator, IncomingChannelMessage, MaybeSend, Payload,
15 ReplySink, RequestBody, RequestCall, RequestId, RequestMessage, RequestResponse, RoamError,
16 SelfRef, TxError,
17};
18
19use crate::session::{ConnectionHandle, ConnectionMessage, ConnectionSender, DropControlRequest};
20use moire::sync::mpsc;
21
22type ResponseSlot = moire::sync::oneshot::Sender<SelfRef<RequestMessage<'static>>>;
23
24struct DriverShared {
26 pending_responses: SyncMutex<BTreeMap<RequestId, ResponseSlot>>,
27 request_ids: SyncMutex<IdAllocator<RequestId>>,
28 channel_ids: SyncMutex<IdAllocator<ChannelId>>,
29 channel_senders:
31 SyncMutex<BTreeMap<ChannelId, tokio::sync::mpsc::Sender<IncomingChannelMessage>>>,
32 channel_buffers: SyncMutex<BTreeMap<ChannelId, Vec<IncomingChannelMessage>>>,
39 channel_credits: SyncMutex<BTreeMap<ChannelId, Arc<Semaphore>>>,
42}
43
44struct CallerDropGuard {
45 control_tx: mpsc::UnboundedSender<DropControlRequest>,
46 request: DropControlRequest,
47}
48
49impl Drop for CallerDropGuard {
50 fn drop(&mut self) {
51 let _ = self.control_tx.send(self.request);
52 }
53}
54
55#[cfg(test)]
56mod tests {
57 use super::{DriverChannelCreditReplenisher, DriverLocalControl};
58 use roam_types::{ChannelCreditReplenisher, ChannelId};
59 use tokio::sync::mpsc::error::TryRecvError;
60
61 #[test]
62 fn replenisher_batches_at_half_the_initial_window() {
63 let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher");
64 let replenisher = DriverChannelCreditReplenisher::new(ChannelId(7), 16, tx);
65
66 for _ in 0..7 {
67 replenisher.on_item_consumed();
68 }
69 assert!(
70 matches!(rx.try_recv(), Err(TryRecvError::Empty)),
71 "should not emit credit before reaching the batch threshold"
72 );
73
74 replenisher.on_item_consumed();
75 let Ok(DriverLocalControl::GrantCredit {
76 channel_id,
77 additional,
78 }) = rx.try_recv()
79 else {
80 panic!("expected batched credit grant");
81 };
82 assert_eq!(channel_id, ChannelId(7));
83 assert_eq!(additional, 8);
84 }
85
86 #[test]
87 fn replenisher_grants_one_by_one_for_single_credit_windows() {
88 let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher.single");
89 let replenisher = DriverChannelCreditReplenisher::new(ChannelId(9), 1, tx);
90
91 replenisher.on_item_consumed();
92 let Ok(DriverLocalControl::GrantCredit {
93 channel_id,
94 additional,
95 }) = rx.try_recv()
96 else {
97 panic!("expected immediate credit grant");
98 };
99 assert_eq!(channel_id, ChannelId(9));
100 assert_eq!(additional, 1);
101 }
102}
103
104pub struct DriverReplySink {
111 sender: Option<ConnectionSender>,
112 request_id: RequestId,
113 binder: DriverChannelBinder,
114}
115
116impl ReplySink for DriverReplySink {
117 async fn send_reply(mut self, response: RequestResponse<'_>) {
118 let sender = self
119 .sender
120 .take()
121 .expect("unreachable: send_reply takes self by value");
122 if let Err(_e) = sender.send_response(self.request_id, response).await {
123 sender.mark_failure(self.request_id, "send_response failed");
124 }
125 }
126
127 fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
128 Some(&self.binder)
129 }
130}
131
132impl Drop for DriverReplySink {
134 fn drop(&mut self) {
135 if let Some(sender) = self.sender.take() {
136 sender.mark_failure(self.request_id, "no reply sent")
137 }
138 }
139}
140
141pub struct DriverChannelSink {
149 sender: ConnectionSender,
150 channel_id: ChannelId,
151 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
152}
153
154impl ChannelSink for DriverChannelSink {
155 fn send_payload<'payload>(
156 &self,
157 payload: Payload<'payload>,
158 ) -> Pin<Box<dyn std::future::Future<Output = Result<(), TxError>> + Send + 'payload>> {
159 let sender = self.sender.clone();
160 let channel_id = self.channel_id;
161 Box::pin(async move {
162 sender
163 .send(ConnectionMessage::Channel(ChannelMessage {
164 id: channel_id,
165 body: ChannelBody::Item(ChannelItem { item: payload }),
166 }))
167 .await
168 .map_err(|()| TxError::Transport("connection closed".into()))
169 })
170 }
171
172 fn close_channel(
173 &self,
174 _metadata: roam_types::Metadata,
175 ) -> Pin<Box<dyn std::future::Future<Output = Result<(), TxError>> + Send + 'static>> {
176 let sender = self.sender.clone();
180 let channel_id = self.channel_id;
181 Box::pin(async move {
182 sender
183 .send(ConnectionMessage::Channel(ChannelMessage {
184 id: channel_id,
185 body: ChannelBody::Close(ChannelClose {
186 metadata: Default::default(),
187 }),
188 }))
189 .await
190 .map_err(|()| TxError::Transport("connection closed".into()))
191 })
192 }
193
194 fn close_channel_on_drop(&self) {
195 let _ = self
196 .local_control_tx
197 .send(DriverLocalControl::CloseChannel {
198 channel_id: self.channel_id,
199 });
200 }
201}
202
203#[must_use = "Dropping NoopCaller may close the connection if it is the last caller."]
207#[derive(Clone)]
208pub struct NoopCaller(#[allow(dead_code)] DriverCaller);
209
210impl From<DriverCaller> for NoopCaller {
211 fn from(caller: DriverCaller) -> Self {
212 Self(caller)
213 }
214}
215
216#[derive(Clone)]
217struct DriverChannelBinder {
218 sender: ConnectionSender,
219 shared: Arc<DriverShared>,
220 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
221 drop_guard: Option<Arc<CallerDropGuard>>,
222}
223
224impl DriverChannelBinder {
225 fn create_tx_channel(
226 &self,
227 initial_credit: u32,
228 ) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
229 let channel_id = self.shared.channel_ids.lock().alloc();
230 let inner = DriverChannelSink {
231 sender: self.sender.clone(),
232 channel_id,
233 local_control_tx: self.local_control_tx.clone(),
234 };
235 let sink = Arc::new(CreditSink::new(inner, initial_credit));
236 self.shared
237 .channel_credits
238 .lock()
239 .insert(channel_id, Arc::clone(sink.credit()));
240 (channel_id, sink)
241 }
242
243 fn register_rx_channel(
244 &self,
245 channel_id: ChannelId,
246 initial_credit: u32,
247 ) -> roam_types::BoundChannelReceiver {
248 let (tx, rx) = tokio::sync::mpsc::channel(64);
249 let mut terminal_buffered = false;
250 if let Some(buffered) = self.shared.channel_buffers.lock().remove(&channel_id) {
251 for msg in buffered {
252 let is_terminal = matches!(
253 msg,
254 IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
255 );
256 let _ = tx.try_send(msg);
257 if is_terminal {
258 terminal_buffered = true;
259 break;
260 }
261 }
262 }
263 if terminal_buffered {
264 self.shared.channel_credits.lock().remove(&channel_id);
265 return roam_types::BoundChannelReceiver {
266 receiver: rx,
267 liveness: self.channel_liveness(),
268 replenisher: None,
269 };
270 }
271
272 self.shared.channel_senders.lock().insert(channel_id, tx);
273 roam_types::BoundChannelReceiver {
274 receiver: rx,
275 liveness: self.channel_liveness(),
276 replenisher: Some(Arc::new(DriverChannelCreditReplenisher::new(
277 channel_id,
278 initial_credit,
279 self.local_control_tx.clone(),
280 )) as ChannelCreditReplenisherHandle),
281 }
282 }
283}
284
285impl ChannelBinder for DriverChannelBinder {
286 fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
287 let (id, sink) = self.create_tx_channel(initial_credit);
288 (id, sink as Arc<dyn ChannelSink>)
289 }
290
291 fn create_rx(&self, initial_credit: u32) -> (ChannelId, roam_types::BoundChannelReceiver) {
292 let channel_id = self.shared.channel_ids.lock().alloc();
293 let rx = self.register_rx_channel(channel_id, initial_credit);
294 (channel_id, rx)
295 }
296
297 fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
298 let inner = DriverChannelSink {
299 sender: self.sender.clone(),
300 channel_id,
301 local_control_tx: self.local_control_tx.clone(),
302 };
303 let sink = Arc::new(CreditSink::new(inner, initial_credit));
304 self.shared
305 .channel_credits
306 .lock()
307 .insert(channel_id, Arc::clone(sink.credit()));
308 sink
309 }
310
311 fn register_rx(
312 &self,
313 channel_id: ChannelId,
314 initial_credit: u32,
315 ) -> roam_types::BoundChannelReceiver {
316 self.register_rx_channel(channel_id, initial_credit)
317 }
318
319 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
320 self.drop_guard
321 .as_ref()
322 .map(|guard| guard.clone() as ChannelLivenessHandle)
323 }
324}
325
326#[derive(Clone)]
329pub struct DriverCaller {
330 sender: ConnectionSender,
331 shared: Arc<DriverShared>,
332 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
333 closed_rx: watch::Receiver<bool>,
334 _drop_guard: Option<Arc<CallerDropGuard>>,
335}
336
337impl DriverCaller {
338 pub fn create_tx_channel(
344 &self,
345 initial_credit: u32,
346 ) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
347 let channel_id = self.shared.channel_ids.lock().alloc();
348 let inner = DriverChannelSink {
349 sender: self.sender.clone(),
350 channel_id,
351 local_control_tx: self.local_control_tx.clone(),
352 };
353 let sink = Arc::new(CreditSink::new(inner, initial_credit));
354 self.shared
355 .channel_credits
356 .lock()
357 .insert(channel_id, Arc::clone(sink.credit()));
358 (channel_id, sink)
359 }
360
361 #[cfg(test)]
366 pub(crate) fn connection_sender(&self) -> &ConnectionSender {
367 &self.sender
368 }
369
370 pub fn register_rx_channel(
375 &self,
376 channel_id: ChannelId,
377 initial_credit: u32,
378 ) -> roam_types::BoundChannelReceiver {
379 let (tx, rx) = tokio::sync::mpsc::channel(64);
380 let mut terminal_buffered = false;
381 if let Some(buffered) = self.shared.channel_buffers.lock().remove(&channel_id) {
383 for msg in buffered {
384 let is_terminal = matches!(
385 msg,
386 IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
387 );
388 let _ = tx.try_send(msg);
389 if is_terminal {
390 terminal_buffered = true;
391 break;
392 }
393 }
394 }
395 if terminal_buffered {
396 self.shared.channel_credits.lock().remove(&channel_id);
397 return roam_types::BoundChannelReceiver {
398 receiver: rx,
399 liveness: self.channel_liveness(),
400 replenisher: None,
401 };
402 }
403
404 self.shared.channel_senders.lock().insert(channel_id, tx);
405 roam_types::BoundChannelReceiver {
406 receiver: rx,
407 liveness: self.channel_liveness(),
408 replenisher: Some(Arc::new(DriverChannelCreditReplenisher::new(
409 channel_id,
410 initial_credit,
411 self.local_control_tx.clone(),
412 )) as ChannelCreditReplenisherHandle),
413 }
414 }
415}
416
417impl ChannelBinder for DriverCaller {
418 fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
419 let (id, sink) = self.create_tx_channel(initial_credit);
420 (id, sink as Arc<dyn ChannelSink>)
421 }
422
423 fn create_rx(&self, initial_credit: u32) -> (ChannelId, roam_types::BoundChannelReceiver) {
424 let channel_id = self.shared.channel_ids.lock().alloc();
425 let rx = self.register_rx_channel(channel_id, initial_credit);
426 (channel_id, rx)
427 }
428
429 fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
430 let inner = DriverChannelSink {
431 sender: self.sender.clone(),
432 channel_id,
433 local_control_tx: self.local_control_tx.clone(),
434 };
435 let sink = Arc::new(CreditSink::new(inner, initial_credit));
436 self.shared
437 .channel_credits
438 .lock()
439 .insert(channel_id, Arc::clone(sink.credit()));
440 sink
441 }
442
443 fn register_rx(
444 &self,
445 channel_id: ChannelId,
446 initial_credit: u32,
447 ) -> roam_types::BoundChannelReceiver {
448 self.register_rx_channel(channel_id, initial_credit)
449 }
450
451 fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
452 self._drop_guard
453 .as_ref()
454 .map(|guard| guard.clone() as ChannelLivenessHandle)
455 }
456}
457
458impl Caller for DriverCaller {
459 fn call<'a>(
460 &'a self,
461 call: RequestCall<'a>,
462 ) -> impl std::future::Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>>
463 + MaybeSend
464 + 'a {
465 async {
466 let req_id = self.shared.request_ids.lock().alloc();
468
469 let (tx, rx) = moire::sync::oneshot::channel("driver.response");
472 self.shared.pending_responses.lock().insert(req_id, tx);
473
474 let send_result = self
477 .sender
478 .send(ConnectionMessage::Request(RequestMessage {
479 id: req_id,
480 body: RequestBody::Call(call),
481 }))
482 .await;
483
484 if send_result.is_err() {
485 self.shared.pending_responses.lock().remove(&req_id);
487 return Err(RoamError::Cancelled);
488 }
489
490 let response_msg: SelfRef<RequestMessage<'static>> = rx
492 .named("awaiting_response")
493 .await
494 .map_err(|_| RoamError::Cancelled)?;
495
496 let response = response_msg.map(|m| match m.body {
498 RequestBody::Response(r) => r,
499 _ => unreachable!("pending_responses only gets Response variants"),
500 });
501
502 Ok(response)
503 }
504 .named("Caller::call")
505 }
506
507 fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
508 Box::pin(async move {
509 if *self.closed_rx.borrow() {
510 return;
511 }
512 let mut rx = self.closed_rx.clone();
513 while rx.changed().await.is_ok() {
514 if *rx.borrow() {
515 return;
516 }
517 }
518 })
519 }
520
521 fn is_connected(&self) -> bool {
522 !*self.closed_rx.borrow()
523 }
524
525 fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
526 Some(self)
527 }
528}
529
530pub struct Driver<H: Handler<DriverReplySink>> {
537 sender: ConnectionSender,
538 rx: mpsc::Receiver<SelfRef<ConnectionMessage<'static>>>,
539 failures_rx: mpsc::UnboundedReceiver<(RequestId, &'static str)>,
540 closed_rx: watch::Receiver<bool>,
541 local_control_rx: mpsc::UnboundedReceiver<DriverLocalControl>,
542 handler: Arc<H>,
543 shared: Arc<DriverShared>,
544 in_flight_handlers: BTreeMap<RequestId, moire::task::JoinHandle<()>>,
547 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
548 drop_control_seed: Option<mpsc::UnboundedSender<DropControlRequest>>,
549 drop_control_request: DropControlRequest,
550 drop_guard: SyncMutex<Option<Weak<CallerDropGuard>>>,
551}
552
553enum DriverLocalControl {
554 CloseChannel {
555 channel_id: ChannelId,
556 },
557 GrantCredit {
558 channel_id: ChannelId,
559 additional: u32,
560 },
561}
562
563struct DriverChannelCreditReplenisher {
564 channel_id: ChannelId,
565 threshold: u32,
566 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
567 pending: std::sync::Mutex<u32>,
568}
569
570impl DriverChannelCreditReplenisher {
571 fn new(
572 channel_id: ChannelId,
573 initial_credit: u32,
574 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
575 ) -> Self {
576 Self {
577 channel_id,
578 threshold: (initial_credit / 2).max(1),
579 local_control_tx,
580 pending: std::sync::Mutex::new(0),
581 }
582 }
583}
584
585impl ChannelCreditReplenisher for DriverChannelCreditReplenisher {
586 fn on_item_consumed(&self) {
587 let mut pending = self.pending.lock().expect("pending credit mutex poisoned");
588 *pending += 1;
589 if *pending < self.threshold {
590 return;
591 }
592
593 let additional = *pending;
594 *pending = 0;
595 let _ = self.local_control_tx.send(DriverLocalControl::GrantCredit {
596 channel_id: self.channel_id,
597 additional,
598 });
599 }
600}
601
602impl<H: Handler<DriverReplySink>> Driver<H> {
603 pub fn new(handle: ConnectionHandle, handler: H) -> Self {
604 let conn_id = handle.connection_id();
605 let ConnectionHandle {
606 sender,
607 rx,
608 failures_rx,
609 control_tx,
610 closed_rx,
611 parity,
612 } = handle;
613 let drop_control_request = DropControlRequest::Close(conn_id);
614 let (local_control_tx, local_control_rx) = mpsc::unbounded_channel("driver.local_control");
615 Self {
616 sender,
617 rx,
618 failures_rx,
619 closed_rx,
620 local_control_rx,
621 handler: Arc::new(handler),
622 shared: Arc::new(DriverShared {
623 pending_responses: SyncMutex::new("driver.pending_responses", BTreeMap::new()),
624 request_ids: SyncMutex::new("driver.request_ids", IdAllocator::new(parity)),
625 channel_ids: SyncMutex::new("driver.channel_ids", IdAllocator::new(parity)),
626 channel_senders: SyncMutex::new("driver.channel_senders", BTreeMap::new()),
627 channel_buffers: SyncMutex::new("driver.channel_buffers", BTreeMap::new()),
628 channel_credits: SyncMutex::new("driver.channel_credits", BTreeMap::new()),
629 }),
630 in_flight_handlers: BTreeMap::new(),
631 local_control_tx,
632 drop_control_seed: control_tx,
633 drop_control_request,
634 drop_guard: SyncMutex::new("driver.drop_guard", None),
635 }
636 }
637
638 fn existing_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
644 self.drop_guard.lock().as_ref().and_then(Weak::upgrade)
645 }
646
647 fn connection_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
648 let drop_guard = if let Some(existing) = self.existing_drop_guard() {
649 Some(existing)
650 } else if let Some(seed) = &self.drop_control_seed {
651 let mut guard = self.drop_guard.lock();
652 if let Some(existing) = guard.as_ref().and_then(Weak::upgrade) {
653 Some(existing)
654 } else {
655 let arc = Arc::new(CallerDropGuard {
656 control_tx: seed.clone(),
657 request: self.drop_control_request,
658 });
659 *guard = Some(Arc::downgrade(&arc));
660 Some(arc)
661 }
662 } else {
663 None
664 };
665 drop_guard
666 }
667
668 pub fn caller(&self) -> DriverCaller {
669 let drop_guard = self.connection_drop_guard();
670 DriverCaller {
671 sender: self.sender.clone(),
672 shared: Arc::clone(&self.shared),
673 local_control_tx: self.local_control_tx.clone(),
674 closed_rx: self.closed_rx.clone(),
675 _drop_guard: drop_guard,
676 }
677 }
678
679 fn internal_binder(&self) -> DriverChannelBinder {
680 DriverChannelBinder {
681 sender: self.sender.clone(),
682 shared: Arc::clone(&self.shared),
683 local_control_tx: self.local_control_tx.clone(),
684 drop_guard: self.existing_drop_guard(),
685 }
686 }
687
688 pub async fn run(&mut self) {
693 loop {
694 tokio::select! {
695 msg = self.rx.recv() => {
696 match msg {
697 Some(msg) => self.handle_msg(msg),
698 None => break,
699 }
700 }
701 Some((req_id, _reason)) = self.failures_rx.recv() => {
702 self.in_flight_handlers.remove(&req_id);
704 if self.shared.pending_responses.lock().remove(&req_id).is_none() {
705 let error: Result<(), RoamError<core::convert::Infallible>> =
709 Err(RoamError::Cancelled);
710 let _ = self.sender.send_response(req_id, RequestResponse {
711 ret: Payload::outgoing(&error),
712 channels: vec![],
713 metadata: Default::default(),
714 }).await;
715 }
716 }
717 Some(ctrl) = self.local_control_rx.recv() => {
718 self.handle_local_control(ctrl).await;
719 }
720 }
721 }
722
723 for (_, handle) in std::mem::take(&mut self.in_flight_handlers) {
724 handle.abort();
725 }
726 self.shared.pending_responses.lock().clear();
727
728 self.shared.channel_senders.lock().clear();
731 self.shared.channel_buffers.lock().clear();
732 self.shared.channel_credits.lock().clear();
733 }
734
735 async fn handle_local_control(&mut self, control: DriverLocalControl) {
736 match control {
737 DriverLocalControl::CloseChannel { channel_id } => {
738 let _ = self
739 .sender
740 .send(ConnectionMessage::Channel(ChannelMessage {
741 id: channel_id,
742 body: ChannelBody::Close(ChannelClose {
743 metadata: Default::default(),
744 }),
745 }))
746 .await;
747 }
748 DriverLocalControl::GrantCredit {
749 channel_id,
750 additional,
751 } => {
752 let _ = self
753 .sender
754 .send(ConnectionMessage::Channel(ChannelMessage {
755 id: channel_id,
756 body: ChannelBody::GrantCredit(roam_types::ChannelGrantCredit {
757 additional,
758 }),
759 }))
760 .await;
761 }
762 }
763 }
764
765 fn handle_msg(&mut self, msg: SelfRef<ConnectionMessage<'static>>) {
766 let is_request = matches!(&*msg, ConnectionMessage::Request(_));
767 if is_request {
768 let msg = msg.map(|m| match m {
769 ConnectionMessage::Request(r) => r,
770 _ => unreachable!(),
771 });
772 self.handle_request(msg);
773 } else {
774 let msg = msg.map(|m| match m {
775 ConnectionMessage::Channel(c) => c,
776 _ => unreachable!(),
777 });
778 self.handle_channel(msg);
779 }
780 }
781
782 fn handle_request(&mut self, msg: SelfRef<RequestMessage<'static>>) {
783 let req_id = msg.id;
784 let is_call = matches!(&msg.body, RequestBody::Call(_));
785 let is_response = matches!(&msg.body, RequestBody::Response(_));
786 let is_cancel = matches!(&msg.body, RequestBody::Cancel(_));
787
788 if is_call {
789 let reply = DriverReplySink {
792 sender: Some(self.sender.clone()),
793 request_id: req_id,
794 binder: self.internal_binder(),
795 };
796 let call = msg.map(|m| match m.body {
797 RequestBody::Call(c) => c,
798 _ => unreachable!(),
799 });
800 let handler = Arc::clone(&self.handler);
801 let join_handle = moire::task::spawn(
802 async move {
803 handler.handle(call, reply).await;
804 }
805 .named("handler"),
806 );
807 self.in_flight_handlers.insert(req_id, join_handle);
808 } else if is_response {
809 if let Some(tx) = self.shared.pending_responses.lock().remove(&req_id) {
811 let _: Result<(), _> = tx.send(msg);
812 }
813 } else if is_cancel {
814 if let Some(handle) = self.in_flight_handlers.remove(&req_id) {
819 handle.abort();
820 }
821 }
824 }
825
826 fn handle_channel(&mut self, msg: SelfRef<ChannelMessage<'static>>) {
827 let chan_id = msg.id;
828
829 let sender = self.shared.channel_senders.lock().get(&chan_id).cloned();
832
833 match &msg.body {
834 ChannelBody::Item(_item) => {
836 if let Some(tx) = &sender {
837 let item = msg.map(|m| match m.body {
838 ChannelBody::Item(item) => item,
839 _ => unreachable!(),
840 });
841 let _ = tx.try_send(IncomingChannelMessage::Item(item));
843 } else {
844 let item = msg.map(|m| match m.body {
846 ChannelBody::Item(item) => item,
847 _ => unreachable!(),
848 });
849 self.shared
850 .channel_buffers
851 .lock()
852 .entry(chan_id)
853 .or_default()
854 .push(IncomingChannelMessage::Item(item));
855 }
856 }
857 ChannelBody::Close(_close) => {
859 if let Some(tx) = &sender {
860 let close = msg.map(|m| match m.body {
861 ChannelBody::Close(close) => close,
862 _ => unreachable!(),
863 });
864 let _ = tx.try_send(IncomingChannelMessage::Close(close));
865 } else {
866 let close = msg.map(|m| match m.body {
868 ChannelBody::Close(close) => close,
869 _ => unreachable!(),
870 });
871 self.shared
872 .channel_buffers
873 .lock()
874 .entry(chan_id)
875 .or_default()
876 .push(IncomingChannelMessage::Close(close));
877 }
878 self.shared.channel_senders.lock().remove(&chan_id);
879 self.shared.channel_credits.lock().remove(&chan_id);
880 }
881 ChannelBody::Reset(_reset) => {
883 if let Some(tx) = &sender {
884 let reset = msg.map(|m| match m.body {
885 ChannelBody::Reset(reset) => reset,
886 _ => unreachable!(),
887 });
888 let _ = tx.try_send(IncomingChannelMessage::Reset(reset));
889 } else {
890 let reset = msg.map(|m| match m.body {
892 ChannelBody::Reset(reset) => reset,
893 _ => unreachable!(),
894 });
895 self.shared
896 .channel_buffers
897 .lock()
898 .entry(chan_id)
899 .or_default()
900 .push(IncomingChannelMessage::Reset(reset));
901 }
902 self.shared.channel_senders.lock().remove(&chan_id);
903 self.shared.channel_credits.lock().remove(&chan_id);
904 }
905 ChannelBody::GrantCredit(grant) => {
908 if let Some(semaphore) = self.shared.channel_credits.lock().get(&chan_id) {
909 semaphore.add_permits(grant.additional as usize);
910 }
911 }
912 }
913 }
914}