Skip to main content

clasp_test_utils/
lib.rs

1//! Common test helpers and utilities for CLASP tests
2//!
3//! This crate provides robust test utilities including:
4//! - Condition-based waiting (no hardcoded sleeps)
5//! - Proper resource cleanup with RAII
6//! - Strong assertion helpers
7//! - Test router management
8//! - Value collectors for subscription testing
9
10use clasp_client::Clasp;
11use clasp_core::{SecurityMode, Value};
12use clasp_router::{Router, RouterConfig};
13use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use tokio::sync::Notify;
17use tokio::time::timeout;
18
19/// Default test timeout
20pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
21
22/// Default condition check interval
23pub const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_millis(10);
24
25// ============================================================================
26// Port Allocation
27// ============================================================================
28
29/// Find an available TCP port for testing
30pub async fn find_available_port() -> u16 {
31    let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
32    listener.local_addr().unwrap().port()
33}
34
35/// Find an available UDP port for testing
36pub fn find_available_udp_port() -> u16 {
37    let socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
38    socket.local_addr().unwrap().port()
39}
40
41// ============================================================================
42// Condition-Based Waiting
43// ============================================================================
44
45/// Wait for a condition with timeout - condition-based, not time-based
46pub async fn wait_for<F, Fut>(check: F, interval: Duration, max_wait: Duration) -> bool
47where
48    F: Fn() -> Fut,
49    Fut: std::future::Future<Output = bool>,
50{
51    let start = Instant::now();
52    while start.elapsed() < max_wait {
53        if check().await {
54            return true;
55        }
56        tokio::time::sleep(interval).await;
57    }
58    false
59}
60
61/// Wait for an atomic counter to reach a target value
62pub async fn wait_for_count(counter: &AtomicU32, target: u32, max_wait: Duration) -> bool {
63    wait_for(
64        || async { counter.load(Ordering::SeqCst) >= target },
65        DEFAULT_CHECK_INTERVAL,
66        max_wait,
67    )
68    .await
69}
70
71/// Wait for a boolean flag to become true
72pub async fn wait_for_flag(flag: &AtomicBool, max_wait: Duration) -> bool {
73    wait_for(
74        || async { flag.load(Ordering::SeqCst) },
75        DEFAULT_CHECK_INTERVAL,
76        max_wait,
77    )
78    .await
79}
80
81/// Wait with notification - more efficient than polling
82pub async fn wait_with_notify(notify: &Notify, max_wait: Duration) -> bool {
83    timeout(max_wait, notify.notified()).await.is_ok()
84}
85
86// ============================================================================
87// Test Router - RAII wrapper with proper cleanup
88// ============================================================================
89
90/// A test router that automatically cleans up on drop
91pub struct TestRouter {
92    port: u16,
93    handle: Option<tokio::task::JoinHandle<()>>,
94    ready: Arc<AtomicBool>,
95}
96
97impl TestRouter {
98    /// Start a test router with default configuration
99    pub async fn start() -> Self {
100        Self::start_with_config(RouterConfig {
101            name: "Test Router".to_string(),
102            max_sessions: 100,
103            session_timeout: 60,
104            features: vec![
105                "param".to_string(),
106                "event".to_string(),
107                "stream".to_string(),
108            ],
109            security_mode: SecurityMode::Open,
110            max_subscriptions_per_session: 1000,
111            gesture_coalescing: true,
112            gesture_coalesce_interval_ms: 16,
113            max_messages_per_second: 0, // Disable rate limiting for tests
114            rate_limiting_enabled: false,
115            state_config: clasp_router::RouterStateConfig::unlimited(), // No TTL in tests
116        })
117        .await
118    }
119
120    /// Start a test router with custom configuration
121    pub async fn start_with_config(config: RouterConfig) -> Self {
122        let port = find_available_port().await;
123        let addr = format!("127.0.0.1:{}", port);
124        let ready = Arc::new(AtomicBool::new(false));
125        let ready_clone = ready.clone();
126
127        let router = Router::new(config);
128
129        let handle = tokio::spawn(async move {
130            ready_clone.store(true, Ordering::SeqCst);
131            let _ = router.serve_websocket(&addr).await;
132        });
133
134        // Wait for router to be ready using condition-based wait
135        let start = Instant::now();
136        while !ready.load(Ordering::SeqCst) && start.elapsed() < Duration::from_secs(5) {
137            tokio::time::sleep(Duration::from_millis(10)).await;
138        }
139
140        // Additional check: try to connect to verify the port is listening
141        let _ = wait_for(
142            || {
143                let port = port;
144                async move {
145                    tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
146                        .await
147                        .is_ok()
148                }
149            },
150            Duration::from_millis(10),
151            Duration::from_secs(5),
152        )
153        .await;
154
155        Self {
156            port,
157            handle: Some(handle),
158            ready,
159        }
160    }
161
162    /// Get the WebSocket URL for this router
163    pub fn url(&self) -> String {
164        format!("ws://127.0.0.1:{}", self.port)
165    }
166
167    /// Get the port number
168    pub fn port(&self) -> u16 {
169        self.port
170    }
171
172    /// Check if router is ready
173    pub fn is_ready(&self) -> bool {
174        self.ready.load(Ordering::SeqCst)
175    }
176
177    /// Connect a client to this router
178    pub async fn connect_client(&self) -> Result<Clasp, clasp_client::ClientError> {
179        Clasp::connect_to(&self.url()).await
180    }
181
182    /// Connect a client with a custom name
183    pub async fn connect_client_named(
184        &self,
185        name: &str,
186    ) -> Result<Clasp, clasp_client::ClientError> {
187        Clasp::builder(&self.url()).name(name).connect().await
188    }
189
190    /// Stop the router explicitly (also happens on drop)
191    pub fn stop(&mut self) {
192        if let Some(handle) = self.handle.take() {
193            handle.abort();
194        }
195    }
196}
197
198impl Drop for TestRouter {
199    fn drop(&mut self) {
200        self.stop();
201    }
202}
203
204// ============================================================================
205// Assertion Helpers
206// ============================================================================
207
208/// Assert that two values are approximately equal (for floating point)
209pub fn assert_approx_eq(actual: f64, expected: f64, epsilon: f64, msg: &str) -> Result<(), String> {
210    if (actual - expected).abs() < epsilon {
211        Ok(())
212    } else {
213        Err(format!(
214            "{}: expected {} +/- {}, got {}",
215            msg, expected, epsilon, actual
216        ))
217    }
218}
219
220/// Assert a condition with a custom message
221pub fn assert_that(condition: bool, msg: &str) -> Result<(), String> {
222    if condition {
223        Ok(())
224    } else {
225        Err(msg.to_string())
226    }
227}
228
229/// Assert that an Option is Some and return the value
230pub fn assert_some<T>(opt: Option<T>, msg: &str) -> Result<T, String> {
231    opt.ok_or_else(|| msg.to_string())
232}
233
234/// Assert that a Result is Ok and return the value
235pub fn assert_ok<T, E: std::fmt::Debug>(result: Result<T, E>, msg: &str) -> Result<T, String> {
236    result.map_err(|e| format!("{}: {:?}", msg, e))
237}
238
239/// Assert that a Result is Err
240pub fn assert_err<T: std::fmt::Debug, E>(result: Result<T, E>, msg: &str) -> Result<(), String> {
241    match result {
242        Ok(v) => Err(format!("{}: expected error, got Ok({:?})", msg, v)),
243        Err(_) => Ok(()),
244    }
245}
246
247// ============================================================================
248// Test Collectors - for verifying received values
249// ============================================================================
250
251/// Collector for subscription values with thread-safe access
252#[derive(Clone)]
253pub struct ValueCollector {
254    values: Arc<parking_lot::Mutex<Vec<(String, Value)>>>,
255    notify: Arc<Notify>,
256    count: Arc<AtomicU32>,
257}
258
259impl ValueCollector {
260    pub fn new() -> Self {
261        Self {
262            values: Arc::new(parking_lot::Mutex::new(Vec::new())),
263            notify: Arc::new(Notify::new()),
264            count: Arc::new(AtomicU32::new(0)),
265        }
266    }
267
268    /// Create a callback function for subscriptions
269    pub fn callback(&self) -> impl Fn(Value, String) + Send + 'static {
270        let values = self.values.clone();
271        let notify = self.notify.clone();
272        let count = self.count.clone();
273
274        move |value, address| {
275            {
276                let mut guard = values.lock();
277                guard.push((address, value));
278            }
279            count.fetch_add(1, Ordering::SeqCst);
280            notify.notify_waiters();
281        }
282    }
283
284    /// Create a callback function for subscriptions that takes &str address
285    pub fn callback_ref(&self) -> impl Fn(Value, &str) + Send + Sync + 'static {
286        let values = self.values.clone();
287        let notify = self.notify.clone();
288        let count = self.count.clone();
289
290        move |value, address| {
291            {
292                let mut guard = values.lock();
293                guard.push((address.to_string(), value));
294            }
295            count.fetch_add(1, Ordering::SeqCst);
296            notify.notify_waiters();
297        }
298    }
299
300    /// Get the count of received values
301    pub fn count(&self) -> u32 {
302        self.count.load(Ordering::SeqCst)
303    }
304
305    /// Wait for at least n values to be received
306    pub async fn wait_for_count(&self, n: u32, max_wait: Duration) -> bool {
307        wait_for_count(&self.count, n, max_wait).await
308    }
309
310    /// Get all collected values
311    pub fn values(&self) -> Vec<(String, Value)> {
312        self.values.lock().clone()
313    }
314
315    /// Check if a specific address was received
316    pub fn has_address(&self, addr: &str) -> bool {
317        self.values.lock().iter().any(|(a, _)| a == addr)
318    }
319
320    /// Get values for a specific address pattern
321    pub fn values_for(&self, addr: &str) -> Vec<Value> {
322        self.values
323            .lock()
324            .iter()
325            .filter(|(a, _)| a == addr)
326            .map(|(_, v)| v.clone())
327            .collect()
328    }
329
330    /// Get the last value received
331    pub fn last_value(&self) -> Option<(String, Value)> {
332        self.values.lock().last().cloned()
333    }
334
335    /// Clear all collected values
336    pub fn clear(&self) {
337        self.values.lock().clear();
338        self.count.store(0, Ordering::SeqCst);
339    }
340}
341
342impl Default for ValueCollector {
343    fn default() -> Self {
344        Self::new()
345    }
346}