rivetkit_client/drivers/
mod.rs1use std::sync::Arc;
2
3use crate::{
4 protocol::{query, to_client, to_server},
5 remote_manager::RemoteManager,
6 EncodingKind, TransportKind,
7};
8use anyhow::Result;
9use serde_json::Value;
10use tokio::{
11 sync::mpsc,
12 task::{AbortHandle, JoinHandle},
13};
14use tracing::debug;
15
16pub mod sse;
17pub mod ws;
18
19pub type MessageToClient = Arc<to_client::ToClient>;
20pub type MessageToServer = Arc<to_server::ToServer>;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum DriverStopReason {
24 UserAborted,
25 ServerDisconnect,
26 ServerError,
27 TaskError,
28}
29
30#[derive(Debug)]
31pub struct DriverHandle {
32 abort_handle: AbortHandle,
33 sender: mpsc::UnboundedSender<MessageToServer>,
34}
35
36impl DriverHandle {
37 pub fn new(sender: mpsc::UnboundedSender<MessageToServer>, abort_handle: AbortHandle) -> Self {
38 Self {
39 sender,
40 abort_handle,
41 }
42 }
43
44 pub async fn send(&self, msg: Arc<to_server::ToServer>) -> Result<()> {
45 self.sender.send(msg)?;
46
47 Ok(())
48 }
49
50 pub fn disconnect(&self) {
51 self.abort_handle.abort();
52 }
53}
54
55impl Drop for DriverHandle {
56 fn drop(&mut self) {
57 debug!("DriverHandle dropped, aborting task");
58 self.disconnect()
59 }
60}
61
62pub type DriverConnection = (
63 DriverHandle,
64 mpsc::UnboundedReceiver<MessageToClient>,
65 JoinHandle<DriverStopReason>,
66);
67
68pub struct DriverConnectArgs {
69 pub remote_manager: RemoteManager,
70 pub encoding_kind: EncodingKind,
71 pub query: query::ActorQuery,
72 pub parameters: Option<Value>,
73 pub conn_id: Option<String>,
74 pub conn_token: Option<String>,
75}
76
77pub async fn connect_driver(
78 transport_kind: TransportKind,
79 args: DriverConnectArgs,
80) -> Result<DriverConnection> {
81 let res = match transport_kind {
82 TransportKind::WebSocket => ws::connect(args).await?,
83 TransportKind::Sse => sse::connect(args).await?,
84 };
85
86 Ok(res)
87}