1use crate::builder::ClientBuilder;
2use crate::client::{Client, ClientRunner};
3use crate::connector::Connector;
4use crate::error::{CoreError, CoreResult};
5use crate::middleware::{MiddlewareContext, WebSocketMiddleware};
6use crate::statistics::{ConnectionStats, StatisticsTracker};
7use crate::traits::AppState;
8use async_trait::async_trait;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio_tungstenite::tungstenite::Message;
12use tracing::{debug, error, info, warn};
13
14#[derive(Debug, Clone)]
16pub struct TestingConfig {
17 pub stats_interval: Duration,
19 pub log_stats: bool,
21 pub track_events: bool,
23 pub max_reconnect_attempts: Option<u32>,
25 pub reconnect_delay: Duration,
27 pub connection_timeout: Duration,
29 pub auto_reconnect: bool,
31}
32
33impl Default for TestingConfig {
34 fn default() -> Self {
35 Self {
36 stats_interval: Duration::from_secs(30),
37 log_stats: true,
38 track_events: true,
39 max_reconnect_attempts: Some(5),
40 reconnect_delay: Duration::from_secs(5),
41 connection_timeout: Duration::from_secs(10),
42 auto_reconnect: true,
43 }
44 }
45}
46
47pub struct TestingWrapper<S: AppState> {
50 client: Client<S>,
51 runner: Option<ClientRunner<S>>,
52 stats: Arc<StatisticsTracker>,
53 config: TestingConfig,
54 is_running: Arc<std::sync::atomic::AtomicBool>,
55 stats_task: Option<tokio::task::JoinHandle<()>>,
56 runner_task: Option<tokio::task::JoinHandle<()>>,
57}
58
59pub struct TestingMiddleware<S: AppState> {
61 stats: Arc<StatisticsTracker>,
62 _phantom: std::marker::PhantomData<S>,
63}
64
65impl<S: AppState> TestingMiddleware<S> {
66 pub fn new(stats: Arc<StatisticsTracker>) -> Self {
68 Self {
69 stats,
70 _phantom: std::marker::PhantomData,
71 }
72 }
73}
74
75#[async_trait]
76impl<S: AppState> WebSocketMiddleware<S> for TestingMiddleware<S> {
77 async fn on_connection_attempt(&self, _context: &MiddlewareContext<S>) -> CoreResult<()> {
78 self.stats.record_connection_attempt().await;
80 debug!(target: "TestingMiddleware", "Connection attempt recorded");
81 Ok(())
82 }
83
84 async fn on_connection_failure(
85 &self,
86 _context: &MiddlewareContext<S>,
87 reason: Option<String>,
88 ) -> CoreResult<()> {
89 self.stats.record_connection_failure(reason).await;
91 debug!(target: "TestingMiddleware", "Connection failure recorded");
92 Ok(())
93 }
94
95 async fn on_connect(&self, _context: &MiddlewareContext<S>) -> CoreResult<()> {
96 self.stats.record_connection_success().await;
98 debug!(target: "TestingMiddleware", "Connection established");
99 Ok(())
100 }
101
102 async fn on_disconnect(&self, _context: &MiddlewareContext<S>) -> CoreResult<()> {
103 self.stats
105 .record_disconnection(Some("Connection lost".to_string()))
106 .await;
107 debug!(target: "TestingMiddleware", "Connection lost");
108 Ok(())
109 }
110
111 async fn on_send(&self, message: &Message, _context: &MiddlewareContext<S>) -> CoreResult<()> {
112 self.stats.record_message_sent(message).await;
114 debug!(target: "TestingMiddleware", "Message sent: {} bytes", Self::get_message_size(message));
115 Ok(())
116 }
117
118 async fn on_receive(
119 &self,
120 message: &Message,
121 _context: &MiddlewareContext<S>,
122 ) -> CoreResult<()> {
123 self.stats.record_message_received(message).await;
125 debug!(target: "TestingMiddleware", "Message received: {} bytes", Self::get_message_size(message));
126 Ok(())
127 }
128}
129
130impl<S: AppState> TestingMiddleware<S> {
131 fn get_message_size(message: &Message) -> usize {
133 match message {
134 Message::Text(text) => text.len(),
135 Message::Binary(data) => data.len(),
136 Message::Ping(data) => data.len(),
137 Message::Pong(data) => data.len(),
138 Message::Close(_) => 0,
139 Message::Frame(_) => 0,
140 }
141 }
142}
143
144impl<S: AppState> TestingWrapper<S> {
145 pub fn new(client: Client<S>, runner: ClientRunner<S>, config: TestingConfig) -> Self {
147 let stats = Arc::new(StatisticsTracker::new());
148
149 Self {
150 client,
151 runner: Some(runner),
152 stats,
153 config,
154 is_running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
155 stats_task: None,
156 runner_task: None,
157 }
158 }
159
160 pub fn new_with_stats(
163 client: Client<S>,
164 runner: ClientRunner<S>,
165 config: TestingConfig,
166 stats: Arc<StatisticsTracker>,
167 ) -> Self {
168 Self {
169 client,
170 runner: Some(runner),
171 stats,
172 config,
173 is_running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
174 stats_task: None,
175 runner_task: None,
176 }
177 }
178
179 pub fn create_middleware(&self) -> TestingMiddleware<S> {
181 TestingMiddleware::new(Arc::clone(&self.stats))
182 }
183
184 pub async fn start(&mut self) -> CoreResult<()> {
186 self.is_running
187 .store(true, std::sync::atomic::Ordering::SeqCst);
188
189 if self.config.log_stats {
191 let stats = self.stats.clone();
192 let interval = self.config.stats_interval;
193 let is_running = self.is_running.clone();
194
195 self.stats_task = Some(tokio::spawn(async move {
196 let mut interval = tokio::time::interval(interval);
197 interval.tick().await; while is_running.load(std::sync::atomic::Ordering::SeqCst) {
200 interval.tick().await;
201
202 let stats = stats.get_stats().await;
203 Self::log_statistics(&stats);
204 }
205 }));
206 }
207
208 self.stats.record_connection_attempt().await;
210
211 let runner = self.runner.take().ok_or_else(|| {
214 CoreError::Other("Runner has already been started or consumed".to_string())
215 })?;
216 let stats = self.stats.clone();
217 let is_running = self.is_running.clone();
218
219 self.runner_task = Some(tokio::spawn(async move {
220 let mut runner = runner;
221
222 let result = Self::run_with_stats(&mut runner, stats.clone()).await;
224
225 is_running.store(false, std::sync::atomic::Ordering::SeqCst);
227
228 match result {
229 Ok(_) => {
230 info!("ClientRunner completed successfully");
231 }
232 Err(e) => {
233 error!("ClientRunner failed: {}", e);
234 stats.record_connection_failure(Some(e.to_string())).await;
236 }
237 }
238 }));
239
240 info!("Testing wrapper started successfully");
241 Ok(())
242 }
243
244 async fn run_with_stats(
246 runner: &mut ClientRunner<S>,
247 stats: Arc<StatisticsTracker>,
248 ) -> CoreResult<()> {
249 stats.record_connection_success().await;
256 runner.run().await;
257 Ok(())
258 }
259
260 pub async fn stop(mut self) -> CoreResult<ConnectionStats> {
262 self.is_running
263 .store(false, std::sync::atomic::Ordering::SeqCst);
264
265 if let Some(task) = self.stats_task.take() {
267 task.abort();
268 }
269
270 info!("Sending shutdown command to client...");
273
274 self.stats
276 .record_disconnection(Some("Manual stop".to_string()))
277 .await;
278
279 if let Some(runner_task) = self.runner_task.take() {
284 match tokio::time::timeout(Duration::from_secs(10), runner_task).await {
286 Ok(Ok(())) => {
287 info!("Runner task completed successfully");
288 }
289 Ok(Err(e)) => {
290 if e.is_cancelled() {
291 info!("Runner task was cancelled");
292 } else {
293 error!("Runner task failed: {}", e);
294 }
295 }
296 Err(_) => {
297 warn!("Runner task did not complete within timeout, it may still be running");
298 }
299 }
300 }
301
302 let stats = self.get_stats().await;
303
304 info!("Shutting down client...");
306 self.client.shutdown().await?;
307
308 info!("Testing wrapper stopped");
309 Ok(stats)
310 }
311
312 pub async fn get_stats(&self) -> ConnectionStats {
314 self.stats.get_stats().await
315 }
316
317 pub fn client(&self) -> &Client<S> {
319 &self.client
320 }
321
322 pub fn client_mut(&mut self) -> &mut Client<S> {
324 &mut self.client
325 }
326
327 pub async fn reset_stats(&self) {
329 warn!("Statistics reset requested, but not fully implemented");
334 }
335
336 pub async fn export_stats_json(&self) -> CoreResult<String> {
338 let stats = self.get_stats().await;
339 serde_json::to_string_pretty(&stats)
340 .map_err(|e| CoreError::Other(format!("Failed to serialize stats: {e}")))
341 }
342
343 pub async fn export_stats_csv(&self) -> CoreResult<String> {
345 let stats = self.get_stats().await;
346
347 let mut csv = String::new();
348 csv.push_str("metric,value\n");
349 csv.push_str(&format!(
350 "connection_attempts,{}\n",
351 stats.connection_attempts
352 ));
353 csv.push_str(&format!(
354 "successful_connections,{}\n",
355 stats.successful_connections
356 ));
357 csv.push_str(&format!(
358 "failed_connections,{}\n",
359 stats.failed_connections
360 ));
361 csv.push_str(&format!("disconnections,{}\n", stats.disconnections));
362 csv.push_str(&format!("reconnections,{}\n", stats.reconnections));
363 csv.push_str(&format!(
364 "avg_connection_latency_ms,{}\n",
365 stats.avg_connection_latency_ms
366 ));
367 csv.push_str(&format!(
368 "last_connection_latency_ms,{}\n",
369 stats.last_connection_latency_ms
370 ));
371 csv.push_str(&format!(
372 "total_uptime_seconds,{}\n",
373 stats.total_uptime_seconds
374 ));
375 csv.push_str(&format!(
376 "current_uptime_seconds,{}\n",
377 stats.current_uptime_seconds
378 ));
379 csv.push_str(&format!(
380 "time_since_last_disconnection_seconds,{}\n",
381 stats.time_since_last_disconnection_seconds
382 ));
383 csv.push_str(&format!("messages_sent,{}\n", stats.messages_sent));
384 csv.push_str(&format!("messages_received,{}\n", stats.messages_received));
385 csv.push_str(&format!("bytes_sent,{}\n", stats.bytes_sent));
386 csv.push_str(&format!("bytes_received,{}\n", stats.bytes_received));
387 csv.push_str(&format!(
388 "avg_messages_sent_per_second,{}\n",
389 stats.avg_messages_sent_per_second
390 ));
391 csv.push_str(&format!(
392 "avg_messages_received_per_second,{}\n",
393 stats.avg_messages_received_per_second
394 ));
395 csv.push_str(&format!(
396 "avg_bytes_sent_per_second,{}\n",
397 stats.avg_bytes_sent_per_second
398 ));
399 csv.push_str(&format!(
400 "avg_bytes_received_per_second,{}\n",
401 stats.avg_bytes_received_per_second
402 ));
403 csv.push_str(&format!("is_connected,{}\n", stats.is_connected));
404
405 Ok(csv)
406 }
407
408 fn log_statistics(stats: &ConnectionStats) {
410 info!("=== WebSocket Connection Statistics ===");
411 info!(
412 "Connection Status: {}",
413 if stats.is_connected {
414 "CONNECTED"
415 } else {
416 "DISCONNECTED"
417 }
418 );
419 info!("Connection Attempts: {}", stats.connection_attempts);
420 info!("Successful Connections: {}", stats.successful_connections);
421 info!("Failed Connections: {}", stats.failed_connections);
422 info!("Disconnections: {}", stats.disconnections);
423 info!("Reconnections: {}", stats.reconnections);
424
425 if stats.avg_connection_latency_ms > 0.0 {
426 info!(
427 "Average Connection Latency: {:.2}ms",
428 stats.avg_connection_latency_ms
429 );
430 info!(
431 "Last Connection Latency: {:.2}ms",
432 stats.last_connection_latency_ms
433 );
434 }
435
436 info!("Total Uptime: {:.2}s", stats.total_uptime_seconds);
437 if stats.is_connected {
438 info!(
439 "Current Connection Uptime: {:.2}s",
440 stats.current_uptime_seconds
441 );
442 }
443 if stats.time_since_last_disconnection_seconds > 0.0 {
444 info!(
445 "Time Since Last Disconnection: {:.2}s",
446 stats.time_since_last_disconnection_seconds
447 );
448 }
449
450 info!(
451 "Messages Sent: {} ({:.2}/s)",
452 stats.messages_sent, stats.avg_messages_sent_per_second
453 );
454 info!(
455 "Messages Received: {} ({:.2}/s)",
456 stats.messages_received, stats.avg_messages_received_per_second
457 );
458 info!(
459 "Bytes Sent: {} ({:.2}/s)",
460 stats.bytes_sent, stats.avg_bytes_sent_per_second
461 );
462 info!(
463 "Bytes Received: {} ({:.2}/s)",
464 stats.bytes_received, stats.avg_bytes_received_per_second
465 );
466
467 if stats.connection_attempts > 0 {
468 let success_rate =
469 (stats.successful_connections as f64 / stats.connection_attempts as f64) * 100.0;
470 info!("Connection Success Rate: {:.1}%", success_rate);
471 }
472
473 info!("========================================");
474 }
475}
476
477pub struct TestingConnector<C, S> {
479 inner: C,
480 stats: Arc<StatisticsTracker>,
481 config: TestingConfig,
482 _phantom: std::marker::PhantomData<S>,
483}
484
485impl<C, S> TestingConnector<C, S> {
486 pub fn new(inner: C, stats: Arc<StatisticsTracker>, config: TestingConfig) -> Self {
487 Self {
488 inner,
489 stats,
490 config,
491 _phantom: std::marker::PhantomData,
492 }
493 }
494}
495
496#[async_trait]
497impl<C, S> Connector<S> for TestingConnector<C, S>
498where
499 C: Connector<S> + Send + Sync,
500 S: AppState,
501{
502 async fn connect(
503 &self,
504 state: Arc<S>,
505 ) -> crate::connector::ConnectorResult<crate::connector::WsStream> {
506 self.stats.record_connection_attempt().await;
507
508 let start_time = std::time::Instant::now();
509
510 let result =
512 tokio::time::timeout(self.config.connection_timeout, self.inner.connect(state)).await;
513
514 match result {
515 Ok(Ok(stream)) => {
516 self.stats.record_connection_success().await;
517 debug!("Connection established in {:?}", start_time.elapsed());
518 Ok(stream)
519 }
520 Ok(Err(err)) => {
521 self.stats
522 .record_connection_failure(Some(err.to_string()))
523 .await;
524 error!("Connection failed: {}", err);
525 Err(err)
526 }
527 Err(_) => {
528 let timeout_error = crate::connector::ConnectorError::Timeout;
529 self.stats
530 .record_connection_failure(Some(timeout_error.to_string()))
531 .await;
532 error!(
533 "Connection timed out after {:?}",
534 self.config.connection_timeout
535 );
536 Err(timeout_error)
537 }
538 }
539 }
540
541 async fn disconnect(&self) -> crate::connector::ConnectorResult<()> {
542 self.stats
543 .record_disconnection(Some("Manual disconnect".to_string()))
544 .await;
545 self.inner.disconnect().await
546 }
547}
548
549pub struct TestingWrapperBuilder<S: AppState> {
551 config: TestingConfig,
552 _phantom: std::marker::PhantomData<S>,
553}
554
555impl<S: AppState> TestingWrapperBuilder<S> {
556 pub fn new() -> Self {
557 Self {
558 config: TestingConfig::default(),
559 _phantom: std::marker::PhantomData,
560 }
561 }
562
563 pub fn with_stats_interval(mut self, interval: Duration) -> Self {
564 self.config.stats_interval = interval;
565 self
566 }
567
568 pub fn with_log_stats(mut self, log_stats: bool) -> Self {
569 self.config.log_stats = log_stats;
570 self
571 }
572
573 pub fn with_track_events(mut self, track_events: bool) -> Self {
574 self.config.track_events = track_events;
575 self
576 }
577
578 pub fn with_max_reconnect_attempts(mut self, max_attempts: Option<u32>) -> Self {
579 self.config.max_reconnect_attempts = max_attempts;
580 self
581 }
582
583 pub fn with_reconnect_delay(mut self, delay: Duration) -> Self {
584 self.config.reconnect_delay = delay;
585 self
586 }
587
588 pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
589 self.config.connection_timeout = timeout;
590 self
591 }
592
593 pub fn with_auto_reconnect(mut self, auto_reconnect: bool) -> Self {
594 self.config.auto_reconnect = auto_reconnect;
595 self
596 }
597
598 pub fn build(self, client: Client<S>, runner: ClientRunner<S>) -> TestingWrapper<S> {
599 TestingWrapper::new(client, runner, self.config)
600 }
601
602 pub async fn build_with_middleware(
604 self,
605 builder: ClientBuilder<S>,
606 ) -> CoreResult<TestingWrapper<S>> {
607 let stats = Arc::new(StatisticsTracker::new());
608 let middleware = TestingMiddleware::new(Arc::clone(&stats));
609 let (client, runner) = builder
610 .with_middleware(Box::new(middleware))
611 .build()
612 .await?;
613 let wrapper = TestingWrapper::new_with_stats(client, runner, self.config, stats);
614
615 Ok(wrapper)
616 }
617}
618
619impl<S: AppState> Default for TestingWrapperBuilder<S> {
620 fn default() -> Self {
621 Self::new()
622 }
623}