datasynth_test_utils/
server.rs1use std::net::SocketAddr;
4use std::sync::atomic::{AtomicU16, Ordering};
5use std::time::Duration;
6
7use thiserror::Error;
8use tokio::time::timeout;
9
10#[derive(Debug, Error)]
12pub enum TestServerError {
13 #[error("Server startup timeout")]
14 StartupTimeout,
15
16 #[error("Health check failed: {0}")]
17 HealthCheckFailed(String),
18
19 #[error("Request failed: {0}")]
20 RequestFailed(String),
21
22 #[error("Invalid response: {0}")]
23 InvalidResponse(String),
24}
25
26static PORT_COUNTER: AtomicU16 = AtomicU16::new(50100);
28
29pub fn get_test_port() -> u16 {
31 PORT_COUNTER.fetch_add(1, Ordering::SeqCst)
32}
33
34#[derive(Debug, Clone)]
36pub struct TestServerConfig {
37 pub host: String,
39 pub port: u16,
41 pub startup_timeout_secs: u64,
43 pub health_check_interval_ms: u64,
45}
46
47impl Default for TestServerConfig {
48 fn default() -> Self {
49 Self {
50 host: "127.0.0.1".to_string(),
51 port: get_test_port(),
52 startup_timeout_secs: 10,
53 health_check_interval_ms: 100,
54 }
55 }
56}
57
58impl TestServerConfig {
59 pub fn addr(&self) -> SocketAddr {
61 format!("{}:{}", self.host, self.port)
62 .parse()
63 .expect("Invalid address")
64 }
65
66 pub fn rest_url(&self) -> String {
68 format!("http://{}:{}", self.host, self.port)
69 }
70
71 pub fn grpc_url(&self) -> String {
73 format!("http://{}:{}", self.host, self.port)
74 }
75
76 pub fn ws_url(&self, path: &str) -> String {
78 format!("ws://{}:{}{}", self.host, self.port, path)
79 }
80}
81
82pub async fn wait_for_health(
84 base_url: &str,
85 timeout_secs: u64,
86 interval_ms: u64,
87) -> Result<(), TestServerError> {
88 let client = reqwest::Client::new();
89 let health_url = format!("{}/health", base_url);
90
91 let result = timeout(Duration::from_secs(timeout_secs), async {
92 loop {
93 match client.get(&health_url).send().await {
94 Ok(response) if response.status().is_success() => {
95 return Ok(());
96 }
97 Ok(response) => {
98 let status = response.status();
100 let body = response.text().await.unwrap_or_default();
101 tracing::debug!("Health check returned {}: {}", status, body);
102 }
103 Err(e) => {
104 tracing::debug!("Health check failed: {}", e);
105 }
106 }
107 tokio::time::sleep(Duration::from_millis(interval_ms)).await;
108 }
109 })
110 .await;
111
112 match result {
113 Ok(Ok(())) => Ok(()),
114 Ok(Err(e)) => Err(e),
115 Err(_) => Err(TestServerError::StartupTimeout),
116 }
117}
118
119pub async fn is_healthy(base_url: &str) -> bool {
121 let client = reqwest::Client::new();
122 let health_url = format!("{}/health", base_url);
123
124 match client.get(&health_url).send().await {
125 Ok(response) => response.status().is_success(),
126 Err(_) => false,
127 }
128}
129
130pub struct TestHttpClient {
132 client: reqwest::Client,
133 base_url: String,
134}
135
136impl TestHttpClient {
137 pub fn new(base_url: &str) -> Self {
138 Self {
139 client: reqwest::Client::builder()
140 .timeout(Duration::from_secs(30))
141 .build()
142 .expect("Failed to create HTTP client"),
143 base_url: base_url.to_string(),
144 }
145 }
146
147 pub async fn get(&self, path: &str) -> Result<reqwest::Response, TestServerError> {
149 let url = format!("{}{}", self.base_url, path);
150 self.client
151 .get(&url)
152 .send()
153 .await
154 .map_err(|e| TestServerError::RequestFailed(e.to_string()))
155 }
156
157 pub async fn get_json<T: serde::de::DeserializeOwned>(
159 &self,
160 path: &str,
161 ) -> Result<T, TestServerError> {
162 let response = self.get(path).await?;
163 response
164 .json()
165 .await
166 .map_err(|e| TestServerError::InvalidResponse(e.to_string()))
167 }
168
169 pub async fn post<T: serde::Serialize>(
171 &self,
172 path: &str,
173 body: &T,
174 ) -> Result<reqwest::Response, TestServerError> {
175 let url = format!("{}{}", self.base_url, path);
176 self.client
177 .post(&url)
178 .json(body)
179 .send()
180 .await
181 .map_err(|e| TestServerError::RequestFailed(e.to_string()))
182 }
183
184 pub async fn post_json<T: serde::Serialize, R: serde::de::DeserializeOwned>(
186 &self,
187 path: &str,
188 body: &T,
189 ) -> Result<R, TestServerError> {
190 let response = self.post(path, body).await?;
191 response
192 .json()
193 .await
194 .map_err(|e| TestServerError::InvalidResponse(e.to_string()))
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn test_get_test_port_unique() {
204 let port1 = get_test_port();
205 let port2 = get_test_port();
206 let port3 = get_test_port();
207
208 assert_ne!(port1, port2);
209 assert_ne!(port2, port3);
210 assert_ne!(port1, port3);
211 }
212
213 #[test]
214 fn test_server_config_default() {
215 let config = TestServerConfig::default();
216 assert_eq!(config.host, "127.0.0.1");
217 assert!(config.port > 50000);
218 assert_eq!(config.startup_timeout_secs, 10);
219 }
220
221 #[test]
222 #[allow(clippy::field_reassign_with_default)]
223 fn test_server_config_urls() {
224 let mut config = TestServerConfig::default();
225 config.port = 3000;
226
227 assert_eq!(config.rest_url(), "http://127.0.0.1:3000");
228 assert_eq!(config.grpc_url(), "http://127.0.0.1:3000");
229 assert_eq!(config.ws_url("/ws/events"), "ws://127.0.0.1:3000/ws/events");
230 }
231}