1use crate::schubert::{IntersectionResult, SchubertCalculus, SchubertClass};
21use crate::EnumerativeResult;
22use num_rational::Rational64;
23
24#[derive(Debug, Clone)]
29pub struct TorusWeights {
30 pub weights: Vec<i64>,
32}
33
34impl TorusWeights {
35 #[must_use]
45 pub fn standard(n: usize) -> Self {
46 Self {
47 weights: (1..=n as i64).collect(),
48 }
49 }
50
51 pub fn custom(weights: Vec<i64>) -> EnumerativeResult<Self> {
60 if weights.contains(&0) {
61 return Err(crate::EnumerativeError::InvalidDimension(
62 "Torus weights must be nonzero".to_string(),
63 ));
64 }
65 let mut sorted = weights.clone();
66 sorted.sort_unstable();
67 if sorted.windows(2).any(|w| w[0] == w[1]) {
68 return Err(crate::EnumerativeError::InvalidDimension(
69 "Torus weights must be distinct".to_string(),
70 ));
71 }
72 Ok(Self { weights })
73 }
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
81pub struct FixedPoint {
82 pub subset: Vec<usize>,
84 pub grassmannian: (usize, usize),
86}
87
88impl FixedPoint {
89 pub fn new(mut subset: Vec<usize>, grassmannian: (usize, usize)) -> EnumerativeResult<Self> {
98 let (k, n) = grassmannian;
99 if subset.len() != k {
100 return Err(crate::EnumerativeError::InvalidDimension(format!(
101 "Fixed point subset must have {} elements, got {}",
102 k,
103 subset.len()
104 )));
105 }
106 if subset.iter().any(|&i| i >= n) {
107 return Err(crate::EnumerativeError::InvalidDimension(format!(
108 "Subset elements must be < {n}"
109 )));
110 }
111 subset.sort_unstable();
112 Ok(Self {
113 subset,
114 grassmannian,
115 })
116 }
117
118 #[must_use]
130 pub fn tangent_euler_class(&self, weights: &TorusWeights) -> Rational64 {
131 let (_, n) = self.grassmannian;
132 let complement: Vec<usize> = (0..n).filter(|i| !self.subset.contains(i)).collect();
133
134 let mut product = Rational64::from(1);
135 for &i in &self.subset {
136 for &j in &complement {
137 product *= Rational64::from(weights.weights[j] - weights.weights[i]);
138 }
139 }
140 product
141 }
142
143 #[must_use]
147 pub fn to_partition(&self) -> Vec<usize> {
148 let mut partition: Vec<usize> = self
149 .subset
150 .iter()
151 .enumerate()
152 .map(|(a, &i_a)| i_a - a)
153 .collect();
154 partition.sort_unstable_by(|a, b| b.cmp(a));
155 partition.retain(|&x| x > 0);
156 partition
157 }
158}
159
160#[derive(Debug)]
165pub struct EquivariantLocalizer {
166 pub grassmannian: (usize, usize),
168 pub weights: TorusWeights,
170 fixed_points: Option<Vec<FixedPoint>>,
172 schubert_engine: SchubertCalculus,
174}
175
176impl EquivariantLocalizer {
177 pub fn new(grassmannian: (usize, usize)) -> EnumerativeResult<Self> {
186 let (k, n) = grassmannian;
187 if k > n {
188 return Err(crate::EnumerativeError::InvalidDimension(format!(
189 "k={k} must be <= n={n} for Gr(k,n)"
190 )));
191 }
192 Ok(Self {
193 grassmannian,
194 weights: TorusWeights::standard(n),
195 fixed_points: None,
196 schubert_engine: SchubertCalculus::new(grassmannian),
197 })
198 }
199
200 pub fn with_weights(
209 grassmannian: (usize, usize),
210 weights: TorusWeights,
211 ) -> EnumerativeResult<Self> {
212 let (k, n) = grassmannian;
213 if k > n {
214 return Err(crate::EnumerativeError::InvalidDimension(format!(
215 "k={k} must be <= n={n}"
216 )));
217 }
218 if weights.weights.len() != n {
219 return Err(crate::EnumerativeError::InvalidDimension(format!(
220 "Need {n} weights, got {}",
221 weights.weights.len()
222 )));
223 }
224 Ok(Self {
225 grassmannian,
226 weights,
227 fixed_points: None,
228 schubert_engine: SchubertCalculus::new(grassmannian),
229 })
230 }
231
232 #[must_use]
234 pub fn fixed_point_count(&self) -> usize {
235 let (k, n) = self.grassmannian;
236 binomial_usize(n, k)
237 }
238
239 fn ensure_fixed_points(&mut self) {
241 if self.fixed_points.is_some() {
242 return;
243 }
244 let (k, n) = self.grassmannian;
245 let subsets = k_subsets(n, k);
246 let points: Vec<FixedPoint> = subsets
247 .into_iter()
248 .map(|s| FixedPoint {
249 subset: s,
250 grassmannian: self.grassmannian,
251 })
252 .collect();
253 self.fixed_points = Some(points);
254 }
255
256 pub fn fixed_points(&mut self) -> &[FixedPoint] {
258 self.ensure_fixed_points();
259 self.fixed_points.as_ref().unwrap()
260 }
261
262 pub fn localized_intersection(&mut self, classes: &[SchubertClass]) -> Rational64 {
274 let result = self.schubert_engine.multi_intersect(classes);
275 match result {
276 IntersectionResult::Finite(n) => Rational64::from(n as i64),
277 _ => Rational64::from(0),
278 }
279 }
280
281 pub fn intersection_result(&mut self, classes: &[SchubertClass]) -> IntersectionResult {
294 self.schubert_engine.multi_intersect(classes)
295 }
296
297 pub fn fixed_point_analysis(&mut self) -> Vec<(&FixedPoint, Rational64)> {
308 self.ensure_fixed_points();
309 let weights = &self.weights;
310 self.fixed_points
311 .as_ref()
312 .unwrap()
313 .iter()
314 .map(|fp| {
315 let euler = fp.tangent_euler_class(weights);
316 (fp, euler)
317 })
318 .collect()
319 }
320
321 #[cfg(feature = "parallel")]
323 pub fn localized_intersection_parallel(&mut self, classes: &[SchubertClass]) -> Rational64 {
324 self.localized_intersection(classes)
326 }
327}
328
329fn k_subsets(n: usize, k: usize) -> Vec<Vec<usize>> {
331 let mut result = Vec::new();
332 let mut current = Vec::with_capacity(k);
333 generate_subsets(n, k, 0, &mut current, &mut result);
334 result
335}
336
337fn generate_subsets(
338 n: usize,
339 k: usize,
340 start: usize,
341 current: &mut Vec<usize>,
342 result: &mut Vec<Vec<usize>>,
343) {
344 if current.len() == k {
345 result.push(current.clone());
346 return;
347 }
348 let remaining = k - current.len();
349 for i in start..=(n - remaining) {
350 current.push(i);
351 generate_subsets(n, k, i + 1, current, result);
352 current.pop();
353 }
354}
355
356fn binomial_usize(n: usize, k: usize) -> usize {
358 if k > n {
359 return 0;
360 }
361 if k == 0 || k == n {
362 return 1;
363 }
364 let k = k.min(n - k);
365 let mut result: usize = 1;
366 for i in 0..k {
367 result = result * (n - i) / (i + 1);
368 }
369 result
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn test_fixed_point_count() {
378 let loc = EquivariantLocalizer::new((2, 4)).unwrap();
379 assert_eq!(loc.fixed_point_count(), 6); }
381
382 #[test]
383 fn test_fixed_point_count_gr35() {
384 let loc = EquivariantLocalizer::new((3, 5)).unwrap();
385 assert_eq!(loc.fixed_point_count(), 10); }
387
388 #[test]
389 fn test_tangent_euler_class_nonzero() {
390 let weights = TorusWeights::standard(4);
391 let fp = FixedPoint::new(vec![0, 1], (2, 4)).unwrap();
392 let euler = fp.tangent_euler_class(&weights);
393 assert_ne!(euler, Rational64::from(0));
394 }
395
396 #[test]
397 fn test_fixed_point_partition() {
398 let fp = FixedPoint::new(vec![0, 2], (2, 4)).unwrap();
400 assert_eq!(fp.to_partition(), vec![1]);
401 }
402
403 #[test]
404 fn test_localization_four_lines() {
405 let mut loc = EquivariantLocalizer::new((2, 4)).unwrap();
408
409 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
410 let classes = vec![sigma_1.clone(), sigma_1.clone(), sigma_1.clone(), sigma_1];
411
412 let result = loc.localized_intersection(&classes);
413 assert_eq!(result, Rational64::from(2));
414 }
415
416 #[test]
417 fn test_localization_point_class() {
418 let mut loc = EquivariantLocalizer::new((2, 4)).unwrap();
420
421 let sigma_22 = SchubertClass::new(vec![2, 2], (2, 4)).unwrap();
422 let classes = vec![sigma_22];
423
424 let result = loc.localized_intersection(&classes);
425 assert_eq!(result, Rational64::from(1));
426 }
427
428 #[test]
429 fn test_localization_sigma1_squared_gr24() {
430 let mut loc = EquivariantLocalizer::new((2, 4)).unwrap();
432
433 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
434 let classes = vec![sigma_1.clone(), sigma_1];
435
436 let result = loc.intersection_result(&classes);
437 assert!(matches!(
438 result,
439 IntersectionResult::PositiveDimensional { dimension: 2, .. }
440 ));
441 }
442
443 #[test]
444 fn test_localization_overdetermined() {
445 let mut loc = EquivariantLocalizer::new((2, 4)).unwrap();
446
447 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
448 let classes = vec![
450 sigma_1.clone(),
451 sigma_1.clone(),
452 sigma_1.clone(),
453 sigma_1.clone(),
454 sigma_1,
455 ];
456
457 let result = loc.intersection_result(&classes);
458 assert_eq!(result, IntersectionResult::Empty);
459 }
460
461 #[test]
462 fn test_localization_sigma1_cubed_sigma1() {
463 let mut loc = EquivariantLocalizer::new((2, 4)).unwrap();
465
466 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
467 let classes = vec![sigma_1.clone(), sigma_1.clone(), sigma_1.clone(), sigma_1];
468
469 let result = loc.intersection_result(&classes);
470 assert_eq!(result, IntersectionResult::Finite(2));
471 }
472
473 #[test]
474 fn test_custom_weights() {
475 let weights = TorusWeights::custom(vec![1, 3, 7, 11]).unwrap();
477 let mut loc = EquivariantLocalizer::with_weights((2, 4), weights).unwrap();
478
479 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
480 let classes = vec![sigma_1.clone(), sigma_1.clone(), sigma_1.clone(), sigma_1];
481
482 let result = loc.localized_intersection(&classes);
483 assert_eq!(result, Rational64::from(2));
484 }
485
486 #[test]
487 fn test_invalid_weights_zero() {
488 let result = TorusWeights::custom(vec![0, 1, 2, 3]);
489 assert!(result.is_err());
490 }
491
492 #[test]
493 fn test_invalid_weights_duplicate() {
494 let result = TorusWeights::custom(vec![1, 2, 2, 3]);
495 assert!(result.is_err());
496 }
497
498 #[test]
499 fn test_k_subsets() {
500 let subs = k_subsets(4, 2);
501 assert_eq!(subs.len(), 6);
502 assert_eq!(subs[0], vec![0, 1]);
503 assert_eq!(subs[5], vec![2, 3]);
504 }
505
506 #[test]
507 fn test_fixed_point_analysis() {
508 let mut loc = EquivariantLocalizer::new((2, 4)).unwrap();
509 let analysis = loc.fixed_point_analysis();
510 assert_eq!(analysis.len(), 6);
511 for (_, euler) in &analysis {
513 assert_ne!(*euler, Rational64::from(0));
514 }
515 }
516
517 #[cfg(feature = "parallel")]
518 #[test]
519 fn test_parallel_agrees() {
520 let mut loc = EquivariantLocalizer::new((2, 4)).unwrap();
521
522 let sigma_1 = SchubertClass::new(vec![1], (2, 4)).unwrap();
523 let classes = vec![sigma_1.clone(), sigma_1.clone(), sigma_1.clone(), sigma_1];
524
525 let sequential = loc.localized_intersection(&classes);
526 let parallel = loc.localized_intersection_parallel(&classes);
527 assert_eq!(sequential, parallel);
528 }
529}