cortenforge_tools/warehouse_commands/
builder.rs

1use super::common::{CmdConfig, WarehouseStore};
2use crate::ToolConfig;
3
4#[derive(Clone, Copy)]
5pub enum Shell {
6    PowerShell,
7    Bash,
8}
9
10impl Shell {
11    fn env_kv(&self, key: &str, val: &str) -> String {
12        match self {
13            Shell::PowerShell => format!("$env:{key}=\"{val}\""),
14            Shell::Bash => format!("{key}=\"{val}\""),
15        }
16    }
17
18    fn separator(&self) -> &'static str {
19        match self {
20            Shell::PowerShell => "; ",
21            Shell::Bash => " ",
22        }
23    }
24}
25
26#[allow(dead_code)]
27pub fn build_command(cfg: &CmdConfig<'_>, shell: Shell) -> String {
28    let tools_cfg = ToolConfig::load();
29    build_command_with_template(cfg, shell, &tools_cfg.warehouse_train_template)
30}
31
32pub fn build_command_with_template(
33    cfg: &CmdConfig<'_>,
34    shell: Shell,
35    template: &str,
36) -> String {
37    let mut env_parts = Vec::new();
38    env_parts.push(shell.env_kv("TENSOR_WAREHOUSE_MANIFEST", cfg.manifest.as_ref()));
39    env_parts.push(shell.env_kv("WAREHOUSE_STORE", cfg.store.as_str()));
40    if matches!(cfg.store, WarehouseStore::Stream) {
41        let depth = cfg.prefetch.unwrap_or(2);
42        env_parts.push(shell.env_kv("WAREHOUSE_PREFETCH", &depth.to_string()));
43    }
44    env_parts.push(shell.env_kv("WGPU_BACKEND", cfg.wgpu_backend.as_ref()));
45    if let Some(adapter) = &cfg.wgpu_adapter {
46        env_parts.push(shell.env_kv("WGPU_ADAPTER_NAME", adapter.as_ref()));
47    }
48    env_parts.push(shell.env_kv("WGPU_POWER_PREF", "high-performance"));
49    env_parts.push(shell.env_kv("RUST_LOG", "trace,wgpu_core=trace,wgpu_hal=trace"));
50
51    let cmd = render_template(
52        template,
53        &[
54            ("MODEL", cfg.model.as_str()),
55            ("BATCH", &cfg.batch_size.to_string()),
56            ("LOG_EVERY", &cfg.log_every.to_string()),
57            ("EXTRA_ARGS", cfg.extra_args.as_ref()),
58            ("MANIFEST", cfg.manifest.as_ref()),
59            ("STORE", cfg.store.as_str()),
60            ("WGPU_BACKEND", cfg.wgpu_backend.as_ref()),
61            (
62                "WGPU_ADAPTER",
63                cfg.wgpu_adapter.as_ref().map(|v| v.as_ref()).unwrap_or(""),
64            ),
65        ],
66    )
67    .trim()
68    .to_string();
69    let cmd = if cmd.trim().is_empty() {
70        eprintln!("warehouse_cmd: empty train_template; falling back to legacy command");
71        default_command(cfg)
72    } else {
73        cmd
74    };
75
76    let sep = shell.separator();
77    match shell {
78        Shell::PowerShell => format!("{}; {}", env_parts.join(sep), cmd),
79        Shell::Bash => format!("{} {}", env_parts.join(sep), cmd),
80    }
81}
82
83fn default_command(cfg: &CmdConfig<'_>) -> String {
84    let mut cmd_parts = Vec::new();
85    cmd_parts.push("cargo train_hp".to_string());
86    cmd_parts.push(format!("--model {}", cfg.model.as_str()));
87    cmd_parts.push(format!("--batch-size {}", cfg.batch_size));
88    cmd_parts.push(format!("--log-every {}", cfg.log_every));
89    if !cfg.extra_args.as_ref().trim().is_empty() {
90        cmd_parts.push(cfg.extra_args.as_ref().trim().to_string());
91    }
92    cmd_parts.join(" ")
93}
94
95fn render_template(template: &str, replacements: &[(&str, &str)]) -> String {
96    let mut out = template.to_string();
97    for (key, val) in replacements {
98        let needle = format!("${{{}}}", key);
99        out = out.replace(&needle, val);
100    }
101    out
102}