1use regex::Regex;
9use streaming_iterator::StreamingIterator;
10use tree_sitter::{Query, QueryCursor};
11
12use harn_hostlib::ast::{api, Language};
13
14use crate::error::RulesError;
15use crate::model::Constraint;
16use crate::pattern::compile_pattern;
17
18pub struct CompiledConstraint {
20 pub metavar: String,
22 kind: Kind,
23}
24
25enum Kind {
26 Regex(Regex),
27 Comparison { op: CmpOp, value: toml::Value },
28 SubPattern { language: Language, query: Query },
29}
30
31#[derive(Clone, Copy)]
32enum CmpOp {
33 Lt,
34 Le,
35 Gt,
36 Ge,
37 Eq,
38 Ne,
39}
40
41impl CmpOp {
42 fn parse(op: &str) -> Option<Self> {
43 Some(match op {
44 "<" => CmpOp::Lt,
45 "<=" => CmpOp::Le,
46 ">" => CmpOp::Gt,
47 ">=" => CmpOp::Ge,
48 "==" => CmpOp::Eq,
49 "!=" => CmpOp::Ne,
50 _ => return None,
51 })
52 }
53}
54
55impl CompiledConstraint {
56 pub fn compile(
59 rule_id: &str,
60 default_language: Language,
61 constraint: &Constraint,
62 ) -> Result<Self, RulesError> {
63 let err = |message: String| RulesError::PatternCompile {
64 rule: rule_id.to_string(),
65 message,
66 };
67
68 let set = [
69 constraint.regex.is_some(),
70 constraint.comparison.is_some(),
71 constraint.pattern.is_some(),
72 ]
73 .into_iter()
74 .filter(|b| *b)
75 .count();
76 if set != 1 {
77 return Err(err(format!(
78 "where-constraint on `{}` must set exactly one of `regex` / `comparison` / `pattern`",
79 constraint.metavar
80 )));
81 }
82
83 let kind = if let Some(re) = &constraint.regex {
84 Kind::Regex(
85 Regex::new(re)
86 .map_err(|e| err(format!("constraint regex `{re}` is invalid: {e}")))?,
87 )
88 } else if let Some(cmp) = &constraint.comparison {
89 let op = CmpOp::parse(&cmp.op)
90 .ok_or_else(|| err(format!("unknown comparison operator `{}`", cmp.op)))?;
91 Kind::Comparison {
92 op,
93 value: cmp.value.clone(),
94 }
95 } else {
96 let snippet = constraint.pattern.as_ref().unwrap();
97 let language = match &constraint.language {
98 Some(name) => Language::from_name(name)
99 .ok_or_else(|| err(format!("unknown sub-pattern language `{name}`")))?,
100 None => default_language,
101 };
102 let ts_language = language
103 .ts_language()
104 .ok_or_else(|| err(format!("grammar for `{}` is unavailable", language.name())))?;
105 let compiled = compile_pattern(snippet, language)
106 .map_err(|m| err(format!("sub-pattern on `{}`: {m}", constraint.metavar)))?;
107 let query = Query::new(&ts_language, &compiled.query)
108 .map_err(|e| err(format!("sub-pattern query rejected: {e}")))?;
109 Kind::SubPattern { language, query }
110 };
111
112 Ok(CompiledConstraint {
113 metavar: constraint.metavar.clone(),
114 kind,
115 })
116 }
117
118 pub fn evaluate(&self, text: &str) -> bool {
120 match &self.kind {
121 Kind::Regex(re) => re.is_match(text),
122 Kind::Comparison { op, value } => evaluate_comparison(*op, text, value),
123 Kind::SubPattern { language, query } => {
124 let Ok(tree) = api::parse_tree(text, *language) else {
125 return false;
126 };
127 let mut cursor = QueryCursor::new();
128 let mut it = cursor.matches(query, tree.root_node(), text.as_bytes());
129 it.next().is_some()
130 }
131 }
132 }
133}
134
135fn evaluate_comparison(op: CmpOp, text: &str, value: &toml::Value) -> bool {
136 if let Some(rhs) = value
139 .as_float()
140 .or_else(|| value.as_integer().map(|i| i as f64))
141 {
142 if let Ok(lhs) = text.trim().parse::<f64>() {
143 return match op {
144 CmpOp::Lt => lhs < rhs,
145 CmpOp::Le => lhs <= rhs,
146 CmpOp::Gt => lhs > rhs,
147 CmpOp::Ge => lhs >= rhs,
148 CmpOp::Eq => (lhs - rhs).abs() < f64::EPSILON,
149 CmpOp::Ne => (lhs - rhs).abs() >= f64::EPSILON,
150 };
151 }
152 return matches!(op, CmpOp::Ne);
154 }
155
156 let rhs = match value {
157 toml::Value::String(s) => s.clone(),
158 toml::Value::Boolean(b) => b.to_string(),
159 other => other.to_string(),
160 };
161 match op {
162 CmpOp::Eq => text == rhs,
163 CmpOp::Ne => text != rhs,
164 CmpOp::Lt => text < rhs.as_str(),
166 CmpOp::Le => text <= rhs.as_str(),
167 CmpOp::Gt => text > rhs.as_str(),
168 CmpOp::Ge => text >= rhs.as_str(),
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::model::Comparison;
176
177 fn regex_constraint(metavar: &str, re: &str) -> CompiledConstraint {
178 let c = Constraint {
179 metavar: metavar.into(),
180 regex: Some(re.into()),
181 comparison: None,
182 pattern: None,
183 language: None,
184 };
185 CompiledConstraint::compile("r", Language::Rust, &c).unwrap()
186 }
187
188 #[test]
189 fn regex_constraint_matches() {
190 let c = regex_constraint("KEY", "^[a-z][a-zA-Z]*$");
191 assert!(c.evaluate("userId"));
192 assert!(!c.evaluate("0bad"));
193 }
194
195 #[test]
196 fn numeric_comparison() {
197 let c = Constraint {
198 metavar: "N".into(),
199 regex: None,
200 comparison: Some(Comparison {
201 op: ">".into(),
202 value: toml::Value::Integer(0),
203 }),
204 pattern: None,
205 language: None,
206 };
207 let c = CompiledConstraint::compile("r", Language::Rust, &c).unwrap();
208 assert!(c.evaluate("5"));
209 assert!(!c.evaluate("0"));
210 assert!(!c.evaluate("-3"));
211 }
212
213 #[test]
214 fn string_equality_comparison() {
215 let c = Constraint {
216 metavar: "S".into(),
217 regex: None,
218 comparison: Some(Comparison {
219 op: "!=".into(),
220 value: toml::Value::String("nil".into()),
221 }),
222 pattern: None,
223 language: None,
224 };
225 let c = CompiledConstraint::compile("r", Language::Rust, &c).unwrap();
226 assert!(c.evaluate("something"));
227 assert!(!c.evaluate("nil"));
228 }
229
230 #[test]
231 fn sub_pattern_constraint() {
232 let c = Constraint {
234 metavar: "VALUE".into(),
235 regex: None,
236 comparison: None,
237 pattern: Some("$FN($ARG)".into()),
238 language: Some("typescript".into()),
239 };
240 let c = CompiledConstraint::compile("r", Language::TypeScript, &c).unwrap();
241 assert!(c.evaluate("compute(x)"));
242 assert!(!c.evaluate("42"));
243 }
244
245 #[test]
246 fn rejects_zero_or_multiple_kinds() {
247 let none = Constraint {
248 metavar: "X".into(),
249 regex: None,
250 comparison: None,
251 pattern: None,
252 language: None,
253 };
254 assert!(CompiledConstraint::compile("r", Language::Rust, &none).is_err());
255 }
256}