1use std::{
2 collections::BTreeMap,
3 pin::Pin,
4 sync::{Arc, Weak},
5};
6
7use moire::sync::SyncMutex;
8use tokio::sync::Semaphore;
9
10use moire::task::FutureExt as _;
11use roam_types::{
12 Caller, ChannelBinder, ChannelBody, ChannelClose, ChannelId, ChannelItem, ChannelMessage,
13 ChannelSink, CreditSink, Handler, IdAllocator, IncomingChannelMessage, Payload, ReplySink,
14 RequestBody, RequestCall, RequestId, RequestMessage, RequestResponse, RoamError, SelfRef,
15 TxError,
16};
17
18use crate::session::{ConnectionHandle, ConnectionMessage, ConnectionSender, DropControlRequest};
19use moire::sync::mpsc;
20
21type ResponseSlot = moire::sync::oneshot::Sender<SelfRef<RequestMessage<'static>>>;
22
23struct DriverShared {
25 pending_responses: SyncMutex<BTreeMap<RequestId, ResponseSlot>>,
26 request_ids: SyncMutex<IdAllocator<RequestId>>,
27 channel_ids: SyncMutex<IdAllocator<ChannelId>>,
28 channel_senders:
30 SyncMutex<BTreeMap<ChannelId, tokio::sync::mpsc::Sender<IncomingChannelMessage>>>,
31 channel_buffers: SyncMutex<BTreeMap<ChannelId, Vec<IncomingChannelMessage>>>,
38 channel_credits: SyncMutex<BTreeMap<ChannelId, Arc<Semaphore>>>,
41}
42
43struct CallerDropGuard {
44 control_tx: mpsc::UnboundedSender<DropControlRequest>,
45 request: DropControlRequest,
46}
47
48impl Drop for CallerDropGuard {
49 fn drop(&mut self) {
50 let _ = self.control_tx.send(self.request);
51 }
52}
53
54pub struct DriverReplySink {
61 sender: Option<ConnectionSender>,
62 request_id: RequestId,
63 binder: DriverChannelBinder,
64}
65
66impl ReplySink for DriverReplySink {
67 async fn send_reply(mut self, response: RequestResponse<'_>) {
68 let sender = self
69 .sender
70 .take()
71 .expect("unreachable: send_reply takes self by value");
72 if let Err(_e) = sender.send_response(self.request_id, response).await {
73 sender.mark_failure(self.request_id, "send_response failed");
74 }
75 }
76
77 fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
78 Some(&self.binder)
79 }
80}
81
82impl Drop for DriverReplySink {
84 fn drop(&mut self) {
85 if let Some(sender) = self.sender.take() {
86 sender.mark_failure(self.request_id, "no reply sent")
87 }
88 }
89}
90
91pub struct DriverChannelSink {
99 sender: ConnectionSender,
100 channel_id: ChannelId,
101 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
102}
103
104impl ChannelSink for DriverChannelSink {
105 fn send_payload<'payload>(
106 &self,
107 payload: Payload<'payload>,
108 ) -> Pin<Box<dyn std::future::Future<Output = Result<(), TxError>> + Send + 'payload>> {
109 let sender = self.sender.clone();
110 let channel_id = self.channel_id;
111 Box::pin(async move {
112 sender
113 .send(ConnectionMessage::Channel(ChannelMessage {
114 id: channel_id,
115 body: ChannelBody::Item(ChannelItem { item: payload }),
116 }))
117 .await
118 .map_err(|()| TxError::Transport("connection closed".into()))
119 })
120 }
121
122 fn close_channel(
123 &self,
124 _metadata: roam_types::Metadata,
125 ) -> Pin<Box<dyn std::future::Future<Output = Result<(), TxError>> + Send + 'static>> {
126 let sender = self.sender.clone();
130 let channel_id = self.channel_id;
131 Box::pin(async move {
132 sender
133 .send(ConnectionMessage::Channel(ChannelMessage {
134 id: channel_id,
135 body: ChannelBody::Close(ChannelClose {
136 metadata: Default::default(),
137 }),
138 }))
139 .await
140 .map_err(|()| TxError::Transport("connection closed".into()))
141 })
142 }
143
144 fn close_channel_on_drop(&self) {
145 let _ = self
146 .local_control_tx
147 .send(DriverLocalControl::CloseChannel {
148 channel_id: self.channel_id,
149 });
150 }
151}
152
153impl From<DriverCaller> for () {
156 fn from(_: DriverCaller) {}
157}
158
159#[derive(Clone)]
160struct DriverChannelBinder {
161 sender: ConnectionSender,
162 shared: Arc<DriverShared>,
163 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
164}
165
166impl DriverChannelBinder {
167 fn create_tx_channel(
168 &self,
169 initial_credit: u32,
170 ) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
171 let channel_id = self.shared.channel_ids.lock().alloc();
172 let inner = DriverChannelSink {
173 sender: self.sender.clone(),
174 channel_id,
175 local_control_tx: self.local_control_tx.clone(),
176 };
177 let sink = Arc::new(CreditSink::new(inner, initial_credit));
178 self.shared
179 .channel_credits
180 .lock()
181 .insert(channel_id, Arc::clone(sink.credit()));
182 (channel_id, sink)
183 }
184
185 fn register_rx_channel(
186 &self,
187 channel_id: ChannelId,
188 ) -> tokio::sync::mpsc::Receiver<IncomingChannelMessage> {
189 let (tx, rx) = tokio::sync::mpsc::channel(64);
190 let mut terminal_buffered = false;
191 if let Some(buffered) = self.shared.channel_buffers.lock().remove(&channel_id) {
192 for msg in buffered {
193 let is_terminal = matches!(
194 msg,
195 IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
196 );
197 let _ = tx.try_send(msg);
198 if is_terminal {
199 terminal_buffered = true;
200 break;
201 }
202 }
203 }
204 if terminal_buffered {
205 self.shared.channel_credits.lock().remove(&channel_id);
206 return rx;
207 }
208
209 self.shared.channel_senders.lock().insert(channel_id, tx);
210 rx
211 }
212}
213
214impl ChannelBinder for DriverChannelBinder {
215 fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
216 let (id, sink) = self.create_tx_channel(initial_credit);
217 (id, sink as Arc<dyn ChannelSink>)
218 }
219
220 fn create_rx(
221 &self,
222 ) -> (
223 ChannelId,
224 tokio::sync::mpsc::Receiver<IncomingChannelMessage>,
225 ) {
226 let channel_id = self.shared.channel_ids.lock().alloc();
227 let rx = self.register_rx_channel(channel_id);
228 (channel_id, rx)
229 }
230
231 fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
232 let inner = DriverChannelSink {
233 sender: self.sender.clone(),
234 channel_id,
235 local_control_tx: self.local_control_tx.clone(),
236 };
237 let sink = Arc::new(CreditSink::new(inner, initial_credit));
238 self.shared
239 .channel_credits
240 .lock()
241 .insert(channel_id, Arc::clone(sink.credit()));
242 sink
243 }
244
245 fn register_rx(
246 &self,
247 channel_id: ChannelId,
248 ) -> tokio::sync::mpsc::Receiver<IncomingChannelMessage> {
249 self.register_rx_channel(channel_id)
250 }
251}
252
253#[derive(Clone)]
256pub struct DriverCaller {
257 sender: ConnectionSender,
258 shared: Arc<DriverShared>,
259 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
260 _drop_guard: Option<Arc<CallerDropGuard>>,
261}
262
263impl DriverCaller {
264 pub fn create_tx_channel(
270 &self,
271 initial_credit: u32,
272 ) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
273 let channel_id = self.shared.channel_ids.lock().alloc();
274 let inner = DriverChannelSink {
275 sender: self.sender.clone(),
276 channel_id,
277 local_control_tx: self.local_control_tx.clone(),
278 };
279 let sink = Arc::new(CreditSink::new(inner, initial_credit));
280 self.shared
281 .channel_credits
282 .lock()
283 .insert(channel_id, Arc::clone(sink.credit()));
284 (channel_id, sink)
285 }
286
287 #[cfg(test)]
292 pub(crate) fn connection_sender(&self) -> &ConnectionSender {
293 &self.sender
294 }
295
296 pub fn register_rx_channel(
301 &self,
302 channel_id: ChannelId,
303 ) -> tokio::sync::mpsc::Receiver<IncomingChannelMessage> {
304 let (tx, rx) = tokio::sync::mpsc::channel(64);
305 let mut terminal_buffered = false;
306 if let Some(buffered) = self.shared.channel_buffers.lock().remove(&channel_id) {
308 for msg in buffered {
309 let is_terminal = matches!(
310 msg,
311 IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
312 );
313 let _ = tx.try_send(msg);
314 if is_terminal {
315 terminal_buffered = true;
316 break;
317 }
318 }
319 }
320 if terminal_buffered {
321 self.shared.channel_credits.lock().remove(&channel_id);
322 return rx;
323 }
324
325 self.shared.channel_senders.lock().insert(channel_id, tx);
326 rx
327 }
328}
329
330impl ChannelBinder for DriverCaller {
331 fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
332 let (id, sink) = self.create_tx_channel(initial_credit);
333 (id, sink as Arc<dyn ChannelSink>)
334 }
335
336 fn create_rx(
337 &self,
338 ) -> (
339 ChannelId,
340 tokio::sync::mpsc::Receiver<IncomingChannelMessage>,
341 ) {
342 let channel_id = self.shared.channel_ids.lock().alloc();
343 let rx = self.register_rx_channel(channel_id);
344 (channel_id, rx)
345 }
346
347 fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
348 let inner = DriverChannelSink {
349 sender: self.sender.clone(),
350 channel_id,
351 local_control_tx: self.local_control_tx.clone(),
352 };
353 let sink = Arc::new(CreditSink::new(inner, initial_credit));
354 self.shared
355 .channel_credits
356 .lock()
357 .insert(channel_id, Arc::clone(sink.credit()));
358 sink
359 }
360
361 fn register_rx(
362 &self,
363 channel_id: ChannelId,
364 ) -> tokio::sync::mpsc::Receiver<IncomingChannelMessage> {
365 self.register_rx_channel(channel_id)
366 }
367}
368
369impl Caller for DriverCaller {
370 async fn call<'a>(
371 &self,
372 call: RequestCall<'a>,
373 ) -> Result<SelfRef<RequestResponse<'static>>, RoamError> {
374 async {
375 let req_id = self.shared.request_ids.lock().alloc();
377
378 let (tx, rx) = moire::sync::oneshot::channel("driver.response");
381 self.shared.pending_responses.lock().insert(req_id, tx);
382
383 let send_result = self
386 .sender
387 .send(ConnectionMessage::Request(RequestMessage {
388 id: req_id,
389 body: RequestBody::Call(call),
390 }))
391 .await;
392
393 if send_result.is_err() {
394 self.shared.pending_responses.lock().remove(&req_id);
396 return Err(RoamError::Cancelled);
397 }
398
399 let response_msg: SelfRef<RequestMessage<'static>> = rx
401 .named("awaiting_response")
402 .await
403 .map_err(|_| RoamError::Cancelled)?;
404
405 let response = response_msg.map(|m| match m.body {
407 RequestBody::Response(r) => r,
408 _ => unreachable!("pending_responses only gets Response variants"),
409 });
410
411 Ok(response)
412 }
413 .named("Caller::call")
414 .await
415 }
416
417 fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
418 Some(self)
419 }
420}
421
422pub struct Driver<H: Handler<DriverReplySink>> {
429 sender: ConnectionSender,
430 rx: mpsc::Receiver<SelfRef<ConnectionMessage<'static>>>,
431 failures_rx: mpsc::UnboundedReceiver<(RequestId, &'static str)>,
432 local_control_rx: mpsc::UnboundedReceiver<DriverLocalControl>,
433 handler: Arc<H>,
434 shared: Arc<DriverShared>,
435 in_flight_handlers: BTreeMap<RequestId, moire::task::JoinHandle<()>>,
438 local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
439 drop_control_seed: Option<mpsc::UnboundedSender<DropControlRequest>>,
440 drop_control_request: DropControlRequest,
441 drop_guard: SyncMutex<Option<Weak<CallerDropGuard>>>,
442}
443
444enum DriverLocalControl {
445 CloseChannel { channel_id: ChannelId },
446}
447
448impl<H: Handler<DriverReplySink>> Driver<H> {
449 pub fn new(handle: ConnectionHandle, handler: H) -> Self {
450 let conn_id = handle.connection_id();
451 let ConnectionHandle {
452 sender,
453 rx,
454 failures_rx,
455 control_tx,
456 parity,
457 } = handle;
458 let drop_control_request = DropControlRequest::Close(conn_id);
459 let (local_control_tx, local_control_rx) = mpsc::unbounded_channel("driver.local_control");
460 Self {
461 sender,
462 rx,
463 failures_rx,
464 local_control_rx,
465 handler: Arc::new(handler),
466 shared: Arc::new(DriverShared {
467 pending_responses: SyncMutex::new("driver.pending_responses", BTreeMap::new()),
468 request_ids: SyncMutex::new("driver.request_ids", IdAllocator::new(parity)),
469 channel_ids: SyncMutex::new("driver.channel_ids", IdAllocator::new(parity)),
470 channel_senders: SyncMutex::new("driver.channel_senders", BTreeMap::new()),
471 channel_buffers: SyncMutex::new("driver.channel_buffers", BTreeMap::new()),
472 channel_credits: SyncMutex::new("driver.channel_credits", BTreeMap::new()),
473 }),
474 in_flight_handlers: BTreeMap::new(),
475 local_control_tx,
476 drop_control_seed: control_tx,
477 drop_control_request,
478 drop_guard: SyncMutex::new("driver.drop_guard", None),
479 }
480 }
481
482 pub fn caller(&self) -> DriverCaller {
488 let drop_guard = if let Some(seed) = &self.drop_control_seed {
489 let mut guard = self.drop_guard.lock();
490 if let Some(existing) = guard.as_ref().and_then(Weak::upgrade) {
491 Some(existing)
492 } else {
493 let arc = Arc::new(CallerDropGuard {
494 control_tx: seed.clone(),
495 request: self.drop_control_request,
496 });
497 *guard = Some(Arc::downgrade(&arc));
498 Some(arc)
499 }
500 } else {
501 None
502 };
503 DriverCaller {
504 sender: self.sender.clone(),
505 shared: Arc::clone(&self.shared),
506 local_control_tx: self.local_control_tx.clone(),
507 _drop_guard: drop_guard,
508 }
509 }
510
511 fn internal_binder(&self) -> DriverChannelBinder {
512 DriverChannelBinder {
513 sender: self.sender.clone(),
514 shared: Arc::clone(&self.shared),
515 local_control_tx: self.local_control_tx.clone(),
516 }
517 }
518
519 pub async fn run(&mut self) {
524 loop {
525 tokio::select! {
526 msg = self.rx.recv() => {
527 match msg {
528 Some(msg) => self.handle_msg(msg),
529 None => break,
530 }
531 }
532 Some((req_id, _reason)) = self.failures_rx.recv() => {
533 self.in_flight_handlers.remove(&req_id);
535 if self.shared.pending_responses.lock().remove(&req_id).is_none() {
536 let error: Result<(), RoamError<core::convert::Infallible>> =
540 Err(RoamError::Cancelled);
541 let _ = self.sender.send_response(req_id, RequestResponse {
542 ret: Payload::outgoing(&error),
543 channels: vec![],
544 metadata: Default::default(),
545 }).await;
546 }
547 }
548 Some(ctrl) = self.local_control_rx.recv() => {
549 self.handle_local_control(ctrl).await;
550 }
551 }
552 }
553
554 for (_, handle) in std::mem::take(&mut self.in_flight_handlers) {
555 handle.abort();
556 }
557 self.shared.pending_responses.lock().clear();
558
559 self.shared.channel_senders.lock().clear();
562 self.shared.channel_buffers.lock().clear();
563 self.shared.channel_credits.lock().clear();
564 }
565
566 async fn handle_local_control(&mut self, control: DriverLocalControl) {
567 match control {
568 DriverLocalControl::CloseChannel { channel_id } => {
569 let _ = self
570 .sender
571 .send(ConnectionMessage::Channel(ChannelMessage {
572 id: channel_id,
573 body: ChannelBody::Close(ChannelClose {
574 metadata: Default::default(),
575 }),
576 }))
577 .await;
578 }
579 }
580 }
581
582 fn handle_msg(&mut self, msg: SelfRef<ConnectionMessage<'static>>) {
583 let is_request = matches!(&*msg, ConnectionMessage::Request(_));
584 if is_request {
585 let msg = msg.map(|m| match m {
586 ConnectionMessage::Request(r) => r,
587 _ => unreachable!(),
588 });
589 self.handle_request(msg);
590 } else {
591 let msg = msg.map(|m| match m {
592 ConnectionMessage::Channel(c) => c,
593 _ => unreachable!(),
594 });
595 self.handle_channel(msg);
596 }
597 }
598
599 fn handle_request(&mut self, msg: SelfRef<RequestMessage<'static>>) {
600 let req_id = msg.id;
601 let is_call = matches!(&msg.body, RequestBody::Call(_));
602 let is_response = matches!(&msg.body, RequestBody::Response(_));
603 let is_cancel = matches!(&msg.body, RequestBody::Cancel(_));
604
605 if is_call {
606 let reply = DriverReplySink {
609 sender: Some(self.sender.clone()),
610 request_id: req_id,
611 binder: self.internal_binder(),
612 };
613 let call = msg.map(|m| match m.body {
614 RequestBody::Call(c) => c,
615 _ => unreachable!(),
616 });
617 let handler = Arc::clone(&self.handler);
618 let join_handle = moire::task::spawn(
619 async move {
620 handler.handle(call, reply).await;
621 }
622 .named("handler"),
623 );
624 self.in_flight_handlers.insert(req_id, join_handle);
625 } else if is_response {
626 if let Some(tx) = self.shared.pending_responses.lock().remove(&req_id) {
628 let _: Result<(), _> = tx.send(msg);
629 }
630 } else if is_cancel {
631 if let Some(handle) = self.in_flight_handlers.remove(&req_id) {
636 handle.abort();
637 }
638 }
641 }
642
643 fn handle_channel(&mut self, msg: SelfRef<ChannelMessage<'static>>) {
644 let chan_id = msg.id;
645
646 let sender = self.shared.channel_senders.lock().get(&chan_id).cloned();
649
650 match &msg.body {
651 ChannelBody::Item(_item) => {
653 if let Some(tx) = &sender {
654 let item = msg.map(|m| match m.body {
655 ChannelBody::Item(item) => item,
656 _ => unreachable!(),
657 });
658 let _ = tx.try_send(IncomingChannelMessage::Item(item));
660 } else {
661 let item = msg.map(|m| match m.body {
663 ChannelBody::Item(item) => item,
664 _ => unreachable!(),
665 });
666 self.shared
667 .channel_buffers
668 .lock()
669 .entry(chan_id)
670 .or_default()
671 .push(IncomingChannelMessage::Item(item));
672 }
673 }
674 ChannelBody::Close(_close) => {
676 if let Some(tx) = &sender {
677 let close = msg.map(|m| match m.body {
678 ChannelBody::Close(close) => close,
679 _ => unreachable!(),
680 });
681 let _ = tx.try_send(IncomingChannelMessage::Close(close));
682 } else {
683 let close = msg.map(|m| match m.body {
685 ChannelBody::Close(close) => close,
686 _ => unreachable!(),
687 });
688 self.shared
689 .channel_buffers
690 .lock()
691 .entry(chan_id)
692 .or_default()
693 .push(IncomingChannelMessage::Close(close));
694 }
695 self.shared.channel_senders.lock().remove(&chan_id);
696 self.shared.channel_credits.lock().remove(&chan_id);
697 }
698 ChannelBody::Reset(_reset) => {
700 if let Some(tx) = &sender {
701 let reset = msg.map(|m| match m.body {
702 ChannelBody::Reset(reset) => reset,
703 _ => unreachable!(),
704 });
705 let _ = tx.try_send(IncomingChannelMessage::Reset(reset));
706 } else {
707 let reset = msg.map(|m| match m.body {
709 ChannelBody::Reset(reset) => reset,
710 _ => unreachable!(),
711 });
712 self.shared
713 .channel_buffers
714 .lock()
715 .entry(chan_id)
716 .or_default()
717 .push(IncomingChannelMessage::Reset(reset));
718 }
719 self.shared.channel_senders.lock().remove(&chan_id);
720 self.shared.channel_credits.lock().remove(&chan_id);
721 }
722 ChannelBody::GrantCredit(grant) => {
725 if let Some(semaphore) = self.shared.channel_credits.lock().get(&chan_id) {
726 semaphore.add_permits(grant.additional as usize);
727 }
728 }
729 }
730 }
731}