cortenforge_tools/warehouse_commands/
builder.rs1use super::common::{CmdConfig, WarehouseStore};
2use crate::ToolConfig;
3
4#[derive(Clone, Copy)]
5pub enum Shell {
6 PowerShell,
7 Bash,
8}
9
10impl Shell {
11 fn env_kv(&self, key: &str, val: &str) -> String {
12 match self {
13 Shell::PowerShell => format!("$env:{key}=\"{val}\""),
14 Shell::Bash => format!("{key}=\"{val}\""),
15 }
16 }
17
18 fn separator(&self) -> &'static str {
19 match self {
20 Shell::PowerShell => "; ",
21 Shell::Bash => " ",
22 }
23 }
24}
25
26pub fn build_command(cfg: &CmdConfig<'_>, shell: Shell) -> String {
27 let tools_cfg = ToolConfig::load();
28 build_command_with_template(cfg, shell, &tools_cfg.warehouse_train_template)
29}
30
31pub fn build_command_with_template(cfg: &CmdConfig<'_>, shell: Shell, template: &str) -> String {
32 let mut env_parts = Vec::new();
33 env_parts.push(shell.env_kv("TENSOR_WAREHOUSE_MANIFEST", cfg.manifest.as_ref()));
34 env_parts.push(shell.env_kv("WAREHOUSE_STORE", cfg.store.as_str()));
35 if matches!(cfg.store, WarehouseStore::Stream) {
36 let depth = cfg.prefetch.unwrap_or(2);
37 env_parts.push(shell.env_kv("WAREHOUSE_PREFETCH", &depth.to_string()));
38 }
39 env_parts.push(shell.env_kv("WGPU_BACKEND", cfg.wgpu_backend.as_ref()));
40 if let Some(adapter) = &cfg.wgpu_adapter {
41 env_parts.push(shell.env_kv("WGPU_ADAPTER_NAME", adapter.as_ref()));
42 }
43 env_parts.push(shell.env_kv("WGPU_POWER_PREF", "high-performance"));
44 env_parts.push(shell.env_kv("RUST_LOG", "trace,wgpu_core=trace,wgpu_hal=trace"));
45
46 let cmd = render_template(
47 template,
48 &[
49 ("MODEL", cfg.model.as_str()),
50 ("BATCH", &cfg.batch_size.to_string()),
51 ("LOG_EVERY", &cfg.log_every.to_string()),
52 ("EXTRA_ARGS", cfg.extra_args.as_ref()),
53 ("MANIFEST", cfg.manifest.as_ref()),
54 ("STORE", cfg.store.as_str()),
55 ("WGPU_BACKEND", cfg.wgpu_backend.as_ref()),
56 (
57 "WGPU_ADAPTER",
58 cfg.wgpu_adapter.as_ref().map(|v| v.as_ref()).unwrap_or(""),
59 ),
60 ],
61 )
62 .trim()
63 .to_string();
64 if cmd.trim().is_empty() {
65 eprintln!(
66 "warehouse_cmd: empty train_template; set [warehouse].train_template in cortenforge-tools.toml"
67 );
68 std::process::exit(2);
69 }
70
71 let sep = shell.separator();
72 match shell {
73 Shell::PowerShell => format!("{}; {}", env_parts.join(sep), cmd),
74 Shell::Bash => format!("{} {}", env_parts.join(sep), cmd),
75 }
76}
77
78fn render_template(template: &str, replacements: &[(&str, &str)]) -> String {
79 let mut out = template.to_string();
80 for (key, val) in replacements {
81 let needle = format!("${{{}}}", key);
82 out = out.replace(&needle, val);
83 }
84 out
85}