1use crate::transport::error::{NetworkError, NetworkResult};
6use actr_protocol::prost::Message as ProstMessage;
7use actr_protocol::{
8 AIdCredential, ActrId, ActrToSignaling, PeerToSignaling, Ping, RegisterRequest,
9 RegisterResponse, RouteCandidatesRequest, RouteCandidatesResponse, ServiceAvailabilityState,
10 SignalingEnvelope, UnregisterRequest, UnregisterResponse, actr_to_signaling, peer_to_signaling,
11 signaling_envelope, signaling_to_actr,
12};
13use async_trait::async_trait;
14use futures_util::{SinkExt, StreamExt};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::{
18 Arc,
19 atomic::{AtomicBool, AtomicU64, Ordering},
20};
21use tokio::net::TcpStream;
22use tokio::sync::{mpsc, oneshot, watch};
23use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
24use tokio_util::sync::CancellationToken;
25use url::Url;
26
27const RESPONSE_TIMEOUT_SECS: u64 = 5;
33
34#[derive(Debug, Clone)]
40pub struct SignalingConfig {
41 pub server_url: Url,
43
44 pub connection_timeout: u64,
46
47 pub heartbeat_interval: u64,
49
50 pub reconnect_config: ReconnectConfig,
52
53 pub auth_config: Option<AuthConfig>,
55}
56
57#[derive(Debug, Clone)]
59pub struct ReconnectConfig {
60 pub enabled: bool,
62
63 pub max_attempts: u32,
65
66 pub initial_delay: u64,
68
69 pub max_delay: u64,
71
72 pub backoff_multiplier: f64,
74}
75
76impl Default for ReconnectConfig {
77 fn default() -> Self {
78 Self {
79 enabled: true,
80 max_attempts: 10,
81 initial_delay: 1,
82 max_delay: 60,
83 backoff_multiplier: 2.0,
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct AuthConfig {
91 pub auth_type: AuthType,
93
94 pub credentials: HashMap<String, String>,
96}
97
98#[derive(Debug, Clone)]
100pub enum AuthType {
101 None,
103 BearerToken,
105 ApiKey,
107 Jwt,
109}
110
111#[async_trait]
121pub trait SignalingClient: Send + Sync {
122 async fn connect(&self) -> NetworkResult<()>;
124
125 async fn disconnect(&self) -> NetworkResult<()>;
127
128 async fn send_register_request(
130 &self,
131 request: RegisterRequest,
132 ) -> NetworkResult<RegisterResponse>;
133
134 async fn send_unregister_request(
139 &self,
140 actor_id: ActrId,
141 credential: AIdCredential,
142 reason: Option<String>,
143 ) -> NetworkResult<UnregisterResponse>;
144
145 async fn send_heartbeat(
147 &self,
148 actor_id: ActrId,
149 credential: AIdCredential,
150 availability: ServiceAvailabilityState,
151 power_reserve: f32,
152 mailbox_backlog: f32,
153 ) -> NetworkResult<()>;
154
155 async fn send_route_candidates_request(
157 &self,
158 actor_id: ActrId,
159 credential: AIdCredential,
160 request: RouteCandidatesRequest,
161 ) -> NetworkResult<RouteCandidatesResponse>;
162
163 async fn send_envelope(&self, envelope: SignalingEnvelope) -> NetworkResult<()>;
165
166 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>>;
168
169 fn is_connected(&self) -> bool;
171
172 fn get_stats(&self) -> SignalingStats;
174 fn subscribe_state(&self) -> watch::Receiver<ConnectionState>;
176}
177
178#[derive(Debug, Clone, Copy, PartialEq, Eq)]
180pub enum ConnectionState {
181 Disconnected,
182 Connected,
183}
184
185pub struct WebSocketSignalingClient {
187 config: SignalingConfig,
188 ws_sink: tokio::sync::Mutex<
190 Option<
191 futures_util::stream::SplitSink<
192 WebSocketStream<MaybeTlsStream<TcpStream>>,
193 tokio_tungstenite::tungstenite::Message,
194 >,
195 >,
196 >,
197 ws_stream: tokio::sync::Mutex<
199 Option<futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>,
200 >,
201 connected: Arc<AtomicBool>,
203 stats: Arc<AtomicSignalingStats>,
205 envelope_counter: tokio::sync::Mutex<u64>,
207 pending_replies: Arc<tokio::sync::Mutex<HashMap<String, oneshot::Sender<SignalingEnvelope>>>>,
209 inbound_rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<SignalingEnvelope>>>,
211 inbound_tx: tokio::sync::Mutex<mpsc::UnboundedSender<SignalingEnvelope>>,
212 receiver_task: tokio::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
214 state_tx: watch::Sender<ConnectionState>,
216}
217
218impl WebSocketSignalingClient {
219 pub fn new(config: SignalingConfig) -> Self {
221 let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
222 let (state_tx, _state_rx) = watch::channel(ConnectionState::Disconnected);
223 Self {
224 config,
225 ws_sink: tokio::sync::Mutex::new(None),
226 ws_stream: tokio::sync::Mutex::new(None),
227 connected: Arc::new(AtomicBool::new(false)),
228 stats: Arc::new(AtomicSignalingStats::default()),
229 envelope_counter: tokio::sync::Mutex::new(0),
230 pending_replies: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
231 inbound_rx: Arc::new(tokio::sync::Mutex::new(inbound_rx)),
232 inbound_tx: tokio::sync::Mutex::new(inbound_tx),
233 receiver_task: tokio::sync::Mutex::new(None),
234 state_tx,
235 }
236 }
237
238 pub async fn connect_to(url: &str) -> NetworkResult<Self> {
240 let config = SignalingConfig {
241 server_url: url.parse()?,
242 connection_timeout: 30,
243 heartbeat_interval: 30,
244 reconnect_config: ReconnectConfig::default(),
245 auth_config: None,
246 };
247
248 let client = Self::new(config);
249 client.connect().await?;
250 Ok(client)
251 }
252
253 async fn next_envelope_id(&self) -> String {
255 let mut counter = self.envelope_counter.lock().await;
256 *counter += 1;
257 format!("env-{}", *counter)
258 }
259
260 async fn create_envelope(&self, flow: signaling_envelope::Flow) -> SignalingEnvelope {
262 SignalingEnvelope {
263 envelope_version: 1,
264 envelope_id: self.next_envelope_id().await,
265 reply_for: None,
266 timestamp: prost_types::Timestamp {
267 seconds: chrono::Utc::now().timestamp(),
268 nanos: 0,
269 },
270 flow: Some(flow),
271 }
272 }
273
274 async fn reset_inbound_channel(&self) {
276 let (tx, rx) = mpsc::unbounded_channel();
277 *self.inbound_tx.lock().await = tx;
278 *self.inbound_rx.lock().await = rx;
279 }
280
281 async fn establish_connection_once(&self) -> NetworkResult<()> {
285 let url = self.config.server_url.clone();
286 let timeout_secs = self.config.connection_timeout;
287
288 let connect_result = if timeout_secs == 0 {
290 connect_async(url.as_str()).await
291 } else {
292 let timeout_duration = std::time::Duration::from_secs(timeout_secs);
293 tokio::time::timeout(timeout_duration, connect_async(url.as_str()))
294 .await
295 .map_err(|_| {
296 NetworkError::ConnectionError(format!(
297 "Signaling connect timeout after {}s",
298 timeout_secs
299 ))
300 })?
301 }?;
302
303 let (ws_stream, _) = connect_result;
304
305 let (sink, stream) = ws_stream.split();
307
308 *self.ws_sink.lock().await = Some(sink);
309 *self.ws_stream.lock().await = Some(stream);
310 self.connected.store(true, Ordering::Release);
311 let _ = self.state_tx.send(ConnectionState::Connected);
313
314 self.stats.connections.fetch_add(1, Ordering::Relaxed);
315
316 Ok(())
317 }
318
319 async fn connect_with_retries(&self) -> NetworkResult<()> {
321 let cfg = &self.config.reconnect_config;
322
323 if !cfg.enabled {
325 return self.establish_connection_once().await;
326 }
327
328 let mut attempt: u32 = 0;
329 let mut delay_secs = cfg.initial_delay.max(1);
330
331 loop {
332 attempt += 1;
333
334 match self.establish_connection_once().await {
335 Ok(()) => {
336 return Ok(());
337 }
338 Err(e) => {
339 tracing::warn!("Signaling connect attempt {} failed: {e:?}", attempt);
340
341 if attempt >= cfg.max_attempts {
342 tracing::error!(
343 "Signaling connect failed after {} attempts, giving up",
344 attempt
345 );
346 return Err(e);
347 }
348
349 let sleep_secs = delay_secs.min(cfg.max_delay.max(1));
350 tracing::info!("Retry signaling connect after {}s", sleep_secs);
351 tokio::time::sleep(std::time::Duration::from_secs(sleep_secs)).await;
352
353 delay_secs = ((delay_secs as f64) * cfg.backoff_multiplier)
355 .round()
356 .max(1.0) as u64;
357 }
358 }
359 }
360 }
361
362 async fn send_envelope_and_wait_response(
364 &self,
365 envelope: SignalingEnvelope,
366 ) -> NetworkResult<SignalingEnvelope> {
367 let reply_for = envelope.envelope_id.clone();
368
369 let (tx, rx) = oneshot::channel();
371 self.pending_replies
372 .lock()
373 .await
374 .insert(reply_for.clone(), tx);
375
376 if let Err(e) = self.send_envelope(envelope).await {
377 self.pending_replies.lock().await.remove(&reply_for);
379 return Err(e);
380 }
381
382 let result =
383 tokio::time::timeout(std::time::Duration::from_secs(RESPONSE_TIMEOUT_SECS), rx).await;
384 if result.is_err() {
386 self.pending_replies.lock().await.remove(&reply_for);
387 }
388
389 let response_envelope = result
390 .map_err(|_| {
391 NetworkError::ConnectionError(
392 "Timed out waiting for signaling response".to_string(),
393 )
394 })?
395 .map_err(|_| {
396 NetworkError::ConnectionError(
397 "Receiver dropped while waiting for signaling response".to_string(),
398 )
399 })?;
400
401 Ok(response_envelope)
402 }
403
404 async fn start_receiver(&self) {
406 let mut stream_guard = self.ws_stream.lock().await;
407 if stream_guard.is_none() {
408 return;
409 }
410
411 let mut stream = stream_guard.take().expect("stream exists");
412 let pending = self.pending_replies.clone();
413 let inbound_tx = { self.inbound_tx.lock().await.clone() };
414 let stats = self.stats.clone();
415 let connected = self.connected.clone();
416 let state_tx = self.state_tx.clone();
417 tracing::debug!("Start receiver");
418 let handle = tokio::spawn(async move {
419 while let Some(msg) = stream.next().await {
420 match msg {
421 Ok(tokio_tungstenite::tungstenite::Message::Binary(data)) => {
422 match SignalingEnvelope::decode(&data[..]) {
423 Ok(envelope) => {
424 stats.messages_received.fetch_add(1, Ordering::Relaxed);
425 tracing::debug!("Received message: {:?}", envelope);
426 if let Some(reply_for) = envelope.reply_for.clone() {
427 let mut pending_guard = pending.lock().await;
428 if let Some(sender) = pending_guard.remove(&reply_for) {
429 if let Err(e) = sender.send(envelope) {
430 stats.errors.fetch_add(1, Ordering::Relaxed);
431 tracing::warn!(
432 "Failed to send reply envelope to waiter: {e:?}",
433 );
434 }
435 continue;
436 }
437 }
438 tracing::debug!(
439 "Unmatched or push message -> forward to inbound channel"
440 );
441 if let Err(e) = inbound_tx.send(envelope) {
443 stats.errors.fetch_add(1, Ordering::Relaxed);
444 tracing::warn!(
445 "Failed to send envelope to inbound channel: {e:?}"
446 );
447 }
448 }
449 Err(e) => {
450 stats.errors.fetch_add(1, Ordering::Relaxed);
451 tracing::warn!("Failed to decode SignalingEnvelope: {e}");
452 }
453 }
454 }
455 Ok(_) => {
456 tracing::warn!("Received non-binary frame, ignoring");
457 }
458 Err(e) => {
459 stats.errors.fetch_add(1, Ordering::Relaxed);
460 tracing::error!("Signaling receive error: {e}");
461 break;
462 }
463 }
464 }
465
466 connected.store(false, Ordering::Release);
468 stats.disconnections.fetch_add(1, Ordering::Relaxed);
469 let _ = state_tx.send(ConnectionState::Disconnected);
470 });
471
472 *self.receiver_task.lock().await = Some(handle);
473 }
474}
475
476#[async_trait]
477impl SignalingClient for WebSocketSignalingClient {
478 async fn connect(&self) -> NetworkResult<()> {
479 self.connect_with_retries().await?;
480 self.start_receiver().await;
481 Ok(())
482 }
483
484 async fn disconnect(&self) -> NetworkResult<()> {
485 let mut sink_guard = self.ws_sink.lock().await;
487 let mut stream_guard = self.ws_stream.lock().await;
488
489 if let Some(mut sink) = sink_guard.take() {
491 let _ = sink.close().await;
492 }
493
494 stream_guard.take();
496
497 if let Some(handle) = self.receiver_task.lock().await.take() {
499 handle.abort();
500 }
501
502 self.reset_inbound_channel().await;
503
504 self.connected.store(false, Ordering::Release);
505 self.stats.disconnections.fetch_add(1, Ordering::Relaxed);
506
507 Ok(())
508 }
509
510 async fn send_register_request(
511 &self,
512 request: RegisterRequest,
513 ) -> NetworkResult<RegisterResponse> {
514 let flow = signaling_envelope::Flow::PeerToServer(PeerToSignaling {
516 payload: Some(peer_to_signaling::Payload::RegisterRequest(request)),
517 });
518
519 let envelope = self.create_envelope(flow).await;
520 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
521
522 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
523 {
524 if let Some(signaling_to_actr::Payload::RegisterResponse(response)) =
525 server_to_actr.payload
526 {
527 return Ok(response);
528 }
529 }
530
531 Err(NetworkError::ConnectionError(
532 "Invalid registration response".to_string(),
533 ))
534 }
535
536 async fn send_unregister_request(
537 &self,
538 actor_id: ActrId,
539 credential: AIdCredential,
540 reason: Option<String>,
541 ) -> NetworkResult<UnregisterResponse> {
542 let request = UnregisterRequest {
544 actr_id: actor_id.clone(),
545 reason,
546 };
547
548 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
550 source: actor_id,
551 credential,
552 payload: Some(actr_to_signaling::Payload::UnregisterRequest(request)),
553 });
554
555 let envelope = self.create_envelope(flow).await;
557 self.send_envelope(envelope).await?;
558
559 Ok(UnregisterResponse {
564 result: Some(actr_protocol::unregister_response::Result::Success(
565 actr_protocol::unregister_response::UnregisterOk {},
566 )),
567 })
568 }
569
570 async fn send_heartbeat(
571 &self,
572 actor_id: ActrId,
573 credential: AIdCredential,
574 availability: ServiceAvailabilityState,
575 power_reserve: f32,
576 mailbox_backlog: f32,
577 ) -> NetworkResult<()> {
578 let ping = Ping {
579 availability: availability as i32,
580 power_reserve,
581 mailbox_backlog,
582 sticky_client_ids: vec![], };
584
585 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
586 source: actor_id,
587 credential,
588 payload: Some(actr_to_signaling::Payload::Ping(ping)),
589 });
590
591 let envelope = self.create_envelope(flow).await;
592 self.send_envelope(envelope).await
593 }
594
595 async fn send_route_candidates_request(
596 &self,
597 actor_id: ActrId,
598 credential: AIdCredential,
599 request: RouteCandidatesRequest,
600 ) -> NetworkResult<RouteCandidatesResponse> {
601 let flow = signaling_envelope::Flow::ActrToServer(ActrToSignaling {
602 source: actor_id,
603 credential,
604 payload: Some(actr_to_signaling::Payload::RouteCandidatesRequest(request)),
605 });
606
607 let envelope = self.create_envelope(flow).await;
608 let response_envelope = self.send_envelope_and_wait_response(envelope).await?;
609
610 if let Some(signaling_envelope::Flow::ServerToActr(server_to_actr)) = response_envelope.flow
611 {
612 match server_to_actr.payload {
613 Some(signaling_to_actr::Payload::RouteCandidatesResponse(response)) => {
614 return Ok(response);
615 }
616 Some(signaling_to_actr::Payload::Error(err)) => {
617 return Err(NetworkError::ServiceDiscoveryError(format!(
618 "{} ({})",
619 err.message, err.code
620 )));
621 }
622 _ => {}
623 }
624 }
625
626 Err(NetworkError::ConnectionError(
627 "Invalid route candidates response".to_string(),
628 ))
629 }
630
631 async fn send_envelope(&self, envelope: SignalingEnvelope) -> NetworkResult<()> {
632 let mut sink_guard = self.ws_sink.lock().await;
633
634 if let Some(sink) = sink_guard.as_mut() {
635 let mut buf = Vec::new();
637 envelope.encode(&mut buf)?;
638 let msg = tokio_tungstenite::tungstenite::Message::Binary(buf.into());
639 sink.send(msg).await?;
640
641 self.stats.messages_sent.fetch_add(1, Ordering::Relaxed);
642
643 Ok(())
644 } else {
645 Err(NetworkError::ConnectionError("Not connected".to_string()))
646 }
647 }
648
649 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>> {
650 let mut rx = self.inbound_rx.lock().await;
651 match rx.recv().await {
652 Some(envelope) => Ok(Some(envelope)),
653 None => {
654 tracing::error!("Inbound channel closed");
655 Err(NetworkError::ConnectionError(
656 "Inbound channel closed".to_string(),
657 ))
658 }
659 }
660 }
661
662 fn is_connected(&self) -> bool {
663 self.connected.load(Ordering::Acquire)
664 }
665
666 fn get_stats(&self) -> SignalingStats {
667 self.stats.snapshot()
668 }
669
670 fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
671 self.state_tx.subscribe()
672 }
673}
674
675pub fn spawn_signaling_reconnector(client: Arc<dyn SignalingClient>, shutdown: CancellationToken) {
678 let mut state_rx = client.subscribe_state();
679
680 tokio::spawn(async move {
681 loop {
682 tokio::select! {
683 _ = shutdown.cancelled() => {
684 tracing::info!("🛑 Stopping signaling reconnect helper due to shutdown");
685 break;
686 }
687 changed = state_rx.changed() => {
688 if changed.is_err() {
689 tracing::info!("Signaling state channel closed, stopping reconnect helper");
690 break;
691 }
692
693 if *state_rx.borrow() == ConnectionState::Disconnected {
694 if shutdown.is_cancelled() {
696 tracing::info!(
697 "Shutdown already requested when disconnect event observed; skipping reconnect"
698 );
699 break;
700 }
701
702 tracing::warn!("📡 Signaling state is DISCONNECTED, attempting reconnect");
703 if let Err(e) = client.connect().await {
704 tracing::error!("❌ Signaling reconnect failed: {e}");
705 } else {
706 tracing::info!("✅ Signaling reconnect succeeded");
707 }
708
709 }
710 }
711 }
712 }
713 });
714}
715
716#[derive(Debug)]
718pub(crate) struct AtomicSignalingStats {
719 pub connections: AtomicU64,
721
722 pub disconnections: AtomicU64,
724
725 pub messages_sent: AtomicU64,
727
728 pub messages_received: AtomicU64,
730
731 pub heartbeats_sent: AtomicU64,
734
735 pub heartbeats_received: AtomicU64,
738
739 pub errors: AtomicU64,
741}
742
743impl Default for AtomicSignalingStats {
744 fn default() -> Self {
745 Self {
746 connections: AtomicU64::new(0),
747 disconnections: AtomicU64::new(0),
748 messages_sent: AtomicU64::new(0),
749 messages_received: AtomicU64::new(0),
750 heartbeats_sent: AtomicU64::new(0),
751 heartbeats_received: AtomicU64::new(0),
752 errors: AtomicU64::new(0),
753 }
754 }
755}
756
757#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
759pub struct SignalingStats {
760 pub connections: u64,
762
763 pub disconnections: u64,
765
766 pub messages_sent: u64,
768
769 pub messages_received: u64,
771
772 pub heartbeats_sent: u64,
774
775 pub heartbeats_received: u64,
777
778 pub errors: u64,
780}
781
782impl AtomicSignalingStats {
783 pub fn snapshot(&self) -> SignalingStats {
785 SignalingStats {
786 connections: self.connections.load(Ordering::Relaxed),
787 disconnections: self.disconnections.load(Ordering::Relaxed),
788 messages_sent: self.messages_sent.load(Ordering::Relaxed),
789 messages_received: self.messages_received.load(Ordering::Relaxed),
790 heartbeats_sent: self.heartbeats_sent.load(Ordering::Relaxed),
791 heartbeats_received: self.heartbeats_received.load(Ordering::Relaxed),
792 errors: self.errors.load(Ordering::Relaxed),
793 }
794 }
795}
796
797#[cfg(test)]
798mod tests {
799 use super::*;
800 use std::sync::atomic::{AtomicUsize, Ordering as UsizeOrdering};
801 use tokio_util::sync::CancellationToken;
802
803 struct FakeSignalingClient {
805 state_tx: watch::Sender<ConnectionState>,
806 connect_calls: Arc<AtomicUsize>,
807 }
808
809 #[async_trait]
810 impl SignalingClient for FakeSignalingClient {
811 async fn connect(&self) -> NetworkResult<()> {
812 self.connect_calls.fetch_add(1, UsizeOrdering::SeqCst);
813 Ok(())
814 }
815
816 async fn disconnect(&self) -> NetworkResult<()> {
817 Ok(())
818 }
819
820 async fn send_register_request(
821 &self,
822 _request: RegisterRequest,
823 ) -> NetworkResult<RegisterResponse> {
824 unimplemented!("not needed in tests");
825 }
826
827 async fn send_unregister_request(
828 &self,
829 _actor_id: ActrId,
830 _credential: AIdCredential,
831 _reason: Option<String>,
832 ) -> NetworkResult<UnregisterResponse> {
833 unimplemented!("not needed in tests");
834 }
835
836 async fn send_heartbeat(
837 &self,
838 _actor_id: ActrId,
839 _credential: AIdCredential,
840 _availability: ServiceAvailabilityState,
841 _power_reserve: f32,
842 _mailbox_backlog: f32,
843 ) -> NetworkResult<()> {
844 unimplemented!("not needed in tests");
845 }
846
847 async fn send_route_candidates_request(
848 &self,
849 _actor_id: ActrId,
850 _credential: AIdCredential,
851 _request: RouteCandidatesRequest,
852 ) -> NetworkResult<RouteCandidatesResponse> {
853 unimplemented!("not needed in tests");
854 }
855
856 async fn send_envelope(&self, _envelope: SignalingEnvelope) -> NetworkResult<()> {
857 unimplemented!("not needed in tests");
858 }
859
860 async fn receive_envelope(&self) -> NetworkResult<Option<SignalingEnvelope>> {
861 unimplemented!("not needed in tests");
862 }
863
864 fn is_connected(&self) -> bool {
865 *self.state_tx.borrow() == ConnectionState::Connected
867 }
868
869 fn get_stats(&self) -> SignalingStats {
870 SignalingStats::default()
871 }
872
873 fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
874 self.state_tx.subscribe()
875 }
876 }
877
878 fn make_fake_client() -> (Arc<FakeSignalingClient>, watch::Sender<ConnectionState>) {
879 let (state_tx, _rx) = watch::channel(ConnectionState::Disconnected);
880 let client = Arc::new(FakeSignalingClient {
881 state_tx: state_tx.clone(),
882 connect_calls: Arc::new(AtomicUsize::new(0)),
883 });
884 (client, state_tx)
885 }
886
887 #[tokio::test]
888 async fn test_spawn_signaling_reconnector_does_not_trigger_on_connected() {
889 let (fake_client, state_tx) = make_fake_client();
890 let shutdown = CancellationToken::new();
891
892 let client_trait: Arc<dyn SignalingClient> = fake_client.clone();
894 spawn_signaling_reconnector(client_trait, shutdown.clone());
895
896 let _ = state_tx.send(ConnectionState::Connected);
898
899 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
900
901 let calls = fake_client.connect_calls.load(UsizeOrdering::SeqCst);
902 assert_eq!(
903 calls, 0,
904 "connect() should not be called on Connected state"
905 );
906
907 shutdown.cancel();
908 }
909
910 #[tokio::test]
911 async fn test_spawn_signaling_reconnector_triggers_connect_on_disconnect() {
912 let (fake_client, state_tx) = make_fake_client();
913 let shutdown = CancellationToken::new();
914
915 let client_trait: Arc<dyn SignalingClient> = fake_client.clone();
917 spawn_signaling_reconnector(client_trait, shutdown.clone());
918
919 let _ = state_tx.send(ConnectionState::Disconnected);
921
922 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
924
925 let calls = fake_client.connect_calls.load(UsizeOrdering::SeqCst);
926 assert!(
927 calls >= 1,
928 "expected at least one reconnect attempt, got {}",
929 calls
930 );
931
932 shutdown.cancel();
933 }
934
935 #[tokio::test]
936 async fn test_spawn_signaling_reconnector_stops_on_shutdown_before_disconnect() {
937 let (fake_client, state_tx) = make_fake_client();
938 let shutdown = CancellationToken::new();
939
940 let client_trait: Arc<dyn SignalingClient> = fake_client.clone();
941 spawn_signaling_reconnector(client_trait, shutdown.clone());
942
943 shutdown.cancel();
945
946 let _ = state_tx.send(ConnectionState::Disconnected);
948
949 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
950
951 let calls = fake_client.connect_calls.load(UsizeOrdering::SeqCst);
952 assert_eq!(calls, 0, "reconnect helper should not run after shutdown");
953 }
954
955 #[test]
956 fn test_websocket_signaling_client_initial_state_disconnected() {
957 let config = SignalingConfig {
959 server_url: Url::parse("ws://example.com/signaling/ws").unwrap(),
960 connection_timeout: 30,
961 heartbeat_interval: 30,
962 reconnect_config: ReconnectConfig::default(),
963 auth_config: None,
964 };
965
966 let client = WebSocketSignalingClient::new(config);
967 let state_rx = client.subscribe_state();
968 assert_eq!(*state_rx.borrow(), ConnectionState::Disconnected);
969 }
970}