1use std::collections::{HashMap, HashSet, VecDeque};
86
87use serde::{Deserialize, Serialize};
88
89use super::ir::{PlanNode, Predicate, PredicateValue, QueryPlan};
90
91#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
108pub struct FusedPlanBatch {
109 groups: Vec<FusionGroup>,
111 subquery_batch: Option<Box<FusedPlanBatch>>,
119 stats: FusionStats,
121 #[serde(default)]
124 shared_nodes: Vec<SharedNode>,
125}
126
127impl FusedPlanBatch {
128 #[inline]
130 #[must_use]
131 pub fn groups(&self) -> &[FusionGroup] {
132 &self.groups
133 }
134
135 #[inline]
138 #[must_use]
139 pub fn subquery_batch(&self) -> Option<&FusedPlanBatch> {
140 self.subquery_batch.as_deref()
141 }
142
143 #[inline]
145 #[must_use]
146 pub fn stats(&self) -> &FusionStats {
147 &self.stats
148 }
149
150 #[inline]
152 #[must_use]
153 pub fn shared_nodes(&self) -> &[SharedNode] {
154 &self.shared_nodes
155 }
156
157 #[inline]
159 #[must_use]
160 pub fn len(&self) -> usize {
161 self.groups.len()
162 }
163
164 #[inline]
167 #[must_use]
168 pub fn is_empty(&self) -> bool {
169 self.groups.is_empty()
170 }
171
172 pub fn iter_groups(&self) -> impl Iterator<Item = &FusionGroup> {
174 self.groups.iter()
175 }
176
177 #[must_use]
184 pub fn input_plans(&self) -> Vec<QueryPlan> {
185 let total_tails: usize = self.groups.iter().map(|g| g.tails.len()).sum();
187 let mut out: Vec<Option<QueryPlan>> = (0..total_tails).map(|_| None).collect();
188
189 for group in &self.groups {
190 for tail in &group.tails {
191 let plan = tail.reconstruct(&group.prefix);
192 let idx = tail.original_index;
193 debug_assert!(idx < out.len(), "original_index out of bounds");
194 out[idx] = Some(plan);
195 }
196 }
197
198 out.into_iter()
199 .map(|p| p.expect("every original index must be filled"))
200 .collect()
201 }
202
203 #[must_use]
206 pub fn total_plans(&self) -> usize {
207 self.groups.iter().map(|g| g.tails.len()).sum()
208 }
209}
210
211#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
213pub struct SharedNodeId(u32);
214
215impl SharedNodeId {
216 #[inline]
217 #[must_use]
218 pub const fn get(self) -> u32 {
219 self.0
220 }
221}
222
223#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
226pub struct SharedNode {
227 canonical_plan: PlanNode,
228 consumers: Vec<usize>,
229 id: SharedNodeId,
230}
231
232impl SharedNode {
233 #[inline]
235 #[must_use]
236 pub fn canonical_plan(&self) -> &PlanNode {
237 &self.canonical_plan
238 }
239
240 #[inline]
242 #[must_use]
243 pub fn consumers(&self) -> &[usize] {
244 &self.consumers
245 }
246
247 #[inline]
249 #[must_use]
250 pub const fn id(&self) -> SharedNodeId {
251 self.id
252 }
253}
254
255#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
271pub struct FusionGroup {
272 prefix: PlanNode,
274 tails: Vec<FusedTail>,
276}
277
278impl FusionGroup {
279 #[inline]
281 #[must_use]
282 pub fn prefix(&self) -> &PlanNode {
283 &self.prefix
284 }
285
286 #[inline]
289 #[must_use]
290 pub fn tails(&self) -> &[FusedTail] {
291 &self.tails
292 }
293
294 #[inline]
296 #[must_use]
297 pub fn tail_count(&self) -> usize {
298 self.tails.len()
299 }
300}
301
302#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
309pub struct FusedTail {
310 pub original_index: usize,
313 pub tail: FusionTail,
315}
316
317impl FusedTail {
318 #[must_use]
322 pub fn reconstruct(&self, prefix: &PlanNode) -> QueryPlan {
323 let root = match &self.tail {
324 FusionTail::Identity => prefix.clone(),
325 FusionTail::ChainContinuation { remaining_steps } => {
326 let mut steps = Vec::with_capacity(remaining_steps.len() + 1);
327 steps.push(prefix.clone());
328 steps.extend(remaining_steps.iter().cloned());
329 PlanNode::Chain { steps }
330 }
331 };
332 QueryPlan::new(root)
333 }
334}
335
336#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
350#[serde(rename_all = "snake_case")]
351pub enum FusionTail {
352 Identity,
354 ChainContinuation {
363 remaining_steps: Vec<PlanNode>,
366 },
367}
368
369#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
372pub struct FusionStats {
373 pub total_plans: usize,
375 pub fusion_groups: usize,
377 pub scans_eliminated: usize,
381 pub subqueries_total: usize,
384 pub subqueries_unique: usize,
390 pub shared_nodes_promoted: usize,
392 #[serde(with = "f64_bits")]
394 pub plan_tree_reduction_ratio: f64,
395}
396
397impl Default for FusionStats {
398 fn default() -> Self {
399 Self {
400 total_plans: 0,
401 fusion_groups: 0,
402 scans_eliminated: 0,
403 subqueries_total: 0,
404 subqueries_unique: 0,
405 shared_nodes_promoted: 0,
406 plan_tree_reduction_ratio: 0.0,
407 }
408 }
409}
410
411impl PartialEq for FusionStats {
412 fn eq(&self, other: &Self) -> bool {
413 self.total_plans == other.total_plans
414 && self.fusion_groups == other.fusion_groups
415 && self.scans_eliminated == other.scans_eliminated
416 && self.subqueries_total == other.subqueries_total
417 && self.subqueries_unique == other.subqueries_unique
418 && self.shared_nodes_promoted == other.shared_nodes_promoted
419 && self.plan_tree_reduction_ratio.to_bits() == other.plan_tree_reduction_ratio.to_bits()
420 }
421}
422
423impl Eq for FusionStats {}
424
425mod f64_bits {
426 use serde::{Deserialize, Deserializer, Serializer};
427
428 pub fn serialize<S>(value: &f64, serializer: S) -> Result<S::Ok, S::Error>
429 where
430 S: Serializer,
431 {
432 if serializer.is_human_readable() {
433 serializer.serialize_f64(*value)
434 } else {
435 serializer.serialize_u64(value.to_bits())
436 }
437 }
438
439 pub fn deserialize<'de, D>(deserializer: D) -> Result<f64, D::Error>
440 where
441 D: Deserializer<'de>,
442 {
443 if deserializer.is_human_readable() {
444 f64::deserialize(deserializer)
445 } else {
446 u64::deserialize(deserializer).map(f64::from_bits)
447 }
448 }
449}
450
451impl FusionStats {
452 #[inline]
455 #[must_use]
456 pub fn subqueries_eliminated(&self) -> usize {
457 self.subqueries_total.saturating_sub(self.subqueries_unique)
458 }
459}
460
461#[must_use]
482pub fn fuse_plans(plans: Vec<QueryPlan>) -> FusedPlanBatch {
483 let original_plans = plans.clone();
484 let total_plans = plans.len();
485
486 let (subqueries_total, subquery_plans) = collect_subquery_plans(&plans);
491 let subquery_batch = if subquery_plans.is_empty() {
492 None
493 } else {
494 Some(Box::new(fuse_plans(subquery_plans)))
495 };
496
497 let mut groups: Vec<FusionGroup> = Vec::new();
505 let mut prefix_index: HashMap<PlanNode, usize> = HashMap::new();
506
507 for (original_index, plan) in plans.into_iter().enumerate() {
508 let (prefix, tail) = split_prefix_and_tail(plan.root);
509
510 let fused_tail = FusedTail {
511 original_index,
512 tail,
513 };
514
515 if let Some(&idx) = prefix_index.get(&prefix) {
517 groups[idx].tails.push(fused_tail);
518 } else {
519 let group = FusionGroup {
520 prefix: prefix.clone(),
521 tails: vec![fused_tail],
522 };
523 prefix_index.insert(prefix, groups.len());
524 groups.push(group);
525 }
526 }
527
528 let fusion_groups = groups.len();
529 let scans_eliminated = total_plans.saturating_sub(fusion_groups);
530 let promoted_shared_candidates = collect_promoted_candidates(&original_plans, &groups);
531 let promoted_shared_nodes = materialize_shared_nodes(&promoted_shared_candidates);
532 let shared_nodes_promoted = promoted_shared_nodes.len();
533
534 let subqueries_unique = subquery_batch
535 .as_deref()
536 .map_or(0, FusedPlanBatch::total_plans);
537 let plan_tree_reduction_ratio =
538 estimate_plan_tree_reduction_ratio(&original_plans, &promoted_shared_candidates);
539
540 let stats = FusionStats {
541 total_plans,
542 fusion_groups,
543 scans_eliminated,
544 subqueries_total,
545 subqueries_unique,
546 shared_nodes_promoted,
547 plan_tree_reduction_ratio,
548 };
549
550 FusedPlanBatch {
551 groups,
552 subquery_batch,
553 stats,
554 shared_nodes: promoted_shared_nodes,
555 }
556}
557
558#[must_use]
564pub fn fuse_single(plan: QueryPlan) -> FusedPlanBatch {
565 fuse_plans(vec![plan])
566}
567
568fn split_prefix_and_tail(root: PlanNode) -> (PlanNode, FusionTail) {
589 if let PlanNode::Chain { mut steps } = root {
590 match steps.len() {
591 0 => {
592 (PlanNode::Chain { steps }, FusionTail::Identity)
595 }
596 1 => {
597 let only = steps.pop().expect("len == 1");
599 if only.is_context_free() {
600 (only, FusionTail::Identity)
601 } else {
602 (PlanNode::Chain { steps: vec![only] }, FusionTail::Identity)
606 }
607 }
608 _ => {
609 let first = steps.remove(0);
610 if first.is_context_free() {
611 (
612 first,
613 FusionTail::ChainContinuation {
614 remaining_steps: steps,
615 },
616 )
617 } else {
618 let mut original = Vec::with_capacity(steps.len() + 1);
620 original.push(first);
621 original.extend(steps);
622 (PlanNode::Chain { steps: original }, FusionTail::Identity)
623 }
624 }
625 }
626 } else {
627 (root, FusionTail::Identity)
631 }
632}
633
634#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
635enum PathSegment {
636 ChainStep(usize),
637 ChainPrefix(usize),
638 SetLeft,
639 SetRight,
640 PredicateValueSubquery,
641 PredicateAnd(usize),
642 PredicateOr(usize),
643 PredicateNot,
644}
645
646#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
647struct SubtreePath(Vec<PathSegment>);
648
649impl SubtreePath {
650 fn push(&mut self, segment: PathSegment) {
651 self.0.push(segment);
652 }
653
654 fn pop(&mut self) {
655 self.0.pop();
656 }
657}
658
659#[derive(Debug, Default)]
660struct SharedSubtreeCollector {
661 candidates: HashMap<PlanNode, Vec<(usize, SubtreePath)>>,
662}
663
664impl SharedSubtreeCollector {
665 fn register(&mut self, plan_index: usize, path: &SubtreePath, plan: PlanNode) {
666 self.candidates
667 .entry(plan)
668 .or_default()
669 .push((plan_index, path.clone()));
670 }
671}
672
673#[derive(Debug, Clone)]
674struct PromotedCandidate {
675 canonical_plan: PlanNode,
676 positions: Vec<(usize, SubtreePath)>,
677}
678
679fn collect_promoted_candidates(
680 plans: &[QueryPlan],
681 groups: &[FusionGroup],
682) -> Vec<PromotedCandidate> {
683 let mut collector = SharedSubtreeCollector::default();
684
685 for (plan_index, plan) in plans.iter().enumerate() {
686 let mut path = SubtreePath::default();
687 walk_plan_for_shared_subtrees(&plan.root, plan_index, &mut path, &mut collector);
688 }
689
690 let promoted = promote_candidates(collector.candidates, groups);
691 sort_promoted_candidates_by_containment(promoted)
692}
693
694fn materialize_shared_nodes(candidates: &[PromotedCandidate]) -> Vec<SharedNode> {
695 candidates
696 .iter()
697 .enumerate()
698 .map(|(index, candidate)| SharedNode {
699 canonical_plan: candidate.canonical_plan.clone(),
700 consumers: candidate_consumers(&candidate.positions),
701 id: SharedNodeId(index as u32),
702 })
703 .collect()
704}
705
706fn is_independently_executable_root(plan: &PlanNode) -> bool {
707 match plan {
708 PlanNode::NodeScan { .. } | PlanNode::SetOp { .. } => true,
709 PlanNode::Chain { steps } => steps.first().is_some_and(PlanNode::is_context_free),
710 PlanNode::Filter { .. } | PlanNode::EdgeTraversal { .. } => false,
711 }
712}
713
714fn walk_plan_for_shared_subtrees(
715 plan: &PlanNode,
716 plan_index: usize,
717 path: &mut SubtreePath,
718 collector: &mut SharedSubtreeCollector,
719) {
720 if is_independently_executable_root(plan) {
721 collector.register(plan_index, path, plan.clone());
722 }
723
724 match plan {
725 PlanNode::NodeScan { .. } | PlanNode::EdgeTraversal { .. } => {}
726 PlanNode::Chain { steps } => {
727 register_executable_chain_prefixes(steps, plan_index, path, collector);
728 for (step_index, step) in steps.iter().enumerate() {
729 path.push(PathSegment::ChainStep(step_index));
730 walk_plan_for_shared_subtrees(step, plan_index, path, collector);
731 path.pop();
732 }
733 }
734 PlanNode::SetOp { left, right, .. } => {
735 path.push(PathSegment::SetLeft);
736 walk_plan_for_shared_subtrees(left, plan_index, path, collector);
737 path.pop();
738
739 path.push(PathSegment::SetRight);
740 walk_plan_for_shared_subtrees(right, plan_index, path, collector);
741 path.pop();
742 }
743 PlanNode::Filter { predicate } => {
744 walk_predicate_for_shared_subtrees(predicate, plan_index, path, collector);
745 }
746 }
747}
748
749fn register_executable_chain_prefixes(
750 steps: &[PlanNode],
751 plan_index: usize,
752 path: &mut SubtreePath,
753 collector: &mut SharedSubtreeCollector,
754) {
755 if !steps.first().is_some_and(PlanNode::is_context_free) {
756 return;
757 }
758
759 for prefix_len in 2..steps.len() {
760 path.push(PathSegment::ChainPrefix(prefix_len));
761 collector.register(
762 plan_index,
763 path,
764 PlanNode::Chain {
765 steps: steps[..prefix_len].to_vec(),
766 },
767 );
768 path.pop();
769 }
770}
771
772fn walk_predicate_for_shared_subtrees(
773 pred: &Predicate,
774 plan_index: usize,
775 path: &mut SubtreePath,
776 collector: &mut SharedSubtreeCollector,
777) {
778 match pred {
779 Predicate::HasCaller
780 | Predicate::HasCallee
781 | Predicate::IsUnused
782 | Predicate::InFile(_)
783 | Predicate::InScope(_)
784 | Predicate::MatchesName(_)
785 | Predicate::Returns(_) => {}
786 Predicate::Callers(value)
787 | Predicate::Callees(value)
788 | Predicate::Imports(value)
789 | Predicate::Exports(value)
790 | Predicate::References(value)
791 | Predicate::Implements(value) => {
792 walk_predicate_value_for_shared_subtrees(value, plan_index, path, collector);
793 }
794 Predicate::And(list) => {
795 for (index, inner) in list.iter().enumerate() {
796 path.push(PathSegment::PredicateAnd(index));
797 walk_predicate_for_shared_subtrees(inner, plan_index, path, collector);
798 path.pop();
799 }
800 }
801 Predicate::Or(list) => {
802 for (index, inner) in list.iter().enumerate() {
803 path.push(PathSegment::PredicateOr(index));
804 walk_predicate_for_shared_subtrees(inner, plan_index, path, collector);
805 path.pop();
806 }
807 }
808 Predicate::Not(inner) => {
809 path.push(PathSegment::PredicateNot);
810 walk_predicate_for_shared_subtrees(inner, plan_index, path, collector);
811 path.pop();
812 }
813 }
814}
815
816fn walk_predicate_value_for_shared_subtrees(
817 value: &PredicateValue,
818 plan_index: usize,
819 path: &mut SubtreePath,
820 collector: &mut SharedSubtreeCollector,
821) {
822 if let PredicateValue::Subquery(plan) = value {
823 path.push(PathSegment::PredicateValueSubquery);
824 walk_plan_for_shared_subtrees(plan, plan_index, path, collector);
825 path.pop();
826 }
827}
828
829fn promote_candidates(
830 candidates: HashMap<PlanNode, Vec<(usize, SubtreePath)>>,
831 groups: &[FusionGroup],
832) -> Vec<PromotedCandidate> {
833 let existing_prefixes: HashSet<PlanNode> =
834 groups.iter().map(|group| group.prefix.clone()).collect();
835 let mut promoted: Vec<PromotedCandidate> = candidates
836 .into_iter()
837 .filter_map(|(canonical_plan, mut positions)| {
838 if positions.len() < 2 || existing_prefixes.contains(&canonical_plan) {
839 return None;
840 }
841
842 positions
843 .sort_by(|left, right| left.0.cmp(&right.0).then_with(|| left.1.cmp(&right.1)));
844 Some(PromotedCandidate {
845 canonical_plan,
846 positions,
847 })
848 })
849 .collect();
850
851 promoted.sort_by(|left, right| {
852 left.positions[0]
853 .0
854 .cmp(&right.positions[0].0)
855 .then_with(|| left.positions[0].1.cmp(&right.positions[0].1))
856 .then_with(|| {
857 left.canonical_plan
858 .operator_count()
859 .cmp(&right.canonical_plan.operator_count())
860 })
861 });
862
863 promoted
864}
865
866fn candidate_consumers(positions: &[(usize, SubtreePath)]) -> Vec<usize> {
867 let mut consumers = Vec::new();
868 for (plan_index, _) in positions {
869 if !consumers.contains(plan_index) {
870 consumers.push(*plan_index);
871 }
872 }
873 consumers
874}
875
876fn sort_promoted_candidates_by_containment(
877 candidates: Vec<PromotedCandidate>,
878) -> Vec<PromotedCandidate> {
879 if candidates.len() <= 1 {
880 return candidates;
881 }
882
883 let mut outgoing: HashMap<usize, Vec<usize>> = HashMap::new();
884 let mut indegree = vec![0_usize; candidates.len()];
885
886 for (parent_index, parent) in candidates.iter().enumerate() {
887 for (child_index, child) in candidates.iter().enumerate() {
888 if parent_index == child_index {
889 continue;
890 }
891 if is_proper_subtree(&child.canonical_plan, &parent.canonical_plan) {
892 outgoing.entry(child_index).or_default().push(parent_index);
893 indegree[parent_index] += 1;
894 }
895 }
896 }
897
898 let mut ready: VecDeque<usize> = indegree
899 .iter()
900 .enumerate()
901 .filter_map(|(index, degree)| (*degree == 0).then_some(index))
902 .collect();
903 let mut order = Vec::with_capacity(candidates.len());
904
905 while let Some(index) = ready.pop_front() {
906 order.push(index);
907 if let Some(edges) = outgoing.get(&index) {
908 for &dependent in edges {
909 indegree[dependent] -= 1;
910 if indegree[dependent] == 0 {
911 ready.push_back(dependent);
912 }
913 }
914 }
915 }
916
917 if order.len() != candidates.len() {
918 return candidates;
919 }
920
921 order
922 .into_iter()
923 .map(|index| candidates[index].clone())
924 .collect()
925}
926
927fn is_proper_subtree(needle: &PlanNode, haystack: &PlanNode) -> bool {
928 let mut found = false;
929 visit_proper_plan_subtrees(haystack, &mut |candidate| {
930 if !found && candidate == needle {
931 found = true;
932 }
933 });
934 found
935}
936
937fn visit_proper_plan_subtrees(plan: &PlanNode, visitor: &mut dyn FnMut(&PlanNode)) {
938 match plan {
939 PlanNode::NodeScan { .. } | PlanNode::EdgeTraversal { .. } => {}
940 PlanNode::Filter { predicate } => {
941 visit_proper_predicate_subtrees(predicate, visitor);
942 }
943 PlanNode::SetOp { left, right, .. } => {
944 visitor(left);
945 visit_proper_plan_subtrees(left, visitor);
946 visitor(right);
947 visit_proper_plan_subtrees(right, visitor);
948 }
949 PlanNode::Chain { steps } => {
950 for prefix_len in 2..steps.len() {
951 let prefix = PlanNode::Chain {
952 steps: steps[..prefix_len].to_vec(),
953 };
954 visitor(&prefix);
955 }
956
957 for step in steps {
958 visitor(step);
959 visit_proper_plan_subtrees(step, visitor);
960 }
961 }
962 }
963}
964
965fn visit_proper_predicate_subtrees(predicate: &Predicate, visitor: &mut dyn FnMut(&PlanNode)) {
966 match predicate {
967 Predicate::HasCaller
968 | Predicate::HasCallee
969 | Predicate::IsUnused
970 | Predicate::InFile(_)
971 | Predicate::InScope(_)
972 | Predicate::MatchesName(_)
973 | Predicate::Returns(_) => {}
974 Predicate::Callers(value)
975 | Predicate::Callees(value)
976 | Predicate::Imports(value)
977 | Predicate::Exports(value)
978 | Predicate::References(value)
979 | Predicate::Implements(value) => {
980 if let PredicateValue::Subquery(plan) = value {
981 visitor(plan);
982 visit_proper_plan_subtrees(plan, visitor);
983 }
984 }
985 Predicate::And(list) | Predicate::Or(list) => {
986 for inner in list {
987 visit_proper_predicate_subtrees(inner, visitor);
988 }
989 }
990 Predicate::Not(inner) => {
991 visit_proper_predicate_subtrees(inner, visitor);
992 }
993 }
994}
995
996fn estimate_plan_tree_reduction_ratio(
997 plans: &[QueryPlan],
998 candidates: &[PromotedCandidate],
999) -> f64 {
1000 let total_nodes_before: usize = plans.iter().map(|plan| plan.root.operator_count()).sum();
1001
1002 if total_nodes_before == 0 {
1003 return 0.0;
1004 }
1005
1006 let total_saved_nodes: usize = candidates
1007 .iter()
1008 .map(|candidate| {
1009 candidate
1010 .canonical_plan
1011 .operator_count()
1012 .saturating_mul(candidate.positions.len().saturating_sub(1))
1013 })
1014 .sum();
1015
1016 let bounded_saved_nodes = total_saved_nodes.min(total_nodes_before);
1017 bounded_saved_nodes as f64 / total_nodes_before as f64
1018}
1019
1020fn collect_subquery_plans(plans: &[QueryPlan]) -> (usize, Vec<QueryPlan>) {
1035 let mut total = 0_usize;
1036 let mut seen: HashMap<PlanNode, ()> = HashMap::new();
1037 let mut ordered: Vec<PlanNode> = Vec::new();
1038
1039 for plan in plans {
1040 walk_plan_for_subqueries(&plan.root, &mut total, &mut seen, &mut ordered);
1041 }
1042
1043 let dedup_plans = ordered.into_iter().map(QueryPlan::new).collect();
1044 (total, dedup_plans)
1045}
1046
1047fn walk_plan_for_subqueries(
1052 node: &PlanNode,
1053 total: &mut usize,
1054 seen: &mut HashMap<PlanNode, ()>,
1055 ordered: &mut Vec<PlanNode>,
1056) {
1057 match node {
1058 PlanNode::NodeScan { .. } | PlanNode::EdgeTraversal { .. } => {}
1059 PlanNode::Filter { predicate } => {
1060 walk_predicate_for_subqueries(predicate, total, seen, ordered);
1061 }
1062 PlanNode::SetOp { left, right, .. } => {
1063 walk_plan_for_subqueries(left, total, seen, ordered);
1064 walk_plan_for_subqueries(right, total, seen, ordered);
1065 }
1066 PlanNode::Chain { steps } => {
1067 for step in steps {
1068 walk_plan_for_subqueries(step, total, seen, ordered);
1069 }
1070 }
1071 }
1072}
1073
1074fn walk_predicate_for_subqueries(
1078 pred: &Predicate,
1079 total: &mut usize,
1080 seen: &mut HashMap<PlanNode, ()>,
1081 ordered: &mut Vec<PlanNode>,
1082) {
1083 match pred {
1084 Predicate::HasCaller
1085 | Predicate::HasCallee
1086 | Predicate::IsUnused
1087 | Predicate::InFile(_)
1088 | Predicate::InScope(_)
1089 | Predicate::MatchesName(_)
1090 | Predicate::Returns(_) => {}
1091 Predicate::Callers(v)
1092 | Predicate::Callees(v)
1093 | Predicate::Imports(v)
1094 | Predicate::Exports(v)
1095 | Predicate::References(v)
1096 | Predicate::Implements(v) => {
1097 walk_predicate_value_for_subqueries(v, total, seen, ordered);
1098 }
1099 Predicate::And(list) | Predicate::Or(list) => {
1100 for inner in list {
1101 walk_predicate_for_subqueries(inner, total, seen, ordered);
1102 }
1103 }
1104 Predicate::Not(inner) => {
1105 walk_predicate_for_subqueries(inner, total, seen, ordered);
1106 }
1107 }
1108}
1109
1110fn walk_predicate_value_for_subqueries(
1127 value: &PredicateValue,
1128 total: &mut usize,
1129 seen: &mut HashMap<PlanNode, ()>,
1130 ordered: &mut Vec<PlanNode>,
1131) {
1132 if let PredicateValue::Subquery(inner) = value {
1133 *total += 1;
1134 let key: PlanNode = (**inner).clone();
1136 if !seen.contains_key(&key) {
1137 seen.insert(key.clone(), ());
1138 ordered.push(key);
1139 }
1140 walk_plan_for_subqueries(inner, total, seen, ordered);
1142 }
1143}
1144
1145#[cfg(test)]
1150mod tests {
1151 use super::*;
1152 use crate::planner::ir::{Direction, MatchMode, PathPattern, SetOperation, StringPattern};
1153 use sqry_core::graph::unified::node::kind::NodeKind;
1154
1155 fn make_scan(kind: NodeKind) -> PlanNode {
1156 PlanNode::NodeScan {
1157 kind: Some(kind),
1158 visibility: None,
1159 name_pattern: None,
1160 }
1161 }
1162
1163 fn make_filter_has_caller() -> PlanNode {
1164 PlanNode::Filter {
1165 predicate: Predicate::HasCaller,
1166 }
1167 }
1168
1169 #[test]
1170 fn split_chain_with_multiple_steps() {
1171 let chain = PlanNode::Chain {
1172 steps: vec![
1173 make_scan(NodeKind::Function),
1174 make_filter_has_caller(),
1175 PlanNode::EdgeTraversal {
1176 direction: Direction::Forward,
1177 edge_kind: None,
1178 max_depth: 1,
1179 },
1180 ],
1181 };
1182 let (prefix, tail) = split_prefix_and_tail(chain);
1183 assert_eq!(prefix, make_scan(NodeKind::Function));
1184 match tail {
1185 FusionTail::ChainContinuation { remaining_steps } => {
1186 assert_eq!(remaining_steps.len(), 2);
1187 }
1188 FusionTail::Identity => panic!("expected ChainContinuation"),
1189 }
1190 }
1191
1192 #[test]
1193 fn split_chain_with_one_step_collapses_to_identity() {
1194 let chain = PlanNode::Chain {
1195 steps: vec![make_scan(NodeKind::Class)],
1196 };
1197 let (prefix, tail) = split_prefix_and_tail(chain);
1198 assert_eq!(prefix, make_scan(NodeKind::Class));
1199 assert_eq!(tail, FusionTail::Identity);
1200 }
1201
1202 #[test]
1203 fn split_standalone_scan_is_identity() {
1204 let scan = make_scan(NodeKind::Method);
1205 let (prefix, tail) = split_prefix_and_tail(scan.clone());
1206 assert_eq!(prefix, scan);
1207 assert_eq!(tail, FusionTail::Identity);
1208 }
1209
1210 #[test]
1211 fn split_standalone_setop_is_identity() {
1212 let set = PlanNode::SetOp {
1213 op: SetOperation::Union,
1214 left: Box::new(make_scan(NodeKind::Function)),
1215 right: Box::new(make_scan(NodeKind::Method)),
1216 };
1217 let (prefix, tail) = split_prefix_and_tail(set.clone());
1218 assert_eq!(prefix, set);
1219 assert_eq!(tail, FusionTail::Identity);
1220 }
1221
1222 #[test]
1223 fn split_malformed_chain_with_filter_first_passes_through() {
1224 let chain = PlanNode::Chain {
1225 steps: vec![make_filter_has_caller(), make_scan(NodeKind::Function)],
1226 };
1227 let original = chain.clone();
1228 let (prefix, tail) = split_prefix_and_tail(chain);
1229 assert_eq!(prefix, original);
1230 assert_eq!(tail, FusionTail::Identity);
1231 }
1232
1233 #[test]
1234 fn split_empty_chain_passes_through() {
1235 let chain = PlanNode::Chain { steps: vec![] };
1236 let (prefix, tail) = split_prefix_and_tail(chain.clone());
1237 assert_eq!(prefix, chain);
1238 assert_eq!(tail, FusionTail::Identity);
1239 }
1240
1241 #[test]
1242 fn collect_subquery_plans_empty_when_no_filters() {
1243 let plans = vec![
1244 QueryPlan::new(make_scan(NodeKind::Function)),
1245 QueryPlan::new(make_scan(NodeKind::Method)),
1246 ];
1247 let (total, dedup) = collect_subquery_plans(&plans);
1248 assert_eq!(total, 0);
1249 assert!(dedup.is_empty());
1250 }
1251
1252 #[test]
1253 fn collect_subquery_plans_dedupes_identical_subqueries() {
1254 let inner = make_scan(NodeKind::Method);
1255 let pred = |v: PlanNode| Predicate::Callers(PredicateValue::Subquery(Box::new(v)));
1256
1257 let plan_a = QueryPlan::new(PlanNode::Chain {
1258 steps: vec![
1259 make_scan(NodeKind::Function),
1260 PlanNode::Filter {
1261 predicate: pred(inner.clone()),
1262 },
1263 ],
1264 });
1265 let plan_b = QueryPlan::new(PlanNode::Chain {
1266 steps: vec![
1267 make_scan(NodeKind::Class),
1268 PlanNode::Filter {
1269 predicate: pred(inner.clone()),
1270 },
1271 ],
1272 });
1273
1274 let (total, dedup) = collect_subquery_plans(&[plan_a, plan_b]);
1275 assert_eq!(total, 2);
1276 assert_eq!(dedup.len(), 1);
1277 assert_eq!(dedup[0].root(), &inner);
1278 }
1279
1280 #[test]
1281 fn collect_subquery_plans_walks_all_predicate_arms() {
1282 let inner_a = make_scan(NodeKind::Function);
1286 let inner_b = make_scan(NodeKind::Method);
1287
1288 let sub_a = || PredicateValue::Subquery(Box::new(inner_a.clone()));
1289 let sub_b = || PredicateValue::Subquery(Box::new(inner_b.clone()));
1290
1291 let predicate = Predicate::And(vec![
1292 Predicate::Or(vec![
1293 Predicate::Callers(sub_a()),
1294 Predicate::Callees(sub_a()),
1295 Predicate::Imports(sub_b()),
1296 Predicate::Exports(sub_b()),
1297 Predicate::References(sub_a()),
1298 Predicate::Implements(sub_b()),
1299 ]),
1300 Predicate::Not(Box::new(Predicate::Callers(sub_a()))),
1301 ]);
1302
1303 let plan = QueryPlan::new(PlanNode::Chain {
1304 steps: vec![make_scan(NodeKind::Class), PlanNode::Filter { predicate }],
1305 });
1306 let (total, dedup) = collect_subquery_plans(&[plan]);
1307 assert_eq!(total, 7);
1308 assert_eq!(dedup.len(), 2);
1310 }
1311
1312 #[test]
1313 fn collect_subquery_plans_recurses_into_nested_subqueries() {
1314 let leaf = make_scan(NodeKind::Function);
1316 let nested_pred = Predicate::Callers(PredicateValue::Subquery(Box::new(leaf.clone())));
1317 let mid_plan = PlanNode::Chain {
1318 steps: vec![
1319 make_scan(NodeKind::Method),
1320 PlanNode::Filter {
1321 predicate: nested_pred,
1322 },
1323 ],
1324 };
1325 let outer = QueryPlan::new(PlanNode::Chain {
1326 steps: vec![
1327 make_scan(NodeKind::Class),
1328 PlanNode::Filter {
1329 predicate: Predicate::Callees(PredicateValue::Subquery(Box::new(mid_plan))),
1330 },
1331 ],
1332 });
1333
1334 let (total, dedup) = collect_subquery_plans(&[outer]);
1335 assert_eq!(total, 2);
1338 assert_eq!(dedup.len(), 2);
1340 }
1341
1342 #[test]
1343 fn fused_tail_reconstruct_identity() {
1344 let scan = make_scan(NodeKind::Function);
1345 let tail = FusedTail {
1346 original_index: 0,
1347 tail: FusionTail::Identity,
1348 };
1349 let plan = tail.reconstruct(&scan);
1350 assert_eq!(plan.root(), &scan);
1351 }
1352
1353 #[test]
1354 fn fused_tail_reconstruct_chain_continuation() {
1355 let scan = make_scan(NodeKind::Function);
1356 let f = make_filter_has_caller();
1357 let tail = FusedTail {
1358 original_index: 7,
1359 tail: FusionTail::ChainContinuation {
1360 remaining_steps: vec![f.clone()],
1361 },
1362 };
1363 let plan = tail.reconstruct(&scan);
1364 match plan.root() {
1365 PlanNode::Chain { steps } => {
1366 assert_eq!(steps.len(), 2);
1367 assert_eq!(&steps[0], &scan);
1368 assert_eq!(&steps[1], &f);
1369 }
1370 other => panic!("expected Chain, got {other:?}"),
1371 }
1372 }
1373
1374 #[test]
1375 fn fusion_stats_subqueries_eliminated() {
1376 let stats = FusionStats {
1377 total_plans: 5,
1378 fusion_groups: 3,
1379 scans_eliminated: 2,
1380 subqueries_total: 7,
1381 subqueries_unique: 3,
1382 shared_nodes_promoted: 0,
1383 plan_tree_reduction_ratio: 0.0,
1384 };
1385 assert_eq!(stats.subqueries_eliminated(), 4);
1386 }
1387
1388 #[test]
1389 fn fusion_stats_subqueries_eliminated_saturates() {
1390 let stats = FusionStats {
1393 total_plans: 0,
1394 fusion_groups: 0,
1395 scans_eliminated: 0,
1396 subqueries_total: 1,
1397 subqueries_unique: 5,
1398 shared_nodes_promoted: 0,
1399 plan_tree_reduction_ratio: 0.0,
1400 };
1401 assert_eq!(stats.subqueries_eliminated(), 0);
1402 }
1403
1404 #[test]
1405 fn fuse_single_round_trip() {
1406 let scan = make_scan(NodeKind::Function);
1407 let plan = QueryPlan::new(scan.clone());
1408 let batch = fuse_single(plan.clone());
1409 assert_eq!(batch.len(), 1);
1410 assert_eq!(batch.stats().total_plans, 1);
1411 assert_eq!(batch.stats().scans_eliminated, 0);
1412 let recovered = batch.input_plans();
1413 assert_eq!(recovered, vec![plan]);
1414 }
1415
1416 #[test]
1417 fn helpers_do_not_double_count_distinct_pattern_arms() {
1418 let predicate = Predicate::And(vec![
1421 Predicate::Callers(PredicateValue::Pattern(StringPattern::glob("foo*"))),
1422 Predicate::References(PredicateValue::Pattern(StringPattern {
1423 raw: "bar".into(),
1424 mode: MatchMode::Exact,
1425 case_insensitive: true,
1426 })),
1427 Predicate::InFile(PathPattern::new("src/**")),
1428 ]);
1429 let plan = QueryPlan::new(PlanNode::Chain {
1430 steps: vec![
1431 make_scan(NodeKind::Function),
1432 PlanNode::Filter { predicate },
1433 ],
1434 });
1435 let (total, dedup) = collect_subquery_plans(&[plan]);
1436 assert_eq!(total, 0);
1437 assert!(dedup.is_empty());
1438 }
1439}