1#[cfg(feature = "server")]
22mod api;
23#[cfg(feature = "server")]
24mod handlers;
25#[cfg(feature = "server")]
26mod state;
27
28#[cfg(feature = "server")]
29pub use api::*;
30#[cfg(feature = "server")]
31pub use handlers::*;
32#[cfg(feature = "server")]
33pub use state::*;
34
35use serde::{Deserialize, Serialize};
36use std::net::SocketAddr;
37use thiserror::Error;
38
39#[derive(Debug, Error)]
41pub enum ServerError {
42 #[error("Bind error: {0}")]
43 Bind(String),
44
45 #[error("IO error: {0}")]
46 Io(#[from] std::io::Error),
47
48 #[error("Storage error: {0}")]
49 Storage(String),
50
51 #[error("Not found: {0}")]
52 NotFound(String),
53
54 #[error("Validation error: {0}")]
55 Validation(String),
56
57 #[error("Internal error: {0}")]
58 Internal(String),
59}
60
61pub type Result<T> = std::result::Result<T, ServerError>;
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ServerConfig {
67 pub address: SocketAddr,
69 pub cors_enabled: bool,
71 pub cors_origins: Vec<String>,
73 pub api_key: Option<String>,
75 pub timeout_secs: u64,
77 pub max_body_size: usize,
79}
80
81impl Default for ServerConfig {
82 fn default() -> Self {
83 Self {
84 address: "127.0.0.1:5000".parse().expect("default server address must be valid"),
85 cors_enabled: true,
86 cors_origins: vec!["*".to_string()],
87 api_key: None,
88 timeout_secs: 30,
89 max_body_size: 10 * 1024 * 1024, }
91 }
92}
93
94impl ServerConfig {
95 pub fn with_address(mut self, addr: SocketAddr) -> Self {
97 self.address = addr;
98 self
99 }
100
101 pub fn with_api_key(mut self, key: &str) -> Self {
103 self.api_key = Some(key.to_string());
104 self
105 }
106
107 pub fn without_cors(mut self) -> Self {
109 self.cors_enabled = false;
110 self
111 }
112
113 pub fn with_cors_origins(mut self, origins: Vec<String>) -> Self {
115 self.cors_origins = origins;
116 self
117 }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct ApiResponse<T> {
123 pub success: bool,
125 pub data: Option<T>,
127 pub error: Option<String>,
129 pub request_id: String,
131}
132
133impl<T> ApiResponse<T> {
134 pub fn success(data: T, request_id: &str) -> Self {
136 Self { success: true, data: Some(data), error: None, request_id: request_id.to_string() }
137 }
138
139 pub fn error(message: &str, request_id: &str) -> Self {
141 Self {
142 success: false,
143 data: None,
144 error: Some(message.to_string()),
145 request_id: request_id.to_string(),
146 }
147 }
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct HealthResponse {
153 pub status: String,
155 pub version: String,
157 pub uptime_secs: u64,
159 pub experiments_count: usize,
161 pub runs_count: usize,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct CreateExperimentRequest {
172 pub name: String,
174 pub description: Option<String>,
176 pub tags: Option<std::collections::HashMap<String, String>>,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct CreateRunRequest {
183 pub experiment_id: String,
185 pub name: Option<String>,
187 pub tags: Option<std::collections::HashMap<String, String>>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct LogParamsRequest {
194 pub params: std::collections::HashMap<String, serde_json::Value>,
196}
197
198#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct LogMetricsRequest {
201 pub metrics: std::collections::HashMap<String, f64>,
203 pub step: Option<u64>,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct UpdateRunRequest {
210 pub status: Option<String>,
212 pub end_time: Option<String>,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct ExperimentResponse {
219 pub id: String,
221 pub name: String,
223 pub description: Option<String>,
225 pub created_at: String,
227 pub tags: std::collections::HashMap<String, String>,
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct RunResponse {
234 pub id: String,
236 pub experiment_id: String,
238 pub name: Option<String>,
240 pub status: String,
242 pub start_time: String,
244 pub end_time: Option<String>,
246 pub params: std::collections::HashMap<String, serde_json::Value>,
248 pub metrics: std::collections::HashMap<String, f64>,
250 pub tags: std::collections::HashMap<String, String>,
252}
253
254#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_server_config_default() {
264 let config = ServerConfig::default();
265 assert_eq!(config.address.port(), 5000);
266 assert!(config.cors_enabled);
267 assert!(config.api_key.is_none());
268 }
269
270 #[test]
271 fn test_server_config_with_address() {
272 let addr: SocketAddr = "0.0.0.0:8080".parse().expect("parsing should succeed");
273 let config = ServerConfig::default().with_address(addr);
274 assert_eq!(config.address.port(), 8080);
275 }
276
277 #[test]
278 fn test_server_config_with_api_key() {
279 let config = ServerConfig::default().with_api_key("secret123");
280 assert_eq!(config.api_key, Some("secret123".to_string()));
281 }
282
283 #[test]
284 fn test_server_config_without_cors() {
285 let config = ServerConfig::default().without_cors();
286 assert!(!config.cors_enabled);
287 }
288
289 #[test]
290 fn test_api_response_success() {
291 let response = ApiResponse::success("hello", "req-123");
292 assert!(response.success);
293 assert_eq!(response.data, Some("hello"));
294 assert!(response.error.is_none());
295 }
296
297 #[test]
298 fn test_api_response_error() {
299 let response: ApiResponse<String> = ApiResponse::error("not found", "req-456");
300 assert!(!response.success);
301 assert!(response.data.is_none());
302 assert_eq!(response.error, Some("not found".to_string()));
303 }
304
305 #[test]
306 fn test_health_response_serialize() {
307 let health = HealthResponse {
308 status: "healthy".to_string(),
309 version: "0.2.3".to_string(),
310 uptime_secs: 3600,
311 experiments_count: 10,
312 runs_count: 50,
313 };
314 let json = serde_json::to_string(&health).expect("JSON serialization should succeed");
315 assert!(json.contains("healthy"));
316 }
317
318 #[test]
319 fn test_create_experiment_request() {
320 let json = r#"{"name": "test-exp", "description": "A test"}"#;
321 let req: CreateExperimentRequest =
322 serde_json::from_str(json).expect("JSON deserialization should succeed");
323 assert_eq!(req.name, "test-exp");
324 assert_eq!(req.description, Some("A test".to_string()));
325 }
326
327 #[test]
328 fn test_create_run_request() {
329 let json = r#"{"experiment_id": "exp-123", "name": "run-1"}"#;
330 let req: CreateRunRequest =
331 serde_json::from_str(json).expect("JSON deserialization should succeed");
332 assert_eq!(req.experiment_id, "exp-123");
333 assert_eq!(req.name, Some("run-1".to_string()));
334 }
335
336 #[test]
337 fn test_log_params_request() {
338 let json = r#"{"params": {"lr": 0.001, "batch_size": 32}}"#;
339 let req: LogParamsRequest =
340 serde_json::from_str(json).expect("JSON deserialization should succeed");
341 assert!(req.params.contains_key("lr"));
342 assert!(req.params.contains_key("batch_size"));
343 }
344
345 #[test]
346 fn test_log_metrics_request() {
347 let json = r#"{"metrics": {"loss": 0.5, "accuracy": 0.9}, "step": 100}"#;
348 let req: LogMetricsRequest =
349 serde_json::from_str(json).expect("JSON deserialization should succeed");
350 assert_eq!(req.metrics.get("loss"), Some(&0.5));
351 assert_eq!(req.step, Some(100));
352 }
353
354 #[test]
355 fn test_update_run_request() {
356 let json = r#"{"status": "completed", "end_time": "2024-01-15T10:30:00Z"}"#;
357 let req: UpdateRunRequest =
358 serde_json::from_str(json).expect("JSON deserialization should succeed");
359 assert_eq!(req.status, Some("completed".to_string()));
360 }
361
362 #[test]
363 fn test_experiment_response_serialize() {
364 let exp = ExperimentResponse {
365 id: "exp-123".to_string(),
366 name: "My Experiment".to_string(),
367 description: Some("Test".to_string()),
368 created_at: "2024-01-15T10:00:00Z".to_string(),
369 tags: std::collections::HashMap::new(),
370 };
371 let json = serde_json::to_string(&exp).expect("JSON serialization should succeed");
372 assert!(json.contains("exp-123"));
373 }
374
375 #[test]
376 fn test_run_response_serialize() {
377 let run = RunResponse {
378 id: "run-456".to_string(),
379 experiment_id: "exp-123".to_string(),
380 name: Some("training-run".to_string()),
381 status: "running".to_string(),
382 start_time: "2024-01-15T10:00:00Z".to_string(),
383 end_time: None,
384 params: std::collections::HashMap::new(),
385 metrics: std::collections::HashMap::new(),
386 tags: std::collections::HashMap::new(),
387 };
388 let json = serde_json::to_string(&run).expect("JSON serialization should succeed");
389 assert!(json.contains("run-456"));
390 }
391}
392
393#[cfg(test)]
398mod property_tests {
399 use super::*;
400 use proptest::prelude::*;
401
402 proptest! {
403 #![proptest_config(ProptestConfig::with_cases(200))]
404
405 #[test]
406 fn prop_server_config_port_preserved(port in 1024u16..65535) {
407 let addr: SocketAddr = format!("127.0.0.1:{port}").parse().expect("parsing should succeed");
408 let config = ServerConfig::default().with_address(addr);
409 prop_assert_eq!(config.address.port(), port);
410 }
411
412 #[test]
413 fn prop_api_response_success_has_data(data in "[a-zA-Z0-9]{1,100}") {
414 let response = ApiResponse::success(data.clone(), "req-1");
415 prop_assert!(response.success);
416 prop_assert_eq!(response.data, Some(data));
417 }
418
419 #[test]
420 fn prop_api_response_error_has_message(msg in "[a-zA-Z0-9 ]{1,100}") {
421 let response: ApiResponse<String> = ApiResponse::error(&msg, "req-1");
422 prop_assert!(!response.success);
423 prop_assert_eq!(response.error, Some(msg));
424 }
425
426 #[test]
427 fn prop_create_experiment_roundtrip(name in "[a-zA-Z0-9-]{1,50}") {
428 let req = CreateExperimentRequest {
429 name: name.clone(),
430 description: None,
431 tags: None,
432 };
433 let json = serde_json::to_string(&req).expect("JSON serialization should succeed");
434 let parsed: CreateExperimentRequest = serde_json::from_str(&json).expect("JSON deserialization should succeed");
435 prop_assert_eq!(parsed.name, name);
436 }
437
438 #[test]
439 fn prop_log_metrics_roundtrip(
440 metric_name in "[a-z_]{1,20}",
441 value in -1000.0f64..1000.0
442 ) {
443 let mut metrics = std::collections::HashMap::new();
444 metrics.insert(metric_name.clone(), value);
445 let req = LogMetricsRequest { metrics, step: None };
446 let json = serde_json::to_string(&req).expect("JSON serialization should succeed");
447 let parsed: LogMetricsRequest = serde_json::from_str(&json).expect("JSON deserialization should succeed");
448 prop_assert!((parsed.metrics.get(&metric_name).expect("parsing should succeed") - value).abs() < 1e-10);
449 }
450 }
451}