Skip to main content

binary_options_tools_core_pre/
testing.rs

1use crate::builder::ClientBuilder;
2use crate::client::{Client, ClientRunner};
3use crate::connector::Connector;
4use crate::error::{CoreError, CoreResult};
5use crate::middleware::{MiddlewareContext, WebSocketMiddleware};
6use crate::statistics::{ConnectionStats, StatisticsTracker};
7use crate::traits::AppState;
8use async_trait::async_trait;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio_tungstenite::tungstenite::Message;
12use tracing::{debug, error, info, warn};
13
14/// Configuration for the testing wrapper
15#[derive(Debug, Clone)]
16pub struct TestingConfig {
17    /// How often to collect and log statistics
18    pub stats_interval: Duration,
19    /// Whether to log statistics to console
20    pub log_stats: bool,
21    /// Whether to track detailed connection events
22    pub track_events: bool,
23    /// Maximum number of reconnection attempts
24    pub max_reconnect_attempts: Option<u32>,
25    /// Delay between reconnection attempts
26    pub reconnect_delay: Duration,
27    /// Connection timeout duration
28    pub connection_timeout: Duration,
29    /// Whether to automatically reconnect on disconnection
30    pub auto_reconnect: bool,
31}
32
33impl Default for TestingConfig {
34    fn default() -> Self {
35        Self {
36            stats_interval: Duration::from_secs(30),
37            log_stats: true,
38            track_events: true,
39            max_reconnect_attempts: Some(5),
40            reconnect_delay: Duration::from_secs(5),
41            connection_timeout: Duration::from_secs(10),
42            auto_reconnect: true,
43        }
44    }
45}
46
47/// A testing wrapper around the Client that provides comprehensive statistics
48/// and monitoring capabilities for WebSocket connections.
49pub struct TestingWrapper<S: AppState> {
50    client: Client<S>,
51    runner: Option<ClientRunner<S>>,
52    stats: Arc<StatisticsTracker>,
53    config: TestingConfig,
54    is_running: Arc<std::sync::atomic::AtomicBool>,
55    stats_task: Option<tokio::task::JoinHandle<()>>,
56    runner_task: Option<tokio::task::JoinHandle<()>>,
57}
58
59/// A testing middleware that tracks connection statistics using the shared StatisticsTracker
60pub struct TestingMiddleware<S: AppState> {
61    stats: Arc<StatisticsTracker>,
62    _phantom: std::marker::PhantomData<S>,
63}
64
65impl<S: AppState> TestingMiddleware<S> {
66    /// Create a new testing middleware with the provided StatisticsTracker
67    pub fn new(stats: Arc<StatisticsTracker>) -> Self {
68        Self {
69            stats,
70            _phantom: std::marker::PhantomData,
71        }
72    }
73}
74
75#[async_trait]
76impl<S: AppState> WebSocketMiddleware<S> for TestingMiddleware<S> {
77    async fn on_connection_attempt(&self, _context: &MiddlewareContext<S>) -> CoreResult<()> {
78        // 🎯 This is the missing piece!
79        self.stats.record_connection_attempt().await;
80        debug!(target: "TestingMiddleware", "Connection attempt recorded");
81        Ok(())
82    }
83
84    async fn on_connection_failure(
85        &self,
86        _context: &MiddlewareContext<S>,
87        reason: Option<String>,
88    ) -> CoreResult<()> {
89        // 🎯 This will give you proper failure tracking
90        self.stats.record_connection_failure(reason).await;
91        debug!(target: "TestingMiddleware", "Connection failure recorded");
92        Ok(())
93    }
94
95    async fn on_connect(&self, _context: &MiddlewareContext<S>) -> CoreResult<()> {
96        // This calls record_connection_success - already implemented
97        self.stats.record_connection_success().await;
98        debug!(target: "TestingMiddleware", "Connection established");
99        Ok(())
100    }
101
102    async fn on_disconnect(&self, _context: &MiddlewareContext<S>) -> CoreResult<()> {
103        // Record disconnection with reason
104        self.stats
105            .record_disconnection(Some("Connection lost".to_string()))
106            .await;
107        debug!(target: "TestingMiddleware", "Connection lost");
108        Ok(())
109    }
110
111    async fn on_send(&self, message: &Message, _context: &MiddlewareContext<S>) -> CoreResult<()> {
112        // Record message sent with size tracking
113        self.stats.record_message_sent(message).await;
114        debug!(target: "TestingMiddleware", "Message sent: {} bytes", Self::get_message_size(message));
115        Ok(())
116    }
117
118    async fn on_receive(
119        &self,
120        message: &Message,
121        _context: &MiddlewareContext<S>,
122    ) -> CoreResult<()> {
123        // Record message received with size tracking
124        self.stats.record_message_received(message).await;
125        debug!(target: "TestingMiddleware", "Message received: {} bytes", Self::get_message_size(message));
126        Ok(())
127    }
128}
129
130impl<S: AppState> TestingMiddleware<S> {
131    /// Get the size of a message in bytes
132    fn get_message_size(message: &Message) -> usize {
133        match message {
134            Message::Text(text) => text.len(),
135            Message::Binary(data) => data.len(),
136            Message::Ping(data) => data.len(),
137            Message::Pong(data) => data.len(),
138            Message::Close(_) => 0,
139            Message::Frame(_) => 0,
140        }
141    }
142}
143
144impl<S: AppState> TestingWrapper<S> {
145    /// Create a new testing wrapper with the provided client and runner
146    pub fn new(client: Client<S>, runner: ClientRunner<S>, config: TestingConfig) -> Self {
147        let stats = Arc::new(StatisticsTracker::new());
148
149        Self {
150            client,
151            runner: Some(runner),
152            stats,
153            config,
154            is_running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
155            stats_task: None,
156            runner_task: None,
157        }
158    }
159
160    /// Create a new testing wrapper with a shared StatisticsTracker
161    /// This is useful when you want to share statistics between multiple components
162    pub fn new_with_stats(
163        client: Client<S>,
164        runner: ClientRunner<S>,
165        config: TestingConfig,
166        stats: Arc<StatisticsTracker>,
167    ) -> Self {
168        Self {
169            client,
170            runner: Some(runner),
171            stats,
172            config,
173            is_running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
174            stats_task: None,
175            runner_task: None,
176        }
177    }
178
179    /// Create a TestingMiddleware that shares the same StatisticsTracker
180    pub fn create_middleware(&self) -> TestingMiddleware<S> {
181        TestingMiddleware::new(Arc::clone(&self.stats))
182    }
183
184    /// Start the testing wrapper, which will run the client and begin collecting statistics
185    pub async fn start(&mut self) -> CoreResult<()> {
186        self.is_running
187            .store(true, std::sync::atomic::Ordering::SeqCst);
188
189        // Start statistics collection task
190        if self.config.log_stats {
191            let stats = self.stats.clone();
192            let interval = self.config.stats_interval;
193            let is_running = self.is_running.clone();
194
195            self.stats_task = Some(tokio::spawn(async move {
196                let mut interval = tokio::time::interval(interval);
197                interval.tick().await; // Skip first tick
198
199                while is_running.load(std::sync::atomic::Ordering::SeqCst) {
200                    interval.tick().await;
201
202                    let stats = stats.get_stats().await;
203                    Self::log_statistics(&stats);
204                }
205            }));
206        }
207
208        // Record initial connection attempt
209        self.stats.record_connection_attempt().await;
210
211        // Start the actual ClientRunner in a separate task
212        // We need to take ownership of the runner to move it into the task
213        let runner = self.runner.take().ok_or_else(|| {
214            CoreError::Other("Runner has already been started or consumed".to_string())
215        })?;
216        let stats = self.stats.clone();
217        let is_running = self.is_running.clone();
218
219        self.runner_task = Some(tokio::spawn(async move {
220            let mut runner = runner;
221
222            // Create a wrapper around the runner that tracks statistics
223            let result = Self::run_with_stats(&mut runner, stats.clone()).await;
224
225            // Mark as not running when the runner exits
226            is_running.store(false, std::sync::atomic::Ordering::SeqCst);
227
228            match result {
229                Ok(_) => {
230                    info!("ClientRunner completed successfully");
231                }
232                Err(e) => {
233                    error!("ClientRunner failed: {}", e);
234                    // Record connection failure
235                    stats.record_connection_failure(Some(e.to_string())).await;
236                }
237            }
238        }));
239
240        info!("Testing wrapper started successfully");
241        Ok(())
242    }
243
244    /// Run the ClientRunner with statistics tracking
245    async fn run_with_stats(
246        runner: &mut ClientRunner<S>,
247        stats: Arc<StatisticsTracker>,
248    ) -> CoreResult<()> {
249        // For now, we'll just run the runner directly
250        // In a future enhancement, we could intercept connection events
251        // and track them more granularly
252
253        // Since ClientRunner.run() doesn't return a Result, we'll assume it succeeds
254        // and track the connection success
255        stats.record_connection_success().await;
256        runner.run().await;
257        Ok(())
258    }
259
260    /// Stop the testing wrapper
261    pub async fn stop(mut self) -> CoreResult<ConnectionStats> {
262        self.is_running
263            .store(false, std::sync::atomic::Ordering::SeqCst);
264
265        // Abort the statistics task
266        if let Some(task) = self.stats_task.take() {
267            task.abort();
268        }
269
270        // Shutdown the client, which will signal the runner to stop
271        // Note: This consumes the client, so we need to handle this carefully
272        info!("Sending shutdown command to client...");
273
274        // Record the disconnection before shutting down
275        self.stats
276            .record_disconnection(Some("Manual stop".to_string()))
277            .await;
278
279        // We can't consume self.client here because we need to return self
280        // Instead, we'll wait for the runner task to complete naturally
281        // The runner should stop when the connection is closed or on error
282
283        if let Some(runner_task) = self.runner_task.take() {
284            // Wait for the runner task to complete with a timeout
285            match tokio::time::timeout(Duration::from_secs(10), runner_task).await {
286                Ok(Ok(())) => {
287                    info!("Runner task completed successfully");
288                }
289                Ok(Err(e)) => {
290                    if e.is_cancelled() {
291                        info!("Runner task was cancelled");
292                    } else {
293                        error!("Runner task failed: {}", e);
294                    }
295                }
296                Err(_) => {
297                    warn!("Runner task did not complete within timeout, it may still be running");
298                }
299            }
300        }
301
302        let stats = self.get_stats().await;
303
304        // Shutdown the client
305        info!("Shutting down client...");
306        self.client.shutdown().await?;
307
308        info!("Testing wrapper stopped");
309        Ok(stats)
310    }
311
312    /// Get the current connection statistics
313    pub async fn get_stats(&self) -> ConnectionStats {
314        self.stats.get_stats().await
315    }
316
317    /// Get a reference to the underlying client
318    pub fn client(&self) -> &Client<S> {
319        &self.client
320    }
321
322    /// Get a mutable reference to the underlying client
323    pub fn client_mut(&mut self) -> &mut Client<S> {
324        &mut self.client
325    }
326
327    /// Reset all statistics
328    pub async fn reset_stats(&self) {
329        // Create a new statistics tracker and replace the current one
330        // Note: This is a simplified approach. In a real implementation,
331        // you might want to use Arc::make_mut or other techniques
332        // to properly reset the statistics while maintaining thread safety
333        warn!("Statistics reset requested, but not fully implemented");
334    }
335
336    /// Export statistics to JSON
337    pub async fn export_stats_json(&self) -> CoreResult<String> {
338        let stats = self.get_stats().await;
339        serde_json::to_string_pretty(&stats)
340            .map_err(|e| CoreError::Other(format!("Failed to serialize stats: {e}")))
341    }
342
343    /// Export statistics to CSV
344    pub async fn export_stats_csv(&self) -> CoreResult<String> {
345        let stats = self.get_stats().await;
346
347        let mut csv = String::new();
348        csv.push_str("metric,value\n");
349        csv.push_str(&format!(
350            "connection_attempts,{}\n",
351            stats.connection_attempts
352        ));
353        csv.push_str(&format!(
354            "successful_connections,{}\n",
355            stats.successful_connections
356        ));
357        csv.push_str(&format!(
358            "failed_connections,{}\n",
359            stats.failed_connections
360        ));
361        csv.push_str(&format!("disconnections,{}\n", stats.disconnections));
362        csv.push_str(&format!("reconnections,{}\n", stats.reconnections));
363        csv.push_str(&format!(
364            "avg_connection_latency_ms,{}\n",
365            stats.avg_connection_latency_ms
366        ));
367        csv.push_str(&format!(
368            "last_connection_latency_ms,{}\n",
369            stats.last_connection_latency_ms
370        ));
371        csv.push_str(&format!(
372            "total_uptime_seconds,{}\n",
373            stats.total_uptime_seconds
374        ));
375        csv.push_str(&format!(
376            "current_uptime_seconds,{}\n",
377            stats.current_uptime_seconds
378        ));
379        csv.push_str(&format!(
380            "time_since_last_disconnection_seconds,{}\n",
381            stats.time_since_last_disconnection_seconds
382        ));
383        csv.push_str(&format!("messages_sent,{}\n", stats.messages_sent));
384        csv.push_str(&format!("messages_received,{}\n", stats.messages_received));
385        csv.push_str(&format!("bytes_sent,{}\n", stats.bytes_sent));
386        csv.push_str(&format!("bytes_received,{}\n", stats.bytes_received));
387        csv.push_str(&format!(
388            "avg_messages_sent_per_second,{}\n",
389            stats.avg_messages_sent_per_second
390        ));
391        csv.push_str(&format!(
392            "avg_messages_received_per_second,{}\n",
393            stats.avg_messages_received_per_second
394        ));
395        csv.push_str(&format!(
396            "avg_bytes_sent_per_second,{}\n",
397            stats.avg_bytes_sent_per_second
398        ));
399        csv.push_str(&format!(
400            "avg_bytes_received_per_second,{}\n",
401            stats.avg_bytes_received_per_second
402        ));
403        csv.push_str(&format!("is_connected,{}\n", stats.is_connected));
404
405        Ok(csv)
406    }
407
408    /// Log current statistics to console
409    fn log_statistics(stats: &ConnectionStats) {
410        info!("=== WebSocket Connection Statistics ===");
411        info!(
412            "Connection Status: {}",
413            if stats.is_connected {
414                "CONNECTED"
415            } else {
416                "DISCONNECTED"
417            }
418        );
419        info!("Connection Attempts: {}", stats.connection_attempts);
420        info!("Successful Connections: {}", stats.successful_connections);
421        info!("Failed Connections: {}", stats.failed_connections);
422        info!("Disconnections: {}", stats.disconnections);
423        info!("Reconnections: {}", stats.reconnections);
424
425        if stats.avg_connection_latency_ms > 0.0 {
426            info!(
427                "Average Connection Latency: {:.2}ms",
428                stats.avg_connection_latency_ms
429            );
430            info!(
431                "Last Connection Latency: {:.2}ms",
432                stats.last_connection_latency_ms
433            );
434        }
435
436        info!("Total Uptime: {:.2}s", stats.total_uptime_seconds);
437        if stats.is_connected {
438            info!(
439                "Current Connection Uptime: {:.2}s",
440                stats.current_uptime_seconds
441            );
442        }
443        if stats.time_since_last_disconnection_seconds > 0.0 {
444            info!(
445                "Time Since Last Disconnection: {:.2}s",
446                stats.time_since_last_disconnection_seconds
447            );
448        }
449
450        info!(
451            "Messages Sent: {} ({:.2}/s)",
452            stats.messages_sent, stats.avg_messages_sent_per_second
453        );
454        info!(
455            "Messages Received: {} ({:.2}/s)",
456            stats.messages_received, stats.avg_messages_received_per_second
457        );
458        info!(
459            "Bytes Sent: {} ({:.2}/s)",
460            stats.bytes_sent, stats.avg_bytes_sent_per_second
461        );
462        info!(
463            "Bytes Received: {} ({:.2}/s)",
464            stats.bytes_received, stats.avg_bytes_received_per_second
465        );
466
467        if stats.connection_attempts > 0 {
468            let success_rate =
469                (stats.successful_connections as f64 / stats.connection_attempts as f64) * 100.0;
470            info!("Connection Success Rate: {:.1}%", success_rate);
471        }
472
473        info!("========================================");
474    }
475}
476
477/// A testing connector wrapper that tracks connection statistics
478pub struct TestingConnector<C, S> {
479    inner: C,
480    stats: Arc<StatisticsTracker>,
481    config: TestingConfig,
482    _phantom: std::marker::PhantomData<S>,
483}
484
485impl<C, S> TestingConnector<C, S> {
486    pub fn new(inner: C, stats: Arc<StatisticsTracker>, config: TestingConfig) -> Self {
487        Self {
488            inner,
489            stats,
490            config,
491            _phantom: std::marker::PhantomData,
492        }
493    }
494}
495
496#[async_trait]
497impl<C, S> Connector<S> for TestingConnector<C, S>
498where
499    C: Connector<S> + Send + Sync,
500    S: AppState,
501{
502    async fn connect(
503        &self,
504        state: Arc<S>,
505    ) -> crate::connector::ConnectorResult<crate::connector::WsStream> {
506        self.stats.record_connection_attempt().await;
507
508        let start_time = std::time::Instant::now();
509
510        // Apply connection timeout
511        let result =
512            tokio::time::timeout(self.config.connection_timeout, self.inner.connect(state)).await;
513
514        match result {
515            Ok(Ok(stream)) => {
516                self.stats.record_connection_success().await;
517                debug!("Connection established in {:?}", start_time.elapsed());
518                Ok(stream)
519            }
520            Ok(Err(err)) => {
521                self.stats
522                    .record_connection_failure(Some(err.to_string()))
523                    .await;
524                error!("Connection failed: {}", err);
525                Err(err)
526            }
527            Err(_) => {
528                let timeout_error = crate::connector::ConnectorError::Timeout;
529                self.stats
530                    .record_connection_failure(Some(timeout_error.to_string()))
531                    .await;
532                error!(
533                    "Connection timed out after {:?}",
534                    self.config.connection_timeout
535                );
536                Err(timeout_error)
537            }
538        }
539    }
540
541    async fn disconnect(&self) -> crate::connector::ConnectorResult<()> {
542        self.stats
543            .record_disconnection(Some("Manual disconnect".to_string()))
544            .await;
545        self.inner.disconnect().await
546    }
547}
548
549/// Builder for creating a testing wrapper with custom configuration
550pub struct TestingWrapperBuilder<S: AppState> {
551    config: TestingConfig,
552    _phantom: std::marker::PhantomData<S>,
553}
554
555impl<S: AppState> TestingWrapperBuilder<S> {
556    pub fn new() -> Self {
557        Self {
558            config: TestingConfig::default(),
559            _phantom: std::marker::PhantomData,
560        }
561    }
562
563    pub fn with_stats_interval(mut self, interval: Duration) -> Self {
564        self.config.stats_interval = interval;
565        self
566    }
567
568    pub fn with_log_stats(mut self, log_stats: bool) -> Self {
569        self.config.log_stats = log_stats;
570        self
571    }
572
573    pub fn with_track_events(mut self, track_events: bool) -> Self {
574        self.config.track_events = track_events;
575        self
576    }
577
578    pub fn with_max_reconnect_attempts(mut self, max_attempts: Option<u32>) -> Self {
579        self.config.max_reconnect_attempts = max_attempts;
580        self
581    }
582
583    pub fn with_reconnect_delay(mut self, delay: Duration) -> Self {
584        self.config.reconnect_delay = delay;
585        self
586    }
587
588    pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
589        self.config.connection_timeout = timeout;
590        self
591    }
592
593    pub fn with_auto_reconnect(mut self, auto_reconnect: bool) -> Self {
594        self.config.auto_reconnect = auto_reconnect;
595        self
596    }
597
598    pub fn build(self, client: Client<S>, runner: ClientRunner<S>) -> TestingWrapper<S> {
599        TestingWrapper::new(client, runner, self.config)
600    }
601
602    /// Build the testing wrapper and return both the wrapper and a compatible middleware
603    pub async fn build_with_middleware(
604        self,
605        builder: ClientBuilder<S>,
606    ) -> CoreResult<TestingWrapper<S>> {
607        let stats = Arc::new(StatisticsTracker::new());
608        let middleware = TestingMiddleware::new(Arc::clone(&stats));
609        let (client, runner) = builder
610            .with_middleware(Box::new(middleware))
611            .build()
612            .await?;
613        let wrapper = TestingWrapper::new_with_stats(client, runner, self.config, stats);
614
615        Ok(wrapper)
616    }
617}
618
619impl<S: AppState> Default for TestingWrapperBuilder<S> {
620    fn default() -> Self {
621        Self::new()
622    }
623}