#![allow(clippy::result_large_err)]
use crate::query::types::{Condition, Expr, Operator, Query, Value};
use thiserror::Error;
pub const QUERY_TOO_BROAD_DOC_URL: &str = "https://docs.verivus.dev/sqry/query-cost-gate";
pub const KIND_QUERY_TOO_BROAD: &str = "query_too_broad";
pub const SOURCE_STATIC_ESTIMATE: &str = "static_estimate";
pub const SCOPE_FILTER_FIELDS: &[&str] = &["kind", "lang", "language", "path", "file"];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CostGateConfig {
pub min_prefix_len: usize,
pub min_literal_len: usize,
pub node_count_threshold: Option<usize>,
}
impl CostGateConfig {
pub const DEFAULT_MIN_PREFIX_LEN: usize = 3;
pub const DEFAULT_MIN_LITERAL_LEN: usize = 4;
pub const DEFAULT_NODE_COUNT_THRESHOLD: usize = 50_000;
}
impl Default for CostGateConfig {
fn default() -> Self {
Self {
min_prefix_len: Self::DEFAULT_MIN_PREFIX_LEN,
min_literal_len: Self::DEFAULT_MIN_LITERAL_LEN,
node_count_threshold: Some(Self::DEFAULT_NODE_COUNT_THRESHOLD),
}
}
}
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum CostGateError {
#[error(
"query rejected: predicate `{field}{op}{pattern}` is unbounded over {node_count} nodes; \
add a scope filter (one of: {scope_hint}) or anchor the regex with `^` / a literal \
prefix \u{2265} {min_prefix_len} chars. See {doc_url}"
)]
QueryTooBroad {
field: String,
op: &'static str,
pattern: String,
node_count: usize,
node_limit: usize,
scope_hint: String,
min_prefix_len: usize,
doc_url: &'static str,
},
}
impl CostGateError {
#[must_use]
pub fn to_query_too_broad_details(&self) -> serde_json::Value {
let Self::QueryTooBroad {
field,
op,
pattern: _,
node_count,
node_limit,
scope_hint: _,
min_prefix_len: _,
doc_url,
} = self;
let suggested: Vec<&str> = SCOPE_FILTER_FIELDS.to_vec();
let mut predicate_shape = format!("{field}{op}<elided>");
if predicate_shape.len() > 256 {
predicate_shape.truncate(253);
predicate_shape.push('\u{2026}');
}
serde_json::json!({
"source": SOURCE_STATIC_ESTIMATE,
"kind": KIND_QUERY_TOO_BROAD,
"estimated_visited_nodes": node_count,
"limit": node_limit,
"predicate_shape": predicate_shape,
"suggested_predicates": suggested,
"doc_url": doc_url,
})
}
}
pub fn check_query(
expr: &Expr,
node_count: usize,
cfg: &CostGateConfig,
) -> Result<(), CostGateError> {
walk_expr(expr, false, node_count, cfg)
}
pub fn check_query_root(
query: &Query,
node_count: usize,
cfg: &CostGateConfig,
) -> Result<(), CostGateError> {
check_query(&query.root, node_count, cfg)
}
pub fn check_regex_pattern_text(
pattern: &str,
node_count: usize,
cfg: &CostGateConfig,
) -> Result<(), CostGateError> {
if !cap_engaged(node_count, cfg) {
return Ok(());
}
if regex_shape_is_acceptable(pattern, cfg) {
return Ok(());
}
Err(CostGateError::QueryTooBroad {
field: "search".to_string(),
op: " ",
pattern: format!("/{pattern}/"),
node_count,
node_limit: cfg.node_count_threshold.unwrap_or(0),
scope_hint: SCOPE_FILTER_FIELDS.join(", "),
min_prefix_len: cfg.min_prefix_len,
doc_url: QUERY_TOO_BROAD_DOC_URL,
})
}
enum Class {
Cheap,
Medium,
Prohibitive,
}
fn cap_engaged(node_count: usize, cfg: &CostGateConfig) -> bool {
match cfg.node_count_threshold {
Some(0) | None => false,
Some(threshold) => node_count > threshold,
}
}
fn walk_expr(
expr: &Expr,
scope_in_scope: bool,
node_count: usize,
cfg: &CostGateConfig,
) -> Result<(), CostGateError> {
match expr {
Expr::Condition(cond) => walk_condition(cond, scope_in_scope, node_count, cfg),
Expr::And(operands) => {
let coupled = scope_in_scope || operands.iter().any(is_scope_filter_at);
for op in operands {
walk_expr(op, coupled, node_count, cfg)?;
}
Ok(())
}
Expr::Or(branches) => {
for br in branches {
walk_expr(br, scope_in_scope, node_count, cfg)?;
}
Ok(())
}
Expr::Not(inner) => {
walk_expr(inner, scope_in_scope, node_count, cfg)
}
Expr::Join(join) => {
walk_expr(&join.left, scope_in_scope, node_count, cfg)?;
walk_expr(&join.right, scope_in_scope, node_count, cfg)
}
}
}
fn walk_condition(
cond: &Condition,
scope_in_scope: bool,
node_count: usize,
cfg: &CostGateConfig,
) -> Result<(), CostGateError> {
if let Value::Subquery(inner) = &cond.value {
walk_expr(inner, scope_in_scope, node_count, cfg)?;
}
if matches!(cond.value, Value::Variable(_)) {
return Ok(());
}
let class = classify_condition(cond, cfg);
match class {
Class::Cheap | Class::Medium => Ok(()),
Class::Prohibitive => {
if !cap_engaged(node_count, cfg) {
return Ok(());
}
if scope_in_scope {
return Ok(());
}
Err(build_query_too_broad(cond, node_count, cfg))
}
}
}
fn classify_condition(cond: &Condition, cfg: &CostGateConfig) -> Class {
let field = cond.field.as_str();
match (&cond.value, &cond.operator) {
(Value::String(_), Operator::Equal)
| (Value::Boolean(_), Operator::Equal)
| (Value::Number(_), Operator::Equal) => Class::Cheap,
(Value::Regex(rv), Operator::Regex) => regex_class(field, &rv.pattern, cfg),
(_, Operator::Greater | Operator::Less | Operator::GreaterEq | Operator::LessEq) => {
Class::Medium
}
(Value::Subquery(_), _) => Class::Medium,
(Value::Variable(_), _) => Class::Cheap,
_ => Class::Medium,
}
}
fn regex_class(field: &str, pattern: &str, cfg: &CostGateConfig) -> Class {
if matches!(field, "kind" | "lang" | "language") {
return Class::Medium;
}
if regex_shape_is_acceptable(pattern, cfg) {
Class::Medium
} else {
Class::Prohibitive
}
}
fn regex_shape_is_acceptable(pattern: &str, cfg: &CostGateConfig) -> bool {
let Ok(hir) = regex_syntax::parse(pattern) else {
return true;
};
let mut extractor = regex_syntax::hir::literal::Extractor::new();
extractor.kind(regex_syntax::hir::literal::ExtractKind::Prefix);
let prefixes = extractor.extract(&hir);
let longest_prefix = prefixes
.literals()
.map(|lits| {
lits.iter()
.map(|lit| lit.as_bytes().len())
.max()
.unwrap_or(0)
})
.unwrap_or(0);
if longest_prefix > cfg.min_prefix_len {
return true;
}
if let Some(min_len) = hir.properties().minimum_len()
&& min_len > cfg.min_literal_len
{
return true;
}
false
}
fn is_scope_filter_at(expr: &Expr) -> bool {
if let Expr::Condition(cond) = expr {
let f = cond.field.as_str();
if SCOPE_FILTER_FIELDS.contains(&f) {
return true;
}
}
false
}
fn build_query_too_broad(
cond: &Condition,
node_count: usize,
cfg: &CostGateConfig,
) -> CostGateError {
let field = cond.field.as_str().to_string();
let op = match cond.operator {
Operator::Equal => ":",
Operator::Regex => "~=",
Operator::Greater => ">",
Operator::Less => "<",
Operator::GreaterEq => ">=",
Operator::LessEq => "<=",
};
let pattern = match &cond.value {
Value::String(s) => s.clone(),
Value::Regex(rv) => format!("/{}/", rv.pattern),
Value::Number(n) => n.to_string(),
Value::Boolean(b) => b.to_string(),
Value::Variable(name) => format!("${name}"),
Value::Subquery(_) => "(<subquery>)".to_string(),
};
CostGateError::QueryTooBroad {
field,
op,
pattern,
node_count,
node_limit: cfg.node_count_threshold.unwrap_or(0),
scope_hint: SCOPE_FILTER_FIELDS.join(", "),
min_prefix_len: cfg.min_prefix_len,
doc_url: QUERY_TOO_BROAD_DOC_URL,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::query::QueryParser;
fn parse(q: &str) -> Query {
QueryParser::parse_query(q).expect("parse")
}
fn cfg() -> CostGateConfig {
CostGateConfig::default()
}
fn cfg_no_cap() -> CostGateConfig {
CostGateConfig {
node_count_threshold: None,
..CostGateConfig::default()
}
}
#[test]
fn gate_rejects_bare_unanchored_suffix_regex() {
let q = parse("name~=/.*_set$/");
let err = check_query_root(&q, 200_000, &cfg()).expect_err("must reject");
assert!(
matches!(err, CostGateError::QueryTooBroad { ref field, .. } if field == "name"),
"expected name-field rejection, got {err:?}"
);
}
#[test]
fn gate_rejects_bare_unanchored_substring_regex() {
let q = parse("name~=/.*foo.*/");
let err = check_query_root(&q, 200_000, &cfg()).expect_err("must reject");
let CostGateError::QueryTooBroad { ref pattern, .. } = err;
assert!(
pattern.contains(".*foo.*"),
"envelope must echo the offending pattern, got {pattern}"
);
}
#[test]
fn gate_allows_unanchored_regex_below_node_threshold() {
let q = parse("name~=/.*_set$/");
check_query_root(&q, 1_000, &cfg()).expect("below threshold must pass");
}
#[test]
fn gate_allows_unanchored_regex_with_kind_coupling() {
let q = parse("kind:function AND name~=/.*_set$/");
check_query_root(&q, 1_000_000, &cfg()).expect("kind coupling must pass");
}
#[test]
fn gate_allows_unanchored_regex_with_lang_coupling() {
let q = parse("lang:rust AND name~=/.*_set$/");
check_query_root(&q, 1_000_000, &cfg()).expect("lang coupling must pass");
}
#[test]
fn gate_allows_unanchored_regex_with_path_coupling() {
let q = parse("path:src/**/*.rs AND name~=/.*_set$/");
check_query_root(&q, 1_000_000, &cfg()).expect("path coupling must pass");
}
#[test]
fn gate_allows_anchored_prefix_regex_without_coupling() {
let q = parse("name~=/^get_/");
check_query_root(&q, 1_000_000, &cfg()).expect("anchored prefix must pass");
}
#[test]
fn gate_allows_long_required_literal_without_anchor() {
let q = parse("name~=/.*deserialize.*/");
check_query_root(&q, 1_000_000, &cfg()).expect("long literal must pass");
}
#[test]
fn gate_rejects_short_anchored_regex_below_prefix_len() {
let q = parse("name~=/^a/");
let err = check_query_root(&q, 1_000_000, &cfg()).expect_err("short prefix must reject");
assert!(matches!(err, CostGateError::QueryTooBroad { .. }));
}
#[test]
fn gate_rejects_or_branch_with_uncoupled_prohibitive() {
let q = parse("(kind:function AND name~=/.*_set$/) OR (name~=/.*foo.*/)");
let err = check_query_root(&q, 1_000_000, &cfg()).expect_err("uncoupled Or must reject");
let CostGateError::QueryTooBroad { ref pattern, .. } = err;
assert!(
pattern.contains(".*foo.*"),
"rejection must point at the uncoupled branch, got {pattern}"
);
}
#[test]
fn gate_passes_known_good_canonical_queries() {
let canonical = [
"kind:function",
"name:foo",
"path:src/**/*.rs",
"lang:rust AND kind:method",
"kind:method AND callers:foo",
];
for q in canonical {
let parsed = parse(q);
check_query_root(&parsed, 1_000_000, &cfg())
.unwrap_or_else(|e| panic!("canonical query {q:?} must pass; got {e:?}"));
}
}
#[test]
fn gate_threshold_disabled_when_node_count_threshold_is_none() {
let q = parse("name~=/.*_set$/");
check_query_root(&q, 1_000_000_000, &cfg_no_cap())
.expect("None threshold must disable cap entirely");
}
#[test]
fn gate_threshold_disabled_when_node_count_threshold_is_zero() {
let q = parse("name~=/.*_set$/");
let cfg = CostGateConfig {
node_count_threshold: Some(0),
..CostGateConfig::default()
};
check_query_root(&q, 1_000_000_000, &cfg).expect("Some(0) threshold must disable cap");
}
#[test]
fn gate_recurses_into_subquery_value() {
let q = parse("kind:function AND callers:(name~=/.*foo.*/)");
let err = check_query_root(&q, 1_000_000, &cfg());
if let Err(CostGateError::QueryTooBroad { ref field, .. }) = err {
assert_eq!(field, "name");
}
}
#[test]
fn to_query_too_broad_details_emits_canonical_cc2_seven_keys() {
let err = CostGateError::QueryTooBroad {
field: "name".into(),
op: "~=",
pattern: "/.*_set$/".into(),
node_count: 312_487,
node_limit: 50_000,
scope_hint: SCOPE_FILTER_FIELDS.join(", "),
min_prefix_len: 3,
doc_url: QUERY_TOO_BROAD_DOC_URL,
};
let details = err.to_query_too_broad_details();
assert_eq!(details["source"], SOURCE_STATIC_ESTIMATE);
assert_eq!(details["kind"], KIND_QUERY_TOO_BROAD);
assert_eq!(details["estimated_visited_nodes"], 312_487);
assert_eq!(details["limit"], 50_000);
let shape = details["predicate_shape"].as_str().unwrap();
assert_eq!(shape, "name~=<elided>");
assert!(!shape.contains("_set"));
assert!(details["suggested_predicates"].is_array());
assert_eq!(details["doc_url"], QUERY_TOO_BROAD_DOC_URL);
}
#[test]
fn cli_search_shape_check_rejects_unanchored_substring() {
let err = check_regex_pattern_text(".*foo.*", 1_000_000, &cfg())
.expect_err("CLI shape check must reject .*foo.*");
assert!(matches!(err, CostGateError::QueryTooBroad { .. }));
}
#[test]
fn cli_search_shape_check_passes_anchored_prefix() {
check_regex_pattern_text("^get_", 1_000_000, &cfg())
.expect("anchored prefix must pass CLI shape check");
}
#[test]
fn cli_search_shape_check_passes_long_literal() {
check_regex_pattern_text(".*deserialize.*", 1_000_000, &cfg())
.expect("long literal must pass CLI shape check");
}
#[test]
fn cli_search_shape_check_below_threshold_passes() {
check_regex_pattern_text(".*foo.*", 1_000, &cfg())
.expect("below cap must pass shape check");
}
}