#![cfg_attr(feature = "simd", feature(portable_simd))]
pub mod error;
#[cfg(feature = "candle")]
pub mod models;
pub mod training;
pub mod inference;
pub mod utils;
pub mod storage;
#[cfg(not(feature = "candle"))]
pub mod stubs;
pub use error::{NeuralError, Result};
#[cfg(feature = "candle")]
pub use models::{
ModelConfig, ModelType,
};
#[cfg(feature = "candle")]
pub use models::{
nhits::{NHITSModel, NHITSConfig},
lstm_attention::{LSTMAttentionModel, LSTMAttentionConfig},
transformer::{TransformerModel, TransformerConfig},
};
#[cfg(feature = "candle")]
pub use models::{
gru::{GRUModel, GRUConfig},
tcn::{TCNModel, TCNConfig},
deepar::{DeepARModel, DeepARConfig, DistributionType},
nbeats::{NBeatsModel, NBeatsConfig, StackType},
prophet::{ProphetModel, ProphetConfig, GrowthModel},
};
#[cfg(feature = "candle")]
pub use training::{
Trainer, TrainingConfig, TrainingMetrics,
data_loader::{DataLoader, TimeSeriesDataset},
optimizer::{Optimizer, OptimizerConfig},
nhits_trainer::{NHITSTrainer, NHITSTrainingConfig},
};
pub use training::TrainingConfig;
pub use training::TrainingMetrics;
#[cfg(feature = "candle")]
pub use inference::{
Predictor, BatchPredictor,
};
pub use inference::PredictionResult;
use serde::Serialize;
use std::path::Path;
#[cfg(feature = "candle")]
pub use candle_core::{Device, Tensor};
#[cfg(not(feature = "candle"))]
pub use stubs::{Device, Tensor};
#[cfg(feature = "candle")]
pub fn initialize() -> Result<Device> {
#[cfg(feature = "cuda")]
{
if let Ok(device) = Device::new_cuda(0) {
tracing::info!("Neural module initialized with CUDA GPU acceleration");
return Ok(device);
}
}
#[cfg(feature = "metal")]
{
if let Ok(device) = Device::new_metal(0) {
tracing::info!("Neural module initialized with Metal GPU acceleration");
return Ok(device);
}
}
#[cfg(feature = "accelerate")]
{
tracing::info!("Neural module initialized with Accelerate CPU optimization");
return Ok(Device::Cpu);
}
tracing::warn!("Neural module initialized with standard CPU (no GPU acceleration)");
Ok(Device::Cpu)
}
#[cfg(not(feature = "candle"))]
pub fn initialize() -> Result<()> {
tracing::warn!("Neural module compiled without candle support");
Ok(())
}
#[cfg(feature = "candle")]
#[derive(Debug, Clone, Serialize, serde::Deserialize)]
pub struct ModelVersion {
pub version: String,
pub model_id: String,
pub created_at: chrono::DateTime<chrono::Utc>,
pub model_type: ModelType,
pub config: serde_json::Value,
pub metrics: Option<TrainingMetrics>,
}
#[cfg(feature = "candle")]
impl ModelVersion {
pub fn new(model_type: ModelType, config: serde_json::Value) -> Self {
let model_id = uuid::Uuid::new_v4().to_string();
Self {
version: "1.0.0".to_string(),
model_id,
created_at: chrono::Utc::now(),
model_type,
config,
metrics: None,
}
}
pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
let json = serde_json::to_string_pretty(self)?;
std::fs::write(path, json)?;
Ok(())
}
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let json = std::fs::read_to_string(path)?;
let version = serde_json::from_str(&json)?;
Ok(version)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "candle")]
fn test_initialize_device() {
let device = initialize().unwrap();
assert!(device.is_cpu() || device.is_cuda() || device.is_metal());
}
#[test]
#[cfg(not(feature = "candle"))]
fn test_initialize_without_candle() {
let result = initialize();
assert!(result.is_ok());
}
#[test]
fn test_model_version_serialization() {
let model_config = serde_json::json!({
"input_size": 168,
"horizon": 24,
});
let version = ModelVersion::new(ModelType::NHITS, model_config);
let json = serde_json::to_string(&version).unwrap();
assert!(json.contains("version"));
assert!(json.contains("model_id"));
let deserialized: ModelVersion = serde_json::from_str(&json).unwrap();
assert_eq!(version.version, deserialized.version);
assert_eq!(version.model_id, deserialized.model_id);
}
#[test]
fn test_new_model_types() {
let _gru = ModelType::GRU;
let _tcn = ModelType::TCN;
let _deepar = ModelType::DeepAR;
let _nbeats = ModelType::NBeats;
let _prophet = ModelType::Prophet;
}
}