1use std::collections::HashMap;
49use std::sync::atomic::{AtomicU64, Ordering};
50use std::sync::{Arc, RwLock};
51
52use super::chunk::DataChunk;
53use super::operators::OperatorError;
54use super::pipeline::{ChunkSizeHint, PushOperator, Sink};
55
56pub const DEFAULT_REOPTIMIZATION_THRESHOLD: f64 = 3.0;
59
60pub const MIN_ROWS_FOR_REOPTIMIZATION: u64 = 1000;
63
64#[derive(Debug, Clone)]
66pub struct CardinalityCheckpoint {
67 pub operator_id: String,
69 pub estimated: f64,
71 pub actual: u64,
73 pub recorded: bool,
75}
76
77impl CardinalityCheckpoint {
78 #[must_use]
80 pub fn new(operator_id: &str, estimated: f64) -> Self {
81 Self {
82 operator_id: operator_id.to_string(),
83 estimated,
84 actual: 0,
85 recorded: false,
86 }
87 }
88
89 pub fn record(&mut self, actual: u64) {
91 self.actual = actual;
92 self.recorded = true;
93 }
94
95 #[must_use]
100 pub fn deviation_ratio(&self) -> f64 {
101 if self.estimated <= 0.0 {
102 return if self.actual == 0 { 1.0 } else { f64::INFINITY };
103 }
104 self.actual as f64 / self.estimated
105 }
106
107 #[must_use]
109 pub fn absolute_deviation(&self) -> f64 {
110 (self.actual as f64 - self.estimated).abs()
111 }
112
113 #[must_use]
115 pub fn is_significant_deviation(&self, threshold: f64) -> bool {
116 if !self.recorded {
117 return false;
118 }
119 let ratio = self.deviation_ratio();
120 ratio > threshold || ratio < 1.0 / threshold
121 }
122}
123
124#[derive(Debug, Default)]
129pub struct CardinalityFeedback {
130 actuals: HashMap<String, u64>,
132 running_counts: HashMap<String, AtomicU64>,
134}
135
136impl CardinalityFeedback {
137 #[must_use]
139 pub fn new() -> Self {
140 Self {
141 actuals: HashMap::new(),
142 running_counts: HashMap::new(),
143 }
144 }
145
146 pub fn record(&mut self, operator_id: &str, count: u64) {
148 self.actuals.insert(operator_id.to_string(), count);
149 }
150
151 pub fn add_rows(&self, operator_id: &str, count: u64) {
153 if let Some(counter) = self.running_counts.get(operator_id) {
154 counter.fetch_add(count, Ordering::Relaxed);
155 }
156 }
157
158 pub fn init_counter(&mut self, operator_id: &str) {
160 self.running_counts
161 .insert(operator_id.to_string(), AtomicU64::new(0));
162 }
163
164 pub fn finalize_counter(&mut self, operator_id: &str) {
166 if let Some(counter) = self.running_counts.get(operator_id) {
167 let count = counter.load(Ordering::Relaxed);
168 self.actuals.insert(operator_id.to_string(), count);
169 }
170 }
171
172 #[must_use]
174 pub fn get(&self, operator_id: &str) -> Option<u64> {
175 self.actuals.get(operator_id).copied()
176 }
177
178 #[must_use]
180 pub fn get_running(&self, operator_id: &str) -> Option<u64> {
181 self.running_counts
182 .get(operator_id)
183 .map(|c| c.load(Ordering::Relaxed))
184 }
185
186 #[must_use]
188 pub fn all_actuals(&self) -> &HashMap<String, u64> {
189 &self.actuals
190 }
191}
192
193#[derive(Debug)]
198pub struct AdaptiveContext {
199 checkpoints: HashMap<String, CardinalityCheckpoint>,
201 reoptimization_threshold: f64,
203 min_rows: u64,
205 reoptimization_triggered: bool,
207 trigger_operator: Option<String>,
209}
210
211impl AdaptiveContext {
212 #[must_use]
214 pub fn new() -> Self {
215 Self {
216 checkpoints: HashMap::new(),
217 reoptimization_threshold: DEFAULT_REOPTIMIZATION_THRESHOLD,
218 min_rows: MIN_ROWS_FOR_REOPTIMIZATION,
219 reoptimization_triggered: false,
220 trigger_operator: None,
221 }
222 }
223
224 #[must_use]
226 pub fn with_thresholds(threshold: f64, min_rows: u64) -> Self {
227 Self {
228 checkpoints: HashMap::new(),
229 reoptimization_threshold: threshold,
230 min_rows,
231 reoptimization_triggered: false,
232 trigger_operator: None,
233 }
234 }
235
236 pub fn set_estimate(&mut self, operator_id: &str, estimate: f64) {
238 self.checkpoints.insert(
239 operator_id.to_string(),
240 CardinalityCheckpoint::new(operator_id, estimate),
241 );
242 }
243
244 pub fn record_actual(&mut self, operator_id: &str, actual: u64) {
246 if let Some(checkpoint) = self.checkpoints.get_mut(operator_id) {
247 checkpoint.record(actual);
248 } else {
249 let mut checkpoint = CardinalityCheckpoint::new(operator_id, 0.0);
251 checkpoint.record(actual);
252 self.checkpoints.insert(operator_id.to_string(), checkpoint);
253 }
254 }
255
256 pub fn apply_feedback(&mut self, feedback: &CardinalityFeedback) {
258 for (op_id, &actual) in feedback.all_actuals() {
259 self.record_actual(op_id, actual);
260 }
261 }
262
263 #[must_use]
265 pub fn has_significant_deviation(&self) -> bool {
266 self.checkpoints
267 .values()
268 .any(|cp| cp.is_significant_deviation(self.reoptimization_threshold))
269 }
270
271 #[must_use]
278 pub fn should_reoptimize(&mut self) -> bool {
279 if self.reoptimization_triggered {
280 return false;
281 }
282
283 for (op_id, checkpoint) in &self.checkpoints {
284 if checkpoint.actual < self.min_rows {
285 continue;
286 }
287
288 if checkpoint.is_significant_deviation(self.reoptimization_threshold) {
289 self.reoptimization_triggered = true;
290 self.trigger_operator = Some(op_id.clone());
291 return true;
292 }
293 }
294
295 false
296 }
297
298 #[must_use]
300 pub fn trigger_operator(&self) -> Option<&str> {
301 self.trigger_operator.as_deref()
302 }
303
304 #[must_use]
306 pub fn get_checkpoint(&self, operator_id: &str) -> Option<&CardinalityCheckpoint> {
307 self.checkpoints.get(operator_id)
308 }
309
310 #[must_use]
312 pub fn all_checkpoints(&self) -> &HashMap<String, CardinalityCheckpoint> {
313 &self.checkpoints
314 }
315
316 #[must_use]
320 pub fn correction_factor(&self, operator_id: &str) -> f64 {
321 self.checkpoints
322 .get(operator_id)
323 .filter(|cp| cp.recorded)
324 .map(CardinalityCheckpoint::deviation_ratio)
325 .unwrap_or(1.0)
326 }
327
328 #[must_use]
330 pub fn summary(&self) -> AdaptiveSummary {
331 let recorded_count = self.checkpoints.values().filter(|cp| cp.recorded).count();
332 let deviation_count = self
333 .checkpoints
334 .values()
335 .filter(|cp| cp.is_significant_deviation(self.reoptimization_threshold))
336 .count();
337
338 let avg_deviation = if recorded_count > 0 {
339 self.checkpoints
340 .values()
341 .filter(|cp| cp.recorded)
342 .map(CardinalityCheckpoint::deviation_ratio)
343 .sum::<f64>()
344 / recorded_count as f64
345 } else {
346 1.0
347 };
348
349 let max_deviation = self
350 .checkpoints
351 .values()
352 .filter(|cp| cp.recorded)
353 .map(|cp| {
354 let ratio = cp.deviation_ratio();
355 if ratio > 1.0 { ratio } else { 1.0 / ratio }
356 })
357 .fold(1.0_f64, f64::max);
358
359 AdaptiveSummary {
360 checkpoint_count: self.checkpoints.len(),
361 recorded_count,
362 deviation_count,
363 avg_deviation_ratio: avg_deviation,
364 max_deviation_ratio: max_deviation,
365 reoptimization_triggered: self.reoptimization_triggered,
366 trigger_operator: self.trigger_operator.clone(),
367 }
368 }
369
370 pub fn reset(&mut self) {
372 for checkpoint in self.checkpoints.values_mut() {
373 checkpoint.actual = 0;
374 checkpoint.recorded = false;
375 }
376 self.reoptimization_triggered = false;
377 self.trigger_operator = None;
378 }
379}
380
381impl Default for AdaptiveContext {
382 fn default() -> Self {
383 Self::new()
384 }
385}
386
387#[derive(Debug, Clone, Default)]
389pub struct AdaptiveSummary {
390 pub checkpoint_count: usize,
392 pub recorded_count: usize,
394 pub deviation_count: usize,
396 pub avg_deviation_ratio: f64,
398 pub max_deviation_ratio: f64,
400 pub reoptimization_triggered: bool,
402 pub trigger_operator: Option<String>,
404}
405
406#[derive(Debug, Clone)]
410pub struct SharedAdaptiveContext {
411 inner: Arc<RwLock<AdaptiveContext>>,
412}
413
414impl SharedAdaptiveContext {
415 #[must_use]
417 pub fn new() -> Self {
418 Self {
419 inner: Arc::new(RwLock::new(AdaptiveContext::new())),
420 }
421 }
422
423 #[must_use]
425 pub fn from_context(ctx: AdaptiveContext) -> Self {
426 Self {
427 inner: Arc::new(RwLock::new(ctx)),
428 }
429 }
430
431 pub fn record_actual(&self, operator_id: &str, actual: u64) {
433 if let Ok(mut ctx) = self.inner.write() {
434 ctx.record_actual(operator_id, actual);
435 }
436 }
437
438 #[must_use]
440 pub fn should_reoptimize(&self) -> bool {
441 if let Ok(mut ctx) = self.inner.write() {
442 ctx.should_reoptimize()
443 } else {
444 false
445 }
446 }
447
448 #[must_use]
450 pub fn snapshot(&self) -> Option<AdaptiveContext> {
451 self.inner.read().ok().map(|guard| AdaptiveContext {
452 checkpoints: guard.checkpoints.clone(),
453 reoptimization_threshold: guard.reoptimization_threshold,
454 min_rows: guard.min_rows,
455 reoptimization_triggered: guard.reoptimization_triggered,
456 trigger_operator: guard.trigger_operator.clone(),
457 })
458 }
459}
460
461impl Default for SharedAdaptiveContext {
462 fn default() -> Self {
463 Self::new()
464 }
465}
466
467pub struct CardinalityTrackingOperator {
472 inner: Box<dyn PushOperator>,
474 operator_id: String,
476 row_count: u64,
478 context: SharedAdaptiveContext,
480}
481
482impl CardinalityTrackingOperator {
483 pub fn new(
485 inner: Box<dyn PushOperator>,
486 operator_id: &str,
487 context: SharedAdaptiveContext,
488 ) -> Self {
489 Self {
490 inner,
491 operator_id: operator_id.to_string(),
492 row_count: 0,
493 context,
494 }
495 }
496
497 #[must_use]
499 pub fn current_count(&self) -> u64 {
500 self.row_count
501 }
502}
503
504impl PushOperator for CardinalityTrackingOperator {
505 fn push(&mut self, chunk: DataChunk, sink: &mut dyn Sink) -> Result<bool, OperatorError> {
506 self.row_count += chunk.len() as u64;
508
509 self.inner.push(chunk, sink)
511 }
512
513 fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
514 self.context
516 .record_actual(&self.operator_id, self.row_count);
517
518 self.inner.finalize(sink)
520 }
521
522 fn preferred_chunk_size(&self) -> ChunkSizeHint {
523 self.inner.preferred_chunk_size()
524 }
525
526 fn name(&self) -> &'static str {
527 self.inner.name()
529 }
530}
531
532pub struct CardinalityTrackingSink {
534 inner: Box<dyn Sink>,
536 operator_id: String,
538 row_count: u64,
540 context: SharedAdaptiveContext,
542}
543
544impl CardinalityTrackingSink {
545 pub fn new(inner: Box<dyn Sink>, operator_id: &str, context: SharedAdaptiveContext) -> Self {
547 Self {
548 inner,
549 operator_id: operator_id.to_string(),
550 row_count: 0,
551 context,
552 }
553 }
554
555 #[must_use]
557 pub fn current_count(&self) -> u64 {
558 self.row_count
559 }
560}
561
562impl Sink for CardinalityTrackingSink {
563 fn consume(&mut self, chunk: DataChunk) -> Result<bool, OperatorError> {
564 self.row_count += chunk.len() as u64;
565 self.inner.consume(chunk)
566 }
567
568 fn finalize(&mut self) -> Result<(), OperatorError> {
569 self.context
571 .record_actual(&self.operator_id, self.row_count);
572 self.inner.finalize()
573 }
574
575 fn name(&self) -> &'static str {
576 self.inner.name()
577 }
578}
579
580#[derive(Debug, Clone, PartialEq)]
582pub enum ReoptimizationDecision {
583 Continue,
585 Reoptimize {
587 trigger: String,
589 corrections: HashMap<String, f64>,
591 },
592 Abort { reason: String },
594}
595
596#[must_use]
598pub fn evaluate_reoptimization(ctx: &AdaptiveContext) -> ReoptimizationDecision {
599 let summary = ctx.summary();
600
601 if !summary.reoptimization_triggered {
603 return ReoptimizationDecision::Continue;
604 }
605
606 if summary.max_deviation_ratio > 100.0 {
608 return ReoptimizationDecision::Abort {
609 reason: format!(
610 "Catastrophic cardinality misestimate: {}x deviation",
611 summary.max_deviation_ratio
612 ),
613 };
614 }
615
616 let corrections: HashMap<String, f64> = ctx
618 .all_checkpoints()
619 .iter()
620 .filter(|(_, cp)| cp.recorded)
621 .map(|(id, cp)| (id.clone(), cp.deviation_ratio()))
622 .collect();
623
624 ReoptimizationDecision::Reoptimize {
625 trigger: summary.trigger_operator.unwrap_or_default(),
626 corrections,
627 }
628}
629
630pub type PlanFactory = Box<dyn Fn(&AdaptiveContext) -> Vec<Box<dyn PushOperator>> + Send + Sync>;
636
637#[derive(Debug, Clone)]
639pub struct AdaptivePipelineConfig {
640 pub check_interval: u64,
642 pub reoptimization_threshold: f64,
644 pub min_rows_for_reoptimization: u64,
646 pub max_reoptimizations: usize,
648}
649
650impl Default for AdaptivePipelineConfig {
651 fn default() -> Self {
652 Self {
653 check_interval: 10_000,
654 reoptimization_threshold: DEFAULT_REOPTIMIZATION_THRESHOLD,
655 min_rows_for_reoptimization: MIN_ROWS_FOR_REOPTIMIZATION,
656 max_reoptimizations: 3,
657 }
658 }
659}
660
661impl AdaptivePipelineConfig {
662 #[must_use]
664 pub fn new(check_interval: u64, threshold: f64, min_rows: u64) -> Self {
665 Self {
666 check_interval,
667 reoptimization_threshold: threshold,
668 min_rows_for_reoptimization: min_rows,
669 max_reoptimizations: 3,
670 }
671 }
672
673 #[must_use]
675 pub fn with_max_reoptimizations(mut self, max: usize) -> Self {
676 self.max_reoptimizations = max;
677 self
678 }
679}
680
681#[derive(Debug, Clone)]
683pub struct AdaptiveExecutionResult {
684 pub total_rows: u64,
686 pub reoptimization_count: usize,
688 pub triggers: Vec<String>,
690 pub final_context: AdaptiveSummary,
692}
693
694#[derive(Debug)]
699pub struct AdaptiveCheckpoint {
700 pub id: String,
702 pub after_operator: usize,
704 pub estimated_cardinality: f64,
706 pub actual_rows: u64,
708 pub triggered: bool,
710}
711
712impl AdaptiveCheckpoint {
713 #[must_use]
715 pub fn new(id: &str, after_operator: usize, estimated: f64) -> Self {
716 Self {
717 id: id.to_string(),
718 after_operator,
719 estimated_cardinality: estimated,
720 actual_rows: 0,
721 triggered: false,
722 }
723 }
724
725 pub fn record_rows(&mut self, count: u64) {
727 self.actual_rows += count;
728 }
729
730 #[must_use]
732 pub fn exceeds_threshold(&self, threshold: f64, min_rows: u64) -> bool {
733 if self.actual_rows < min_rows {
734 return false;
735 }
736 if self.estimated_cardinality <= 0.0 {
737 return self.actual_rows > 0;
738 }
739 let ratio = self.actual_rows as f64 / self.estimated_cardinality;
740 ratio > threshold || ratio < 1.0 / threshold
741 }
742}
743
744#[derive(Debug, Clone)]
746pub enum AdaptiveEvent {
747 CheckpointReached {
749 id: String,
750 actual_rows: u64,
751 estimated: f64,
752 },
753 ReoptimizationTriggered {
755 checkpoint_id: String,
756 deviation_ratio: f64,
757 },
758 PlanSwitched {
760 old_operator_count: usize,
761 new_operator_count: usize,
762 },
763 ExecutionCompleted { total_rows: u64 },
765}
766
767pub type AdaptiveEventCallback = Box<dyn Fn(AdaptiveEvent) + Send + Sync>;
769
770pub struct AdaptivePipelineBuilder {
772 checkpoints: Vec<AdaptiveCheckpoint>,
773 config: AdaptivePipelineConfig,
774 context: AdaptiveContext,
775 event_callback: Option<AdaptiveEventCallback>,
776}
777
778impl AdaptivePipelineBuilder {
779 #[must_use]
781 pub fn new() -> Self {
782 Self {
783 checkpoints: Vec::new(),
784 config: AdaptivePipelineConfig::default(),
785 context: AdaptiveContext::new(),
786 event_callback: None,
787 }
788 }
789
790 #[must_use]
792 pub fn with_config(mut self, config: AdaptivePipelineConfig) -> Self {
793 self.config = config;
794 self
795 }
796
797 #[must_use]
799 pub fn with_checkpoint(mut self, id: &str, after_operator: usize, estimated: f64) -> Self {
800 self.checkpoints
801 .push(AdaptiveCheckpoint::new(id, after_operator, estimated));
802 self.context.set_estimate(id, estimated);
803 self
804 }
805
806 #[must_use]
808 pub fn with_event_callback(mut self, callback: AdaptiveEventCallback) -> Self {
809 self.event_callback = Some(callback);
810 self
811 }
812
813 #[must_use]
815 pub fn with_context(mut self, context: AdaptiveContext) -> Self {
816 self.context = context;
817 self
818 }
819
820 #[must_use]
822 pub fn build(self) -> AdaptiveExecutionConfig {
823 AdaptiveExecutionConfig {
824 checkpoints: self.checkpoints,
825 config: self.config,
826 context: self.context,
827 event_callback: self.event_callback,
828 }
829 }
830}
831
832impl Default for AdaptivePipelineBuilder {
833 fn default() -> Self {
834 Self::new()
835 }
836}
837
838pub struct AdaptiveExecutionConfig {
840 pub checkpoints: Vec<AdaptiveCheckpoint>,
842 pub config: AdaptivePipelineConfig,
844 pub context: AdaptiveContext,
846 pub event_callback: Option<AdaptiveEventCallback>,
848}
849
850impl AdaptiveExecutionConfig {
851 #[must_use]
853 pub fn summary(&self) -> AdaptiveSummary {
854 self.context.summary()
855 }
856
857 pub fn record_checkpoint(&mut self, checkpoint_id: &str, actual: u64) {
859 self.context.record_actual(checkpoint_id, actual);
860
861 if let Some(cp) = self.checkpoints.iter_mut().find(|c| c.id == checkpoint_id) {
862 cp.actual_rows = actual;
863 }
864
865 if let Some(ref callback) = self.event_callback {
866 let estimated = self
867 .context
868 .get_checkpoint(checkpoint_id)
869 .map(|cp| cp.estimated)
870 .unwrap_or(0.0);
871 callback(AdaptiveEvent::CheckpointReached {
872 id: checkpoint_id.to_string(),
873 actual_rows: actual,
874 estimated,
875 });
876 }
877 }
878
879 #[must_use]
881 pub fn should_reoptimize(&self) -> Option<&AdaptiveCheckpoint> {
882 self.checkpoints.iter().find(|cp| {
883 !cp.triggered
884 && cp.exceeds_threshold(
885 self.config.reoptimization_threshold,
886 self.config.min_rows_for_reoptimization,
887 )
888 })
889 }
890
891 pub fn mark_triggered(&mut self, checkpoint_id: &str) {
893 if let Some(cp) = self.checkpoints.iter_mut().find(|c| c.id == checkpoint_id) {
894 cp.triggered = true;
895 }
896
897 if let Some(ref callback) = self.event_callback {
898 let ratio = self
899 .context
900 .get_checkpoint(checkpoint_id)
901 .filter(|cp| cp.recorded)
902 .map(|cp| cp.deviation_ratio())
903 .unwrap_or(1.0);
904 callback(AdaptiveEvent::ReoptimizationTriggered {
905 checkpoint_id: checkpoint_id.to_string(),
906 deviation_ratio: ratio,
907 });
908 }
909 }
910}
911
912use super::operators::{Operator, OperatorResult}; pub struct CardinalityTrackingWrapper {
922 inner: Box<dyn Operator>,
924 operator_id: String,
926 row_count: u64,
928 context: SharedAdaptiveContext,
930 finalized: bool,
932}
933
934impl CardinalityTrackingWrapper {
935 pub fn new(
937 inner: Box<dyn Operator>,
938 operator_id: &str,
939 context: SharedAdaptiveContext,
940 ) -> Self {
941 Self {
942 inner,
943 operator_id: operator_id.to_string(),
944 row_count: 0,
945 context,
946 finalized: false,
947 }
948 }
949
950 #[must_use]
952 pub fn current_count(&self) -> u64 {
953 self.row_count
954 }
955
956 fn report_final(&mut self) {
958 if !self.finalized {
959 self.context
960 .record_actual(&self.operator_id, self.row_count);
961 self.finalized = true;
962 }
963 }
964}
965
966impl Operator for CardinalityTrackingWrapper {
967 fn next(&mut self) -> OperatorResult {
968 match self.inner.next() {
969 Ok(Some(chunk)) => {
970 self.row_count += chunk.row_count() as u64;
972 Ok(Some(chunk))
973 }
974 Ok(None) => {
975 self.report_final();
977 Ok(None)
978 }
979 Err(e) => {
980 self.report_final();
982 Err(e)
983 }
984 }
985 }
986
987 fn reset(&mut self) {
988 self.row_count = 0;
989 self.finalized = false;
990 self.inner.reset();
991 }
992
993 fn name(&self) -> &'static str {
994 self.inner.name()
995 }
996}
997
998impl Drop for CardinalityTrackingWrapper {
999 fn drop(&mut self) {
1000 self.report_final();
1002 }
1003}
1004
1005use super::pipeline::{DEFAULT_CHUNK_SIZE, Source}; use super::sink::CollectorSink;
1009use super::source::OperatorSource;
1010
1011pub struct AdaptivePipelineExecutor {
1028 source: OperatorSource,
1029 context: SharedAdaptiveContext,
1030 config: AdaptivePipelineConfig,
1031}
1032
1033impl AdaptivePipelineExecutor {
1034 pub fn new(operator: Box<dyn Operator>, context: AdaptiveContext) -> Self {
1041 Self {
1042 source: OperatorSource::new(operator),
1043 context: SharedAdaptiveContext::from_context(context),
1044 config: AdaptivePipelineConfig::default(),
1045 }
1046 }
1047
1048 pub fn with_config(
1050 operator: Box<dyn Operator>,
1051 context: AdaptiveContext,
1052 config: AdaptivePipelineConfig,
1053 ) -> Self {
1054 Self {
1055 source: OperatorSource::new(operator),
1056 context: SharedAdaptiveContext::from_context(context),
1057 config,
1058 }
1059 }
1060
1061 pub fn execute(mut self) -> Result<(Vec<DataChunk>, AdaptiveSummary), OperatorError> {
1071 let mut sink = CardinalityTrackingSink::new(
1072 Box::new(CollectorSink::new()),
1073 "output",
1074 self.context.clone(),
1075 );
1076
1077 let chunk_size = DEFAULT_CHUNK_SIZE;
1078 let mut total_rows: u64 = 0;
1079 let check_interval = self.config.check_interval;
1080
1081 while let Some(chunk) = self.source.next_chunk(chunk_size)? {
1083 let chunk_rows = chunk.len() as u64;
1084 total_rows += chunk_rows;
1085
1086 let continue_exec = sink.consume(chunk)?;
1088 if !continue_exec {
1089 break;
1090 }
1091
1092 if total_rows >= check_interval
1094 && total_rows.is_multiple_of(check_interval)
1095 && self.context.should_reoptimize()
1096 {
1097 }
1100 }
1101
1102 sink.finalize()?;
1104
1105 let summary = self
1107 .context
1108 .snapshot()
1109 .map(|ctx| ctx.summary())
1110 .unwrap_or_default();
1111
1112 Ok((Vec::new(), summary))
1116 }
1117
1118 pub fn execute_collecting(
1122 mut self,
1123 ) -> Result<(Vec<DataChunk>, AdaptiveSummary), OperatorError> {
1124 let mut chunks = Vec::new();
1125 let chunk_size = DEFAULT_CHUNK_SIZE;
1126 let mut total_rows: u64 = 0;
1127 let check_interval = self.config.check_interval;
1128
1129 while let Some(chunk) = self.source.next_chunk(chunk_size)? {
1131 let chunk_rows = chunk.len() as u64;
1132 total_rows += chunk_rows;
1133
1134 self.context.record_actual("root", total_rows);
1136
1137 if !chunk.is_empty() {
1139 chunks.push(chunk);
1140 }
1141
1142 if total_rows >= check_interval && total_rows.is_multiple_of(check_interval) {
1144 let _ = self.context.should_reoptimize();
1145 }
1146 }
1147
1148 let summary = self
1149 .context
1150 .snapshot()
1151 .map(|ctx| ctx.summary())
1152 .unwrap_or_default();
1153
1154 Ok((chunks, summary))
1155 }
1156
1157 pub fn context(&self) -> &SharedAdaptiveContext {
1159 &self.context
1160 }
1161}
1162
1163pub fn execute_adaptive(
1177 operator: Box<dyn Operator>,
1178 context: Option<AdaptiveContext>,
1179 config: Option<AdaptivePipelineConfig>,
1180) -> Result<(Vec<DataChunk>, Option<AdaptiveSummary>), OperatorError> {
1181 let ctx = context.unwrap_or_default();
1182 let cfg = config.unwrap_or_default();
1183
1184 let executor = AdaptivePipelineExecutor::with_config(operator, ctx, cfg);
1185 let (chunks, summary) = executor.execute_collecting()?;
1186
1187 Ok((chunks, Some(summary)))
1188}
1189
1190#[cfg(test)]
1191mod tests {
1192 use super::*;
1193
1194 #[test]
1195 fn test_checkpoint_deviation_ratio() {
1196 let mut cp = CardinalityCheckpoint::new("test", 100.0);
1197 cp.record(200);
1198
1199 assert!((cp.deviation_ratio() - 2.0).abs() < 0.001);
1201 }
1202
1203 #[test]
1204 fn test_checkpoint_underestimate() {
1205 let mut cp = CardinalityCheckpoint::new("test", 100.0);
1206 cp.record(500);
1207
1208 assert!((cp.deviation_ratio() - 5.0).abs() < 0.001);
1210 assert!(cp.is_significant_deviation(3.0));
1211 }
1212
1213 #[test]
1214 fn test_checkpoint_overestimate() {
1215 let mut cp = CardinalityCheckpoint::new("test", 100.0);
1216 cp.record(20);
1217
1218 assert!((cp.deviation_ratio() - 0.2).abs() < 0.001);
1220 assert!(cp.is_significant_deviation(3.0)); }
1222
1223 #[test]
1224 fn test_checkpoint_accurate() {
1225 let mut cp = CardinalityCheckpoint::new("test", 100.0);
1226 cp.record(110);
1227
1228 assert!((cp.deviation_ratio() - 1.1).abs() < 0.001);
1230 assert!(!cp.is_significant_deviation(3.0)); }
1232
1233 #[test]
1234 fn test_checkpoint_zero_estimate() {
1235 let mut cp = CardinalityCheckpoint::new("test", 0.0);
1236 cp.record(100);
1237
1238 assert!(cp.deviation_ratio().is_infinite());
1240 }
1241
1242 #[test]
1243 fn test_checkpoint_zero_both() {
1244 let mut cp = CardinalityCheckpoint::new("test", 0.0);
1245 cp.record(0);
1246
1247 assert!((cp.deviation_ratio() - 1.0).abs() < 0.001);
1249 }
1250
1251 #[test]
1252 fn test_feedback_collection() {
1253 let mut feedback = CardinalityFeedback::new();
1254 feedback.record("scan_1", 1000);
1255 feedback.record("filter_1", 100);
1256
1257 assert_eq!(feedback.get("scan_1"), Some(1000));
1258 assert_eq!(feedback.get("filter_1"), Some(100));
1259 assert_eq!(feedback.get("unknown"), None);
1260 }
1261
1262 #[test]
1263 fn test_feedback_running_counter() {
1264 let mut feedback = CardinalityFeedback::new();
1265 feedback.init_counter("op_1");
1266
1267 feedback.add_rows("op_1", 100);
1268 feedback.add_rows("op_1", 200);
1269 feedback.add_rows("op_1", 50);
1270
1271 assert_eq!(feedback.get_running("op_1"), Some(350));
1272
1273 feedback.finalize_counter("op_1");
1274 assert_eq!(feedback.get("op_1"), Some(350));
1275 }
1276
1277 #[test]
1278 fn test_adaptive_context_basic() {
1279 let mut ctx = AdaptiveContext::new();
1280 ctx.set_estimate("scan", 1000.0);
1281 ctx.set_estimate("filter", 100.0);
1282
1283 ctx.record_actual("scan", 1000);
1284 ctx.record_actual("filter", 500); let cp = ctx.get_checkpoint("filter").unwrap();
1287 assert!((cp.deviation_ratio() - 5.0).abs() < 0.001);
1288 }
1289
1290 #[test]
1291 fn test_adaptive_context_should_reoptimize() {
1292 let mut ctx = AdaptiveContext::with_thresholds(2.0, 100);
1293 ctx.set_estimate("scan", 10000.0);
1294 ctx.set_estimate("filter", 1000.0);
1295
1296 ctx.record_actual("scan", 10000);
1297 ctx.record_actual("filter", 5000); assert!(ctx.should_reoptimize());
1300 assert_eq!(ctx.trigger_operator(), Some("filter"));
1301
1302 assert!(!ctx.should_reoptimize());
1304 }
1305
1306 #[test]
1307 fn test_adaptive_context_min_rows() {
1308 let mut ctx = AdaptiveContext::with_thresholds(2.0, 1000);
1309 ctx.set_estimate("filter", 100.0);
1310 ctx.record_actual("filter", 500); assert!(!ctx.should_reoptimize());
1314 }
1315
1316 #[test]
1317 fn test_adaptive_context_no_deviation() {
1318 let mut ctx = AdaptiveContext::new();
1319 ctx.set_estimate("scan", 1000.0);
1320 ctx.record_actual("scan", 1100); assert!(!ctx.has_significant_deviation());
1323 assert!(!ctx.should_reoptimize());
1324 }
1325
1326 #[test]
1327 fn test_adaptive_context_correction_factor() {
1328 let mut ctx = AdaptiveContext::new();
1329 ctx.set_estimate("filter", 100.0);
1330 ctx.record_actual("filter", 300);
1331
1332 assert!((ctx.correction_factor("filter") - 3.0).abs() < 0.001);
1333 assert!((ctx.correction_factor("unknown") - 1.0).abs() < 0.001);
1334 }
1335
1336 #[test]
1337 fn test_adaptive_context_apply_feedback() {
1338 let mut ctx = AdaptiveContext::new();
1339 ctx.set_estimate("scan", 1000.0);
1340 ctx.set_estimate("filter", 100.0);
1341
1342 let mut feedback = CardinalityFeedback::new();
1343 feedback.record("scan", 1000);
1344 feedback.record("filter", 500);
1345
1346 ctx.apply_feedback(&feedback);
1347
1348 assert_eq!(ctx.get_checkpoint("scan").unwrap().actual, 1000);
1349 assert_eq!(ctx.get_checkpoint("filter").unwrap().actual, 500);
1350 }
1351
1352 #[test]
1353 fn test_adaptive_summary() {
1354 let mut ctx = AdaptiveContext::with_thresholds(2.0, 0);
1355 ctx.set_estimate("op1", 100.0);
1356 ctx.set_estimate("op2", 200.0);
1357 ctx.set_estimate("op3", 300.0);
1358
1359 ctx.record_actual("op1", 100); ctx.record_actual("op2", 600); let _ = ctx.should_reoptimize();
1364
1365 let summary = ctx.summary();
1366 assert_eq!(summary.checkpoint_count, 3);
1367 assert_eq!(summary.recorded_count, 2);
1368 assert_eq!(summary.deviation_count, 1);
1369 assert!(summary.reoptimization_triggered);
1370 }
1371
1372 #[test]
1373 fn test_adaptive_context_reset() {
1374 let mut ctx = AdaptiveContext::new();
1375 ctx.set_estimate("scan", 1000.0);
1376 ctx.record_actual("scan", 5000);
1377 let _ = ctx.should_reoptimize(); assert!(ctx.reoptimization_triggered);
1380
1381 ctx.reset();
1382
1383 assert!(!ctx.reoptimization_triggered);
1384 assert_eq!(ctx.get_checkpoint("scan").unwrap().actual, 0);
1385 assert!(!ctx.get_checkpoint("scan").unwrap().recorded);
1386 assert!((ctx.get_checkpoint("scan").unwrap().estimated - 1000.0).abs() < 0.001);
1388 }
1389
1390 #[test]
1391 fn test_shared_context() {
1392 let ctx = SharedAdaptiveContext::new();
1393
1394 ctx.record_actual("op1", 1000);
1395
1396 let snapshot = ctx.snapshot().unwrap();
1397 assert_eq!(snapshot.get_checkpoint("op1").unwrap().actual, 1000);
1398 }
1399
1400 #[test]
1401 fn test_reoptimization_decision_continue() {
1402 let mut ctx = AdaptiveContext::new();
1403 ctx.set_estimate("scan", 1000.0);
1404 ctx.record_actual("scan", 1100);
1405
1406 let decision = evaluate_reoptimization(&ctx);
1407 assert_eq!(decision, ReoptimizationDecision::Continue);
1408 }
1409
1410 #[test]
1411 fn test_reoptimization_decision_reoptimize() {
1412 let mut ctx = AdaptiveContext::with_thresholds(2.0, 0);
1413 ctx.set_estimate("filter", 100.0);
1414 ctx.record_actual("filter", 500);
1415 let _ = ctx.should_reoptimize(); let decision = evaluate_reoptimization(&ctx);
1418
1419 if let ReoptimizationDecision::Reoptimize {
1420 trigger,
1421 corrections,
1422 } = decision
1423 {
1424 assert_eq!(trigger, "filter");
1425 assert!((corrections.get("filter").copied().unwrap_or(0.0) - 5.0).abs() < 0.001);
1426 } else {
1427 panic!("Expected Reoptimize decision");
1428 }
1429 }
1430
1431 #[test]
1432 fn test_reoptimization_decision_abort() {
1433 let mut ctx = AdaptiveContext::with_thresholds(2.0, 0);
1434 ctx.set_estimate("filter", 1.0);
1435 ctx.record_actual("filter", 1000); let _ = ctx.should_reoptimize();
1437
1438 let decision = evaluate_reoptimization(&ctx);
1439
1440 if let ReoptimizationDecision::Abort { reason } = decision {
1441 assert!(reason.contains("Catastrophic"));
1442 } else {
1443 panic!("Expected Abort decision");
1444 }
1445 }
1446
1447 #[test]
1448 fn test_absolute_deviation() {
1449 let mut cp = CardinalityCheckpoint::new("test", 100.0);
1450 cp.record(150);
1451
1452 assert!((cp.absolute_deviation() - 50.0).abs() < 0.001);
1453 }
1454
1455 #[test]
1458 fn test_adaptive_checkpoint_basic() {
1459 let mut cp = AdaptiveCheckpoint::new("filter_1", 0, 100.0);
1460 assert_eq!(cp.actual_rows, 0);
1461 assert!(!cp.triggered);
1462
1463 cp.record_rows(50);
1464 assert_eq!(cp.actual_rows, 50);
1465
1466 cp.record_rows(100);
1467 assert_eq!(cp.actual_rows, 150);
1468 }
1469
1470 #[test]
1471 fn test_adaptive_checkpoint_exceeds_threshold() {
1472 let mut cp = AdaptiveCheckpoint::new("filter", 0, 100.0);
1473
1474 cp.record_rows(50);
1476 assert!(!cp.exceeds_threshold(2.0, 100));
1477
1478 cp.record_rows(50);
1480 assert!(!cp.exceeds_threshold(2.0, 100)); cp.actual_rows = 0;
1484 cp.record_rows(500);
1485 assert!(cp.exceeds_threshold(2.0, 100)); let mut cp2 = AdaptiveCheckpoint::new("filter2", 0, 1000.0);
1489 cp2.record_rows(200);
1490 assert!(cp2.exceeds_threshold(2.0, 100)); }
1492
1493 #[test]
1494 fn test_adaptive_pipeline_config_default() {
1495 let config = AdaptivePipelineConfig::default();
1496
1497 assert_eq!(config.check_interval, 10_000);
1498 assert!((config.reoptimization_threshold - DEFAULT_REOPTIMIZATION_THRESHOLD).abs() < 0.001);
1499 assert_eq!(
1500 config.min_rows_for_reoptimization,
1501 MIN_ROWS_FOR_REOPTIMIZATION
1502 );
1503 assert_eq!(config.max_reoptimizations, 3);
1504 }
1505
1506 #[test]
1507 fn test_adaptive_pipeline_config_custom() {
1508 let config = AdaptivePipelineConfig::new(5000, 2.0, 500).with_max_reoptimizations(5);
1509
1510 assert_eq!(config.check_interval, 5000);
1511 assert!((config.reoptimization_threshold - 2.0).abs() < 0.001);
1512 assert_eq!(config.min_rows_for_reoptimization, 500);
1513 assert_eq!(config.max_reoptimizations, 5);
1514 }
1515
1516 #[test]
1517 fn test_adaptive_pipeline_builder() {
1518 let config = AdaptivePipelineBuilder::new()
1519 .with_config(AdaptivePipelineConfig::new(1000, 2.0, 100))
1520 .with_checkpoint("scan", 0, 10000.0)
1521 .with_checkpoint("filter", 1, 1000.0)
1522 .build();
1523
1524 assert_eq!(config.checkpoints.len(), 2);
1525 assert_eq!(config.checkpoints[0].id, "scan");
1526 assert!((config.checkpoints[0].estimated_cardinality - 10000.0).abs() < 0.001);
1527 assert_eq!(config.checkpoints[1].id, "filter");
1528 assert!((config.checkpoints[1].estimated_cardinality - 1000.0).abs() < 0.001);
1529 }
1530
1531 #[test]
1532 fn test_adaptive_execution_config_record_checkpoint() {
1533 let mut config = AdaptivePipelineBuilder::new()
1534 .with_checkpoint("filter", 0, 100.0)
1535 .build();
1536
1537 config.record_checkpoint("filter", 500);
1538
1539 let cp = config.context.get_checkpoint("filter").unwrap();
1541 assert_eq!(cp.actual, 500);
1542 assert!(cp.recorded);
1543
1544 let acp = config
1546 .checkpoints
1547 .iter()
1548 .find(|c| c.id == "filter")
1549 .unwrap();
1550 assert_eq!(acp.actual_rows, 500);
1551 }
1552
1553 #[test]
1554 fn test_adaptive_execution_config_should_reoptimize() {
1555 let mut config = AdaptivePipelineBuilder::new()
1556 .with_config(AdaptivePipelineConfig::new(1000, 2.0, 100))
1557 .with_checkpoint("filter", 0, 100.0)
1558 .build();
1559
1560 assert!(config.should_reoptimize().is_none());
1562
1563 config.record_checkpoint("filter", 150);
1565 assert!(config.should_reoptimize().is_none()); config.checkpoints[0].actual_rows = 0; config.record_checkpoint("filter", 500);
1570 config.checkpoints[0].actual_rows = 500;
1571
1572 let trigger = config.should_reoptimize();
1573 assert!(trigger.is_some());
1574 assert_eq!(trigger.unwrap().id, "filter");
1575 }
1576
1577 #[test]
1578 fn test_adaptive_execution_config_mark_triggered() {
1579 let mut config = AdaptivePipelineBuilder::new()
1580 .with_checkpoint("filter", 0, 100.0)
1581 .build();
1582
1583 assert!(!config.checkpoints[0].triggered);
1584
1585 config.mark_triggered("filter");
1586
1587 assert!(config.checkpoints[0].triggered);
1588 }
1589
1590 #[test]
1591 fn test_adaptive_event_callback() {
1592 use std::sync::atomic::AtomicUsize;
1593
1594 let event_count = Arc::new(AtomicUsize::new(0));
1595 let counter = event_count.clone();
1596
1597 let mut config = AdaptivePipelineBuilder::new()
1598 .with_checkpoint("filter", 0, 100.0)
1599 .with_event_callback(Box::new(move |_event| {
1600 counter.fetch_add(1, Ordering::Relaxed);
1601 }))
1602 .build();
1603
1604 config.record_checkpoint("filter", 500);
1605
1606 assert_eq!(event_count.load(Ordering::Relaxed), 1);
1608
1609 config.mark_triggered("filter");
1610
1611 assert_eq!(event_count.load(Ordering::Relaxed), 2);
1613 }
1614
1615 #[test]
1616 fn test_adaptive_checkpoint_with_zero_estimate() {
1617 let mut cp = AdaptiveCheckpoint::new("test", 0, 0.0);
1618 cp.record_rows(100);
1619
1620 assert!(cp.exceeds_threshold(2.0, 50));
1622 }
1623}