1use super::ast::{BinOp, Expr, Node, PathSeg};
10use super::parser::parse as parse_template;
11use crate::runtime_limits::RuntimeLimits;
12
13const TEMPLATE_LINT_AST_MAX_DEPTH: usize = RuntimeLimits::DEFAULT.max_template_ast_depth;
14
15pub fn parse(src: &str) -> Result<Vec<LintConstruct>, String> {
20 let nodes = parse_template(src).map_err(|error| error.message())?;
21 let mut out = Vec::new();
22 walk_nodes(&nodes, &mut out, 0)?;
23 Ok(out)
24}
25
26#[derive(Debug, Clone)]
30pub enum LintConstruct {
31 IfChain { branches: Vec<IfBranch> },
36 Section {
40 name: String,
41 line: usize,
42 col: usize,
43 },
44}
45
46#[derive(Debug, Clone)]
47pub struct IfBranch {
48 pub line: usize,
49 pub col: usize,
50 pub condition: ConditionShape,
51}
52
53#[derive(Debug, Clone)]
65pub enum ConditionShape {
66 ProviderIdentity(IdentityField),
69 CapabilityFlag {
74 flag: String,
75 },
76 Other,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub enum IdentityField {
81 Provider,
82 Model,
83 Family,
84}
85
86impl IdentityField {
87 pub fn as_str(self) -> &'static str {
88 match self {
89 IdentityField::Provider => "provider",
90 IdentityField::Model => "model",
91 IdentityField::Family => "family",
92 }
93 }
94}
95
96fn walk_nodes(nodes: &[Node], out: &mut Vec<LintConstruct>, depth: usize) -> Result<(), String> {
97 for node in nodes {
98 walk_node(node, out, depth)?;
99 }
100 Ok(())
101}
102
103fn walk_node(node: &Node, out: &mut Vec<LintConstruct>, depth: usize) -> Result<(), String> {
104 if depth > TEMPLATE_LINT_AST_MAX_DEPTH {
105 return Err(lint_depth_error(node));
106 }
107
108 match node {
109 Node::Text(_) | Node::Expr { .. } | Node::LegacyBareInterp { .. } => {}
110 Node::If {
111 branches,
112 else_branch,
113 line: _,
114 col: _,
115 } => {
116 let mut summary = Vec::with_capacity(branches.len());
117 for branch in branches {
118 summary.push(IfBranch {
119 line: branch.line,
120 col: branch.col,
121 condition: classify_condition(&branch.cond),
122 });
123 walk_nodes(&branch.body, out, depth + 1)?;
124 }
125 out.push(LintConstruct::IfChain { branches: summary });
126 if let Some(else_body) = else_branch {
127 walk_nodes(else_body, out, depth + 1)?;
128 }
129 }
130 Node::For { body, empty, .. } => {
131 walk_nodes(body, out, depth + 1)?;
132 if let Some(empty) = empty {
133 walk_nodes(empty, out, depth + 1)?;
134 }
135 }
136 Node::Include { .. } => {
137 }
141 Node::Section {
142 name,
143 body,
144 line,
145 col,
146 ..
147 } => {
148 out.push(LintConstruct::Section {
149 name: name.clone(),
150 line: *line,
151 col: *col,
152 });
153 walk_nodes(body, out, depth + 1)?;
154 }
155 }
156 Ok(())
157}
158
159fn lint_depth_error(node: &Node) -> String {
160 let prefix = format!("template lint AST depth exceeded ({TEMPLATE_LINT_AST_MAX_DEPTH} levels)");
161 match node_location(node) {
162 Some((line, col)) => format!("{prefix} at {line}:{col}"),
163 None => prefix,
164 }
165}
166
167fn node_location(node: &Node) -> Option<(usize, usize)> {
168 match node {
169 Node::Expr { line, col, .. }
170 | Node::If { line, col, .. }
171 | Node::For { line, col, .. }
172 | Node::Include { line, col, .. }
173 | Node::Section { line, col, .. } => Some((*line, *col)),
174 Node::Text(_) | Node::LegacyBareInterp { .. } => None,
175 }
176}
177
178fn classify_condition(expr: &Expr) -> ConditionShape {
180 if let Some(identity) = match_identity_compare(expr) {
181 return ConditionShape::ProviderIdentity(identity);
182 }
183 if let Some(capability) = match_capability_path(expr) {
184 return capability;
185 }
186 ConditionShape::Other
187}
188
189fn match_identity_compare(expr: &Expr) -> Option<IdentityField> {
192 let Expr::Binary(op, lhs, rhs) = expr else {
193 return None;
194 };
195 if !matches!(op, BinOp::Eq | BinOp::Neq) {
196 return None;
197 }
198 let path = match (lhs.as_ref(), rhs.as_ref()) {
199 (Expr::Path(p), Expr::Str(_)) | (Expr::Str(_), Expr::Path(p)) => p,
200 _ => return None,
201 };
202 if !path_starts_with_llm(path) {
203 return None;
204 }
205 match path.get(1) {
206 Some(PathSeg::Field(name) | PathSeg::Key(name)) if name == "provider" => {
207 Some(IdentityField::Provider)
208 }
209 Some(PathSeg::Field(name) | PathSeg::Key(name)) if name == "model" => {
210 Some(IdentityField::Model)
211 }
212 Some(PathSeg::Field(name) | PathSeg::Key(name)) if name == "family" => {
213 Some(IdentityField::Family)
214 }
215 _ => None,
216 }
217}
218
219fn match_capability_path(expr: &Expr) -> Option<ConditionShape> {
222 fn find_capability_path(expr: &Expr) -> Option<String> {
223 let mut stack = vec![expr];
224 while let Some(expr) = stack.pop() {
225 match expr {
226 Expr::Path(path) => {
227 if let Some(flag) = capability_flag_from_path(path) {
228 return Some(flag);
229 }
230 }
231 Expr::Unary(_, inner) => stack.push(inner),
232 Expr::Binary(_, lhs, rhs) => {
233 stack.push(rhs);
234 stack.push(lhs);
235 }
236 Expr::Filter(inner, _, _) => stack.push(inner),
237 _ => {}
238 }
239 }
240 None
241 }
242 let flag = find_capability_path(expr)?;
243 Some(ConditionShape::CapabilityFlag { flag })
244}
245
246fn capability_flag_from_path(path: &[PathSeg]) -> Option<String> {
247 if !path_starts_with_llm(path) {
248 return None;
249 }
250 let Some(PathSeg::Field(name) | PathSeg::Key(name)) = path.get(1) else {
251 return None;
252 };
253 if name != "capabilities" {
254 return None;
255 }
256 let Some(PathSeg::Field(flag) | PathSeg::Key(flag)) = path.get(2) else {
257 return None;
258 };
259 Some(flag.clone())
260}
261
262fn path_starts_with_llm(path: &[PathSeg]) -> bool {
263 matches!(
264 path.first(),
265 Some(PathSeg::Field(name)) if name == "llm",
266 )
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 fn parse_ok(src: &str) -> Vec<LintConstruct> {
274 parse(src).expect("template should parse")
275 }
276
277 fn first_if(constructs: &[LintConstruct]) -> &[IfBranch] {
278 match constructs
279 .iter()
280 .find(|c| matches!(c, LintConstruct::IfChain { .. }))
281 .expect("if chain present")
282 {
283 LintConstruct::IfChain { branches } => branches.as_slice(),
284 _ => unreachable!(),
285 }
286 }
287
288 #[test]
289 fn provider_identity_eq_detected() {
290 let constructs = parse_ok("{{ if llm.provider == \"anthropic\" }}x{{ else }}y{{ end }}");
291 let branches = first_if(&constructs);
292 assert_eq!(branches.len(), 1);
293 assert!(matches!(
294 branches[0].condition,
295 ConditionShape::ProviderIdentity(IdentityField::Provider)
296 ));
297 }
298
299 #[test]
300 fn model_identity_neq_detected() {
301 let constructs = parse_ok("{{ if llm.model != \"gpt-5\" }}x{{ end }}");
302 let branches = first_if(&constructs);
303 assert!(matches!(
304 branches[0].condition,
305 ConditionShape::ProviderIdentity(IdentityField::Model)
306 ));
307 }
308
309 #[test]
310 fn capability_flag_detected_in_negation_and_filter() {
311 let constructs = parse_ok(
312 "{{ if !llm.capabilities.native_tools }}x{{ end }}\
313 {{ if llm.capabilities.prefers_xml_scaffolding | default: false }}y{{ end }}",
314 );
315 let if_chains: Vec<_> = constructs
316 .iter()
317 .filter_map(|c| match c {
318 LintConstruct::IfChain { branches } => Some(branches.clone()),
319 _ => None,
320 })
321 .collect();
322 assert_eq!(if_chains.len(), 2);
323 assert!(matches!(
324 if_chains[0][0].condition,
325 ConditionShape::CapabilityFlag { ref flag, .. } if flag == "native_tools"
326 ));
327 assert!(matches!(
328 if_chains[1][0].condition,
329 ConditionShape::CapabilityFlag { ref flag, .. } if flag == "prefers_xml_scaffolding"
330 ));
331 }
332
333 #[test]
334 fn capability_flag_detection_handles_wide_binary_expression() {
335 let mut terms = (0..300).map(|idx| format!("flag{idx}")).collect::<Vec<_>>();
336 terms.push("llm.capabilities.native_tools".to_string());
337 let src = format!("{{{{ if {} }}}}x{{{{ end }}}}", terms.join(" or "));
338
339 let constructs = parse_ok(&src);
340 let branches = first_if(&constructs);
341
342 assert!(matches!(
343 branches[0].condition,
344 ConditionShape::CapabilityFlag { ref flag, .. } if flag == "native_tools"
345 ));
346 }
347
348 #[test]
349 fn parse_reports_template_control_depth_limit() {
350 let depth = RuntimeLimits::DEFAULT.max_template_ast_depth + 1;
351 let mut src = String::new();
352 for _ in 0..depth {
353 src.push_str("{{ if true }}");
354 }
355 src.push('x');
356 for _ in 0..depth {
357 src.push_str("{{ end }}");
358 }
359
360 let err = parse(&src).expect_err("depth limit");
361
362 assert!(err.contains("template nesting depth exceeded"));
363 assert!(err.contains(&format!(
364 "({} levels)",
365 RuntimeLimits::DEFAULT.max_template_ast_depth
366 )));
367 }
368
369 #[test]
370 fn parse_reports_template_expression_depth_limit() {
371 let depth = RuntimeLimits::DEFAULT.max_template_ast_depth + 1;
372 let condition = format!("{}llm.capabilities.native_tools", "!".repeat(depth));
373 let src = format!("{{{{ if {condition} }}}}x{{{{ end }}}}");
374
375 let err = parse(&src).expect_err("depth limit");
376
377 assert!(err.contains("template expression depth exceeded"));
378 assert!(err.contains(&format!(
379 "({} levels)",
380 RuntimeLimits::DEFAULT.max_template_ast_depth
381 )));
382 }
383
384 #[test]
385 fn elif_chain_lifts_per_branch_condition() {
386 let constructs = parse_ok(
387 "{{ if llm.provider == \"openai\" }}a\
388 {{ elif llm.capabilities.native_tools }}b\
389 {{ else }}c{{ end }}",
390 );
391 let branches = first_if(&constructs);
392 assert_eq!(branches.len(), 2);
393 assert!(matches!(
394 branches[0].condition,
395 ConditionShape::ProviderIdentity(IdentityField::Provider)
396 ));
397 assert!(matches!(
398 branches[1].condition,
399 ConditionShape::CapabilityFlag { ref flag, .. } if flag == "native_tools"
400 ));
401 }
402
403 #[test]
404 fn unrelated_condition_falls_through_to_other() {
405 let constructs = parse_ok("{{ if score > 0.5 }}a{{ end }}");
406 let branches = first_if(&constructs);
407 assert!(matches!(branches[0].condition, ConditionShape::Other));
408 }
409
410 #[test]
411 fn sections_listed_in_source_order() {
412 let constructs = parse_ok(
413 "{{ section \"task\" }}t{{ endsection }}\
414 {{ section \"output_format\" }}o{{ endsection }}",
415 );
416 let names: Vec<_> = constructs
417 .iter()
418 .filter_map(|c| match c {
419 LintConstruct::Section { name, .. } => Some(name.clone()),
420 _ => None,
421 })
422 .collect();
423 assert_eq!(names, vec!["task", "output_format"]);
424 }
425}