Skip to main content

oxirs_stream/
stream_fusion.rs

1//! # Stream Fusion Optimizer
2//!
3//! Automatically detects and fuses consecutive stream operations into single passes,
4//! eliminating intermediate allocations, reducing function call overhead, and improving
5//! cache locality for significant performance gains.
6//!
7//! ## Features
8//!
9//! - **Automatic Fusion**: Detects fusable operation sequences
10//! - **Multiple Fusion Rules**: Map-Map, Filter-Filter, Map-Filter combinations
11//! - **Cost-Based Optimization**: Only fuses when beneficial
12//! - **Performance Metrics**: Tracks fusion benefits and overhead reduction
13//! - **Safe Transformations**: Validates fusion correctness
14//! - **Configurable**: Enable/disable specific fusion types
15//!
16//! ## Fusion Rules
17//!
18//! 1. **Map Fusion**: `map(f) → map(g)` becomes `map(g ∘ f)`
19//! 2. **Filter Fusion**: `filter(p) → filter(q)` becomes `filter(p && q)`
20//! 3. **Map-Filter Fusion**: `map(f) → filter(p)` becomes `filter_map(|x| p(f(x)))`
21//! 4. **Filter-Map Reordering**: Sometimes safe to reorder for better fusion
22//!
23//! ## Example
24//!
25//! ```ignore
26//! use oxirs_stream::stream_fusion::{FusionOptimizer, FusionConfig};
27//! use oxirs_stream::stream_fusion::Operation;
28//!
29//! # async fn example() -> anyhow::Result<()> {
30//! let config = FusionConfig {
31//!     enable_fusion: true,
32//!     max_fusion_depth: 5,
33//!     enable_map_fusion: true,
34//!     enable_filter_fusion: true,
35//!     enable_cross_fusion: true,
36//!     ..Default::default()
37//! };
38//!
39//! let mut optimizer = FusionOptimizer::new(config);
40//!
41//! // Define a pipeline with multiple operations
42//! let pipeline = vec![
43//!     Operation::Map { name: "normalize".to_string() },
44//!     Operation::Map { name: "transform".to_string() },
45//!     Operation::Filter { name: "validate".to_string() },
46//!     Operation::Filter { name: "check_bounds".to_string() },
47//! ];
48//!
49//! // Optimize the pipeline
50//! let optimized = optimizer.optimize_pipeline(&pipeline)?;
51//!
52//! // Get fusion statistics
53//! let stats = optimizer.get_stats();
54//! println!("Fused {} operations, saved {}% overhead",
55//!          stats.operations_fused, stats.overhead_reduction_percent);
56//! # Ok(())
57//! # }
58//! ```
59
60use anyhow::{anyhow, Result};
61use serde::{Deserialize, Serialize};
62use std::collections::HashMap;
63use std::sync::Arc;
64use std::time::{Duration, Instant};
65use tokio::sync::RwLock;
66use tracing::{debug, info};
67
68/// Configuration for stream fusion optimizer
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct FusionConfig {
71    /// Enable stream fusion optimization
72    pub enable_fusion: bool,
73
74    /// Maximum number of operations to fuse in a single chain
75    pub max_fusion_depth: usize,
76
77    /// Enable map-map fusion
78    pub enable_map_fusion: bool,
79
80    /// Enable filter-filter fusion
81    pub enable_filter_fusion: bool,
82
83    /// Enable map-filter cross fusion
84    pub enable_cross_fusion: bool,
85
86    /// Enable filter-map reordering (requires analysis)
87    pub enable_reordering: bool,
88
89    /// Minimum operations required to consider fusion (avoid overhead for small chains)
90    pub min_fusion_size: usize,
91
92    /// Cost threshold for fusion (only fuse if benefit > cost)
93    pub cost_threshold: f32,
94
95    /// Enable aggressive fusion (may increase compilation time)
96    pub aggressive_mode: bool,
97
98    /// Enable fusion metrics collection
99    pub collect_metrics: bool,
100}
101
102impl Default for FusionConfig {
103    fn default() -> Self {
104        Self {
105            enable_fusion: true,
106            max_fusion_depth: 10,
107            enable_map_fusion: true,
108            enable_filter_fusion: true,
109            enable_cross_fusion: true,
110            enable_reordering: false, // Conservative default
111            min_fusion_size: 2,
112            cost_threshold: 0.1,
113            aggressive_mode: false,
114            collect_metrics: true,
115        }
116    }
117}
118
119/// Stream operation types that can be fused
120#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
121pub enum Operation {
122    /// Map operation: transforms each element
123    Map { name: String },
124
125    /// Filter operation: selects elements based on predicate
126    Filter { name: String },
127
128    /// FlatMap operation: maps and flattens
129    FlatMap { name: String },
130
131    /// Distinct operation: removes duplicates
132    Distinct { name: String },
133
134    /// Take operation: limits number of elements
135    Take { count: usize },
136
137    /// Skip operation: skips first n elements
138    Skip { count: usize },
139
140    /// Custom operation: user-defined
141    Custom { name: String, fusable: bool },
142}
143
144impl Operation {
145    /// Check if this operation can be fused with another
146    pub fn can_fuse_with(&self, other: &Operation) -> bool {
147        match (self, other) {
148            // Map can fuse with map
149            (Operation::Map { .. }, Operation::Map { .. }) => true,
150            // Filter can fuse with filter
151            (Operation::Filter { .. }, Operation::Filter { .. }) => true,
152            // Map can fuse with filter
153            (Operation::Map { .. }, Operation::Filter { .. }) => true,
154            // Custom operations check fusable flag
155            (Operation::Custom { fusable: true, .. }, Operation::Custom { fusable: true, .. }) => {
156                true
157            }
158            _ => false,
159        }
160    }
161
162    /// Get operation name for debugging
163    pub fn name(&self) -> String {
164        match self {
165            Operation::Map { name } => format!("map({})", name),
166            Operation::Filter { name } => format!("filter({})", name),
167            Operation::FlatMap { name } => format!("flat_map({})", name),
168            Operation::Distinct { name } => format!("distinct({})", name),
169            Operation::Take { count } => format!("take({})", count),
170            Operation::Skip { count } => format!("skip({})", count),
171            Operation::Custom { name, .. } => format!("custom({})", name),
172        }
173    }
174}
175
176/// Fused operation combining multiple operations
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct FusedOperation {
179    /// Original operations that were fused
180    pub original_ops: Vec<Operation>,
181
182    /// Fused operation type
183    pub fused_type: FusedType,
184
185    /// Estimated cost savings (0.0-1.0)
186    pub cost_savings: f32,
187
188    /// Fusion timestamp
189    pub fused_at: chrono::DateTime<chrono::Utc>,
190}
191
192/// Type of fused operation
193#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
194pub enum FusedType {
195    /// Multiple maps fused into one
196    MapChain,
197
198    /// Multiple filters fused into one
199    FilterChain,
200
201    /// Map and filter fused
202    MapFilter,
203
204    /// Complex fusion
205    Complex,
206}
207
208/// Fusion statistics
209#[derive(Debug, Clone, Default, Serialize, Deserialize)]
210pub struct FusionStats {
211    /// Total pipelines optimized
212    pub pipelines_optimized: u64,
213
214    /// Total operations analyzed
215    pub operations_analyzed: u64,
216
217    /// Total operations fused
218    pub operations_fused: u64,
219
220    /// Number of fusion chains created
221    pub fusion_chains_created: u64,
222
223    /// Estimated overhead reduction (percentage)
224    pub overhead_reduction_percent: f32,
225
226    /// Average fusion chain length
227    pub avg_fusion_chain_length: f32,
228
229    /// Map fusions performed
230    pub map_fusions: u64,
231
232    /// Filter fusions performed
233    pub filter_fusions: u64,
234
235    /// Cross fusions performed (map+filter)
236    pub cross_fusions: u64,
237
238    /// Reorderings performed
239    pub reorderings: u64,
240
241    /// Total optimization time
242    pub total_optimization_time: Duration,
243
244    /// Last optimization timestamp
245    pub last_optimization: Option<chrono::DateTime<chrono::Utc>>,
246}
247
248/// Stream fusion optimizer
249pub struct FusionOptimizer {
250    config: FusionConfig,
251    stats: Arc<RwLock<FusionStats>>,
252    fusion_cache: Arc<RwLock<HashMap<String, Vec<FusedOperation>>>>,
253}
254
255impl FusionOptimizer {
256    /// Create a new fusion optimizer
257    pub fn new(config: FusionConfig) -> Self {
258        Self {
259            config,
260            stats: Arc::new(RwLock::new(FusionStats::default())),
261            fusion_cache: Arc::new(RwLock::new(HashMap::new())),
262        }
263    }
264
265    /// Optimize a pipeline by fusing operations
266    pub async fn optimize_pipeline(&mut self, pipeline: &[Operation]) -> Result<Vec<Operation>> {
267        if !self.config.enable_fusion {
268            return Ok(pipeline.to_vec());
269        }
270
271        if pipeline.len() < self.config.min_fusion_size {
272            debug!(
273                "Pipeline too small for fusion: {} operations",
274                pipeline.len()
275            );
276            return Ok(pipeline.to_vec());
277        }
278
279        let start_time = Instant::now();
280
281        // Update stats
282        let mut stats = self.stats.write().await;
283        stats.pipelines_optimized += 1;
284        stats.operations_analyzed += pipeline.len() as u64;
285        drop(stats);
286
287        // Perform fusion optimization
288        let optimized = self.fuse_operations(pipeline).await?;
289
290        // Update optimization time
291        let optimization_time = start_time.elapsed();
292        let mut stats = self.stats.write().await;
293        stats.total_optimization_time += optimization_time;
294        stats.last_optimization = Some(chrono::Utc::now());
295
296        info!(
297            "Optimized pipeline: {} ops -> {} ops in {:?}",
298            pipeline.len(),
299            optimized.len(),
300            optimization_time
301        );
302
303        Ok(optimized)
304    }
305
306    /// Fuse consecutive operations in the pipeline
307    async fn fuse_operations(&self, operations: &[Operation]) -> Result<Vec<Operation>> {
308        let mut result = Vec::new();
309        let mut i = 0;
310
311        while i < operations.len() {
312            // Try to fuse starting from this position
313            let fusion_chain = self.find_fusion_chain(operations, i).await?;
314
315            if fusion_chain.len() > 1 {
316                // Multiple operations can be fused
317                let fused_op = self.create_fused_operation(&fusion_chain).await?;
318                result.push(fused_op);
319
320                // Update stats
321                let mut stats = self.stats.write().await;
322                stats.operations_fused += fusion_chain.len() as u64;
323                stats.fusion_chains_created += 1;
324
325                // Update fusion type counts
326                let fusion_type = self.classify_fusion(&fusion_chain);
327                match fusion_type {
328                    FusedType::MapChain => stats.map_fusions += 1,
329                    FusedType::FilterChain => stats.filter_fusions += 1,
330                    FusedType::MapFilter => stats.cross_fusions += 1,
331                    FusedType::Complex => {}
332                }
333
334                // Calculate overhead reduction
335                let reduction = self.estimate_overhead_reduction(&fusion_chain);
336                stats.overhead_reduction_percent =
337                    (stats.overhead_reduction_percent + reduction) / 2.0;
338
339                // Update average chain length
340                let chain_len = fusion_chain.len() as f32;
341                stats.avg_fusion_chain_length = (stats.avg_fusion_chain_length
342                    * (stats.fusion_chains_created - 1) as f32
343                    + chain_len)
344                    / stats.fusion_chains_created as f32;
345
346                i += fusion_chain.len();
347            } else {
348                // Cannot fuse, keep original operation
349                result.push(operations[i].clone());
350                i += 1;
351            }
352        }
353
354        Ok(result)
355    }
356
357    /// Find the longest chain of fusable operations starting from position
358    async fn find_fusion_chain(
359        &self,
360        operations: &[Operation],
361        start: usize,
362    ) -> Result<Vec<Operation>> {
363        let mut chain = vec![operations[start].clone()];
364        let mut current = start;
365
366        while current + 1 < operations.len() && chain.len() < self.config.max_fusion_depth {
367            let current_op = &operations[current];
368            let next_op = &operations[current + 1];
369
370            // Check if operations can be fused based on config
371            let can_fuse = match (current_op, next_op) {
372                (Operation::Map { .. }, Operation::Map { .. }) => self.config.enable_map_fusion,
373                (Operation::Filter { .. }, Operation::Filter { .. }) => {
374                    self.config.enable_filter_fusion
375                }
376                (Operation::Map { .. }, Operation::Filter { .. }) => {
377                    self.config.enable_cross_fusion
378                }
379                (Operation::Filter { .. }, Operation::Map { .. }) => {
380                    self.config.enable_cross_fusion && self.config.enable_reordering
381                }
382                _ => current_op.can_fuse_with(next_op),
383            };
384
385            if can_fuse {
386                // Check cost-benefit
387                let benefit = self.estimate_fusion_benefit(current_op, next_op);
388                if benefit >= self.config.cost_threshold {
389                    chain.push(next_op.clone());
390                    current += 1;
391                } else {
392                    debug!("Fusion benefit too low: {}", benefit);
393                    break;
394                }
395            } else {
396                break;
397            }
398        }
399
400        Ok(chain)
401    }
402
403    /// Create a fused operation from a chain
404    async fn create_fused_operation(&self, chain: &[Operation]) -> Result<Operation> {
405        if chain.is_empty() {
406            return Err(anyhow!("Cannot create fused operation from empty chain"));
407        }
408
409        if chain.len() == 1 {
410            return Ok(chain[0].clone());
411        }
412
413        // Classify the fusion type
414        let fusion_type = self.classify_fusion(chain);
415
416        // Create appropriate fused operation
417        match fusion_type {
418            FusedType::MapChain => {
419                // Combine map operations
420                let names: Vec<String> = chain
421                    .iter()
422                    .filter_map(|op| {
423                        if let Operation::Map { name } = op {
424                            Some(name.clone())
425                        } else {
426                            None
427                        }
428                    })
429                    .collect();
430
431                Ok(Operation::Map {
432                    name: format!("fused[{}]", names.join(" → ")),
433                })
434            }
435            FusedType::FilterChain => {
436                // Combine filter operations
437                let names: Vec<String> = chain
438                    .iter()
439                    .filter_map(|op| {
440                        if let Operation::Filter { name } = op {
441                            Some(name.clone())
442                        } else {
443                            None
444                        }
445                    })
446                    .collect();
447
448                Ok(Operation::Filter {
449                    name: format!("fused[{} && ...]", names.join(" && ")),
450                })
451            }
452            FusedType::MapFilter => {
453                // Combine map and filter
454                let op_names: Vec<String> = chain.iter().map(|op| op.name()).collect();
455
456                Ok(Operation::Custom {
457                    name: format!("fused_map_filter[{}]", op_names.join(" → ")),
458                    fusable: true,
459                })
460            }
461            FusedType::Complex => {
462                // Complex fusion
463                let op_names: Vec<String> = chain.iter().map(|op| op.name()).collect();
464
465                Ok(Operation::Custom {
466                    name: format!("fused_complex[{}]", op_names.join(" → ")),
467                    fusable: true,
468                })
469            }
470        }
471    }
472
473    /// Classify the type of fusion for a chain
474    fn classify_fusion(&self, chain: &[Operation]) -> FusedType {
475        let all_maps = chain.iter().all(|op| matches!(op, Operation::Map { .. }));
476        let all_filters = chain
477            .iter()
478            .all(|op| matches!(op, Operation::Filter { .. }));
479
480        if all_maps {
481            FusedType::MapChain
482        } else if all_filters {
483            FusedType::FilterChain
484        } else if chain.iter().any(|op| matches!(op, Operation::Map { .. }))
485            && chain
486                .iter()
487                .any(|op| matches!(op, Operation::Filter { .. }))
488        {
489            FusedType::MapFilter
490        } else {
491            FusedType::Complex
492        }
493    }
494
495    /// Estimate the benefit of fusing two operations
496    fn estimate_fusion_benefit(&self, op1: &Operation, op2: &Operation) -> f32 {
497        // Base benefit from eliminating intermediate overhead
498        let base_benefit = 0.3;
499
500        // Additional benefit based on operation types
501        let type_benefit = match (op1, op2) {
502            // Map-map fusion has high benefit (eliminates intermediate allocation)
503            (Operation::Map { .. }, Operation::Map { .. }) => 0.4,
504            // Filter-filter fusion has good benefit (combines predicates)
505            (Operation::Filter { .. }, Operation::Filter { .. }) => 0.35,
506            // Map-filter has moderate benefit
507            (Operation::Map { .. }, Operation::Filter { .. }) => 0.25,
508            // Other combinations
509            _ => 0.2,
510        };
511
512        // Aggressive mode increases benefit estimates
513        let aggressive_multiplier = if self.config.aggressive_mode {
514            1.2
515        } else {
516            1.0
517        };
518
519        (base_benefit + type_benefit) * aggressive_multiplier
520    }
521
522    /// Estimate overhead reduction from fusing a chain
523    fn estimate_overhead_reduction(&self, chain: &[Operation]) -> f32 {
524        if chain.len() <= 1 {
525            return 0.0;
526        }
527
528        // Each fused operation eliminates one intermediate step
529        // Estimate 15-25% overhead reduction per eliminated step
530        let eliminated_steps = (chain.len() - 1) as f32;
531        let reduction_per_step = 0.20; // 20% average
532
533        (eliminated_steps * reduction_per_step * 100.0).min(95.0)
534    }
535
536    /// Get fusion statistics
537    pub async fn get_stats(&self) -> FusionStats {
538        self.stats.read().await.clone()
539    }
540
541    /// Reset statistics
542    pub async fn reset_stats(&self) {
543        let mut stats = self.stats.write().await;
544        *stats = FusionStats::default();
545    }
546
547    /// Analyze a pipeline without applying fusion (dry run)
548    pub async fn analyze_pipeline(&self, pipeline: &[Operation]) -> Result<FusionAnalysis> {
549        let mut fusable_chains = Vec::new();
550        let mut i = 0;
551
552        while i < pipeline.len() {
553            let chain = self.find_fusion_chain(pipeline, i).await?;
554
555            if chain.len() > 1 {
556                let fusion_type = self.classify_fusion(&chain);
557                let benefit = self.estimate_overhead_reduction(&chain);
558
559                fusable_chains.push(FusableChain {
560                    start_index: i,
561                    operations: chain.clone(),
562                    fusion_type,
563                    estimated_benefit: benefit,
564                });
565
566                i += chain.len();
567            } else {
568                i += 1;
569            }
570        }
571
572        let ops_saved: usize = fusable_chains.iter().map(|c| c.operations.len() - 1).sum();
573        let estimated_final_count = pipeline.len() - ops_saved;
574
575        Ok(FusionAnalysis {
576            original_operation_count: pipeline.len(),
577            fusable_chains,
578            estimated_final_count,
579        })
580    }
581
582    /// Clear the fusion cache
583    pub async fn clear_cache(&self) {
584        self.fusion_cache.write().await.clear();
585    }
586}
587
588/// Result of pipeline analysis
589#[derive(Debug, Clone, Serialize, Deserialize)]
590pub struct FusionAnalysis {
591    /// Original number of operations
592    pub original_operation_count: usize,
593
594    /// Chains that can be fused
595    pub fusable_chains: Vec<FusableChain>,
596
597    /// Estimated operation count after fusion
598    pub estimated_final_count: usize,
599}
600
601/// A chain of operations that can be fused
602#[derive(Debug, Clone, Serialize, Deserialize)]
603pub struct FusableChain {
604    /// Starting index in the original pipeline
605    pub start_index: usize,
606
607    /// Operations in the chain
608    pub operations: Vec<Operation>,
609
610    /// Type of fusion
611    pub fusion_type: FusedType,
612
613    /// Estimated benefit (percentage)
614    pub estimated_benefit: f32,
615}
616
617impl FusionAnalysis {
618    /// Get a summary of the analysis
619    pub fn summary(&self) -> String {
620        format!(
621            "Pipeline Analysis: {} ops -> {} ops ({} fusable chains, {:.1}% reduction)",
622            self.original_operation_count,
623            self.estimated_final_count,
624            self.fusable_chains.len(),
625            ((self.original_operation_count - self.estimated_final_count) as f32
626                / self.original_operation_count as f32
627                * 100.0)
628        )
629    }
630}
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635
636    #[tokio::test]
637    async fn test_fusion_optimizer_creation() {
638        let config = FusionConfig::default();
639        let optimizer = FusionOptimizer::new(config);
640        assert!(optimizer.config.enable_fusion);
641    }
642
643    #[tokio::test]
644    async fn test_map_fusion() {
645        let config = FusionConfig {
646            enable_map_fusion: true,
647            ..Default::default()
648        };
649
650        let mut optimizer = FusionOptimizer::new(config);
651
652        let pipeline = vec![
653            Operation::Map {
654                name: "step1".to_string(),
655            },
656            Operation::Map {
657                name: "step2".to_string(),
658            },
659            Operation::Map {
660                name: "step3".to_string(),
661            },
662        ];
663
664        let optimized = optimizer.optimize_pipeline(&pipeline).await.unwrap();
665
666        // Should fuse all three maps into one
667        assert_eq!(optimized.len(), 1);
668        assert!(matches!(optimized[0], Operation::Map { .. }));
669
670        let stats = optimizer.get_stats().await;
671        assert_eq!(stats.operations_fused, 3);
672        assert_eq!(stats.map_fusions, 1);
673    }
674
675    #[tokio::test]
676    async fn test_filter_fusion() {
677        let config = FusionConfig {
678            enable_filter_fusion: true,
679            ..Default::default()
680        };
681
682        let mut optimizer = FusionOptimizer::new(config);
683
684        let pipeline = vec![
685            Operation::Filter {
686                name: "check1".to_string(),
687            },
688            Operation::Filter {
689                name: "check2".to_string(),
690            },
691        ];
692
693        let optimized = optimizer.optimize_pipeline(&pipeline).await.unwrap();
694
695        // Should fuse filters
696        assert_eq!(optimized.len(), 1);
697        assert!(matches!(optimized[0], Operation::Filter { .. }));
698
699        let stats = optimizer.get_stats().await;
700        assert_eq!(stats.filter_fusions, 1);
701    }
702
703    #[tokio::test]
704    async fn test_mixed_fusion() {
705        let config = FusionConfig {
706            enable_cross_fusion: true,
707            ..Default::default()
708        };
709
710        let mut optimizer = FusionOptimizer::new(config);
711
712        let pipeline = vec![
713            Operation::Map {
714                name: "transform".to_string(),
715            },
716            Operation::Filter {
717                name: "validate".to_string(),
718            },
719        ];
720
721        let optimized = optimizer.optimize_pipeline(&pipeline).await.unwrap();
722
723        // Should fuse map and filter
724        assert_eq!(optimized.len(), 1);
725
726        let stats = optimizer.get_stats().await;
727        assert_eq!(stats.cross_fusions, 1);
728    }
729
730    #[tokio::test]
731    async fn test_no_fusion_when_disabled() {
732        let config = FusionConfig {
733            enable_fusion: false,
734            ..Default::default()
735        };
736
737        let mut optimizer = FusionOptimizer::new(config);
738
739        let pipeline = vec![
740            Operation::Map {
741                name: "step1".to_string(),
742            },
743            Operation::Map {
744                name: "step2".to_string(),
745            },
746        ];
747
748        let optimized = optimizer.optimize_pipeline(&pipeline).await.unwrap();
749
750        // Should not fuse
751        assert_eq!(optimized.len(), 2);
752    }
753
754    #[tokio::test]
755    async fn test_min_fusion_size() {
756        let config = FusionConfig {
757            min_fusion_size: 3,
758            ..Default::default()
759        };
760
761        let mut optimizer = FusionOptimizer::new(config);
762
763        let pipeline = vec![
764            Operation::Map {
765                name: "step1".to_string(),
766            },
767            Operation::Map {
768                name: "step2".to_string(),
769            },
770        ];
771
772        let optimized = optimizer.optimize_pipeline(&pipeline).await.unwrap();
773
774        // Should not fuse (too small)
775        assert_eq!(optimized.len(), 2);
776    }
777
778    #[tokio::test]
779    async fn test_max_fusion_depth() {
780        let config = FusionConfig {
781            max_fusion_depth: 2,
782            ..Default::default()
783        };
784
785        let mut optimizer = FusionOptimizer::new(config);
786
787        let pipeline = vec![
788            Operation::Map {
789                name: "step1".to_string(),
790            },
791            Operation::Map {
792                name: "step2".to_string(),
793            },
794            Operation::Map {
795                name: "step3".to_string(),
796            },
797        ];
798
799        let optimized = optimizer.optimize_pipeline(&pipeline).await.unwrap();
800
801        // Should fuse only first 2, then the third separately (or as another fusion)
802        assert!(optimized.len() <= 2);
803    }
804
805    #[tokio::test]
806    async fn test_fusion_analysis() {
807        let config = FusionConfig::default();
808        let optimizer = FusionOptimizer::new(config);
809
810        let pipeline = vec![
811            Operation::Map {
812                name: "step1".to_string(),
813            },
814            Operation::Map {
815                name: "step2".to_string(),
816            },
817            Operation::Filter {
818                name: "check".to_string(),
819            },
820        ];
821
822        let analysis = optimizer.analyze_pipeline(&pipeline).await.unwrap();
823
824        assert_eq!(analysis.original_operation_count, 3);
825        assert!(!analysis.fusable_chains.is_empty());
826        assert!(analysis.estimated_final_count < analysis.original_operation_count);
827    }
828
829    #[tokio::test]
830    async fn test_operation_can_fuse() {
831        let map1 = Operation::Map {
832            name: "map1".to_string(),
833        };
834        let map2 = Operation::Map {
835            name: "map2".to_string(),
836        };
837        let filter1 = Operation::Filter {
838            name: "filter1".to_string(),
839        };
840
841        assert!(map1.can_fuse_with(&map2));
842        assert!(map1.can_fuse_with(&filter1));
843    }
844
845    #[tokio::test]
846    async fn test_stats_tracking() {
847        let config = FusionConfig::default();
848        let mut optimizer = FusionOptimizer::new(config);
849
850        let pipeline = vec![
851            Operation::Map {
852                name: "step1".to_string(),
853            },
854            Operation::Map {
855                name: "step2".to_string(),
856            },
857        ];
858
859        optimizer.optimize_pipeline(&pipeline).await.unwrap();
860
861        let stats = optimizer.get_stats().await;
862        assert_eq!(stats.pipelines_optimized, 1);
863        assert!(stats.operations_fused > 0);
864        assert!(stats.fusion_chains_created > 0);
865    }
866
867    #[tokio::test]
868    async fn test_reset_stats() {
869        let config = FusionConfig::default();
870        let mut optimizer = FusionOptimizer::new(config);
871
872        let pipeline = vec![
873            Operation::Map {
874                name: "step1".to_string(),
875            },
876            Operation::Map {
877                name: "step2".to_string(),
878            },
879        ];
880
881        optimizer.optimize_pipeline(&pipeline).await.unwrap();
882        optimizer.reset_stats().await;
883
884        let stats = optimizer.get_stats().await;
885        assert_eq!(stats.pipelines_optimized, 0);
886        assert_eq!(stats.operations_fused, 0);
887    }
888}