1#[cfg(feature = "opentelemetry")]
6use super::trace;
7use crate::transport::error::{NetworkError, NetworkResult};
8use actr_protocol::prost::Message as ProstMessage;
9use actr_protocol::{
10 AIdCredential, ActrId, ActrIdExt, ActrToSignaling, PeerToSignaling, Ping, RegisterRequest,
11 RegisterResponse, RouteCandidatesRequest, RouteCandidatesResponse, ServiceAvailabilityState,
12 SignalingEnvelope, UnregisterRequest, UnregisterResponse, actr_to_signaling, peer_to_signaling,
13 signaling_envelope, signaling_to_actr,
14};
15use async_trait::async_trait;
16use futures_util::{SinkExt, StreamExt};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::sync::{
20 Arc,
21 atomic::{AtomicBool, AtomicU64, Ordering},
22};
23use tokio::net::TcpStream;
24use tokio::sync::{mpsc, oneshot, watch};
25use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
26use tokio_util::sync::CancellationToken;
27use tracing::instrument;
28use url::Url;
29
30const RESPONSE_TIMEOUT_SECS: u64 = 5;
36
37#[derive(Debug, Clone)]
43pub struct SignalingConfig {
44 pub server_url: Url,
46
47 pub connection_timeout: u64,
49
50 pub heartbeat_interval: u64,
52
53 pub reconnect_config: ReconnectConfig,
55
56 pub auth_config: Option<AuthConfig>,
58}
59
60#[derive(Debug, Clone)]
62pub struct ReconnectConfig {
63 pub enabled: bool,
65
66 pub max_attempts: u32,
68
69 pub initial_delay: u64,
71
72 pub max_delay: u64,
74
75 pub backoff_multiplier: f64,
77}
78
79impl Default for ReconnectConfig {
80 fn default() -> Self {
81 Self {
82 enabled: true,
83 max_attempts: 10,
84 initial_delay: 1,
85 max_delay: 60,
86 backoff_multiplier: 2.0,
87 }
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct AuthConfig {
94 pub auth_type: AuthType,
96
97 pub credentials: HashMap<String, String>,
99}
100
101#[derive(Debug, Clone)]
103pub enum AuthType {
104 None,
106 BearerToken,
108 ApiKey,
110 Jwt,
112}
113
114#[async_trait]
124pub trait SignalingClient: Send + Sync {
125 async fn connect(&self) -> NetworkResult<()>;
127
128 async fn disconnect(&self) -> NetworkResult<()>;
130
131 async fn send_register_request(
133 &self,
134 request: RegisterRequest,
135 ) -> NetworkResult<RegisterResponse>;
136
137 async fn send_unregister_request(
142 &self,
143 actor_id: ActrId,
144 credential: AIdCredential,
145 reason: Option<String>,
146 ) -> NetworkResult<UnregisterResponse>;
147
148 async fn send_heartbeat(
150 &self,
151 actor_id: ActrId,
152 credential: AIdCredential,
153 availability: ServiceAvailabilityState,
154 power_reserve: f32,
155 mailbox_backlog: f32,
156 ) -> NetworkResult<()>;
157
158 async fn send_route_candidates_request(
160 &self,
161 actor_id: ActrId,
162 credential: AIdCredential,
163 request: RouteCandidatesRequest,
164 ) -> NetworkResult<RouteCandidatesResponse>;
165
166 async fn send_envelope(&self, envelope: SignalingEnvelope) -> NetworkResult<()>;
168
169 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>>;
171
172 fn is_connected(&self) -> bool;
174
175 fn get_stats(&self) -> SignalingStats;
177 fn subscribe_state(&self) -> watch::Receiver<ConnectionState>;
179}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq)]
183pub enum ConnectionState {
184 Disconnected,
185 Connected,
186}
187
188pub struct WebSocketSignalingClient {
190 config: SignalingConfig,
191 ws_sink: tokio::sync::Mutex<
193 Option<
194 futures_util::stream::SplitSink<
195 WebSocketStream<MaybeTlsStream<TcpStream>>,
196 tokio_tungstenite::tungstenite::Message,
197 >,
198 >,
199 >,
200 ws_stream: tokio::sync::Mutex<
202 Option<futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>,
203 >,
204 connected: Arc<AtomicBool>,
206 stats: Arc<AtomicSignalingStats>,
208 envelope_counter: tokio::sync::Mutex<u64>,
210 pending_replies: Arc<tokio::sync::Mutex<HashMap<String, oneshot::Sender<SignalingEnvelope>>>>,
212 inbound_rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<SignalingEnvelope>>>,
214 inbound_tx: tokio::sync::Mutex<mpsc::UnboundedSender<SignalingEnvelope>>,
215 receiver_task: tokio::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
217 state_tx: watch::Sender<ConnectionState>,
219}
220
221impl WebSocketSignalingClient {
222 pub fn new(config: SignalingConfig) -> Self {
224 let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
225 let (state_tx, _state_rx) = watch::channel(ConnectionState::Disconnected);
226 Self {
227 config,
228 ws_sink: tokio::sync::Mutex::new(None),
229 ws_stream: tokio::sync::Mutex::new(None),
230 connected: Arc::new(AtomicBool::new(false)),
231 stats: Arc::new(AtomicSignalingStats::default()),
232 envelope_counter: tokio::sync::Mutex::new(0),
233 pending_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
234 inbound_rx: Arc::new(tokio::sync::Mutex::new(inbound_rx)),
235 inbound_tx: tokio::sync::Mutex::new(inbound_tx),
236 receiver_task: tokio::sync::Mutex::new(None),
237 state_tx,
238 }
239 }
240
241 pub async fn connect_to(url: &str) -> NetworkResult<Self> {
243 let config = SignalingConfig {
244 server_url: url.parse()?,
245 connection_timeout: 30,
246 heartbeat_interval: 30,
247 reconnect_config: ReconnectConfig::default(),
248 auth_config: None,
249 };
250
251 let client = Self::new(config);
252 client.connect().await?;
253 Ok(client)
254 }
255
256 async fn next_envelope_id(&self) -> String {
258 let mut counter = self.envelope_counter.lock().await;
259 *counter += 1;
260 format!("env-{}", *counter)
261 }
262
263 async fn create_envelope(&self, flow: signaling_envelope::Flow) -> SignalingEnvelope {
265 SignalingEnvelope {
266 envelope_version: 1,
267 envelope_id: self.next_envelope_id().await,
268 reply_for: None,
269 timestamp: prost_types::Timestamp {
270 seconds: chrono::Utc::now().timestamp(),
271 nanos: 0,
272 },
273 traceparent: None,
274 tracestate: None,
275 flow: Some(flow),
276 }
277 }
278
279 async fn reset_inbound_channel(&self) {
281 let (tx, rx) = mpsc::unbounded_channel();
282 *self.inbound_tx.lock().await = tx;
283 *self.inbound_rx.lock().await = rx;
284 }
285
286 async fn establish_connection_once(&self) -> NetworkResult<()> {
290 let url = self.config.server_url.clone();
291 let timeout_secs = self.config.connection_timeout;
292
293 let connect_result = if timeout_secs == 0 {
295 connect_async(url.as_str()).await
296 } else {
297 let timeout_duration = std::time::Duration::from_secs(timeout_secs);
298 tokio::time::timeout(timeout_duration, connect_async(url.as_str()))
299 .await
300 .map_err(|_| {
301 NetworkError::ConnectionError(format!(
302 "Signaling connect timeout after {}s",
303 timeout_secs
304 ))
305 })?
306 }?;
307
308 let (ws_stream, _) = connect_result;
309
310 let (sink, stream) = ws_stream.split();
312
313 *self.ws_sink.lock().await = Some(sink);
314 *self.ws_stream.lock().await = Some(stream);
315 self.connected.store(true, Ordering::Release);
316 let _ = self.state_tx.send(ConnectionState::Connected);
318
319 self.stats.connections.fetch_add(1, Ordering::Relaxed);
320
321 Ok(())
322 }
323
324 async fn connect_with_retries(&self) -> NetworkResult<()> {
326 let cfg = &self.config.reconnect_config;
327
328 if !cfg.enabled {
330 return self.establish_connection_once().await;
331 }
332
333 let mut attempt: u32 = 0;
334 let mut delay_secs = cfg.initial_delay.max(1);
335
336 loop {
337 attempt += 1;
338
339 match self.establish_connection_once().await {
340 Ok(()) => {
341 return Ok(());
342 }
343 Err(e) => {
344 tracing::warn!("Signaling connect attempt {} failed: {e:?}", attempt);
345
346 if attempt >= cfg.max_attempts {
347 tracing::error!(
348 "Signaling connect failed after {} attempts, giving up",
349 attempt
350 );
351 return Err(e);
352 }
353
354 let sleep_secs = delay_secs.min(cfg.max_delay.max(1));
355 tracing::info!("Retry signaling connect after {}s", sleep_secs);
356 tokio::time::sleep(std::time::Duration::from_secs(sleep_secs)).await;
357
358 delay_secs = ((delay_secs as f64) * cfg.backoff_multiplier)
360 .round()
361 .max(1.0) as u64;
362 }
363 }
364 }
365 }
366
367 async fn send_envelope_and_wait_response(
369 &self,
370 envelope: SignalingEnvelope,
371 ) -> NetworkResult<SignalingEnvelope> {
372 let reply_for = envelope.envelope_id.clone();
373
374 let (tx, rx) = oneshot::channel();
376 self.pending_replies
377 .lock()
378 .await
379 .insert(reply_for.clone(), tx);
380
381 if let Err(e) = self.send_envelope(envelope).await {
382 self.pending_replies.lock().await.remove(&reply_for);
384 return Err(e);
385 }
386
387 let result =
388 tokio::time::timeout(std::time::Duration::from_secs(RESPONSE_TIMEOUT_SECS), rx).await;
389 if result.is_err() {
391 self.pending_replies.lock().await.remove(&reply_for);
392 }
393
394 let response_envelope = result
395 .map_err(|_| {
396 NetworkError::ConnectionError(
397 "Timed out waiting for signaling response".to_string(),
398 )
399 })?
400 .map_err(|_| {
401 NetworkError::ConnectionError(
402 "Receiver dropped while waiting for signaling response".to_string(),
403 )
404 })?;
405
406 Ok(response_envelope)
407 }
408
409 async fn start_receiver(&self) {
411 let mut stream_guard = self.ws_stream.lock().await;
412 if stream_guard.is_none() {
413 return;
414 }
415
416 let mut stream = stream_guard.take().expect("stream exists");
417 let pending = self.pending_replies.clone();
418 let inbound_tx = { self.inbound_tx.lock().await.clone() };
419 let stats = self.stats.clone();
420 let connected = self.connected.clone();
421 let state_tx = self.state_tx.clone();
422 tracing::debug!("Start receiver");
423 let handle = tokio::spawn(async move {
424 while let Some(msg) = stream.next().await {
425 match msg {
426 Ok(tokio_tungstenite::tungstenite::Message::Binary(data)) => {
427 match SignalingEnvelope::decode(&data[..]) {
428 Ok(envelope) => {
429 stats.messages_received.fetch_add(1, Ordering::Relaxed);
430 tracing::debug!("Received message: {:?}", envelope);
431 if let Some(reply_for) = envelope.reply_for.clone() {
432 let mut pending_guard = pending.lock().await;
433 if let Some(sender) = pending_guard.remove(&reply_for) {
434 if let Err(e) = sender.send(envelope) {
435 stats.errors.fetch_add(1, Ordering::Relaxed);
436 tracing::warn!(
437 "Failed to send reply envelope to waiter: {e:?}",
438 );
439 }
440 continue;
441 }
442 }
443 tracing::debug!(
444 "Unmatched or push message -> forward to inbound channel"
445 );
446 if let Err(e) = inbound_tx.send(envelope) {
448 stats.errors.fetch_add(1, Ordering::Relaxed);
449 tracing::warn!(
450 "Failed to send envelope to inbound channel: {e:?}"
451 );
452 }
453 }
454 Err(e) => {
455 stats.errors.fetch_add(1, Ordering::Relaxed);
456 tracing::warn!("Failed to decode SignalingEnvelope: {e}");
457 }
458 }
459 }
460 Ok(_) => {
461 tracing::warn!("Received non-binary frame, ignoring");
462 }
463 Err(e) => {
464 stats.errors.fetch_add(1, Ordering::Relaxed);
465 tracing::error!("Signaling receive error: {e}");
466 break;
467 }
468 }
469 }
470
471 connected.store(false, Ordering::Release);
473 stats.disconnections.fetch_add(1, Ordering::Relaxed);
474 let _ = state_tx.send(ConnectionState::Disconnected);
475 });
476
477 *self.receiver_task.lock().await = Some(handle);
478 }
479}
480
481#[async_trait]
482impl SignalingClient for WebSocketSignalingClient {
483 async fn connect(&self) -> NetworkResult<()> {
484 self.connect_with_retries().await?;
485 self.start_receiver().await;
486 Ok(())
487 }
488
489 async fn disconnect(&self) -> NetworkResult<()> {
490 let mut sink_guard = self.ws_sink.lock().await;
492 let mut stream_guard = self.ws_stream.lock().await;
493
494 if let Some(mut sink) = sink_guard.take() {
496 let _ = sink.close().await;
497 }
498
499 stream_guard.take();
501
502 if let Some(handle) = self.receiver_task.lock().await.take() {
504 handle.abort();
505 }
506
507 self.reset_inbound_channel().await;
508
509 self.connected.store(false, Ordering::Release);
510 self.stats.disconnections.fetch_add(1, Ordering::Relaxed);
511
512 Ok(())
513 }
514
515 async fn send_register_request(
516 &self,
517 request: RegisterRequest,
518 ) -> NetworkResult<RegisterResponse> {
519 let flow = signaling_envelope::Flow::PeerToServer(PeerToSignaling {
521 payload: Some(peer_to_signaling::Payload::RegisterRequest(request)),
522 });
523
524 let envelope = self.create_envelope(flow).await;
525 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
526
527 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
528 {
529 if let Some(signaling_to_actr::Payload::RegisterResponse(response)) =
530 server_to_actr.payload
531 {
532 return Ok(response);
533 }
534 }
535
536 Err(NetworkError::ConnectionError(
537 "Invalid registration response".to_string(),
538 ))
539 }
540
541 async fn send_unregister_request(
542 &self,
543 actor_id: ActrId,
544 credential: AIdCredential,
545 reason: Option<String>,
546 ) -> NetworkResult<UnregisterResponse> {
547 let request = UnregisterRequest {
549 actr_id: actor_id.clone(),
550 reason,
551 };
552
553 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
555 source: actor_id,
556 credential,
557 payload: Some(actr_to_signaling::Payload::UnregisterRequest(request)),
558 });
559
560 let envelope = self.create_envelope(flow).await;
562 self.send_envelope(envelope).await?;
563
564 Ok(UnregisterResponse {
569 result: Some(actr_protocol::unregister_response::Result::Success(
570 actr_protocol::unregister_response::UnregisterOk {},
571 )),
572 })
573 }
574
575 #[instrument(level = "debug", skip_all, fields(actor_id = %actor_id.to_string_repr()))]
576 async fn send_heartbeat(
577 &self,
578 actor_id: ActrId,
579 credential: AIdCredential,
580 availability: ServiceAvailabilityState,
581 power_reserve: f32,
582 mailbox_backlog: f32,
583 ) -> NetworkResult<()> {
584 let ping = Ping {
585 availability: availability as i32,
586 power_reserve,
587 mailbox_backlog,
588 sticky_client_ids: vec![], };
590
591 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
592 source: actor_id,
593 credential,
594 payload: Some(actr_to_signaling::Payload::Ping(ping)),
595 });
596
597 let envelope = self.create_envelope(flow).await;
598 self.send_envelope(envelope).await
599 }
600
601 async fn send_route_candidates_request(
602 &self,
603 actor_id: ActrId,
604 credential: AIdCredential,
605 request: RouteCandidatesRequest,
606 ) -> NetworkResult<RouteCandidatesResponse> {
607 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
608 source: actor_id,
609 credential,
610 payload: Some(actr_to_signaling::Payload::RouteCandidatesRequest(request)),
611 });
612
613 let envelope = self.create_envelope(flow).await;
614 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
615
616 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
617 {
618 match server_to_actr.payload {
619 Some(signaling_to_actr::Payload::RouteCandidatesResponse(response)) => {
620 return Ok(response);
621 }
622 Some(signaling_to_actr::Payload::Error(err)) => {
623 return Err(NetworkError::ServiceDiscoveryError(format!(
624 "{} ({})",
625 err.message, err.code
626 )));
627 }
628 _ => {}
629 }
630 }
631
632 Err(NetworkError::ConnectionError(
633 "Invalid route candidates response".to_string(),
634 ))
635 }
636
637 #[allow(unused_mut)]
638 #[tracing::instrument(
639 level = "debug",
640 skip_all,
641 fields(envelope_id = %envelope.envelope_id)
642 )]
643 async fn send_envelope(&self, mut envelope: SignalingEnvelope) -> NetworkResult<()> {
644 #[cfg(feature = "opentelemetry")]
645 trace::inject_span_context(&tracing::Span::current(), &mut envelope);
646
647 let mut sink_guard = self.ws_sink.lock().await;
648
649 if let Some(sink) = sink_guard.as_mut() {
650 let mut buf = Vec::new();
652 envelope.encode(&mut buf)?;
653 let msg = tokio_tungstenite::tungstenite::Message::Binary(buf.into());
654 sink.send(msg).await?;
655
656 self.stats.messages_sent.fetch_add(1, Ordering::Relaxed);
657
658 Ok(())
659 } else {
660 Err(NetworkError::ConnectionError("Not connected".to_string()))
661 }
662 }
663
664 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>> {
665 let mut rx = self.inbound_rx.lock().await;
666 match rx.recv().await {
667 Some(envelope) => Ok(Some(envelope)),
668 None => {
669 tracing::error!("Inbound channel closed");
670 Err(NetworkError::ConnectionError(
671 "Inbound channel closed".to_string(),
672 ))
673 }
674 }
675 }
676
677 fn is_connected(&self) -> bool {
678 self.connected.load(Ordering::Acquire)
679 }
680
681 fn get_stats(&self) -> SignalingStats {
682 self.stats.snapshot()
683 }
684
685 fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
686 self.state_tx.subscribe()
687 }
688}
689
690pub fn spawn_signaling_reconnector(client: Arc<dyn SignalingClient>, shutdown: CancellationToken) {
693 let mut state_rx = client.subscribe_state();
694
695 tokio::spawn(async move {
696 loop {
697 tokio::select! {
698 _ = shutdown.cancelled() => {
699 tracing::info!("🛑 Stopping signaling reconnect helper due to shutdown");
700 break;
701 }
702 changed = state_rx.changed() => {
703 if changed.is_err() {
704 tracing::info!("Signaling state channel closed, stopping reconnect helper");
705 break;
706 }
707
708 if *state_rx.borrow() == ConnectionState::Disconnected {
709 if shutdown.is_cancelled() {
711 tracing::info!(
712 "Shutdown already requested when disconnect event observed; skipping reconnect"
713 );
714 break;
715 }
716
717 tracing::warn!("📡 Signaling state is DISCONNECTED, attempting reconnect");
718 if let Err(e) = client.connect().await {
719 tracing::error!("❌ Signaling reconnect failed: {e}");
720 } else {
721 tracing::info!("✅ Signaling reconnect succeeded");
722 }
723
724 }
725 }
726 }
727 }
728 });
729}
730
731#[derive(Debug)]
733pub(crate) struct AtomicSignalingStats {
734 pub connections: AtomicU64,
736
737 pub disconnections: AtomicU64,
739
740 pub messages_sent: AtomicU64,
742
743 pub messages_received: AtomicU64,
745
746 pub heartbeats_sent: AtomicU64,
749
750 pub heartbeats_received: AtomicU64,
753
754 pub errors: AtomicU64,
756}
757
758impl Default for AtomicSignalingStats {
759 fn default() -> Self {
760 Self {
761 connections: AtomicU64::new(0),
762 disconnections: AtomicU64::new(0),
763 messages_sent: AtomicU64::new(0),
764 messages_received: AtomicU64::new(0),
765 heartbeats_sent: AtomicU64::new(0),
766 heartbeats_received: AtomicU64::new(0),
767 errors: AtomicU64::new(0),
768 }
769 }
770}
771
772#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
774pub struct SignalingStats {
775 pub connections: u64,
777
778 pub disconnections: u64,
780
781 pub messages_sent: u64,
783
784 pub messages_received: u64,
786
787 pub heartbeats_sent: u64,
789
790 pub heartbeats_received: u64,
792
793 pub errors: u64,
795}
796
797impl AtomicSignalingStats {
798 pub fn snapshot(&self) -> SignalingStats {
800 SignalingStats {
801 connections: self.connections.load(Ordering::Relaxed),
802 disconnections: self.disconnections.load(Ordering::Relaxed),
803 messages_sent: self.messages_sent.load(Ordering::Relaxed),
804 messages_received: self.messages_received.load(Ordering::Relaxed),
805 heartbeats_sent: self.heartbeats_sent.load(Ordering::Relaxed),
806 heartbeats_received: self.heartbeats_received.load(Ordering::Relaxed),
807 errors: self.errors.load(Ordering::Relaxed),
808 }
809 }
810}
811
812#[cfg(test)]
813mod tests {
814 use super::*;
815 use std::sync::atomic::{AtomicUsize, Ordering as UsizeOrdering};
816 use tokio_util::sync::CancellationToken;
817
818 struct FakeSignalingClient {
820 state_tx: watch::Sender<ConnectionState>,
821 connect_calls: Arc<AtomicUsize>,
822 }
823
824 #[async_trait]
825 impl SignalingClient for FakeSignalingClient {
826 async fn connect(&self) -> NetworkResult<()> {
827 self.connect_calls.fetch_add(1, UsizeOrdering::SeqCst);
828 Ok(())
829 }
830
831 async fn disconnect(&self) -> NetworkResult<()> {
832 Ok(())
833 }
834
835 async fn send_register_request(
836 &self,
837 _request: RegisterRequest,
838 ) -> NetworkResult<RegisterResponse> {
839 unimplemented!("not needed in tests");
840 }
841
842 async fn send_unregister_request(
843 &self,
844 _actor_id: ActrId,
845 _credential: AIdCredential,
846 _reason: Option<String>,
847 ) -> NetworkResult<UnregisterResponse> {
848 unimplemented!("not needed in tests");
849 }
850
851 async fn send_heartbeat(
852 &self,
853 _actor_id: ActrId,
854 _credential: AIdCredential,
855 _availability: ServiceAvailabilityState,
856 _power_reserve: f32,
857 _mailbox_backlog: f32,
858 ) -> NetworkResult<()> {
859 unimplemented!("not needed in tests");
860 }
861
862 async fn send_route_candidates_request(
863 &self,
864 _actor_id: ActrId,
865 _credential: AIdCredential,
866 _request: RouteCandidatesRequest,
867 ) -> NetworkResult<RouteCandidatesResponse> {
868 unimplemented!("not needed in tests");
869 }
870
871 async fn send_envelope(&self, _envelope: SignalingEnvelope) -> NetworkResult<()> {
872 unimplemented!("not needed in tests");
873 }
874
875 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>> {
876 unimplemented!("not needed in tests");
877 }
878
879 fn is_connected(&self) -> bool {
880 *self.state_tx.borrow() == ConnectionState::Connected
882 }
883
884 fn get_stats(&self) -> SignalingStats {
885 SignalingStats::default()
886 }
887
888 fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
889 self.state_tx.subscribe()
890 }
891 }
892
893 fn make_fake_client() -> (Arc<FakeSignalingClient>, watch::Sender<ConnectionState>) {
894 let (state_tx, _rx) = watch::channel(ConnectionState::Disconnected);
895 let client = Arc::new(FakeSignalingClient {
896 state_tx: state_tx.clone(),
897 connect_calls: Arc::new(AtomicUsize::new(0)),
898 });
899 (client, state_tx)
900 }
901
902 #[tokio::test]
903 async fn test_spawn_signaling_reconnector_does_not_trigger_on_connected() {
904 let (fake_client, state_tx) = make_fake_client();
905 let shutdown = CancellationToken::new();
906
907 let client_trait: Arc<dyn SignalingClient> = fake_client.clone();
909 spawn_signaling_reconnector(client_trait, shutdown.clone());
910
911 let _ = state_tx.send(ConnectionState::Connected);
913
914 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
915
916 let calls = fake_client.connect_calls.load(UsizeOrdering::SeqCst);
917 assert_eq!(
918 calls, 0,
919 "connect() should not be called on Connected state"
920 );
921
922 shutdown.cancel();
923 }
924
925 #[tokio::test]
926 async fn test_spawn_signaling_reconnector_triggers_connect_on_disconnect() {
927 let (fake_client, state_tx) = make_fake_client();
928 let shutdown = CancellationToken::new();
929
930 let client_trait: Arc<dyn SignalingClient> = fake_client.clone();
932 spawn_signaling_reconnector(client_trait, shutdown.clone());
933
934 let _ = state_tx.send(ConnectionState::Disconnected);
936
937 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
939
940 let calls = fake_client.connect_calls.load(UsizeOrdering::SeqCst);
941 assert!(
942 calls >= 1,
943 "expected at least one reconnect attempt, got {}",
944 calls
945 );
946
947 shutdown.cancel();
948 }
949
950 #[tokio::test]
951 async fn test_spawn_signaling_reconnector_stops_on_shutdown_before_disconnect() {
952 let (fake_client, state_tx) = make_fake_client();
953 let shutdown = CancellationToken::new();
954
955 let client_trait: Arc<dyn SignalingClient> = fake_client.clone();
956 spawn_signaling_reconnector(client_trait, shutdown.clone());
957
958 shutdown.cancel();
960
961 let _ = state_tx.send(ConnectionState::Disconnected);
963
964 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
965
966 let calls = fake_client.connect_calls.load(UsizeOrdering::SeqCst);
967 assert_eq!(calls, 0, "reconnect helper should not run after shutdown");
968 }
969
970 #[test]
971 fn test_websocket_signaling_client_initial_state_disconnected() {
972 let config = SignalingConfig {
974 server_url: Url::parse("ws://example.com/signaling/ws").unwrap(),
975 connection_timeout: 30,
976 heartbeat_interval: 30,
977 reconnect_config: ReconnectConfig::default(),
978 auth_config: None,
979 };
980
981 let client = WebSocketSignalingClient::new(config);
982 let state_rx = client.subscribe_state();
983 assert_eq!(*state_rx.borrow(), ConnectionState::Disconnected);
984 }
985}