use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct Config {
pub strategy: String,
pub domains: HashMap<String, usize>,
pub output_format: String,
pub validate: bool,
pub debug: bool,
pub colored: bool,
pub repl: ReplConfig,
pub watch: WatchConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct ReplConfig {
pub prompt: String,
pub history_file: String,
pub max_history: usize,
pub auto_save: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct WatchConfig {
pub debounce_ms: u64,
pub clear_screen: bool,
pub show_timestamps: bool,
}
impl Default for Config {
fn default() -> Self {
let mut domains = HashMap::new();
domains.insert("D".to_string(), 100);
Self {
strategy: "soft_differentiable".to_string(),
domains,
output_format: "graph".to_string(),
validate: false,
debug: false,
colored: true,
repl: ReplConfig::default(),
watch: WatchConfig::default(),
}
}
}
impl Default for ReplConfig {
fn default() -> Self {
Self {
prompt: "tensorlogic> ".to_string(),
history_file: ".tensorlogic_history".to_string(),
max_history: 1000,
auto_save: true,
}
}
}
impl Default for WatchConfig {
fn default() -> Self {
Self {
debounce_ms: 500,
clear_screen: true,
show_timestamps: true,
}
}
}
impl Config {
pub fn load(path: &Path) -> Result<Self> {
let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read config file: {}", path.display()))?;
toml::from_str(&content)
.with_context(|| format!("Failed to parse config file: {}", path.display()))
}
pub fn save(&self, path: &Path) -> Result<()> {
let content = toml::to_string_pretty(self).context("Failed to serialize configuration")?;
fs::write(path, content)
.with_context(|| format!("Failed to write config file: {}", path.display()))
}
pub fn load_default() -> Self {
if let Ok(path) = std::env::var("TENSORLOGIC_CONFIG") {
if let Ok(config) = Self::load(Path::new(&path)) {
return config;
}
}
let current_config = PathBuf::from(".tensorlogicrc");
if current_config.exists() {
if let Ok(config) = Self::load(¤t_config) {
return config;
}
}
if let Some(home) = dirs::home_dir() {
let home_config = home.join(".tensorlogicrc");
if home_config.exists() {
if let Ok(config) = Self::load(&home_config) {
return config;
}
}
}
Self::default()
}
pub fn config_path() -> PathBuf {
let current = PathBuf::from(".tensorlogicrc");
if current.exists() {
return current;
}
if let Some(home) = dirs::home_dir() {
home.join(".tensorlogicrc")
} else {
current
}
}
pub fn create_default() -> Result<PathBuf> {
let config = Self::default();
let path = Self::config_path();
config.save(&path)?;
Ok(path)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = Config::default();
assert_eq!(config.strategy, "soft_differentiable");
assert!(config.domains.contains_key("D"));
assert_eq!(config.output_format, "graph");
}
#[test]
fn test_serialize_deserialize() {
let config = Config::default();
let toml_str = toml::to_string(&config).unwrap();
let deserialized: Config = toml::from_str(&toml_str).unwrap();
assert_eq!(config.strategy, deserialized.strategy);
}
}