#![allow(dead_code)]
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use tracing::{debug, info};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub general: GeneralConfig,
pub model: ModelConfig,
pub training: TrainingConfig,
pub hub: HubConfig,
pub benchmark: BenchmarkConfig,
pub dev: DevConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneralConfig {
pub output_dir: PathBuf,
pub cache_dir: PathBuf,
pub default_device: String,
pub num_workers: usize,
pub memory_limit_gb: Option<f64>,
pub show_progress: bool,
pub default_dtype: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub default_format: String,
pub optimization: OptimizationConfig,
pub validation: ValidationConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationConfig {
pub auto_optimize: bool,
pub quantization: QuantizationConfig,
pub pruning: PruningConfig,
pub fusion: FusionConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationConfig {
pub enabled: bool,
pub method: String,
pub precision: String,
pub calibration_samples: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PruningConfig {
pub enabled: bool,
pub sparsity: f64,
pub method: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusionConfig {
pub enabled: bool,
pub patterns: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationConfig {
pub enabled: bool,
pub dataset_path: Option<PathBuf>,
pub num_samples: usize,
pub accuracy_threshold: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
pub config_dir: PathBuf,
pub checkpoint_dir: PathBuf,
pub log_dir: PathBuf,
pub auto_resume: bool,
pub checkpoint_frequency: usize,
pub early_stopping_patience: usize,
pub mixed_precision: bool,
pub distributed: DistributedConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributedConfig {
pub backend: String,
pub master_addr: String,
pub master_port: u16,
pub auto_detect: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HubConfig {
pub api_endpoint: String,
pub auth_token: Option<String>,
pub organization: Option<String>,
pub cache_dir: PathBuf,
pub verify_signatures: bool,
pub timeout_seconds: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkConfig {
pub warmup_iterations: usize,
pub benchmark_iterations: usize,
pub batch_sizes: Vec<usize>,
pub track_memory: bool,
pub track_power: bool,
pub output_dir: PathBuf,
pub detailed_profiling: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DevConfig {
pub enabled: bool,
pub debug_logging: bool,
pub experimental_features: bool,
pub codegen: CodegenConfig,
pub testing: TestingConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodegenConfig {
pub enabled: bool,
pub output_dir: PathBuf,
pub templates_dir: PathBuf,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestingConfig {
pub enabled: bool,
pub test_data_dir: PathBuf,
pub numerical_tolerance: f64,
}
impl Default for Config {
fn default() -> Self {
let home_dir = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
let torsh_dir = home_dir.join(".torsh");
Self {
general: GeneralConfig {
output_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
cache_dir: torsh_dir.join("cache"),
default_device: "cpu".to_string(),
num_workers: std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4),
memory_limit_gb: None,
show_progress: true,
default_dtype: "f32".to_string(),
},
model: ModelConfig {
default_format: "torsh".to_string(),
optimization: OptimizationConfig {
auto_optimize: false,
quantization: QuantizationConfig {
enabled: false,
method: "dynamic".to_string(),
precision: "int8".to_string(),
calibration_samples: 1000,
},
pruning: PruningConfig {
enabled: false,
sparsity: 0.5,
method: "magnitude".to_string(),
},
fusion: FusionConfig {
enabled: true,
patterns: vec!["conv_bn_relu".to_string(), "linear_relu".to_string()],
},
},
validation: ValidationConfig {
enabled: true,
dataset_path: None,
num_samples: 1000,
accuracy_threshold: 0.95,
},
},
training: TrainingConfig {
config_dir: torsh_dir.join("configs"),
checkpoint_dir: PathBuf::from("./checkpoints"),
log_dir: PathBuf::from("./logs"),
auto_resume: false,
checkpoint_frequency: 1,
early_stopping_patience: 10,
mixed_precision: true,
distributed: DistributedConfig {
backend: "nccl".to_string(),
master_addr: "localhost".to_string(),
master_port: 29500,
auto_detect: true,
},
},
hub: HubConfig {
api_endpoint: "https://hub.torsh.dev".to_string(),
auth_token: None,
organization: None,
cache_dir: torsh_dir.join("hub"),
verify_signatures: true,
timeout_seconds: 300,
},
benchmark: BenchmarkConfig {
warmup_iterations: 10,
benchmark_iterations: 100,
batch_sizes: vec![1, 4, 8, 16, 32, 64],
track_memory: true,
track_power: false,
output_dir: PathBuf::from("./benchmarks"),
detailed_profiling: false,
},
dev: DevConfig {
enabled: false,
debug_logging: false,
experimental_features: false,
codegen: CodegenConfig {
enabled: false,
output_dir: PathBuf::from("./generated"),
templates_dir: torsh_dir.join("templates"),
},
testing: TestingConfig {
enabled: true,
test_data_dir: PathBuf::from("./test_data"),
numerical_tolerance: 1e-6,
},
},
}
}
}
pub async fn load_config(config_path: Option<&Path>) -> Result<Config> {
let config_path = if let Some(path) = config_path {
path.to_path_buf()
} else {
get_default_config_path()?
};
if config_path.exists() {
debug!("Loading configuration from: {}", config_path.display());
load_config_from_file(&config_path).await
} else {
info!("Configuration file not found, using defaults");
let config = Config::default();
if let Some(parent) = config_path.parent() {
tokio::fs::create_dir_all(parent).await.with_context(|| {
format!("Failed to create config directory: {}", parent.display())
})?;
}
save_config(&config, &config_path)
.await
.with_context(|| "Failed to save default configuration")?;
Ok(config)
}
}
async fn load_config_from_file(path: &Path) -> Result<Config> {
let content = tokio::fs::read_to_string(path)
.await
.with_context(|| format!("Failed to read config file: {}", path.display()))?;
let config =
match path.extension().and_then(|ext| ext.to_str()) {
Some("yaml") | Some("yml") => serde_yaml::from_str(&content)
.with_context(|| "Failed to parse YAML configuration")?,
Some("json") => serde_json::from_str(&content)
.with_context(|| "Failed to parse JSON configuration")?,
Some("toml") => {
toml::from_str(&content).with_context(|| "Failed to parse TOML configuration")?
}
_ => {
if content.trim_start().starts_with('{') {
serde_json::from_str(&content)
.with_context(|| "Failed to parse JSON configuration")?
} else {
serde_yaml::from_str(&content)
.with_context(|| "Failed to parse YAML configuration")?
}
}
};
Ok(config)
}
pub async fn save_config(config: &Config, path: &Path) -> Result<()> {
let content = match path.extension().and_then(|ext| ext.to_str()) {
Some("json") => serde_json::to_string_pretty(config)
.with_context(|| "Failed to serialize configuration to JSON")?,
Some("toml") => toml::to_string_pretty(config)
.with_context(|| "Failed to serialize configuration to TOML")?,
_ => {
serde_yaml::to_string(config)
.with_context(|| "Failed to serialize configuration to YAML")?
}
};
tokio::fs::write(path, content)
.await
.with_context(|| format!("Failed to write config file: {}", path.display()))?;
info!("Configuration saved to: {}", path.display());
Ok(())
}
fn get_default_config_path() -> Result<PathBuf> {
let config_dir = dirs::config_dir()
.or_else(|| dirs::home_dir().map(|h| h.join(".config")))
.unwrap_or_else(|| PathBuf::from("."));
Ok(config_dir.join("torsh").join("config.yaml"))
}
pub async fn init_config_dirs(config: &Config) -> Result<()> {
let dirs = [
&config.general.cache_dir,
&config.training.config_dir,
&config.training.checkpoint_dir,
&config.training.log_dir,
&config.hub.cache_dir,
&config.benchmark.output_dir,
];
for dir in dirs {
if !dir.exists() {
tokio::fs::create_dir_all(dir)
.await
.with_context(|| format!("Failed to create directory: {}", dir.display()))?;
debug!("Created directory: {}", dir.display());
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[tokio::test]
async fn test_default_config() {
let config = Config::default();
assert_eq!(config.general.default_device, "cpu");
assert_eq!(config.model.default_format, "torsh");
}
#[tokio::test]
async fn test_config_serialization() {
let config = Config::default();
let yaml = serde_yaml::to_string(&config).unwrap();
let parsed: Config = serde_yaml::from_str(&yaml).unwrap();
assert_eq!(config.general.default_device, parsed.general.default_device);
let json = serde_json::to_string_pretty(&config).unwrap();
let parsed: Config = serde_json::from_str(&json).unwrap();
assert_eq!(config.general.default_device, parsed.general.default_device);
}
#[tokio::test]
async fn test_config_file_operations() {
let temp_dir = tempdir().unwrap();
let config_path = temp_dir.path().join("test_config.yaml");
let config = Config::default();
save_config(&config, &config_path).await.unwrap();
assert!(config_path.exists());
let loaded_config = load_config_from_file(&config_path).await.unwrap();
assert_eq!(
config.general.default_device,
loaded_config.general.default_device
);
}
}