1use anyhow::{anyhow, Result};
61use serde::{Deserialize, Serialize};
62use std::collections::HashMap;
63use std::sync::Arc;
64use std::time::{Duration, Instant};
65use tokio::sync::RwLock;
66use tracing::{debug, info};
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct FusionConfig {
71 pub enable_fusion: bool,
73
74 pub max_fusion_depth: usize,
76
77 pub enable_map_fusion: bool,
79
80 pub enable_filter_fusion: bool,
82
83 pub enable_cross_fusion: bool,
85
86 pub enable_reordering: bool,
88
89 pub min_fusion_size: usize,
91
92 pub cost_threshold: f32,
94
95 pub aggressive_mode: bool,
97
98 pub collect_metrics: bool,
100}
101
102impl Default for FusionConfig {
103 fn default() -> Self {
104 Self {
105 enable_fusion: true,
106 max_fusion_depth: 10,
107 enable_map_fusion: true,
108 enable_filter_fusion: true,
109 enable_cross_fusion: true,
110 enable_reordering: false, min_fusion_size: 2,
112 cost_threshold: 0.1,
113 aggressive_mode: false,
114 collect_metrics: true,
115 }
116 }
117}
118
119#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
121pub enum Operation {
122 Map { name: String },
124
125 Filter { name: String },
127
128 FlatMap { name: String },
130
131 Distinct { name: String },
133
134 Take { count: usize },
136
137 Skip { count: usize },
139
140 Custom { name: String, fusable: bool },
142}
143
144impl Operation {
145 pub fn can_fuse_with(&self, other: &Operation) -> bool {
147 match (self, other) {
148 (Operation::Map { .. }, Operation::Map { .. }) => true,
150 (Operation::Filter { .. }, Operation::Filter { .. }) => true,
152 (Operation::Map { .. }, Operation::Filter { .. }) => true,
154 (Operation::Custom { fusable: true, .. }, Operation::Custom { fusable: true, .. }) => {
156 true
157 }
158 _ => false,
159 }
160 }
161
162 pub fn name(&self) -> String {
164 match self {
165 Operation::Map { name } => format!("map({})", name),
166 Operation::Filter { name } => format!("filter({})", name),
167 Operation::FlatMap { name } => format!("flat_map({})", name),
168 Operation::Distinct { name } => format!("distinct({})", name),
169 Operation::Take { count } => format!("take({})", count),
170 Operation::Skip { count } => format!("skip({})", count),
171 Operation::Custom { name, .. } => format!("custom({})", name),
172 }
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct FusedOperation {
179 pub original_ops: Vec<Operation>,
181
182 pub fused_type: FusedType,
184
185 pub cost_savings: f32,
187
188 pub fused_at: chrono::DateTime<chrono::Utc>,
190}
191
192#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
194pub enum FusedType {
195 MapChain,
197
198 FilterChain,
200
201 MapFilter,
203
204 Complex,
206}
207
208#[derive(Debug, Clone, Default, Serialize, Deserialize)]
210pub struct FusionStats {
211 pub pipelines_optimized: u64,
213
214 pub operations_analyzed: u64,
216
217 pub operations_fused: u64,
219
220 pub fusion_chains_created: u64,
222
223 pub overhead_reduction_percent: f32,
225
226 pub avg_fusion_chain_length: f32,
228
229 pub map_fusions: u64,
231
232 pub filter_fusions: u64,
234
235 pub cross_fusions: u64,
237
238 pub reorderings: u64,
240
241 pub total_optimization_time: Duration,
243
244 pub last_optimization: Option<chrono::DateTime<chrono::Utc>>,
246}
247
248pub struct FusionOptimizer {
250 config: FusionConfig,
251 stats: Arc<RwLock<FusionStats>>,
252 fusion_cache: Arc<RwLock<HashMap<String, Vec<FusedOperation>>>>,
253}
254
255impl FusionOptimizer {
256 pub fn new(config: FusionConfig) -> Self {
258 Self {
259 config,
260 stats: Arc::new(RwLock::new(FusionStats::default())),
261 fusion_cache: Arc::new(RwLock::new(HashMap::new())),
262 }
263 }
264
265 pub async fn optimize_pipeline(&mut self, pipeline: &[Operation]) -> Result<Vec<Operation>> {
267 if !self.config.enable_fusion {
268 return Ok(pipeline.to_vec());
269 }
270
271 if pipeline.len() < self.config.min_fusion_size {
272 debug!(
273 "Pipeline too small for fusion: {} operations",
274 pipeline.len()
275 );
276 return Ok(pipeline.to_vec());
277 }
278
279 let start_time = Instant::now();
280
281 let mut stats = self.stats.write().await;
283 stats.pipelines_optimized += 1;
284 stats.operations_analyzed += pipeline.len() as u64;
285 drop(stats);
286
287 let optimized = self.fuse_operations(pipeline).await?;
289
290 let optimization_time = start_time.elapsed();
292 let mut stats = self.stats.write().await;
293 stats.total_optimization_time += optimization_time;
294 stats.last_optimization = Some(chrono::Utc::now());
295
296 info!(
297 "Optimized pipeline: {} ops -> {} ops in {:?}",
298 pipeline.len(),
299 optimized.len(),
300 optimization_time
301 );
302
303 Ok(optimized)
304 }
305
306 async fn fuse_operations(&self, operations: &[Operation]) -> Result<Vec<Operation>> {
308 let mut result = Vec::new();
309 let mut i = 0;
310
311 while i < operations.len() {
312 let fusion_chain = self.find_fusion_chain(operations, i).await?;
314
315 if fusion_chain.len() > 1 {
316 let fused_op = self.create_fused_operation(&fusion_chain).await?;
318 result.push(fused_op);
319
320 let mut stats = self.stats.write().await;
322 stats.operations_fused += fusion_chain.len() as u64;
323 stats.fusion_chains_created += 1;
324
325 let fusion_type = self.classify_fusion(&fusion_chain);
327 match fusion_type {
328 FusedType::MapChain => stats.map_fusions += 1,
329 FusedType::FilterChain => stats.filter_fusions += 1,
330 FusedType::MapFilter => stats.cross_fusions += 1,
331 FusedType::Complex => {}
332 }
333
334 let reduction = self.estimate_overhead_reduction(&fusion_chain);
336 stats.overhead_reduction_percent =
337 (stats.overhead_reduction_percent + reduction) / 2.0;
338
339 let chain_len = fusion_chain.len() as f32;
341 stats.avg_fusion_chain_length = (stats.avg_fusion_chain_length
342 * (stats.fusion_chains_created - 1) as f32
343 + chain_len)
344 / stats.fusion_chains_created as f32;
345
346 i += fusion_chain.len();
347 } else {
348 result.push(operations[i].clone());
350 i += 1;
351 }
352 }
353
354 Ok(result)
355 }
356
357 async fn find_fusion_chain(
359 &self,
360 operations: &[Operation],
361 start: usize,
362 ) -> Result<Vec<Operation>> {
363 let mut chain = vec![operations[start].clone()];
364 let mut current = start;
365
366 while current + 1 < operations.len() && chain.len() < self.config.max_fusion_depth {
367 let current_op = &operations[current];
368 let next_op = &operations[current + 1];
369
370 let can_fuse = match (current_op, next_op) {
372 (Operation::Map { .. }, Operation::Map { .. }) => self.config.enable_map_fusion,
373 (Operation::Filter { .. }, Operation::Filter { .. }) => {
374 self.config.enable_filter_fusion
375 }
376 (Operation::Map { .. }, Operation::Filter { .. }) => {
377 self.config.enable_cross_fusion
378 }
379 (Operation::Filter { .. }, Operation::Map { .. }) => {
380 self.config.enable_cross_fusion && self.config.enable_reordering
381 }
382 _ => current_op.can_fuse_with(next_op),
383 };
384
385 if can_fuse {
386 let benefit = self.estimate_fusion_benefit(current_op, next_op);
388 if benefit >= self.config.cost_threshold {
389 chain.push(next_op.clone());
390 current += 1;
391 } else {
392 debug!("Fusion benefit too low: {}", benefit);
393 break;
394 }
395 } else {
396 break;
397 }
398 }
399
400 Ok(chain)
401 }
402
403 async fn create_fused_operation(&self, chain: &[Operation]) -> Result<Operation> {
405 if chain.is_empty() {
406 return Err(anyhow!("Cannot create fused operation from empty chain"));
407 }
408
409 if chain.len() == 1 {
410 return Ok(chain[0].clone());
411 }
412
413 let fusion_type = self.classify_fusion(chain);
415
416 match fusion_type {
418 FusedType::MapChain => {
419 let names: Vec<String> = chain
421 .iter()
422 .filter_map(|op| {
423 if let Operation::Map { name } = op {
424 Some(name.clone())
425 } else {
426 None
427 }
428 })
429 .collect();
430
431 Ok(Operation::Map {
432 name: format!("fused[{}]", names.join(" → ")),
433 })
434 }
435 FusedType::FilterChain => {
436 let names: Vec<String> = chain
438 .iter()
439 .filter_map(|op| {
440 if let Operation::Filter { name } = op {
441 Some(name.clone())
442 } else {
443 None
444 }
445 })
446 .collect();
447
448 Ok(Operation::Filter {
449 name: format!("fused[{} && ...]", names.join(" && ")),
450 })
451 }
452 FusedType::MapFilter => {
453 let op_names: Vec<String> = chain.iter().map(|op| op.name()).collect();
455
456 Ok(Operation::Custom {
457 name: format!("fused_map_filter[{}]", op_names.join(" → ")),
458 fusable: true,
459 })
460 }
461 FusedType::Complex => {
462 let op_names: Vec<String> = chain.iter().map(|op| op.name()).collect();
464
465 Ok(Operation::Custom {
466 name: format!("fused_complex[{}]", op_names.join(" → ")),
467 fusable: true,
468 })
469 }
470 }
471 }
472
473 fn classify_fusion(&self, chain: &[Operation]) -> FusedType {
475 let all_maps = chain.iter().all(|op| matches!(op, Operation::Map { .. }));
476 let all_filters = chain
477 .iter()
478 .all(|op| matches!(op, Operation::Filter { .. }));
479
480 if all_maps {
481 FusedType::MapChain
482 } else if all_filters {
483 FusedType::FilterChain
484 } else if chain.iter().any(|op| matches!(op, Operation::Map { .. }))
485 && chain
486 .iter()
487 .any(|op| matches!(op, Operation::Filter { .. }))
488 {
489 FusedType::MapFilter
490 } else {
491 FusedType::Complex
492 }
493 }
494
495 fn estimate_fusion_benefit(&self, op1: &Operation, op2: &Operation) -> f32 {
497 let base_benefit = 0.3;
499
500 let type_benefit = match (op1, op2) {
502 (Operation::Map { .. }, Operation::Map { .. }) => 0.4,
504 (Operation::Filter { .. }, Operation::Filter { .. }) => 0.35,
506 (Operation::Map { .. }, Operation::Filter { .. }) => 0.25,
508 _ => 0.2,
510 };
511
512 let aggressive_multiplier = if self.config.aggressive_mode {
514 1.2
515 } else {
516 1.0
517 };
518
519 (base_benefit + type_benefit) * aggressive_multiplier
520 }
521
522 fn estimate_overhead_reduction(&self, chain: &[Operation]) -> f32 {
524 if chain.len() <= 1 {
525 return 0.0;
526 }
527
528 let eliminated_steps = (chain.len() - 1) as f32;
531 let reduction_per_step = 0.20; (eliminated_steps * reduction_per_step * 100.0).min(95.0)
534 }
535
536 pub async fn get_stats(&self) -> FusionStats {
538 self.stats.read().await.clone()
539 }
540
541 pub async fn reset_stats(&self) {
543 let mut stats = self.stats.write().await;
544 *stats = FusionStats::default();
545 }
546
547 pub async fn analyze_pipeline(&self, pipeline: &[Operation]) -> Result<FusionAnalysis> {
549 let mut fusable_chains = Vec::new();
550 let mut i = 0;
551
552 while i < pipeline.len() {
553 let chain = self.find_fusion_chain(pipeline, i).await?;
554
555 if chain.len() > 1 {
556 let fusion_type = self.classify_fusion(&chain);
557 let benefit = self.estimate_overhead_reduction(&chain);
558
559 fusable_chains.push(FusableChain {
560 start_index: i,
561 operations: chain.clone(),
562 fusion_type,
563 estimated_benefit: benefit,
564 });
565
566 i += chain.len();
567 } else {
568 i += 1;
569 }
570 }
571
572 let ops_saved: usize = fusable_chains.iter().map(|c| c.operations.len() - 1).sum();
573 let estimated_final_count = pipeline.len() - ops_saved;
574
575 Ok(FusionAnalysis {
576 original_operation_count: pipeline.len(),
577 fusable_chains,
578 estimated_final_count,
579 })
580 }
581
582 pub async fn clear_cache(&self) {
584 self.fusion_cache.write().await.clear();
585 }
586}
587
588#[derive(Debug, Clone, Serialize, Deserialize)]
590pub struct FusionAnalysis {
591 pub original_operation_count: usize,
593
594 pub fusable_chains: Vec<FusableChain>,
596
597 pub estimated_final_count: usize,
599}
600
601#[derive(Debug, Clone, Serialize, Deserialize)]
603pub struct FusableChain {
604 pub start_index: usize,
606
607 pub operations: Vec<Operation>,
609
610 pub fusion_type: FusedType,
612
613 pub estimated_benefit: f32,
615}
616
617impl FusionAnalysis {
618 pub fn summary(&self) -> String {
620 format!(
621 "Pipeline Analysis: {} ops -> {} ops ({} fusable chains, {:.1}% reduction)",
622 self.original_operation_count,
623 self.estimated_final_count,
624 self.fusable_chains.len(),
625 ((self.original_operation_count - self.estimated_final_count) as f32
626 / self.original_operation_count as f32
627 * 100.0)
628 )
629 }
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635
636 #[tokio::test]
637 async fn test_fusion_optimizer_creation() {
638 let config = FusionConfig::default();
639 let optimizer = FusionOptimizer::new(config);
640 assert!(optimizer.config.enable_fusion);
641 }
642
643 #[tokio::test]
644 async fn test_map_fusion() {
645 let config = FusionConfig {
646 enable_map_fusion: true,
647 ..Default::default()
648 };
649
650 let mut optimizer = FusionOptimizer::new(config);
651
652 let pipeline = vec![
653 Operation::Map {
654 name: "step1".to_string(),
655 },
656 Operation::Map {
657 name: "step2".to_string(),
658 },
659 Operation::Map {
660 name: "step3".to_string(),
661 },
662 ];
663
664 let optimized = optimizer.optimize_pipeline(&pipeline).await.unwrap();
665
666 assert_eq!(optimized.len(), 1);
668 assert!(matches!(optimized[0], Operation::Map { .. }));
669
670 let stats = optimizer.get_stats().await;
671 assert_eq!(stats.operations_fused, 3);
672 assert_eq!(stats.map_fusions, 1);
673 }
674
675 #[tokio::test]
676 async fn test_filter_fusion() {
677 let config = FusionConfig {
678 enable_filter_fusion: true,
679 ..Default::default()
680 };
681
682 let mut optimizer = FusionOptimizer::new(config);
683
684 let pipeline = vec![
685 Operation::Filter {
686 name: "check1".to_string(),
687 },
688 Operation::Filter {
689 name: "check2".to_string(),
690 },
691 ];
692
693 let optimized = optimizer.optimize_pipeline(&pipeline).await.unwrap();
694
695 assert_eq!(optimized.len(), 1);
697 assert!(matches!(optimized[0], Operation::Filter { .. }));
698
699 let stats = optimizer.get_stats().await;
700 assert_eq!(stats.filter_fusions, 1);
701 }
702
703 #[tokio::test]
704 async fn test_mixed_fusion() {
705 let config = FusionConfig {
706 enable_cross_fusion: true,
707 ..Default::default()
708 };
709
710 let mut optimizer = FusionOptimizer::new(config);
711
712 let pipeline = vec![
713 Operation::Map {
714 name: "transform".to_string(),
715 },
716 Operation::Filter {
717 name: "validate".to_string(),
718 },
719 ];
720
721 let optimized = optimizer.optimize_pipeline(&pipeline).await.unwrap();
722
723 assert_eq!(optimized.len(), 1);
725
726 let stats = optimizer.get_stats().await;
727 assert_eq!(stats.cross_fusions, 1);
728 }
729
730 #[tokio::test]
731 async fn test_no_fusion_when_disabled() {
732 let config = FusionConfig {
733 enable_fusion: false,
734 ..Default::default()
735 };
736
737 let mut optimizer = FusionOptimizer::new(config);
738
739 let pipeline = vec![
740 Operation::Map {
741 name: "step1".to_string(),
742 },
743 Operation::Map {
744 name: "step2".to_string(),
745 },
746 ];
747
748 let optimized = optimizer.optimize_pipeline(&pipeline).await.unwrap();
749
750 assert_eq!(optimized.len(), 2);
752 }
753
754 #[tokio::test]
755 async fn test_min_fusion_size() {
756 let config = FusionConfig {
757 min_fusion_size: 3,
758 ..Default::default()
759 };
760
761 let mut optimizer = FusionOptimizer::new(config);
762
763 let pipeline = vec![
764 Operation::Map {
765 name: "step1".to_string(),
766 },
767 Operation::Map {
768 name: "step2".to_string(),
769 },
770 ];
771
772 let optimized = optimizer.optimize_pipeline(&pipeline).await.unwrap();
773
774 assert_eq!(optimized.len(), 2);
776 }
777
778 #[tokio::test]
779 async fn test_max_fusion_depth() {
780 let config = FusionConfig {
781 max_fusion_depth: 2,
782 ..Default::default()
783 };
784
785 let mut optimizer = FusionOptimizer::new(config);
786
787 let pipeline = vec![
788 Operation::Map {
789 name: "step1".to_string(),
790 },
791 Operation::Map {
792 name: "step2".to_string(),
793 },
794 Operation::Map {
795 name: "step3".to_string(),
796 },
797 ];
798
799 let optimized = optimizer.optimize_pipeline(&pipeline).await.unwrap();
800
801 assert!(optimized.len() <= 2);
803 }
804
805 #[tokio::test]
806 async fn test_fusion_analysis() {
807 let config = FusionConfig::default();
808 let optimizer = FusionOptimizer::new(config);
809
810 let pipeline = vec![
811 Operation::Map {
812 name: "step1".to_string(),
813 },
814 Operation::Map {
815 name: "step2".to_string(),
816 },
817 Operation::Filter {
818 name: "check".to_string(),
819 },
820 ];
821
822 let analysis = optimizer.analyze_pipeline(&pipeline).await.unwrap();
823
824 assert_eq!(analysis.original_operation_count, 3);
825 assert!(!analysis.fusable_chains.is_empty());
826 assert!(analysis.estimated_final_count < analysis.original_operation_count);
827 }
828
829 #[tokio::test]
830 async fn test_operation_can_fuse() {
831 let map1 = Operation::Map {
832 name: "map1".to_string(),
833 };
834 let map2 = Operation::Map {
835 name: "map2".to_string(),
836 };
837 let filter1 = Operation::Filter {
838 name: "filter1".to_string(),
839 };
840
841 assert!(map1.can_fuse_with(&map2));
842 assert!(map1.can_fuse_with(&filter1));
843 }
844
845 #[tokio::test]
846 async fn test_stats_tracking() {
847 let config = FusionConfig::default();
848 let mut optimizer = FusionOptimizer::new(config);
849
850 let pipeline = vec![
851 Operation::Map {
852 name: "step1".to_string(),
853 },
854 Operation::Map {
855 name: "step2".to_string(),
856 },
857 ];
858
859 optimizer.optimize_pipeline(&pipeline).await.unwrap();
860
861 let stats = optimizer.get_stats().await;
862 assert_eq!(stats.pipelines_optimized, 1);
863 assert!(stats.operations_fused > 0);
864 assert!(stats.fusion_chains_created > 0);
865 }
866
867 #[tokio::test]
868 async fn test_reset_stats() {
869 let config = FusionConfig::default();
870 let mut optimizer = FusionOptimizer::new(config);
871
872 let pipeline = vec![
873 Operation::Map {
874 name: "step1".to_string(),
875 },
876 Operation::Map {
877 name: "step2".to_string(),
878 },
879 ];
880
881 optimizer.optimize_pipeline(&pipeline).await.unwrap();
882 optimizer.reset_stats().await;
883
884 let stats = optimizer.get_stats().await;
885 assert_eq!(stats.pipelines_optimized, 0);
886 assert_eq!(stats.operations_fused, 0);
887 }
888}