Skip to main content

rivetkit_client/
connection.rs

1use anyhow::Result;
2use futures_util::FutureExt;
3use parking_lot::Mutex as SyncMutex;
4use scc::{hash_map::Entry as SccEntry, HashMap as SccHashMap};
5use serde_json::Value;
6use std::fmt::Debug;
7use std::ops::Deref;
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::{Arc, Weak};
10use std::time::Duration;
11use tokio::sync::{broadcast, oneshot, watch, Mutex};
12
13use crate::{
14	backoff::Backoff,
15	drivers::*,
16	protocol::{query::ActorQuery, *},
17	remote_manager::RemoteManager,
18	EncodingKind, TransportKind,
19};
20use tracing::debug;
21
22type RpcResponse = Result<to_client::ActionResponse, to_client::Error>;
23type EventCallback = dyn Fn(Event) + Send + Sync;
24type VoidCallback = dyn Fn() + Send + Sync;
25type ErrorCallback = dyn Fn(&str) + Send + Sync;
26type StatusCallback = dyn Fn(ConnectionStatus) + Send + Sync;
27
28#[derive(Debug, Clone)]
29pub struct Event {
30	pub name: String,
31	pub args: Vec<Value>,
32	pub raw_args: Vec<u8>,
33}
34
35struct EventSubscription {
36	id: u64,
37	callback: Box<EventCallback>,
38}
39
40#[derive(Clone)]
41pub struct SubscriptionHandle {
42	inner: Arc<SubscriptionHandleInner>,
43}
44
45struct SubscriptionHandleInner {
46	conn: Weak<ActorConnectionInner>,
47	event_name: String,
48	id: u64,
49	active: AtomicBool,
50}
51
52impl SubscriptionHandle {
53	fn new(conn: &Arc<ActorConnectionInner>, event_name: String, id: u64) -> Self {
54		Self {
55			inner: Arc::new(SubscriptionHandleInner {
56				conn: Arc::downgrade(conn),
57				event_name,
58				id,
59				active: AtomicBool::new(true),
60			}),
61		}
62	}
63
64	pub async fn unsubscribe(&self) {
65		if !self.inner.active.swap(false, Ordering::SeqCst) {
66			return;
67		}
68
69		let Some(conn) = self.inner.conn.upgrade() else {
70			return;
71		};
72
73		conn.remove_event_subscription(&self.inner.event_name, self.inner.id)
74			.await;
75	}
76}
77
78struct SendMsgOpts {
79	ephemeral: bool,
80}
81
82impl Default for SendMsgOpts {
83	fn default() -> Self {
84		Self { ephemeral: false }
85	}
86}
87
88// struct WatchPair {
89//     tx: watch::Sender<bool>,
90//     rx: watch::Receiver<bool>,
91// }
92type WatchPair = (watch::Sender<bool>, watch::Receiver<bool>);
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum ConnectionStatus {
96	Idle,
97	Connecting,
98	Connected,
99	Disconnected,
100}
101
102pub type ActorConnection = Arc<ActorConnectionInner>;
103
104struct ConnectionAttempt {
105	did_open: bool,
106	_task_end_reason: DriverStopReason,
107}
108
109pub struct ActorConnectionInner {
110	remote_manager: RemoteManager,
111	transport_kind: TransportKind,
112	encoding_kind: EncodingKind,
113	query: ActorQuery,
114	parameters: Option<Value>,
115
116	driver: Mutex<Option<DriverHandle>>,
117	msg_queue: Mutex<Vec<Arc<to_server::ToServer>>>,
118
119	rpc_counter: AtomicU64,
120	event_subscription_counter: AtomicU64,
121	in_flight_rpcs: SccHashMap<u64, oneshot::Sender<RpcResponse>>,
122
123	event_subscriptions: SccHashMap<String, Vec<Arc<EventSubscription>>>,
124	on_open_callbacks: Mutex<Vec<Box<VoidCallback>>>,
125	on_close_callbacks: Mutex<Vec<Box<VoidCallback>>>,
126	on_error_callbacks: Mutex<Vec<Box<ErrorCallback>>>,
127	on_status_change_callbacks: Mutex<Vec<Box<StatusCallback>>>,
128
129	// Connection info for reconnection
130	actor_id: Mutex<Option<String>>,
131	connection_id: Mutex<Option<String>>,
132	connection_token: Mutex<Option<String>>,
133
134	dc_watch: WatchPair,
135	status_watch: (
136		watch::Sender<ConnectionStatus>,
137		watch::Receiver<ConnectionStatus>,
138	),
139	disconnection_rx: Mutex<Option<oneshot::Receiver<()>>>,
140}
141
142impl ActorConnectionInner {
143	pub(crate) fn new(
144		remote_manager: RemoteManager,
145		query: ActorQuery,
146		transport_kind: TransportKind,
147		encoding_kind: EncodingKind,
148		parameters: Option<Value>,
149	) -> ActorConnection {
150		Arc::new(Self {
151			remote_manager,
152			transport_kind,
153			encoding_kind,
154			query,
155			parameters,
156			driver: Mutex::new(None),
157			msg_queue: Mutex::new(Vec::new()),
158			rpc_counter: AtomicU64::new(0),
159			event_subscription_counter: AtomicU64::new(0),
160			in_flight_rpcs: SccHashMap::new(),
161			event_subscriptions: SccHashMap::new(),
162			on_open_callbacks: Mutex::new(Vec::new()),
163			on_close_callbacks: Mutex::new(Vec::new()),
164			on_error_callbacks: Mutex::new(Vec::new()),
165			on_status_change_callbacks: Mutex::new(Vec::new()),
166			actor_id: Mutex::new(None),
167			connection_id: Mutex::new(None),
168			connection_token: Mutex::new(None),
169			dc_watch: watch::channel(false),
170			status_watch: watch::channel(ConnectionStatus::Idle),
171			disconnection_rx: Mutex::new(None),
172		})
173	}
174
175	fn is_disconnecting(self: &Arc<Self>) -> bool {
176		*self.dc_watch.1.borrow() == true
177	}
178
179	async fn try_connect(self: &Arc<Self>) -> ConnectionAttempt {
180		self.set_status(ConnectionStatus::Connecting).await;
181
182		// Get connection info for reconnection
183		let conn_id = self.connection_id.lock().await.clone();
184		let conn_token = self.connection_token.lock().await.clone();
185
186		let (driver, mut recver, task) = match connect_driver(
187			self.transport_kind,
188			DriverConnectArgs {
189				remote_manager: self.remote_manager.clone(),
190				query: self.query.clone(),
191				encoding_kind: self.encoding_kind,
192				parameters: self.parameters.clone(),
193				conn_id,
194				conn_token,
195			},
196		)
197		.await
198		{
199			Ok(value) => value,
200			Err(error) => {
201				let message = error.to_string();
202				self.emit_error(&message).await;
203				self.set_status(ConnectionStatus::Disconnected).await;
204				return ConnectionAttempt {
205					did_open: false,
206					_task_end_reason: DriverStopReason::TaskError,
207				};
208			}
209		};
210
211		{
212			let mut my_driver = self.driver.lock().await;
213			*my_driver = Some(driver);
214		}
215
216		let mut task_end_reason = task.map(|res| match res {
217			Ok(a) => a,
218			Err(task_err) => {
219				if task_err.is_cancelled() {
220					debug!("Connection task was cancelled");
221					DriverStopReason::UserAborted
222				} else {
223					DriverStopReason::TaskError
224				}
225			}
226		});
227
228		let mut did_connection_open = false;
229
230		// spawn listener for rpcs
231		let task_end_reason = loop {
232			tokio::select! {
233				reason = &mut task_end_reason => {
234					debug!("Connection closed: {:?}", reason);
235
236					break reason;
237				},
238				msg = recver.recv() => {
239					// If the sender is dropped, break the loop
240					let Some(msg) = msg else {
241						// break DriverStopReason::ServerDisconnect;
242						continue;
243					};
244
245					if let to_client::ToClientBody::Init(_) = &msg.body {
246						did_connection_open = true;
247					}
248
249					self.on_message(msg).await;
250				}
251			}
252		};
253
254		'destroy_driver: {
255			debug!("Destroying driver");
256			let mut d_guard = self.driver.lock().await;
257			let Some(d) = d_guard.take() else {
258				// We destroyed the driver already,
259				// e.g. .disconnect() was called
260				break 'destroy_driver;
261			};
262
263			d.disconnect();
264		}
265
266		self.set_status(ConnectionStatus::Disconnected).await;
267		self.emit_close().await;
268
269		ConnectionAttempt {
270			did_open: did_connection_open,
271			_task_end_reason: task_end_reason,
272		}
273	}
274
275	async fn handle_open(self: &Arc<Self>, init: &to_client::Init) {
276		debug!("Connected to server: {:?}", init);
277
278		// Store connection info for reconnection
279		*self.actor_id.lock().await = Some(init.actor_id.clone());
280		*self.connection_id.lock().await = Some(init.connection_id.clone());
281		*self.connection_token.lock().await = init.connection_token.clone();
282		self.set_status(ConnectionStatus::Connected).await;
283		self.emit_open().await;
284
285		let mut event_names = Vec::new();
286		self.event_subscriptions
287			.iter_async(|event_name, _| {
288				event_names.push(event_name.clone());
289				true
290			})
291			.await;
292		for event_name in event_names {
293			self.send_subscription(event_name.clone(), true).await;
294		}
295
296		// Flush message queue
297		for msg in self.msg_queue.lock().await.drain(..) {
298			// If its in the queue, it isn't ephemeral, so we pass
299			// default SendMsgOpts
300			self.send_msg(msg, SendMsgOpts::default()).await;
301		}
302	}
303
304	async fn on_message(self: &Arc<Self>, msg: Arc<to_client::ToClient>) {
305		let body = &msg.body;
306
307		match body {
308			to_client::ToClientBody::Init(init) => {
309				self.handle_open(init).await;
310			}
311			to_client::ToClientBody::ActionResponse(ar) => {
312				let id = ar.id;
313				let Some((_, tx)) = self.in_flight_rpcs.remove_async(&id).await else {
314					debug!("Unexpected response: rpc id not found");
315					return;
316				};
317				if let Err(e) = tx.send(Ok(ar.clone())) {
318					debug!("{:?}", e);
319					return;
320				}
321			}
322			to_client::ToClientBody::Event(ev) => {
323				let args = decode_event_args(&ev.args);
324
325				let callbacks = {
326					self.event_subscriptions
327						.read_async(&ev.name, |_, listeners| listeners.clone())
328						.await
329						.unwrap_or_default()
330				};
331				let event = Event {
332					name: ev.name.clone(),
333					args,
334					raw_args: ev.args.clone(),
335				};
336				for subscription in callbacks {
337					(subscription.callback)(event.clone());
338				}
339			}
340			to_client::ToClientBody::Error(e) => {
341				if let Some(action_id) = e.action_id {
342					let Some((_, tx)) = self.in_flight_rpcs.remove_async(&action_id).await else {
343						debug!("Unexpected response: rpc id not found");
344						return;
345					};
346					if let Err(e) = tx.send(Err(e.clone())) {
347						debug!("{:?}", e);
348						return;
349					}
350
351					return;
352				}
353
354				debug!("Connection error: {} - {}", e.code, e.message);
355				self.emit_error(&e.message).await;
356			}
357		}
358	}
359
360	async fn set_status(self: &Arc<Self>, status: ConnectionStatus) {
361		if *self.status_watch.1.borrow() == status {
362			return;
363		}
364		self.status_watch.0.send(status).ok();
365		for callback in self.on_status_change_callbacks.lock().await.iter() {
366			callback(status);
367		}
368	}
369
370	async fn emit_open(self: &Arc<Self>) {
371		for callback in self.on_open_callbacks.lock().await.iter() {
372			callback();
373		}
374	}
375
376	async fn emit_close(self: &Arc<Self>) {
377		for callback in self.on_close_callbacks.lock().await.iter() {
378			callback();
379		}
380	}
381
382	async fn emit_error(self: &Arc<Self>, message: &str) {
383		for callback in self.on_error_callbacks.lock().await.iter() {
384			callback(message);
385		}
386	}
387
388	async fn send_msg(self: &Arc<Self>, msg: Arc<to_server::ToServer>, opts: SendMsgOpts) {
389		let guard = self.driver.lock().await;
390
391		'send_immediately: {
392			let Some(driver) = guard.deref() else {
393				break 'send_immediately;
394			};
395
396			let Ok(_) = driver.send(msg.clone()).await else {
397				break 'send_immediately;
398			};
399
400			return;
401		}
402
403		// Otherwise queue
404		if opts.ephemeral == false {
405			self.msg_queue.lock().await.push(msg.clone());
406		}
407
408		return;
409	}
410
411	pub async fn action(self: &Arc<Self>, method: &str, params: Vec<Value>) -> Result<Value> {
412		let id: u64 = self.rpc_counter.fetch_add(1, Ordering::SeqCst);
413
414		let (tx, rx) = oneshot::channel();
415		if self.in_flight_rpcs.insert_async(id, tx).await.is_err() {
416			return Err(anyhow::anyhow!("duplicate rpc id"));
417		}
418
419		// Encode params as CBOR
420		let args_cbor = serde_cbor::to_vec(&params)?;
421
422		self.send_msg(
423			Arc::new(to_server::ToServer {
424				body: to_server::ToServerBody::ActionRequest(to_server::ActionRequest {
425					id,
426					name: method.to_string(),
427					args: args_cbor,
428				}),
429			}),
430			SendMsgOpts::default(),
431		)
432		.await;
433
434		let Ok(res) = rx.await else {
435			return Err(anyhow::anyhow!("Socket closed during rpc"));
436		};
437
438		match res {
439			Ok(ok) => {
440				// Decode CBOR output
441				let output: Value = serde_cbor::from_slice(&ok.output)?;
442				Ok(output)
443			}
444			Err(err) => {
445				let metadata = if let Some(md) = &err.metadata {
446					match serde_cbor::from_slice::<Value>(md) {
447						Ok(v) => v,
448						Err(_) => Value::Null,
449					}
450				} else {
451					Value::Null
452				};
453
454				Err(anyhow::anyhow!(
455					"RPC Error({}/{}): {}, {:#}",
456					err.group,
457					err.code,
458					err.message,
459					metadata
460				))
461			}
462		}
463	}
464
465	async fn send_subscription(self: &Arc<Self>, event_name: String, subscribe: bool) {
466		self.send_msg(
467			Arc::new(to_server::ToServer {
468				body: to_server::ToServerBody::SubscriptionRequest(
469					to_server::SubscriptionRequest {
470						event_name,
471						subscribe,
472					},
473				),
474			}),
475			SendMsgOpts { ephemeral: true },
476		)
477		.await;
478	}
479
480	async fn add_event_subscription(
481		self: &Arc<Self>,
482		event_name: String,
483		callback: Box<EventCallback>,
484	) -> SubscriptionHandle {
485		let id = self
486			.event_subscription_counter
487			.fetch_add(1, Ordering::SeqCst);
488		let handle = SubscriptionHandle::new(self, event_name.clone(), id);
489
490		self.insert_event_subscription(event_name, id, callback)
491			.await;
492
493		handle
494	}
495
496	async fn insert_event_subscription(
497		self: &Arc<Self>,
498		event_name: String,
499		id: u64,
500		callback: Box<EventCallback>,
501	) {
502		let is_new_subscription = {
503			let mut listeners = self
504				.event_subscriptions
505				.entry_async(event_name.clone())
506				.await
507				.or_insert_with(Vec::new);
508			let is_new_subscription = listeners.is_empty();
509
510			listeners.push(Arc::new(EventSubscription { id, callback }));
511
512			is_new_subscription
513		};
514
515		if is_new_subscription {
516			self.send_subscription(event_name, true).await;
517		}
518	}
519
520	async fn remove_event_subscription(self: &Arc<Self>, event_name: &str, id: u64) {
521		let should_unsubscribe = {
522			match self
523				.event_subscriptions
524				.entry_async(event_name.to_string())
525				.await
526			{
527				SccEntry::Occupied(mut entry) => {
528					entry.retain(|subscription| subscription.id != id);
529					if entry.is_empty() {
530						let _ = entry.remove_entry();
531						true
532					} else {
533						false
534					}
535				}
536				SccEntry::Vacant(entry) => {
537					drop(entry);
538					false
539				}
540			}
541		};
542
543		if should_unsubscribe {
544			self.send_subscription(event_name.to_string(), false).await;
545		}
546	}
547
548	pub async fn on_event<F>(self: &Arc<Self>, event_name: &str, callback: F) -> SubscriptionHandle
549	where
550		F: Fn(&Vec<Value>) + Send + Sync + 'static,
551	{
552		self.add_event_subscription(
553			event_name.to_string(),
554			Box::new(move |event| callback(&event.args)),
555		)
556		.await
557	}
558
559	pub async fn on_event_raw<F>(
560		self: &Arc<Self>,
561		event_name: &str,
562		callback: F,
563	) -> SubscriptionHandle
564	where
565		F: Fn(Event) + Send + Sync + 'static,
566	{
567		self.add_event_subscription(event_name.to_string(), Box::new(callback))
568			.await
569	}
570
571	pub async fn once_event<F>(
572		self: &Arc<Self>,
573		event_name: &str,
574		callback: F,
575	) -> SubscriptionHandle
576	where
577		F: FnOnce(Event) + Send + 'static,
578	{
579		let id = self
580			.event_subscription_counter
581			.fetch_add(1, Ordering::SeqCst);
582		let handle = SubscriptionHandle::new(self, event_name.to_string(), id);
583		// Event callbacks are synchronous, so a FnOnce lives behind a short sync lock.
584		let callback = Arc::new(SyncMutex::new(Some(callback)));
585		let unsubscribe_handle = handle.clone();
586		let fired = Arc::new(AtomicBool::new(false));
587		self.insert_event_subscription(
588			event_name.to_string(),
589			id,
590			Box::new(move |event| {
591				if fired.swap(true, Ordering::SeqCst) {
592					return;
593				}
594
595				let unsubscribe_handle = unsubscribe_handle.clone();
596				tokio::spawn(async move {
597					unsubscribe_handle.unsubscribe().await;
598				});
599
600				let Some(callback) = callback.lock().take() else {
601					return;
602				};
603				callback(event);
604			}),
605		)
606		.await;
607
608		handle
609	}
610
611	pub async fn on_open<F>(self: &Arc<Self>, callback: F)
612	where
613		F: Fn() + Send + Sync + 'static,
614	{
615		self.on_open_callbacks.lock().await.push(Box::new(callback));
616	}
617
618	pub async fn on_close<F>(self: &Arc<Self>, callback: F)
619	where
620		F: Fn() + Send + Sync + 'static,
621	{
622		self.on_close_callbacks
623			.lock()
624			.await
625			.push(Box::new(callback));
626	}
627
628	pub async fn on_error<F>(self: &Arc<Self>, callback: F)
629	where
630		F: Fn(&str) + Send + Sync + 'static,
631	{
632		self.on_error_callbacks
633			.lock()
634			.await
635			.push(Box::new(callback));
636	}
637
638	pub async fn on_status_change<F>(self: &Arc<Self>, callback: F)
639	where
640		F: Fn(ConnectionStatus) + Send + Sync + 'static,
641	{
642		self.on_status_change_callbacks
643			.lock()
644			.await
645			.push(Box::new(callback));
646	}
647
648	pub fn conn_status(self: &Arc<Self>) -> ConnectionStatus {
649		*self.status_watch.1.borrow()
650	}
651
652	pub fn status_receiver(self: &Arc<Self>) -> watch::Receiver<ConnectionStatus> {
653		self.status_watch.1.clone()
654	}
655
656	pub async fn disconnect(self: &Arc<Self>) {
657		if self.is_disconnecting() {
658			// We are already disconnecting
659			return;
660		}
661
662		debug!("Disconnecting from actor conn");
663
664		self.dc_watch.0.send(true).ok();
665		self.set_status(ConnectionStatus::Disconnected).await;
666
667		if let Some(d) = self.driver.lock().await.deref() {
668			d.disconnect();
669		}
670		self.in_flight_rpcs.clear_async().await;
671		self.event_subscriptions.clear_async().await;
672		let Some(rx) = self.disconnection_rx.lock().await.take() else {
673			return;
674		};
675
676		rx.await.ok();
677	}
678
679	pub async fn dispose(self: &Arc<Self>) {
680		self.disconnect().await
681	}
682}
683
684pub fn start_connection(
685	conn: &Arc<ActorConnectionInner>,
686	mut shutdown_rx: broadcast::Receiver<()>,
687) {
688	let (tx, rx) = oneshot::channel();
689
690	let conn = conn.clone();
691
692	tokio::spawn(async move {
693		{
694			let mut stop_rx = conn.disconnection_rx.lock().await;
695			if stop_rx.is_some() {
696				// Already doing connection_with_retry
697				// - this drops the oneshot
698				return;
699			}
700
701			*stop_rx = Some(rx);
702		}
703
704		'keepalive: loop {
705			debug!("Attempting to reconnect");
706			let mut backoff = Backoff::new(Duration::from_secs(1), Duration::from_secs(30));
707			let mut retry_attempt = 0;
708			'retry: loop {
709				retry_attempt += 1;
710				debug!(
711					"Establish conn: attempt={}, timeout={:?}",
712					retry_attempt,
713					backoff.delay()
714				);
715				let attempt = conn.try_connect().await;
716
717				if conn.is_disconnecting() {
718					break 'keepalive;
719				}
720
721				if attempt.did_open {
722					break 'retry;
723				}
724
725				let mut dc_rx = conn.dc_watch.0.subscribe();
726
727				tokio::select! {
728					_ = backoff.tick() => {},
729					_ = dc_rx.wait_for(|x| *x == true) => {
730						break 'keepalive;
731					}
732					_ = shutdown_rx.recv() => {
733						debug!("Received shutdown signal, stopping connection attempts");
734						break 'keepalive;
735					}
736				}
737			}
738		}
739
740		tx.send(()).ok();
741		conn.disconnection_rx.lock().await.take();
742	});
743}
744
745impl Debug for ActorConnectionInner {
746	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
747		f.debug_struct("ActorConnection")
748			.field("transport_kind", &self.transport_kind)
749			.field("encoding_kind", &self.encoding_kind)
750			.finish()
751	}
752}
753
754fn decode_event_args(raw_args: &[u8]) -> Vec<Value> {
755	match serde_cbor::from_slice::<Vec<Value>>(raw_args) {
756		Ok(args) => args,
757		Err(vector_error) => match serde_cbor::from_slice::<Value>(raw_args) {
758			Ok(Value::Array(args)) => args,
759			Ok(value) => vec![value],
760			Err(value_error) => {
761				debug!(?vector_error, ?value_error, "failed to decode event args");
762				Vec::new()
763			}
764		},
765	}
766}