actor_core_client/drivers/
mod.rs

1use std::sync::Arc;
2
3use crate::{encoding::EncodingKind, protocol};
4use anyhow::Result;
5use serde_json::Value;
6use tokio::{
7    sync::mpsc,
8    task::{AbortHandle, JoinHandle},
9};
10use urlencoding::encode;
11
12pub mod sse;
13pub mod ws;
14
15const MAX_CONN_PARAMS_SIZE: usize = 4096;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum DriverStopReason {
19    UserAborted,
20    ServerDisconnect,
21    ServerError,
22    TaskError,
23}
24
25pub(crate) type MessageToClient = Arc<protocol::ToClient>;
26pub(crate) type MessageToServer = Arc<protocol::ToServer>;
27
28pub(crate) struct DriverHandle {
29    abort_handle: AbortHandle,
30    sender: mpsc::Sender<MessageToServer>,
31}
32
33impl DriverHandle {
34    pub fn new(sender: mpsc::Sender<MessageToServer>, abort_handle: AbortHandle) -> Self {
35        Self {
36            sender,
37            abort_handle,
38        }
39    }
40
41    pub async fn send(&self, msg: Arc<protocol::ToServer>) -> Result<()> {
42        self.sender.send(msg).await?;
43
44        Ok(())
45    }
46
47    pub fn disconnect(&self) {
48        self.abort_handle.abort();
49    }
50}
51
52impl Drop for DriverHandle {
53    fn drop(&mut self) {
54        self.disconnect()
55    }
56}
57
58#[derive(Debug, Clone, Copy)]
59pub enum TransportKind {
60    WebSocket,
61    Sse,
62}
63
64impl TransportKind {
65    pub(crate) async fn connect(
66        &self,
67        endpoint: String,
68        encoding_kind: EncodingKind,
69        parameters: &Option<Value>,
70    ) -> Result<(
71        DriverHandle,
72        mpsc::Receiver<MessageToClient>,
73        JoinHandle<DriverStopReason>,
74    )> {
75        match *self {
76            TransportKind::WebSocket => ws::connect(endpoint, encoding_kind, parameters).await,
77            TransportKind::Sse => sse::connect(endpoint, encoding_kind, parameters).await,
78        }
79    }
80}
81
82fn build_conn_url(
83    endpoint: &str,
84    transport_kind: &TransportKind,
85    encoding_kind: EncodingKind,
86    params: &Option<Value>,
87) -> Result<String> {
88    let connect_path = {
89        match transport_kind {
90            TransportKind::WebSocket => "websocket",
91            TransportKind::Sse => "sse",
92        }
93    };
94
95    let endpoint = match transport_kind {
96        TransportKind::WebSocket => endpoint
97            .to_string()
98            .replace("http://", "ws://")
99            .replace("https://", "wss://"),
100        TransportKind::Sse => endpoint.to_string(),
101    };
102
103    let Some(params) = params else {
104        return Ok(format!(
105            "{}/connect/{}?encoding={}",
106            endpoint,
107            connect_path,
108            encoding_kind.as_str()
109        ));
110    };
111
112    let params_str = serde_json::to_string(params)?;
113    if params_str.len() > MAX_CONN_PARAMS_SIZE {
114        return Err(anyhow::anyhow!("Connection parameters too long"));
115    }
116
117    let params_str = encode(&params_str);
118
119    Ok(format!(
120        "{}/connect/{}?encoding={}&params={}",
121        endpoint,
122        connect_path,
123        encoding_kind.as_str(),
124        params_str
125    ))
126}