cortenforge_tools/warehouse_commands/
builder.rs1use super::common::{CmdConfig, WarehouseStore};
2use crate::ToolConfig;
3
4#[derive(Clone, Copy)]
5pub enum Shell {
6 PowerShell,
7 Bash,
8}
9
10impl Shell {
11 fn env_kv(&self, key: &str, val: &str) -> String {
12 match self {
13 Shell::PowerShell => format!("$env:{key}=\"{val}\""),
14 Shell::Bash => format!("{key}=\"{val}\""),
15 }
16 }
17
18 fn separator(&self) -> &'static str {
19 match self {
20 Shell::PowerShell => "; ",
21 Shell::Bash => " ",
22 }
23 }
24}
25
26#[allow(dead_code)]
27pub fn build_command(cfg: &CmdConfig<'_>, shell: Shell) -> String {
28 let tools_cfg = ToolConfig::load();
29 build_command_with_template(cfg, shell, &tools_cfg.warehouse_train_template)
30}
31
32pub fn build_command_with_template(cfg: &CmdConfig<'_>, shell: Shell, template: &str) -> String {
33 let mut env_parts = Vec::new();
34 env_parts.push(shell.env_kv("TENSOR_WAREHOUSE_MANIFEST", cfg.manifest.as_ref()));
35 env_parts.push(shell.env_kv("WAREHOUSE_STORE", cfg.store.as_str()));
36 if matches!(cfg.store, WarehouseStore::Stream) {
37 let depth = cfg.prefetch.unwrap_or(2);
38 env_parts.push(shell.env_kv("WAREHOUSE_PREFETCH", &depth.to_string()));
39 }
40 env_parts.push(shell.env_kv("WGPU_BACKEND", cfg.wgpu_backend.as_ref()));
41 if let Some(adapter) = &cfg.wgpu_adapter {
42 env_parts.push(shell.env_kv("WGPU_ADAPTER_NAME", adapter.as_ref()));
43 }
44 env_parts.push(shell.env_kv("WGPU_POWER_PREF", "high-performance"));
45 env_parts.push(shell.env_kv("RUST_LOG", "trace,wgpu_core=trace,wgpu_hal=trace"));
46
47 let cmd = render_template(
48 template,
49 &[
50 ("MODEL", cfg.model.as_str()),
51 ("BATCH", &cfg.batch_size.to_string()),
52 ("LOG_EVERY", &cfg.log_every.to_string()),
53 ("EXTRA_ARGS", cfg.extra_args.as_ref()),
54 ("MANIFEST", cfg.manifest.as_ref()),
55 ("STORE", cfg.store.as_str()),
56 ("WGPU_BACKEND", cfg.wgpu_backend.as_ref()),
57 (
58 "WGPU_ADAPTER",
59 cfg.wgpu_adapter.as_ref().map(|v| v.as_ref()).unwrap_or(""),
60 ),
61 ],
62 )
63 .trim()
64 .to_string();
65 let cmd = if cmd.trim().is_empty() {
66 eprintln!("warehouse_cmd: empty train_template; falling back to legacy command");
67 default_command(cfg)
68 } else {
69 cmd
70 };
71
72 let sep = shell.separator();
73 match shell {
74 Shell::PowerShell => format!("{}; {}", env_parts.join(sep), cmd),
75 Shell::Bash => format!("{} {}", env_parts.join(sep), cmd),
76 }
77}
78
79fn default_command(cfg: &CmdConfig<'_>) -> String {
80 let mut cmd_parts = Vec::new();
81 cmd_parts.push("cargo train_hp".to_string());
82 cmd_parts.push(format!("--model {}", cfg.model.as_str()));
83 cmd_parts.push(format!("--batch-size {}", cfg.batch_size));
84 cmd_parts.push(format!("--log-every {}", cfg.log_every));
85 if !cfg.extra_args.as_ref().trim().is_empty() {
86 cmd_parts.push(cfg.extra_args.as_ref().trim().to_string());
87 }
88 cmd_parts.join(" ")
89}
90
91fn render_template(template: &str, replacements: &[(&str, &str)]) -> String {
92 let mut out = template.to_string();
93 for (key, val) in replacements {
94 let needle = format!("${{{}}}", key);
95 out = out.replace(&needle, val);
96 }
97 out
98}