1#![cfg_attr(feature = "simd", feature(portable_simd))]
50
51pub mod error;
52#[cfg(feature = "candle")]
53pub mod models;
54pub mod training;
55pub mod inference;
56pub mod utils;
57pub mod storage;
58
59#[cfg(not(feature = "candle"))]
60pub mod stubs;
61
62pub use error::{NeuralError, Result};
64
65#[cfg(feature = "candle")]
66pub use models::{
67 ModelConfig, ModelType,
68};
69
70#[cfg(feature = "candle")]
71pub use models::{
72 nhits::{NHITSModel, NHITSConfig},
73 lstm_attention::{LSTMAttentionModel, LSTMAttentionConfig},
74 transformer::{TransformerModel, TransformerConfig},
75};
76
77#[cfg(feature = "candle")]
79pub use models::{
80 gru::{GRUModel, GRUConfig},
81 tcn::{TCNModel, TCNConfig},
82 deepar::{DeepARModel, DeepARConfig, DistributionType},
83 nbeats::{NBeatsModel, NBeatsConfig, StackType},
84 prophet::{ProphetModel, ProphetConfig, GrowthModel},
85};
86
87#[cfg(feature = "candle")]
88pub use training::{
89 Trainer, TrainingConfig, TrainingMetrics,
90 data_loader::{DataLoader, TimeSeriesDataset},
91 optimizer::{Optimizer, OptimizerConfig},
92 nhits_trainer::{NHITSTrainer, NHITSTrainingConfig},
93};
94pub use training::TrainingConfig;
95pub use training::TrainingMetrics;
96
97#[cfg(feature = "candle")]
98pub use inference::{
99 Predictor, BatchPredictor,
100};
101pub use inference::PredictionResult;
102
103use serde::Serialize;
104use std::path::Path;
105
106#[cfg(feature = "candle")]
108pub use candle_core::{Device, Tensor};
109#[cfg(not(feature = "candle"))]
110pub use stubs::{Device, Tensor};
111
112#[cfg(feature = "candle")]
114pub fn initialize() -> Result<Device> {
115 #[cfg(feature = "cuda")]
116 {
117 if let Ok(device) = Device::new_cuda(0) {
118 tracing::info!("Neural module initialized with CUDA GPU acceleration");
119 return Ok(device);
120 }
121 }
122
123 #[cfg(feature = "metal")]
124 {
125 if let Ok(device) = Device::new_metal(0) {
126 tracing::info!("Neural module initialized with Metal GPU acceleration");
127 return Ok(device);
128 }
129 }
130
131 #[cfg(feature = "accelerate")]
132 {
133 tracing::info!("Neural module initialized with Accelerate CPU optimization");
134 return Ok(Device::Cpu);
135 }
136
137 tracing::warn!("Neural module initialized with standard CPU (no GPU acceleration)");
138 Ok(Device::Cpu)
139}
140
141#[cfg(not(feature = "candle"))]
143pub fn initialize() -> Result<()> {
144 tracing::warn!("Neural module compiled without candle support");
145 Ok(())
146}
147
148#[cfg(feature = "candle")]
150#[derive(Debug, Clone, Serialize, serde::Deserialize)]
151pub struct ModelVersion {
152 pub version: String,
153 pub model_id: String,
154 pub created_at: chrono::DateTime<chrono::Utc>,
155 pub model_type: ModelType,
156 pub config: serde_json::Value,
157 pub metrics: Option<TrainingMetrics>,
158}
159
160#[cfg(feature = "candle")]
161impl ModelVersion {
162 pub fn new(model_type: ModelType, config: serde_json::Value) -> Self {
163 let model_id = uuid::Uuid::new_v4().to_string();
164 Self {
165 version: "1.0.0".to_string(),
166 model_id,
167 created_at: chrono::Utc::now(),
168 model_type,
169 config,
170 metrics: None,
171 }
172 }
173
174 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
175 let json = serde_json::to_string_pretty(self)?;
176 std::fs::write(path, json)?;
177 Ok(())
178 }
179
180 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
181 let json = std::fs::read_to_string(path)?;
182 let version = serde_json::from_str(&json)?;
183 Ok(version)
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 #[cfg(feature = "candle")]
193 fn test_initialize_device() {
194 let device = initialize().unwrap();
195 assert!(device.is_cpu() || device.is_cuda() || device.is_metal());
196 }
197
198 #[test]
199 #[cfg(not(feature = "candle"))]
200 fn test_initialize_without_candle() {
201 let result = initialize();
202 assert!(result.is_ok());
203 }
204
205 #[test]
206 fn test_model_version_serialization() {
207 let model_config = serde_json::json!({
208 "input_size": 168,
209 "horizon": 24,
210 });
211
212 let version = ModelVersion::new(ModelType::NHITS, model_config);
213
214 let json = serde_json::to_string(&version).unwrap();
216 assert!(json.contains("version"));
217 assert!(json.contains("model_id"));
218
219 let deserialized: ModelVersion = serde_json::from_str(&json).unwrap();
221 assert_eq!(version.version, deserialized.version);
222 assert_eq!(version.model_id, deserialized.model_id);
223 }
224
225 #[test]
226 fn test_new_model_types() {
227 let _gru = ModelType::GRU;
229 let _tcn = ModelType::TCN;
230 let _deepar = ModelType::DeepAR;
231 let _nbeats = ModelType::NBeats;
232 let _prophet = ModelType::Prophet;
233 }
234}