Skip to main content

batuta/playbook/
template.rs

1//! Template resolution for playbook commands (PB-001)
2//!
3//! Handles `{{params.key}}`, `{{deps[N].path}}`, `{{outs[N].path}}` substitution.
4
5use super::types::{yaml_value_to_string, Dependency, Output};
6use anyhow::{bail, Result};
7use std::collections::HashMap;
8
9/// Resolve all template variables in a command string
10///
11/// Uses UTF-8-safe string scanning (no byte-level char casting).
12pub fn resolve_template(
13    cmd: &str,
14    global_params: &HashMap<String, serde_yaml_ng::Value>,
15    _stage_param_keys: &Option<Vec<String>>,
16    deps: &[Dependency],
17    outs: &[Output],
18) -> Result<String> {
19    let mut result = String::with_capacity(cmd.len());
20    let mut pos = 0;
21
22    while pos < cmd.len() {
23        if cmd[pos..].starts_with("{{") {
24            let start = pos + 2;
25            if let Some(end_offset) = cmd[start..].find("}}") {
26                let ref_str = cmd[start..start + end_offset].trim();
27                let replacement = resolve_ref(ref_str, global_params, deps, outs)?;
28                result.push_str(&replacement);
29                pos = start + end_offset + 2;
30            } else {
31                bail!("unclosed template expression at position {}", pos);
32            }
33        } else {
34            // UTF-8-safe: advance by one character
35            let ch = cmd[pos..].chars().next().expect("iterator empty");
36            result.push(ch);
37            pos += ch.len_utf8();
38        }
39    }
40
41    Ok(result)
42}
43
44fn resolve_ref(
45    ref_str: &str,
46    global_params: &HashMap<String, serde_yaml_ng::Value>,
47    deps: &[Dependency],
48    outs: &[Output],
49) -> Result<String> {
50    // {{params.key}} — resolved from global params
51    if let Some(key) = ref_str.strip_prefix("params.") {
52        if let Some(val) = global_params.get(key) {
53            return Ok(yaml_value_to_string(val));
54        }
55        bail!("undefined param '{}'", key);
56    }
57
58    // {{deps[N].path}}
59    if let Some(idx_str) = ref_str.strip_prefix("deps[").and_then(|s| s.strip_suffix("].path")) {
60        let idx: usize =
61            idx_str.parse().map_err(|_| anyhow::anyhow!("invalid deps index '{}'", idx_str))?;
62        if idx >= deps.len() {
63            bail!("deps[{}] out of range (only {} deps)", idx, deps.len());
64        }
65        return Ok(deps[idx].path.clone());
66    }
67
68    // {{outs[N].path}}
69    if let Some(idx_str) = ref_str.strip_prefix("outs[").and_then(|s| s.strip_suffix("].path")) {
70        let idx: usize =
71            idx_str.parse().map_err(|_| anyhow::anyhow!("invalid outs index '{}'", idx_str))?;
72        if idx >= outs.len() {
73            bail!("outs[{}] out of range (only {} outs)", idx, outs.len());
74        }
75        return Ok(outs[idx].path.clone());
76    }
77
78    bail!("unknown template reference '{}'", ref_str);
79}
80
81#[cfg(test)]
82#[allow(non_snake_case)]
83mod tests {
84    use super::*;
85
86    fn make_params(pairs: &[(&str, &str)]) -> HashMap<String, serde_yaml_ng::Value> {
87        pairs
88            .iter()
89            .map(|(k, v)| (k.to_string(), serde_yaml_ng::Value::String(v.to_string())))
90            .collect()
91    }
92
93    fn make_deps(paths: &[&str]) -> Vec<Dependency> {
94        paths.iter().map(|p| Dependency { path: p.to_string(), dep_type: None }).collect()
95    }
96
97    fn make_outs(paths: &[&str]) -> Vec<Output> {
98        paths.iter().map(|p| Output { path: p.to_string(), out_type: None, remote: None }).collect()
99    }
100
101    #[test]
102    fn test_PB001_param_substitution() {
103        let global = make_params(&[("model", "whisper-base")]);
104        let result = resolve_template("run --model {{params.model}}", &global, &None, &[], &[])
105            .expect("unexpected failure");
106        assert_eq!(result, "run --model whisper-base");
107    }
108
109    #[test]
110    fn test_PB001_numeric_param_substitution() {
111        let mut global = HashMap::new();
112        global.insert(
113            "chunk_size".to_string(),
114            serde_yaml_ng::Value::Number(serde_yaml_ng::Number::from(512)),
115        );
116        let result =
117            resolve_template("split --size {{params.chunk_size}}", &global, &None, &[], &[])
118                .expect("unexpected failure");
119        assert_eq!(result, "split --size 512");
120    }
121
122    #[test]
123    fn test_PB001_deps_path_ref() {
124        let deps = make_deps(&["/data/input.wav", "/data/config.json"]);
125        let result = resolve_template(
126            "cat {{deps[0].path}} {{deps[1].path}}",
127            &HashMap::new(),
128            &None,
129            &deps,
130            &[],
131        )
132        .expect("unexpected failure");
133        assert_eq!(result, "cat /data/input.wav /data/config.json");
134    }
135
136    #[test]
137    fn test_PB001_outs_path_ref() {
138        let outs = make_outs(&["/tmp/output.txt"]);
139        let result =
140            resolve_template("echo hello > {{outs[0].path}}", &HashMap::new(), &None, &[], &outs)
141                .expect("unexpected failure");
142        assert_eq!(result, "echo hello > /tmp/output.txt");
143    }
144
145    #[test]
146    fn test_PB001_multiple_substitutions() {
147        let global = make_params(&[("model", "base"), ("lang", "en")]);
148        let deps = make_deps(&["/input.wav"]);
149        let outs = make_outs(&["/output.txt"]);
150        let result = resolve_template(
151            "transcribe --model {{params.model}} --lang {{params.lang}} {{deps[0].path}} > {{outs[0].path}}",
152            &global, &None, &deps, &outs,
153        ).expect("unexpected failure");
154        assert_eq!(result, "transcribe --model base --lang en /input.wav > /output.txt");
155    }
156
157    #[test]
158    fn test_PB001_no_templates() {
159        let result = resolve_template("echo hello world", &HashMap::new(), &None, &[], &[])
160            .expect("unexpected failure");
161        assert_eq!(result, "echo hello world");
162    }
163
164    #[test]
165    fn test_PB001_missing_param_error() {
166        let err = resolve_template("echo {{params.missing}}", &HashMap::new(), &None, &[], &[])
167            .unwrap_err();
168        assert!(err.to_string().contains("undefined param"));
169    }
170
171    #[test]
172    fn test_PB001_deps_out_of_range() {
173        let err =
174            resolve_template("cat {{deps[5].path}}", &HashMap::new(), &None, &[], &[]).unwrap_err();
175        assert!(err.to_string().contains("out of range"));
176    }
177
178    #[test]
179    fn test_PB001_outs_out_of_range() {
180        let err =
181            resolve_template("cat {{outs[0].path}}", &HashMap::new(), &None, &[], &[]).unwrap_err();
182        assert!(err.to_string().contains("out of range"));
183    }
184
185    #[test]
186    fn test_PB001_unclosed_template() {
187        let err =
188            resolve_template("echo {{params.model", &HashMap::new(), &None, &[], &[]).unwrap_err();
189        assert!(err.to_string().contains("unclosed"));
190    }
191
192    #[test]
193    fn test_PB001_whitespace_in_template() {
194        let global = make_params(&[("name", "world")]);
195        let result = resolve_template("echo {{ params.name }}", &global, &None, &[], &[])
196            .expect("unexpected failure");
197        assert_eq!(result, "echo world");
198    }
199
200    #[test]
201    fn test_PB001_unicode_safe() {
202        let global = make_params(&[("name", "héllo")]);
203        let result = resolve_template("echo {{params.name}} — résumé", &global, &None, &[], &[])
204            .expect("unexpected failure");
205        assert_eq!(result, "echo héllo — résumé");
206    }
207}