nt_neural/
lib.rs

1//! Neural forecasting models for time series prediction in trading.
2//!
3//! This crate provides high-performance neural network models optimized for
4//! financial time series forecasting with GPU acceleration support.
5//!
6//! # Models
7//!
8//! - **NHITS**: Neural Hierarchical Interpolation for Time Series
9//! - **LSTM-Attention**: LSTM with multi-head attention mechanism
10//! - **Transformer**: Transformer architecture for time series
11//! - **GRU**: Gated Recurrent Unit (simpler than LSTM)
12//! - **TCN**: Temporal Convolutional Network
13//! - **DeepAR**: Probabilistic forecasting with LSTM
14//! - **N-BEATS**: Pure MLP with interpretable decomposition
15//! - **Prophet**: Time series decomposition (trend + seasonality)
16//!
17//! # Features
18//!
19//! - GPU acceleration (CUDA, Metal)
20//! - Mixed precision training (FP16/FP32)
21//! - Quantile regression for confidence intervals
22//! - Model checkpointing and versioning
23//! - Integration with AgentDB for model storage
24//! - SIMD acceleration for CPU operations (requires nightly Rust)
25//!
26//! # Examples
27//!
28//! ```no_run
29//! use nt_neural::{NHITSModel, ModelConfig, TrainingConfig};
30//!
31//! # async fn example() -> anyhow::Result<()> {
32//! // Create model configuration
33//! let _config = ModelConfig {
34//!     input_size: 168,  // 1 week of hourly data
35//!     horizon: 24,      // 24 hour forecast
36//!     hidden_size: 512,
37//!     ..Default::default()
38//! };
39//!
40//! // Initialize model
41//! let model = NHITSModel::new(config)?;
42//!
43//! // Train model (data preparation omitted)
44//! // let trained_model = model.train(train_data, val_data).await?;
45//! # Ok(())
46//! # }
47//! ```
48
49#![cfg_attr(feature = "simd", feature(portable_simd))]
50
51pub mod error;
52#[cfg(feature = "candle")]
53pub mod models;
54pub mod training;
55pub mod inference;
56pub mod utils;
57pub mod storage;
58
59#[cfg(not(feature = "candle"))]
60pub mod stubs;
61
62// Re-export main types
63pub use error::{NeuralError, Result};
64
65#[cfg(feature = "candle")]
66pub use models::{
67    ModelConfig, ModelType,
68};
69
70#[cfg(feature = "candle")]
71pub use models::{
72    nhits::{NHITSModel, NHITSConfig},
73    lstm_attention::{LSTMAttentionModel, LSTMAttentionConfig},
74    transformer::{TransformerModel, TransformerConfig},
75};
76
77// Re-export new models
78#[cfg(feature = "candle")]
79pub use models::{
80    gru::{GRUModel, GRUConfig},
81    tcn::{TCNModel, TCNConfig},
82    deepar::{DeepARModel, DeepARConfig, DistributionType},
83    nbeats::{NBeatsModel, NBeatsConfig, StackType},
84    prophet::{ProphetModel, ProphetConfig, GrowthModel},
85};
86
87#[cfg(feature = "candle")]
88pub use training::{
89    Trainer, TrainingConfig, TrainingMetrics,
90    data_loader::{DataLoader, TimeSeriesDataset},
91    optimizer::{Optimizer, OptimizerConfig},
92    nhits_trainer::{NHITSTrainer, NHITSTrainingConfig},
93};
94pub use training::TrainingConfig;
95pub use training::TrainingMetrics;
96
97#[cfg(feature = "candle")]
98pub use inference::{
99    Predictor, BatchPredictor,
100};
101pub use inference::PredictionResult;
102
103use serde::Serialize;
104use std::path::Path;
105
106// Re-export Device and Tensor from appropriate source
107#[cfg(feature = "candle")]
108pub use candle_core::{Device, Tensor};
109#[cfg(not(feature = "candle"))]
110pub use stubs::{Device, Tensor};
111
112/// Initialize the neural module with optimal device selection
113#[cfg(feature = "candle")]
114pub fn initialize() -> Result<Device> {
115    #[cfg(feature = "cuda")]
116    {
117        if let Ok(device) = Device::new_cuda(0) {
118            tracing::info!("Neural module initialized with CUDA GPU acceleration");
119            return Ok(device);
120        }
121    }
122
123    #[cfg(feature = "metal")]
124    {
125        if let Ok(device) = Device::new_metal(0) {
126            tracing::info!("Neural module initialized with Metal GPU acceleration");
127            return Ok(device);
128        }
129    }
130
131    #[cfg(feature = "accelerate")]
132    {
133        tracing::info!("Neural module initialized with Accelerate CPU optimization");
134        return Ok(Device::Cpu);
135    }
136
137    tracing::warn!("Neural module initialized with standard CPU (no GPU acceleration)");
138    Ok(Device::Cpu)
139}
140
141/// Placeholder when candle is not enabled
142#[cfg(not(feature = "candle"))]
143pub fn initialize() -> Result<()> {
144    tracing::warn!("Neural module compiled without candle support");
145    Ok(())
146}
147
148/// Model version information for tracking and reproducibility
149#[cfg(feature = "candle")]
150#[derive(Debug, Clone, Serialize, serde::Deserialize)]
151pub struct ModelVersion {
152    pub version: String,
153    pub model_id: String,
154    pub created_at: chrono::DateTime<chrono::Utc>,
155    pub model_type: ModelType,
156    pub config: serde_json::Value,
157    pub metrics: Option<TrainingMetrics>,
158}
159
160#[cfg(feature = "candle")]
161impl ModelVersion {
162    pub fn new(model_type: ModelType, config: serde_json::Value) -> Self {
163        let model_id = uuid::Uuid::new_v4().to_string();
164        Self {
165            version: "1.0.0".to_string(),
166            model_id,
167            created_at: chrono::Utc::now(),
168            model_type,
169            config,
170            metrics: None,
171        }
172    }
173
174    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
175        let json = serde_json::to_string_pretty(self)?;
176        std::fs::write(path, json)?;
177        Ok(())
178    }
179
180    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
181        let json = std::fs::read_to_string(path)?;
182        let version = serde_json::from_str(&json)?;
183        Ok(version)
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    #[cfg(feature = "candle")]
193    fn test_initialize_device() {
194        let device = initialize().unwrap();
195        assert!(device.is_cpu() || device.is_cuda() || device.is_metal());
196    }
197
198    #[test]
199    #[cfg(not(feature = "candle"))]
200    fn test_initialize_without_candle() {
201        let result = initialize();
202        assert!(result.is_ok());
203    }
204
205    #[test]
206    fn test_model_version_serialization() {
207        let model_config = serde_json::json!({
208            "input_size": 168,
209            "horizon": 24,
210        });
211
212        let version = ModelVersion::new(ModelType::NHITS, model_config);
213
214        // Serialize to JSON
215        let json = serde_json::to_string(&version).unwrap();
216        assert!(json.contains("version"));
217        assert!(json.contains("model_id"));
218
219        // Deserialize back
220        let deserialized: ModelVersion = serde_json::from_str(&json).unwrap();
221        assert_eq!(version.version, deserialized.version);
222        assert_eq!(version.model_id, deserialized.model_id);
223    }
224
225    #[test]
226    fn test_new_model_types() {
227        // Test that all new model types are properly defined
228        let _gru = ModelType::GRU;
229        let _tcn = ModelType::TCN;
230        let _deepar = ModelType::DeepAR;
231        let _nbeats = ModelType::NBeats;
232        let _prophet = ModelType::Prophet;
233    }
234}