1use parking_lot::RwLock;
47use std::collections::{HashMap, HashSet};
48use std::sync::Arc;
49
50#[derive(Debug, Clone)]
56pub struct CostModelConfig {
57 pub c_seq: f64,
59 pub c_random: f64,
61 pub c_filter: f64,
63 pub c_compare: f64,
65 pub block_size: usize,
67 pub btree_fanout: usize,
69 pub memory_bandwidth: f64,
71}
72
73impl Default for CostModelConfig {
74 fn default() -> Self {
75 Self {
76 c_seq: 0.1, c_random: 5.0, c_filter: 0.001, c_compare: 0.0001, block_size: 4096, btree_fanout: 100, memory_bandwidth: 10000.0, }
84 }
85}
86
87#[derive(Debug, Clone)]
93pub struct TableStats {
94 pub name: String,
96 pub row_count: u64,
98 pub size_bytes: u64,
100 pub column_stats: HashMap<String, ColumnStats>,
102 pub indices: Vec<IndexStats>,
104 pub last_updated: u64,
106}
107
108#[derive(Debug, Clone)]
110pub struct ColumnStats {
111 pub name: String,
113 pub distinct_count: u64,
115 pub null_count: u64,
117 pub min_value: Option<String>,
119 pub max_value: Option<String>,
121 pub avg_length: f64,
123 pub mcv: Vec<(String, f64)>,
125 pub histogram: Option<Histogram>,
127}
128
129#[derive(Debug, Clone)]
131pub struct Histogram {
132 pub boundaries: Vec<f64>,
134 pub counts: Vec<u64>,
136 pub total_rows: u64,
138}
139
140impl Histogram {
141 pub fn estimate_range_selectivity(&self, min: Option<f64>, max: Option<f64>) -> f64 {
143 if self.total_rows == 0 {
144 return 0.5; }
146
147 let mut selected_rows = 0u64;
148
149 for (i, &count) in self.counts.iter().enumerate() {
150 let bucket_min = if i == 0 {
151 f64::NEG_INFINITY
152 } else {
153 self.boundaries[i - 1]
154 };
155 let bucket_max = if i == self.boundaries.len() {
156 f64::INFINITY
157 } else {
158 self.boundaries[i]
159 };
160
161 let overlaps = match (min, max) {
162 (Some(min_val), Some(max_val)) => bucket_max >= min_val && bucket_min <= max_val,
163 (Some(min_val), None) => bucket_max >= min_val,
164 (None, Some(max_val)) => bucket_min <= max_val,
165 (None, None) => true,
166 };
167
168 if overlaps {
169 selected_rows += count;
170 }
171 }
172
173 selected_rows as f64 / self.total_rows as f64
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct IndexStats {
180 pub name: String,
182 pub columns: Vec<String>,
184 pub is_primary: bool,
186 pub is_unique: bool,
188 pub index_type: IndexType,
190 pub leaf_pages: u64,
192 pub height: u32,
194 pub avg_leaf_density: f64,
196}
197
198#[derive(Debug, Clone, Copy, PartialEq, Eq)]
200pub enum IndexType {
201 BTree,
202 Hash,
203 LSM,
204 Learned,
205 Vector,
206 Bloom,
207}
208
209#[derive(Debug, Clone)]
215pub enum Predicate {
216 Eq { column: String, value: String },
218 Ne { column: String, value: String },
220 Lt { column: String, value: String },
222 Le { column: String, value: String },
224 Gt { column: String, value: String },
226 Ge { column: String, value: String },
228 Between {
230 column: String,
231 min: String,
232 max: String,
233 },
234 In { column: String, values: Vec<String> },
236 Like { column: String, pattern: String },
238 IsNull { column: String },
240 IsNotNull { column: String },
242 And(Box<Predicate>, Box<Predicate>),
244 Or(Box<Predicate>, Box<Predicate>),
246 Not(Box<Predicate>),
248}
249
250impl Predicate {
251 pub fn referenced_columns(&self) -> HashSet<String> {
253 let mut cols = HashSet::new();
254 self.collect_columns(&mut cols);
255 cols
256 }
257
258 fn collect_columns(&self, cols: &mut HashSet<String>) {
259 match self {
260 Self::Eq { column, .. }
261 | Self::Ne { column, .. }
262 | Self::Lt { column, .. }
263 | Self::Le { column, .. }
264 | Self::Gt { column, .. }
265 | Self::Ge { column, .. }
266 | Self::Between { column, .. }
267 | Self::In { column, .. }
268 | Self::Like { column, .. }
269 | Self::IsNull { column }
270 | Self::IsNotNull { column } => {
271 cols.insert(column.clone());
272 }
273 Self::And(left, right) | Self::Or(left, right) => {
274 left.collect_columns(cols);
275 right.collect_columns(cols);
276 }
277 Self::Not(inner) => inner.collect_columns(cols),
278 }
279 }
280}
281
282#[derive(Debug, Clone)]
288pub enum PhysicalPlan {
289 TableScan {
291 table: String,
292 columns: Vec<String>,
293 predicate: Option<Box<Predicate>>,
294 estimated_rows: u64,
295 estimated_cost: f64,
296 },
297 IndexSeek {
299 table: String,
300 index: String,
301 columns: Vec<String>,
302 key_range: KeyRange,
303 predicate: Option<Box<Predicate>>,
304 estimated_rows: u64,
305 estimated_cost: f64,
306 },
307 Filter {
309 input: Box<PhysicalPlan>,
310 predicate: Predicate,
311 estimated_rows: u64,
312 estimated_cost: f64,
313 },
314 Project {
316 input: Box<PhysicalPlan>,
317 columns: Vec<String>,
318 estimated_cost: f64,
319 },
320 Sort {
322 input: Box<PhysicalPlan>,
323 order_by: Vec<(String, SortDirection)>,
324 estimated_cost: f64,
325 },
326 Limit {
328 input: Box<PhysicalPlan>,
329 limit: u64,
330 offset: u64,
331 estimated_cost: f64,
332 },
333 NestedLoopJoin {
335 outer: Box<PhysicalPlan>,
336 inner: Box<PhysicalPlan>,
337 condition: Predicate,
338 join_type: JoinType,
339 estimated_rows: u64,
340 estimated_cost: f64,
341 },
342 HashJoin {
344 build: Box<PhysicalPlan>,
345 probe: Box<PhysicalPlan>,
346 build_keys: Vec<String>,
347 probe_keys: Vec<String>,
348 join_type: JoinType,
349 estimated_rows: u64,
350 estimated_cost: f64,
351 },
352 MergeJoin {
354 left: Box<PhysicalPlan>,
355 right: Box<PhysicalPlan>,
356 left_keys: Vec<String>,
357 right_keys: Vec<String>,
358 join_type: JoinType,
359 estimated_rows: u64,
360 estimated_cost: f64,
361 },
362 Aggregate {
364 input: Box<PhysicalPlan>,
365 group_by: Vec<String>,
366 aggregates: Vec<AggregateExpr>,
367 estimated_rows: u64,
368 estimated_cost: f64,
369 },
370}
371
372#[derive(Debug, Clone)]
374pub struct KeyRange {
375 pub start: Option<Vec<u8>>,
376 pub end: Option<Vec<u8>>,
377 pub start_inclusive: bool,
378 pub end_inclusive: bool,
379}
380
381impl KeyRange {
382 pub fn all() -> Self {
383 Self {
384 start: None,
385 end: None,
386 start_inclusive: true,
387 end_inclusive: true,
388 }
389 }
390
391 pub fn point(key: Vec<u8>) -> Self {
392 Self {
393 start: Some(key.clone()),
394 end: Some(key),
395 start_inclusive: true,
396 end_inclusive: true,
397 }
398 }
399
400 pub fn range(start: Option<Vec<u8>>, end: Option<Vec<u8>>, inclusive: bool) -> Self {
401 Self {
402 start,
403 end,
404 start_inclusive: inclusive,
405 end_inclusive: inclusive,
406 }
407 }
408}
409
410#[derive(Debug, Clone, Copy, PartialEq, Eq)]
412pub enum SortDirection {
413 Ascending,
414 Descending,
415}
416
417#[derive(Debug, Clone, Copy, PartialEq, Eq)]
419pub enum JoinType {
420 Inner,
421 Left,
422 Right,
423 Full,
424 Cross,
425}
426
427#[derive(Debug, Clone)]
429pub struct AggregateExpr {
430 pub function: AggregateFunction,
431 pub column: Option<String>,
432 pub alias: String,
433}
434
435#[derive(Debug, Clone, Copy, PartialEq, Eq)]
437pub enum AggregateFunction {
438 Count,
439 Sum,
440 Avg,
441 Min,
442 Max,
443 CountDistinct,
444}
445
446pub struct CostBasedOptimizer {
452 config: CostModelConfig,
454 stats_cache: Arc<RwLock<HashMap<String, TableStats>>>,
456 token_budget: Option<u64>,
458 tokens_per_row: f64,
460}
461
462impl CostBasedOptimizer {
463 pub fn new(config: CostModelConfig) -> Self {
464 Self {
465 config,
466 stats_cache: Arc::new(RwLock::new(HashMap::new())),
467 token_budget: None,
468 tokens_per_row: 25.0, }
470 }
471
472 pub fn with_token_budget(mut self, budget: u64, tokens_per_row: f64) -> Self {
474 self.token_budget = Some(budget);
475 self.tokens_per_row = tokens_per_row;
476 self
477 }
478
479 pub fn update_stats(&self, stats: TableStats) {
481 self.stats_cache.write().insert(stats.name.clone(), stats);
482 }
483
484 pub fn get_stats(&self, table: &str) -> Option<TableStats> {
486 self.stats_cache.read().get(table).cloned()
487 }
488
489 pub fn optimize(
491 &self,
492 table: &str,
493 columns: Vec<String>,
494 predicate: Option<Predicate>,
495 order_by: Vec<(String, SortDirection)>,
496 limit: Option<u64>,
497 ) -> PhysicalPlan {
498 let stats = self.get_stats(table);
499
500 let effective_limit = self.calculate_token_limit(limit);
502
503 let mut plan = self.choose_access_path(table, &columns, predicate.as_ref(), &stats);
505
506 plan = self.apply_projection_pushdown(plan, columns.clone());
508
509 if !order_by.is_empty() {
511 plan = self.add_sort(plan, order_by, &stats);
512 }
513
514 if let Some(lim) = effective_limit {
516 plan = PhysicalPlan::Limit {
517 estimated_cost: 0.0,
518 input: Box::new(plan),
519 limit: lim,
520 offset: 0,
521 };
522 }
523
524 plan
525 }
526
527 fn calculate_token_limit(&self, user_limit: Option<u64>) -> Option<u64> {
529 match (self.token_budget, user_limit) {
530 (Some(budget), Some(limit)) => {
531 let header_tokens = 50u64;
532 let max_rows = ((budget - header_tokens) as f64 / self.tokens_per_row) as u64;
533 Some(limit.min(max_rows))
534 }
535 (Some(budget), None) => {
536 let header_tokens = 50u64;
537 let max_rows = ((budget - header_tokens) as f64 / self.tokens_per_row) as u64;
538 Some(max_rows)
539 }
540 (None, limit) => limit,
541 }
542 }
543
544 fn choose_access_path(
546 &self,
547 table: &str,
548 columns: &[String],
549 predicate: Option<&Predicate>,
550 stats: &Option<TableStats>,
551 ) -> PhysicalPlan {
552 let row_count = stats.as_ref().map(|s| s.row_count).unwrap_or(10000);
553 let size_bytes = stats
554 .as_ref()
555 .map(|s| s.size_bytes)
556 .unwrap_or(row_count * 100);
557
558 let scan_cost = self.estimate_scan_cost(row_count, size_bytes, predicate);
560
561 let mut best_index_cost = f64::MAX;
563 let mut best_index: Option<&IndexStats> = None;
564
565 if let Some(table_stats) = stats.as_ref()
566 && let Some(pred) = predicate
567 {
568 let pred_columns = pred.referenced_columns();
569
570 for index in &table_stats.indices {
571 if self.index_covers_predicate(index, &pred_columns) {
572 let selectivity = self.estimate_selectivity(pred, table_stats);
573 let index_cost = self.estimate_index_cost(index, row_count, selectivity);
574
575 if index_cost < best_index_cost {
576 best_index_cost = index_cost;
577 best_index = Some(index);
578 }
579 }
580 }
581 }
582
583 if best_index_cost < scan_cost {
585 let index = best_index.unwrap();
586 let selectivity = predicate
587 .map(|p| self.estimate_selectivity(p, stats.as_ref().unwrap()))
588 .unwrap_or(1.0);
589
590 PhysicalPlan::IndexSeek {
591 table: table.to_string(),
592 index: index.name.clone(),
593 columns: columns.to_vec(),
594 key_range: KeyRange::all(), predicate: predicate.map(|p| Box::new(p.clone())),
596 estimated_rows: (row_count as f64 * selectivity) as u64,
597 estimated_cost: best_index_cost,
598 }
599 } else {
600 PhysicalPlan::TableScan {
601 table: table.to_string(),
602 columns: columns.to_vec(),
603 predicate: predicate.map(|p| Box::new(p.clone())),
604 estimated_rows: row_count,
605 estimated_cost: scan_cost,
606 }
607 }
608 }
609
610 fn index_covers_predicate(&self, index: &IndexStats, pred_columns: &HashSet<String>) -> bool {
612 if let Some(first_col) = index.columns.first() {
614 pred_columns.contains(first_col)
615 } else {
616 false
617 }
618 }
619
620 fn estimate_scan_cost(
622 &self,
623 row_count: u64,
624 size_bytes: u64,
625 predicate: Option<&Predicate>,
626 ) -> f64 {
627 let blocks = (size_bytes as f64 / self.config.block_size as f64).ceil() as u64;
628
629 let io_cost = blocks as f64 * self.config.c_seq;
631
632 let selectivity = predicate.map(|_| 0.1).unwrap_or(1.0);
634 let cpu_cost = row_count as f64 * self.config.c_filter * selectivity;
635
636 io_cost + cpu_cost
637 }
638
639 fn estimate_index_cost(&self, index: &IndexStats, total_rows: u64, selectivity: f64) -> f64 {
643 let tree_cost = index.height as f64 * self.config.c_random;
645
646 let matching_rows = (total_rows as f64 * selectivity) as u64;
648 let leaf_pages_scanned = (matching_rows as f64 / index.avg_leaf_density).ceil() as u64;
649 let leaf_cost = leaf_pages_scanned as f64 * self.config.c_seq;
650
651 let fetch_cost = if index.is_primary {
653 0.0 } else {
655 matching_rows.min(1000) as f64 * self.config.c_random * 0.1 };
657
658 tree_cost + leaf_cost + fetch_cost
659 }
660
661 #[allow(clippy::only_used_in_recursion)]
663 fn estimate_selectivity(&self, predicate: &Predicate, stats: &TableStats) -> f64 {
664 match predicate {
665 Predicate::Eq { column, value } => {
666 if let Some(col_stats) = stats.column_stats.get(column) {
667 for (mcv_val, freq) in &col_stats.mcv {
669 if mcv_val == value {
670 return *freq;
671 }
672 }
673 1.0 / col_stats.distinct_count.max(1) as f64
675 } else {
676 0.1 }
678 }
679 Predicate::Ne { .. } => 0.9, Predicate::Lt { column, value }
681 | Predicate::Le { column, value }
682 | Predicate::Gt { column, value }
683 | Predicate::Ge { column, value } => {
684 if let Some(col_stats) = stats.column_stats.get(column) {
685 if let Some(ref hist) = col_stats.histogram {
686 let val: f64 = value.parse().unwrap_or(0.0);
687 match predicate {
688 Predicate::Lt { .. } | Predicate::Le { .. } => {
689 hist.estimate_range_selectivity(None, Some(val))
690 }
691 _ => hist.estimate_range_selectivity(Some(val), None),
692 }
693 } else {
694 0.25 }
696 } else {
697 0.25
698 }
699 }
700 Predicate::Between { column, min, max } => {
701 if let Some(col_stats) = stats.column_stats.get(column) {
702 if let Some(ref hist) = col_stats.histogram {
703 let min_val: f64 = min.parse().unwrap_or(0.0);
704 let max_val: f64 = max.parse().unwrap_or(f64::MAX);
705 hist.estimate_range_selectivity(Some(min_val), Some(max_val))
706 } else {
707 0.2
708 }
709 } else {
710 0.2
711 }
712 }
713 Predicate::In { column, values } => {
714 if let Some(col_stats) = stats.column_stats.get(column) {
715 (values.len() as f64 / col_stats.distinct_count.max(1) as f64).min(1.0)
716 } else {
717 (values.len() as f64 * 0.1).min(0.5)
718 }
719 }
720 Predicate::Like { .. } => 0.15, Predicate::IsNull { column } => {
722 if let Some(col_stats) = stats.column_stats.get(column) {
723 col_stats.null_count as f64 / stats.row_count.max(1) as f64
724 } else {
725 0.01
726 }
727 }
728 Predicate::IsNotNull { column } => {
729 if let Some(col_stats) = stats.column_stats.get(column) {
730 1.0 - (col_stats.null_count as f64 / stats.row_count.max(1) as f64)
731 } else {
732 0.99
733 }
734 }
735 Predicate::And(left, right) => {
736 self.estimate_selectivity(left, stats) * self.estimate_selectivity(right, stats)
738 }
739 Predicate::Or(left, right) => {
740 let s1 = self.estimate_selectivity(left, stats);
741 let s2 = self.estimate_selectivity(right, stats);
742 (s1 + s2 - s1 * s2).min(1.0)
744 }
745 Predicate::Not(inner) => 1.0 - self.estimate_selectivity(inner, stats),
746 }
747 }
748
749 fn apply_projection_pushdown(&self, plan: PhysicalPlan, columns: Vec<String>) -> PhysicalPlan {
751 match plan {
753 PhysicalPlan::TableScan {
754 table,
755 predicate,
756 estimated_rows,
757 estimated_cost,
758 ..
759 } => {
760 PhysicalPlan::TableScan {
761 table,
762 columns, predicate,
764 estimated_rows,
765 estimated_cost: estimated_cost * 0.2, }
767 }
768 PhysicalPlan::IndexSeek {
769 table,
770 index,
771 key_range,
772 predicate,
773 estimated_rows,
774 estimated_cost,
775 ..
776 } => {
777 PhysicalPlan::IndexSeek {
778 table,
779 index,
780 columns, key_range,
782 predicate,
783 estimated_rows,
784 estimated_cost,
785 }
786 }
787 other => PhysicalPlan::Project {
788 input: Box::new(other),
789 columns,
790 estimated_cost: 0.0,
791 },
792 }
793 }
794
795 fn add_sort(
797 &self,
798 plan: PhysicalPlan,
799 order_by: Vec<(String, SortDirection)>,
800 _stats: &Option<TableStats>,
801 ) -> PhysicalPlan {
802 let estimated_rows = self.get_plan_rows(&plan);
803 let sort_cost = if estimated_rows > 0 {
804 estimated_rows as f64 * (estimated_rows as f64).log2() * self.config.c_compare
805 } else {
806 0.0
807 };
808
809 PhysicalPlan::Sort {
810 input: Box::new(plan),
811 order_by,
812 estimated_cost: sort_cost,
813 }
814 }
815
816 #[allow(clippy::only_used_in_recursion)]
818 fn get_plan_rows(&self, plan: &PhysicalPlan) -> u64 {
819 match plan {
820 PhysicalPlan::TableScan { estimated_rows, .. }
821 | PhysicalPlan::IndexSeek { estimated_rows, .. }
822 | PhysicalPlan::Filter { estimated_rows, .. }
823 | PhysicalPlan::Aggregate { estimated_rows, .. }
824 | PhysicalPlan::NestedLoopJoin { estimated_rows, .. }
825 | PhysicalPlan::HashJoin { estimated_rows, .. }
826 | PhysicalPlan::MergeJoin { estimated_rows, .. } => *estimated_rows,
827 PhysicalPlan::Project { input, .. } | PhysicalPlan::Sort { input, .. } => {
828 self.get_plan_rows(input)
829 }
830 PhysicalPlan::Limit { limit, .. } => *limit,
831 }
832 }
833
834 #[allow(clippy::only_used_in_recursion)]
836 pub fn get_plan_cost(&self, plan: &PhysicalPlan) -> f64 {
837 match plan {
838 PhysicalPlan::TableScan { estimated_cost, .. } => *estimated_cost,
839 PhysicalPlan::IndexSeek { estimated_cost, .. } => *estimated_cost,
840 PhysicalPlan::Filter {
841 estimated_cost,
842 input,
843 ..
844 } => *estimated_cost + self.get_plan_cost(input),
845 PhysicalPlan::Project {
846 estimated_cost,
847 input,
848 ..
849 } => *estimated_cost + self.get_plan_cost(input),
850 PhysicalPlan::Sort {
851 estimated_cost,
852 input,
853 ..
854 } => *estimated_cost + self.get_plan_cost(input),
855 PhysicalPlan::Limit {
856 estimated_cost,
857 input,
858 ..
859 } => *estimated_cost + self.get_plan_cost(input),
860 PhysicalPlan::NestedLoopJoin {
861 estimated_cost,
862 outer,
863 inner,
864 ..
865 } => *estimated_cost + self.get_plan_cost(outer) + self.get_plan_cost(inner),
866 PhysicalPlan::HashJoin {
867 estimated_cost,
868 build,
869 probe,
870 ..
871 } => *estimated_cost + self.get_plan_cost(build) + self.get_plan_cost(probe),
872 PhysicalPlan::MergeJoin {
873 estimated_cost,
874 left,
875 right,
876 ..
877 } => *estimated_cost + self.get_plan_cost(left) + self.get_plan_cost(right),
878 PhysicalPlan::Aggregate {
879 estimated_cost,
880 input,
881 ..
882 } => *estimated_cost + self.get_plan_cost(input),
883 }
884 }
885
886 pub fn explain(&self, plan: &PhysicalPlan) -> String {
888 self.explain_impl(plan, 0)
889 }
890
891 fn explain_impl(&self, plan: &PhysicalPlan, indent: usize) -> String {
892 let prefix = " ".repeat(indent);
893 let cost = self.get_plan_cost(plan);
894
895 match plan {
896 PhysicalPlan::TableScan {
897 table,
898 columns,
899 estimated_rows,
900 ..
901 } => {
902 format!(
903 "{}TableScan [table={}, columns={:?}, rows={}, cost={:.2}ms]",
904 prefix, table, columns, estimated_rows, cost
905 )
906 }
907 PhysicalPlan::IndexSeek {
908 table,
909 index,
910 columns,
911 estimated_rows,
912 ..
913 } => {
914 format!(
915 "{}IndexSeek [table={}, index={}, columns={:?}, rows={}, cost={:.2}ms]",
916 prefix, table, index, columns, estimated_rows, cost
917 )
918 }
919 PhysicalPlan::Filter {
920 input,
921 estimated_rows,
922 ..
923 } => {
924 format!(
925 "{}Filter [rows={}, cost={:.2}ms]\n{}",
926 prefix,
927 estimated_rows,
928 cost,
929 self.explain_impl(input, indent + 1)
930 )
931 }
932 PhysicalPlan::Project { input, columns, .. } => {
933 format!(
934 "{}Project [columns={:?}, cost={:.2}ms]\n{}",
935 prefix,
936 columns,
937 cost,
938 self.explain_impl(input, indent + 1)
939 )
940 }
941 PhysicalPlan::Sort {
942 input, order_by, ..
943 } => {
944 let order: Vec<_> = order_by
945 .iter()
946 .map(|(c, d)| format!("{} {:?}", c, d))
947 .collect();
948 format!(
949 "{}Sort [order={:?}, cost={:.2}ms]\n{}",
950 prefix,
951 order,
952 cost,
953 self.explain_impl(input, indent + 1)
954 )
955 }
956 PhysicalPlan::Limit {
957 input,
958 limit,
959 offset,
960 ..
961 } => {
962 format!(
963 "{}Limit [limit={}, offset={}, cost={:.2}ms]\n{}",
964 prefix,
965 limit,
966 offset,
967 cost,
968 self.explain_impl(input, indent + 1)
969 )
970 }
971 PhysicalPlan::HashJoin {
972 build,
973 probe,
974 join_type,
975 estimated_rows,
976 ..
977 } => {
978 format!(
979 "{}HashJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
980 prefix,
981 join_type,
982 estimated_rows,
983 cost,
984 self.explain_impl(build, indent + 1),
985 self.explain_impl(probe, indent + 1)
986 )
987 }
988 PhysicalPlan::MergeJoin {
989 left,
990 right,
991 join_type,
992 estimated_rows,
993 ..
994 } => {
995 format!(
996 "{}MergeJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
997 prefix,
998 join_type,
999 estimated_rows,
1000 cost,
1001 self.explain_impl(left, indent + 1),
1002 self.explain_impl(right, indent + 1)
1003 )
1004 }
1005 PhysicalPlan::NestedLoopJoin {
1006 outer,
1007 inner,
1008 join_type,
1009 estimated_rows,
1010 ..
1011 } => {
1012 format!(
1013 "{}NestedLoopJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1014 prefix,
1015 join_type,
1016 estimated_rows,
1017 cost,
1018 self.explain_impl(outer, indent + 1),
1019 self.explain_impl(inner, indent + 1)
1020 )
1021 }
1022 PhysicalPlan::Aggregate {
1023 input,
1024 group_by,
1025 aggregates,
1026 estimated_rows,
1027 ..
1028 } => {
1029 let aggs: Vec<_> = aggregates
1030 .iter()
1031 .map(|a| format!("{:?}({})", a.function, a.column.as_deref().unwrap_or("*")))
1032 .collect();
1033 format!(
1034 "{}Aggregate [group_by={:?}, aggs={:?}, rows={}, cost={:.2}ms]\n{}",
1035 prefix,
1036 group_by,
1037 aggs,
1038 estimated_rows,
1039 cost,
1040 self.explain_impl(input, indent + 1)
1041 )
1042 }
1043 }
1044 }
1045}
1046
1047pub struct JoinOrderOptimizer {
1053 stats: HashMap<String, TableStats>,
1055 config: CostModelConfig,
1057}
1058
1059impl JoinOrderOptimizer {
1060 pub fn new(config: CostModelConfig) -> Self {
1061 Self {
1062 stats: HashMap::new(),
1063 config,
1064 }
1065 }
1066
1067 pub fn add_stats(&mut self, stats: TableStats) {
1069 self.stats.insert(stats.name.clone(), stats);
1070 }
1071
1072 pub fn find_optimal_order(
1077 &self,
1078 tables: &[String],
1079 join_conditions: &[(String, String, String, String)], ) -> Vec<(String, String)> {
1081 let n = tables.len();
1082 if n <= 1 {
1083 return vec![];
1084 }
1085
1086 let mut dp: HashMap<u32, (f64, Vec<(String, String)>)> = HashMap::new();
1088
1089 for (i, _table) in tables.iter().enumerate() {
1091 let mask = 1u32 << i;
1092 dp.insert(mask, (0.0, vec![]));
1093 }
1094
1095 for size in 2..=n {
1097 for mask in 0..(1u32 << n) {
1098 if mask.count_ones() != size as u32 {
1099 continue;
1100 }
1101
1102 let mut best_cost = f64::MAX;
1103 let mut best_order = vec![];
1104
1105 for sub in 1..mask {
1107 if sub & mask != sub || sub == 0 {
1108 continue;
1109 }
1110 let other = mask ^ sub;
1111 if other == 0 {
1112 continue;
1113 }
1114
1115 if !self.has_join_condition(tables, sub, other, join_conditions) {
1117 continue;
1118 }
1119
1120 if let (Some((cost1, order1)), Some((cost2, order2))) =
1121 (dp.get(&sub), dp.get(&other))
1122 {
1123 let join_cost = self.estimate_join_cost(tables, sub, other);
1124 let total_cost = cost1 + cost2 + join_cost;
1125
1126 if total_cost < best_cost {
1127 best_cost = total_cost;
1128 best_order = order1.clone();
1129 best_order.extend(order2.clone());
1130
1131 let (t1, t2) =
1133 self.get_join_tables(tables, sub, other, join_conditions);
1134 if let Some((t1, t2)) = Some((t1, t2)) {
1135 best_order.push((t1, t2));
1136 }
1137 }
1138 }
1139 }
1140
1141 if best_cost < f64::MAX {
1142 dp.insert(mask, (best_cost, best_order));
1143 }
1144 }
1145 }
1146
1147 let full_mask = (1u32 << n) - 1;
1148 dp.get(&full_mask)
1149 .map(|(_, order)| order.clone())
1150 .unwrap_or_default()
1151 }
1152
1153 fn has_join_condition(
1154 &self,
1155 tables: &[String],
1156 mask1: u32,
1157 mask2: u32,
1158 conditions: &[(String, String, String, String)],
1159 ) -> bool {
1160 for (t1, _, t2, _) in conditions {
1161 let idx1 = tables.iter().position(|t| t == t1);
1162 let idx2 = tables.iter().position(|t| t == t2);
1163
1164 if let (Some(i1), Some(i2)) = (idx1, idx2) {
1165 let in_mask1 = (mask1 >> i1) & 1 == 1;
1166 let in_mask2 = (mask2 >> i2) & 1 == 1;
1167
1168 if in_mask1 && in_mask2 {
1169 return true;
1170 }
1171 }
1172 }
1173 false
1174 }
1175
1176 fn get_join_tables(
1177 &self,
1178 tables: &[String],
1179 mask1: u32,
1180 mask2: u32,
1181 conditions: &[(String, String, String, String)],
1182 ) -> (String, String) {
1183 for (t1, _, t2, _) in conditions {
1184 let idx1 = tables.iter().position(|t| t == t1);
1185 let idx2 = tables.iter().position(|t| t == t2);
1186
1187 if let (Some(i1), Some(i2)) = (idx1, idx2) {
1188 let t1_in_mask1 = (mask1 >> i1) & 1 == 1;
1189 let t2_in_mask2 = (mask2 >> i2) & 1 == 1;
1190
1191 if t1_in_mask1 && t2_in_mask2 {
1192 return (t1.clone(), t2.clone());
1193 }
1194 }
1195 }
1196 (String::new(), String::new())
1197 }
1198
1199 fn estimate_join_cost(&self, tables: &[String], mask1: u32, mask2: u32) -> f64 {
1200 let rows1 = self.estimate_rows_for_mask(tables, mask1);
1201 let rows2 = self.estimate_rows_for_mask(tables, mask2);
1202
1203 let build_cost = rows1 as f64 * self.config.c_filter;
1206 let probe_cost = rows2 as f64 * self.config.c_filter;
1207
1208 build_cost + probe_cost
1209 }
1210
1211 fn estimate_rows_for_mask(&self, tables: &[String], mask: u32) -> u64 {
1212 let mut total = 1u64;
1213
1214 for (i, table) in tables.iter().enumerate() {
1215 if (mask >> i) & 1 == 1 {
1216 let rows = self.stats.get(table).map(|s| s.row_count).unwrap_or(1000);
1217 total = total.saturating_mul(rows);
1218 }
1219 }
1220
1221 let num_tables = mask.count_ones();
1223 if num_tables > 1 {
1224 total = (total as f64 * 0.1f64.powi(num_tables as i32 - 1)) as u64;
1225 }
1226
1227 total.max(1)
1228 }
1229}
1230
1231#[cfg(test)]
1236mod tests {
1237 use super::*;
1238
1239 fn create_test_stats() -> TableStats {
1240 let mut column_stats = HashMap::new();
1241 column_stats.insert(
1242 "id".to_string(),
1243 ColumnStats {
1244 name: "id".to_string(),
1245 distinct_count: 100000,
1246 null_count: 0,
1247 min_value: Some("1".to_string()),
1248 max_value: Some("100000".to_string()),
1249 avg_length: 8.0,
1250 mcv: vec![],
1251 histogram: None,
1252 },
1253 );
1254 column_stats.insert(
1255 "score".to_string(),
1256 ColumnStats {
1257 name: "score".to_string(),
1258 distinct_count: 100,
1259 null_count: 1000,
1260 min_value: Some("0".to_string()),
1261 max_value: Some("100".to_string()),
1262 avg_length: 8.0,
1263 mcv: vec![("50".to_string(), 0.05)],
1264 histogram: Some(Histogram {
1265 boundaries: vec![25.0, 50.0, 75.0, 100.0],
1266 counts: vec![25000, 25000, 25000, 25000],
1267 total_rows: 100000,
1268 }),
1269 },
1270 );
1271
1272 TableStats {
1273 name: "users".to_string(),
1274 row_count: 100000,
1275 size_bytes: 10_000_000, column_stats,
1277 indices: vec![
1278 IndexStats {
1279 name: "pk_users".to_string(),
1280 columns: vec!["id".to_string()],
1281 is_primary: true,
1282 is_unique: true,
1283 index_type: IndexType::BTree,
1284 leaf_pages: 1000,
1285 height: 3,
1286 avg_leaf_density: 100.0,
1287 },
1288 IndexStats {
1289 name: "idx_score".to_string(),
1290 columns: vec!["score".to_string()],
1291 is_primary: false,
1292 is_unique: false,
1293 index_type: IndexType::BTree,
1294 leaf_pages: 500,
1295 height: 2,
1296 avg_leaf_density: 200.0,
1297 },
1298 ],
1299 last_updated: 0,
1300 }
1301 }
1302
1303 #[test]
1304 fn test_selectivity_estimation() {
1305 let config = CostModelConfig::default();
1306 let optimizer = CostBasedOptimizer::new(config);
1307
1308 let stats = create_test_stats();
1309 optimizer.update_stats(stats.clone());
1310
1311 let pred = Predicate::Eq {
1313 column: "id".to_string(),
1314 value: "12345".to_string(),
1315 };
1316 let sel = optimizer.estimate_selectivity(&pred, &stats);
1317 assert!(sel < 0.001); let pred = Predicate::Gt {
1323 column: "score".to_string(),
1324 value: "75".to_string(),
1325 };
1326 let sel = optimizer.estimate_selectivity(&pred, &stats);
1327 assert!(sel > 0.4 && sel < 0.6); }
1329
1330 #[test]
1331 fn test_access_path_selection() {
1332 let config = CostModelConfig::default();
1333 let optimizer = CostBasedOptimizer::new(config);
1334
1335 let stats = create_test_stats();
1336 optimizer.update_stats(stats);
1337
1338 let pred = Predicate::Eq {
1340 column: "id".to_string(),
1341 value: "12345".to_string(),
1342 };
1343 let plan = optimizer.optimize(
1344 "users",
1345 vec!["id".to_string(), "score".to_string()],
1346 Some(pred),
1347 vec![],
1348 None,
1349 );
1350
1351 match plan {
1352 PhysicalPlan::IndexSeek { index, .. } => {
1353 assert_eq!(index, "pk_users");
1354 }
1355 _ => panic!("Expected IndexSeek for equality on primary key"),
1356 }
1357 }
1358
1359 #[test]
1360 fn test_token_budget_limit() {
1361 let config = CostModelConfig::default();
1362 let optimizer = CostBasedOptimizer::new(config).with_token_budget(2048, 25.0);
1363
1364 let plan = optimizer.optimize("users", vec!["id".to_string()], None, vec![], None);
1367
1368 match plan {
1369 PhysicalPlan::Limit { limit, .. } => {
1370 assert!(limit <= 80);
1371 }
1372 _ => panic!("Expected Limit to be injected"),
1373 }
1374 }
1375
1376 #[test]
1377 fn test_explain_output() {
1378 let config = CostModelConfig::default();
1379 let optimizer = CostBasedOptimizer::new(config);
1380
1381 let stats = create_test_stats();
1382 optimizer.update_stats(stats);
1383
1384 let plan = optimizer.optimize(
1385 "users",
1386 vec!["id".to_string(), "score".to_string()],
1387 Some(Predicate::Gt {
1388 column: "score".to_string(),
1389 value: "80".to_string(),
1390 }),
1391 vec![("score".to_string(), SortDirection::Descending)],
1392 Some(10),
1393 );
1394
1395 let explain = optimizer.explain(&plan);
1396 assert!(explain.contains("Limit"));
1397 assert!(explain.contains("Sort"));
1398 }
1399}