1use crate::schubert::{IntersectionResult, SchubertCalculus, SchubertClass};
21use crate::EnumerativeResult;
22use num_rational::Rational64;
23
24#[derive(Debug, Clone)]
29pub struct TorusWeights {
30 pub weights: Vec<i64>,
32}
33
34impl TorusWeights {
35 #[must_use]
45 pub fn standard(n: usize) -> Self {
46 Self {
47 weights: (1..=n as i64).collect(),
48 }
49 }
50
51 pub fn custom(weights: Vec<i64>) -> EnumerativeResult<Self> {
60 if weights.contains(&0) {
61 return Err(crate::EnumerativeError::InvalidDimension(
62 "Torus weights must be nonzero".to_string(),
63 ));
64 }
65 let mut sorted = weights.clone();
66 sorted.sort_unstable();
67 if sorted.windows(2).any(|w| w[0] == w[1]) {
68 return Err(crate::EnumerativeError::InvalidDimension(
69 "Torus weights must be distinct".to_string(),
70 ));
71 }
72 Ok(Self { weights })
73 }
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
81pub struct FixedPoint {
82 pub subset: Vec<usize>,
84 pub grassmannian: (usize, usize),
86}
87
88impl FixedPoint {
89 pub fn new(mut subset: Vec<usize>, grassmannian: (usize, usize)) -> EnumerativeResult<Self> {
98 let (k, n) = grassmannian;
99 if subset.len() != k {
100 return Err(crate::EnumerativeError::InvalidDimension(format!(
101 "Fixed point subset must have {} elements, got {}",
102 k,
103 subset.len()
104 )));
105 }
106 if subset.iter().any(|&i| i >= n) {
107 return Err(crate::EnumerativeError::InvalidDimension(format!(
108 "Subset elements must be < {n}"
109 )));
110 }
111 subset.sort_unstable();
112 Ok(Self {
113 subset,
114 grassmannian,
115 })
116 }
117
118 #[must_use]
130 pub fn tangent_euler_class(&self, weights: &TorusWeights) -> Rational64 {
131 let (_, n) = self.grassmannian;
132 let complement: Vec<usize> = (0..n).filter(|i| !self.subset.contains(i)).collect();
133
134 let mut product = Rational64::from(1);
135 for &i in &self.subset {
136 for &j in &complement {
137 product *= Rational64::from(weights.weights[j] - weights.weights[i]);
138 }
139 }
140 product
141 }
142
143 #[must_use]
147 pub fn to_partition(&self) -> Vec<usize> {
148 let mut partition: Vec<usize> = self
149 .subset
150 .iter()
151 .enumerate()
152 .map(|(a, &i_a)| i_a - a)
153 .collect();
154 partition.sort_unstable_by(|a, b| b.cmp(a));
155 partition.retain(|&x| x > 0);
156 partition
157 }
158}
159
160#[derive(Debug)]
165pub struct EquivariantLocalizer {
166 pub grassmannian: (usize, usize),
168 pub weights: TorusWeights,
170 fixed_points: Option<Vec<FixedPoint>>,
172 schubert_engine: SchubertCalculus,
174}
175
176impl EquivariantLocalizer {
177 pub fn new(grassmannian: (usize, usize)) -> EnumerativeResult<Self> {
186 let (k, n) = grassmannian;
187 if k > n {
188 return Err(crate::EnumerativeError::InvalidDimension(format!(
189 "k={k} must be <= n={n} for Gr(k,n)"
190 )));
191 }
192 Ok(Self {
193 grassmannian,
194 weights: TorusWeights::standard(n),
195 fixed_points: None,
196 schubert_engine: SchubertCalculus::new(grassmannian),
197 })
198 }
199
200 pub fn with_weights(
209 grassmannian: (usize, usize),
210 weights: TorusWeights,
211 ) -> EnumerativeResult<Self> {
212 let (k, n) = grassmannian;
213 if k > n {
214 return Err(crate::EnumerativeError::InvalidDimension(format!(
215 "k={k} must be <= n={n}"
216 )));
217 }
218 if weights.weights.len() != n {
219 return Err(crate::EnumerativeError::InvalidDimension(format!(
220 "Need {n} weights, got {}",
221 weights.weights.len()
222 )));
223 }
224 Ok(Self {
225 grassmannian,
226 weights,
227 fixed_points: None,
228 schubert_engine: SchubertCalculus::new(grassmannian),
229 })
230 }
231
232 #[must_use]
234 pub fn fixed_point_count(&self) -> usize {
235 let (k, n) = self.grassmannian;
236 binomial_usize(n, k)
237 }
238
239 fn ensure_fixed_points(&mut self) {
241 if self.fixed_points.is_some() {
242 return;
243 }
244 let (k, n) = self.grassmannian;
245 let subsets = k_subsets(n, k);
246 let points: Vec<FixedPoint> = subsets
247 .into_iter()
248 .map(|s| FixedPoint {
249 subset: s,
250 grassmannian: self.grassmannian,
251 })
252 .collect();
253 self.fixed_points = Some(points);
254 }
255
256 pub fn fixed_points(&mut self) -> &[FixedPoint] {
258 self.ensure_fixed_points();
259 self.fixed_points.as_ref().unwrap()
260 }
261
262 pub fn localized_intersection(&mut self, classes: &[SchubertClass]) -> Rational64 {
274 let result = self.schubert_engine.multi_intersect(classes);
275 match result {
276 IntersectionResult::Finite(n) => Rational64::from(n as i64),
277 _ => Rational64::from(0),
278 }
279 }
280
281 pub fn intersection_result(&mut self, classes: &[SchubertClass]) -> IntersectionResult {
294 self.schubert_engine.multi_intersect(classes)
295 }
296
297 pub fn fixed_point_analysis(&mut self) -> Vec<(&FixedPoint, Rational64)> {
308 self.ensure_fixed_points();
309 let weights = &self.weights;
310 self.fixed_points
311 .as_ref()
312 .unwrap()
313 .iter()
314 .map(|fp| {
315 let euler = fp.tangent_euler_class(weights);
316 (fp, euler)
317 })
318 .collect()
319 }
320
321 #[cfg(feature = "parallel")]
323 pub fn localized_intersection_parallel(&mut self, classes: &[SchubertClass]) -> Rational64 {
324 self.localized_intersection(classes)
326 }
327
328 pub fn fq_point_count(&mut self, class: &SchubertClass, q: u64) -> EnumerativeResult<u64> {
333 self.ensure_fixed_points();
334
335 let lambda = crate::littlewood_richardson::Partition::new(class.partition.clone());
336 let mut count = 0u64;
337
338 for fp in self.fixed_points.as_ref().unwrap() {
339 let mu = crate::littlewood_richardson::Partition::new(fp.to_partition());
340 if lambda.contains(&mu) {
341 count += q.pow(mu.size() as u32);
342 }
343 }
344 Ok(count)
345 }
346
347 pub fn atiyah_bott_intersection(
352 &mut self,
353 classes: &[SchubertClass],
354 ) -> EnumerativeResult<Rational64> {
355 self.ensure_fixed_points();
356
357 let weights = &self.weights;
358 let mut total = Rational64::from(0);
359
360 for fp in self.fixed_points.as_ref().unwrap() {
361 let euler = fp.tangent_euler_class(weights);
362 if euler == Rational64::from(0) {
363 continue;
364 }
365
366 let mut product = Rational64::from(1);
367 for class in classes {
368 let restriction = self.schubert_restriction(class, fp);
369 product *= restriction;
370 }
371
372 total += product / euler;
373 }
374
375 Ok(total)
376 }
377
378 pub fn schubert_restriction(
383 &self,
384 class: &SchubertClass,
385 fixed_point: &FixedPoint,
386 ) -> Rational64 {
387 let (k, n) = self.grassmannian;
388 let m = n - k;
389 let partition = &class.partition;
390
391 let mut lambda = vec![0usize; k];
393 for (i, &p) in partition.iter().enumerate() {
394 if i < k {
395 lambda[i] = p;
396 }
397 }
398
399 let subset = &fixed_point.subset;
400 let complement: Vec<usize> = (0..n).filter(|i| !subset.contains(i)).collect();
401
402 let mut product = Rational64::from(1);
403
404 for (i, &li) in lambda.iter().enumerate() {
407 for j in 0..li {
408 if i < k && j < m {
414 let row_idx = k - 1 - i; let col_idx = m - 1 - j; if row_idx < subset.len() && col_idx < complement.len() {
417 let t_i = self.weights.weights[subset[row_idx]];
418 let t_j = self.weights.weights[complement[col_idx]];
419 product *= Rational64::from(t_i - t_j);
420 }
421 }
422 }
423 }
424
425 product
426 }
427}
428
429fn k_subsets(n: usize, k: usize) -> Vec<Vec<usize>> {
431 let mut result = Vec::new();
432 let mut current = Vec::with_capacity(k);
433 generate_subsets(n, k, 0, &mut current, &mut result);
434 result
435}
436
437fn generate_subsets(
438 n: usize,
439 k: usize,
440 start: usize,
441 current: &mut Vec<usize>,
442 result: &mut Vec<Vec<usize>>,
443) {
444 if current.len() == k {
445 result.push(current.clone());
446 return;
447 }
448 let remaining = k - current.len();
449 for i in start..=(n - remaining) {
450 current.push(i);
451 generate_subsets(n, k, i + 1, current, result);
452 current.pop();
453 }
454}
455
456fn binomial_usize(n: usize, k: usize) -> usize {
458 if k > n {
459 return 0;
460 }
461 if k == 0 || k == n {
462 return 1;
463 }
464 let k = k.min(n - k);
465 let mut result: usize = 1;
466 for i in 0..k {
467 result = result * (n - i) / (i + 1);
468 }
469 result
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 #[test]
477 fn test_fixed_point_count() {
478 let loc = EquivariantLocalizer::new((2, 4)).unwrap();
479 assert_eq!(loc.fixed_point_count(), 6); }
481
482 #[test]
483 fn test_fixed_point_count_gr35() {
484 let loc = EquivariantLocalizer::new((3, 5)).unwrap();
485 assert_eq!(loc.fixed_point_count(), 10); }
487
488 #[test]
489 fn test_tangent_euler_class_nonzero() {
490 let weights = TorusWeights::standard(4);
491 let fp = FixedPoint::new(vec![0, 1], (2, 4)).unwrap();
492 let euler = fp.tangent_euler_class(&weights);
493 assert_ne!(euler, Rational64::from(0));
494 }
495
496 #[test]
497 fn test_fixed_point_partition() {
498 let fp = FixedPoint::new(vec![0, 2], (2, 4)).unwrap();
500 assert_eq!(fp.to_partition(), vec![1]);
501 }
502
503 #[test]
504 fn test_localization_four_lines() {
505 let mut loc = EquivariantLocalizer::new((2, 4)).unwrap();
508
509 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
510 let classes = vec![sigma_1.clone(), sigma_1.clone(), sigma_1.clone(), sigma_1];
511
512 let result = loc.localized_intersection(&classes);
513 assert_eq!(result, Rational64::from(2));
514 }
515
516 #[test]
517 fn test_localization_point_class() {
518 let mut loc = EquivariantLocalizer::new((2, 4)).unwrap();
520
521 let sigma_22 = SchubertClass::new(vec![2, 2], (2, 4)).unwrap();
522 let classes = vec![sigma_22];
523
524 let result = loc.localized_intersection(&classes);
525 assert_eq!(result, Rational64::from(1));
526 }
527
528 #[test]
529 fn test_localization_sigma1_squared_gr24() {
530 let mut loc = EquivariantLocalizer::new((2, 4)).unwrap();
532
533 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
534 let classes = vec![sigma_1.clone(), sigma_1];
535
536 let result = loc.intersection_result(&classes);
537 assert!(matches!(
538 result,
539 IntersectionResult::PositiveDimensional { dimension: 2, .. }
540 ));
541 }
542
543 #[test]
544 fn test_localization_overdetermined() {
545 let mut loc = EquivariantLocalizer::new((2, 4)).unwrap();
546
547 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
548 let classes = vec![
550 sigma_1.clone(),
551 sigma_1.clone(),
552 sigma_1.clone(),
553 sigma_1.clone(),
554 sigma_1,
555 ];
556
557 let result = loc.intersection_result(&classes);
558 assert_eq!(result, IntersectionResult::Empty);
559 }
560
561 #[test]
562 fn test_localization_sigma1_cubed_sigma1() {
563 let mut loc = EquivariantLocalizer::new((2, 4)).unwrap();
565
566 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
567 let classes = vec![sigma_1.clone(), sigma_1.clone(), sigma_1.clone(), sigma_1];
568
569 let result = loc.intersection_result(&classes);
570 assert_eq!(result, IntersectionResult::Finite(2));
571 }
572
573 #[test]
574 fn test_custom_weights() {
575 let weights = TorusWeights::custom(vec![1, 3, 7, 11]).unwrap();
577 let mut loc = EquivariantLocalizer::with_weights((2, 4), weights).unwrap();
578
579 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
580 let classes = vec![sigma_1.clone(), sigma_1.clone(), sigma_1.clone(), sigma_1];
581
582 let result = loc.localized_intersection(&classes);
583 assert_eq!(result, Rational64::from(2));
584 }
585
586 #[test]
587 fn test_invalid_weights_zero() {
588 let result = TorusWeights::custom(vec![0, 1, 2, 3]);
589 assert!(result.is_err());
590 }
591
592 #[test]
593 fn test_invalid_weights_duplicate() {
594 let result = TorusWeights::custom(vec![1, 2, 2, 3]);
595 assert!(result.is_err());
596 }
597
598 #[test]
599 fn test_k_subsets() {
600 let subs = k_subsets(4, 2);
601 assert_eq!(subs.len(), 6);
602 assert_eq!(subs[0], vec![0, 1]);
603 assert_eq!(subs[5], vec![2, 3]);
604 }
605
606 #[test]
607 fn test_fixed_point_analysis() {
608 let mut loc = EquivariantLocalizer::new((2, 4)).unwrap();
609 let analysis = loc.fixed_point_analysis();
610 assert_eq!(analysis.len(), 6);
611 for (_, euler) in &analysis {
613 assert_ne!(*euler, Rational64::from(0));
614 }
615 }
616
617 #[cfg(feature = "parallel")]
618 #[test]
619 fn test_parallel_agrees() {
620 let mut loc = EquivariantLocalizer::new((2, 4)).unwrap();
621
622 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
623 let classes = vec![sigma_1.clone(), sigma_1.clone(), sigma_1.clone(), sigma_1];
624
625 let sequential = loc.localized_intersection(&classes);
626 let parallel = loc.localized_intersection_parallel(&classes);
627 assert_eq!(sequential, parallel);
628 }
629}