cortenforge_tools/warehouse_commands/
builder.rs1use super::common::{CmdConfig, WarehouseStore};
2use crate::ToolConfig;
3
4#[derive(Clone, Copy)]
5pub enum Shell {
6 PowerShell,
7 Bash,
8}
9
10impl Shell {
11 fn env_kv(&self, key: &str, val: &str) -> String {
12 match self {
13 Shell::PowerShell => format!("$env:{key}=\"{val}\""),
14 Shell::Bash => format!("{key}=\"{val}\""),
15 }
16 }
17
18 fn separator(&self) -> &'static str {
19 match self {
20 Shell::PowerShell => "; ",
21 Shell::Bash => " ",
22 }
23 }
24}
25
26#[allow(dead_code)]
27pub fn build_command(cfg: &CmdConfig<'_>, shell: Shell) -> String {
28 let tools_cfg = ToolConfig::load();
29 build_command_with_template(cfg, shell, &tools_cfg.warehouse_train_template)
30}
31
32pub fn build_command_with_template(
33 cfg: &CmdConfig<'_>,
34 shell: Shell,
35 template: &str,
36) -> String {
37 let mut env_parts = Vec::new();
38 env_parts.push(shell.env_kv("TENSOR_WAREHOUSE_MANIFEST", cfg.manifest.as_ref()));
39 env_parts.push(shell.env_kv("WAREHOUSE_STORE", cfg.store.as_str()));
40 if matches!(cfg.store, WarehouseStore::Stream) {
41 let depth = cfg.prefetch.unwrap_or(2);
42 env_parts.push(shell.env_kv("WAREHOUSE_PREFETCH", &depth.to_string()));
43 }
44 env_parts.push(shell.env_kv("WGPU_BACKEND", cfg.wgpu_backend.as_ref()));
45 if let Some(adapter) = &cfg.wgpu_adapter {
46 env_parts.push(shell.env_kv("WGPU_ADAPTER_NAME", adapter.as_ref()));
47 }
48 env_parts.push(shell.env_kv("WGPU_POWER_PREF", "high-performance"));
49 env_parts.push(shell.env_kv("RUST_LOG", "trace,wgpu_core=trace,wgpu_hal=trace"));
50
51 let cmd = render_template(
52 template,
53 &[
54 ("MODEL", cfg.model.as_str()),
55 ("BATCH", &cfg.batch_size.to_string()),
56 ("LOG_EVERY", &cfg.log_every.to_string()),
57 ("EXTRA_ARGS", cfg.extra_args.as_ref()),
58 ("MANIFEST", cfg.manifest.as_ref()),
59 ("STORE", cfg.store.as_str()),
60 ("WGPU_BACKEND", cfg.wgpu_backend.as_ref()),
61 (
62 "WGPU_ADAPTER",
63 cfg.wgpu_adapter.as_ref().map(|v| v.as_ref()).unwrap_or(""),
64 ),
65 ],
66 )
67 .trim()
68 .to_string();
69 let cmd = if cmd.trim().is_empty() {
70 eprintln!("warehouse_cmd: empty train_template; falling back to legacy command");
71 default_command(cfg)
72 } else {
73 cmd
74 };
75
76 let sep = shell.separator();
77 match shell {
78 Shell::PowerShell => format!("{}; {}", env_parts.join(sep), cmd),
79 Shell::Bash => format!("{} {}", env_parts.join(sep), cmd),
80 }
81}
82
83fn default_command(cfg: &CmdConfig<'_>) -> String {
84 let mut cmd_parts = Vec::new();
85 cmd_parts.push("cargo train_hp".to_string());
86 cmd_parts.push(format!("--model {}", cfg.model.as_str()));
87 cmd_parts.push(format!("--batch-size {}", cfg.batch_size));
88 cmd_parts.push(format!("--log-every {}", cfg.log_every));
89 if !cfg.extra_args.as_ref().trim().is_empty() {
90 cmd_parts.push(cfg.extra_args.as_ref().trim().to_string());
91 }
92 cmd_parts.join(" ")
93}
94
95fn render_template(template: &str, replacements: &[(&str, &str)]) -> String {
96 let mut out = template.to_string();
97 for (key, val) in replacements {
98 let needle = format!("${{{}}}", key);
99 out = out.replace(&needle, val);
100 }
101 out
102}