Skip to main content

amari_enumerative/
schubert.rs

1//! Schubert calculus on Grassmannians and flag varieties
2//!
3//! This module implements Schubert classes and their intersection theory
4//! on Grassmannians and flag varieties.
5//!
6//! # Contracts
7//!
8//! The key mathematical invariants maintained:
9//!
10//! - **Grassmannian bounds**: Partitions fit in the k × (n-k) box
11//! - **Intersection dimension**: codim(σ_λ ∩ σ_μ) = codim(σ_λ) + codim(σ_μ) (generically)
12//! - **Transversality**: When ∑ codim = dim(Gr), intersection is finite
13//! - **Commutativity**: σ_λ · σ_μ = σ_μ · σ_λ
14//!
15//! # Rayon Parallelization
16//!
17//! When the `parallel` feature is enabled, computationally intensive operations
18//! use parallel iterators for improved performance on multi-core systems.
19
20use crate::littlewood_richardson::{lr_coefficient, schubert_product, Partition};
21use crate::{ChowClass, EnumerativeError, EnumerativeResult};
22use num_rational::Rational64;
23use std::collections::{BTreeMap, HashMap};
24
25#[cfg(feature = "parallel")]
26use rayon::prelude::*;
27
28/// Result of a Schubert intersection computation
29///
30/// # Contracts
31///
32/// - `Empty`: Returned when total codimension exceeds Grassmannian dimension
33/// - `Finite(n)`: Returned when total codimension equals Grassmannian dimension
34/// - `PositiveDimensional`: Returned when total codimension is less than Grassmannian dimension
35///
36/// # Invariant
37///
38/// ```text
39/// requires: total_codim = sum of codimensions of input classes
40/// ensures:
41///   - total_codim > dim(Gr) => Empty
42///   - total_codim == dim(Gr) => Finite(n) where n >= 0
43///   - total_codim < dim(Gr) => PositiveDimensional { dimension: dim(Gr) - total_codim }
44/// ```
45#[derive(Debug, Clone, PartialEq, Eq, Default)]
46pub enum IntersectionResult {
47    /// Empty intersection (overdetermined)
48    #[default]
49    Empty,
50    /// Finite number of points
51    Finite(u64),
52    /// Positive-dimensional intersection
53    PositiveDimensional {
54        /// Dimension of the intersection
55        dimension: usize,
56        /// Degree (if computable)
57        degree: Option<u64>,
58    },
59}
60
61/// Schubert class indexed by a Young diagram/partition
62///
63/// # Contracts
64///
65/// - Each partition entry must be ≤ n-k
66/// - Partition entries should be weakly decreasing (though we allow flexibility)
67#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub struct SchubertClass {
69    /// The partition indexing this Schubert class
70    pub partition: Vec<usize>,
71    /// Dimension of the underlying Grassmannian
72    pub grassmannian_dim: (usize, usize), // (k, n) for Gr(k, n)
73}
74
75impl SchubertClass {
76    /// Create a new Schubert class
77    ///
78    /// # Contract
79    ///
80    /// ```text
81    /// requires: forall i. partition[i] <= n - k
82    /// ensures: result.codimension() == partition.iter().sum()
83    /// ```
84    ///
85    /// # Errors
86    ///
87    /// Returns `SchubertError` if any partition entry exceeds n-k.
88    pub fn new(partition: Vec<usize>, grassmannian_dim: (usize, usize)) -> EnumerativeResult<Self> {
89        let (k, n) = grassmannian_dim;
90
91        // Validate partition - each part must be ≤ n-k
92        for &part in &partition {
93            if part > n - k {
94                return Err(EnumerativeError::SchubertError(format!(
95                    "Partition entry {} exceeds n-k = {}",
96                    part,
97                    n - k
98                )));
99            }
100        }
101
102        Ok(Self {
103            partition,
104            grassmannian_dim,
105        })
106    }
107
108    /// Create from a Partition type
109    pub fn from_partition(
110        partition: Partition,
111        grassmannian_dim: (usize, usize),
112    ) -> EnumerativeResult<Self> {
113        Self::new(partition.parts, grassmannian_dim)
114    }
115
116    /// Convert to a Partition type
117    #[must_use]
118    pub fn to_partition(&self) -> Partition {
119        Partition::new(self.partition.clone())
120    }
121
122    /// Convert to a Chow class
123    #[must_use]
124    pub fn to_chow_class(&self) -> ChowClass {
125        let codimension = self.partition.iter().sum::<usize>();
126        let degree = Rational64::from(1);
127
128        ChowClass::new(codimension, degree)
129    }
130
131    /// Compute the dimension of this Schubert variety
132    ///
133    /// # Contract
134    ///
135    /// ```text
136    /// ensures: result == k * (n - k) - self.codimension()
137    /// ensures: result <= k * (n - k)
138    /// ```
139    #[must_use]
140    pub fn dimension(&self) -> usize {
141        let (k, n) = self.grassmannian_dim;
142        let total_dim = k * (n - k);
143        let codim = self.partition.iter().sum::<usize>();
144        total_dim - codim
145    }
146
147    /// Compute the codimension of this Schubert variety
148    ///
149    /// # Contract
150    ///
151    /// ```text
152    /// ensures: result == partition.iter().sum()
153    /// ```
154    #[must_use]
155    pub fn codimension(&self) -> usize {
156        self.partition.iter().sum()
157    }
158
159    /// Raise Schubert class to a power (repeated intersection)
160    #[must_use]
161    pub fn power(&self, exponent: usize) -> SchubertClass {
162        // Simplified - real power requires sophisticated Schubert calculus
163        let mut new_partition = self.partition.clone();
164        for _ in 1..exponent {
165            if !new_partition.is_empty() {
166                new_partition[0] += 1;
167            } else {
168                new_partition.push(1);
169            }
170        }
171
172        SchubertClass {
173            partition: new_partition,
174            grassmannian_dim: self.grassmannian_dim,
175        }
176    }
177
178    /// Giambelli determinant formula
179    pub fn giambelli_determinant(
180        partition: &[usize],
181        grassmannian_dim: (usize, usize),
182    ) -> EnumerativeResult<Self> {
183        Self::new(partition.to_vec(), grassmannian_dim)
184    }
185}
186
187/// Schubert calculus engine
188#[derive(Debug)]
189pub struct SchubertCalculus {
190    /// The underlying Grassmannian
191    pub grassmannian_dim: (usize, usize),
192    /// Cache for computed intersection numbers
193    intersection_cache: HashMap<(Vec<usize>, Vec<usize>), Rational64>,
194    /// Cache for LR coefficients
195    lr_cache: BTreeMap<(Partition, Partition, Partition), u64>,
196}
197
198impl Default for SchubertCalculus {
199    fn default() -> Self {
200        Self::new((2, 4)) // Default to Gr(2,4)
201    }
202}
203
204impl SchubertCalculus {
205    /// Create a new Schubert calculus engine
206    #[must_use]
207    pub fn new(grassmannian_dim: (usize, usize)) -> Self {
208        Self {
209            grassmannian_dim,
210            intersection_cache: HashMap::new(),
211            lr_cache: BTreeMap::new(),
212        }
213    }
214
215    /// Get the dimension of the Grassmannian
216    ///
217    /// # Contract
218    ///
219    /// ```text
220    /// ensures: result == k * (n - k)
221    /// ```
222    #[must_use]
223    pub fn grassmannian_dimension(&self) -> usize {
224        let (k, n) = self.grassmannian_dim;
225        k * (n - k)
226    }
227
228    /// Compute intersection number of two Schubert classes
229    ///
230    /// # Contract
231    ///
232    /// ```text
233    /// requires: class1.grassmannian_dim == class2.grassmannian_dim == self.grassmannian_dim
234    /// ensures: result >= 0
235    /// ensures: class1.dimension() + class2.dimension() != self.grassmannian_dimension()
236    ///          => result == 0
237    /// ```
238    pub fn intersection_number(
239        &mut self,
240        class1: &SchubertClass,
241        class2: &SchubertClass,
242    ) -> EnumerativeResult<Rational64> {
243        // Check cache first
244        let key = (class1.partition.clone(), class2.partition.clone());
245        if let Some(&cached) = self.intersection_cache.get(&key) {
246            return Ok(cached);
247        }
248
249        // Use actual Schubert calculus for two classes
250        let result = if class1.dimension() + class2.dimension() == self.grassmannian_dimension() {
251            // Transverse intersection - compute via LR coefficients
252            let p1 = class1.to_partition();
253            let p2 = class2.to_partition();
254            let (k, n) = self.grassmannian_dim;
255            let fundamental = Partition::new(vec![n - k; k]);
256
257            let coeff = lr_coefficient(&p1, &p2, &fundamental);
258            Rational64::from(coeff as i64)
259        } else {
260            Rational64::from(0)
261        };
262
263        // Cache the result
264        self.intersection_cache.insert(key, result);
265        Ok(result)
266    }
267
268    /// Intersect multiple Schubert classes
269    ///
270    /// Given classes σ_{λ_1}, ..., σ_{λ_m}, compute their intersection number
271    /// in the Grassmannian Gr(k, n).
272    ///
273    /// # Contract
274    ///
275    /// ```text
276    /// requires: forall c in classes. c.grassmannian_dim == self.grassmannian_dim
277    /// ensures:
278    ///   - sum(c.codimension() for c in classes) > dim(Gr) => Empty
279    ///   - sum(c.codimension() for c in classes) == dim(Gr) => Finite(n)
280    ///   - sum(c.codimension() for c in classes) < dim(Gr) => PositiveDimensional
281    /// ```
282    pub fn multi_intersect(&mut self, classes: &[SchubertClass]) -> IntersectionResult {
283        if classes.is_empty() {
284            return IntersectionResult::PositiveDimensional {
285                dimension: self.grassmannian_dimension(),
286                degree: Some(1),
287            };
288        }
289
290        let grassmannian_dim = self.grassmannian_dimension();
291
292        // Total codimension
293        let total_codim: usize = classes.iter().map(|c| c.codimension()).sum();
294
295        match total_codim.cmp(&grassmannian_dim) {
296            std::cmp::Ordering::Greater => IntersectionResult::Empty,
297            std::cmp::Ordering::Less => {
298                let remaining_dim = grassmannian_dim - total_codim;
299                IntersectionResult::PositiveDimensional {
300                    dimension: remaining_dim,
301                    degree: self.compute_degree_if_easy(classes),
302                }
303            }
304            std::cmp::Ordering::Equal => {
305                // Transverse intersection
306                let count = self.compute_transverse_intersection(classes);
307                IntersectionResult::Finite(count)
308            }
309        }
310    }
311
312    /// Compute intersection number when codimensions sum to Grassmannian dimension
313    fn compute_transverse_intersection(&mut self, classes: &[SchubertClass]) -> u64 {
314        if classes.is_empty() {
315            return 1;
316        }
317
318        if classes.len() == 1 {
319            // Single class at top dimension
320            let (k, n) = self.grassmannian_dim;
321            let fundamental = vec![n - k; k];
322            if classes[0].partition == fundamental {
323                return 1;
324            } else {
325                return 0;
326            }
327        }
328
329        // Convert to partitions and iteratively multiply
330        let partitions: Vec<Partition> = classes.iter().map(|c| c.to_partition()).collect();
331
332        self.multiply_partitions(&partitions)
333    }
334
335    /// Multiply partitions and extract fundamental class coefficient
336    ///
337    /// # Rayon Parallelization
338    ///
339    /// When many partitions need to be multiplied and the intermediate
340    /// products generate many terms, parallel computation can speed this up.
341    fn multiply_partitions(&mut self, partitions: &[Partition]) -> u64 {
342        let (k, n) = self.grassmannian_dim;
343
344        // Start with first partition
345        let mut current: BTreeMap<Partition, u64> = BTreeMap::new();
346        current.insert(partitions[0].clone(), 1);
347
348        // Iteratively multiply
349        for partition in &partitions[1..] {
350            let next = self.multiply_step(&current, partition, k, n);
351            current = next;
352        }
353
354        // Extract coefficient of fundamental class
355        let fundamental = Partition::new(vec![n - k; k]);
356        current.get(&fundamental).copied().unwrap_or(0)
357    }
358
359    /// Single multiplication step (can be parallelized)
360    #[cfg(feature = "parallel")]
361    fn multiply_step(
362        &self,
363        current: &BTreeMap<Partition, u64>,
364        partition: &Partition,
365        k: usize,
366        n: usize,
367    ) -> BTreeMap<Partition, u64> {
368        // Parallel version: collect into pairs and merge
369        let pairs: Vec<_> = current.iter().collect();
370
371        let partial_results: Vec<BTreeMap<Partition, u64>> = pairs
372            .par_iter()
373            .map(|(nu, coeff)| {
374                let products = schubert_product(nu, partition, (k, n));
375                let mut local: BTreeMap<Partition, u64> = BTreeMap::new();
376                for (rho, lr_coeff) in products {
377                    *local.entry(rho).or_insert(0) += **coeff * lr_coeff;
378                }
379                local
380            })
381            .collect();
382
383        // Merge all partial results
384        let mut next: BTreeMap<Partition, u64> = BTreeMap::new();
385        for partial in partial_results {
386            for (rho, coeff) in partial {
387                *next.entry(rho).or_insert(0) += coeff;
388            }
389        }
390        next
391    }
392
393    /// Single multiplication step (sequential version)
394    #[cfg(not(feature = "parallel"))]
395    fn multiply_step(
396        &self,
397        current: &BTreeMap<Partition, u64>,
398        partition: &Partition,
399        k: usize,
400        n: usize,
401    ) -> BTreeMap<Partition, u64> {
402        let mut next: BTreeMap<Partition, u64> = BTreeMap::new();
403
404        for (nu, coeff) in current {
405            let products = schubert_product(nu, partition, (k, n));
406            for (rho, lr_coeff) in products {
407                *next.entry(rho).or_insert(0) += *coeff * lr_coeff;
408            }
409        }
410
411        next
412    }
413
414    fn compute_degree_if_easy(&self, _classes: &[SchubertClass]) -> Option<u64> {
415        // Degree computation for positive-dimensional intersection
416        // is more complex; return None for now
417        None
418    }
419
420    /// Get or compute LR coefficient with caching
421    ///
422    /// # Contract
423    ///
424    /// ```text
425    /// ensures: result == lr_coefficient(lambda, mu, nu)
426    /// ensures: lr_cached(lambda, mu, nu) == lr_cached(mu, lambda, nu)  // symmetry
427    /// ```
428    pub fn lr_cached(&mut self, lambda: &Partition, mu: &Partition, nu: &Partition) -> u64 {
429        // Normalize key (LR coefficients are symmetric in λ, μ)
430        let (a, b) = if lambda <= mu {
431            (lambda.clone(), mu.clone())
432        } else {
433            (mu.clone(), lambda.clone())
434        };
435
436        let key = (a, b, nu.clone());
437
438        if let Some(&cached) = self.lr_cache.get(&key) {
439            return cached;
440        }
441
442        let result = lr_coefficient(lambda, mu, nu);
443        self.lr_cache.insert(key, result);
444        result
445    }
446
447    /// Expand product of two Schubert classes
448    ///
449    /// # Contract
450    ///
451    /// ```text
452    /// requires: class1.grassmannian_dim == class2.grassmannian_dim == self.grassmannian_dim
453    /// ensures: forall (c, coeff) in result. coeff > 0
454    /// ensures: product(class1, class2) == product(class2, class1)  // commutativity
455    /// ```
456    #[must_use]
457    pub fn product(
458        &mut self,
459        class1: &SchubertClass,
460        class2: &SchubertClass,
461    ) -> Vec<(SchubertClass, u64)> {
462        let p1 = class1.to_partition();
463        let p2 = class2.to_partition();
464
465        let products = schubert_product(&p1, &p2, self.grassmannian_dim);
466
467        products
468            .into_iter()
469            .filter_map(|(partition, coeff)| {
470                SchubertClass::new(partition.parts, self.grassmannian_dim)
471                    .ok()
472                    .map(|class| (class, coeff))
473            })
474            .collect()
475    }
476
477    /// Multiply two Schubert classes using Pieri's rule (simplified)
478    pub fn pieri_multiply(
479        &self,
480        schubert_class: &SchubertClass,
481        special_class: usize,
482    ) -> EnumerativeResult<Vec<SchubertClass>> {
483        // Simplified Pieri rule - adds horizontal strips
484        let mut results = Vec::new();
485        let (k, n) = self.grassmannian_dim;
486
487        // Option 1: Add to the first row (if it exists)
488        if !schubert_class.partition.is_empty() {
489            let mut new_partition = schubert_class.partition.clone();
490            new_partition[0] += special_class;
491
492            // Check if this partition is valid
493            if new_partition[0] <= n - k {
494                if let Ok(new_class) = SchubertClass::new(new_partition, self.grassmannian_dim) {
495                    results.push(new_class);
496                }
497            }
498        }
499
500        // Option 2: Add a new row
501        let mut new_partition = schubert_class.partition.clone();
502        new_partition.push(special_class);
503
504        // Check if this partition is valid
505        if special_class <= n - k {
506            if let Ok(new_class) = SchubertClass::new(new_partition, self.grassmannian_dim) {
507                results.push(new_class);
508            }
509        }
510
511        // If no valid results, return the original approach
512        if results.is_empty() {
513            let mut new_partition = schubert_class.partition.clone();
514            if !new_partition.is_empty() {
515                new_partition[0] += special_class;
516            } else {
517                new_partition.push(special_class);
518            }
519            results.push(SchubertClass::new(new_partition, self.grassmannian_dim)?);
520        }
521
522        Ok(results)
523    }
524}
525
526/// Flag variety F(n1, n2, ..., nk; n)
527#[derive(Debug, Clone, PartialEq, Eq)]
528pub struct FlagVariety {
529    /// Dimensions of the flags
530    pub flag_dims: Vec<usize>,
531    /// Ambient dimension
532    pub ambient_dim: usize,
533}
534
535impl FlagVariety {
536    /// Create a new flag variety
537    ///
538    /// # Contract
539    ///
540    /// ```text
541    /// requires: flag_dims is strictly increasing
542    /// requires: flag_dims.last() < ambient_dim
543    /// ensures: result.dimension() >= 0
544    /// ```
545    pub fn new(flag_dims: Vec<usize>, ambient_dim: usize) -> EnumerativeResult<Self> {
546        // Validate that flag dimensions are increasing
547        for i in 1..flag_dims.len() {
548            if flag_dims[i] <= flag_dims[i - 1] {
549                return Err(EnumerativeError::SchubertError(
550                    "Flag dimensions must be strictly increasing".to_string(),
551                ));
552            }
553        }
554
555        if flag_dims.last().copied().unwrap_or(0) >= ambient_dim {
556            return Err(EnumerativeError::SchubertError(
557                "Largest flag dimension must be less than ambient dimension".to_string(),
558            ));
559        }
560
561        Ok(Self {
562            flag_dims,
563            ambient_dim,
564        })
565    }
566
567    /// Compute the dimension of the flag variety
568    #[must_use]
569    pub fn dimension(&self) -> usize {
570        let mut dim = 0;
571        let mut prev_dim = 0;
572
573        for &flag_dim in &self.flag_dims {
574            dim += (flag_dim - prev_dim) * (self.ambient_dim - prev_dim);
575            prev_dim = flag_dim;
576        }
577
578        dim
579    }
580}
581
582// ============================================================================
583// Parallel Batch Operations
584// ============================================================================
585
586/// Compute multiple Schubert intersections in parallel
587///
588/// # Contract
589///
590/// ```text
591/// requires: forall batch in batches. batch.classes all have same grassmannian_dim
592/// ensures: result.len() == batches.len()
593/// ```
594#[cfg(feature = "parallel")]
595pub fn multi_intersect_batch(
596    batches: &[(Vec<SchubertClass>, (usize, usize))],
597) -> Vec<IntersectionResult> {
598    batches
599        .par_iter()
600        .map(|(classes, grassmannian_dim)| {
601            let mut calc = SchubertCalculus::new(*grassmannian_dim);
602            calc.multi_intersect(classes)
603        })
604        .collect()
605}
606
607#[cfg(test)]
608mod tests {
609    use super::*;
610
611    #[test]
612    fn test_schubert_class_creation() {
613        let class = SchubertClass::new(vec![2, 1], (3, 6)).unwrap();
614        assert_eq!(class.partition, vec![2, 1]);
615        assert_eq!(class.codimension(), 3);
616    }
617
618    #[test]
619    fn test_intersection_result_default() {
620        let result = IntersectionResult::default();
621        assert_eq!(result, IntersectionResult::Empty);
622    }
623
624    #[test]
625    fn test_schubert_calculus_default() {
626        let calc = SchubertCalculus::default();
627        assert_eq!(calc.grassmannian_dim, (2, 4));
628    }
629
630    #[test]
631    fn test_multi_intersect_four_lines() {
632        // Classic: how many lines meet 4 general lines in P³?
633        // This is σ_1^4 in Gr(2,4) = 2
634        let mut calc = SchubertCalculus::new((2, 4));
635        let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
636
637        let classes = vec![
638            sigma_1.clone(),
639            sigma_1.clone(),
640            sigma_1.clone(),
641            sigma_1.clone(),
642        ];
643
644        let result = calc.multi_intersect(&classes);
645        assert_eq!(result, IntersectionResult::Finite(2));
646    }
647
648    #[test]
649    fn test_multi_intersect_underdetermined() {
650        let mut calc = SchubertCalculus::new((2, 4));
651        let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
652
653        // Only 2 conditions in Gr(2,4) which has dimension 4
654        let classes = vec![sigma_1.clone(), sigma_1.clone()];
655
656        let result = calc.multi_intersect(&classes);
657        assert!(matches!(
658            result,
659            IntersectionResult::PositiveDimensional { dimension: 2, .. }
660        ));
661    }
662
663    #[test]
664    fn test_multi_intersect_overdetermined() {
665        let mut calc = SchubertCalculus::new((2, 4));
666        let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
667
668        // 5 conditions exceeds dimension 4
669        let classes = vec![
670            sigma_1.clone(),
671            sigma_1.clone(),
672            sigma_1.clone(),
673            sigma_1.clone(),
674            sigma_1.clone(),
675        ];
676
677        let result = calc.multi_intersect(&classes);
678        assert_eq!(result, IntersectionResult::Empty);
679    }
680
681    #[test]
682    fn test_product() {
683        let mut calc = SchubertCalculus::new((2, 4));
684        let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
685
686        let products = calc.product(&sigma_1, &sigma_1);
687
688        // σ_1 · σ_1 = σ_2 + σ_{1,1}
689        assert_eq!(products.len(), 2);
690
691        let partitions: Vec<Vec<usize>> =
692            products.iter().map(|(c, _)| c.partition.clone()).collect();
693        assert!(partitions.contains(&vec![2]));
694        assert!(partitions.contains(&vec![1, 1]));
695    }
696
697    #[test]
698    fn test_partition_conversion() {
699        let class = SchubertClass::new(vec![3, 2, 1], (4, 8)).unwrap();
700        let partition = class.to_partition();
701        assert_eq!(partition.parts, vec![3, 2, 1]);
702
703        let class2 = SchubertClass::from_partition(partition, (4, 8)).unwrap();
704        assert_eq!(class2.partition, vec![3, 2, 1]);
705    }
706
707    #[test]
708    fn test_flag_variety() {
709        let flag = FlagVariety::new(vec![1, 2], 4).unwrap();
710        assert!(flag.dimension() > 0);
711    }
712}
713
714// ============================================================================
715// Parallel Batch Operation Tests
716// ============================================================================
717
718#[cfg(all(test, feature = "parallel"))]
719mod parallel_tests {
720    use super::*;
721
722    #[test]
723    fn test_multi_intersect_batch() {
724        // Test multiple intersection problems in parallel
725        let sigma_1_gr24 = SchubertClass::new(vec![1], (2, 4)).unwrap();
726        let sigma_1_gr25 = SchubertClass::new(vec![1], (2, 5)).unwrap();
727
728        let batches = vec![
729            // σ_1^4 in Gr(2,4) = 2
730            (vec![sigma_1_gr24.clone(); 4], (2, 4)),
731            // σ_1^6 in Gr(2,5) should be finite
732            (vec![sigma_1_gr25.clone(); 6], (2, 5)),
733            // Overdetermined: 5 conditions in dim 4
734            (vec![sigma_1_gr24.clone(); 5], (2, 4)),
735            // Underdetermined: 2 conditions in dim 4
736            (vec![sigma_1_gr24.clone(); 2], (2, 4)),
737        ];
738
739        let results = multi_intersect_batch(&batches);
740
741        assert_eq!(results.len(), 4);
742        assert_eq!(results[0], IntersectionResult::Finite(2));
743        assert!(matches!(results[1], IntersectionResult::Finite(_)));
744        assert_eq!(results[2], IntersectionResult::Empty);
745        assert!(matches!(
746            results[3],
747            IntersectionResult::PositiveDimensional { dimension: 2, .. }
748        ));
749    }
750
751    #[test]
752    fn test_multi_intersect_batch_empty() {
753        let batches: Vec<(Vec<SchubertClass>, (usize, usize))> = vec![];
754        let results = multi_intersect_batch(&batches);
755        assert!(results.is_empty());
756    }
757
758    #[test]
759    fn test_multi_intersect_batch_single() {
760        let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
761        let batches = vec![(vec![sigma_1; 4], (2, 4))];
762
763        let results = multi_intersect_batch(&batches);
764        assert_eq!(results.len(), 1);
765        assert_eq!(results[0], IntersectionResult::Finite(2));
766    }
767}