cortenforge_tools/
config.rs1use std::path::{Path, PathBuf};
2
3use serde::Deserialize;
4
5const DEFAULT_CONFIG_NAME: &str = "cortenforge-tools.toml";
6const DEFAULT_TRAIN_TEMPLATE: &str =
7 "cargo train_hp --model ${MODEL} --batch-size ${BATCH} --log-every ${LOG_EVERY} ${EXTRA_ARGS}";
8
9#[derive(Debug, Clone)]
10pub struct ToolConfig {
11 pub sim_bin: PathBuf,
12 pub train_bin: PathBuf,
13 pub assets_root: PathBuf,
14 pub captures_root: PathBuf,
15 pub captures_filtered_root: PathBuf,
16 pub warehouse_manifest: PathBuf,
17 pub logs_root: PathBuf,
18 pub metrics_path: PathBuf,
19 pub train_log_path: PathBuf,
20 pub train_status_paths: Vec<PathBuf>,
21 pub datagen_args: Vec<String>,
22 pub training_args: Vec<String>,
23 pub warehouse_train_template: String,
24 pub ui_title: String,
25}
26
27impl Default for ToolConfig {
28 fn default() -> Self {
29 let assets_root = PathBuf::from("assets");
30 let logs_root = PathBuf::from("logs");
31 Self {
32 sim_bin: PathBuf::from("sim_view"),
33 train_bin: PathBuf::from("train"),
34 captures_root: assets_root.join("datasets/captures"),
35 captures_filtered_root: assets_root.join("datasets/captures_filtered"),
36 warehouse_manifest: assets_root.join("warehouse/manifest.json"),
37 assets_root,
38 logs_root: logs_root.clone(),
39 metrics_path: logs_root.join("metrics.jsonl"),
40 train_log_path: logs_root.join("train.log"),
41 train_status_paths: vec![
42 PathBuf::from("logs/train_hp_status.json"),
43 PathBuf::from("logs/train_status.json"),
44 ],
45 datagen_args: Vec::new(),
46 training_args: Vec::new(),
47 warehouse_train_template: DEFAULT_TRAIN_TEMPLATE.to_string(),
48 ui_title: "CortenForge Tools".to_string(),
49 }
50 }
51}
52
53#[derive(Debug, Deserialize, Default)]
54struct ToolConfigFile {
55 sim_bin: Option<String>,
56 train_bin: Option<String>,
57 assets_root: Option<String>,
58 captures_root: Option<String>,
59 captures_filtered_root: Option<String>,
60 warehouse_manifest: Option<String>,
61 logs_root: Option<String>,
62 metrics_path: Option<String>,
63 train_log_path: Option<String>,
64 train_status_paths: Option<Vec<String>>,
65 datagen: Option<ArgSection>,
66 training: Option<ArgSection>,
67 warehouse: Option<WarehouseSection>,
68 ui: Option<UiSection>,
69}
70
71#[derive(Debug, Deserialize, Default)]
72struct ArgSection {
73 args: Option<Vec<String>>,
74}
75
76#[derive(Debug, Deserialize, Default)]
77struct WarehouseSection {
78 train_template: Option<String>,
79}
80
81#[derive(Debug, Deserialize, Default)]
82struct UiSection {
83 title: Option<String>,
84}
85
86impl ToolConfig {
87 pub fn load() -> Self {
88 if let Ok(path) = std::env::var("CORTENFORGE_TOOLS_CONFIG") {
89 let cfg = Self::from_path(Path::new(&path)).unwrap_or_default();
90 cfg.warn_if_invalid();
91 return cfg;
92 }
93 let cfg = Self::from_path(Path::new(DEFAULT_CONFIG_NAME)).unwrap_or_default();
94 cfg.warn_if_invalid();
95 cfg
96 }
97
98 pub fn from_path(path: &Path) -> Option<Self> {
99 if !path.exists() {
100 return None;
101 }
102 let raw = std::fs::read_to_string(path).ok()?;
103 let file: ToolConfigFile = toml::from_str(&raw).ok()?;
104 Some(Self::from_file(file))
105 }
106
107 fn from_file(file: ToolConfigFile) -> Self {
108 let assets_root = file
109 .assets_root
110 .map(|v| expand_path(&v))
111 .unwrap_or_else(|| PathBuf::from("assets"));
112 let logs_root = file
113 .logs_root
114 .map(|v| expand_path(&v))
115 .unwrap_or_else(|| PathBuf::from("logs"));
116
117 let captures_root = file
118 .captures_root
119 .map(|v| expand_path(&v))
120 .unwrap_or_else(|| assets_root.join("datasets/captures"));
121 let captures_filtered_root = file
122 .captures_filtered_root
123 .map(|v| expand_path(&v))
124 .unwrap_or_else(|| assets_root.join("datasets/captures_filtered"));
125 let warehouse_manifest = file
126 .warehouse_manifest
127 .map(|v| expand_path(&v))
128 .unwrap_or_else(|| assets_root.join("warehouse/manifest.json"));
129
130 let metrics_path = file
131 .metrics_path
132 .map(|v| expand_path(&v))
133 .unwrap_or_else(|| logs_root.join("metrics.jsonl"));
134 let train_log_path = file
135 .train_log_path
136 .map(|v| expand_path(&v))
137 .unwrap_or_else(|| logs_root.join("train.log"));
138 let train_status_paths = file
139 .train_status_paths
140 .map(|paths| paths.into_iter().map(|v| expand_path(&v)).collect())
141 .unwrap_or_else(|| {
142 vec![
143 PathBuf::from("logs/train_hp_status.json"),
144 PathBuf::from("logs/train_status.json"),
145 ]
146 });
147
148 ToolConfig {
149 sim_bin: file
150 .sim_bin
151 .map(|v| expand_path(&v))
152 .unwrap_or_else(|| PathBuf::from("sim_view")),
153 train_bin: file
154 .train_bin
155 .map(|v| expand_path(&v))
156 .unwrap_or_else(|| PathBuf::from("train")),
157 assets_root,
158 captures_root,
159 captures_filtered_root,
160 warehouse_manifest,
161 logs_root,
162 metrics_path,
163 train_log_path,
164 train_status_paths,
165 datagen_args: file.datagen.and_then(|d| d.args).unwrap_or_default(),
166 training_args: file.training.and_then(|t| t.args).unwrap_or_default(),
167 warehouse_train_template: file
168 .warehouse
169 .and_then(|w| w.train_template)
170 .unwrap_or_else(|| DEFAULT_TRAIN_TEMPLATE.to_string()),
171 ui_title: file
172 .ui
173 .and_then(|u| u.title)
174 .filter(|t| !t.trim().is_empty())
175 .unwrap_or_else(|| "CortenForge Tools".to_string()),
176 }
177 }
178
179 fn warn_if_invalid(&self) {
180 if self.sim_bin.as_os_str().is_empty() {
181 eprintln!("tools config: sim_bin is empty; sim tools may fail to launch");
182 }
183 if self.train_bin.as_os_str().is_empty() {
184 eprintln!("tools config: train_bin is empty; training tools may fail to launch");
185 }
186 if self.warehouse_train_template.trim().is_empty() {
187 eprintln!(
188 "tools config: warehouse.train_template is empty; warehouse_cmd will fail to run"
189 );
190 }
191 if self.train_status_paths.is_empty() {
192 eprintln!("tools config: train_status_paths is empty; TUI status will be disabled");
193 }
194 }
195}
196
197fn expand_path(raw: &str) -> PathBuf {
198 let mut out = raw.to_string();
199 if let Some(stripped) = out.strip_prefix("~") {
200 if let Ok(home) = std::env::var("HOME") {
201 out = format!("{home}{stripped}");
202 }
203 }
204 PathBuf::from(expand_env(&out))
205}
206
207fn expand_env(input: &str) -> String {
208 let mut out = String::new();
209 let bytes = input.as_bytes();
210 let mut i = 0;
211 while i < bytes.len() {
212 if bytes[i] == b'$' && i + 1 < bytes.len() && bytes[i + 1] == b'{' {
213 if let Some(end) = input[i + 2..].find('}') {
214 let key = &input[i + 2..i + 2 + end];
215 if let Ok(val) = std::env::var(key) {
216 out.push_str(&val);
217 } else {
218 out.push_str(&format!("${{{}}}", key));
219 }
220 i += end + 3;
221 continue;
222 }
223 }
224 out.push(bytes[i] as char);
225 i += 1;
226 }
227 out
228}