Skip to main content

rivetkit_client/drivers/
mod.rs

1use std::sync::Arc;
2
3use crate::{
4	protocol::{query, to_client, to_server},
5	remote_manager::RemoteManager,
6	EncodingKind, TransportKind,
7};
8use anyhow::Result;
9use serde_json::Value;
10use tokio::{
11	sync::mpsc,
12	task::{AbortHandle, JoinHandle},
13};
14use tracing::debug;
15
16pub mod sse;
17pub mod ws;
18
19pub type MessageToClient = Arc<to_client::ToClient>;
20pub type MessageToServer = Arc<to_server::ToServer>;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum DriverStopReason {
24	UserAborted,
25	ServerDisconnect,
26	ServerError,
27	TaskError,
28}
29
30#[derive(Debug)]
31pub struct DriverHandle {
32	abort_handle: AbortHandle,
33	sender: mpsc::UnboundedSender<MessageToServer>,
34}
35
36impl DriverHandle {
37	pub fn new(sender: mpsc::UnboundedSender<MessageToServer>, abort_handle: AbortHandle) -> Self {
38		Self {
39			sender,
40			abort_handle,
41		}
42	}
43
44	pub async fn send(&self, msg: Arc<to_server::ToServer>) -> Result<()> {
45		self.sender.send(msg)?;
46
47		Ok(())
48	}
49
50	pub fn disconnect(&self) {
51		self.abort_handle.abort();
52	}
53}
54
55impl Drop for DriverHandle {
56	fn drop(&mut self) {
57		debug!("DriverHandle dropped, aborting task");
58		self.disconnect()
59	}
60}
61
62pub type DriverConnection = (
63	DriverHandle,
64	mpsc::UnboundedReceiver<MessageToClient>,
65	JoinHandle<DriverStopReason>,
66);
67
68pub struct DriverConnectArgs {
69	pub remote_manager: RemoteManager,
70	pub encoding_kind: EncodingKind,
71	pub query: query::ActorQuery,
72	pub parameters: Option<Value>,
73	pub conn_id: Option<String>,
74	pub conn_token: Option<String>,
75}
76
77pub async fn connect_driver(
78	transport_kind: TransportKind,
79	args: DriverConnectArgs,
80) -> Result<DriverConnection> {
81	let res = match transport_kind {
82		TransportKind::WebSocket => ws::connect(args).await?,
83		TransportKind::Sse => sse::connect(args).await?,
84	};
85
86	Ok(res)
87}