1use std::{
2 collections::BTreeMap,
3 pin::Pin,
4 sync::{Arc, Weak},
5};
6
7use moire::sync::SyncMutex;
8use tokio::sync::Semaphore;
9
10use moire::task::FutureExt as _;
11use roam_types::{
12 Caller, ChannelBinder, ChannelBody, ChannelClose, ChannelId, ChannelItem, ChannelMessage,
13 ChannelSink, CreditSink, Handler, IdAllocator, IncomingChannelMessage, MaybeSend, Payload,
14 ReplySink, RequestBody, RequestCall, RequestId, RequestMessage, RequestResponse, RoamError,
15 SelfRef, TxError,
16};
17
18use crate::session::{ConnectionHandle, ConnectionMessage, ConnectionSender, DropControlRequest};
19use moire::sync::mpsc;
20
21type ResponseSlot = moire::sync::oneshot::Sender<SelfRef<RequestMessage<'static>>>;
22
23struct DriverShared {
25 pending_responses: SyncMutex<BTreeMap<RequestId, ResponseSlot>>,
26 request_ids: SyncMutex<IdAllocator<RequestId>>,
27 channel_ids: SyncMutex<IdAllocator<ChannelId>>,
28 channel_senders:
30 SyncMutex<BTreeMap<ChannelId, tokio::sync::mpsc::Sender<IncomingChannelMessage>>>,
31 channel_buffers: SyncMutex<BTreeMap<ChannelId, Vec<IncomingChannelMessage>>>,
38 channel_credits: SyncMutex<BTreeMap<ChannelId, Arc<Semaphore>>>,
41}
42
43struct CallerDropGuard {
44 control_tx: mpsc::UnboundedSender<DropControlRequest>,
45 request: DropControlRequest,
46}
47
48impl Drop for CallerDropGuard {
49 fn drop(&mut self) {
50 let _ = self.control_tx.send(self.request);
51 }
52}
53
54pub struct DriverReplySink {
61 sender: Option<ConnectionSender>,
62 request_id: RequestId,
63 binder: DriverChannelBinder,
64}
65
66impl ReplySink for DriverReplySink {
67 async fn send_reply(mut self, response: RequestResponse<'_>) {
68 let sender = self
69 .sender
70 .take()
71 .expect("unreachable: send_reply takes self by value");
72 if let Err(_e) = sender.send_response(self.request_id, response).await {
73 sender.mark_failure(self.request_id, "send_response failed");
74 }
75 }
76
77 fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
78 Some(&self.binder)
79 }
80}
81
82impl Drop for DriverReplySink {
84 fn drop(&mut self) {
85 if let Some(sender) = self.sender.take() {
86 sender.mark_failure(self.request_id, "no reply sent")
87 }
88 }
89}
90
91pub struct DriverChannelSink {
99 sender: ConnectionSender,
100 channel_id: ChannelId,
101 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
102}
103
104impl ChannelSink for DriverChannelSink {
105 fn send_payload<'payload>(
106 &self,
107 payload: Payload<'payload>,
108 ) -> Pin<Box<dyn std::future::Future<Output = Result<(), TxError>> + Send + 'payload>> {
109 let sender = self.sender.clone();
110 let channel_id = self.channel_id;
111 Box::pin(async move {
112 sender
113 .send(ConnectionMessage::Channel(ChannelMessage {
114 id: channel_id,
115 body: ChannelBody::Item(ChannelItem { item: payload }),
116 }))
117 .await
118 .map_err(|()| TxError::Transport("connection closed".into()))
119 })
120 }
121
122 fn close_channel(
123 &self,
124 _metadata: roam_types::Metadata,
125 ) -> Pin<Box<dyn std::future::Future<Output = Result<(), TxError>> + Send + 'static>> {
126 let sender = self.sender.clone();
130 let channel_id = self.channel_id;
131 Box::pin(async move {
132 sender
133 .send(ConnectionMessage::Channel(ChannelMessage {
134 id: channel_id,
135 body: ChannelBody::Close(ChannelClose {
136 metadata: Default::default(),
137 }),
138 }))
139 .await
140 .map_err(|()| TxError::Transport("connection closed".into()))
141 })
142 }
143
144 fn close_channel_on_drop(&self) {
145 let _ = self
146 .local_control_tx
147 .send(DriverLocalControl::CloseChannel {
148 channel_id: self.channel_id,
149 });
150 }
151}
152
153impl From<DriverCaller> for () {
156 fn from(_: DriverCaller) {}
157}
158
159#[derive(Clone)]
160struct DriverChannelBinder {
161 sender: ConnectionSender,
162 shared: Arc<DriverShared>,
163 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
164}
165
166impl DriverChannelBinder {
167 fn create_tx_channel(
168 &self,
169 initial_credit: u32,
170 ) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
171 let channel_id = self.shared.channel_ids.lock().alloc();
172 let inner = DriverChannelSink {
173 sender: self.sender.clone(),
174 channel_id,
175 local_control_tx: self.local_control_tx.clone(),
176 };
177 let sink = Arc::new(CreditSink::new(inner, initial_credit));
178 self.shared
179 .channel_credits
180 .lock()
181 .insert(channel_id, Arc::clone(sink.credit()));
182 (channel_id, sink)
183 }
184
185 fn register_rx_channel(
186 &self,
187 channel_id: ChannelId,
188 ) -> tokio::sync::mpsc::Receiver<IncomingChannelMessage> {
189 let (tx, rx) = tokio::sync::mpsc::channel(64);
190 let mut terminal_buffered = false;
191 if let Some(buffered) = self.shared.channel_buffers.lock().remove(&channel_id) {
192 for msg in buffered {
193 let is_terminal = matches!(
194 msg,
195 IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
196 );
197 let _ = tx.try_send(msg);
198 if is_terminal {
199 terminal_buffered = true;
200 break;
201 }
202 }
203 }
204 if terminal_buffered {
205 self.shared.channel_credits.lock().remove(&channel_id);
206 return rx;
207 }
208
209 self.shared.channel_senders.lock().insert(channel_id, tx);
210 rx
211 }
212}
213
214impl ChannelBinder for DriverChannelBinder {
215 fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
216 let (id, sink) = self.create_tx_channel(initial_credit);
217 (id, sink as Arc<dyn ChannelSink>)
218 }
219
220 fn create_rx(
221 &self,
222 ) -> (
223 ChannelId,
224 tokio::sync::mpsc::Receiver<IncomingChannelMessage>,
225 ) {
226 let channel_id = self.shared.channel_ids.lock().alloc();
227 let rx = self.register_rx_channel(channel_id);
228 (channel_id, rx)
229 }
230
231 fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
232 let inner = DriverChannelSink {
233 sender: self.sender.clone(),
234 channel_id,
235 local_control_tx: self.local_control_tx.clone(),
236 };
237 let sink = Arc::new(CreditSink::new(inner, initial_credit));
238 self.shared
239 .channel_credits
240 .lock()
241 .insert(channel_id, Arc::clone(sink.credit()));
242 sink
243 }
244
245 fn register_rx(
246 &self,
247 channel_id: ChannelId,
248 ) -> tokio::sync::mpsc::Receiver<IncomingChannelMessage> {
249 self.register_rx_channel(channel_id)
250 }
251}
252
253#[derive(Clone)]
256pub struct DriverCaller {
257 sender: ConnectionSender,
258 shared: Arc<DriverShared>,
259 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
260 _drop_guard: Option<Arc<CallerDropGuard>>,
261}
262
263impl DriverCaller {
264 pub fn create_tx_channel(
270 &self,
271 initial_credit: u32,
272 ) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
273 let channel_id = self.shared.channel_ids.lock().alloc();
274 let inner = DriverChannelSink {
275 sender: self.sender.clone(),
276 channel_id,
277 local_control_tx: self.local_control_tx.clone(),
278 };
279 let sink = Arc::new(CreditSink::new(inner, initial_credit));
280 self.shared
281 .channel_credits
282 .lock()
283 .insert(channel_id, Arc::clone(sink.credit()));
284 (channel_id, sink)
285 }
286
287 #[cfg(test)]
292 pub(crate) fn connection_sender(&self) -> &ConnectionSender {
293 &self.sender
294 }
295
296 pub fn register_rx_channel(
301 &self,
302 channel_id: ChannelId,
303 ) -> tokio::sync::mpsc::Receiver<IncomingChannelMessage> {
304 let (tx, rx) = tokio::sync::mpsc::channel(64);
305 let mut terminal_buffered = false;
306 if let Some(buffered) = self.shared.channel_buffers.lock().remove(&channel_id) {
308 for msg in buffered {
309 let is_terminal = matches!(
310 msg,
311 IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
312 );
313 let _ = tx.try_send(msg);
314 if is_terminal {
315 terminal_buffered = true;
316 break;
317 }
318 }
319 }
320 if terminal_buffered {
321 self.shared.channel_credits.lock().remove(&channel_id);
322 return rx;
323 }
324
325 self.shared.channel_senders.lock().insert(channel_id, tx);
326 rx
327 }
328}
329
330impl ChannelBinder for DriverCaller {
331 fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
332 let (id, sink) = self.create_tx_channel(initial_credit);
333 (id, sink as Arc<dyn ChannelSink>)
334 }
335
336 fn create_rx(
337 &self,
338 ) -> (
339 ChannelId,
340 tokio::sync::mpsc::Receiver<IncomingChannelMessage>,
341 ) {
342 let channel_id = self.shared.channel_ids.lock().alloc();
343 let rx = self.register_rx_channel(channel_id);
344 (channel_id, rx)
345 }
346
347 fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
348 let inner = DriverChannelSink {
349 sender: self.sender.clone(),
350 channel_id,
351 local_control_tx: self.local_control_tx.clone(),
352 };
353 let sink = Arc::new(CreditSink::new(inner, initial_credit));
354 self.shared
355 .channel_credits
356 .lock()
357 .insert(channel_id, Arc::clone(sink.credit()));
358 sink
359 }
360
361 fn register_rx(
362 &self,
363 channel_id: ChannelId,
364 ) -> tokio::sync::mpsc::Receiver<IncomingChannelMessage> {
365 self.register_rx_channel(channel_id)
366 }
367}
368
369impl Caller for DriverCaller {
370 fn call<'a>(
371 &'a self,
372 call: RequestCall<'a>,
373 ) -> impl std::future::Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>>
374 + MaybeSend
375 + 'a {
376 async {
377 let req_id = self.shared.request_ids.lock().alloc();
379
380 let (tx, rx) = moire::sync::oneshot::channel("driver.response");
383 self.shared.pending_responses.lock().insert(req_id, tx);
384
385 let send_result = self
388 .sender
389 .send(ConnectionMessage::Request(RequestMessage {
390 id: req_id,
391 body: RequestBody::Call(call),
392 }))
393 .await;
394
395 if send_result.is_err() {
396 self.shared.pending_responses.lock().remove(&req_id);
398 return Err(RoamError::Cancelled);
399 }
400
401 let response_msg: SelfRef<RequestMessage<'static>> = rx
403 .named("awaiting_response")
404 .await
405 .map_err(|_| RoamError::Cancelled)?;
406
407 let response = response_msg.map(|m| match m.body {
409 RequestBody::Response(r) => r,
410 _ => unreachable!("pending_responses only gets Response variants"),
411 });
412
413 Ok(response)
414 }
415 .named("Caller::call")
416 }
417
418 fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
419 Some(self)
420 }
421}
422
423pub struct Driver<H: Handler<DriverReplySink>> {
430 sender: ConnectionSender,
431 rx: mpsc::Receiver<SelfRef<ConnectionMessage<'static>>>,
432 failures_rx: mpsc::UnboundedReceiver<(RequestId, &'static str)>,
433 local_control_rx: mpsc::UnboundedReceiver<DriverLocalControl>,
434 handler: Arc<H>,
435 shared: Arc<DriverShared>,
436 in_flight_handlers: BTreeMap<RequestId, moire::task::JoinHandle<()>>,
439 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
440 drop_control_seed: Option<mpsc::UnboundedSender<DropControlRequest>>,
441 drop_control_request: DropControlRequest,
442 drop_guard: SyncMutex<Option<Weak<CallerDropGuard>>>,
443}
444
445enum DriverLocalControl {
446 CloseChannel { channel_id: ChannelId },
447}
448
449impl<H: Handler<DriverReplySink>> Driver<H> {
450 pub fn new(handle: ConnectionHandle, handler: H) -> Self {
451 let conn_id = handle.connection_id();
452 let ConnectionHandle {
453 sender,
454 rx,
455 failures_rx,
456 control_tx,
457 parity,
458 } = handle;
459 let drop_control_request = DropControlRequest::Close(conn_id);
460 let (local_control_tx, local_control_rx) = mpsc::unbounded_channel("driver.local_control");
461 Self {
462 sender,
463 rx,
464 failures_rx,
465 local_control_rx,
466 handler: Arc::new(handler),
467 shared: Arc::new(DriverShared {
468 pending_responses: SyncMutex::new("driver.pending_responses", BTreeMap::new()),
469 request_ids: SyncMutex::new("driver.request_ids", IdAllocator::new(parity)),
470 channel_ids: SyncMutex::new("driver.channel_ids", IdAllocator::new(parity)),
471 channel_senders: SyncMutex::new("driver.channel_senders", BTreeMap::new()),
472 channel_buffers: SyncMutex::new("driver.channel_buffers", BTreeMap::new()),
473 channel_credits: SyncMutex::new("driver.channel_credits", BTreeMap::new()),
474 }),
475 in_flight_handlers: BTreeMap::new(),
476 local_control_tx,
477 drop_control_seed: control_tx,
478 drop_control_request,
479 drop_guard: SyncMutex::new("driver.drop_guard", None),
480 }
481 }
482
483 pub fn caller(&self) -> DriverCaller {
489 let drop_guard = if let Some(seed) = &self.drop_control_seed {
490 let mut guard = self.drop_guard.lock();
491 if let Some(existing) = guard.as_ref().and_then(Weak::upgrade) {
492 Some(existing)
493 } else {
494 let arc = Arc::new(CallerDropGuard {
495 control_tx: seed.clone(),
496 request: self.drop_control_request,
497 });
498 *guard = Some(Arc::downgrade(&arc));
499 Some(arc)
500 }
501 } else {
502 None
503 };
504 DriverCaller {
505 sender: self.sender.clone(),
506 shared: Arc::clone(&self.shared),
507 local_control_tx: self.local_control_tx.clone(),
508 _drop_guard: drop_guard,
509 }
510 }
511
512 fn internal_binder(&self) -> DriverChannelBinder {
513 DriverChannelBinder {
514 sender: self.sender.clone(),
515 shared: Arc::clone(&self.shared),
516 local_control_tx: self.local_control_tx.clone(),
517 }
518 }
519
520 pub async fn run(&mut self) {
525 loop {
526 tokio::select! {
527 msg = self.rx.recv() => {
528 match msg {
529 Some(msg) => self.handle_msg(msg),
530 None => break,
531 }
532 }
533 Some((req_id, _reason)) = self.failures_rx.recv() => {
534 self.in_flight_handlers.remove(&req_id);
536 if self.shared.pending_responses.lock().remove(&req_id).is_none() {
537 let error: Result<(), RoamError<core::convert::Infallible>> =
541 Err(RoamError::Cancelled);
542 let _ = self.sender.send_response(req_id, RequestResponse {
543 ret: Payload::outgoing(&error),
544 channels: vec![],
545 metadata: Default::default(),
546 }).await;
547 }
548 }
549 Some(ctrl) = self.local_control_rx.recv() => {
550 self.handle_local_control(ctrl).await;
551 }
552 }
553 }
554
555 for (_, handle) in std::mem::take(&mut self.in_flight_handlers) {
556 handle.abort();
557 }
558 self.shared.pending_responses.lock().clear();
559
560 self.shared.channel_senders.lock().clear();
563 self.shared.channel_buffers.lock().clear();
564 self.shared.channel_credits.lock().clear();
565 }
566
567 async fn handle_local_control(&mut self, control: DriverLocalControl) {
568 match control {
569 DriverLocalControl::CloseChannel { channel_id } => {
570 let _ = self
571 .sender
572 .send(ConnectionMessage::Channel(ChannelMessage {
573 id: channel_id,
574 body: ChannelBody::Close(ChannelClose {
575 metadata: Default::default(),
576 }),
577 }))
578 .await;
579 }
580 }
581 }
582
583 fn handle_msg(&mut self, msg: SelfRef<ConnectionMessage<'static>>) {
584 let is_request = matches!(&*msg, ConnectionMessage::Request(_));
585 if is_request {
586 let msg = msg.map(|m| match m {
587 ConnectionMessage::Request(r) => r,
588 _ => unreachable!(),
589 });
590 self.handle_request(msg);
591 } else {
592 let msg = msg.map(|m| match m {
593 ConnectionMessage::Channel(c) => c,
594 _ => unreachable!(),
595 });
596 self.handle_channel(msg);
597 }
598 }
599
600 fn handle_request(&mut self, msg: SelfRef<RequestMessage<'static>>) {
601 let req_id = msg.id;
602 let is_call = matches!(&msg.body, RequestBody::Call(_));
603 let is_response = matches!(&msg.body, RequestBody::Response(_));
604 let is_cancel = matches!(&msg.body, RequestBody::Cancel(_));
605
606 if is_call {
607 let reply = DriverReplySink {
610 sender: Some(self.sender.clone()),
611 request_id: req_id,
612 binder: self.internal_binder(),
613 };
614 let call = msg.map(|m| match m.body {
615 RequestBody::Call(c) => c,
616 _ => unreachable!(),
617 });
618 let handler = Arc::clone(&self.handler);
619 let join_handle = moire::task::spawn(
620 async move {
621 handler.handle(call, reply).await;
622 }
623 .named("handler"),
624 );
625 self.in_flight_handlers.insert(req_id, join_handle);
626 } else if is_response {
627 if let Some(tx) = self.shared.pending_responses.lock().remove(&req_id) {
629 let _: Result<(), _> = tx.send(msg);
630 }
631 } else if is_cancel {
632 if let Some(handle) = self.in_flight_handlers.remove(&req_id) {
637 handle.abort();
638 }
639 }
642 }
643
644 fn handle_channel(&mut self, msg: SelfRef<ChannelMessage<'static>>) {
645 let chan_id = msg.id;
646
647 let sender = self.shared.channel_senders.lock().get(&chan_id).cloned();
650
651 match &msg.body {
652 ChannelBody::Item(_item) => {
654 if let Some(tx) = &sender {
655 let item = msg.map(|m| match m.body {
656 ChannelBody::Item(item) => item,
657 _ => unreachable!(),
658 });
659 let _ = tx.try_send(IncomingChannelMessage::Item(item));
661 } else {
662 let item = msg.map(|m| match m.body {
664 ChannelBody::Item(item) => item,
665 _ => unreachable!(),
666 });
667 self.shared
668 .channel_buffers
669 .lock()
670 .entry(chan_id)
671 .or_default()
672 .push(IncomingChannelMessage::Item(item));
673 }
674 }
675 ChannelBody::Close(_close) => {
677 if let Some(tx) = &sender {
678 let close = msg.map(|m| match m.body {
679 ChannelBody::Close(close) => close,
680 _ => unreachable!(),
681 });
682 let _ = tx.try_send(IncomingChannelMessage::Close(close));
683 } else {
684 let close = msg.map(|m| match m.body {
686 ChannelBody::Close(close) => close,
687 _ => unreachable!(),
688 });
689 self.shared
690 .channel_buffers
691 .lock()
692 .entry(chan_id)
693 .or_default()
694 .push(IncomingChannelMessage::Close(close));
695 }
696 self.shared.channel_senders.lock().remove(&chan_id);
697 self.shared.channel_credits.lock().remove(&chan_id);
698 }
699 ChannelBody::Reset(_reset) => {
701 if let Some(tx) = &sender {
702 let reset = msg.map(|m| match m.body {
703 ChannelBody::Reset(reset) => reset,
704 _ => unreachable!(),
705 });
706 let _ = tx.try_send(IncomingChannelMessage::Reset(reset));
707 } else {
708 let reset = msg.map(|m| match m.body {
710 ChannelBody::Reset(reset) => reset,
711 _ => unreachable!(),
712 });
713 self.shared
714 .channel_buffers
715 .lock()
716 .entry(chan_id)
717 .or_default()
718 .push(IncomingChannelMessage::Reset(reset));
719 }
720 self.shared.channel_senders.lock().remove(&chan_id);
721 self.shared.channel_credits.lock().remove(&chan_id);
722 }
723 ChannelBody::GrantCredit(grant) => {
726 if let Some(semaphore) = self.shared.channel_credits.lock().get(&chan_id) {
727 semaphore.add_permits(grant.additional as usize);
728 }
729 }
730 }
731 }
732}