use crate::planner::ir::{
MatchMode, PlanNode, Predicate, PredicateValue, QueryPlan, StringPattern,
};
use thiserror::Error;
pub use sqry_core::query::cost_gate::CostGateConfig as PlannerCostGateConfig;
pub use sqry_core::query::cost_gate::QUERY_TOO_BROAD_DOC_URL;
pub use sqry_core::query::cost_gate::SCOPE_FILTER_FIELDS;
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum PlannerCostGateError {
#[error(
"query rejected: planner predicate `{predicate_shape}` is unbounded over {node_count} \
nodes; add a scope filter (one of: {scope_hint}) or anchor the regex with `^` / a \
literal prefix \u{2265} {min_prefix_len} chars. See {doc_url}"
)]
QueryTooBroad {
predicate_shape: String,
node_count: usize,
node_limit: usize,
scope_hint: String,
min_prefix_len: usize,
doc_url: &'static str,
},
}
impl PlannerCostGateError {
#[must_use]
pub fn to_query_too_broad_details(&self) -> serde_json::Value {
let Self::QueryTooBroad {
predicate_shape,
node_count,
node_limit,
scope_hint: _,
min_prefix_len: _,
doc_url,
} = self;
let suggested: Vec<&str> = SCOPE_FILTER_FIELDS.to_vec();
let mut shape = predicate_shape.clone();
if shape.len() > 256 {
shape.truncate(253);
shape.push('\u{2026}');
}
serde_json::json!({
"source": sqry_core::query::cost_gate::SOURCE_STATIC_ESTIMATE,
"kind": sqry_core::query::cost_gate::KIND_QUERY_TOO_BROAD,
"estimated_visited_nodes": node_count,
"limit": node_limit,
"predicate_shape": shape,
"suggested_predicates": suggested,
"doc_url": doc_url,
})
}
}
pub fn check_plan(
plan: &QueryPlan,
node_count: usize,
cfg: &PlannerCostGateConfig,
) -> Result<(), PlannerCostGateError> {
walk_node(&plan.root, false, node_count, cfg)
}
fn cap_engaged(node_count: usize, cfg: &PlannerCostGateConfig) -> bool {
match cfg.node_count_threshold {
Some(0) | None => false,
Some(threshold) => node_count > threshold,
}
}
fn walk_node(
node: &PlanNode,
scope_in_scope: bool,
node_count: usize,
cfg: &PlannerCostGateConfig,
) -> Result<(), PlannerCostGateError> {
match node {
PlanNode::NodeScan {
kind,
visibility,
name_pattern,
} => {
let scoped = scope_in_scope || kind.is_some() || visibility.is_some();
if let Some(name_pattern) = name_pattern {
check_name_pattern(name_pattern, scoped, node_count, cfg)?;
}
Ok(())
}
PlanNode::EdgeTraversal { .. } => {
Ok(())
}
PlanNode::Filter { predicate } => {
check_predicate(predicate, scope_in_scope, node_count, cfg)
}
PlanNode::SetOp { left, right, .. } => {
walk_node(left, scope_in_scope, node_count, cfg)?;
walk_node(right, scope_in_scope, node_count, cfg)
}
PlanNode::Chain { steps } => {
let mut chain_scope = scope_in_scope;
for step in steps {
walk_node(step, chain_scope, node_count, cfg)?;
chain_scope = chain_scope || node_introduces_scope(step);
}
Ok(())
}
}
}
fn node_introduces_scope(node: &PlanNode) -> bool {
matches!(
node,
PlanNode::NodeScan { kind: Some(_), .. }
| PlanNode::NodeScan {
visibility: Some(_),
..
}
| PlanNode::Filter {
predicate: Predicate::InFile(_) | Predicate::InScope(_),
}
)
}
fn check_predicate(
predicate: &Predicate,
scope_in_scope: bool,
node_count: usize,
cfg: &PlannerCostGateConfig,
) -> Result<(), PlannerCostGateError> {
match predicate {
Predicate::And(list) => {
let coupled = scope_in_scope
|| list
.iter()
.any(|p| matches!(p, Predicate::InFile(_) | Predicate::InScope(_)));
for p in list {
check_predicate(p, coupled, node_count, cfg)?;
}
Ok(())
}
Predicate::Or(list) => {
for p in list {
check_predicate(p, scope_in_scope, node_count, cfg)?;
}
Ok(())
}
Predicate::Not(inner) => check_predicate(inner, scope_in_scope, node_count, cfg),
Predicate::MatchesName(pattern) => {
check_name_pattern(pattern, scope_in_scope, node_count, cfg)
}
Predicate::Callers(v)
| Predicate::Callees(v)
| Predicate::Imports(v)
| Predicate::Exports(v)
| Predicate::References(v)
| Predicate::Implements(v) => check_predicate_value(v, scope_in_scope, node_count, cfg),
Predicate::HasCaller
| Predicate::HasCallee
| Predicate::IsUnused
| Predicate::InFile(_)
| Predicate::InScope(_)
| Predicate::Returns(_) => Ok(()),
}
}
fn check_predicate_value(
value: &PredicateValue,
scope_in_scope: bool,
node_count: usize,
cfg: &PlannerCostGateConfig,
) -> Result<(), PlannerCostGateError> {
match value {
PredicateValue::Pattern(_) => Ok(()),
PredicateValue::Regex(rx) => {
if !cap_engaged(node_count, cfg) || scope_in_scope {
return Ok(());
}
if regex_shape_is_acceptable(&rx.pattern, cfg) {
return Ok(());
}
Err(reject("references~=<elided>".to_string(), node_count, cfg))
}
PredicateValue::Subquery(plan) => walk_node(plan, scope_in_scope, node_count, cfg),
}
}
fn check_name_pattern(
pattern: &StringPattern,
scope_in_scope: bool,
node_count: usize,
cfg: &PlannerCostGateConfig,
) -> Result<(), PlannerCostGateError> {
if !cap_engaged(node_count, cfg) || scope_in_scope {
return Ok(());
}
let shape = match pattern.mode {
MatchMode::Exact => return Ok(()),
MatchMode::Prefix => {
if pattern.raw.len() > cfg.min_prefix_len {
return Ok(());
}
"name:<prefix-elided>*".to_string()
}
MatchMode::Suffix => {
if pattern.raw.len() > cfg.min_literal_len {
return Ok(());
}
"name:*<suffix-elided>".to_string()
}
MatchMode::Contains => {
if pattern.raw.len() > cfg.min_literal_len {
return Ok(());
}
"name:*<contains-elided>*".to_string()
}
MatchMode::Glob => {
let max_literal = pattern
.raw
.split(['*', '?', '['])
.map(str::len)
.max()
.unwrap_or(0);
if max_literal > cfg.min_literal_len {
return Ok(());
}
"name:<glob-elided>".to_string()
}
};
Err(reject(shape, node_count, cfg))
}
fn regex_shape_is_acceptable(pattern: &str, cfg: &PlannerCostGateConfig) -> bool {
let Ok(hir) = regex_syntax::parse(pattern) else {
return true;
};
let mut extractor = regex_syntax::hir::literal::Extractor::new();
extractor.kind(regex_syntax::hir::literal::ExtractKind::Prefix);
let prefixes = extractor.extract(&hir);
let longest_prefix = prefixes
.literals()
.map(|lits| {
lits.iter()
.map(|lit| lit.as_bytes().len())
.max()
.unwrap_or(0)
})
.unwrap_or(0);
if longest_prefix > cfg.min_prefix_len {
return true;
}
if let Some(min_len) = hir.properties().minimum_len()
&& min_len > cfg.min_literal_len
{
return true;
}
false
}
fn reject(
predicate_shape: String,
node_count: usize,
cfg: &PlannerCostGateConfig,
) -> PlannerCostGateError {
PlannerCostGateError::QueryTooBroad {
predicate_shape,
node_count,
node_limit: cfg.node_count_threshold.unwrap_or(0),
scope_hint: SCOPE_FILTER_FIELDS.join(", "),
min_prefix_len: cfg.min_prefix_len,
doc_url: QUERY_TOO_BROAD_DOC_URL,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::planner::ir::RegexPattern;
use sqry_core::graph::unified::node::kind::NodeKind;
fn cfg() -> PlannerCostGateConfig {
PlannerCostGateConfig::default()
}
fn plan(root: PlanNode) -> QueryPlan {
QueryPlan::new(root)
}
#[test]
fn planner_gate_rejects_bare_substring_below_threshold_arena_passes() {
let p = plan(PlanNode::NodeScan {
kind: None,
visibility: None,
name_pattern: Some(StringPattern {
raw: "foo".to_string(),
mode: MatchMode::Contains,
case_insensitive: false,
}),
});
check_plan(&p, 1_000, &cfg()).expect("below cap must pass");
}
#[test]
fn planner_gate_rejects_bare_substring_above_threshold() {
let p = plan(PlanNode::NodeScan {
kind: None,
visibility: None,
name_pattern: Some(StringPattern {
raw: "foo".to_string(),
mode: MatchMode::Contains,
case_insensitive: false,
}),
});
let err = check_plan(&p, 1_000_000, &cfg()).expect_err("must reject");
assert!(matches!(err, PlannerCostGateError::QueryTooBroad { .. }));
}
#[test]
fn planner_gate_allows_kind_coupled_substring() {
let p = plan(PlanNode::NodeScan {
kind: Some(NodeKind::Function),
visibility: None,
name_pattern: Some(StringPattern {
raw: "foo".to_string(),
mode: MatchMode::Contains,
case_insensitive: false,
}),
});
check_plan(&p, 1_000_000, &cfg()).expect("kind coupling must pass");
}
#[test]
fn planner_gate_allows_long_substring_without_coupling() {
let p = plan(PlanNode::NodeScan {
kind: None,
visibility: None,
name_pattern: Some(StringPattern {
raw: "deserialize".to_string(),
mode: MatchMode::Contains,
case_insensitive: false,
}),
});
check_plan(&p, 1_000_000, &cfg()).expect("long literal must pass");
}
#[test]
fn planner_gate_rejects_prohibitive_regex_in_filter_without_coupling() {
let p = plan(PlanNode::Chain {
steps: vec![
PlanNode::NodeScan {
kind: None,
visibility: None,
name_pattern: None,
},
PlanNode::Filter {
predicate: Predicate::References(PredicateValue::Regex(RegexPattern::new(
".*foo.*",
))),
},
],
});
let err = check_plan(&p, 1_000_000, &cfg()).expect_err("must reject");
assert!(matches!(err, PlannerCostGateError::QueryTooBroad { .. }));
}
#[test]
fn planner_gate_passes_canonical_kind_only_query() {
let p = plan(PlanNode::NodeScan {
kind: Some(NodeKind::Function),
visibility: None,
name_pattern: None,
});
check_plan(&p, 1_000_000_000, &cfg()).expect("canonical kind:function must pass");
}
}