nt_core/
config.rs

1//! Configuration management for the Neural Trading system
2//!
3//! This module provides configuration types with validation using serde and validator.
4//! Configuration can be loaded from environment variables, TOML files, or JSON.
5
6use crate::error::{Result, TradingError};
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9use validator::Validate;
10
11// ============================================================================
12// Main Configuration
13// ============================================================================
14
15/// Main application configuration
16#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
17pub struct AppConfig {
18    /// Server configuration
19    #[validate]
20    pub server: ServerConfig,
21
22    /// Broker configuration (Alpaca, etc.)
23    #[validate]
24    pub broker: BrokerConfig,
25
26    /// Strategy configurations
27    #[validate]
28    pub strategies: Vec<StrategyConfig>,
29
30    /// Risk management configuration
31    #[validate]
32    pub risk: RiskConfig,
33
34    /// Database configuration
35    #[validate]
36    pub database: DatabaseConfig,
37
38    /// Logging configuration
39    #[validate]
40    pub logging: LoggingConfig,
41}
42
43impl AppConfig {
44    /// Load configuration from a TOML file
45    ///
46    /// # Arguments
47    ///
48    /// * `path` - Path to TOML configuration file
49    ///
50    /// # Returns
51    ///
52    /// Validated configuration
53    pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self> {
54        let contents = std::fs::read_to_string(path.as_ref())
55            .map_err(|e| TradingError::config(format!("Failed to read config file: {}", e)))?;
56
57        let config: Self = toml::from_str(&contents)
58            .map_err(|e| TradingError::config(format!("Failed to parse TOML config: {}", e)))?;
59
60        config
61            .validate()
62            .map_err(|e| TradingError::config(format!("Configuration validation failed: {}", e)))?;
63
64        Ok(config)
65    }
66
67    /// Load configuration from a JSON file
68    ///
69    /// # Arguments
70    ///
71    /// * `path` - Path to JSON configuration file
72    ///
73    /// # Returns
74    ///
75    /// Validated configuration
76    pub fn from_json_file(path: impl AsRef<Path>) -> Result<Self> {
77        let contents = std::fs::read_to_string(path.as_ref())
78            .map_err(|e| TradingError::config(format!("Failed to read config file: {}", e)))?;
79
80        let config: Self = serde_json::from_str(&contents)
81            .map_err(|e| TradingError::config(format!("Failed to parse JSON config: {}", e)))?;
82
83        config
84            .validate()
85            .map_err(|e| TradingError::config(format!("Configuration validation failed: {}", e)))?;
86
87        Ok(config)
88    }
89
90    /// Create a default configuration for testing
91    pub fn default_test_config() -> Self {
92        Self {
93            server: ServerConfig::default(),
94            broker: BrokerConfig::default(),
95            strategies: vec![],
96            risk: RiskConfig::default(),
97            database: DatabaseConfig::default(),
98            logging: LoggingConfig::default(),
99        }
100    }
101}
102
103// ============================================================================
104// Server Configuration
105// ============================================================================
106
107/// HTTP server configuration
108#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
109pub struct ServerConfig {
110    /// Server host
111    #[validate(length(min = 1))]
112    pub host: String,
113
114    /// Server port
115    #[validate(range(min = 1024, max = 65535))]
116    pub port: u16,
117
118    /// Enable HTTPS
119    pub enable_https: bool,
120
121    /// Maximum request size in bytes
122    #[validate(range(min = 1024, max = 104857600))] // 1KB to 100MB
123    pub max_request_size: usize,
124
125    /// Request timeout in seconds
126    #[validate(range(min = 1, max = 300))]
127    pub request_timeout_secs: u64,
128}
129
130impl Default for ServerConfig {
131    fn default() -> Self {
132        Self {
133            host: "127.0.0.1".to_string(),
134            port: 8080,
135            enable_https: false,
136            max_request_size: 10485760, // 10MB
137            request_timeout_secs: 30,
138        }
139    }
140}
141
142// ============================================================================
143// Broker Configuration
144// ============================================================================
145
146/// Broker API configuration
147#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
148pub struct BrokerConfig {
149    /// Broker name (alpaca, polygon, etc.)
150    #[validate(length(min = 1))]
151    pub name: String,
152
153    /// API base URL
154    #[validate(url)]
155    pub api_url: String,
156
157    /// WebSocket URL for real-time data
158    #[validate(url)]
159    pub ws_url: String,
160
161    /// API key (should be loaded from environment)
162    #[serde(skip_serializing)]
163    pub api_key: String,
164
165    /// API secret (should be loaded from environment)
166    #[serde(skip_serializing)]
167    pub api_secret: String,
168
169    /// Paper trading mode
170    pub paper_trading: bool,
171
172    /// Connection timeout in seconds
173    #[validate(range(min = 1, max = 60))]
174    pub connection_timeout_secs: u64,
175
176    /// Maximum retry attempts
177    #[validate(range(min = 0, max = 10))]
178    pub max_retry_attempts: u32,
179}
180
181impl Default for BrokerConfig {
182    fn default() -> Self {
183        Self {
184            name: "alpaca".to_string(),
185            api_url: "https://paper-api.alpaca.markets".to_string(),
186            ws_url: "wss://stream.data.alpaca.markets".to_string(),
187            api_key: String::new(),
188            api_secret: String::new(),
189            paper_trading: true,
190            connection_timeout_secs: 30,
191            max_retry_attempts: 3,
192        }
193    }
194}
195
196// ============================================================================
197// Strategy Configuration
198// ============================================================================
199
200/// Individual strategy configuration
201#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
202pub struct StrategyConfig {
203    /// Strategy ID
204    #[validate(length(min = 1))]
205    pub id: String,
206
207    /// Strategy type (momentum, mean_reversion, etc.)
208    #[validate(length(min = 1))]
209    pub strategy_type: String,
210
211    /// Symbols to trade
212    #[validate(length(min = 1))]
213    pub symbols: Vec<String>,
214
215    /// Enable this strategy
216    pub enabled: bool,
217
218    /// Strategy-specific parameters
219    pub parameters: serde_json::Value,
220}
221
222// ============================================================================
223// Risk Configuration
224// ============================================================================
225
226/// Risk management configuration
227#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
228pub struct RiskConfig {
229    /// Maximum position size as percentage of portfolio (0.0-1.0)
230    #[validate(range(min = 0.0, max = 1.0))]
231    pub max_position_size: f64,
232
233    /// Maximum daily loss as percentage of portfolio (0.0-1.0)
234    #[validate(range(min = 0.0, max = 1.0))]
235    pub max_daily_loss: f64,
236
237    /// Maximum drawdown as percentage (0.0-1.0)
238    #[validate(range(min = 0.0, max = 1.0))]
239    pub max_drawdown: f64,
240
241    /// Maximum leverage allowed
242    #[validate(range(min = 1.0, max = 10.0))]
243    pub max_leverage: f64,
244
245    /// Stop loss percentage (0.0-1.0)
246    #[validate(range(min = 0.0, max = 1.0))]
247    pub default_stop_loss: f64,
248
249    /// Take profit percentage (0.0-1.0)
250    #[validate(range(min = 0.0, max = 1.0))]
251    pub default_take_profit: f64,
252
253    /// Maximum sector concentration (0.0-1.0)
254    #[validate(range(min = 0.0, max = 1.0))]
255    pub max_sector_concentration: f64,
256
257    /// Enable circuit breakers
258    pub enable_circuit_breakers: bool,
259
260    /// Circuit breaker cool-down period in seconds
261    #[validate(range(min = 60, max = 86400))] // 1 minute to 1 day
262    pub circuit_breaker_cooldown_secs: u64,
263}
264
265impl Default for RiskConfig {
266    fn default() -> Self {
267        Self {
268            max_position_size: 0.1,        // 10% per position
269            max_daily_loss: 0.05,          // 5% max daily loss
270            max_drawdown: 0.2,             // 20% max drawdown
271            max_leverage: 1.0,             // No leverage
272            default_stop_loss: 0.02,       // 2% stop loss
273            default_take_profit: 0.05,     // 5% take profit
274            max_sector_concentration: 0.3, // 30% max per sector
275            enable_circuit_breakers: true,
276            circuit_breaker_cooldown_secs: 300, // 5 minutes
277        }
278    }
279}
280
281// ============================================================================
282// Database Configuration
283// ============================================================================
284
285/// Database configuration
286#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
287pub struct DatabaseConfig {
288    /// Database type (sqlite, postgres, etc.)
289    #[validate(length(min = 1))]
290    pub database_type: String,
291
292    /// Connection URL
293    #[validate(length(min = 1))]
294    pub connection_url: String,
295
296    /// Maximum number of connections in the pool
297    #[validate(range(min = 1, max = 100))]
298    pub max_connections: u32,
299
300    /// Connection timeout in seconds
301    #[validate(range(min = 1, max = 60))]
302    pub connection_timeout_secs: u64,
303}
304
305impl Default for DatabaseConfig {
306    fn default() -> Self {
307        Self {
308            database_type: "sqlite".to_string(),
309            connection_url: "sqlite::memory:".to_string(),
310            max_connections: 10,
311            connection_timeout_secs: 30,
312        }
313    }
314}
315
316// ============================================================================
317// Logging Configuration
318// ============================================================================
319
320/// Logging configuration
321#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
322pub struct LoggingConfig {
323    /// Log level (trace, debug, info, warn, error)
324    #[validate(length(min = 1))]
325    pub level: String,
326
327    /// Log format (json, pretty, compact)
328    #[validate(length(min = 1))]
329    pub format: String,
330
331    /// Enable file logging
332    pub enable_file_logging: bool,
333
334    /// Log file path
335    pub log_file_path: Option<String>,
336
337    /// Maximum log file size in bytes
338    #[validate(range(min = 1048576, max = 1073741824))] // 1MB to 1GB
339    pub max_log_file_size: usize,
340
341    /// Number of log files to keep
342    #[validate(range(min = 1, max = 100))]
343    pub log_file_count: usize,
344}
345
346impl Default for LoggingConfig {
347    fn default() -> Self {
348        Self {
349            level: "info".to_string(),
350            format: "pretty".to_string(),
351            enable_file_logging: false,
352            log_file_path: None,
353            max_log_file_size: 10485760, // 10MB
354            log_file_count: 5,
355        }
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use std::io::Write;
363    use tempfile::NamedTempFile;
364
365    #[test]
366    fn test_default_config() {
367        let config = AppConfig::default_test_config();
368        assert!(config.validate().is_ok());
369    }
370
371    #[test]
372    fn test_server_config_validation() {
373        let mut config = ServerConfig::default();
374        assert!(config.validate().is_ok());
375
376        // Invalid port
377        config.port = 80; // Below minimum 1024
378        assert!(config.validate().is_err());
379
380        // Valid port
381        config.port = 8080;
382        assert!(config.validate().is_ok());
383    }
384
385    #[test]
386    fn test_risk_config_validation() {
387        let mut config = RiskConfig::default();
388        assert!(config.validate().is_ok());
389
390        // Invalid max_position_size
391        config.max_position_size = 1.5; // Above 1.0
392        assert!(config.validate().is_err());
393
394        // Valid max_position_size
395        config.max_position_size = 0.2;
396        assert!(config.validate().is_ok());
397    }
398
399    #[test]
400    fn test_load_from_toml() {
401        let toml_config = r#"
402[server]
403host = "0.0.0.0"
404port = 8080
405enable_https = false
406max_request_size = 10485760
407request_timeout_secs = 30
408
409[broker]
410name = "alpaca"
411api_url = "https://paper-api.alpaca.markets"
412ws_url = "wss://stream.data.alpaca.markets"
413api_key = "test_key"
414api_secret = "test_secret"
415paper_trading = true
416connection_timeout_secs = 30
417max_retry_attempts = 3
418
419[[strategies]]
420id = "momentum_1"
421strategy_type = "momentum"
422symbols = ["AAPL", "GOOGL"]
423enabled = true
424parameters = {}
425
426[risk]
427max_position_size = 0.1
428max_daily_loss = 0.05
429max_drawdown = 0.2
430max_leverage = 1.0
431default_stop_loss = 0.02
432default_take_profit = 0.05
433max_sector_concentration = 0.3
434enable_circuit_breakers = true
435circuit_breaker_cooldown_secs = 300
436
437[database]
438database_type = "sqlite"
439connection_url = "sqlite::memory:"
440max_connections = 10
441connection_timeout_secs = 30
442
443[logging]
444level = "info"
445format = "pretty"
446enable_file_logging = false
447max_log_file_size = 10485760
448log_file_count = 5
449        "#;
450
451        let mut temp_file = NamedTempFile::new().unwrap();
452        temp_file.write_all(toml_config.as_bytes()).unwrap();
453        temp_file.flush().unwrap();
454
455        let config = AppConfig::from_toml_file(temp_file.path()).unwrap();
456        assert_eq!(config.server.port, 8080);
457        assert_eq!(config.broker.name, "alpaca");
458        assert_eq!(config.strategies.len(), 1);
459        assert_eq!(config.risk.max_position_size, 0.1);
460    }
461
462    #[test]
463    fn test_broker_config_default() {
464        let config = BrokerConfig::default();
465        assert_eq!(config.name, "alpaca");
466        assert!(config.paper_trading);
467        assert_eq!(config.max_retry_attempts, 3);
468    }
469
470    #[test]
471    fn test_logging_config_validation() {
472        let mut config = LoggingConfig::default();
473        assert!(config.validate().is_ok());
474
475        // Invalid log file size
476        config.max_log_file_size = 100; // Below minimum 1MB
477        assert!(config.validate().is_err());
478
479        // Valid log file size
480        config.max_log_file_size = 10485760; // 10MB
481        assert!(config.validate().is_ok());
482    }
483}