1use anyhow::Result;
2use futures_util::FutureExt;
3use parking_lot::Mutex as SyncMutex;
4use scc::{hash_map::Entry as SccEntry, HashMap as SccHashMap};
5use serde_json::Value;
6use std::fmt::Debug;
7use std::ops::Deref;
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::{Arc, Weak};
10use std::time::Duration;
11use tokio::sync::{broadcast, oneshot, watch, Mutex};
12
13use crate::{
14 backoff::Backoff,
15 drivers::*,
16 protocol::{query::ActorQuery, *},
17 remote_manager::RemoteManager,
18 EncodingKind, TransportKind,
19};
20use tracing::debug;
21
22type RpcResponse = Result<to_client::ActionResponse, to_client::Error>;
23type EventCallback = dyn Fn(Event) + Send + Sync;
24type VoidCallback = dyn Fn() + Send + Sync;
25type ErrorCallback = dyn Fn(&str) + Send + Sync;
26type StatusCallback = dyn Fn(ConnectionStatus) + Send + Sync;
27
28#[derive(Debug, Clone)]
29pub struct Event {
30 pub name: String,
31 pub args: Vec<Value>,
32 pub raw_args: Vec<u8>,
33}
34
35struct EventSubscription {
36 id: u64,
37 callback: Box<EventCallback>,
38}
39
40#[derive(Clone)]
41pub struct SubscriptionHandle {
42 inner: Arc<SubscriptionHandleInner>,
43}
44
45struct SubscriptionHandleInner {
46 conn: Weak<ActorConnectionInner>,
47 event_name: String,
48 id: u64,
49 active: AtomicBool,
50}
51
52impl SubscriptionHandle {
53 fn new(conn: &Arc<ActorConnectionInner>, event_name: String, id: u64) -> Self {
54 Self {
55 inner: Arc::new(SubscriptionHandleInner {
56 conn: Arc::downgrade(conn),
57 event_name,
58 id,
59 active: AtomicBool::new(true),
60 }),
61 }
62 }
63
64 pub async fn unsubscribe(&self) {
65 if !self.inner.active.swap(false, Ordering::SeqCst) {
66 return;
67 }
68
69 let Some(conn) = self.inner.conn.upgrade() else {
70 return;
71 };
72
73 conn.remove_event_subscription(&self.inner.event_name, self.inner.id)
74 .await;
75 }
76}
77
78struct SendMsgOpts {
79 ephemeral: bool,
80}
81
82impl Default for SendMsgOpts {
83 fn default() -> Self {
84 Self { ephemeral: false }
85 }
86}
87
88type WatchPair = (watch::Sender<bool>, watch::Receiver<bool>);
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum ConnectionStatus {
96 Idle,
97 Connecting,
98 Connected,
99 Disconnected,
100}
101
102pub type ActorConnection = Arc<ActorConnectionInner>;
103
104struct ConnectionAttempt {
105 did_open: bool,
106 _task_end_reason: DriverStopReason,
107}
108
109pub struct ActorConnectionInner {
110 remote_manager: RemoteManager,
111 transport_kind: TransportKind,
112 encoding_kind: EncodingKind,
113 query: ActorQuery,
114 parameters: Option<Value>,
115
116 driver: Mutex<Option<DriverHandle>>,
117 msg_queue: Mutex<Vec<Arc<to_server::ToServer>>>,
118
119 rpc_counter: AtomicU64,
120 event_subscription_counter: AtomicU64,
121 in_flight_rpcs: SccHashMap<u64, oneshot::Sender<RpcResponse>>,
122
123 event_subscriptions: SccHashMap<String, Vec<Arc<EventSubscription>>>,
124 on_open_callbacks: Mutex<Vec<Box<VoidCallback>>>,
125 on_close_callbacks: Mutex<Vec<Box<VoidCallback>>>,
126 on_error_callbacks: Mutex<Vec<Box<ErrorCallback>>>,
127 on_status_change_callbacks: Mutex<Vec<Box<StatusCallback>>>,
128
129 actor_id: Mutex<Option<String>>,
131 connection_id: Mutex<Option<String>>,
132 connection_token: Mutex<Option<String>>,
133
134 dc_watch: WatchPair,
135 status_watch: (
136 watch::Sender<ConnectionStatus>,
137 watch::Receiver<ConnectionStatus>,
138 ),
139 disconnection_rx: Mutex<Option<oneshot::Receiver<()>>>,
140}
141
142impl ActorConnectionInner {
143 pub(crate) fn new(
144 remote_manager: RemoteManager,
145 query: ActorQuery,
146 transport_kind: TransportKind,
147 encoding_kind: EncodingKind,
148 parameters: Option<Value>,
149 ) -> ActorConnection {
150 Arc::new(Self {
151 remote_manager,
152 transport_kind,
153 encoding_kind,
154 query,
155 parameters,
156 driver: Mutex::new(None),
157 msg_queue: Mutex::new(Vec::new()),
158 rpc_counter: AtomicU64::new(0),
159 event_subscription_counter: AtomicU64::new(0),
160 in_flight_rpcs: SccHashMap::new(),
161 event_subscriptions: SccHashMap::new(),
162 on_open_callbacks: Mutex::new(Vec::new()),
163 on_close_callbacks: Mutex::new(Vec::new()),
164 on_error_callbacks: Mutex::new(Vec::new()),
165 on_status_change_callbacks: Mutex::new(Vec::new()),
166 actor_id: Mutex::new(None),
167 connection_id: Mutex::new(None),
168 connection_token: Mutex::new(None),
169 dc_watch: watch::channel(false),
170 status_watch: watch::channel(ConnectionStatus::Idle),
171 disconnection_rx: Mutex::new(None),
172 })
173 }
174
175 fn is_disconnecting(self: &Arc<Self>) -> bool {
176 *self.dc_watch.1.borrow() == true
177 }
178
179 async fn try_connect(self: &Arc<Self>) -> ConnectionAttempt {
180 self.set_status(ConnectionStatus::Connecting).await;
181
182 let conn_id = self.connection_id.lock().await.clone();
184 let conn_token = self.connection_token.lock().await.clone();
185
186 let (driver, mut recver, task) = match connect_driver(
187 self.transport_kind,
188 DriverConnectArgs {
189 remote_manager: self.remote_manager.clone(),
190 query: self.query.clone(),
191 encoding_kind: self.encoding_kind,
192 parameters: self.parameters.clone(),
193 conn_id,
194 conn_token,
195 },
196 )
197 .await
198 {
199 Ok(value) => value,
200 Err(error) => {
201 let message = error.to_string();
202 self.emit_error(&message).await;
203 self.set_status(ConnectionStatus::Disconnected).await;
204 return ConnectionAttempt {
205 did_open: false,
206 _task_end_reason: DriverStopReason::TaskError,
207 };
208 }
209 };
210
211 {
212 let mut my_driver = self.driver.lock().await;
213 *my_driver = Some(driver);
214 }
215
216 let mut task_end_reason = task.map(|res| match res {
217 Ok(a) => a,
218 Err(task_err) => {
219 if task_err.is_cancelled() {
220 debug!("Connection task was cancelled");
221 DriverStopReason::UserAborted
222 } else {
223 DriverStopReason::TaskError
224 }
225 }
226 });
227
228 let mut did_connection_open = false;
229
230 let task_end_reason = loop {
232 tokio::select! {
233 reason = &mut task_end_reason => {
234 debug!("Connection closed: {:?}", reason);
235
236 break reason;
237 },
238 msg = recver.recv() => {
239 let Some(msg) = msg else {
241 continue;
243 };
244
245 if let to_client::ToClientBody::Init(_) = &msg.body {
246 did_connection_open = true;
247 }
248
249 self.on_message(msg).await;
250 }
251 }
252 };
253
254 'destroy_driver: {
255 debug!("Destroying driver");
256 let mut d_guard = self.driver.lock().await;
257 let Some(d) = d_guard.take() else {
258 break 'destroy_driver;
261 };
262
263 d.disconnect();
264 }
265
266 self.set_status(ConnectionStatus::Disconnected).await;
267 self.emit_close().await;
268
269 ConnectionAttempt {
270 did_open: did_connection_open,
271 _task_end_reason: task_end_reason,
272 }
273 }
274
275 async fn handle_open(self: &Arc<Self>, init: &to_client::Init) {
276 debug!("Connected to server: {:?}", init);
277
278 *self.actor_id.lock().await = Some(init.actor_id.clone());
280 *self.connection_id.lock().await = Some(init.connection_id.clone());
281 *self.connection_token.lock().await = init.connection_token.clone();
282 self.set_status(ConnectionStatus::Connected).await;
283 self.emit_open().await;
284
285 let mut event_names = Vec::new();
286 self.event_subscriptions
287 .iter_async(|event_name, _| {
288 event_names.push(event_name.clone());
289 true
290 })
291 .await;
292 for event_name in event_names {
293 self.send_subscription(event_name.clone(), true).await;
294 }
295
296 for msg in self.msg_queue.lock().await.drain(..) {
298 self.send_msg(msg, SendMsgOpts::default()).await;
301 }
302 }
303
304 async fn on_message(self: &Arc<Self>, msg: Arc<to_client::ToClient>) {
305 let body = &msg.body;
306
307 match body {
308 to_client::ToClientBody::Init(init) => {
309 self.handle_open(init).await;
310 }
311 to_client::ToClientBody::ActionResponse(ar) => {
312 let id = ar.id;
313 let Some((_, tx)) = self.in_flight_rpcs.remove_async(&id).await else {
314 debug!("Unexpected response: rpc id not found");
315 return;
316 };
317 if let Err(e) = tx.send(Ok(ar.clone())) {
318 debug!("{:?}", e);
319 return;
320 }
321 }
322 to_client::ToClientBody::Event(ev) => {
323 let args = decode_event_args(&ev.args);
324
325 let callbacks = {
326 self.event_subscriptions
327 .read_async(&ev.name, |_, listeners| listeners.clone())
328 .await
329 .unwrap_or_default()
330 };
331 let event = Event {
332 name: ev.name.clone(),
333 args,
334 raw_args: ev.args.clone(),
335 };
336 for subscription in callbacks {
337 (subscription.callback)(event.clone());
338 }
339 }
340 to_client::ToClientBody::Error(e) => {
341 if let Some(action_id) = e.action_id {
342 let Some((_, tx)) = self.in_flight_rpcs.remove_async(&action_id).await else {
343 debug!("Unexpected response: rpc id not found");
344 return;
345 };
346 if let Err(e) = tx.send(Err(e.clone())) {
347 debug!("{:?}", e);
348 return;
349 }
350
351 return;
352 }
353
354 debug!("Connection error: {} - {}", e.code, e.message);
355 self.emit_error(&e.message).await;
356 }
357 }
358 }
359
360 async fn set_status(self: &Arc<Self>, status: ConnectionStatus) {
361 if *self.status_watch.1.borrow() == status {
362 return;
363 }
364 self.status_watch.0.send(status).ok();
365 for callback in self.on_status_change_callbacks.lock().await.iter() {
366 callback(status);
367 }
368 }
369
370 async fn emit_open(self: &Arc<Self>) {
371 for callback in self.on_open_callbacks.lock().await.iter() {
372 callback();
373 }
374 }
375
376 async fn emit_close(self: &Arc<Self>) {
377 for callback in self.on_close_callbacks.lock().await.iter() {
378 callback();
379 }
380 }
381
382 async fn emit_error(self: &Arc<Self>, message: &str) {
383 for callback in self.on_error_callbacks.lock().await.iter() {
384 callback(message);
385 }
386 }
387
388 async fn send_msg(self: &Arc<Self>, msg: Arc<to_server::ToServer>, opts: SendMsgOpts) {
389 let guard = self.driver.lock().await;
390
391 'send_immediately: {
392 let Some(driver) = guard.deref() else {
393 break 'send_immediately;
394 };
395
396 let Ok(_) = driver.send(msg.clone()).await else {
397 break 'send_immediately;
398 };
399
400 return;
401 }
402
403 if opts.ephemeral == false {
405 self.msg_queue.lock().await.push(msg.clone());
406 }
407
408 return;
409 }
410
411 pub async fn action(self: &Arc<Self>, method: &str, params: Vec<Value>) -> Result<Value> {
412 let id: u64 = self.rpc_counter.fetch_add(1, Ordering::SeqCst);
413
414 let (tx, rx) = oneshot::channel();
415 if self.in_flight_rpcs.insert_async(id, tx).await.is_err() {
416 return Err(anyhow::anyhow!("duplicate rpc id"));
417 }
418
419 let args_cbor = serde_cbor::to_vec(¶ms)?;
421
422 self.send_msg(
423 Arc::new(to_server::ToServer {
424 body: to_server::ToServerBody::ActionRequest(to_server::ActionRequest {
425 id,
426 name: method.to_string(),
427 args: args_cbor,
428 }),
429 }),
430 SendMsgOpts::default(),
431 )
432 .await;
433
434 let Ok(res) = rx.await else {
435 return Err(anyhow::anyhow!("Socket closed during rpc"));
436 };
437
438 match res {
439 Ok(ok) => {
440 let output: Value = serde_cbor::from_slice(&ok.output)?;
442 Ok(output)
443 }
444 Err(err) => {
445 let metadata = if let Some(md) = &err.metadata {
446 match serde_cbor::from_slice::<Value>(md) {
447 Ok(v) => v,
448 Err(_) => Value::Null,
449 }
450 } else {
451 Value::Null
452 };
453
454 Err(anyhow::anyhow!(
455 "RPC Error({}/{}): {}, {:#}",
456 err.group,
457 err.code,
458 err.message,
459 metadata
460 ))
461 }
462 }
463 }
464
465 async fn send_subscription(self: &Arc<Self>, event_name: String, subscribe: bool) {
466 self.send_msg(
467 Arc::new(to_server::ToServer {
468 body: to_server::ToServerBody::SubscriptionRequest(
469 to_server::SubscriptionRequest {
470 event_name,
471 subscribe,
472 },
473 ),
474 }),
475 SendMsgOpts { ephemeral: true },
476 )
477 .await;
478 }
479
480 async fn add_event_subscription(
481 self: &Arc<Self>,
482 event_name: String,
483 callback: Box<EventCallback>,
484 ) -> SubscriptionHandle {
485 let id = self
486 .event_subscription_counter
487 .fetch_add(1, Ordering::SeqCst);
488 let handle = SubscriptionHandle::new(self, event_name.clone(), id);
489
490 self.insert_event_subscription(event_name, id, callback)
491 .await;
492
493 handle
494 }
495
496 async fn insert_event_subscription(
497 self: &Arc<Self>,
498 event_name: String,
499 id: u64,
500 callback: Box<EventCallback>,
501 ) {
502 let is_new_subscription = {
503 let mut listeners = self
504 .event_subscriptions
505 .entry_async(event_name.clone())
506 .await
507 .or_insert_with(Vec::new);
508 let is_new_subscription = listeners.is_empty();
509
510 listeners.push(Arc::new(EventSubscription { id, callback }));
511
512 is_new_subscription
513 };
514
515 if is_new_subscription {
516 self.send_subscription(event_name, true).await;
517 }
518 }
519
520 async fn remove_event_subscription(self: &Arc<Self>, event_name: &str, id: u64) {
521 let should_unsubscribe = {
522 match self
523 .event_subscriptions
524 .entry_async(event_name.to_string())
525 .await
526 {
527 SccEntry::Occupied(mut entry) => {
528 entry.retain(|subscription| subscription.id != id);
529 if entry.is_empty() {
530 let _ = entry.remove_entry();
531 true
532 } else {
533 false
534 }
535 }
536 SccEntry::Vacant(entry) => {
537 drop(entry);
538 false
539 }
540 }
541 };
542
543 if should_unsubscribe {
544 self.send_subscription(event_name.to_string(), false).await;
545 }
546 }
547
548 pub async fn on_event<F>(self: &Arc<Self>, event_name: &str, callback: F) -> SubscriptionHandle
549 where
550 F: Fn(&Vec<Value>) + Send + Sync + 'static,
551 {
552 self.add_event_subscription(
553 event_name.to_string(),
554 Box::new(move |event| callback(&event.args)),
555 )
556 .await
557 }
558
559 pub async fn on_event_raw<F>(
560 self: &Arc<Self>,
561 event_name: &str,
562 callback: F,
563 ) -> SubscriptionHandle
564 where
565 F: Fn(Event) + Send + Sync + 'static,
566 {
567 self.add_event_subscription(event_name.to_string(), Box::new(callback))
568 .await
569 }
570
571 pub async fn once_event<F>(
572 self: &Arc<Self>,
573 event_name: &str,
574 callback: F,
575 ) -> SubscriptionHandle
576 where
577 F: FnOnce(Event) + Send + 'static,
578 {
579 let id = self
580 .event_subscription_counter
581 .fetch_add(1, Ordering::SeqCst);
582 let handle = SubscriptionHandle::new(self, event_name.to_string(), id);
583 let callback = Arc::new(SyncMutex::new(Some(callback)));
585 let unsubscribe_handle = handle.clone();
586 let fired = Arc::new(AtomicBool::new(false));
587 self.insert_event_subscription(
588 event_name.to_string(),
589 id,
590 Box::new(move |event| {
591 if fired.swap(true, Ordering::SeqCst) {
592 return;
593 }
594
595 let unsubscribe_handle = unsubscribe_handle.clone();
596 tokio::spawn(async move {
597 unsubscribe_handle.unsubscribe().await;
598 });
599
600 let Some(callback) = callback.lock().take() else {
601 return;
602 };
603 callback(event);
604 }),
605 )
606 .await;
607
608 handle
609 }
610
611 pub async fn on_open<F>(self: &Arc<Self>, callback: F)
612 where
613 F: Fn() + Send + Sync + 'static,
614 {
615 self.on_open_callbacks.lock().await.push(Box::new(callback));
616 }
617
618 pub async fn on_close<F>(self: &Arc<Self>, callback: F)
619 where
620 F: Fn() + Send + Sync + 'static,
621 {
622 self.on_close_callbacks
623 .lock()
624 .await
625 .push(Box::new(callback));
626 }
627
628 pub async fn on_error<F>(self: &Arc<Self>, callback: F)
629 where
630 F: Fn(&str) + Send + Sync + 'static,
631 {
632 self.on_error_callbacks
633 .lock()
634 .await
635 .push(Box::new(callback));
636 }
637
638 pub async fn on_status_change<F>(self: &Arc<Self>, callback: F)
639 where
640 F: Fn(ConnectionStatus) + Send + Sync + 'static,
641 {
642 self.on_status_change_callbacks
643 .lock()
644 .await
645 .push(Box::new(callback));
646 }
647
648 pub fn conn_status(self: &Arc<Self>) -> ConnectionStatus {
649 *self.status_watch.1.borrow()
650 }
651
652 pub fn status_receiver(self: &Arc<Self>) -> watch::Receiver<ConnectionStatus> {
653 self.status_watch.1.clone()
654 }
655
656 pub async fn disconnect(self: &Arc<Self>) {
657 if self.is_disconnecting() {
658 return;
660 }
661
662 debug!("Disconnecting from actor conn");
663
664 self.dc_watch.0.send(true).ok();
665 self.set_status(ConnectionStatus::Disconnected).await;
666
667 if let Some(d) = self.driver.lock().await.deref() {
668 d.disconnect();
669 }
670 self.in_flight_rpcs.clear_async().await;
671 self.event_subscriptions.clear_async().await;
672 let Some(rx) = self.disconnection_rx.lock().await.take() else {
673 return;
674 };
675
676 rx.await.ok();
677 }
678
679 pub async fn dispose(self: &Arc<Self>) {
680 self.disconnect().await
681 }
682}
683
684pub fn start_connection(
685 conn: &Arc<ActorConnectionInner>,
686 mut shutdown_rx: broadcast::Receiver<()>,
687) {
688 let (tx, rx) = oneshot::channel();
689
690 let conn = conn.clone();
691
692 tokio::spawn(async move {
693 {
694 let mut stop_rx = conn.disconnection_rx.lock().await;
695 if stop_rx.is_some() {
696 return;
699 }
700
701 *stop_rx = Some(rx);
702 }
703
704 'keepalive: loop {
705 debug!("Attempting to reconnect");
706 let mut backoff = Backoff::new(Duration::from_secs(1), Duration::from_secs(30));
707 let mut retry_attempt = 0;
708 'retry: loop {
709 retry_attempt += 1;
710 debug!(
711 "Establish conn: attempt={}, timeout={:?}",
712 retry_attempt,
713 backoff.delay()
714 );
715 let attempt = conn.try_connect().await;
716
717 if conn.is_disconnecting() {
718 break 'keepalive;
719 }
720
721 if attempt.did_open {
722 break 'retry;
723 }
724
725 let mut dc_rx = conn.dc_watch.0.subscribe();
726
727 tokio::select! {
728 _ = backoff.tick() => {},
729 _ = dc_rx.wait_for(|x| *x == true) => {
730 break 'keepalive;
731 }
732 _ = shutdown_rx.recv() => {
733 debug!("Received shutdown signal, stopping connection attempts");
734 break 'keepalive;
735 }
736 }
737 }
738 }
739
740 tx.send(()).ok();
741 conn.disconnection_rx.lock().await.take();
742 });
743}
744
745impl Debug for ActorConnectionInner {
746 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
747 f.debug_struct("ActorConnection")
748 .field("transport_kind", &self.transport_kind)
749 .field("encoding_kind", &self.encoding_kind)
750 .finish()
751 }
752}
753
754fn decode_event_args(raw_args: &[u8]) -> Vec<Value> {
755 match serde_cbor::from_slice::<Vec<Value>>(raw_args) {
756 Ok(args) => args,
757 Err(vector_error) => match serde_cbor::from_slice::<Value>(raw_args) {
758 Ok(Value::Array(args)) => args,
759 Ok(value) => vec![value],
760 Err(value_error) => {
761 debug!(?vector_error, ?value_error, "failed to decode event args");
762 Vec::new()
763 }
764 },
765 }
766}