actor_core_client/drivers/
mod.rs1use std::sync::Arc;
2
3use crate::{encoding::EncodingKind, protocol};
4use anyhow::Result;
5use serde_json::Value;
6use tokio::{
7 sync::mpsc,
8 task::{AbortHandle, JoinHandle},
9};
10use urlencoding::encode;
11
12pub mod sse;
13pub mod ws;
14
15const MAX_CONN_PARAMS_SIZE: usize = 4096;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum DriverStopReason {
19 UserAborted,
20 ServerDisconnect,
21 ServerError,
22 TaskError,
23}
24
25pub(crate) type MessageToClient = Arc<protocol::ToClient>;
26pub(crate) type MessageToServer = Arc<protocol::ToServer>;
27
28pub(crate) struct DriverHandle {
29 abort_handle: AbortHandle,
30 sender: mpsc::Sender<MessageToServer>,
31}
32
33impl DriverHandle {
34 pub fn new(sender: mpsc::Sender<MessageToServer>, abort_handle: AbortHandle) -> Self {
35 Self {
36 sender,
37 abort_handle,
38 }
39 }
40
41 pub async fn send(&self, msg: Arc<protocol::ToServer>) -> Result<()> {
42 self.sender.send(msg).await?;
43
44 Ok(())
45 }
46
47 pub fn disconnect(&self) {
48 self.abort_handle.abort();
49 }
50}
51
52impl Drop for DriverHandle {
53 fn drop(&mut self) {
54 self.disconnect()
55 }
56}
57
58#[derive(Debug, Clone, Copy)]
59pub enum TransportKind {
60 WebSocket,
61 Sse,
62}
63
64impl TransportKind {
65 pub(crate) async fn connect(
66 &self,
67 endpoint: String,
68 encoding_kind: EncodingKind,
69 parameters: &Option<Value>,
70 ) -> Result<(
71 DriverHandle,
72 mpsc::Receiver<MessageToClient>,
73 JoinHandle<DriverStopReason>,
74 )> {
75 match *self {
76 TransportKind::WebSocket => ws::connect(endpoint, encoding_kind, parameters).await,
77 TransportKind::Sse => sse::connect(endpoint, encoding_kind, parameters).await,
78 }
79 }
80}
81
82fn build_conn_url(
83 endpoint: &str,
84 transport_kind: &TransportKind,
85 encoding_kind: EncodingKind,
86 params: &Option<Value>,
87) -> Result<String> {
88 let connect_path = {
89 match transport_kind {
90 TransportKind::WebSocket => "websocket",
91 TransportKind::Sse => "sse",
92 }
93 };
94
95 let endpoint = match transport_kind {
96 TransportKind::WebSocket => endpoint
97 .to_string()
98 .replace("http://", "ws://")
99 .replace("https://", "wss://"),
100 TransportKind::Sse => endpoint.to_string(),
101 };
102
103 let Some(params) = params else {
104 return Ok(format!(
105 "{}/connect/{}?encoding={}",
106 endpoint,
107 connect_path,
108 encoding_kind.as_str()
109 ));
110 };
111
112 let params_str = serde_json::to_string(params)?;
113 if params_str.len() > MAX_CONN_PARAMS_SIZE {
114 return Err(anyhow::anyhow!("Connection parameters too long"));
115 }
116
117 let params_str = encode(¶ms_str);
118
119 Ok(format!(
120 "{}/connect/{}?encoding={}¶ms={}",
121 endpoint,
122 connect_path,
123 encoding_kind.as_str(),
124 params_str
125 ))
126}