Skip to main content

clawdentity_core/connector/
client.rs

1use std::cmp;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
4use std::time::{Duration, Instant};
5
6use futures_util::{Sink, SinkExt, StreamExt};
7use serde::Serialize;
8use tokio::sync::{mpsc, watch};
9use tokio_tungstenite::connect_async;
10use tokio_tungstenite::tungstenite::{Message, client::IntoClientRequest};
11
12use crate::connector_frames::{
13    CONNECTOR_FRAME_VERSION, ConnectorFrame, HeartbeatAckFrame, HeartbeatFrame, new_frame_id,
14    now_iso, parse_frame, serialize_frame,
15};
16use crate::error::{CoreError, Result};
17
18const DEFAULT_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(20);
19const DEFAULT_HEARTBEAT_ACK_TIMEOUT: Duration = Duration::from_secs(15);
20const DEFAULT_RECONNECT_MIN_DELAY: Duration = Duration::from_millis(500);
21const DEFAULT_RECONNECT_MAX_DELAY: Duration = Duration::from_secs(15);
22
23#[derive(Debug, Clone)]
24pub struct ConnectorClientOptions {
25    pub relay_connect_url: String,
26    pub headers: Vec<(String, String)>,
27    pub heartbeat_interval: Duration,
28    pub heartbeat_ack_timeout: Duration,
29    pub reconnect_min_delay: Duration,
30    pub reconnect_max_delay: Duration,
31}
32
33impl ConnectorClientOptions {
34    /// TODO(clawdentity): document `with_defaults`.
35    pub fn with_defaults(
36        relay_connect_url: impl Into<String>,
37        headers: Vec<(String, String)>,
38    ) -> Self {
39        Self {
40            relay_connect_url: relay_connect_url.into(),
41            headers,
42            heartbeat_interval: DEFAULT_HEARTBEAT_INTERVAL,
43            heartbeat_ack_timeout: DEFAULT_HEARTBEAT_ACK_TIMEOUT,
44            reconnect_min_delay: DEFAULT_RECONNECT_MIN_DELAY,
45            reconnect_max_delay: DEFAULT_RECONNECT_MAX_DELAY,
46        }
47    }
48}
49
50#[derive(Debug, Clone, Serialize)]
51pub struct ConnectorClientMetricsSnapshot {
52    pub connected: bool,
53    pub reconnect_attempts: u64,
54    pub heartbeat_sent: u64,
55    pub heartbeat_ack_timeouts: u64,
56}
57
58struct ConnectorClientMetrics {
59    connected: AtomicBool,
60    reconnect_attempts: AtomicU64,
61    heartbeat_sent: AtomicU64,
62    heartbeat_ack_timeouts: AtomicU64,
63}
64
65impl ConnectorClientMetrics {
66    fn new() -> Self {
67        Self {
68            connected: AtomicBool::new(false),
69            reconnect_attempts: AtomicU64::new(0),
70            heartbeat_sent: AtomicU64::new(0),
71            heartbeat_ack_timeouts: AtomicU64::new(0),
72        }
73    }
74
75    fn snapshot(&self) -> ConnectorClientMetricsSnapshot {
76        ConnectorClientMetricsSnapshot {
77            connected: self.connected.load(Ordering::SeqCst),
78            reconnect_attempts: self.reconnect_attempts.load(Ordering::SeqCst),
79            heartbeat_sent: self.heartbeat_sent.load(Ordering::SeqCst),
80            heartbeat_ack_timeouts: self.heartbeat_ack_timeouts.load(Ordering::SeqCst),
81        }
82    }
83}
84
85#[derive(Clone)]
86pub struct ConnectorClientSender {
87    sender: mpsc::Sender<ConnectorFrame>,
88    metrics: Arc<ConnectorClientMetrics>,
89    shutdown_tx: watch::Sender<bool>,
90}
91
92impl ConnectorClientSender {
93    /// TODO(clawdentity): document `send_frame`.
94    pub async fn send_frame(&self, frame: ConnectorFrame) -> Result<()> {
95        self.sender
96            .send(frame)
97            .await
98            .map_err(|_| CoreError::InvalidInput("connector client is not running".to_string()))
99    }
100
101    /// TODO(clawdentity): document `is_connected`.
102    pub fn is_connected(&self) -> bool {
103        self.metrics.connected.load(Ordering::SeqCst)
104    }
105
106    /// TODO(clawdentity): document `metrics_snapshot`.
107    pub fn metrics_snapshot(&self) -> ConnectorClientMetricsSnapshot {
108        self.metrics.snapshot()
109    }
110
111    /// TODO(clawdentity): document `shutdown`.
112    pub fn shutdown(&self) {
113        let _ = self.shutdown_tx.send(true);
114    }
115}
116
117pub struct ConnectorClient {
118    sender: ConnectorClientSender,
119    inbound_rx: mpsc::Receiver<ConnectorFrame>,
120}
121
122impl ConnectorClient {
123    /// TODO(clawdentity): document `sender`.
124    pub fn sender(&self) -> ConnectorClientSender {
125        self.sender.clone()
126    }
127
128    /// TODO(clawdentity): document `recv_frame`.
129    pub async fn recv_frame(&mut self) -> Option<ConnectorFrame> {
130        self.inbound_rx.recv().await
131    }
132}
133
134/// TODO(clawdentity): document `spawn_connector_client`.
135pub fn spawn_connector_client(options: ConnectorClientOptions) -> ConnectorClient {
136    let (outbound_tx, outbound_rx) = mpsc::channel::<ConnectorFrame>(256);
137    let (inbound_tx, inbound_rx) = mpsc::channel::<ConnectorFrame>(256);
138    let (shutdown_tx, shutdown_rx) = watch::channel(false);
139    let metrics = Arc::new(ConnectorClientMetrics::new());
140
141    tokio::spawn(run_connector_loop(
142        options,
143        outbound_rx,
144        inbound_tx,
145        metrics.clone(),
146        shutdown_rx,
147    ));
148
149    ConnectorClient {
150        sender: ConnectorClientSender {
151            sender: outbound_tx,
152            metrics,
153            shutdown_tx,
154        },
155        inbound_rx,
156    }
157}
158
159enum SessionExit {
160    Reconnect,
161    Shutdown,
162}
163
164#[allow(clippy::too_many_lines)]
165async fn run_connector_loop(
166    options: ConnectorClientOptions,
167    mut outbound_rx: mpsc::Receiver<ConnectorFrame>,
168    inbound_tx: mpsc::Sender<ConnectorFrame>,
169    metrics: Arc<ConnectorClientMetrics>,
170    mut shutdown_rx: watch::Receiver<bool>,
171) {
172    let mut backoff = options.reconnect_min_delay;
173    loop {
174        if *shutdown_rx.borrow() {
175            break;
176        }
177
178        let attempt = metrics.reconnect_attempts.fetch_add(1, Ordering::SeqCst) + 1;
179        tracing::info!(
180            relay_connect_url = %options.relay_connect_url,
181            attempt,
182            "connector websocket connect attempt"
183        );
184        let stream = match connect_socket(&options).await {
185            Ok(stream) => {
186                tracing::info!(
187                    relay_connect_url = %options.relay_connect_url,
188                    "connector websocket connected"
189                );
190                Some(stream)
191            }
192            Err(error) => {
193                tracing::warn!(
194                    relay_connect_url = %options.relay_connect_url,
195                    attempt,
196                    error = %error,
197                    "connector websocket connect failed"
198                );
199                None
200            }
201        };
202        if let Some(stream) = stream {
203            metrics.connected.store(true, Ordering::SeqCst);
204            let exit = run_socket_session(
205                stream,
206                &options,
207                &mut outbound_rx,
208                &inbound_tx,
209                metrics.clone(),
210                &mut shutdown_rx,
211            )
212            .await;
213            metrics.connected.store(false, Ordering::SeqCst);
214
215            match exit {
216                SessionExit::Shutdown => break,
217                SessionExit::Reconnect => {
218                    tracing::warn!(
219                        relay_connect_url = %options.relay_connect_url,
220                        "connector websocket session ended; reconnecting"
221                    );
222                    backoff = options.reconnect_min_delay;
223                }
224            }
225        }
226
227        if *shutdown_rx.borrow() {
228            break;
229        }
230
231        tokio::select! {
232            _ = shutdown_rx.changed() => {
233                if *shutdown_rx.borrow() {
234                    break;
235                }
236            }
237            _ = tokio::time::sleep(backoff) => {}
238        }
239        backoff = next_backoff(backoff, options.reconnect_max_delay);
240    }
241}
242
243fn next_backoff(current: Duration, max: Duration) -> Duration {
244    let doubled = current.saturating_mul(2);
245    cmp::min(doubled, max)
246}
247
248fn heartbeat_ack_timed_out(
249    pending_heartbeat_ack: &Option<(String, Instant)>,
250    heartbeat_ack_timeout: Duration,
251) -> bool {
252    pending_heartbeat_ack
253        .as_ref()
254        .is_some_and(|(_, sent_at)| sent_at.elapsed() >= heartbeat_ack_timeout)
255}
256
257async fn connect_socket(
258    options: &ConnectorClientOptions,
259) -> Result<
260    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
261> {
262    let mut request = options
263        .relay_connect_url
264        .clone()
265        .into_client_request()
266        .map_err(|error| CoreError::InvalidInput(error.to_string()))?;
267
268    for (name, value) in &options.headers {
269        let header_name =
270            tokio_tungstenite::tungstenite::http::header::HeaderName::from_bytes(name.as_bytes())
271                .map_err(|error| CoreError::InvalidInput(error.to_string()))?;
272        let header_value =
273            tokio_tungstenite::tungstenite::http::header::HeaderValue::from_str(value)
274                .map_err(|error| CoreError::InvalidInput(error.to_string()))?;
275        request.headers_mut().insert(header_name, header_value);
276    }
277
278    let (stream, _response) = connect_async(request)
279        .await
280        .map_err(|error| CoreError::Http(error.to_string()))?;
281    Ok(stream)
282}
283
284#[allow(clippy::too_many_lines)]
285async fn run_socket_session(
286    stream: tokio_tungstenite::WebSocketStream<
287        tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
288    >,
289    options: &ConnectorClientOptions,
290    outbound_rx: &mut mpsc::Receiver<ConnectorFrame>,
291    inbound_tx: &mpsc::Sender<ConnectorFrame>,
292    metrics: Arc<ConnectorClientMetrics>,
293    shutdown_rx: &mut watch::Receiver<bool>,
294) -> SessionExit {
295    let (mut write, mut read) = stream.split();
296    let mut heartbeat_tick = tokio::time::interval(options.heartbeat_interval);
297    heartbeat_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
298
299    let mut pending_heartbeat_ack: Option<(String, Instant)> = None;
300
301    loop {
302        tokio::select! {
303            _ = shutdown_rx.changed() => {
304                if *shutdown_rx.borrow() {
305                    let _ = write.send(Message::Close(None)).await;
306                    return SessionExit::Shutdown;
307                }
308            }
309            outbound = outbound_rx.recv() => {
310                let Some(frame) = outbound else {
311                    let _ = write.send(Message::Close(None)).await;
312                    return SessionExit::Shutdown;
313                };
314                let payload = match serialize_frame(&frame) {
315                    Ok(payload) => payload,
316                    Err(_) => continue,
317                };
318                if write.send(Message::Text(payload.into())).await.is_err() {
319                    return SessionExit::Reconnect;
320                }
321            }
322            _ = heartbeat_tick.tick() => {
323                if heartbeat_ack_timed_out(&pending_heartbeat_ack, options.heartbeat_ack_timeout) {
324                    metrics
325                        .heartbeat_ack_timeouts
326                        .fetch_add(1, Ordering::SeqCst);
327                    tracing::warn!("connector heartbeat ack timeout; reconnecting");
328                    return SessionExit::Reconnect;
329                }
330
331                if pending_heartbeat_ack.is_some() {
332                    continue;
333                }
334
335                let heartbeat = ConnectorFrame::Heartbeat(HeartbeatFrame {
336                    v: CONNECTOR_FRAME_VERSION,
337                    id: new_frame_id(),
338                    ts: now_iso(),
339                });
340                let frame_id = match &heartbeat {
341                    ConnectorFrame::Heartbeat(frame) => frame.id.clone(),
342                    _ => String::new(),
343                };
344                let payload = match serialize_frame(&heartbeat) {
345                    Ok(payload) => payload,
346                    Err(_) => continue,
347                };
348                if write.send(Message::Text(payload.into())).await.is_err() {
349                    return SessionExit::Reconnect;
350                }
351                metrics.heartbeat_sent.fetch_add(1, Ordering::SeqCst);
352                pending_heartbeat_ack = Some((frame_id, Instant::now()));
353            }
354            incoming = read.next() => {
355                match incoming {
356                    Some(Ok(Message::Text(text))) => {
357                        if handle_incoming_frame(
358                            &text,
359                            &mut write,
360                            inbound_tx,
361                            &mut pending_heartbeat_ack,
362                        ).await.is_err() {
363                            return SessionExit::Reconnect;
364                        }
365                    }
366                    Some(Ok(Message::Binary(bytes))) => {
367                        if handle_incoming_frame(
368                            &bytes,
369                            &mut write,
370                            inbound_tx,
371                            &mut pending_heartbeat_ack,
372                        ).await.is_err() {
373                            return SessionExit::Reconnect;
374                        }
375                    }
376                    Some(Ok(Message::Ping(payload))) => {
377                        if write.send(Message::Pong(payload)).await.is_err() {
378                            return SessionExit::Reconnect;
379                        }
380                    }
381                    Some(Ok(Message::Close(_))) => {
382                        return SessionExit::Reconnect;
383                    }
384                    Some(Ok(Message::Pong(_))) => {}
385                    Some(Ok(Message::Frame(_))) => {}
386                    Some(Err(_)) | None => {
387                        return SessionExit::Reconnect;
388                    }
389                }
390            }
391        }
392
393        if heartbeat_ack_timed_out(&pending_heartbeat_ack, options.heartbeat_ack_timeout) {
394            metrics
395                .heartbeat_ack_timeouts
396                .fetch_add(1, Ordering::SeqCst);
397            tracing::warn!("connector heartbeat ack timeout; reconnecting");
398            return SessionExit::Reconnect;
399        }
400    }
401}
402
403async fn handle_incoming_frame(
404    payload: impl AsRef<[u8]>,
405    write: &mut (impl Sink<Message, Error = tokio_tungstenite::tungstenite::Error> + Unpin),
406    inbound_tx: &mpsc::Sender<ConnectorFrame>,
407    pending_heartbeat_ack: &mut Option<(String, Instant)>,
408) -> Result<()> {
409    let frame = parse_frame(payload)?;
410    match &frame {
411        ConnectorFrame::Heartbeat(heartbeat) => {
412            let ack = ConnectorFrame::HeartbeatAck(HeartbeatAckFrame {
413                v: CONNECTOR_FRAME_VERSION,
414                id: new_frame_id(),
415                ts: now_iso(),
416                ack_id: heartbeat.id.clone(),
417            });
418            let payload = serialize_frame(&ack)?;
419            write
420                .send(Message::Text(payload.into()))
421                .await
422                .map_err(|error| CoreError::Http(error.to_string()))?;
423        }
424        ConnectorFrame::HeartbeatAck(ack) => {
425            if let Some((pending_id, _)) = pending_heartbeat_ack
426                && pending_id == &ack.ack_id
427            {
428                *pending_heartbeat_ack = None;
429            }
430        }
431        _ => {
432            let _ = inbound_tx.send(frame).await;
433        }
434    }
435    Ok(())
436}
437
438#[cfg(test)]
439mod tests {
440    use std::time::{Duration, Instant};
441
442    use super::{ConnectorClientOptions, heartbeat_ack_timed_out, spawn_connector_client};
443
444    #[tokio::test]
445    async fn client_sender_exposes_default_metrics_snapshot() {
446        let client = spawn_connector_client(ConnectorClientOptions::with_defaults(
447            "ws://127.0.0.1:9/v1/relay/connect",
448            vec![],
449        ));
450        tokio::time::sleep(Duration::from_millis(50)).await;
451        let snapshot = client.sender().metrics_snapshot();
452        assert!(!snapshot.connected);
453        assert!(snapshot.reconnect_attempts >= 1);
454        client.sender().shutdown();
455    }
456
457    #[test]
458    fn heartbeat_ack_timeout_helper_handles_missing_pending_ack() {
459        let timed_out = heartbeat_ack_timed_out(&None, Duration::from_secs(15));
460        assert!(!timed_out);
461    }
462
463    #[test]
464    fn heartbeat_ack_timeout_helper_detects_expired_ack() {
465        let pending = Some(("hb-1".to_string(), Instant::now() - Duration::from_secs(20)));
466        let timed_out = heartbeat_ack_timed_out(&pending, Duration::from_secs(15));
467        assert!(timed_out);
468    }
469
470    #[test]
471    fn heartbeat_ack_timeout_helper_allows_recent_ack() {
472        let pending = Some(("hb-1".to_string(), Instant::now()));
473        let timed_out = heartbeat_ack_timed_out(&pending, Duration::from_secs(15));
474        assert!(!timed_out);
475    }
476}