clawdentity_core/connector/
client.rs1use std::cmp;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
4use std::time::{Duration, Instant};
5
6use futures_util::{Sink, SinkExt, StreamExt};
7use serde::Serialize;
8use tokio::sync::{mpsc, watch};
9use tokio_tungstenite::connect_async;
10use tokio_tungstenite::tungstenite::{Message, client::IntoClientRequest};
11
12use crate::connector_frames::{
13 CONNECTOR_FRAME_VERSION, ConnectorFrame, HeartbeatAckFrame, HeartbeatFrame, new_frame_id,
14 now_iso, parse_frame, serialize_frame,
15};
16use crate::error::{CoreError, Result};
17
18const DEFAULT_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(20);
19const DEFAULT_HEARTBEAT_ACK_TIMEOUT: Duration = Duration::from_secs(15);
20const DEFAULT_RECONNECT_MIN_DELAY: Duration = Duration::from_millis(500);
21const DEFAULT_RECONNECT_MAX_DELAY: Duration = Duration::from_secs(15);
22
23#[derive(Debug, Clone)]
24pub struct ConnectorClientOptions {
25 pub relay_connect_url: String,
26 pub headers: Vec<(String, String)>,
27 pub heartbeat_interval: Duration,
28 pub heartbeat_ack_timeout: Duration,
29 pub reconnect_min_delay: Duration,
30 pub reconnect_max_delay: Duration,
31}
32
33impl ConnectorClientOptions {
34 pub fn with_defaults(
36 relay_connect_url: impl Into<String>,
37 headers: Vec<(String, String)>,
38 ) -> Self {
39 Self {
40 relay_connect_url: relay_connect_url.into(),
41 headers,
42 heartbeat_interval: DEFAULT_HEARTBEAT_INTERVAL,
43 heartbeat_ack_timeout: DEFAULT_HEARTBEAT_ACK_TIMEOUT,
44 reconnect_min_delay: DEFAULT_RECONNECT_MIN_DELAY,
45 reconnect_max_delay: DEFAULT_RECONNECT_MAX_DELAY,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize)]
51pub struct ConnectorClientMetricsSnapshot {
52 pub connected: bool,
53 pub reconnect_attempts: u64,
54 pub heartbeat_sent: u64,
55 pub heartbeat_ack_timeouts: u64,
56}
57
58struct ConnectorClientMetrics {
59 connected: AtomicBool,
60 reconnect_attempts: AtomicU64,
61 heartbeat_sent: AtomicU64,
62 heartbeat_ack_timeouts: AtomicU64,
63}
64
65impl ConnectorClientMetrics {
66 fn new() -> Self {
67 Self {
68 connected: AtomicBool::new(false),
69 reconnect_attempts: AtomicU64::new(0),
70 heartbeat_sent: AtomicU64::new(0),
71 heartbeat_ack_timeouts: AtomicU64::new(0),
72 }
73 }
74
75 fn snapshot(&self) -> ConnectorClientMetricsSnapshot {
76 ConnectorClientMetricsSnapshot {
77 connected: self.connected.load(Ordering::SeqCst),
78 reconnect_attempts: self.reconnect_attempts.load(Ordering::SeqCst),
79 heartbeat_sent: self.heartbeat_sent.load(Ordering::SeqCst),
80 heartbeat_ack_timeouts: self.heartbeat_ack_timeouts.load(Ordering::SeqCst),
81 }
82 }
83}
84
85#[derive(Clone)]
86pub struct ConnectorClientSender {
87 sender: mpsc::Sender<ConnectorFrame>,
88 metrics: Arc<ConnectorClientMetrics>,
89 shutdown_tx: watch::Sender<bool>,
90}
91
92impl ConnectorClientSender {
93 pub async fn send_frame(&self, frame: ConnectorFrame) -> Result<()> {
95 self.sender
96 .send(frame)
97 .await
98 .map_err(|_| CoreError::InvalidInput("connector client is not running".to_string()))
99 }
100
101 pub fn is_connected(&self) -> bool {
103 self.metrics.connected.load(Ordering::SeqCst)
104 }
105
106 pub fn metrics_snapshot(&self) -> ConnectorClientMetricsSnapshot {
108 self.metrics.snapshot()
109 }
110
111 pub fn shutdown(&self) {
113 let _ = self.shutdown_tx.send(true);
114 }
115}
116
117pub struct ConnectorClient {
118 sender: ConnectorClientSender,
119 inbound_rx: mpsc::Receiver<ConnectorFrame>,
120}
121
122impl ConnectorClient {
123 pub fn sender(&self) -> ConnectorClientSender {
125 self.sender.clone()
126 }
127
128 pub async fn recv_frame(&mut self) -> Option<ConnectorFrame> {
130 self.inbound_rx.recv().await
131 }
132}
133
134pub fn spawn_connector_client(options: ConnectorClientOptions) -> ConnectorClient {
136 let (outbound_tx, outbound_rx) = mpsc::channel::<ConnectorFrame>(256);
137 let (inbound_tx, inbound_rx) = mpsc::channel::<ConnectorFrame>(256);
138 let (shutdown_tx, shutdown_rx) = watch::channel(false);
139 let metrics = Arc::new(ConnectorClientMetrics::new());
140
141 tokio::spawn(run_connector_loop(
142 options,
143 outbound_rx,
144 inbound_tx,
145 metrics.clone(),
146 shutdown_rx,
147 ));
148
149 ConnectorClient {
150 sender: ConnectorClientSender {
151 sender: outbound_tx,
152 metrics,
153 shutdown_tx,
154 },
155 inbound_rx,
156 }
157}
158
159enum SessionExit {
160 Reconnect,
161 Shutdown,
162}
163
164#[allow(clippy::too_many_lines)]
165async fn run_connector_loop(
166 options: ConnectorClientOptions,
167 mut outbound_rx: mpsc::Receiver<ConnectorFrame>,
168 inbound_tx: mpsc::Sender<ConnectorFrame>,
169 metrics: Arc<ConnectorClientMetrics>,
170 mut shutdown_rx: watch::Receiver<bool>,
171) {
172 let mut backoff = options.reconnect_min_delay;
173 loop {
174 if *shutdown_rx.borrow() {
175 break;
176 }
177
178 let attempt = metrics.reconnect_attempts.fetch_add(1, Ordering::SeqCst) + 1;
179 tracing::info!(
180 relay_connect_url = %options.relay_connect_url,
181 attempt,
182 "connector websocket connect attempt"
183 );
184 let stream = match connect_socket(&options).await {
185 Ok(stream) => {
186 tracing::info!(
187 relay_connect_url = %options.relay_connect_url,
188 "connector websocket connected"
189 );
190 Some(stream)
191 }
192 Err(error) => {
193 tracing::warn!(
194 relay_connect_url = %options.relay_connect_url,
195 attempt,
196 error = %error,
197 "connector websocket connect failed"
198 );
199 None
200 }
201 };
202 if let Some(stream) = stream {
203 metrics.connected.store(true, Ordering::SeqCst);
204 let exit = run_socket_session(
205 stream,
206 &options,
207 &mut outbound_rx,
208 &inbound_tx,
209 metrics.clone(),
210 &mut shutdown_rx,
211 )
212 .await;
213 metrics.connected.store(false, Ordering::SeqCst);
214
215 match exit {
216 SessionExit::Shutdown => break,
217 SessionExit::Reconnect => {
218 tracing::warn!(
219 relay_connect_url = %options.relay_connect_url,
220 "connector websocket session ended; reconnecting"
221 );
222 backoff = options.reconnect_min_delay;
223 }
224 }
225 }
226
227 if *shutdown_rx.borrow() {
228 break;
229 }
230
231 tokio::select! {
232 _ = shutdown_rx.changed() => {
233 if *shutdown_rx.borrow() {
234 break;
235 }
236 }
237 _ = tokio::time::sleep(backoff) => {}
238 }
239 backoff = next_backoff(backoff, options.reconnect_max_delay);
240 }
241}
242
243fn next_backoff(current: Duration, max: Duration) -> Duration {
244 let doubled = current.saturating_mul(2);
245 cmp::min(doubled, max)
246}
247
248fn heartbeat_ack_timed_out(
249 pending_heartbeat_ack: &Option<(String, Instant)>,
250 heartbeat_ack_timeout: Duration,
251) -> bool {
252 pending_heartbeat_ack
253 .as_ref()
254 .is_some_and(|(_, sent_at)| sent_at.elapsed() >= heartbeat_ack_timeout)
255}
256
257async fn connect_socket(
258 options: &ConnectorClientOptions,
259) -> Result<
260 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
261> {
262 let mut request = options
263 .relay_connect_url
264 .clone()
265 .into_client_request()
266 .map_err(|error| CoreError::InvalidInput(error.to_string()))?;
267
268 for (name, value) in &options.headers {
269 let header_name =
270 tokio_tungstenite::tungstenite::http::header::HeaderName::from_bytes(name.as_bytes())
271 .map_err(|error| CoreError::InvalidInput(error.to_string()))?;
272 let header_value =
273 tokio_tungstenite::tungstenite::http::header::HeaderValue::from_str(value)
274 .map_err(|error| CoreError::InvalidInput(error.to_string()))?;
275 request.headers_mut().insert(header_name, header_value);
276 }
277
278 let (stream, _response) = connect_async(request)
279 .await
280 .map_err(|error| CoreError::Http(error.to_string()))?;
281 Ok(stream)
282}
283
284#[allow(clippy::too_many_lines)]
285async fn run_socket_session(
286 stream: tokio_tungstenite::WebSocketStream<
287 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
288 >,
289 options: &ConnectorClientOptions,
290 outbound_rx: &mut mpsc::Receiver<ConnectorFrame>,
291 inbound_tx: &mpsc::Sender<ConnectorFrame>,
292 metrics: Arc<ConnectorClientMetrics>,
293 shutdown_rx: &mut watch::Receiver<bool>,
294) -> SessionExit {
295 let (mut write, mut read) = stream.split();
296 let mut heartbeat_tick = tokio::time::interval(options.heartbeat_interval);
297 heartbeat_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
298
299 let mut pending_heartbeat_ack: Option<(String, Instant)> = None;
300
301 loop {
302 tokio::select! {
303 _ = shutdown_rx.changed() => {
304 if *shutdown_rx.borrow() {
305 let _ = write.send(Message::Close(None)).await;
306 return SessionExit::Shutdown;
307 }
308 }
309 outbound = outbound_rx.recv() => {
310 let Some(frame) = outbound else {
311 let _ = write.send(Message::Close(None)).await;
312 return SessionExit::Shutdown;
313 };
314 let payload = match serialize_frame(&frame) {
315 Ok(payload) => payload,
316 Err(_) => continue,
317 };
318 if write.send(Message::Text(payload.into())).await.is_err() {
319 return SessionExit::Reconnect;
320 }
321 }
322 _ = heartbeat_tick.tick() => {
323 if heartbeat_ack_timed_out(&pending_heartbeat_ack, options.heartbeat_ack_timeout) {
324 metrics
325 .heartbeat_ack_timeouts
326 .fetch_add(1, Ordering::SeqCst);
327 tracing::warn!("connector heartbeat ack timeout; reconnecting");
328 return SessionExit::Reconnect;
329 }
330
331 if pending_heartbeat_ack.is_some() {
332 continue;
333 }
334
335 let heartbeat = ConnectorFrame::Heartbeat(HeartbeatFrame {
336 v: CONNECTOR_FRAME_VERSION,
337 id: new_frame_id(),
338 ts: now_iso(),
339 });
340 let frame_id = match &heartbeat {
341 ConnectorFrame::Heartbeat(frame) => frame.id.clone(),
342 _ => String::new(),
343 };
344 let payload = match serialize_frame(&heartbeat) {
345 Ok(payload) => payload,
346 Err(_) => continue,
347 };
348 if write.send(Message::Text(payload.into())).await.is_err() {
349 return SessionExit::Reconnect;
350 }
351 metrics.heartbeat_sent.fetch_add(1, Ordering::SeqCst);
352 pending_heartbeat_ack = Some((frame_id, Instant::now()));
353 }
354 incoming = read.next() => {
355 match incoming {
356 Some(Ok(Message::Text(text))) => {
357 if handle_incoming_frame(
358 &text,
359 &mut write,
360 inbound_tx,
361 &mut pending_heartbeat_ack,
362 ).await.is_err() {
363 return SessionExit::Reconnect;
364 }
365 }
366 Some(Ok(Message::Binary(bytes))) => {
367 if handle_incoming_frame(
368 &bytes,
369 &mut write,
370 inbound_tx,
371 &mut pending_heartbeat_ack,
372 ).await.is_err() {
373 return SessionExit::Reconnect;
374 }
375 }
376 Some(Ok(Message::Ping(payload))) => {
377 if write.send(Message::Pong(payload)).await.is_err() {
378 return SessionExit::Reconnect;
379 }
380 }
381 Some(Ok(Message::Close(_))) => {
382 return SessionExit::Reconnect;
383 }
384 Some(Ok(Message::Pong(_))) => {}
385 Some(Ok(Message::Frame(_))) => {}
386 Some(Err(_)) | None => {
387 return SessionExit::Reconnect;
388 }
389 }
390 }
391 }
392
393 if heartbeat_ack_timed_out(&pending_heartbeat_ack, options.heartbeat_ack_timeout) {
394 metrics
395 .heartbeat_ack_timeouts
396 .fetch_add(1, Ordering::SeqCst);
397 tracing::warn!("connector heartbeat ack timeout; reconnecting");
398 return SessionExit::Reconnect;
399 }
400 }
401}
402
403async fn handle_incoming_frame(
404 payload: impl AsRef<[u8]>,
405 write: &mut (impl Sink<Message, Error = tokio_tungstenite::tungstenite::Error> + Unpin),
406 inbound_tx: &mpsc::Sender<ConnectorFrame>,
407 pending_heartbeat_ack: &mut Option<(String, Instant)>,
408) -> Result<()> {
409 let frame = parse_frame(payload)?;
410 match &frame {
411 ConnectorFrame::Heartbeat(heartbeat) => {
412 let ack = ConnectorFrame::HeartbeatAck(HeartbeatAckFrame {
413 v: CONNECTOR_FRAME_VERSION,
414 id: new_frame_id(),
415 ts: now_iso(),
416 ack_id: heartbeat.id.clone(),
417 });
418 let payload = serialize_frame(&ack)?;
419 write
420 .send(Message::Text(payload.into()))
421 .await
422 .map_err(|error| CoreError::Http(error.to_string()))?;
423 }
424 ConnectorFrame::HeartbeatAck(ack) => {
425 if let Some((pending_id, _)) = pending_heartbeat_ack
426 && pending_id == &ack.ack_id
427 {
428 *pending_heartbeat_ack = None;
429 }
430 }
431 _ => {
432 let _ = inbound_tx.send(frame).await;
433 }
434 }
435 Ok(())
436}
437
438#[cfg(test)]
439mod tests {
440 use std::time::{Duration, Instant};
441
442 use super::{ConnectorClientOptions, heartbeat_ack_timed_out, spawn_connector_client};
443
444 #[tokio::test]
445 async fn client_sender_exposes_default_metrics_snapshot() {
446 let client = spawn_connector_client(ConnectorClientOptions::with_defaults(
447 "ws://127.0.0.1:9/v1/relay/connect",
448 vec![],
449 ));
450 tokio::time::sleep(Duration::from_millis(50)).await;
451 let snapshot = client.sender().metrics_snapshot();
452 assert!(!snapshot.connected);
453 assert!(snapshot.reconnect_attempts >= 1);
454 client.sender().shutdown();
455 }
456
457 #[test]
458 fn heartbeat_ack_timeout_helper_handles_missing_pending_ack() {
459 let timed_out = heartbeat_ack_timed_out(&None, Duration::from_secs(15));
460 assert!(!timed_out);
461 }
462
463 #[test]
464 fn heartbeat_ack_timeout_helper_detects_expired_ack() {
465 let pending = Some(("hb-1".to_string(), Instant::now() - Duration::from_secs(20)));
466 let timed_out = heartbeat_ack_timed_out(&pending, Duration::from_secs(15));
467 assert!(timed_out);
468 }
469
470 #[test]
471 fn heartbeat_ack_timeout_helper_allows_recent_ack() {
472 let pending = Some(("hb-1".to_string(), Instant::now()));
473 let timed_out = heartbeat_ack_timed_out(&pending, Duration::from_secs(15));
474 assert!(!timed_out);
475 }
476}