cortenforge_tools/
config.rs

1use std::path::{Path, PathBuf};
2
3use serde::Deserialize;
4
5const DEFAULT_CONFIG_NAME: &str = "cortenforge-tools.toml";
6const DEFAULT_TRAIN_TEMPLATE: &str =
7    "cargo train_hp --model ${MODEL} --batch-size ${BATCH} --log-every ${LOG_EVERY} ${EXTRA_ARGS}";
8
9#[derive(Debug, Clone)]
10pub struct ToolConfig {
11    pub sim_bin: PathBuf,
12    pub train_bin: PathBuf,
13    pub assets_root: PathBuf,
14    pub captures_root: PathBuf,
15    pub captures_filtered_root: PathBuf,
16    pub warehouse_manifest: PathBuf,
17    pub logs_root: PathBuf,
18    pub metrics_path: PathBuf,
19    pub train_log_path: PathBuf,
20    pub train_status_paths: Vec<PathBuf>,
21    pub datagen_args: Vec<String>,
22    pub training_args: Vec<String>,
23    pub warehouse_train_template: String,
24    pub ui_title: String,
25}
26
27impl Default for ToolConfig {
28    fn default() -> Self {
29        let assets_root = PathBuf::from("assets");
30        let logs_root = PathBuf::from("logs");
31        Self {
32            sim_bin: PathBuf::from("sim_view"),
33            train_bin: PathBuf::from("train"),
34            captures_root: assets_root.join("datasets/captures"),
35            captures_filtered_root: assets_root.join("datasets/captures_filtered"),
36            warehouse_manifest: assets_root.join("warehouse/manifest.json"),
37            assets_root,
38            logs_root: logs_root.clone(),
39            metrics_path: logs_root.join("metrics.jsonl"),
40            train_log_path: logs_root.join("train.log"),
41            train_status_paths: vec![
42                PathBuf::from("logs/train_hp_status.json"),
43                PathBuf::from("logs/train_status.json"),
44            ],
45            datagen_args: Vec::new(),
46            training_args: Vec::new(),
47            warehouse_train_template: DEFAULT_TRAIN_TEMPLATE.to_string(),
48            ui_title: "CortenForge Tools".to_string(),
49        }
50    }
51}
52
53#[derive(Debug, Deserialize, Default)]
54struct ToolConfigFile {
55    sim_bin: Option<String>,
56    train_bin: Option<String>,
57    assets_root: Option<String>,
58    captures_root: Option<String>,
59    captures_filtered_root: Option<String>,
60    warehouse_manifest: Option<String>,
61    logs_root: Option<String>,
62    metrics_path: Option<String>,
63    train_log_path: Option<String>,
64    train_status_paths: Option<Vec<String>>,
65    datagen: Option<ArgSection>,
66    training: Option<ArgSection>,
67    warehouse: Option<WarehouseSection>,
68    ui: Option<UiSection>,
69}
70
71#[derive(Debug, Deserialize, Default)]
72struct ArgSection {
73    args: Option<Vec<String>>,
74}
75
76#[derive(Debug, Deserialize, Default)]
77struct WarehouseSection {
78    train_template: Option<String>,
79}
80
81#[derive(Debug, Deserialize, Default)]
82struct UiSection {
83    title: Option<String>,
84}
85
86impl ToolConfig {
87    pub fn load() -> Self {
88        if let Ok(path) = std::env::var("CORTENFORGE_TOOLS_CONFIG") {
89            let cfg = Self::from_path(Path::new(&path)).unwrap_or_default();
90            cfg.warn_if_invalid();
91            return cfg;
92        }
93        let cfg = Self::from_path(Path::new(DEFAULT_CONFIG_NAME)).unwrap_or_default();
94        cfg.warn_if_invalid();
95        cfg
96    }
97
98    pub fn from_path(path: &Path) -> Option<Self> {
99        if !path.exists() {
100            return None;
101        }
102        let raw = std::fs::read_to_string(path).ok()?;
103        let file: ToolConfigFile = toml::from_str(&raw).ok()?;
104        Some(Self::from_file(file))
105    }
106
107    fn from_file(file: ToolConfigFile) -> Self {
108        let assets_root = file
109            .assets_root
110            .map(|v| expand_path(&v))
111            .unwrap_or_else(|| PathBuf::from("assets"));
112        let logs_root = file
113            .logs_root
114            .map(|v| expand_path(&v))
115            .unwrap_or_else(|| PathBuf::from("logs"));
116
117        let captures_root = file
118            .captures_root
119            .map(|v| expand_path(&v))
120            .unwrap_or_else(|| assets_root.join("datasets/captures"));
121        let captures_filtered_root = file
122            .captures_filtered_root
123            .map(|v| expand_path(&v))
124            .unwrap_or_else(|| assets_root.join("datasets/captures_filtered"));
125        let warehouse_manifest = file
126            .warehouse_manifest
127            .map(|v| expand_path(&v))
128            .unwrap_or_else(|| assets_root.join("warehouse/manifest.json"));
129
130        let metrics_path = file
131            .metrics_path
132            .map(|v| expand_path(&v))
133            .unwrap_or_else(|| logs_root.join("metrics.jsonl"));
134        let train_log_path = file
135            .train_log_path
136            .map(|v| expand_path(&v))
137            .unwrap_or_else(|| logs_root.join("train.log"));
138        let train_status_paths = file
139            .train_status_paths
140            .map(|paths| paths.into_iter().map(|v| expand_path(&v)).collect())
141            .unwrap_or_else(|| {
142                vec![
143                    PathBuf::from("logs/train_hp_status.json"),
144                    PathBuf::from("logs/train_status.json"),
145                ]
146            });
147
148        ToolConfig {
149            sim_bin: file
150                .sim_bin
151                .map(|v| expand_path(&v))
152                .unwrap_or_else(|| PathBuf::from("sim_view")),
153            train_bin: file
154                .train_bin
155                .map(|v| expand_path(&v))
156                .unwrap_or_else(|| PathBuf::from("train")),
157            assets_root,
158            captures_root,
159            captures_filtered_root,
160            warehouse_manifest,
161            logs_root,
162            metrics_path,
163            train_log_path,
164            train_status_paths,
165            datagen_args: file
166                .datagen
167                .and_then(|d| d.args)
168                .unwrap_or_default(),
169            training_args: file
170                .training
171                .and_then(|t| t.args)
172                .unwrap_or_default(),
173            warehouse_train_template: file
174                .warehouse
175                .and_then(|w| w.train_template)
176                .unwrap_or_else(|| DEFAULT_TRAIN_TEMPLATE.to_string()),
177            ui_title: file
178                .ui
179                .and_then(|u| u.title)
180                .filter(|t| !t.trim().is_empty())
181                .unwrap_or_else(|| "CortenForge Tools".to_string()),
182        }
183    }
184
185    fn warn_if_invalid(&self) {
186        if self.sim_bin.as_os_str().is_empty() {
187            eprintln!("tools config: sim_bin is empty; sim tools may fail to launch");
188        }
189        if self.train_bin.as_os_str().is_empty() {
190            eprintln!("tools config: train_bin is empty; training tools may fail to launch");
191        }
192        if self.train_status_paths.is_empty() {
193            eprintln!("tools config: train_status_paths is empty; TUI status will be disabled");
194        }
195    }
196}
197
198fn expand_path(raw: &str) -> PathBuf {
199    let mut out = raw.to_string();
200    if let Some(stripped) = out.strip_prefix("~") {
201        if let Ok(home) = std::env::var("HOME") {
202            out = format!("{home}{stripped}");
203        }
204    }
205    PathBuf::from(expand_env(&out))
206}
207
208fn expand_env(input: &str) -> String {
209    let mut out = String::new();
210    let bytes = input.as_bytes();
211    let mut i = 0;
212    while i < bytes.len() {
213        if bytes[i] == b'$' && i + 1 < bytes.len() && bytes[i + 1] == b'{' {
214            if let Some(end) = input[i + 2..].find('}') {
215                let key = &input[i + 2..i + 2 + end];
216                if let Ok(val) = std::env::var(key) {
217                    out.push_str(&val);
218                } else {
219                    out.push_str(&format!("${{{}}}", key));
220                }
221                i += end + 3;
222                continue;
223            }
224        }
225        out.push(bytes[i] as char);
226        i += 1;
227    }
228    out
229}