Skip to main content

harn_vm/stdlib/template/
lint.rs

1//! AST surface that `harn-lint` consumes to enforce `.harn.prompt`
2//! drift-prevention rules (#1669).
3//!
4//! The template parser and AST are otherwise internal — exposing a
5//! shallow read-only view through this module keeps the lint crate
6//! free of template-engine internals while still giving rules enough
7//! structure to walk conditionals, sections, and includes.
8
9use super::ast::{BinOp, Expr, Node, PathSeg};
10use super::parser::parse as parse_template;
11use crate::runtime_limits::RuntimeLimits;
12
13const TEMPLATE_LINT_AST_MAX_DEPTH: usize = RuntimeLimits::DEFAULT.max_template_ast_depth;
14
15/// Parse a template source string into a flat list of lintable
16/// constructs (conditionals + sections). Returns `Err` when the
17/// template doesn't parse — callers should surface the underlying
18/// `validate_template_syntax` error to the user before linting.
19pub fn parse(src: &str) -> Result<Vec<LintConstruct>, String> {
20    let nodes = parse_template(src).map_err(|error| error.message())?;
21    let mut out = Vec::new();
22    walk_nodes(&nodes, &mut out, 0)?;
23    Ok(out)
24}
25
26/// One lintable construct, materialized in source order so rules can
27/// reason about counts (e.g. branch-explosion) and individual call
28/// sites (e.g. provider-identity comparisons).
29#[derive(Debug, Clone)]
30pub enum LintConstruct {
31    /// An `{{ if .. }}` / `{{ elif }}` chain. One entry per condition
32    /// in the chain (the trailing `{{ else }}` is implicit and not
33    /// listed). Conditions are flattened across `elif` to make
34    /// branch-count rules straightforward.
35    IfChain { branches: Vec<IfBranch> },
36    /// A `{{ section "..." }}` block. Sections are themselves
37    /// capability-adaptive but never look identity-driven; rules use
38    /// this to count capability-aware partials.
39    Section {
40        name: String,
41        line: usize,
42        col: usize,
43    },
44}
45
46#[derive(Debug, Clone)]
47pub struct IfBranch {
48    pub line: usize,
49    pub col: usize,
50    pub condition: ConditionShape,
51}
52
53/// Coarse classification of an `{{ if expr }}` condition. The lint
54/// rules don't need to evaluate or fully reconstruct expressions —
55/// just enough structure to detect the two failure patterns called
56/// out in #1669:
57///
58/// - Identity comparisons (`llm.provider == "..."`).
59/// - Capability-flag branches (`llm.capabilities.<flag>`), which the
60///   variant-explosion rule counts.
61///
62/// Conditions outside these shapes resolve to `Other` and don't
63/// participate in either rule.
64#[derive(Debug, Clone)]
65pub enum ConditionShape {
66    /// `llm.provider == "..."` / `llm.model == "..."` /
67    /// `llm.family == "..."` (or `!=`).
68    ProviderIdentity(IdentityField),
69    /// Any path-based condition mentioning `llm.capabilities.<flag>`
70    /// (including negation and use as a comparison operand). The
71    /// variant-explosion rule counts every branch with this shape.
72    /// Source position lives on the surrounding [`IfBranch`].
73    CapabilityFlag {
74        flag: String,
75    },
76    Other,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub enum IdentityField {
81    Provider,
82    Model,
83    Family,
84}
85
86impl IdentityField {
87    pub fn as_str(self) -> &'static str {
88        match self {
89            IdentityField::Provider => "provider",
90            IdentityField::Model => "model",
91            IdentityField::Family => "family",
92        }
93    }
94}
95
96fn walk_nodes(nodes: &[Node], out: &mut Vec<LintConstruct>, depth: usize) -> Result<(), String> {
97    for node in nodes {
98        walk_node(node, out, depth)?;
99    }
100    Ok(())
101}
102
103fn walk_node(node: &Node, out: &mut Vec<LintConstruct>, depth: usize) -> Result<(), String> {
104    if depth > TEMPLATE_LINT_AST_MAX_DEPTH {
105        return Err(lint_depth_error(node));
106    }
107
108    match node {
109        Node::Text(_) | Node::Expr { .. } | Node::LegacyBareInterp { .. } => {}
110        Node::If {
111            branches,
112            else_branch,
113            line: _,
114            col: _,
115        } => {
116            let mut summary = Vec::with_capacity(branches.len());
117            for branch in branches {
118                summary.push(IfBranch {
119                    line: branch.line,
120                    col: branch.col,
121                    condition: classify_condition(&branch.cond),
122                });
123                walk_nodes(&branch.body, out, depth + 1)?;
124            }
125            out.push(LintConstruct::IfChain { branches: summary });
126            if let Some(else_body) = else_branch {
127                walk_nodes(else_body, out, depth + 1)?;
128            }
129        }
130        Node::For { body, empty, .. } => {
131            walk_nodes(body, out, depth + 1)?;
132            if let Some(empty) = empty {
133                walk_nodes(empty, out, depth + 1)?;
134            }
135        }
136        Node::Include { .. } => {
137            // Include resolution happens at render time. Linting only
138            // walks the calling template; the included partial gets
139            // linted independently when the linter encounters it.
140        }
141        Node::Section {
142            name,
143            body,
144            line,
145            col,
146            ..
147        } => {
148            out.push(LintConstruct::Section {
149                name: name.clone(),
150                line: *line,
151                col: *col,
152            });
153            walk_nodes(body, out, depth + 1)?;
154        }
155    }
156    Ok(())
157}
158
159fn lint_depth_error(node: &Node) -> String {
160    let prefix = format!("template lint AST depth exceeded ({TEMPLATE_LINT_AST_MAX_DEPTH} levels)");
161    match node_location(node) {
162        Some((line, col)) => format!("{prefix} at {line}:{col}"),
163        None => prefix,
164    }
165}
166
167fn node_location(node: &Node) -> Option<(usize, usize)> {
168    match node {
169        Node::Expr { line, col, .. }
170        | Node::If { line, col, .. }
171        | Node::For { line, col, .. }
172        | Node::Include { line, col, .. }
173        | Node::Section { line, col, .. } => Some((*line, *col)),
174        Node::Text(_) | Node::LegacyBareInterp { .. } => None,
175    }
176}
177
178/// Classify the top-level shape of an `{{ if expr }}` condition.
179fn classify_condition(expr: &Expr) -> ConditionShape {
180    if let Some(identity) = match_identity_compare(expr) {
181        return ConditionShape::ProviderIdentity(identity);
182    }
183    if let Some(capability) = match_capability_path(expr) {
184        return capability;
185    }
186    ConditionShape::Other
187}
188
189/// Match `llm.<provider|model|family> == "..."` or `!= "..."`,
190/// returning the LHS identity field that was compared.
191fn match_identity_compare(expr: &Expr) -> Option<IdentityField> {
192    let Expr::Binary(op, lhs, rhs) = expr else {
193        return None;
194    };
195    if !matches!(op, BinOp::Eq | BinOp::Neq) {
196        return None;
197    }
198    let path = match (lhs.as_ref(), rhs.as_ref()) {
199        (Expr::Path(p), Expr::Str(_)) | (Expr::Str(_), Expr::Path(p)) => p,
200        _ => return None,
201    };
202    if !path_starts_with_llm(path) {
203        return None;
204    }
205    match path.get(1) {
206        Some(PathSeg::Field(name) | PathSeg::Key(name)) if name == "provider" => {
207            Some(IdentityField::Provider)
208        }
209        Some(PathSeg::Field(name) | PathSeg::Key(name)) if name == "model" => {
210            Some(IdentityField::Model)
211        }
212        Some(PathSeg::Field(name) | PathSeg::Key(name)) if name == "family" => {
213            Some(IdentityField::Family)
214        }
215        _ => None,
216    }
217}
218
219/// Match `llm.capabilities.<flag>` (possibly negated by `!`) or
220/// `llm.capabilities.<flag> == <literal>`, returning the flag name.
221fn match_capability_path(expr: &Expr) -> Option<ConditionShape> {
222    fn find_capability_path(expr: &Expr) -> Option<String> {
223        let mut stack = vec![expr];
224        while let Some(expr) = stack.pop() {
225            match expr {
226                Expr::Path(path) => {
227                    if let Some(flag) = capability_flag_from_path(path) {
228                        return Some(flag);
229                    }
230                }
231                Expr::Unary(_, inner) => stack.push(inner),
232                Expr::Binary(_, lhs, rhs) => {
233                    stack.push(rhs);
234                    stack.push(lhs);
235                }
236                Expr::Filter(inner, _, _) => stack.push(inner),
237                _ => {}
238            }
239        }
240        None
241    }
242    let flag = find_capability_path(expr)?;
243    Some(ConditionShape::CapabilityFlag { flag })
244}
245
246fn capability_flag_from_path(path: &[PathSeg]) -> Option<String> {
247    if !path_starts_with_llm(path) {
248        return None;
249    }
250    let Some(PathSeg::Field(name) | PathSeg::Key(name)) = path.get(1) else {
251        return None;
252    };
253    if name != "capabilities" {
254        return None;
255    }
256    let Some(PathSeg::Field(flag) | PathSeg::Key(flag)) = path.get(2) else {
257        return None;
258    };
259    Some(flag.clone())
260}
261
262fn path_starts_with_llm(path: &[PathSeg]) -> bool {
263    matches!(
264        path.first(),
265        Some(PathSeg::Field(name)) if name == "llm",
266    )
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    fn parse_ok(src: &str) -> Vec<LintConstruct> {
274        parse(src).expect("template should parse")
275    }
276
277    fn first_if(constructs: &[LintConstruct]) -> &[IfBranch] {
278        match constructs
279            .iter()
280            .find(|c| matches!(c, LintConstruct::IfChain { .. }))
281            .expect("if chain present")
282        {
283            LintConstruct::IfChain { branches } => branches.as_slice(),
284            _ => unreachable!(),
285        }
286    }
287
288    #[test]
289    fn provider_identity_eq_detected() {
290        let constructs = parse_ok("{{ if llm.provider == \"anthropic\" }}x{{ else }}y{{ end }}");
291        let branches = first_if(&constructs);
292        assert_eq!(branches.len(), 1);
293        assert!(matches!(
294            branches[0].condition,
295            ConditionShape::ProviderIdentity(IdentityField::Provider)
296        ));
297    }
298
299    #[test]
300    fn model_identity_neq_detected() {
301        let constructs = parse_ok("{{ if llm.model != \"gpt-5\" }}x{{ end }}");
302        let branches = first_if(&constructs);
303        assert!(matches!(
304            branches[0].condition,
305            ConditionShape::ProviderIdentity(IdentityField::Model)
306        ));
307    }
308
309    #[test]
310    fn capability_flag_detected_in_negation_and_filter() {
311        let constructs = parse_ok(
312            "{{ if !llm.capabilities.native_tools }}x{{ end }}\
313             {{ if llm.capabilities.prefers_xml_scaffolding | default: false }}y{{ end }}",
314        );
315        let if_chains: Vec<_> = constructs
316            .iter()
317            .filter_map(|c| match c {
318                LintConstruct::IfChain { branches } => Some(branches.clone()),
319                _ => None,
320            })
321            .collect();
322        assert_eq!(if_chains.len(), 2);
323        assert!(matches!(
324            if_chains[0][0].condition,
325            ConditionShape::CapabilityFlag { ref flag, .. } if flag == "native_tools"
326        ));
327        assert!(matches!(
328            if_chains[1][0].condition,
329            ConditionShape::CapabilityFlag { ref flag, .. } if flag == "prefers_xml_scaffolding"
330        ));
331    }
332
333    #[test]
334    fn capability_flag_detection_handles_wide_binary_expression() {
335        let mut terms = (0..300).map(|idx| format!("flag{idx}")).collect::<Vec<_>>();
336        terms.push("llm.capabilities.native_tools".to_string());
337        let src = format!("{{{{ if {} }}}}x{{{{ end }}}}", terms.join(" or "));
338
339        let constructs = parse_ok(&src);
340        let branches = first_if(&constructs);
341
342        assert!(matches!(
343            branches[0].condition,
344            ConditionShape::CapabilityFlag { ref flag, .. } if flag == "native_tools"
345        ));
346    }
347
348    #[test]
349    fn parse_reports_template_control_depth_limit() {
350        let depth = RuntimeLimits::DEFAULT.max_template_ast_depth + 1;
351        let mut src = String::new();
352        for _ in 0..depth {
353            src.push_str("{{ if true }}");
354        }
355        src.push('x');
356        for _ in 0..depth {
357            src.push_str("{{ end }}");
358        }
359
360        let err = parse(&src).expect_err("depth limit");
361
362        assert!(err.contains("template nesting depth exceeded"));
363        assert!(err.contains(&format!(
364            "({} levels)",
365            RuntimeLimits::DEFAULT.max_template_ast_depth
366        )));
367    }
368
369    #[test]
370    fn parse_reports_template_expression_depth_limit() {
371        let depth = RuntimeLimits::DEFAULT.max_template_ast_depth + 1;
372        let condition = format!("{}llm.capabilities.native_tools", "!".repeat(depth));
373        let src = format!("{{{{ if {condition} }}}}x{{{{ end }}}}");
374
375        let err = parse(&src).expect_err("depth limit");
376
377        assert!(err.contains("template expression depth exceeded"));
378        assert!(err.contains(&format!(
379            "({} levels)",
380            RuntimeLimits::DEFAULT.max_template_ast_depth
381        )));
382    }
383
384    #[test]
385    fn elif_chain_lifts_per_branch_condition() {
386        let constructs = parse_ok(
387            "{{ if llm.provider == \"openai\" }}a\
388             {{ elif llm.capabilities.native_tools }}b\
389             {{ else }}c{{ end }}",
390        );
391        let branches = first_if(&constructs);
392        assert_eq!(branches.len(), 2);
393        assert!(matches!(
394            branches[0].condition,
395            ConditionShape::ProviderIdentity(IdentityField::Provider)
396        ));
397        assert!(matches!(
398            branches[1].condition,
399            ConditionShape::CapabilityFlag { ref flag, .. } if flag == "native_tools"
400        ));
401    }
402
403    #[test]
404    fn unrelated_condition_falls_through_to_other() {
405        let constructs = parse_ok("{{ if score > 0.5 }}a{{ end }}");
406        let branches = first_if(&constructs);
407        assert!(matches!(branches[0].condition, ConditionShape::Other));
408    }
409
410    #[test]
411    fn sections_listed_in_source_order() {
412        let constructs = parse_ok(
413            "{{ section \"task\" }}t{{ endsection }}\
414             {{ section \"output_format\" }}o{{ endsection }}",
415        );
416        let names: Vec<_> = constructs
417            .iter()
418            .filter_map(|c| match c {
419                LintConstruct::Section { name, .. } => Some(name.clone()),
420                _ => None,
421            })
422            .collect();
423        assert_eq!(names, vec!["task", "output_format"]);
424    }
425}