use sqry_core::graph::unified::edge::kind::EdgeKind;
use sqry_core::graph::unified::node::kind::NodeKind;
use sqry_core::graph::unified::string::StringId;
use sqry_core::schema::Visibility;
use thiserror::Error;
use super::ir::{
Direction, PathPattern, PlanNode, Predicate, PredicateValue, QueryPlan, StringPattern,
};
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub enum BuildError {
#[error("query builder is empty: at least one step is required before build()")]
EmptyBuilder,
#[error(
"first step is not context-free: chains must start with NodeScan or SetOp, not {first_kind:?}"
)]
FirstStepNotContextFree {
first_kind: PlanNodeKind,
},
#[error("traversal step has max_depth = 0: must be >= 1 to produce any output")]
ZeroDepth,
#[error(
"set-op operand is not a valid sub-plan: {reason} (the operand must itself build cleanly)"
)]
InvalidSetOpOperand {
reason: String,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PlanNodeKind {
NodeScan,
EdgeTraversal,
Filter,
SetOp,
Chain,
}
impl PlanNodeKind {
#[must_use]
pub const fn of(node: &PlanNode) -> Self {
match node {
PlanNode::NodeScan { .. } => Self::NodeScan,
PlanNode::EdgeTraversal { .. } => Self::EdgeTraversal,
PlanNode::Filter { .. } => Self::Filter,
PlanNode::SetOp { .. } => Self::SetOp,
PlanNode::Chain { .. } => Self::Chain,
}
}
}
#[must_use]
pub fn normalize_edge_kind(kind: EdgeKind) -> EdgeKind {
match kind {
EdgeKind::Defines
| EdgeKind::Contains
| EdgeKind::References
| EdgeKind::Inherits
| EdgeKind::Implements
| EdgeKind::WebAssemblyCall
| EdgeKind::GenericBound
| EdgeKind::AnnotatedWith
| EdgeKind::AnnotationParam
| EdgeKind::LambdaCaptures
| EdgeKind::ModuleExports
| EdgeKind::ModuleRequires
| EdgeKind::ModuleOpens
| EdgeKind::ModuleProvides
| EdgeKind::TypeArgument
| EdgeKind::ExtensionReceiver
| EdgeKind::CompanionOf
| EdgeKind::SealedPermit
| EdgeKind::LifetimeConstraint { .. }
| EdgeKind::FfiCall { .. } => kind,
EdgeKind::Calls { .. } => EdgeKind::Calls {
argument_count: 0,
is_async: false,
},
EdgeKind::Imports { is_wildcard, .. } => EdgeKind::Imports {
alias: None,
is_wildcard,
},
EdgeKind::Exports { kind, .. } => EdgeKind::Exports { kind, alias: None },
EdgeKind::TypeOf { context, .. } => EdgeKind::TypeOf {
context,
index: None,
name: None,
},
EdgeKind::MacroExpansion { expansion_kind, .. } => EdgeKind::MacroExpansion {
expansion_kind,
is_verified: false,
},
EdgeKind::HttpRequest { method, .. } => EdgeKind::HttpRequest { method, url: None },
EdgeKind::DbQuery { query_type, .. } => EdgeKind::DbQuery {
query_type,
table: None,
},
EdgeKind::TableWrite { operation, .. } => EdgeKind::TableWrite {
table_name: StringId::INVALID,
schema: None,
operation,
},
EdgeKind::MessageQueue { protocol, .. } => EdgeKind::MessageQueue {
protocol,
topic: None,
},
EdgeKind::TraitMethodBinding { .. } => EdgeKind::TraitMethodBinding {
trait_name: StringId::INVALID,
impl_type: StringId::INVALID,
is_ambiguous: false,
},
EdgeKind::GrpcCall { .. } => EdgeKind::GrpcCall {
service: StringId::INVALID,
method: StringId::INVALID,
},
EdgeKind::TableRead { .. } => EdgeKind::TableRead {
table_name: StringId::INVALID,
schema: None,
},
EdgeKind::TriggeredBy { .. } => EdgeKind::TriggeredBy {
trigger_name: StringId::INVALID,
schema: None,
},
EdgeKind::WebSocket { .. } => EdgeKind::WebSocket { event: None },
EdgeKind::GraphQLOperation { .. } => EdgeKind::GraphQLOperation {
operation: StringId::INVALID,
},
EdgeKind::ProcessExec { .. } => EdgeKind::ProcessExec {
command: StringId::INVALID,
},
EdgeKind::FileIpc { .. } => EdgeKind::FileIpc { path_pattern: None },
EdgeKind::ProtocolCall { .. } => EdgeKind::ProtocolCall {
protocol: StringId::INVALID,
metadata: None,
},
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct ScanFilters {
pub kind: Option<NodeKind>,
pub visibility: Option<Visibility>,
pub name_pattern: Option<StringPattern>,
}
impl ScanFilters {
#[must_use]
pub const fn new() -> Self {
Self {
kind: None,
visibility: None,
name_pattern: None,
}
}
#[must_use]
pub fn with_kind(mut self, kind: NodeKind) -> Self {
self.kind = Some(kind);
self
}
#[must_use]
pub fn with_visibility(mut self, visibility: Visibility) -> Self {
self.visibility = Some(visibility);
self
}
#[must_use]
pub fn with_name_pattern(mut self, pattern: StringPattern) -> Self {
self.name_pattern = Some(pattern);
self
}
}
#[derive(Debug, Clone, Default)]
pub struct QueryBuilder {
steps: Vec<PlanNode>,
}
impl QueryBuilder {
#[must_use]
pub const fn new() -> Self {
Self { steps: Vec::new() }
}
#[must_use]
pub fn scan(mut self, kind: NodeKind) -> Self {
self.steps.push(PlanNode::NodeScan {
kind: Some(kind),
visibility: None,
name_pattern: None,
});
self
}
#[must_use]
pub fn scan_with(mut self, filters: ScanFilters) -> Self {
let ScanFilters {
kind,
visibility,
name_pattern,
} = filters;
self.steps.push(PlanNode::NodeScan {
kind,
visibility,
name_pattern,
});
self
}
#[must_use]
pub fn scan_all(mut self) -> Self {
self.steps.push(PlanNode::NodeScan {
kind: None,
visibility: None,
name_pattern: None,
});
self
}
#[must_use]
pub fn filter(mut self, predicate: Predicate) -> Self {
self.steps.push(PlanNode::Filter { predicate });
self
}
#[must_use]
pub fn filter_name(self, pattern: StringPattern) -> Self {
self.filter(Predicate::MatchesName(pattern))
}
#[must_use]
pub fn filter_in_file(self, path: impl Into<PathPattern>) -> Self {
self.filter(Predicate::InFile(path.into()))
}
#[must_use]
pub fn traverse(mut self, direction: Direction, edge_kind: EdgeKind, max_depth: u32) -> Self {
self.steps.push(PlanNode::EdgeTraversal {
direction,
edge_kind: Some(normalize_edge_kind(edge_kind)),
max_depth,
});
self
}
#[must_use]
pub fn traverse_any(mut self, direction: Direction, max_depth: u32) -> Self {
self.steps.push(PlanNode::EdgeTraversal {
direction,
edge_kind: None,
max_depth,
});
self
}
#[must_use]
pub fn union(self, other: QueryPlan) -> Self {
self.combine(super::ir::SetOperation::Union, other)
}
#[must_use]
pub fn intersect(self, other: QueryPlan) -> Self {
self.combine(super::ir::SetOperation::Intersect, other)
}
#[must_use]
pub fn difference(self, other: QueryPlan) -> Self {
self.combine(super::ir::SetOperation::Difference, other)
}
fn combine(mut self, op: super::ir::SetOperation, other: QueryPlan) -> Self {
let right = Box::new(other.root);
if self.steps.is_empty() {
adopt_into_steps(&mut self.steps, *right);
return self;
}
let left: Box<PlanNode> = if self.steps.len() == 1 {
Box::new(self.steps.pop().expect("len == 1"))
} else {
Box::new(PlanNode::Chain {
steps: std::mem::take(&mut self.steps),
})
};
self.steps.push(PlanNode::SetOp { op, left, right });
self
}
#[inline]
#[must_use]
pub fn step_count(&self) -> usize {
self.steps.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.steps.is_empty()
}
pub fn build(self) -> Result<QueryPlan, BuildError> {
if self.steps.is_empty() {
return Err(BuildError::EmptyBuilder);
}
let first = &self.steps[0];
if !first.is_context_free() {
return Err(BuildError::FirstStepNotContextFree {
first_kind: PlanNodeKind::of(first),
});
}
for step in &self.steps {
validate_subtree(step)?;
}
Ok(QueryPlan::new(PlanNode::Chain { steps: self.steps }))
}
}
fn adopt_into_steps(steps: &mut Vec<PlanNode>, root: PlanNode) {
match root {
PlanNode::Chain { steps: inner } => steps.extend(inner),
other => steps.push(other),
}
}
fn validate_subtree(node: &PlanNode) -> Result<(), BuildError> {
match node {
PlanNode::NodeScan { .. } => Ok(()),
PlanNode::EdgeTraversal { max_depth, .. } => {
if *max_depth == 0 {
Err(BuildError::ZeroDepth)
} else {
Ok(())
}
}
PlanNode::Filter { predicate } => validate_predicate(predicate),
PlanNode::SetOp { left, right, .. } => {
ensure_context_free(left)?;
ensure_context_free(right)?;
validate_subtree(left)?;
validate_subtree(right)?;
Ok(())
}
PlanNode::Chain { steps } => {
if let Some(first) = steps.first()
&& !first.is_context_free()
{
return Err(BuildError::FirstStepNotContextFree {
first_kind: PlanNodeKind::of(first),
});
}
for step in steps {
validate_subtree(step)?;
}
Ok(())
}
}
}
fn validate_predicate(predicate: &Predicate) -> Result<(), BuildError> {
match predicate {
Predicate::HasCaller
| Predicate::HasCallee
| Predicate::IsUnused
| Predicate::InFile(_)
| Predicate::InScope(_)
| Predicate::MatchesName(_) => Ok(()),
Predicate::Callers(v)
| Predicate::Callees(v)
| Predicate::Imports(v)
| Predicate::Exports(v)
| Predicate::References(v)
| Predicate::Implements(v) => validate_predicate_value(v),
Predicate::And(list) | Predicate::Or(list) => {
for inner in list {
validate_predicate(inner)?;
}
Ok(())
}
Predicate::Not(inner) => validate_predicate(inner),
}
}
fn validate_predicate_value(value: &PredicateValue) -> Result<(), BuildError> {
match value {
PredicateValue::Pattern(_) | PredicateValue::Regex(_) => Ok(()),
PredicateValue::Subquery(plan) => {
ensure_context_free(plan)?;
validate_subtree(plan)
}
}
}
fn ensure_context_free(node: &PlanNode) -> Result<(), BuildError> {
if node.is_context_free() {
return Ok(());
}
if let PlanNode::Chain { steps } = node
&& let Some(first) = steps.first()
&& first.is_context_free()
{
return Ok(());
}
Err(BuildError::InvalidSetOpOperand {
reason: format!("operand root is {:?}", PlanNodeKind::of(node)),
})
}
pub trait QueryPlanExt {
fn into_subquery(self) -> PredicateValue;
fn as_subquery(&self) -> PredicateValue;
}
impl QueryPlanExt for QueryPlan {
fn into_subquery(self) -> PredicateValue {
PredicateValue::Subquery(Box::new(self.root))
}
fn as_subquery(&self) -> PredicateValue {
PredicateValue::Subquery(Box::new(self.root.clone()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_builder_reports_empty_error() {
let err = QueryBuilder::new().build().unwrap_err();
assert_eq!(err, BuildError::EmptyBuilder);
}
#[test]
fn first_step_filter_is_rejected() {
let mut b = QueryBuilder::new();
b.steps.push(PlanNode::Filter {
predicate: Predicate::HasCaller,
});
let err = b.build().unwrap_err();
assert!(matches!(
err,
BuildError::FirstStepNotContextFree {
first_kind: PlanNodeKind::Filter
}
));
}
#[test]
fn first_step_traversal_is_rejected() {
let mut b = QueryBuilder::new();
b.steps.push(PlanNode::EdgeTraversal {
direction: Direction::Forward,
edge_kind: None,
max_depth: 1,
});
let err = b.build().unwrap_err();
assert!(matches!(
err,
BuildError::FirstStepNotContextFree {
first_kind: PlanNodeKind::EdgeTraversal
}
));
}
#[test]
fn zero_depth_is_rejected() {
let err = QueryBuilder::new()
.scan(NodeKind::Function)
.traverse_any(Direction::Forward, 0)
.build()
.unwrap_err();
assert_eq!(err, BuildError::ZeroDepth);
}
#[test]
fn scan_then_filter_builds() {
let plan = QueryBuilder::new()
.scan(NodeKind::Function)
.filter(Predicate::HasCaller)
.build()
.expect("plan");
let PlanNode::Chain { steps } = &plan.root else {
panic!("expected Chain root");
};
assert_eq!(steps.len(), 2);
assert!(matches!(steps[0], PlanNode::NodeScan { .. }));
assert!(matches!(steps[1], PlanNode::Filter { .. }));
}
#[test]
fn step_count_and_is_empty_track_state() {
let b = QueryBuilder::new();
assert!(b.is_empty());
assert_eq!(b.step_count(), 0);
let b = b.scan(NodeKind::Function).filter(Predicate::HasCaller);
assert!(!b.is_empty());
assert_eq!(b.step_count(), 2);
}
#[test]
fn normalize_edge_kind_zeroes_calls_metadata() {
let a = normalize_edge_kind(EdgeKind::Calls {
argument_count: 7,
is_async: true,
});
let b = normalize_edge_kind(EdgeKind::Calls {
argument_count: 0,
is_async: false,
});
assert_eq!(a, b);
}
#[test]
fn normalize_edge_kind_passes_through_metadata_free_variants() {
let cases = [
EdgeKind::Defines,
EdgeKind::Contains,
EdgeKind::References,
EdgeKind::Inherits,
EdgeKind::Implements,
EdgeKind::WebAssemblyCall,
];
for c in cases {
assert_eq!(normalize_edge_kind(c.clone()), c);
}
}
#[test]
fn plan_node_kind_of_covers_all_variants() {
assert_eq!(
PlanNodeKind::of(&PlanNode::NodeScan {
kind: None,
visibility: None,
name_pattern: None
}),
PlanNodeKind::NodeScan
);
assert_eq!(
PlanNodeKind::of(&PlanNode::EdgeTraversal {
direction: Direction::Forward,
edge_kind: None,
max_depth: 1
}),
PlanNodeKind::EdgeTraversal
);
assert_eq!(
PlanNodeKind::of(&PlanNode::Filter {
predicate: Predicate::HasCaller
}),
PlanNodeKind::Filter
);
let scan = PlanNode::NodeScan {
kind: None,
visibility: None,
name_pattern: None,
};
assert_eq!(
PlanNodeKind::of(&PlanNode::SetOp {
op: super::super::ir::SetOperation::Union,
left: Box::new(scan.clone()),
right: Box::new(scan)
}),
PlanNodeKind::SetOp
);
assert_eq!(
PlanNodeKind::of(&PlanNode::Chain { steps: vec![] }),
PlanNodeKind::Chain
);
}
#[test]
fn into_subquery_consumes_plan() {
let plan = QueryBuilder::new()
.scan(NodeKind::Function)
.build()
.expect("plan");
let value = plan.into_subquery();
assert!(value.is_subquery());
}
}