Skip to main content

datasynth_test_utils/
server.rs

1//! Test server utilities for integration testing.
2
3use std::net::SocketAddr;
4use std::sync::atomic::{AtomicU16, Ordering};
5use std::time::Duration;
6
7use thiserror::Error;
8use tokio::time::timeout;
9
10/// Error type for test server operations.
11#[derive(Debug, Error)]
12pub enum TestServerError {
13    #[error("Server startup timeout")]
14    StartupTimeout,
15
16    #[error("Health check failed: {0}")]
17    HealthCheckFailed(String),
18
19    #[error("Request failed: {0}")]
20    RequestFailed(String),
21
22    #[error("Invalid response: {0}")]
23    InvalidResponse(String),
24}
25
26/// Global port counter for unique test ports.
27static PORT_COUNTER: AtomicU16 = AtomicU16::new(50100);
28
29/// Get a unique port for testing.
30pub fn get_test_port() -> u16 {
31    PORT_COUNTER.fetch_add(1, Ordering::SeqCst)
32}
33
34/// Configuration for a test server.
35#[derive(Debug, Clone)]
36pub struct TestServerConfig {
37    /// Host address.
38    pub host: String,
39    /// Port to listen on.
40    pub port: u16,
41    /// Startup timeout in seconds.
42    pub startup_timeout_secs: u64,
43    /// Health check interval in milliseconds.
44    pub health_check_interval_ms: u64,
45}
46
47impl Default for TestServerConfig {
48    fn default() -> Self {
49        Self {
50            host: "127.0.0.1".to_string(),
51            port: get_test_port(),
52            startup_timeout_secs: 10,
53            health_check_interval_ms: 100,
54        }
55    }
56}
57
58impl TestServerConfig {
59    /// Get the address as a SocketAddr.
60    pub fn addr(&self) -> SocketAddr {
61        format!("{}:{}", self.host, self.port)
62            .parse()
63            .expect("Invalid address")
64    }
65
66    /// Get the base URL for REST API.
67    pub fn rest_url(&self) -> String {
68        format!("http://{}:{}", self.host, self.port)
69    }
70
71    /// Get the gRPC address.
72    pub fn grpc_url(&self) -> String {
73        format!("http://{}:{}", self.host, self.port)
74    }
75
76    /// Get the WebSocket URL.
77    pub fn ws_url(&self, path: &str) -> String {
78        format!("ws://{}:{}{}", self.host, self.port, path)
79    }
80}
81
82/// Wait for a server to become healthy.
83pub async fn wait_for_health(
84    base_url: &str,
85    timeout_secs: u64,
86    interval_ms: u64,
87) -> Result<(), TestServerError> {
88    let client = reqwest::Client::new();
89    let health_url = format!("{}/health", base_url);
90
91    let result = timeout(Duration::from_secs(timeout_secs), async {
92        loop {
93            match client.get(&health_url).send().await {
94                Ok(response) if response.status().is_success() => {
95                    return Ok(());
96                }
97                Ok(response) => {
98                    // Server responded but not healthy
99                    let status = response.status();
100                    let body = response.text().await.unwrap_or_default();
101                    tracing::debug!("Health check returned {}: {}", status, body);
102                }
103                Err(e) => {
104                    tracing::debug!("Health check failed: {}", e);
105                }
106            }
107            tokio::time::sleep(Duration::from_millis(interval_ms)).await;
108        }
109    })
110    .await;
111
112    match result {
113        Ok(Ok(())) => Ok(()),
114        Ok(Err(e)) => Err(e),
115        Err(_) => Err(TestServerError::StartupTimeout),
116    }
117}
118
119/// Check if a server is healthy.
120pub async fn is_healthy(base_url: &str) -> bool {
121    let client = reqwest::Client::new();
122    let health_url = format!("{}/health", base_url);
123
124    match client.get(&health_url).send().await {
125        Ok(response) => response.status().is_success(),
126        Err(_) => false,
127    }
128}
129
130/// HTTP client wrapper for testing REST APIs.
131pub struct TestHttpClient {
132    client: reqwest::Client,
133    base_url: String,
134}
135
136impl TestHttpClient {
137    pub fn new(base_url: &str) -> Self {
138        Self {
139            client: reqwest::Client::builder()
140                .timeout(Duration::from_secs(30))
141                .build()
142                .expect("Failed to create HTTP client"),
143            base_url: base_url.to_string(),
144        }
145    }
146
147    /// GET request.
148    pub async fn get(&self, path: &str) -> Result<reqwest::Response, TestServerError> {
149        let url = format!("{}{}", self.base_url, path);
150        self.client
151            .get(&url)
152            .send()
153            .await
154            .map_err(|e| TestServerError::RequestFailed(e.to_string()))
155    }
156
157    /// GET request returning JSON.
158    pub async fn get_json<T: serde::de::DeserializeOwned>(
159        &self,
160        path: &str,
161    ) -> Result<T, TestServerError> {
162        let response = self.get(path).await?;
163        response
164            .json()
165            .await
166            .map_err(|e| TestServerError::InvalidResponse(e.to_string()))
167    }
168
169    /// POST request with JSON body.
170    pub async fn post<T: serde::Serialize>(
171        &self,
172        path: &str,
173        body: &T,
174    ) -> Result<reqwest::Response, TestServerError> {
175        let url = format!("{}{}", self.base_url, path);
176        self.client
177            .post(&url)
178            .json(body)
179            .send()
180            .await
181            .map_err(|e| TestServerError::RequestFailed(e.to_string()))
182    }
183
184    /// POST request returning JSON.
185    pub async fn post_json<T: serde::Serialize, R: serde::de::DeserializeOwned>(
186        &self,
187        path: &str,
188        body: &T,
189    ) -> Result<R, TestServerError> {
190        let response = self.post(path, body).await?;
191        response
192            .json()
193            .await
194            .map_err(|e| TestServerError::InvalidResponse(e.to_string()))
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_get_test_port_unique() {
204        let port1 = get_test_port();
205        let port2 = get_test_port();
206        let port3 = get_test_port();
207
208        assert_ne!(port1, port2);
209        assert_ne!(port2, port3);
210        assert_ne!(port1, port3);
211    }
212
213    #[test]
214    fn test_server_config_default() {
215        let config = TestServerConfig::default();
216        assert_eq!(config.host, "127.0.0.1");
217        assert!(config.port > 50000);
218        assert_eq!(config.startup_timeout_secs, 10);
219    }
220
221    #[test]
222    #[allow(clippy::field_reassign_with_default)]
223    fn test_server_config_urls() {
224        let mut config = TestServerConfig::default();
225        config.port = 3000;
226
227        assert_eq!(config.rest_url(), "http://127.0.0.1:3000");
228        assert_eq!(config.grpc_url(), "http://127.0.0.1:3000");
229        assert_eq!(config.ws_url("/ws/events"), "ws://127.0.0.1:3000/ws/events");
230    }
231}