Skip to main content

clasp_test_utils/
lib.rs

1//! Common test helpers and utilities for CLASP tests
2//!
3//! This crate provides robust test utilities including:
4//! - Condition-based waiting (no hardcoded sleeps)
5//! - Proper resource cleanup with RAII
6//! - Strong assertion helpers
7//! - Test router management
8//! - Value collectors for subscription testing
9
10use clasp_client::Clasp;
11use clasp_core::{SecurityMode, Value};
12use clasp_router::{Router, RouterConfig};
13use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use tokio::sync::Notify;
17use tokio::time::timeout;
18
19/// Default test timeout
20pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
21
22/// Default condition check interval
23pub const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_millis(10);
24
25// ============================================================================
26// Port Allocation
27// ============================================================================
28
29/// Find an available TCP port for testing
30pub async fn find_available_port() -> u16 {
31    let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
32    listener.local_addr().unwrap().port()
33}
34
35/// Find an available UDP port for testing
36pub fn find_available_udp_port() -> u16 {
37    let socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
38    socket.local_addr().unwrap().port()
39}
40
41// ============================================================================
42// Condition-Based Waiting
43// ============================================================================
44
45/// Wait for a condition with timeout - condition-based, not time-based
46pub async fn wait_for<F, Fut>(check: F, interval: Duration, max_wait: Duration) -> bool
47where
48    F: Fn() -> Fut,
49    Fut: std::future::Future<Output = bool>,
50{
51    let start = Instant::now();
52    while start.elapsed() < max_wait {
53        if check().await {
54            return true;
55        }
56        tokio::time::sleep(interval).await;
57    }
58    false
59}
60
61/// Wait for an atomic counter to reach a target value
62pub async fn wait_for_count(counter: &AtomicU32, target: u32, max_wait: Duration) -> bool {
63    wait_for(
64        || async { counter.load(Ordering::SeqCst) >= target },
65        DEFAULT_CHECK_INTERVAL,
66        max_wait,
67    )
68    .await
69}
70
71/// Wait for a boolean flag to become true
72pub async fn wait_for_flag(flag: &AtomicBool, max_wait: Duration) -> bool {
73    wait_for(
74        || async { flag.load(Ordering::SeqCst) },
75        DEFAULT_CHECK_INTERVAL,
76        max_wait,
77    )
78    .await
79}
80
81/// Wait with notification - more efficient than polling
82pub async fn wait_with_notify(notify: &Notify, max_wait: Duration) -> bool {
83    timeout(max_wait, notify.notified()).await.is_ok()
84}
85
86// ============================================================================
87// Test Router - RAII wrapper with proper cleanup
88// ============================================================================
89
90/// A test router that automatically cleans up on drop
91pub struct TestRouter {
92    port: u16,
93    handle: Option<tokio::task::JoinHandle<()>>,
94    ready: Arc<AtomicBool>,
95}
96
97impl TestRouter {
98    /// Start a test router with default configuration
99    pub async fn start() -> Self {
100        Self::start_with_config(RouterConfig {
101            name: "Test Router".to_string(),
102            max_sessions: 100,
103            session_timeout: 60,
104            features: vec![
105                "param".to_string(),
106                "event".to_string(),
107                "stream".to_string(),
108            ],
109            security_mode: SecurityMode::Open,
110            max_subscriptions_per_session: 1000,
111            gesture_coalescing: true,
112            gesture_coalesce_interval_ms: 16,
113            max_messages_per_second: 0, // Disable rate limiting for tests
114            rate_limiting_enabled: false,
115            state_config: clasp_router::RouterStateConfig::unlimited(), // No TTL in tests
116        })
117        .await
118    }
119
120    /// Start a test router with custom configuration
121    pub async fn start_with_config(config: RouterConfig) -> Self {
122        let port = find_available_port().await;
123        let addr = format!("127.0.0.1:{}", port);
124        let ready = Arc::new(AtomicBool::new(false));
125        let ready_clone = ready.clone();
126
127        let router = Router::new(config);
128
129        let handle = tokio::spawn(async move {
130            ready_clone.store(true, Ordering::SeqCst);
131            let _ = router.serve_websocket(&addr).await;
132        });
133
134        // Wait for router to be ready using condition-based wait
135        let start = Instant::now();
136        while !ready.load(Ordering::SeqCst) && start.elapsed() < Duration::from_secs(5) {
137            tokio::time::sleep(Duration::from_millis(10)).await;
138        }
139
140        // Additional check: try to connect to verify the port is listening
141        let _ = wait_for(
142            || async move {
143                tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
144                    .await
145                    .is_ok()
146            },
147            Duration::from_millis(10),
148            Duration::from_secs(5),
149        )
150        .await;
151
152        Self {
153            port,
154            handle: Some(handle),
155            ready,
156        }
157    }
158
159    /// Get the WebSocket URL for this router
160    pub fn url(&self) -> String {
161        format!("ws://127.0.0.1:{}", self.port)
162    }
163
164    /// Get the port number
165    pub fn port(&self) -> u16 {
166        self.port
167    }
168
169    /// Check if router is ready
170    pub fn is_ready(&self) -> bool {
171        self.ready.load(Ordering::SeqCst)
172    }
173
174    /// Connect a client to this router
175    pub async fn connect_client(&self) -> Result<Clasp, clasp_client::ClientError> {
176        Clasp::connect_to(&self.url()).await
177    }
178
179    /// Connect a client with a custom name
180    pub async fn connect_client_named(
181        &self,
182        name: &str,
183    ) -> Result<Clasp, clasp_client::ClientError> {
184        Clasp::builder(&self.url()).name(name).connect().await
185    }
186
187    /// Stop the router explicitly (also happens on drop)
188    pub fn stop(&mut self) {
189        if let Some(handle) = self.handle.take() {
190            handle.abort();
191        }
192    }
193}
194
195impl Drop for TestRouter {
196    fn drop(&mut self) {
197        self.stop();
198    }
199}
200
201// ============================================================================
202// Assertion Helpers
203// ============================================================================
204
205/// Assert that two values are approximately equal (for floating point)
206pub fn assert_approx_eq(actual: f64, expected: f64, epsilon: f64, msg: &str) -> Result<(), String> {
207    if (actual - expected).abs() < epsilon {
208        Ok(())
209    } else {
210        Err(format!(
211            "{}: expected {} +/- {}, got {}",
212            msg, expected, epsilon, actual
213        ))
214    }
215}
216
217/// Assert a condition with a custom message
218pub fn assert_that(condition: bool, msg: &str) -> Result<(), String> {
219    if condition {
220        Ok(())
221    } else {
222        Err(msg.to_string())
223    }
224}
225
226/// Assert that an Option is Some and return the value
227pub fn assert_some<T>(opt: Option<T>, msg: &str) -> Result<T, String> {
228    opt.ok_or_else(|| msg.to_string())
229}
230
231/// Assert that a Result is Ok and return the value
232pub fn assert_ok<T, E: std::fmt::Debug>(result: Result<T, E>, msg: &str) -> Result<T, String> {
233    result.map_err(|e| format!("{}: {:?}", msg, e))
234}
235
236/// Assert that a Result is Err
237pub fn assert_err<T: std::fmt::Debug, E>(result: Result<T, E>, msg: &str) -> Result<(), String> {
238    match result {
239        Ok(v) => Err(format!("{}: expected error, got Ok({:?})", msg, v)),
240        Err(_) => Ok(()),
241    }
242}
243
244// ============================================================================
245// Test Collectors - for verifying received values
246// ============================================================================
247
248/// Collector for subscription values with thread-safe access
249#[derive(Clone)]
250pub struct ValueCollector {
251    values: Arc<parking_lot::Mutex<Vec<(String, Value)>>>,
252    notify: Arc<Notify>,
253    count: Arc<AtomicU32>,
254}
255
256impl ValueCollector {
257    pub fn new() -> Self {
258        Self {
259            values: Arc::new(parking_lot::Mutex::new(Vec::new())),
260            notify: Arc::new(Notify::new()),
261            count: Arc::new(AtomicU32::new(0)),
262        }
263    }
264
265    /// Create a callback function for subscriptions
266    pub fn callback(&self) -> impl Fn(Value, String) + Send + 'static {
267        let values = self.values.clone();
268        let notify = self.notify.clone();
269        let count = self.count.clone();
270
271        move |value, address| {
272            {
273                let mut guard = values.lock();
274                guard.push((address, value));
275            }
276            count.fetch_add(1, Ordering::SeqCst);
277            notify.notify_waiters();
278        }
279    }
280
281    /// Create a callback function for subscriptions that takes &str address
282    pub fn callback_ref(&self) -> impl Fn(Value, &str) + Send + Sync + 'static {
283        let values = self.values.clone();
284        let notify = self.notify.clone();
285        let count = self.count.clone();
286
287        move |value, address| {
288            {
289                let mut guard = values.lock();
290                guard.push((address.to_string(), value));
291            }
292            count.fetch_add(1, Ordering::SeqCst);
293            notify.notify_waiters();
294        }
295    }
296
297    /// Get the count of received values
298    pub fn count(&self) -> u32 {
299        self.count.load(Ordering::SeqCst)
300    }
301
302    /// Wait for at least n values to be received
303    pub async fn wait_for_count(&self, n: u32, max_wait: Duration) -> bool {
304        wait_for_count(&self.count, n, max_wait).await
305    }
306
307    /// Get all collected values
308    pub fn values(&self) -> Vec<(String, Value)> {
309        self.values.lock().clone()
310    }
311
312    /// Check if a specific address was received
313    pub fn has_address(&self, addr: &str) -> bool {
314        self.values.lock().iter().any(|(a, _)| a == addr)
315    }
316
317    /// Get values for a specific address pattern
318    pub fn values_for(&self, addr: &str) -> Vec<Value> {
319        self.values
320            .lock()
321            .iter()
322            .filter(|(a, _)| a == addr)
323            .map(|(_, v)| v.clone())
324            .collect()
325    }
326
327    /// Get the last value received
328    pub fn last_value(&self) -> Option<(String, Value)> {
329        self.values.lock().last().cloned()
330    }
331
332    /// Clear all collected values
333    pub fn clear(&self) {
334        self.values.lock().clear();
335        self.count.store(0, Ordering::SeqCst);
336    }
337}
338
339impl Default for ValueCollector {
340    fn default() -> Self {
341        Self::new()
342    }
343}