1use super::types::{yaml_value_to_string, Dependency, Output};
6use anyhow::{bail, Result};
7use std::collections::HashMap;
8
9pub fn resolve_template(
13 cmd: &str,
14 global_params: &HashMap<String, serde_yaml_ng::Value>,
15 _stage_param_keys: &Option<Vec<String>>,
16 deps: &[Dependency],
17 outs: &[Output],
18) -> Result<String> {
19 let mut result = String::with_capacity(cmd.len());
20 let mut pos = 0;
21
22 while pos < cmd.len() {
23 if cmd[pos..].starts_with("{{") {
24 let start = pos + 2;
25 if let Some(end_offset) = cmd[start..].find("}}") {
26 let ref_str = cmd[start..start + end_offset].trim();
27 let replacement = resolve_ref(ref_str, global_params, deps, outs)?;
28 result.push_str(&replacement);
29 pos = start + end_offset + 2;
30 } else {
31 bail!("unclosed template expression at position {}", pos);
32 }
33 } else {
34 let ch = cmd[pos..].chars().next().expect("iterator empty");
36 result.push(ch);
37 pos += ch.len_utf8();
38 }
39 }
40
41 Ok(result)
42}
43
44fn resolve_ref(
45 ref_str: &str,
46 global_params: &HashMap<String, serde_yaml_ng::Value>,
47 deps: &[Dependency],
48 outs: &[Output],
49) -> Result<String> {
50 if let Some(key) = ref_str.strip_prefix("params.") {
52 if let Some(val) = global_params.get(key) {
53 return Ok(yaml_value_to_string(val));
54 }
55 bail!("undefined param '{}'", key);
56 }
57
58 if let Some(idx_str) = ref_str.strip_prefix("deps[").and_then(|s| s.strip_suffix("].path")) {
60 let idx: usize =
61 idx_str.parse().map_err(|_| anyhow::anyhow!("invalid deps index '{}'", idx_str))?;
62 if idx >= deps.len() {
63 bail!("deps[{}] out of range (only {} deps)", idx, deps.len());
64 }
65 return Ok(deps[idx].path.clone());
66 }
67
68 if let Some(idx_str) = ref_str.strip_prefix("outs[").and_then(|s| s.strip_suffix("].path")) {
70 let idx: usize =
71 idx_str.parse().map_err(|_| anyhow::anyhow!("invalid outs index '{}'", idx_str))?;
72 if idx >= outs.len() {
73 bail!("outs[{}] out of range (only {} outs)", idx, outs.len());
74 }
75 return Ok(outs[idx].path.clone());
76 }
77
78 bail!("unknown template reference '{}'", ref_str);
79}
80
81#[cfg(test)]
82#[allow(non_snake_case)]
83mod tests {
84 use super::*;
85
86 fn make_params(pairs: &[(&str, &str)]) -> HashMap<String, serde_yaml_ng::Value> {
87 pairs
88 .iter()
89 .map(|(k, v)| (k.to_string(), serde_yaml_ng::Value::String(v.to_string())))
90 .collect()
91 }
92
93 fn make_deps(paths: &[&str]) -> Vec<Dependency> {
94 paths.iter().map(|p| Dependency { path: p.to_string(), dep_type: None }).collect()
95 }
96
97 fn make_outs(paths: &[&str]) -> Vec<Output> {
98 paths.iter().map(|p| Output { path: p.to_string(), out_type: None, remote: None }).collect()
99 }
100
101 #[test]
102 fn test_PB001_param_substitution() {
103 let global = make_params(&[("model", "whisper-base")]);
104 let result = resolve_template("run --model {{params.model}}", &global, &None, &[], &[])
105 .expect("unexpected failure");
106 assert_eq!(result, "run --model whisper-base");
107 }
108
109 #[test]
110 fn test_PB001_numeric_param_substitution() {
111 let mut global = HashMap::new();
112 global.insert(
113 "chunk_size".to_string(),
114 serde_yaml_ng::Value::Number(serde_yaml_ng::Number::from(512)),
115 );
116 let result =
117 resolve_template("split --size {{params.chunk_size}}", &global, &None, &[], &[])
118 .expect("unexpected failure");
119 assert_eq!(result, "split --size 512");
120 }
121
122 #[test]
123 fn test_PB001_deps_path_ref() {
124 let deps = make_deps(&["/data/input.wav", "/data/config.json"]);
125 let result = resolve_template(
126 "cat {{deps[0].path}} {{deps[1].path}}",
127 &HashMap::new(),
128 &None,
129 &deps,
130 &[],
131 )
132 .expect("unexpected failure");
133 assert_eq!(result, "cat /data/input.wav /data/config.json");
134 }
135
136 #[test]
137 fn test_PB001_outs_path_ref() {
138 let outs = make_outs(&["/tmp/output.txt"]);
139 let result =
140 resolve_template("echo hello > {{outs[0].path}}", &HashMap::new(), &None, &[], &outs)
141 .expect("unexpected failure");
142 assert_eq!(result, "echo hello > /tmp/output.txt");
143 }
144
145 #[test]
146 fn test_PB001_multiple_substitutions() {
147 let global = make_params(&[("model", "base"), ("lang", "en")]);
148 let deps = make_deps(&["/input.wav"]);
149 let outs = make_outs(&["/output.txt"]);
150 let result = resolve_template(
151 "transcribe --model {{params.model}} --lang {{params.lang}} {{deps[0].path}} > {{outs[0].path}}",
152 &global, &None, &deps, &outs,
153 ).expect("unexpected failure");
154 assert_eq!(result, "transcribe --model base --lang en /input.wav > /output.txt");
155 }
156
157 #[test]
158 fn test_PB001_no_templates() {
159 let result = resolve_template("echo hello world", &HashMap::new(), &None, &[], &[])
160 .expect("unexpected failure");
161 assert_eq!(result, "echo hello world");
162 }
163
164 #[test]
165 fn test_PB001_missing_param_error() {
166 let err = resolve_template("echo {{params.missing}}", &HashMap::new(), &None, &[], &[])
167 .unwrap_err();
168 assert!(err.to_string().contains("undefined param"));
169 }
170
171 #[test]
172 fn test_PB001_deps_out_of_range() {
173 let err =
174 resolve_template("cat {{deps[5].path}}", &HashMap::new(), &None, &[], &[]).unwrap_err();
175 assert!(err.to_string().contains("out of range"));
176 }
177
178 #[test]
179 fn test_PB001_outs_out_of_range() {
180 let err =
181 resolve_template("cat {{outs[0].path}}", &HashMap::new(), &None, &[], &[]).unwrap_err();
182 assert!(err.to_string().contains("out of range"));
183 }
184
185 #[test]
186 fn test_PB001_unclosed_template() {
187 let err =
188 resolve_template("echo {{params.model", &HashMap::new(), &None, &[], &[]).unwrap_err();
189 assert!(err.to_string().contains("unclosed"));
190 }
191
192 #[test]
193 fn test_PB001_whitespace_in_template() {
194 let global = make_params(&[("name", "world")]);
195 let result = resolve_template("echo {{ params.name }}", &global, &None, &[], &[])
196 .expect("unexpected failure");
197 assert_eq!(result, "echo world");
198 }
199
200 #[test]
201 fn test_PB001_unicode_safe() {
202 let global = make_params(&[("name", "héllo")]);
203 let result = resolve_template("echo {{params.name}} — résumé", &global, &None, &[], &[])
204 .expect("unexpected failure");
205 assert_eq!(result, "echo héllo — résumé");
206 }
207}