cortenforge_tools/warehouse_commands/
builder.rs

1use super::common::{CmdConfig, WarehouseStore};
2use crate::ToolConfig;
3
4#[derive(Clone, Copy)]
5pub enum Shell {
6    PowerShell,
7    Bash,
8}
9
10impl Shell {
11    fn env_kv(&self, key: &str, val: &str) -> String {
12        match self {
13            Shell::PowerShell => format!("$env:{key}=\"{val}\""),
14            Shell::Bash => format!("{key}=\"{val}\""),
15        }
16    }
17
18    fn separator(&self) -> &'static str {
19        match self {
20            Shell::PowerShell => "; ",
21            Shell::Bash => " ",
22        }
23    }
24}
25
26#[allow(dead_code)]
27pub fn build_command(cfg: &CmdConfig<'_>, shell: Shell) -> String {
28    let tools_cfg = ToolConfig::load();
29    build_command_with_template(cfg, shell, &tools_cfg.warehouse_train_template)
30}
31
32pub fn build_command_with_template(cfg: &CmdConfig<'_>, shell: Shell, template: &str) -> String {
33    let mut env_parts = Vec::new();
34    env_parts.push(shell.env_kv("TENSOR_WAREHOUSE_MANIFEST", cfg.manifest.as_ref()));
35    env_parts.push(shell.env_kv("WAREHOUSE_STORE", cfg.store.as_str()));
36    if matches!(cfg.store, WarehouseStore::Stream) {
37        let depth = cfg.prefetch.unwrap_or(2);
38        env_parts.push(shell.env_kv("WAREHOUSE_PREFETCH", &depth.to_string()));
39    }
40    env_parts.push(shell.env_kv("WGPU_BACKEND", cfg.wgpu_backend.as_ref()));
41    if let Some(adapter) = &cfg.wgpu_adapter {
42        env_parts.push(shell.env_kv("WGPU_ADAPTER_NAME", adapter.as_ref()));
43    }
44    env_parts.push(shell.env_kv("WGPU_POWER_PREF", "high-performance"));
45    env_parts.push(shell.env_kv("RUST_LOG", "trace,wgpu_core=trace,wgpu_hal=trace"));
46
47    let cmd = render_template(
48        template,
49        &[
50            ("MODEL", cfg.model.as_str()),
51            ("BATCH", &cfg.batch_size.to_string()),
52            ("LOG_EVERY", &cfg.log_every.to_string()),
53            ("EXTRA_ARGS", cfg.extra_args.as_ref()),
54            ("MANIFEST", cfg.manifest.as_ref()),
55            ("STORE", cfg.store.as_str()),
56            ("WGPU_BACKEND", cfg.wgpu_backend.as_ref()),
57            (
58                "WGPU_ADAPTER",
59                cfg.wgpu_adapter.as_ref().map(|v| v.as_ref()).unwrap_or(""),
60            ),
61        ],
62    )
63    .trim()
64    .to_string();
65    let cmd = if cmd.trim().is_empty() {
66        eprintln!("warehouse_cmd: empty train_template; falling back to legacy command");
67        default_command(cfg)
68    } else {
69        cmd
70    };
71
72    let sep = shell.separator();
73    match shell {
74        Shell::PowerShell => format!("{}; {}", env_parts.join(sep), cmd),
75        Shell::Bash => format!("{} {}", env_parts.join(sep), cmd),
76    }
77}
78
79fn default_command(cfg: &CmdConfig<'_>) -> String {
80    let mut cmd_parts = Vec::new();
81    cmd_parts.push("cargo train_hp".to_string());
82    cmd_parts.push(format!("--model {}", cfg.model.as_str()));
83    cmd_parts.push(format!("--batch-size {}", cfg.batch_size));
84    cmd_parts.push(format!("--log-every {}", cfg.log_every));
85    if !cfg.extra_args.as_ref().trim().is_empty() {
86        cmd_parts.push(cfg.extra_args.as_ref().trim().to_string());
87    }
88    cmd_parts.join(" ")
89}
90
91fn render_template(template: &str, replacements: &[(&str, &str)]) -> String {
92    let mut out = template.to_string();
93    for (key, val) in replacements {
94        let needle = format!("${{{}}}", key);
95        out = out.replace(&needle, val);
96    }
97    out
98}