cortenforge_tools/warehouse_commands/
builder.rs

1use super::common::{CmdConfig, WarehouseStore};
2use crate::ToolConfig;
3
4#[derive(Clone, Copy)]
5pub enum Shell {
6    PowerShell,
7    Bash,
8}
9
10impl Shell {
11    fn env_kv(&self, key: &str, val: &str) -> String {
12        match self {
13            Shell::PowerShell => format!("$env:{key}=\"{val}\""),
14            Shell::Bash => format!("{key}=\"{val}\""),
15        }
16    }
17
18    fn separator(&self) -> &'static str {
19        match self {
20            Shell::PowerShell => "; ",
21            Shell::Bash => " ",
22        }
23    }
24}
25
26pub fn build_command(cfg: &CmdConfig<'_>, shell: Shell) -> String {
27    let tools_cfg = ToolConfig::load();
28    build_command_with_template(cfg, shell, &tools_cfg.warehouse_train_template)
29}
30
31pub fn build_command_with_template(cfg: &CmdConfig<'_>, shell: Shell, template: &str) -> String {
32    let mut env_parts = Vec::new();
33    env_parts.push(shell.env_kv("TENSOR_WAREHOUSE_MANIFEST", cfg.manifest.as_ref()));
34    env_parts.push(shell.env_kv("WAREHOUSE_STORE", cfg.store.as_str()));
35    if matches!(cfg.store, WarehouseStore::Stream) {
36        let depth = cfg.prefetch.unwrap_or(2);
37        env_parts.push(shell.env_kv("WAREHOUSE_PREFETCH", &depth.to_string()));
38    }
39    env_parts.push(shell.env_kv("WGPU_BACKEND", cfg.wgpu_backend.as_ref()));
40    if let Some(adapter) = &cfg.wgpu_adapter {
41        env_parts.push(shell.env_kv("WGPU_ADAPTER_NAME", adapter.as_ref()));
42    }
43    env_parts.push(shell.env_kv("WGPU_POWER_PREF", "high-performance"));
44    env_parts.push(shell.env_kv("RUST_LOG", "trace,wgpu_core=trace,wgpu_hal=trace"));
45
46    let cmd = render_template(
47        template,
48        &[
49            ("MODEL", cfg.model.as_str()),
50            ("BATCH", &cfg.batch_size.to_string()),
51            ("LOG_EVERY", &cfg.log_every.to_string()),
52            ("EXTRA_ARGS", cfg.extra_args.as_ref()),
53            ("MANIFEST", cfg.manifest.as_ref()),
54            ("STORE", cfg.store.as_str()),
55            ("WGPU_BACKEND", cfg.wgpu_backend.as_ref()),
56            (
57                "WGPU_ADAPTER",
58                cfg.wgpu_adapter.as_ref().map(|v| v.as_ref()).unwrap_or(""),
59            ),
60        ],
61    )
62    .trim()
63    .to_string();
64    if cmd.trim().is_empty() {
65        eprintln!(
66            "warehouse_cmd: empty train_template; set [warehouse].train_template in cortenforge-tools.toml"
67        );
68        std::process::exit(2);
69    }
70
71    let sep = shell.separator();
72    match shell {
73        Shell::PowerShell => format!("{}; {}", env_parts.join(sep), cmd),
74        Shell::Bash => format!("{} {}", env_parts.join(sep), cmd),
75    }
76}
77
78fn render_template(template: &str, replacements: &[(&str, &str)]) -> String {
79    let mut out = template.to_string();
80    for (key, val) in replacements {
81        let needle = format!("${{{}}}", key);
82        out = out.replace(&needle, val);
83    }
84    out
85}