1use crate::littlewood_richardson::{lr_coefficient, schubert_product, Partition};
21use crate::{ChowClass, EnumerativeError, EnumerativeResult};
22use num_rational::Rational64;
23use std::collections::{BTreeMap, HashMap};
24
25#[cfg(feature = "parallel")]
26use rayon::prelude::*;
27
28#[derive(Debug, Clone, PartialEq, Eq, Default)]
46pub enum IntersectionResult {
47 #[default]
49 Empty,
50 Finite(u64),
52 PositiveDimensional {
54 dimension: usize,
56 degree: Option<u64>,
58 },
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub struct SchubertClass {
69 pub partition: Vec<usize>,
71 pub grassmannian_dim: (usize, usize), }
74
75impl SchubertClass {
76 pub fn new(partition: Vec<usize>, grassmannian_dim: (usize, usize)) -> EnumerativeResult<Self> {
89 let (k, n) = grassmannian_dim;
90
91 for &part in &partition {
93 if part > n - k {
94 return Err(EnumerativeError::SchubertError(format!(
95 "Partition entry {} exceeds n-k = {}",
96 part,
97 n - k
98 )));
99 }
100 }
101
102 Ok(Self {
103 partition,
104 grassmannian_dim,
105 })
106 }
107
108 pub fn from_partition(
110 partition: Partition,
111 grassmannian_dim: (usize, usize),
112 ) -> EnumerativeResult<Self> {
113 Self::new(partition.parts, grassmannian_dim)
114 }
115
116 #[must_use]
118 pub fn to_partition(&self) -> Partition {
119 Partition::new(self.partition.clone())
120 }
121
122 #[must_use]
124 pub fn to_chow_class(&self) -> ChowClass {
125 let codimension = self.partition.iter().sum::<usize>();
126 let degree = Rational64::from(1);
127
128 ChowClass::new(codimension, degree)
129 }
130
131 #[must_use]
140 pub fn dimension(&self) -> usize {
141 let (k, n) = self.grassmannian_dim;
142 let total_dim = k * (n - k);
143 let codim = self.partition.iter().sum::<usize>();
144 total_dim - codim
145 }
146
147 #[must_use]
155 pub fn codimension(&self) -> usize {
156 self.partition.iter().sum()
157 }
158
159 #[must_use]
161 pub fn power(&self, exponent: usize) -> SchubertClass {
162 let mut new_partition = self.partition.clone();
164 for _ in 1..exponent {
165 if !new_partition.is_empty() {
166 new_partition[0] += 1;
167 } else {
168 new_partition.push(1);
169 }
170 }
171
172 SchubertClass {
173 partition: new_partition,
174 grassmannian_dim: self.grassmannian_dim,
175 }
176 }
177
178 pub fn giambelli_determinant(
180 partition: &[usize],
181 grassmannian_dim: (usize, usize),
182 ) -> EnumerativeResult<Self> {
183 Self::new(partition.to_vec(), grassmannian_dim)
184 }
185}
186
187#[derive(Debug)]
189pub struct SchubertCalculus {
190 pub grassmannian_dim: (usize, usize),
192 intersection_cache: HashMap<(Vec<usize>, Vec<usize>), Rational64>,
194 lr_cache: BTreeMap<(Partition, Partition, Partition), u64>,
196}
197
198impl Default for SchubertCalculus {
199 fn default() -> Self {
200 Self::new((2, 4)) }
202}
203
204impl SchubertCalculus {
205 #[must_use]
207 pub fn new(grassmannian_dim: (usize, usize)) -> Self {
208 Self {
209 grassmannian_dim,
210 intersection_cache: HashMap::new(),
211 lr_cache: BTreeMap::new(),
212 }
213 }
214
215 #[must_use]
223 pub fn grassmannian_dimension(&self) -> usize {
224 let (k, n) = self.grassmannian_dim;
225 k * (n - k)
226 }
227
228 pub fn intersection_number(
239 &mut self,
240 class1: &SchubertClass,
241 class2: &SchubertClass,
242 ) -> EnumerativeResult<Rational64> {
243 let key = (class1.partition.clone(), class2.partition.clone());
245 if let Some(&cached) = self.intersection_cache.get(&key) {
246 return Ok(cached);
247 }
248
249 let result = if class1.dimension() + class2.dimension() == self.grassmannian_dimension() {
251 let p1 = class1.to_partition();
253 let p2 = class2.to_partition();
254 let (k, n) = self.grassmannian_dim;
255 let fundamental = Partition::new(vec![n - k; k]);
256
257 let coeff = lr_coefficient(&p1, &p2, &fundamental);
258 Rational64::from(coeff as i64)
259 } else {
260 Rational64::from(0)
261 };
262
263 self.intersection_cache.insert(key, result);
265 Ok(result)
266 }
267
268 pub fn multi_intersect(&mut self, classes: &[SchubertClass]) -> IntersectionResult {
283 if classes.is_empty() {
284 return IntersectionResult::PositiveDimensional {
285 dimension: self.grassmannian_dimension(),
286 degree: Some(1),
287 };
288 }
289
290 let grassmannian_dim = self.grassmannian_dimension();
291
292 let total_codim: usize = classes.iter().map(|c| c.codimension()).sum();
294
295 match total_codim.cmp(&grassmannian_dim) {
296 std::cmp::Ordering::Greater => IntersectionResult::Empty,
297 std::cmp::Ordering::Less => {
298 let remaining_dim = grassmannian_dim - total_codim;
299 IntersectionResult::PositiveDimensional {
300 dimension: remaining_dim,
301 degree: self.compute_degree_if_easy(classes),
302 }
303 }
304 std::cmp::Ordering::Equal => {
305 let count = self.compute_transverse_intersection(classes);
307 IntersectionResult::Finite(count)
308 }
309 }
310 }
311
312 fn compute_transverse_intersection(&mut self, classes: &[SchubertClass]) -> u64 {
314 if classes.is_empty() {
315 return 1;
316 }
317
318 if classes.len() == 1 {
319 let (k, n) = self.grassmannian_dim;
321 let fundamental = vec![n - k; k];
322 if classes[0].partition == fundamental {
323 return 1;
324 } else {
325 return 0;
326 }
327 }
328
329 let partitions: Vec<Partition> = classes.iter().map(|c| c.to_partition()).collect();
331
332 self.multiply_partitions(&partitions)
333 }
334
335 fn multiply_partitions(&mut self, partitions: &[Partition]) -> u64 {
342 let (k, n) = self.grassmannian_dim;
343
344 let mut current: BTreeMap<Partition, u64> = BTreeMap::new();
346 current.insert(partitions[0].clone(), 1);
347
348 for partition in &partitions[1..] {
350 let next = self.multiply_step(¤t, partition, k, n);
351 current = next;
352 }
353
354 let fundamental = Partition::new(vec![n - k; k]);
356 current.get(&fundamental).copied().unwrap_or(0)
357 }
358
359 #[cfg(feature = "parallel")]
361 fn multiply_step(
362 &self,
363 current: &BTreeMap<Partition, u64>,
364 partition: &Partition,
365 k: usize,
366 n: usize,
367 ) -> BTreeMap<Partition, u64> {
368 let pairs: Vec<_> = current.iter().collect();
370
371 let partial_results: Vec<BTreeMap<Partition, u64>> = pairs
372 .par_iter()
373 .map(|(nu, coeff)| {
374 let products = schubert_product(nu, partition, (k, n));
375 let mut local: BTreeMap<Partition, u64> = BTreeMap::new();
376 for (rho, lr_coeff) in products {
377 *local.entry(rho).or_insert(0) += **coeff * lr_coeff;
378 }
379 local
380 })
381 .collect();
382
383 let mut next: BTreeMap<Partition, u64> = BTreeMap::new();
385 for partial in partial_results {
386 for (rho, coeff) in partial {
387 *next.entry(rho).or_insert(0) += coeff;
388 }
389 }
390 next
391 }
392
393 #[cfg(not(feature = "parallel"))]
395 fn multiply_step(
396 &self,
397 current: &BTreeMap<Partition, u64>,
398 partition: &Partition,
399 k: usize,
400 n: usize,
401 ) -> BTreeMap<Partition, u64> {
402 let mut next: BTreeMap<Partition, u64> = BTreeMap::new();
403
404 for (nu, coeff) in current {
405 let products = schubert_product(nu, partition, (k, n));
406 for (rho, lr_coeff) in products {
407 *next.entry(rho).or_insert(0) += *coeff * lr_coeff;
408 }
409 }
410
411 next
412 }
413
414 fn compute_degree_if_easy(&self, _classes: &[SchubertClass]) -> Option<u64> {
415 None
418 }
419
420 pub fn lr_cached(&mut self, lambda: &Partition, mu: &Partition, nu: &Partition) -> u64 {
429 let (a, b) = if lambda <= mu {
431 (lambda.clone(), mu.clone())
432 } else {
433 (mu.clone(), lambda.clone())
434 };
435
436 let key = (a, b, nu.clone());
437
438 if let Some(&cached) = self.lr_cache.get(&key) {
439 return cached;
440 }
441
442 let result = lr_coefficient(lambda, mu, nu);
443 self.lr_cache.insert(key, result);
444 result
445 }
446
447 #[must_use]
457 pub fn product(
458 &mut self,
459 class1: &SchubertClass,
460 class2: &SchubertClass,
461 ) -> Vec<(SchubertClass, u64)> {
462 let p1 = class1.to_partition();
463 let p2 = class2.to_partition();
464
465 let products = schubert_product(&p1, &p2, self.grassmannian_dim);
466
467 products
468 .into_iter()
469 .filter_map(|(partition, coeff)| {
470 SchubertClass::new(partition.parts, self.grassmannian_dim)
471 .ok()
472 .map(|class| (class, coeff))
473 })
474 .collect()
475 }
476
477 pub fn pieri_multiply(
479 &self,
480 schubert_class: &SchubertClass,
481 special_class: usize,
482 ) -> EnumerativeResult<Vec<SchubertClass>> {
483 let mut results = Vec::new();
485 let (k, n) = self.grassmannian_dim;
486
487 if !schubert_class.partition.is_empty() {
489 let mut new_partition = schubert_class.partition.clone();
490 new_partition[0] += special_class;
491
492 if new_partition[0] <= n - k {
494 if let Ok(new_class) = SchubertClass::new(new_partition, self.grassmannian_dim) {
495 results.push(new_class);
496 }
497 }
498 }
499
500 let mut new_partition = schubert_class.partition.clone();
502 new_partition.push(special_class);
503
504 if special_class <= n - k {
506 if let Ok(new_class) = SchubertClass::new(new_partition, self.grassmannian_dim) {
507 results.push(new_class);
508 }
509 }
510
511 if results.is_empty() {
513 let mut new_partition = schubert_class.partition.clone();
514 if !new_partition.is_empty() {
515 new_partition[0] += special_class;
516 } else {
517 new_partition.push(special_class);
518 }
519 results.push(SchubertClass::new(new_partition, self.grassmannian_dim)?);
520 }
521
522 Ok(results)
523 }
524}
525
526#[derive(Debug, Clone, PartialEq, Eq)]
528pub struct FlagVariety {
529 pub flag_dims: Vec<usize>,
531 pub ambient_dim: usize,
533}
534
535impl FlagVariety {
536 pub fn new(flag_dims: Vec<usize>, ambient_dim: usize) -> EnumerativeResult<Self> {
546 for i in 1..flag_dims.len() {
548 if flag_dims[i] <= flag_dims[i - 1] {
549 return Err(EnumerativeError::SchubertError(
550 "Flag dimensions must be strictly increasing".to_string(),
551 ));
552 }
553 }
554
555 if flag_dims.last().copied().unwrap_or(0) >= ambient_dim {
556 return Err(EnumerativeError::SchubertError(
557 "Largest flag dimension must be less than ambient dimension".to_string(),
558 ));
559 }
560
561 Ok(Self {
562 flag_dims,
563 ambient_dim,
564 })
565 }
566
567 #[must_use]
569 pub fn dimension(&self) -> usize {
570 let mut dim = 0;
571 let mut prev_dim = 0;
572
573 for &flag_dim in &self.flag_dims {
574 dim += (flag_dim - prev_dim) * (self.ambient_dim - prev_dim);
575 prev_dim = flag_dim;
576 }
577
578 dim
579 }
580}
581
582#[cfg(feature = "parallel")]
595pub fn multi_intersect_batch(
596 batches: &[(Vec<SchubertClass>, (usize, usize))],
597) -> Vec<IntersectionResult> {
598 batches
599 .par_iter()
600 .map(|(classes, grassmannian_dim)| {
601 let mut calc = SchubertCalculus::new(*grassmannian_dim);
602 calc.multi_intersect(classes)
603 })
604 .collect()
605}
606
607#[cfg(test)]
608mod tests {
609 use super::*;
610
611 #[test]
612 fn test_schubert_class_creation() {
613 let class = SchubertClass::new(vec![2, 1], (3, 6)).unwrap();
614 assert_eq!(class.partition, vec![2, 1]);
615 assert_eq!(class.codimension(), 3);
616 }
617
618 #[test]
619 fn test_intersection_result_default() {
620 let result = IntersectionResult::default();
621 assert_eq!(result, IntersectionResult::Empty);
622 }
623
624 #[test]
625 fn test_schubert_calculus_default() {
626 let calc = SchubertCalculus::default();
627 assert_eq!(calc.grassmannian_dim, (2, 4));
628 }
629
630 #[test]
631 fn test_multi_intersect_four_lines() {
632 let mut calc = SchubertCalculus::new((2, 4));
635 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
636
637 let classes = vec![
638 sigma_1.clone(),
639 sigma_1.clone(),
640 sigma_1.clone(),
641 sigma_1.clone(),
642 ];
643
644 let result = calc.multi_intersect(&classes);
645 assert_eq!(result, IntersectionResult::Finite(2));
646 }
647
648 #[test]
649 fn test_multi_intersect_underdetermined() {
650 let mut calc = SchubertCalculus::new((2, 4));
651 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
652
653 let classes = vec![sigma_1.clone(), sigma_1.clone()];
655
656 let result = calc.multi_intersect(&classes);
657 assert!(matches!(
658 result,
659 IntersectionResult::PositiveDimensional { dimension: 2, .. }
660 ));
661 }
662
663 #[test]
664 fn test_multi_intersect_overdetermined() {
665 let mut calc = SchubertCalculus::new((2, 4));
666 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
667
668 let classes = vec![
670 sigma_1.clone(),
671 sigma_1.clone(),
672 sigma_1.clone(),
673 sigma_1.clone(),
674 sigma_1.clone(),
675 ];
676
677 let result = calc.multi_intersect(&classes);
678 assert_eq!(result, IntersectionResult::Empty);
679 }
680
681 #[test]
682 fn test_product() {
683 let mut calc = SchubertCalculus::new((2, 4));
684 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
685
686 let products = calc.product(&sigma_1, &sigma_1);
687
688 assert_eq!(products.len(), 2);
690
691 let partitions: Vec<Vec<usize>> =
692 products.iter().map(|(c, _)| c.partition.clone()).collect();
693 assert!(partitions.contains(&vec![2]));
694 assert!(partitions.contains(&vec![1, 1]));
695 }
696
697 #[test]
698 fn test_partition_conversion() {
699 let class = SchubertClass::new(vec![3, 2, 1], (4, 8)).unwrap();
700 let partition = class.to_partition();
701 assert_eq!(partition.parts, vec![3, 2, 1]);
702
703 let class2 = SchubertClass::from_partition(partition, (4, 8)).unwrap();
704 assert_eq!(class2.partition, vec![3, 2, 1]);
705 }
706
707 #[test]
708 fn test_flag_variety() {
709 let flag = FlagVariety::new(vec![1, 2], 4).unwrap();
710 assert!(flag.dimension() > 0);
711 }
712}
713
714#[cfg(all(test, feature = "parallel"))]
719mod parallel_tests {
720 use super::*;
721
722 #[test]
723 fn test_multi_intersect_batch() {
724 let sigma_1_gr24 = SchubertClass::new(vec![1], (2, 4)).unwrap();
726 let sigma_1_gr25 = SchubertClass::new(vec![1], (2, 5)).unwrap();
727
728 let batches = vec![
729 (vec![sigma_1_gr24.clone(); 4], (2, 4)),
731 (vec![sigma_1_gr25.clone(); 6], (2, 5)),
733 (vec![sigma_1_gr24.clone(); 5], (2, 4)),
735 (vec![sigma_1_gr24.clone(); 2], (2, 4)),
737 ];
738
739 let results = multi_intersect_batch(&batches);
740
741 assert_eq!(results.len(), 4);
742 assert_eq!(results[0], IntersectionResult::Finite(2));
743 assert!(matches!(results[1], IntersectionResult::Finite(_)));
744 assert_eq!(results[2], IntersectionResult::Empty);
745 assert!(matches!(
746 results[3],
747 IntersectionResult::PositiveDimensional { dimension: 2, .. }
748 ));
749 }
750
751 #[test]
752 fn test_multi_intersect_batch_empty() {
753 let batches: Vec<(Vec<SchubertClass>, (usize, usize))> = vec![];
754 let results = multi_intersect_batch(&batches);
755 assert!(results.is_empty());
756 }
757
758 #[test]
759 fn test_multi_intersect_batch_single() {
760 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
761 let batches = vec![(vec![sigma_1; 4], (2, 4))];
762
763 let results = multi_intersect_batch(&batches);
764 assert_eq!(results.len(), 1);
765 assert_eq!(results[0], IntersectionResult::Finite(2));
766 }
767}