1use anyhow::Result;
2use futures_util::FutureExt;
3use serde_json::Value;
4use std::ops::Deref;
5use std::sync::atomic::{AtomicI64, Ordering};
6use std::time::Duration;
7use std::{collections::HashMap, sync::Arc};
8use tokio::sync::{oneshot, watch, Mutex};
9
10use crate::drivers::{DriverHandle, DriverStopReason, TransportKind};
11use crate::encoding::EncodingKind;
12use crate::{backoff::Backoff, protocol::*};
13use tracing::debug;
14
15use super::protocol;
16
17type RpcResponse = Result<RpcResponseOk, RpcResponseError>;
18type EventCallback = dyn Fn(&Vec<Value>) + Send + Sync;
19
20struct SendMsgOpts {
21 ephemeral: bool,
22}
23
24impl Default for SendMsgOpts {
25 fn default() -> Self {
26 Self { ephemeral: false }
27 }
28}
29
30type WatchPair = (watch::Sender<bool>, watch::Receiver<bool>);
35
36pub type ActorHandle = Arc<ActorHandleInner>;
37
38struct ConnectionAttempt {
39 did_open: bool,
40 _task_end_reason: DriverStopReason,
41}
42
43pub struct ActorHandleInner {
44 pub endpoint: String,
45 transport_kind: TransportKind,
46 encoding_kind: EncodingKind,
47 parameters: Option<Value>,
48
49 driver: Mutex<Option<DriverHandle>>,
50 msg_queue: Mutex<Vec<Arc<protocol::ToServer>>>,
51
52 rpc_counter: AtomicI64,
53 in_flight_rpcs: Mutex<HashMap<i64, oneshot::Sender<RpcResponse>>>,
54
55 event_subscriptions: Mutex<HashMap<String, Vec<Box<EventCallback>>>>,
56
57 dc_watch: WatchPair,
58 disconnection_rx: Mutex<Option<oneshot::Receiver<()>>>,
59}
60
61impl ActorHandleInner {
62 pub(crate) fn new(
63 endpoint: String,
64 transport_kind: TransportKind,
65 encoding_kind: EncodingKind,
66 parameters: Option<Value>,
67 ) -> Result<ActorHandle> {
68 Ok(Arc::new(Self {
69 endpoint: endpoint.clone(),
70 transport_kind,
71 encoding_kind,
72 parameters,
73 driver: Mutex::new(None),
74 msg_queue: Mutex::new(Vec::new()),
75 rpc_counter: AtomicI64::new(0),
76 in_flight_rpcs: Mutex::new(HashMap::new()),
77 event_subscriptions: Mutex::new(HashMap::new()),
78 dc_watch: watch::channel(false),
79 disconnection_rx: Mutex::new(None),
80 }))
81 }
82
83 fn is_disconnecting(self: &Arc<Self>) -> bool {
84 *self.dc_watch.1.borrow() == true
85 }
86
87 async fn try_connect(self: &Arc<Self>) -> ConnectionAttempt {
88 let (driver, mut recver, task) = match self
89 .transport_kind
90 .connect(self.endpoint.clone(), self.encoding_kind, &self.parameters)
91 .await
92 {
93 Ok(a) => a,
94 Err(_) => {
95 return ConnectionAttempt {
98 did_open: false,
99 _task_end_reason: DriverStopReason::TaskError,
100 };
101 }
102 };
103
104 {
105 let mut my_driver = self.driver.lock().await;
106 *my_driver = Some(driver);
107 }
108
109 let mut task_end_reason = task.map(|res| match res {
110 Ok(a) => a,
111 Err(task_err) => {
112 if task_err.is_cancelled() {
113 DriverStopReason::UserAborted
114 } else {
115 DriverStopReason::TaskError
116 }
117 }
118 });
119
120 let mut did_connection_open = false;
121
122 let task_end_reason = loop {
124 tokio::select! {
125 reason = &mut task_end_reason => {
126 debug!("Connection closed: {:?}", reason);
127
128 break reason;
129 },
130 msg = recver.recv() => {
131 let Some(msg) = msg else {
133 continue;
135 };
136
137 if let ToClientBody::Init { i: _ } = &msg.b {
138 did_connection_open = true;
139 }
140
141 self.on_message(msg).await;
142 }
143 }
144 };
145
146 'destroy_driver: {
147 let mut d_guard = self.driver.lock().await;
148 let Some(d) = d_guard.take() else {
149 break 'destroy_driver;
152 };
153
154 d.disconnect();
155 }
156
157 ConnectionAttempt {
158 did_open: did_connection_open,
159 _task_end_reason: task_end_reason,
160 }
161 }
162
163 pub(crate) async fn start_connection(self: &Arc<Self>) {
164 let (tx, rx) = oneshot::channel();
165
166 {
167 let mut stop_rx = self.disconnection_rx.lock().await;
168 if stop_rx.is_some() {
169 return;
172 }
173
174 *stop_rx = Some(rx);
175 }
176
177 let handle = self.clone();
178
179 tokio::spawn(async move {
180 'keepalive: loop {
181 debug!("Attempting to reconnect");
182 let mut backoff = Backoff::new(Duration::from_secs(1), Duration::from_secs(30));
183 let mut retry_attempt = 0;
184 'retry: loop {
185 retry_attempt += 1;
186 debug!(
187 "Establish conn: attempt={}, timeout={:?}",
188 retry_attempt,
189 backoff.delay()
190 );
191 let attempt = handle.try_connect().await;
192
193 if handle.is_disconnecting() {
194 break 'keepalive;
195 }
196
197 if attempt.did_open {
198 break 'retry;
199 }
200
201 let mut dc_rx = handle.dc_watch.0.subscribe();
202
203 tokio::select! {
204 _ = backoff.tick() => {},
205 _ = dc_rx.wait_for(|x| *x == true) => {
206 break 'keepalive;
207 }
208 }
209 }
210 }
211
212 tx.send(()).ok();
213 handle.disconnection_rx.lock().await.take();
214 });
215 }
216
217 async fn on_open(self: &Arc<Self>, init: &protocol::Init) {
218 debug!("Connected to server: {:?}", init);
219
220 for (event_name, _) in self.event_subscriptions.lock().await.iter() {
221 self.send_subscription(event_name.clone(), true).await;
222 }
223
224 for msg in self.msg_queue.lock().await.drain(..) {
226 self.send_msg(msg, SendMsgOpts::default()).await;
229 }
230 }
231
232 async fn on_message(self: &Arc<Self>, msg: Arc<protocol::ToClient>) {
233 let body = &msg.b;
234
235 match body {
236 protocol::ToClientBody::Init { i: init } => {
237 self.on_open(init).await;
238 }
239 protocol::ToClientBody::ResponseOk { ro } => {
240 let id = ro.i;
241 let mut in_flight_rpcs = self.in_flight_rpcs.lock().await;
242 let Some(tx) = in_flight_rpcs.remove(&id) else {
243 debug!("Unexpected response: rpc id not found");
244 return;
245 };
246 if let Err(e) = tx.send(Ok(ro.clone())) {
247 debug!("{:?}", e);
248 return;
249 }
250 }
251 protocol::ToClientBody::ResponseError { re } => {
252 let id = re.i;
253 let mut in_flight_rpcs = self.in_flight_rpcs.lock().await;
254 let Some(tx) = in_flight_rpcs.remove(&id) else {
255 debug!("Unexpected response: rpc id not found");
256 return;
257 };
258 if let Err(e) = tx.send(Err(re.clone())) {
259 debug!("{:?}", e);
260 return;
261 }
262 }
263 protocol::ToClientBody::EventMessage { ev } => {
264 let listeners = self.event_subscriptions.lock().await;
265 if let Some(callbacks) = listeners.get(&ev.n) {
266 for cb in callbacks {
267 cb(&ev.a);
268 }
269 }
270 }
271 protocol::ToClientBody::EventError { er } => {
272 debug!("Event error: {:?}", er);
273 }
274 }
275 }
276
277 async fn send_msg(self: &Arc<Self>, msg: Arc<protocol::ToServer>, opts: SendMsgOpts) {
278 let guard = self.driver.lock().await;
279
280 'send_immediately: {
281 let Some(driver) = guard.deref() else {
282 break 'send_immediately;
283 };
284
285 let Ok(_) = driver.send(msg.clone()).await else {
286 break 'send_immediately;
287 };
288
289 return;
290 }
291
292 if opts.ephemeral == false {
294 self.msg_queue.lock().await.push(msg.clone());
295 }
296
297 return;
298 }
299
300 pub async fn action(self: &Arc<Self>, method: &str, params: Vec<Value>) -> Result<Value> {
301 let id: i64 = self.rpc_counter.fetch_add(1, Ordering::SeqCst);
302
303 let (tx, rx) = oneshot::channel();
304 self.in_flight_rpcs.lock().await.insert(id, tx);
305
306 self.send_msg(
307 Arc::new(protocol::ToServer {
308 b: protocol::ToServerBody::RpcRequest {
309 rr: protocol::RpcRequest {
310 i: id,
311 n: method.to_string(),
312 a: params,
313 },
314 },
315 }),
316 SendMsgOpts::default(),
317 )
318 .await;
319
320 let Ok(res) = rx.await else {
322 return Err(anyhow::anyhow!("Socket closed during rpc"));
324 };
325
326 match res {
327 Ok(ok) => Ok(ok.o),
328 Err(err) => {
329 let metadata = err.md.unwrap_or(Value::Null);
330
331 Err(anyhow::anyhow!(
332 "RPC Error({}): {:?}, {:#}",
333 err.c,
334 err.m,
335 metadata
336 ))
337 }
338 }
339 }
340
341 async fn send_subscription(self: &Arc<Self>, event_name: String, subscribe: bool) {
342 self.send_msg(
343 Arc::new(protocol::ToServer {
344 b: protocol::ToServerBody::SubscriptionRequest {
345 sr: protocol::SubscriptionRequest {
346 e: event_name,
347 s: subscribe,
348 },
349 },
350 }),
351 SendMsgOpts { ephemeral: true },
352 )
353 .await;
354 }
355
356 async fn add_event_subscription(
357 self: &Arc<Self>,
358 event_name: String,
359 callback: Box<EventCallback>,
360 ) {
361 let mut listeners = self.event_subscriptions.lock().await;
363
364 let is_new_subscription = listeners.contains_key(&event_name) == false;
365
366 listeners
367 .entry(event_name.clone())
368 .or_insert(Vec::new())
369 .push(callback);
370
371 if is_new_subscription {
372 self.send_subscription(event_name, true).await;
373 }
374 }
375
376 pub async fn on_event<F>(self: &Arc<Self>, event_name: &str, callback: F)
377 where
378 F: Fn(&Vec<Value>) + Send + Sync + 'static,
379 {
380 self.add_event_subscription(event_name.to_string(), Box::new(callback))
381 .await
382 }
383
384 pub async fn disconnect(self: &Arc<Self>) {
385 if self.is_disconnecting() {
386 return;
388 }
389
390 self.dc_watch.0.send(true).ok();
391
392 if let Some(d) = self.driver.lock().await.deref() {
393 d.disconnect()
394 }
395 self.in_flight_rpcs.lock().await.clear();
396 self.event_subscriptions.lock().await.clear();
397 let Some(rx) = self.disconnection_rx.lock().await.take() else {
398 return;
399 };
400
401 rx.await.ok();
402 }
403}