Skip to main content

harn_rules/
constraint.rs

1//! `where` constraints: predicates on captured metavars (Semgrep
2//! `metavariable-regex` / `metavariable-comparison` / `metavariable-pattern`).
3//!
4//! A match survives only when every constraint holds. Constraints are
5//! compiled once (regex compiled, sub-pattern lowered to a tree-sitter
6//! query) and evaluated against each match's metavar bindings.
7
8use regex::Regex;
9use streaming_iterator::StreamingIterator;
10use tree_sitter::{Query, QueryCursor};
11
12use harn_hostlib::ast::{api, Language};
13
14use crate::engine::{Binding, ResolvedBinding};
15use crate::error::RulesError;
16use crate::model::{Constraint, ResolvedBindingConstraint};
17use crate::pattern::compile_pattern;
18
19/// A compiled `where` constraint bound to one metavar.
20pub struct CompiledConstraint {
21    /// The metavar this constraint filters on (without `$`).
22    pub metavar: String,
23    kind: Kind,
24}
25
26enum Kind {
27    Regex(Regex),
28    Comparison { op: CmpOp, value: toml::Value },
29    SubPattern { language: Language, query: Query },
30    ResolvesTo(ResolvedBindingConstraint),
31    Type(String),
32}
33
34#[derive(Clone, Copy)]
35enum CmpOp {
36    Lt,
37    Le,
38    Gt,
39    Ge,
40    Eq,
41    Ne,
42}
43
44impl CmpOp {
45    fn parse(op: &str) -> Option<Self> {
46        Some(match op {
47            "<" => CmpOp::Lt,
48            "<=" => CmpOp::Le,
49            ">" => CmpOp::Gt,
50            ">=" => CmpOp::Ge,
51            "==" => CmpOp::Eq,
52            "!=" => CmpOp::Ne,
53            _ => return None,
54        })
55    }
56}
57
58impl CompiledConstraint {
59    /// Compile a constraint. `default_language` is the rule's language,
60    /// used for a sub-pattern that does not name its own.
61    pub fn compile(
62        rule_id: &str,
63        default_language: Language,
64        constraint: &Constraint,
65    ) -> Result<Self, RulesError> {
66        let err = |message: String| RulesError::PatternCompile {
67            rule: rule_id.to_string(),
68            message,
69        };
70
71        let set = [
72            constraint.regex.is_some(),
73            constraint.comparison.is_some(),
74            constraint.pattern.is_some(),
75            constraint.resolves_to.is_some(),
76            constraint.type_.is_some(),
77        ]
78        .into_iter()
79        .filter(|b| *b)
80        .count();
81        if set != 1 {
82            return Err(err(format!(
83                "where-constraint on `{}` must set exactly one of `regex` / `comparison` / `pattern` / `resolves_to` / `type`",
84                constraint.metavar
85            )));
86        }
87
88        let kind = if let Some(re) = &constraint.regex {
89            Kind::Regex(
90                Regex::new(re)
91                    .map_err(|e| err(format!("constraint regex `{re}` is invalid: {e}")))?,
92            )
93        } else if let Some(cmp) = &constraint.comparison {
94            let op = CmpOp::parse(&cmp.op)
95                .ok_or_else(|| err(format!("unknown comparison operator `{}`", cmp.op)))?;
96            Kind::Comparison {
97                op,
98                value: cmp.value.clone(),
99            }
100        } else if let Some(snippet) = &constraint.pattern {
101            let language = match &constraint.language {
102                Some(name) => Language::from_name(name)
103                    .ok_or_else(|| err(format!("unknown sub-pattern language `{name}`")))?,
104                None => default_language,
105            };
106            let ts_language = language
107                .ts_language()
108                .ok_or_else(|| err(format!("grammar for `{}` is unavailable", language.name())))?;
109            let compiled = compile_pattern(snippet, language)
110                .map_err(|m| err(format!("sub-pattern on `{}`: {m}", constraint.metavar)))?;
111            let query = Query::new(&ts_language, &compiled.query)
112                .map_err(|e| err(format!("sub-pattern query rejected: {e}")))?;
113            Kind::SubPattern { language, query }
114        } else if let Some(resolves_to) = &constraint.resolves_to {
115            if default_language != Language::Harn {
116                return Err(err(format!(
117                    "`resolves_to` on `{}` is only supported for Harn rules",
118                    constraint.metavar
119                )));
120            }
121            if resolves_to.id.is_none()
122                && resolves_to.name.is_none()
123                && resolves_to.kind.is_none()
124                && resolves_to.line.is_none()
125                && resolves_to.column.is_none()
126            {
127                return Err(err(format!(
128                    "`resolves_to` on `{}` must set at least one identity field",
129                    constraint.metavar
130                )));
131            }
132            Kind::ResolvesTo(resolves_to.clone())
133        } else {
134            if default_language != Language::Harn {
135                return Err(err(format!(
136                    "`type` on `{}` is only supported for Harn rules",
137                    constraint.metavar
138                )));
139            }
140            let expected = constraint.type_.as_ref().unwrap();
141            if expected.trim().is_empty() {
142                return Err(err(format!(
143                    "`type` on `{}` must not be empty",
144                    constraint.metavar
145                )));
146            }
147            Kind::Type(expected.clone())
148        };
149
150        Ok(CompiledConstraint {
151            metavar: constraint.metavar.clone(),
152            kind,
153        })
154    }
155
156    /// Evaluate the constraint against a metavar binding.
157    pub fn evaluate(&self, binding: &Binding) -> bool {
158        match &self.kind {
159            Kind::Regex(re) => re.is_match(&binding.text),
160            Kind::Comparison { op, value } => evaluate_comparison(*op, &binding.text, value),
161            Kind::SubPattern { language, query } => {
162                let Ok(tree) = api::parse_tree(&binding.text, *language) else {
163                    return false;
164                };
165                let mut cursor = QueryCursor::new();
166                let mut it = cursor.matches(query, tree.root_node(), binding.text.as_bytes());
167                it.next().is_some()
168            }
169            Kind::ResolvesTo(expected) => binding
170                .metadata
171                .resolved
172                .as_ref()
173                .is_some_and(|actual| resolved_matches(expected, actual)),
174            Kind::Type(expected) => binding
175                .metadata
176                .ty
177                .as_ref()
178                .is_some_and(|actual| actual == expected),
179        }
180    }
181}
182
183fn resolved_matches(expected: &ResolvedBindingConstraint, actual: &ResolvedBinding) -> bool {
184    expected.id.as_ref().is_none_or(|id| id == &actual.id)
185        && expected
186            .name
187            .as_ref()
188            .is_none_or(|name| name == &actual.name)
189        && expected
190            .kind
191            .as_ref()
192            .is_none_or(|kind| kind == &actual.kind)
193        && expected
194            .line
195            .is_none_or(|line| line == actual.span.start_row + 1)
196        && expected
197            .column
198            .is_none_or(|column| column == actual.span.start_col + 1)
199}
200
201fn evaluate_comparison(op: CmpOp, text: &str, value: &toml::Value) -> bool {
202    // Numeric comparison when the RHS is a number and the captured text
203    // parses as one; otherwise fall back to string equality for `==` / `!=`.
204    if let Some(rhs) = value
205        .as_float()
206        .or_else(|| value.as_integer().map(|i| i as f64))
207    {
208        if let Ok(lhs) = text.trim().parse::<f64>() {
209            return match op {
210                CmpOp::Lt => lhs < rhs,
211                CmpOp::Le => lhs <= rhs,
212                CmpOp::Gt => lhs > rhs,
213                CmpOp::Ge => lhs >= rhs,
214                CmpOp::Eq => (lhs - rhs).abs() < f64::EPSILON,
215                CmpOp::Ne => (lhs - rhs).abs() >= f64::EPSILON,
216            };
217        }
218        // RHS numeric but LHS not a number: only `!=` can be satisfied.
219        return matches!(op, CmpOp::Ne);
220    }
221
222    let rhs = match value {
223        toml::Value::String(s) => s.clone(),
224        toml::Value::Boolean(b) => b.to_string(),
225        other => other.to_string(),
226    };
227    match op {
228        CmpOp::Eq => text == rhs,
229        CmpOp::Ne => text != rhs,
230        // Ordering on non-numbers falls back to lexicographic compare.
231        CmpOp::Lt => text < rhs.as_str(),
232        CmpOp::Le => text <= rhs.as_str(),
233        CmpOp::Gt => text > rhs.as_str(),
234        CmpOp::Ge => text >= rhs.as_str(),
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::engine::{BindingMetadata, Span};
242    use crate::model::Comparison;
243
244    fn binding(text: &str) -> Binding {
245        Binding {
246            text: text.into(),
247            span: Span {
248                start_byte: 0,
249                end_byte: text.len(),
250                start_row: 0,
251                start_col: 0,
252                end_row: 0,
253                end_col: text.len(),
254            },
255            metadata: BindingMetadata::default(),
256        }
257    }
258
259    fn regex_constraint(metavar: &str, re: &str) -> CompiledConstraint {
260        let c = Constraint {
261            metavar: metavar.into(),
262            regex: Some(re.into()),
263            comparison: None,
264            pattern: None,
265            resolves_to: None,
266            type_: None,
267            language: None,
268        };
269        CompiledConstraint::compile("r", Language::Rust, &c).unwrap()
270    }
271
272    #[test]
273    fn regex_constraint_matches() {
274        let c = regex_constraint("KEY", "^[a-z][a-zA-Z]*$");
275        assert!(c.evaluate(&binding("userId")));
276        assert!(!c.evaluate(&binding("0bad")));
277    }
278
279    #[test]
280    fn numeric_comparison() {
281        let c = Constraint {
282            metavar: "N".into(),
283            regex: None,
284            comparison: Some(Comparison {
285                op: ">".into(),
286                value: toml::Value::Integer(0),
287            }),
288            pattern: None,
289            resolves_to: None,
290            type_: None,
291            language: None,
292        };
293        let c = CompiledConstraint::compile("r", Language::Rust, &c).unwrap();
294        assert!(c.evaluate(&binding("5")));
295        assert!(!c.evaluate(&binding("0")));
296        assert!(!c.evaluate(&binding("-3")));
297    }
298
299    #[test]
300    fn string_equality_comparison() {
301        let c = Constraint {
302            metavar: "S".into(),
303            regex: None,
304            comparison: Some(Comparison {
305                op: "!=".into(),
306                value: toml::Value::String("nil".into()),
307            }),
308            pattern: None,
309            resolves_to: None,
310            type_: None,
311            language: None,
312        };
313        let c = CompiledConstraint::compile("r", Language::Rust, &c).unwrap();
314        assert!(c.evaluate(&binding("something")));
315        assert!(!c.evaluate(&binding("nil")));
316    }
317
318    #[test]
319    fn sub_pattern_constraint() {
320        // The captured metavar text must itself be a call expression.
321        let c = Constraint {
322            metavar: "VALUE".into(),
323            regex: None,
324            comparison: None,
325            pattern: Some("$FN($ARG)".into()),
326            resolves_to: None,
327            type_: None,
328            language: Some("typescript".into()),
329        };
330        let c = CompiledConstraint::compile("r", Language::TypeScript, &c).unwrap();
331        assert!(c.evaluate(&binding("compute(x)")));
332        assert!(!c.evaluate(&binding("42")));
333    }
334
335    #[test]
336    fn rejects_zero_or_multiple_kinds() {
337        let none = Constraint {
338            metavar: "X".into(),
339            regex: None,
340            comparison: None,
341            pattern: None,
342            resolves_to: None,
343            type_: None,
344            language: None,
345        };
346        assert!(CompiledConstraint::compile("r", Language::Rust, &none).is_err());
347    }
348}