actor_core_client/
handle.rs

1use anyhow::Result;
2use futures_util::FutureExt;
3use serde_json::Value;
4use std::ops::Deref;
5use std::sync::atomic::{AtomicI64, Ordering};
6use std::time::Duration;
7use std::{collections::HashMap, sync::Arc};
8use tokio::sync::{oneshot, watch, Mutex};
9
10use crate::drivers::{DriverHandle, DriverStopReason, TransportKind};
11use crate::encoding::EncodingKind;
12use crate::{backoff::Backoff, protocol::*};
13use tracing::debug;
14
15use super::protocol;
16
17type RpcResponse = Result<RpcResponseOk, RpcResponseError>;
18type EventCallback = dyn Fn(&Vec<Value>) + Send + Sync;
19
20struct SendMsgOpts {
21    ephemeral: bool,
22}
23
24impl Default for SendMsgOpts {
25    fn default() -> Self {
26        Self { ephemeral: false }
27    }
28}
29
30// struct WatchPair {
31//     tx: watch::Sender<bool>,
32//     rx: watch::Receiver<bool>,
33// }
34type WatchPair = (watch::Sender<bool>, watch::Receiver<bool>);
35
36pub type ActorHandle = Arc<ActorHandleInner>;
37
38struct ConnectionAttempt {
39    did_open: bool,
40    _task_end_reason: DriverStopReason,
41}
42
43pub struct ActorHandleInner {
44    pub endpoint: String,
45    transport_kind: TransportKind,
46    encoding_kind: EncodingKind,
47    parameters: Option<Value>,
48
49    driver: Mutex<Option<DriverHandle>>,
50    msg_queue: Mutex<Vec<Arc<protocol::ToServer>>>,
51
52    rpc_counter: AtomicI64,
53    in_flight_rpcs: Mutex<HashMap<i64, oneshot::Sender<RpcResponse>>>,
54
55    event_subscriptions: Mutex<HashMap<String, Vec<Box<EventCallback>>>>,
56
57    dc_watch: WatchPair,
58    disconnection_rx: Mutex<Option<oneshot::Receiver<()>>>,
59}
60
61impl ActorHandleInner {
62    pub(crate) fn new(
63        endpoint: String,
64        transport_kind: TransportKind,
65        encoding_kind: EncodingKind,
66        parameters: Option<Value>,
67    ) -> Result<ActorHandle> {
68        Ok(Arc::new(Self {
69            endpoint: endpoint.clone(),
70            transport_kind,
71            encoding_kind,
72            parameters,
73            driver: Mutex::new(None),
74            msg_queue: Mutex::new(Vec::new()),
75            rpc_counter: AtomicI64::new(0),
76            in_flight_rpcs: Mutex::new(HashMap::new()),
77            event_subscriptions: Mutex::new(HashMap::new()),
78            dc_watch: watch::channel(false),
79            disconnection_rx: Mutex::new(None),
80        }))
81    }
82
83    fn is_disconnecting(self: &Arc<Self>) -> bool {
84        *self.dc_watch.1.borrow() == true
85    }
86
87    async fn try_connect(self: &Arc<Self>) -> ConnectionAttempt {
88        let (driver, mut recver, task) = match self
89            .transport_kind
90            .connect(self.endpoint.clone(), self.encoding_kind, &self.parameters)
91            .await
92        {
93            Ok(a) => a,
94            Err(_) => {
95                // Either from immediate disconnect (local device connection refused)
96                // or from error like invalid URL
97                return ConnectionAttempt {
98                    did_open: false,
99                    _task_end_reason: DriverStopReason::TaskError,
100                };
101            }
102        };
103
104        {
105            let mut my_driver = self.driver.lock().await;
106            *my_driver = Some(driver);
107        }
108
109        let mut task_end_reason = task.map(|res| match res {
110            Ok(a) => a,
111            Err(task_err) => {
112                if task_err.is_cancelled() {
113                    DriverStopReason::UserAborted
114                } else {
115                    DriverStopReason::TaskError
116                }
117            }
118        });
119
120        let mut did_connection_open = false;
121
122        // spawn listener for rpcs
123        let task_end_reason = loop {
124            tokio::select! {
125                reason = &mut task_end_reason => {
126                    debug!("Connection closed: {:?}", reason);
127
128                    break reason;
129                },
130                msg = recver.recv() => {
131                    // If the sender is dropped, break the loop
132                    let Some(msg) = msg else {
133                        // break DriverStopReason::ServerDisconnect;
134                        continue;
135                    };
136
137                    if let ToClientBody::Init { i: _ } = &msg.b {
138                        did_connection_open = true;
139                    }
140
141                    self.on_message(msg).await;
142                }
143            }
144        };
145
146        'destroy_driver: {
147            let mut d_guard = self.driver.lock().await;
148            let Some(d) = d_guard.take() else {
149                // We destroyed the driver already,
150                // e.g. .disconnect() was called
151                break 'destroy_driver;
152            };
153
154            d.disconnect();
155        }
156
157        ConnectionAttempt {
158            did_open: did_connection_open,
159            _task_end_reason: task_end_reason,
160        }
161    }
162
163    pub(crate) async fn start_connection(self: &Arc<Self>) {
164        let (tx, rx) = oneshot::channel();
165
166        {
167            let mut stop_rx = self.disconnection_rx.lock().await;
168            if stop_rx.is_some() {
169                // Already doing connection_with_retry
170                // - this drops the oneshot
171                return;
172            }
173
174            *stop_rx = Some(rx);
175        }
176
177        let handle = self.clone();
178
179        tokio::spawn(async move {
180            'keepalive: loop {
181                debug!("Attempting to reconnect");
182                let mut backoff = Backoff::new(Duration::from_secs(1), Duration::from_secs(30));
183                let mut retry_attempt = 0;
184                'retry: loop {
185                    retry_attempt += 1;
186                    debug!(
187                        "Establish conn: attempt={}, timeout={:?}",
188                        retry_attempt,
189                        backoff.delay()
190                    );
191                    let attempt = handle.try_connect().await;
192
193                    if handle.is_disconnecting() {
194                        break 'keepalive;
195                    }
196
197                    if attempt.did_open {
198                        break 'retry;
199                    }
200
201                    let mut dc_rx = handle.dc_watch.0.subscribe();
202
203                    tokio::select! {
204                        _ = backoff.tick() => {},
205                        _ = dc_rx.wait_for(|x| *x == true) => {
206                            break 'keepalive;
207                        }
208                    }
209                }
210            }
211
212            tx.send(()).ok();
213            handle.disconnection_rx.lock().await.take();
214        });
215    }
216
217    async fn on_open(self: &Arc<Self>, init: &protocol::Init) {
218        debug!("Connected to server: {:?}", init);
219
220        for (event_name, _) in self.event_subscriptions.lock().await.iter() {
221            self.send_subscription(event_name.clone(), true).await;
222        }
223
224        // Flush message queue
225        for msg in self.msg_queue.lock().await.drain(..) {
226            // If its in the queue, it isn't ephemeral, so we pass
227            // default SendMsgOpts
228            self.send_msg(msg, SendMsgOpts::default()).await;
229        }
230    }
231
232    async fn on_message(self: &Arc<Self>, msg: Arc<protocol::ToClient>) {
233        let body = &msg.b;
234
235        match body {
236            protocol::ToClientBody::Init { i: init } => {
237                self.on_open(init).await;
238            }
239            protocol::ToClientBody::ResponseOk { ro } => {
240                let id = ro.i;
241                let mut in_flight_rpcs = self.in_flight_rpcs.lock().await;
242                let Some(tx) = in_flight_rpcs.remove(&id) else {
243                    debug!("Unexpected response: rpc id not found");
244                    return;
245                };
246                if let Err(e) = tx.send(Ok(ro.clone())) {
247                    debug!("{:?}", e);
248                    return;
249                }
250            }
251            protocol::ToClientBody::ResponseError { re } => {
252                let id = re.i;
253                let mut in_flight_rpcs = self.in_flight_rpcs.lock().await;
254                let Some(tx) = in_flight_rpcs.remove(&id) else {
255                    debug!("Unexpected response: rpc id not found");
256                    return;
257                };
258                if let Err(e) = tx.send(Err(re.clone())) {
259                    debug!("{:?}", e);
260                    return;
261                }
262            }
263            protocol::ToClientBody::EventMessage { ev } => {
264                let listeners = self.event_subscriptions.lock().await;
265                if let Some(callbacks) = listeners.get(&ev.n) {
266                    for cb in callbacks {
267                        cb(&ev.a);
268                    }
269                }
270            }
271            protocol::ToClientBody::EventError { er } => {
272                debug!("Event error: {:?}", er);
273            }
274        }
275    }
276
277    async fn send_msg(self: &Arc<Self>, msg: Arc<protocol::ToServer>, opts: SendMsgOpts) {
278        let guard = self.driver.lock().await;
279
280        'send_immediately: {
281            let Some(driver) = guard.deref() else {
282                break 'send_immediately;
283            };
284
285            let Ok(_) = driver.send(msg.clone()).await else {
286                break 'send_immediately;
287            };
288
289            return;
290        }
291
292        // Otherwise queue
293        if opts.ephemeral == false {
294            self.msg_queue.lock().await.push(msg.clone());
295        }
296
297        return;
298    }
299
300    pub async fn action(self: &Arc<Self>, method: &str, params: Vec<Value>) -> Result<Value> {
301        let id: i64 = self.rpc_counter.fetch_add(1, Ordering::SeqCst);
302
303        let (tx, rx) = oneshot::channel();
304        self.in_flight_rpcs.lock().await.insert(id, tx);
305
306        self.send_msg(
307            Arc::new(protocol::ToServer {
308                b: protocol::ToServerBody::RpcRequest {
309                    rr: protocol::RpcRequest {
310                        i: id,
311                        n: method.to_string(),
312                        a: params,
313                    },
314                },
315            }),
316            SendMsgOpts::default(),
317        )
318        .await;
319
320        // TODO: Support reconnection
321        let Ok(res) = rx.await else {
322            // Verbosity
323            return Err(anyhow::anyhow!("Socket closed during rpc"));
324        };
325
326        match res {
327            Ok(ok) => Ok(ok.o),
328            Err(err) => {
329                let metadata = err.md.unwrap_or(Value::Null);
330
331                Err(anyhow::anyhow!(
332                    "RPC Error({}): {:?}, {:#}",
333                    err.c,
334                    err.m,
335                    metadata
336                ))
337            }
338        }
339    }
340
341    async fn send_subscription(self: &Arc<Self>, event_name: String, subscribe: bool) {
342        self.send_msg(
343            Arc::new(protocol::ToServer {
344                b: protocol::ToServerBody::SubscriptionRequest {
345                    sr: protocol::SubscriptionRequest {
346                        e: event_name,
347                        s: subscribe,
348                    },
349                },
350            }),
351            SendMsgOpts { ephemeral: true },
352        )
353        .await;
354    }
355
356    async fn add_event_subscription(
357        self: &Arc<Self>,
358        event_name: String,
359        callback: Box<EventCallback>,
360    ) {
361        // TODO: Support for once
362        let mut listeners = self.event_subscriptions.lock().await;
363
364        let is_new_subscription = listeners.contains_key(&event_name) == false;
365
366        listeners
367            .entry(event_name.clone())
368            .or_insert(Vec::new())
369            .push(callback);
370
371        if is_new_subscription {
372            self.send_subscription(event_name, true).await;
373        }
374    }
375
376    pub async fn on_event<F>(self: &Arc<Self>, event_name: &str, callback: F)
377    where
378        F: Fn(&Vec<Value>) + Send + Sync + 'static,
379    {
380        self.add_event_subscription(event_name.to_string(), Box::new(callback))
381            .await
382    }
383
384    pub async fn disconnect(self: &Arc<Self>) {
385        if self.is_disconnecting() {
386            // We are already disconnecting
387            return;
388        }
389
390        self.dc_watch.0.send(true).ok();
391
392        if let Some(d) = self.driver.lock().await.deref() {
393            d.disconnect()
394        }
395        self.in_flight_rpcs.lock().await.clear();
396        self.event_subscriptions.lock().await.clear();
397        let Some(rx) = self.disconnection_rx.lock().await.take() else {
398            return;
399        };
400
401        rx.await.ok();
402    }
403}