Skip to main content

entrenar/server/
mod.rs

1//! REST/HTTP API Server (#67)
2//!
3//! Remote access to experiment tracking with built-in quality stops.
4//!
5//! # Toyota Principle: Jidoka (自働化)
6//!
7//! Built-in quality - Remote access enables team-wide visibility while
8//! maintaining quality through input validation and error handling.
9//!
10//! # Example
11//!
12//! ```ignore
13//! use entrenar::server::{TrackingServer, ServerConfig};
14//! use std::net::SocketAddr;
15//!
16//! let config = ServerConfig::default();
17//! let server = TrackingServer::new(config);
18//! server.run("127.0.0.1:5000".parse().expect("valid address")).await?;
19//! ```
20
21#[cfg(feature = "server")]
22mod api;
23#[cfg(feature = "server")]
24mod handlers;
25#[cfg(feature = "server")]
26mod state;
27
28#[cfg(feature = "server")]
29pub use api::*;
30#[cfg(feature = "server")]
31pub use handlers::*;
32#[cfg(feature = "server")]
33pub use state::*;
34
35use serde::{Deserialize, Serialize};
36use std::net::SocketAddr;
37use thiserror::Error;
38
39/// Server errors
40#[derive(Debug, Error)]
41pub enum ServerError {
42    #[error("Bind error: {0}")]
43    Bind(String),
44
45    #[error("IO error: {0}")]
46    Io(#[from] std::io::Error),
47
48    #[error("Storage error: {0}")]
49    Storage(String),
50
51    #[error("Not found: {0}")]
52    NotFound(String),
53
54    #[error("Validation error: {0}")]
55    Validation(String),
56
57    #[error("Internal error: {0}")]
58    Internal(String),
59}
60
61/// Result type for server operations
62pub type Result<T> = std::result::Result<T, ServerError>;
63
64/// Server configuration
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ServerConfig {
67    /// Server address
68    pub address: SocketAddr,
69    /// Enable CORS
70    pub cors_enabled: bool,
71    /// Allowed origins for CORS
72    pub cors_origins: Vec<String>,
73    /// API key for authentication (None = no auth)
74    pub api_key: Option<String>,
75    /// Request timeout in seconds
76    pub timeout_secs: u64,
77    /// Maximum request body size in bytes
78    pub max_body_size: usize,
79}
80
81impl Default for ServerConfig {
82    fn default() -> Self {
83        Self {
84            address: "127.0.0.1:5000".parse().expect("default server address must be valid"),
85            cors_enabled: true,
86            cors_origins: vec!["*".to_string()],
87            api_key: None,
88            timeout_secs: 30,
89            max_body_size: 10 * 1024 * 1024, // 10MB
90        }
91    }
92}
93
94impl ServerConfig {
95    /// Create config with custom address
96    pub fn with_address(mut self, addr: SocketAddr) -> Self {
97        self.address = addr;
98        self
99    }
100
101    /// Create config with API key authentication
102    pub fn with_api_key(mut self, key: &str) -> Self {
103        self.api_key = Some(key.to_string());
104        self
105    }
106
107    /// Disable CORS
108    pub fn without_cors(mut self) -> Self {
109        self.cors_enabled = false;
110        self
111    }
112
113    /// Set allowed CORS origins
114    pub fn with_cors_origins(mut self, origins: Vec<String>) -> Self {
115        self.cors_origins = origins;
116        self
117    }
118}
119
120/// API response wrapper
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct ApiResponse<T> {
123    /// Whether the request was successful
124    pub success: bool,
125    /// Response data (if successful)
126    pub data: Option<T>,
127    /// Error message (if failed)
128    pub error: Option<String>,
129    /// Request ID for tracing
130    pub request_id: String,
131}
132
133impl<T> ApiResponse<T> {
134    /// Create success response
135    pub fn success(data: T, request_id: &str) -> Self {
136        Self { success: true, data: Some(data), error: None, request_id: request_id.to_string() }
137    }
138
139    /// Create error response
140    pub fn error(message: &str, request_id: &str) -> Self {
141        Self {
142            success: false,
143            data: None,
144            error: Some(message.to_string()),
145            request_id: request_id.to_string(),
146        }
147    }
148}
149
150/// Health check response
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct HealthResponse {
153    /// Server status
154    pub status: String,
155    /// Server version
156    pub version: String,
157    /// Uptime in seconds
158    pub uptime_secs: u64,
159    /// Number of active experiments
160    pub experiments_count: usize,
161    /// Number of active runs
162    pub runs_count: usize,
163}
164
165// =============================================================================
166// Request/Response DTOs
167// =============================================================================
168
169/// Create experiment request
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct CreateExperimentRequest {
172    /// Experiment name
173    pub name: String,
174    /// Optional description
175    pub description: Option<String>,
176    /// Optional tags
177    pub tags: Option<std::collections::HashMap<String, String>>,
178}
179
180/// Create run request
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct CreateRunRequest {
183    /// Experiment ID
184    pub experiment_id: String,
185    /// Optional run name
186    pub name: Option<String>,
187    /// Optional tags
188    pub tags: Option<std::collections::HashMap<String, String>>,
189}
190
191/// Log parameters request
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct LogParamsRequest {
194    /// Parameters to log
195    pub params: std::collections::HashMap<String, serde_json::Value>,
196}
197
198/// Log metrics request
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct LogMetricsRequest {
201    /// Metrics to log (name -> value)
202    pub metrics: std::collections::HashMap<String, f64>,
203    /// Optional step number
204    pub step: Option<u64>,
205}
206
207/// Update run request
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct UpdateRunRequest {
210    /// New status
211    pub status: Option<String>,
212    /// End time (ISO 8601)
213    pub end_time: Option<String>,
214}
215
216/// Experiment response
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct ExperimentResponse {
219    /// Experiment ID
220    pub id: String,
221    /// Experiment name
222    pub name: String,
223    /// Description
224    pub description: Option<String>,
225    /// Creation time
226    pub created_at: String,
227    /// Tags
228    pub tags: std::collections::HashMap<String, String>,
229}
230
231/// Run response
232#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct RunResponse {
234    /// Run ID
235    pub id: String,
236    /// Experiment ID
237    pub experiment_id: String,
238    /// Run name
239    pub name: Option<String>,
240    /// Status
241    pub status: String,
242    /// Start time
243    pub start_time: String,
244    /// End time
245    pub end_time: Option<String>,
246    /// Parameters
247    pub params: std::collections::HashMap<String, serde_json::Value>,
248    /// Latest metrics
249    pub metrics: std::collections::HashMap<String, f64>,
250    /// Tags
251    pub tags: std::collections::HashMap<String, String>,
252}
253
254// =============================================================================
255// Tests
256// =============================================================================
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_server_config_default() {
264        let config = ServerConfig::default();
265        assert_eq!(config.address.port(), 5000);
266        assert!(config.cors_enabled);
267        assert!(config.api_key.is_none());
268    }
269
270    #[test]
271    fn test_server_config_with_address() {
272        let addr: SocketAddr = "0.0.0.0:8080".parse().expect("parsing should succeed");
273        let config = ServerConfig::default().with_address(addr);
274        assert_eq!(config.address.port(), 8080);
275    }
276
277    #[test]
278    fn test_server_config_with_api_key() {
279        let config = ServerConfig::default().with_api_key("secret123");
280        assert_eq!(config.api_key, Some("secret123".to_string()));
281    }
282
283    #[test]
284    fn test_server_config_without_cors() {
285        let config = ServerConfig::default().without_cors();
286        assert!(!config.cors_enabled);
287    }
288
289    #[test]
290    fn test_api_response_success() {
291        let response = ApiResponse::success("hello", "req-123");
292        assert!(response.success);
293        assert_eq!(response.data, Some("hello"));
294        assert!(response.error.is_none());
295    }
296
297    #[test]
298    fn test_api_response_error() {
299        let response: ApiResponse<String> = ApiResponse::error("not found", "req-456");
300        assert!(!response.success);
301        assert!(response.data.is_none());
302        assert_eq!(response.error, Some("not found".to_string()));
303    }
304
305    #[test]
306    fn test_health_response_serialize() {
307        let health = HealthResponse {
308            status: "healthy".to_string(),
309            version: "0.2.3".to_string(),
310            uptime_secs: 3600,
311            experiments_count: 10,
312            runs_count: 50,
313        };
314        let json = serde_json::to_string(&health).expect("JSON serialization should succeed");
315        assert!(json.contains("healthy"));
316    }
317
318    #[test]
319    fn test_create_experiment_request() {
320        let json = r#"{"name": "test-exp", "description": "A test"}"#;
321        let req: CreateExperimentRequest =
322            serde_json::from_str(json).expect("JSON deserialization should succeed");
323        assert_eq!(req.name, "test-exp");
324        assert_eq!(req.description, Some("A test".to_string()));
325    }
326
327    #[test]
328    fn test_create_run_request() {
329        let json = r#"{"experiment_id": "exp-123", "name": "run-1"}"#;
330        let req: CreateRunRequest =
331            serde_json::from_str(json).expect("JSON deserialization should succeed");
332        assert_eq!(req.experiment_id, "exp-123");
333        assert_eq!(req.name, Some("run-1".to_string()));
334    }
335
336    #[test]
337    fn test_log_params_request() {
338        let json = r#"{"params": {"lr": 0.001, "batch_size": 32}}"#;
339        let req: LogParamsRequest =
340            serde_json::from_str(json).expect("JSON deserialization should succeed");
341        assert!(req.params.contains_key("lr"));
342        assert!(req.params.contains_key("batch_size"));
343    }
344
345    #[test]
346    fn test_log_metrics_request() {
347        let json = r#"{"metrics": {"loss": 0.5, "accuracy": 0.9}, "step": 100}"#;
348        let req: LogMetricsRequest =
349            serde_json::from_str(json).expect("JSON deserialization should succeed");
350        assert_eq!(req.metrics.get("loss"), Some(&0.5));
351        assert_eq!(req.step, Some(100));
352    }
353
354    #[test]
355    fn test_update_run_request() {
356        let json = r#"{"status": "completed", "end_time": "2024-01-15T10:30:00Z"}"#;
357        let req: UpdateRunRequest =
358            serde_json::from_str(json).expect("JSON deserialization should succeed");
359        assert_eq!(req.status, Some("completed".to_string()));
360    }
361
362    #[test]
363    fn test_experiment_response_serialize() {
364        let exp = ExperimentResponse {
365            id: "exp-123".to_string(),
366            name: "My Experiment".to_string(),
367            description: Some("Test".to_string()),
368            created_at: "2024-01-15T10:00:00Z".to_string(),
369            tags: std::collections::HashMap::new(),
370        };
371        let json = serde_json::to_string(&exp).expect("JSON serialization should succeed");
372        assert!(json.contains("exp-123"));
373    }
374
375    #[test]
376    fn test_run_response_serialize() {
377        let run = RunResponse {
378            id: "run-456".to_string(),
379            experiment_id: "exp-123".to_string(),
380            name: Some("training-run".to_string()),
381            status: "running".to_string(),
382            start_time: "2024-01-15T10:00:00Z".to_string(),
383            end_time: None,
384            params: std::collections::HashMap::new(),
385            metrics: std::collections::HashMap::new(),
386            tags: std::collections::HashMap::new(),
387        };
388        let json = serde_json::to_string(&run).expect("JSON serialization should succeed");
389        assert!(json.contains("run-456"));
390    }
391}
392
393// =============================================================================
394// Property Tests
395// =============================================================================
396
397#[cfg(test)]
398mod property_tests {
399    use super::*;
400    use proptest::prelude::*;
401
402    proptest! {
403        #![proptest_config(ProptestConfig::with_cases(200))]
404
405        #[test]
406        fn prop_server_config_port_preserved(port in 1024u16..65535) {
407            let addr: SocketAddr = format!("127.0.0.1:{port}").parse().expect("parsing should succeed");
408            let config = ServerConfig::default().with_address(addr);
409            prop_assert_eq!(config.address.port(), port);
410        }
411
412        #[test]
413        fn prop_api_response_success_has_data(data in "[a-zA-Z0-9]{1,100}") {
414            let response = ApiResponse::success(data.clone(), "req-1");
415            prop_assert!(response.success);
416            prop_assert_eq!(response.data, Some(data));
417        }
418
419        #[test]
420        fn prop_api_response_error_has_message(msg in "[a-zA-Z0-9 ]{1,100}") {
421            let response: ApiResponse<String> = ApiResponse::error(&msg, "req-1");
422            prop_assert!(!response.success);
423            prop_assert_eq!(response.error, Some(msg));
424        }
425
426        #[test]
427        fn prop_create_experiment_roundtrip(name in "[a-zA-Z0-9-]{1,50}") {
428            let req = CreateExperimentRequest {
429                name: name.clone(),
430                description: None,
431                tags: None,
432            };
433            let json = serde_json::to_string(&req).expect("JSON serialization should succeed");
434            let parsed: CreateExperimentRequest = serde_json::from_str(&json).expect("JSON deserialization should succeed");
435            prop_assert_eq!(parsed.name, name);
436        }
437
438        #[test]
439        fn prop_log_metrics_roundtrip(
440            metric_name in "[a-z_]{1,20}",
441            value in -1000.0f64..1000.0
442        ) {
443            let mut metrics = std::collections::HashMap::new();
444            metrics.insert(metric_name.clone(), value);
445            let req = LogMetricsRequest { metrics, step: None };
446            let json = serde_json::to_string(&req).expect("JSON serialization should succeed");
447            let parsed: LogMetricsRequest = serde_json::from_str(&json).expect("JSON deserialization should succeed");
448            prop_assert!((parsed.metrics.get(&metric_name).expect("parsing should succeed") - value).abs() < 1e-10);
449        }
450    }
451}