Skip to main content

fft_symmetric/
lib.rs

1#![warn(missing_docs)]
2
3//! Fast Fourier transforms for symmetric groups over prime fields.
4//!
5//! This crate implements the ordinary, semisimple representation-theoretic
6//! Fourier transform for `S_n` over a prime field `F_p`, with the standing
7//! requirement `p > n`. That condition is deliberate: it ensures
8//! `char(F_p)` does not divide `|S_n| = n!`, so the group algebra is
9//! semisimple and Young's seminormal representations are valid over the
10//! field.
11//!
12//! The transform of a function `f: S_n -> F_p` is the block diagonal family
13//!
14//! ```text
15//! f_hat(lambda) = sum_{g in S_n} f(g) rho_lambda(g),
16//! ```
17//!
18//! one matrix block for each partition `lambda` of `n`. The irreducible
19//! representations `rho_lambda` are Young seminormal representations indexed
20//! by standard tableaux in last-letter order.
21//!
22//! # Ordering conventions
23//!
24//! Input coefficients are ordered by [`SymmetricFft::permutations`], which is
25//! lexicographic order on permutation images. Permutations use zero-based
26//! images internally: the identity in `S_3` has images `[0, 1, 2]`, and the
27//! transposition `(1 2)` in usual one-based notation has images `[1, 0, 2]`.
28//!
29//! Matrix entries are stored row-major. Blocks in a [`FourierTransform`] are
30//! keyed by [`Partition`].
31//!
32//! # Example
33//!
34//! ```
35//! use fft_symmetric::{Partition, SymmetricFft};
36//!
37//! let fft = SymmetricFft::new(3, 101)?;
38//! let values = vec![1; fft.input_len()];
39//! let transform = fft.fft(&values)?;
40//! let recovered = fft.ifft(&transform)?;
41//! assert_eq!(recovered, values);
42//! let product = fft.multiply(&values, &values)?;
43//! assert_eq!(product.len(), fft.input_len());
44//!
45//! let mut unit = vec![0; fft.input_len()];
46//! unit[0] = 5;
47//! let inverse = fft.invert(&unit)?;
48//! assert_eq!(fft.multiply(&unit, &inverse)?, {
49//!     let mut identity = vec![0; fft.input_len()];
50//!     identity[0] = 1;
51//!     identity
52//! });
53//!
54//! let shape = Partition::new(vec![2, 1])?;
55//! let block = transform.block(&shape).unwrap();
56//! assert_eq!((block.rows(), block.cols()), (2, 2));
57//!
58//! # Ok::<(), Box<dyn std::error::Error>>(())
59//! ```
60
61use std::collections::BTreeMap;
62use std::error::Error;
63use std::fmt;
64#[cfg(test)]
65use std::time::Instant;
66
67/// Errors returned by construction, arithmetic helpers, and transforms.
68///
69/// Most errors are validation failures: unsupported characteristics, malformed
70/// matrix shapes, or input vectors that do not have one coefficient per group
71/// element.
72#[derive(Clone, Debug, PartialEq, Eq)]
73pub enum FftError {
74    /// The requested modulus is not prime.
75    CompositeModulus(u64),
76    /// The requested prime field has characteristic `p <= n`.
77    ///
78    /// This crate currently implements only the semisimple case `p > n`.
79    CharacteristicTooSmall {
80        /// The requested prime modulus.
81        modulus: u64,
82        /// The symmetric-group rank.
83        n: usize,
84    },
85    /// The value `n!` does not fit in `usize`.
86    FactorialOverflow(usize),
87    /// The input vector length did not match `|S_n| = n!`.
88    InputLength {
89        /// The required input length.
90        expected: usize,
91        /// The supplied input length.
92        got: usize,
93    },
94    /// A matrix operation received incompatible dimensions or moduli.
95    MatrixShape,
96    /// A Fourier transform block family had the wrong rank, field, blocks, or dimensions.
97    TransformShape,
98    /// A group-algebra element or matrix block was not invertible.
99    NonInvertibleMatrix,
100    /// The FFT constructor was called with `n = 0`.
101    RankZero,
102}
103
104impl fmt::Display for FftError {
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106        match self {
107            Self::CompositeModulus(p) => write!(f, "{p} is not a prime modulus"),
108            Self::CharacteristicTooSmall { modulus, n } => {
109                write!(
110                    f,
111                    "expected characteristic p > n, got p = {modulus}, n = {n}"
112                )
113            }
114            Self::FactorialOverflow(n) => write!(f, "{n}! does not fit in usize"),
115            Self::InputLength { expected, got } => {
116                write!(f, "expected {expected} input coefficients, got {got}")
117            }
118            Self::MatrixShape => write!(f, "matrix shape mismatch"),
119            Self::TransformShape => write!(f, "Fourier transform block shape mismatch"),
120            Self::NonInvertibleMatrix => write!(f, "matrix block is not invertible"),
121            Self::RankZero => write!(f, "rank n must be at least 1"),
122        }
123    }
124}
125
126impl Error for FftError {}
127
128/// A runtime prime field `F_p`.
129///
130/// Elements are represented as canonical `u64` residues in the range
131/// `0..p`. Arithmetic methods normalize their results modulo `p`.
132///
133/// This type is intentionally small and dependency-free. It is enough for the
134/// current FFT because Young's seminormal formulas only require arithmetic in
135/// the base field and inverses of axial distances, which are nonzero when
136/// `p > n`.
137#[derive(Clone, Copy, Debug, PartialEq, Eq)]
138pub struct PrimeField {
139    modulus: u64,
140}
141
142impl PrimeField {
143    /// Constructs `F_p` for a prime modulus `p`.
144    ///
145    /// This only checks primality. Use [`PrimeField::for_symmetric_group`] or
146    /// [`SymmetricFft::new`] when the field will be used for `S_n`.
147    pub fn new(modulus: u64) -> Result<Self, FftError> {
148        if !is_prime(modulus) {
149            return Err(FftError::CompositeModulus(modulus));
150        }
151
152        Ok(Self { modulus })
153    }
154
155    /// Constructs a prime field suitable for the ordinary FFT of `S_n`.
156    ///
157    /// The method enforces `p > n`, which is equivalent to
158    /// `char(F_p)` not dividing `n!`.
159    pub fn for_symmetric_group(n: usize, modulus: u64) -> Result<Self, FftError> {
160        let field = Self::new(modulus)?;
161        if modulus <= n as u64 {
162            return Err(FftError::CharacteristicTooSmall { modulus, n });
163        }
164
165        Ok(field)
166    }
167
168    /// Returns the prime modulus `p`.
169    pub fn modulus(self) -> u64 {
170        self.modulus
171    }
172
173    /// Returns the additive identity.
174    pub fn zero(self) -> u64 {
175        0
176    }
177
178    /// Returns the multiplicative identity.
179    pub fn one(self) -> u64 {
180        1 % self.modulus
181    }
182
183    /// Reduces an unsigned integer modulo `p`.
184    pub fn normalize(self, value: u64) -> u64 {
185        value % self.modulus
186    }
187
188    /// Converts a signed integer into its canonical residue modulo `p`.
189    pub fn from_i64(self, value: i64) -> u64 {
190        let modulus = self.modulus as i128;
191        let mut value = (value as i128) % modulus;
192        if value < 0 {
193            value += modulus;
194        }
195        value as u64
196    }
197
198    /// Adds two residues modulo `p`.
199    pub fn add(self, lhs: u64, rhs: u64) -> u64 {
200        ((lhs as u128 + rhs as u128) % self.modulus as u128) as u64
201    }
202
203    /// Subtracts two residues modulo `p`.
204    pub fn sub(self, lhs: u64, rhs: u64) -> u64 {
205        ((lhs as u128 + self.modulus as u128 - rhs as u128) % self.modulus as u128) as u64
206    }
207
208    /// Negates a residue modulo `p`.
209    pub fn neg(self, value: u64) -> u64 {
210        if value == 0 { 0 } else { self.modulus - value }
211    }
212
213    /// Multiplies two residues modulo `p`.
214    pub fn mul(self, lhs: u64, rhs: u64) -> u64 {
215        ((lhs as u128 * rhs as u128) % self.modulus as u128) as u64
216    }
217
218    /// Raises a residue to a nonnegative power modulo `p`.
219    pub fn pow(self, mut base: u64, mut exp: u64) -> u64 {
220        let mut acc = self.one();
221        base = self.normalize(base);
222
223        while exp > 0 {
224            if exp & 1 == 1 {
225                acc = self.mul(acc, base);
226            }
227            base = self.mul(base, base);
228            exp >>= 1;
229        }
230
231        acc
232    }
233
234    /// Returns the multiplicative inverse, or `None` for zero.
235    pub fn inv(self, value: u64) -> Option<u64> {
236        if value == 0 {
237            None
238        } else {
239            Some(self.pow(value, self.modulus - 2))
240        }
241    }
242}
243
244/// A dense matrix over a [`PrimeField`].
245///
246/// The matrix stores canonical residues in row-major order. The type is kept
247/// intentionally simple because transform outputs are naturally dense matrix
248/// blocks, while the implementation uses private sparse helpers for the
249/// seminormal generators.
250#[derive(Clone, Debug, PartialEq, Eq)]
251pub struct Matrix {
252    rows: usize,
253    cols: usize,
254    modulus: u64,
255    data: Vec<u64>,
256}
257
258impl Matrix {
259    /// Creates the zero matrix with the given shape.
260    pub fn zero(rows: usize, cols: usize, field: PrimeField) -> Self {
261        Self {
262            rows,
263            cols,
264            modulus: field.modulus(),
265            data: vec![0; rows * cols],
266        }
267    }
268
269    /// Creates an identity matrix of size `size`.
270    pub fn identity(size: usize, field: PrimeField) -> Self {
271        let mut matrix = Self::zero(size, size, field);
272        for i in 0..size {
273            matrix.set(i, i, field.one());
274        }
275        matrix
276    }
277
278    /// Creates a matrix from row-major data.
279    ///
280    /// The input data is normalized into the field. Returns
281    /// [`FftError::MatrixShape`] if `data.len() != rows * cols`.
282    pub fn from_vec(
283        rows: usize,
284        cols: usize,
285        field: PrimeField,
286        data: Vec<u64>,
287    ) -> Result<Self, FftError> {
288        if data.len() != rows * cols {
289            return Err(FftError::MatrixShape);
290        }
291
292        Ok(Self {
293            rows,
294            cols,
295            modulus: field.modulus(),
296            data: data
297                .into_iter()
298                .map(|value| field.normalize(value))
299                .collect(),
300        })
301    }
302
303    /// Returns the number of rows.
304    pub fn rows(&self) -> usize {
305        self.rows
306    }
307
308    /// Returns the number of columns.
309    pub fn cols(&self) -> usize {
310        self.cols
311    }
312
313    /// Returns the prime modulus of the matrix entries.
314    pub fn modulus(&self) -> u64 {
315        self.modulus
316    }
317
318    /// Returns the row-major entries.
319    pub fn data(&self) -> &[u64] {
320        &self.data
321    }
322
323    /// Returns the entry at `(row, col)`.
324    ///
325    /// Panics if the indices are out of bounds.
326    pub fn get(&self, row: usize, col: usize) -> u64 {
327        self.data[row * self.cols + col]
328    }
329
330    /// Sets the entry at `(row, col)`, reducing `value` modulo the matrix field.
331    ///
332    /// Panics if the indices are out of bounds.
333    pub fn set(&mut self, row: usize, col: usize, value: u64) {
334        self.data[row * self.cols + col] = value % self.modulus;
335    }
336
337    /// Adds another matrix to this one in place.
338    ///
339    /// Returns [`FftError::MatrixShape`] if the shapes or moduli differ.
340    pub fn add_assign(&mut self, rhs: &Self) -> Result<(), FftError> {
341        if self.rows != rhs.rows || self.cols != rhs.cols || self.modulus != rhs.modulus {
342            return Err(FftError::MatrixShape);
343        }
344
345        let field = PrimeField {
346            modulus: self.modulus,
347        };
348        for (lhs, rhs) in self.data.iter_mut().zip(rhs.data.iter()) {
349            *lhs = field.add(*lhs, *rhs);
350        }
351
352        Ok(())
353    }
354
355    /// Adds `scalar * rhs` to this matrix in place.
356    ///
357    /// Returns [`FftError::MatrixShape`] if the shapes or moduli differ.
358    pub fn add_scaled_assign(&mut self, scalar: u64, rhs: &Self) -> Result<(), FftError> {
359        if self.rows != rhs.rows || self.cols != rhs.cols || self.modulus != rhs.modulus {
360            return Err(FftError::MatrixShape);
361        }
362
363        let field = PrimeField {
364            modulus: self.modulus,
365        };
366        let scalar = field.normalize(scalar);
367        for (lhs, rhs) in self.data.iter_mut().zip(rhs.data.iter()) {
368            *lhs = field.add(*lhs, field.mul(scalar, *rhs));
369        }
370
371        Ok(())
372    }
373
374    /// Multiplies two dense matrices over the same prime field.
375    ///
376    /// Returns [`FftError::MatrixShape`] if the inner dimensions or moduli do
377    /// not match.
378    pub fn mul(&self, rhs: &Self) -> Result<Self, FftError> {
379        if self.cols != rhs.rows || self.modulus != rhs.modulus {
380            return Err(FftError::MatrixShape);
381        }
382
383        let field = PrimeField {
384            modulus: self.modulus,
385        };
386        let mut out = Self::zero(self.rows, rhs.cols, field);
387
388        for row in 0..self.rows {
389            for mid in 0..self.cols {
390                let lhs = self.get(row, mid);
391                if lhs == 0 {
392                    continue;
393                }
394                for col in 0..rhs.cols {
395                    let idx = row * rhs.cols + col;
396                    out.data[idx] = field.add(out.data[idx], field.mul(lhs, rhs.get(mid, col)));
397                }
398            }
399        }
400
401        Ok(out)
402    }
403
404    /// Returns the inverse of a square matrix over its prime field.
405    ///
406    /// This uses Gauss-Jordan elimination with row swaps. It returns
407    /// [`FftError::MatrixShape`] for non-square matrices and
408    /// [`FftError::NonInvertibleMatrix`] when no nonzero pivot exists.
409    pub fn inverse(&self) -> Result<Self, FftError> {
410        if self.rows != self.cols {
411            return Err(FftError::MatrixShape);
412        }
413
414        let field = PrimeField {
415            modulus: self.modulus,
416        };
417        let size = self.rows;
418        let mut rows = vec![vec![field.zero(); 2 * size]; size];
419
420        for (row_index, row) in rows.iter_mut().enumerate() {
421            for (col, entry) in row.iter_mut().take(size).enumerate() {
422                *entry = self.get(row_index, col);
423            }
424            row[size + row_index] = field.one();
425        }
426
427        for col in 0..size {
428            let pivot_row = (col..size)
429                .find(|&row| rows[row][col] != field.zero())
430                .ok_or(FftError::NonInvertibleMatrix)?;
431            if pivot_row != col {
432                rows.swap(col, pivot_row);
433            }
434
435            let pivot_inverse = field
436                .inv(rows[col][col])
437                .ok_or(FftError::NonInvertibleMatrix)?;
438            for entry in &mut rows[col] {
439                *entry = field.mul(*entry, pivot_inverse);
440            }
441
442            let pivot = rows[col].clone();
443            for (row_index, row) in rows.iter_mut().enumerate() {
444                if row_index == col {
445                    continue;
446                }
447                let factor = row[col];
448                if factor == field.zero() {
449                    continue;
450                }
451                for (entry, pivot_entry) in row.iter_mut().zip(pivot.iter()) {
452                    *entry = field.sub(*entry, field.mul(factor, *pivot_entry));
453                }
454            }
455        }
456
457        let mut data = Vec::with_capacity(size * size);
458        for row in &rows {
459            data.extend_from_slice(&row[size..]);
460        }
461        Self::from_vec(size, size, field, data)
462    }
463
464    fn left_multiply_sparse_rows(&self, rows: &[Vec<(usize, u64)>]) -> Self {
465        debug_assert_eq!(self.rows, rows.len());
466
467        let field = PrimeField {
468            modulus: self.modulus,
469        };
470        let mut out = Self::zero(self.rows, self.cols, field);
471
472        for (row, terms) in rows.iter().enumerate() {
473            for &(src_row, coeff) in terms {
474                if coeff == 0 {
475                    continue;
476                }
477                for col in 0..self.cols {
478                    let idx = row * self.cols + col;
479                    out.data[idx] =
480                        field.add(out.data[idx], field.mul(coeff, self.get(src_row, col)));
481                }
482            }
483        }
484
485        out
486    }
487
488    fn submatrix(&self, start_row: usize, start_col: usize, rows: usize, cols: usize) -> Self {
489        let field = PrimeField {
490            modulus: self.modulus,
491        };
492        let mut out = Self::zero(rows, cols, field);
493
494        for row in 0..rows {
495            for col in 0..cols {
496                out.set(row, col, self.get(start_row + row, start_col + col));
497            }
498        }
499
500        out
501    }
502}
503
504/// An integer partition, used to index irreducible representations of `S_n`.
505///
506/// Parts are stored in nonincreasing order. For example, the partition
507/// `(3, 1, 1)` is represented by `Partition::new(vec![3, 1, 1])`.
508///
509/// The empty partition can occur internally when recursively removing boxes,
510/// but [`SymmetricFft`] itself only supports ranks `n >= 1`.
511#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
512pub struct Partition(Vec<usize>);
513
514impl Partition {
515    /// Constructs a partition from nonzero, nonincreasing parts.
516    ///
517    /// The empty vector is allowed and denotes the empty partition of `0`.
518    pub fn new(parts: Vec<usize>) -> Result<Self, FftError> {
519        if parts.contains(&0) {
520            return Err(FftError::MatrixShape);
521        }
522
523        for pair in parts.windows(2) {
524            if pair[0] < pair[1] {
525                return Err(FftError::MatrixShape);
526            }
527        }
528
529        Ok(Self(parts))
530    }
531
532    /// Returns the parts of the partition.
533    pub fn parts(&self) -> &[usize] {
534        &self.0
535    }
536
537    /// Returns the integer being partitioned.
538    pub fn n(&self) -> usize {
539        self.0.iter().sum()
540    }
541
542    /// Returns row indices whose final box can be removed.
543    ///
544    /// Row indices are zero-based. Removing one of these rows with
545    /// [`Partition::remove_box`] gives a valid partition of `n - 1`.
546    pub fn removable_rows(&self) -> Vec<usize> {
547        let mut rows = Vec::new();
548        for row in 0..self.0.len() {
549            let next = self.0.get(row + 1).copied().unwrap_or(0);
550            if self.0[row] > next {
551                rows.push(row);
552            }
553        }
554        rows
555    }
556
557    /// Removes the final box from a row.
558    ///
559    /// The caller should pass a zero-based row returned by
560    /// [`Partition::removable_rows`]. Passing another row may produce a
561    /// non-partition or panic.
562    pub fn remove_box(&self, row: usize) -> Self {
563        let mut parts = self.0.clone();
564        parts[row] -= 1;
565        if parts[row] == 0 {
566            parts.remove(row);
567        }
568        Self(parts)
569    }
570}
571
572impl fmt::Display for Partition {
573    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
574        write!(f, "(")?;
575        for (i, part) in self.0.iter().enumerate() {
576            if i > 0 {
577                write!(f, ",")?;
578            }
579            write!(f, "{part}")?;
580        }
581        write!(f, ")")
582    }
583}
584
585/// A standard Young tableau.
586///
587/// Entries are the one-based numbers `1..=n`. Row and column coordinates
588/// returned by [`Tableau::position`] are zero-based.
589#[derive(Clone, Debug, PartialEq, Eq)]
590pub struct Tableau {
591    shape: Partition,
592    rows: Vec<Vec<usize>>,
593    positions: Vec<(usize, usize)>,
594}
595
596impl Tableau {
597    /// Returns the tableau shape.
598    pub fn shape(&self) -> &Partition {
599        &self.shape
600    }
601
602    /// Returns tableau rows, with entries written left to right.
603    pub fn rows(&self) -> &[Vec<usize>] {
604        &self.rows
605    }
606
607    /// Returns the zero-based `(row, column)` containing `entry`.
608    ///
609    /// Panics if `entry` is not in `1..=n`.
610    pub fn position(&self, entry: usize) -> (usize, usize) {
611        assert!(entry > 0 && entry < self.positions.len());
612        self.positions[entry]
613    }
614
615    fn key(&self) -> Vec<usize> {
616        self.rows.iter().flatten().copied().collect()
617    }
618
619    fn swapped_key(&self, lhs: usize, rhs: usize) -> Vec<usize> {
620        self.rows
621            .iter()
622            .flatten()
623            .map(|entry| {
624                if *entry == lhs {
625                    rhs
626                } else if *entry == rhs {
627                    lhs
628                } else {
629                    *entry
630                }
631            })
632            .collect()
633    }
634
635    fn content(&self, entry: usize) -> i64 {
636        let (row, col) = self.position(entry);
637        col as i64 - row as i64
638    }
639}
640
641/// A permutation of `{0, ..., n - 1}`.
642///
643/// The image vector stores `g(i)` at index `i`. This is zero-based internally
644/// even though the representation-theory formulas are usually written with
645/// one-based letters.
646#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
647pub struct Permutation {
648    images: Vec<usize>,
649}
650
651impl Permutation {
652    /// Returns the identity permutation of size `n`.
653    pub fn identity(n: usize) -> Self {
654        Self {
655            images: (0..n).collect(),
656        }
657    }
658
659    /// Returns the image vector.
660    pub fn images(&self) -> &[usize] {
661        &self.images
662    }
663
664    /// Returns the permutation size.
665    pub fn len(&self) -> usize {
666        self.images.len()
667    }
668
669    /// Returns true only for the unique permutation of the empty set.
670    pub fn is_empty(&self) -> bool {
671        self.images.is_empty()
672    }
673
674    /// Returns the composition `self o rhs`.
675    ///
676    /// In other words, the result maps `i` to `self(rhs(i))`.
677    ///
678    /// Panics if the permutations have different sizes.
679    pub fn compose(&self, rhs: &Self) -> Self {
680        assert_eq!(self.len(), rhs.len());
681        let images = rhs.images.iter().map(|image| self.images[*image]).collect();
682        Self { images }
683    }
684
685    /// Returns the adjacent transposition swapping `index` and `index + 1`.
686    ///
687    /// The index is zero-based, so `Permutation::adjacent(3, 0)` is the usual
688    /// transposition `(1 2)` in `S_3`.
689    ///
690    /// Panics if `index + 1 >= n`.
691    pub fn adjacent(n: usize, index: usize) -> Self {
692        let mut images: Vec<_> = (0..n).collect();
693        images.swap(index, index + 1);
694        Self { images }
695    }
696
697    fn cycle_moving_last_to(n: usize, target: usize) -> Self {
698        let mut images: Vec<_> = (0..n).collect();
699        if target < n - 1 {
700            for (i, image) in images.iter_mut().enumerate().take(n - 1).skip(target) {
701                *image = i + 1;
702            }
703            images[n - 1] = target;
704        }
705        Self { images }
706    }
707
708    fn embed_fixing_last(&self) -> Self {
709        let mut images = self.images.clone();
710        images.push(self.images.len());
711        Self { images }
712    }
713
714    fn adjacent_word(&self) -> Vec<usize> {
715        let n = self.len();
716        let mut current: Vec<_> = (0..n).collect();
717        let mut word = Vec::new();
718
719        for pos in 0..n {
720            let target = self.images[pos];
721            let mut current_pos = current
722                .iter()
723                .position(|value| *value == target)
724                .expect("permutation image");
725            while current_pos > pos {
726                current.swap(current_pos - 1, current_pos);
727                word.push(current_pos - 1);
728                current_pos -= 1;
729            }
730        }
731
732        word
733    }
734}
735
736/// The Fourier transform of a function on `S_n`.
737///
738/// A transform is a block family indexed by partitions of `n`. The block for
739/// `lambda` is the matrix
740///
741/// ```text
742/// sum_{g in S_n} f(g) rho_lambda(g),
743/// ```
744///
745/// where `rho_lambda` is the Young seminormal irreducible representation.
746/// Rows and columns within each block follow the standard-tableau ordering
747/// returned by [`SymmetricFft::standard_tableaux`].
748#[derive(Clone, Debug)]
749pub struct FourierTransform {
750    n: usize,
751    modulus: u64,
752    blocks: BTreeMap<Partition, Matrix>,
753}
754
755impl FourierTransform {
756    /// Returns the rank `n` of the symmetric group.
757    pub fn n(&self) -> usize {
758        self.n
759    }
760
761    /// Returns the prime modulus of all block entries.
762    pub fn modulus(&self) -> u64 {
763        self.modulus
764    }
765
766    /// Returns all transform blocks keyed by partition.
767    pub fn blocks(&self) -> &BTreeMap<Partition, Matrix> {
768        &self.blocks
769    }
770
771    /// Returns the block for a partition of `n`, if present.
772    pub fn block(&self, partition: &Partition) -> Option<&Matrix> {
773        self.blocks.get(partition)
774    }
775}
776
777/// Precomputed plan for Fourier transforms on `S_n` over `F_p`.
778///
779/// Construction precomputes partitions, standard tableaux, permutations, and
780/// Young seminormal adjacent-transposition matrices for every rank up to `n`.
781/// The actual [`SymmetricFft::fft`] method can then be reused for many input
782/// vectors over the same group and field.
783///
784/// The current implementation is exact and dependency-free, but intentionally
785/// limited to prime fields and the semisimple case `p > n`.
786#[derive(Clone, Debug)]
787pub struct SymmetricFft {
788    n: usize,
789    field: PrimeField,
790    levels: Vec<Level>,
791}
792
793impl SymmetricFft {
794    /// Builds an FFT plan for `S_n` over the prime field `F_modulus`.
795    ///
796    /// This returns an error unless `n >= 1`, `modulus` is prime, and
797    /// `modulus > n`.
798    pub fn new(n: usize, modulus: u64) -> Result<Self, FftError> {
799        if n == 0 {
800            return Err(FftError::RankZero);
801        }
802
803        checked_factorial(n)?;
804        let field = PrimeField::for_symmetric_group(n, modulus)?;
805        let mut levels = Vec::with_capacity(n + 1);
806        levels.push(Level::empty(field));
807
808        for k in 1..=n {
809            let previous = levels.last().expect("previous level");
810            levels.push(Level::new(k, field, previous));
811        }
812
813        Ok(Self { n, field, levels })
814    }
815
816    /// Returns the rank `n`.
817    pub fn n(&self) -> usize {
818        self.n
819    }
820
821    /// Returns the prime field used by this plan.
822    pub fn field(&self) -> PrimeField {
823        self.field
824    }
825
826    /// Returns the expected number of input coefficients, equal to `n!`.
827    pub fn input_len(&self) -> usize {
828        self.levels[self.n].permutations.len()
829    }
830
831    /// Returns the input permutation ordering.
832    ///
833    /// The coefficient at index `i` in [`SymmetricFft::fft`] is interpreted as
834    /// the value of the input function on `self.permutations()[i]`.
835    pub fn permutations(&self) -> &[Permutation] {
836        &self.levels[self.n].permutations
837    }
838
839    /// Returns the partitions of `n` indexing the Fourier blocks.
840    pub fn partitions(&self) -> &[Partition] {
841        &self.levels[self.n].partitions
842    }
843
844    /// Returns the standard-tableau basis order for an irreducible block.
845    ///
846    /// The same order is used for rows and columns of the corresponding
847    /// transform block.
848    pub fn standard_tableaux(&self, partition: &Partition) -> Option<&[Tableau]> {
849        self.levels[self.n]
850            .irreps
851            .get(partition)
852            .map(|irrep| irrep.tableaux.as_slice())
853    }
854
855    /// Applies the fast Fourier transform.
856    ///
857    /// `values` must contain one coefficient for every group element in the
858    /// order returned by [`SymmetricFft::permutations`]. Values are normalized
859    /// into the field before the transform is evaluated.
860    pub fn fft(&self, values: &[u64]) -> Result<FourierTransform, FftError> {
861        self.validate_input(values)?;
862        let values: Vec<_> = values
863            .iter()
864            .map(|value| self.field.normalize(*value))
865            .collect();
866        let blocks = self.fft_level(self.n, &values)?;
867
868        Ok(FourierTransform {
869            n: self.n,
870            modulus: self.field.modulus(),
871            blocks,
872        })
873    }
874
875    /// Applies the inverse fast Fourier transform.
876    ///
877    /// The returned vector is ordered by [`SymmetricFft::permutations`], so if
878    /// `transform = self.fft(values)?`, then `self.ifft(&transform)? == values`
879    /// after reducing the input values modulo the field.
880    pub fn ifft(&self, transform: &FourierTransform) -> Result<Vec<u64>, FftError> {
881        self.validate_transform(transform)?;
882        self.ifft_level(self.n, transform.blocks())
883    }
884
885    /// Multiplies two elements of the group algebra `F_p[S_n]`.
886    ///
887    /// Inputs and output are ordered by [`SymmetricFft::permutations`]. The
888    /// multiplication convention is
889    ///
890    /// ```text
891    /// (v * w)_x = sum_{g h = x} v_g w_h.
892    /// ```
893    ///
894    /// This method uses the convolution theorem:
895    ///
896    /// ```text
897    /// v * w = FFT^{-1}(FFT(v) FFT(w)),
898    /// ```
899    ///
900    /// where the middle product is ordinary matrix multiplication inside each
901    /// Fourier block.
902    pub fn multiply(&self, lhs: &[u64], rhs: &[u64]) -> Result<Vec<u64>, FftError> {
903        self.validate_input(lhs)?;
904        self.validate_input(rhs)?;
905
906        let lhs_transform = self.fft(lhs)?;
907        let rhs_transform = self.fft(rhs)?;
908        let mut blocks = BTreeMap::new();
909
910        for partition in &self.levels[self.n].partitions {
911            let lhs_block = lhs_transform
912                .block(partition)
913                .ok_or(FftError::TransformShape)?;
914            let rhs_block = rhs_transform
915                .block(partition)
916                .ok_or(FftError::TransformShape)?;
917            blocks.insert(partition.clone(), lhs_block.mul(rhs_block)?);
918        }
919
920        self.ifft(&FourierTransform {
921            n: self.n,
922            modulus: self.field.modulus(),
923            blocks,
924        })
925    }
926
927    /// Inverts an element of the group algebra `F_p[S_n]` using the FFT.
928    ///
929    /// Inputs and output are ordered by [`SymmetricFft::permutations`]. The
930    /// method transforms the element, inverts each Young block, and applies the
931    /// inverse FFT. It returns [`FftError::NonInvertibleMatrix`] when any block
932    /// is singular, which is exactly the semisimple obstruction to being a unit.
933    pub fn invert(&self, values: &[u64]) -> Result<Vec<u64>, FftError> {
934        self.validate_input(values)?;
935        let transform = self.fft(values)?;
936        let inverse = self.invert_transform(&transform)?;
937        self.ifft(&inverse)
938    }
939
940    /// Inverts a Fourier transform block-by-block.
941    ///
942    /// The symmetric transform uses the unnormalized convention
943    /// `fhat(rho) = sum_g f(g) rho(g)`, so a group-algebra inverse is obtained
944    /// by ordinary matrix inversion inside every irreducible block.
945    pub fn invert_transform(
946        &self,
947        transform: &FourierTransform,
948    ) -> Result<FourierTransform, FftError> {
949        self.validate_transform(transform)?;
950        let mut blocks = BTreeMap::new();
951
952        for partition in &self.levels[self.n].partitions {
953            let block = transform.block(partition).ok_or(FftError::TransformShape)?;
954            blocks.insert(partition.clone(), block.inverse()?);
955        }
956
957        Ok(FourierTransform {
958            n: self.n,
959            modulus: self.field.modulus(),
960            blocks,
961        })
962    }
963
964    /// Multiplies two group-algebra elements by the direct convolution formula.
965    ///
966    /// This is `O((n!)^2)` and is mainly intended as a correctness oracle and
967    /// performance baseline for [`SymmetricFft::multiply`].
968    pub fn naive_multiply(&self, lhs: &[u64], rhs: &[u64]) -> Result<Vec<u64>, FftError> {
969        self.validate_input(lhs)?;
970        self.validate_input(rhs)?;
971
972        let level = &self.levels[self.n];
973        let mut out = vec![self.field.zero(); level.permutations.len()];
974
975        for (lhs_index, lhs_perm) in level.permutations.iter().enumerate() {
976            let lhs_value = self.field.normalize(lhs[lhs_index]);
977            if lhs_value == 0 {
978                continue;
979            }
980
981            for (rhs_index, rhs_perm) in level.permutations.iter().enumerate() {
982                let rhs_value = self.field.normalize(rhs[rhs_index]);
983                if rhs_value == 0 {
984                    continue;
985                }
986
987                let product = lhs_perm.compose(rhs_perm);
988                let product_index = *level
989                    .permutation_index
990                    .get(product.images())
991                    .expect("permutation product index");
992                out[product_index] = self
993                    .field
994                    .add(out[product_index], self.field.mul(lhs_value, rhs_value));
995            }
996        }
997
998        Ok(out)
999    }
1000
1001    /// Applies the direct definition of the Fourier transform.
1002    ///
1003    /// This is much slower than [`SymmetricFft::fft`] and is mainly intended
1004    /// as a correctness oracle for small `n`.
1005    pub fn naive_dft(&self, values: &[u64]) -> Result<FourierTransform, FftError> {
1006        self.validate_input(values)?;
1007        let level = &self.levels[self.n];
1008        let mut blocks = BTreeMap::new();
1009
1010        for partition in &level.partitions {
1011            let irrep = level.irreps.get(partition).expect("irrep data");
1012            let mut block = Matrix::zero(irrep.dimension(), irrep.dimension(), self.field);
1013
1014            for (perm_index, perm) in level.permutations.iter().enumerate() {
1015                let value = self.field.normalize(values[perm_index]);
1016                if value == 0 {
1017                    continue;
1018                }
1019                let rho = self.representation_matrix(partition, perm)?;
1020                block.add_scaled_assign(value, &rho)?;
1021            }
1022
1023            blocks.insert(partition.clone(), block);
1024        }
1025
1026        Ok(FourierTransform {
1027            n: self.n,
1028            modulus: self.field.modulus(),
1029            blocks,
1030        })
1031    }
1032
1033    /// Returns the Young seminormal matrix for an adjacent transposition.
1034    ///
1035    /// `adjacent_index` is zero-based: index `0` means the usual transposition
1036    /// `(1 2)`, index `1` means `(2 3)`, and so on. The returned matrix is for
1037    /// the irreducible representation indexed by `partition`.
1038    pub fn generator_matrix(&self, partition: &Partition, adjacent_index: usize) -> Option<Matrix> {
1039        let level = self.levels.get(partition.n())?;
1040        let irrep = level.irreps.get(partition)?;
1041        let rows = irrep.generator_rows.get(adjacent_index)?;
1042        Some(matrix_from_sparse_rows(rows, self.field))
1043    }
1044
1045    /// Evaluates an irreducible Young seminormal representation on a permutation.
1046    ///
1047    /// The partition determines the representation. The permutation size must
1048    /// match `partition.n()`.
1049    pub fn representation_matrix(
1050        &self,
1051        partition: &Partition,
1052        permutation: &Permutation,
1053    ) -> Result<Matrix, FftError> {
1054        if permutation.len() != partition.n() {
1055            return Err(FftError::MatrixShape);
1056        }
1057
1058        let level = self
1059            .levels
1060            .get(partition.n())
1061            .ok_or(FftError::MatrixShape)?;
1062        let irrep = level.irreps.get(partition).ok_or(FftError::MatrixShape)?;
1063        let mut matrix = Matrix::identity(irrep.dimension(), self.field);
1064
1065        for adjacent_index in permutation.adjacent_word() {
1066            let generator =
1067                matrix_from_sparse_rows(&irrep.generator_rows[adjacent_index], self.field);
1068            matrix = matrix.mul(&generator)?;
1069        }
1070
1071        Ok(matrix)
1072    }
1073
1074    fn validate_input(&self, values: &[u64]) -> Result<(), FftError> {
1075        let expected = self.input_len();
1076        if values.len() != expected {
1077            return Err(FftError::InputLength {
1078                expected,
1079                got: values.len(),
1080            });
1081        }
1082
1083        Ok(())
1084    }
1085
1086    fn validate_transform(&self, transform: &FourierTransform) -> Result<(), FftError> {
1087        if transform.n() != self.n || transform.modulus() != self.field.modulus() {
1088            return Err(FftError::TransformShape);
1089        }
1090
1091        let level = &self.levels[self.n];
1092        if transform.blocks().len() != level.partitions.len() {
1093            return Err(FftError::TransformShape);
1094        }
1095
1096        for partition in &level.partitions {
1097            let irrep = level.irreps.get(partition).expect("irrep data");
1098            let block = transform.block(partition).ok_or(FftError::TransformShape)?;
1099            if block.rows() != irrep.dimension()
1100                || block.cols() != irrep.dimension()
1101                || block.modulus() != self.field.modulus()
1102            {
1103                return Err(FftError::TransformShape);
1104            }
1105        }
1106
1107        Ok(())
1108    }
1109
1110    fn fft_level(&self, k: usize, values: &[u64]) -> Result<BTreeMap<Partition, Matrix>, FftError> {
1111        let level = &self.levels[k];
1112
1113        if k == 1 {
1114            let partition = Partition(vec![1]);
1115            let matrix = Matrix::from_vec(1, 1, self.field, vec![values[0]])?;
1116            return Ok(BTreeMap::from([(partition, matrix)]));
1117        }
1118
1119        let previous = &self.levels[k - 1];
1120        let mut sub_transforms = Vec::with_capacity(k);
1121
1122        for target in 0..k {
1123            let coset_rep = Permutation::cycle_moving_last_to(k, target);
1124            let mut sub_values = vec![self.field.zero(); previous.permutations.len()];
1125
1126            for (sub_index, sub_perm) in previous.permutations.iter().enumerate() {
1127                let embedded = sub_perm.embed_fixing_last();
1128                let perm = coset_rep.compose(&embedded);
1129                let value_index = *level
1130                    .permutation_index
1131                    .get(perm.images())
1132                    .expect("coset permutation index");
1133                sub_values[sub_index] = values[value_index];
1134            }
1135
1136            sub_transforms.push(self.fft_level(k - 1, &sub_values)?);
1137        }
1138
1139        let mut out = BTreeMap::new();
1140        for partition in &level.partitions {
1141            let irrep = level.irreps.get(partition).expect("irrep data");
1142            let mut block = Matrix::zero(irrep.dimension(), irrep.dimension(), self.field);
1143
1144            for (target, transform) in sub_transforms.iter().enumerate() {
1145                let mut embedded = embed_restricted_blocks(irrep, transform, self.field);
1146
1147                for adjacent_index in (target..k - 1).rev() {
1148                    embedded =
1149                        embedded.left_multiply_sparse_rows(&irrep.generator_rows[adjacent_index]);
1150                }
1151
1152                block.add_assign(&embedded)?;
1153            }
1154
1155            out.insert(partition.clone(), block);
1156        }
1157
1158        Ok(out)
1159    }
1160
1161    fn ifft_level(
1162        &self,
1163        k: usize,
1164        blocks: &BTreeMap<Partition, Matrix>,
1165    ) -> Result<Vec<u64>, FftError> {
1166        let level = &self.levels[k];
1167
1168        if k == 1 {
1169            let block = blocks
1170                .get(&Partition(vec![1]))
1171                .ok_or(FftError::TransformShape)?;
1172            return Ok(vec![block.get(0, 0)]);
1173        }
1174
1175        let previous = &self.levels[k - 1];
1176        let mut values = vec![self.field.zero(); level.permutations.len()];
1177
1178        for target in 0..k {
1179            let mut sub_blocks = previous
1180                .partitions
1181                .iter()
1182                .map(|partition| {
1183                    let irrep = previous.irreps.get(partition).expect("previous irrep");
1184                    (
1185                        partition.clone(),
1186                        Matrix::zero(irrep.dimension(), irrep.dimension(), self.field),
1187                    )
1188                })
1189                .collect::<BTreeMap<_, _>>();
1190
1191            for partition in &level.partitions {
1192                let irrep = level.irreps.get(partition).expect("irrep data");
1193                let mut shifted = blocks
1194                    .get(partition)
1195                    .ok_or(FftError::TransformShape)?
1196                    .clone();
1197
1198                for adjacent_index in target..k - 1 {
1199                    shifted =
1200                        shifted.left_multiply_sparse_rows(&irrep.generator_rows[adjacent_index]);
1201                }
1202
1203                let numerator = self.field.normalize(irrep.dimension() as u64);
1204                for branch in &irrep.branches {
1205                    let denominator = self.field.mul(
1206                        self.field.normalize(k as u64),
1207                        self.field.normalize(branch.size as u64),
1208                    );
1209                    let scalar = self
1210                        .field
1211                        .mul(numerator, self.field.inv(denominator).expect("p > n"));
1212                    let projected =
1213                        shifted.submatrix(branch.start, branch.start, branch.size, branch.size);
1214                    sub_blocks
1215                        .get_mut(&branch.partition)
1216                        .expect("sub-block")
1217                        .add_scaled_assign(scalar, &projected)?;
1218                }
1219            }
1220
1221            let sub_values = self.ifft_level(k - 1, &sub_blocks)?;
1222            let coset_rep = Permutation::cycle_moving_last_to(k, target);
1223
1224            for (sub_index, sub_perm) in previous.permutations.iter().enumerate() {
1225                let embedded = sub_perm.embed_fixing_last();
1226                let perm = coset_rep.compose(&embedded);
1227                let value_index = *level
1228                    .permutation_index
1229                    .get(perm.images())
1230                    .expect("coset permutation index");
1231                values[value_index] = sub_values[sub_index];
1232            }
1233        }
1234
1235        Ok(values)
1236    }
1237}
1238
1239#[derive(Clone, Debug)]
1240struct Level {
1241    partitions: Vec<Partition>,
1242    irreps: BTreeMap<Partition, IrrepData>,
1243    permutations: Vec<Permutation>,
1244    permutation_index: BTreeMap<Vec<usize>, usize>,
1245}
1246
1247impl Level {
1248    fn empty(_field: PrimeField) -> Self {
1249        Self {
1250            partitions: Vec::new(),
1251            irreps: BTreeMap::new(),
1252            permutations: Vec::new(),
1253            permutation_index: BTreeMap::new(),
1254        }
1255    }
1256
1257    fn new(n: usize, field: PrimeField, previous: &Self) -> Self {
1258        let partitions = partitions(n);
1259        let permutations = all_permutations(n);
1260        let permutation_index = permutations
1261            .iter()
1262            .enumerate()
1263            .map(|(index, permutation)| (permutation.images.clone(), index))
1264            .collect();
1265        let mut irreps = BTreeMap::new();
1266
1267        for partition in &partitions {
1268            let irrep = IrrepData::new(partition.clone(), field, previous);
1269            irreps.insert(partition.clone(), irrep);
1270        }
1271
1272        Self {
1273            partitions,
1274            irreps,
1275            permutations,
1276            permutation_index,
1277        }
1278    }
1279}
1280
1281#[derive(Clone, Debug)]
1282struct IrrepData {
1283    tableaux: Vec<Tableau>,
1284    branches: Vec<BranchBlock>,
1285    generator_rows: Vec<Vec<Vec<(usize, u64)>>>,
1286}
1287
1288impl IrrepData {
1289    fn new(partition: Partition, field: PrimeField, previous: &Level) -> Self {
1290        let tableaux = standard_tableaux(&partition);
1291        let tableau_index = tableaux
1292            .iter()
1293            .enumerate()
1294            .map(|(index, tableau)| (tableau.key(), index))
1295            .collect::<BTreeMap<_, _>>();
1296
1297        let mut branches = Vec::new();
1298        let mut start = 0;
1299        if partition.n() > 1 {
1300            for row in partition.removable_rows() {
1301                let subpartition = partition.remove_box(row);
1302                let size = previous
1303                    .irreps
1304                    .get(&subpartition)
1305                    .expect("branch irrep")
1306                    .dimension();
1307                branches.push(BranchBlock {
1308                    partition: subpartition,
1309                    start,
1310                    size,
1311                });
1312                start += size;
1313            }
1314        }
1315
1316        let generator_rows = (0..partition.n().saturating_sub(1))
1317            .map(|adjacent_index| {
1318                seminormal_generator_rows(&tableaux, &tableau_index, adjacent_index, field)
1319            })
1320            .collect();
1321
1322        Self {
1323            tableaux,
1324            branches,
1325            generator_rows,
1326        }
1327    }
1328
1329    fn dimension(&self) -> usize {
1330        self.tableaux.len()
1331    }
1332}
1333
1334#[derive(Clone, Debug)]
1335struct BranchBlock {
1336    partition: Partition,
1337    start: usize,
1338    size: usize,
1339}
1340
1341/// Returns all integer partitions of `n` in reverse lexicographic order.
1342///
1343/// For example, `partitions(4)` starts with `(4)` and ends with `(1,1,1,1)`.
1344pub fn partitions(n: usize) -> Vec<Partition> {
1345    fn go(remaining: usize, max_part: usize, current: &mut Vec<usize>, out: &mut Vec<Partition>) {
1346        if remaining == 0 {
1347            out.push(Partition(current.clone()));
1348            return;
1349        }
1350
1351        for part in (1..=remaining.min(max_part)).rev() {
1352            current.push(part);
1353            go(remaining - part, part, current, out);
1354            current.pop();
1355        }
1356    }
1357
1358    let mut out = Vec::new();
1359    go(n, n, &mut Vec::new(), &mut out);
1360    out
1361}
1362
1363/// Returns all standard Young tableaux of a given shape.
1364///
1365/// The order is the last-letter order used by Young's seminormal
1366/// representations, so it is also the row and column basis order for Fourier
1367/// blocks.
1368pub fn standard_tableaux(partition: &Partition) -> Vec<Tableau> {
1369    fn go(shape: &Partition) -> Vec<Vec<Vec<usize>>> {
1370        if shape.n() == 0 {
1371            return vec![Vec::new()];
1372        }
1373
1374        let mut out = Vec::new();
1375        let entry = shape.n();
1376        for row in shape.removable_rows() {
1377            let subshape = shape.remove_box(row);
1378            for mut rows in go(&subshape) {
1379                if row == rows.len() {
1380                    rows.push(vec![entry]);
1381                } else {
1382                    rows[row].push(entry);
1383                }
1384                out.push(rows);
1385            }
1386        }
1387
1388        out
1389    }
1390
1391    go(partition)
1392        .into_iter()
1393        .map(|rows| tableau_from_rows(partition.clone(), rows))
1394        .collect()
1395}
1396
1397/// Returns all permutations of size `n` in lexicographic image order.
1398///
1399/// This is the same ordering used for transform input coefficients.
1400pub fn all_permutations(n: usize) -> Vec<Permutation> {
1401    let mut current: Vec<_> = (0..n).collect();
1402    let mut out = vec![Permutation {
1403        images: current.clone(),
1404    }];
1405
1406    while next_permutation(&mut current) {
1407        out.push(Permutation {
1408            images: current.clone(),
1409        });
1410    }
1411
1412    out
1413}
1414
1415fn seminormal_generator_rows(
1416    tableaux: &[Tableau],
1417    tableau_index: &BTreeMap<Vec<usize>, usize>,
1418    adjacent_index: usize,
1419    field: PrimeField,
1420) -> Vec<Vec<(usize, u64)>> {
1421    let dim = tableaux.len();
1422    let lhs = adjacent_index + 1;
1423    let rhs = adjacent_index + 2;
1424    let mut rows = vec![Vec::new(); dim];
1425    let mut done = vec![false; dim];
1426
1427    for index in 0..dim {
1428        if done[index] {
1429            continue;
1430        }
1431
1432        let tableau = &tableaux[index];
1433        let lhs_pos = tableau.position(lhs);
1434        let rhs_pos = tableau.position(rhs);
1435
1436        if lhs_pos.0 == rhs_pos.0 {
1437            rows[index] = vec![(index, field.one())];
1438            done[index] = true;
1439        } else if lhs_pos.1 == rhs_pos.1 {
1440            rows[index] = vec![(index, field.neg(field.one()))];
1441            done[index] = true;
1442        } else {
1443            let swapped_key = tableau.swapped_key(lhs, rhs);
1444            let pair_index = *tableau_index
1445                .get(&swapped_key)
1446                .expect("standard tableau adjacent swap");
1447            let first = index.min(pair_index);
1448            let second = index.max(pair_index);
1449            let first_tableau = &tableaux[first];
1450            let distance = first_tableau.content(rhs) - first_tableau.content(lhs);
1451            let distance = field.from_i64(distance);
1452            let inv_distance = field
1453                .inv(distance)
1454                .expect("p > n keeps axial distance invertible");
1455            let inv_distance_squared = field.mul(inv_distance, inv_distance);
1456
1457            rows[first] = vec![
1458                (first, inv_distance),
1459                (second, field.sub(field.one(), inv_distance_squared)),
1460            ];
1461            rows[second] = vec![(first, field.one()), (second, field.neg(inv_distance))];
1462            done[first] = true;
1463            done[second] = true;
1464        }
1465    }
1466
1467    rows
1468}
1469
1470fn embed_restricted_blocks(
1471    irrep: &IrrepData,
1472    transform: &BTreeMap<Partition, Matrix>,
1473    field: PrimeField,
1474) -> Matrix {
1475    let mut embedded = Matrix::zero(irrep.dimension(), irrep.dimension(), field);
1476
1477    for branch in &irrep.branches {
1478        let block = transform
1479            .get(&branch.partition)
1480            .expect("restricted transform block");
1481        debug_assert_eq!(block.rows(), branch.size);
1482        debug_assert_eq!(block.cols(), branch.size);
1483
1484        for row in 0..branch.size {
1485            for col in 0..branch.size {
1486                embedded.set(branch.start + row, branch.start + col, block.get(row, col));
1487            }
1488        }
1489    }
1490
1491    embedded
1492}
1493
1494fn matrix_from_sparse_rows(rows: &[Vec<(usize, u64)>], field: PrimeField) -> Matrix {
1495    let mut matrix = Matrix::zero(rows.len(), rows.len(), field);
1496    for (row, terms) in rows.iter().enumerate() {
1497        for &(col, coeff) in terms {
1498            matrix.set(row, col, coeff);
1499        }
1500    }
1501    matrix
1502}
1503
1504fn tableau_from_rows(shape: Partition, rows: Vec<Vec<usize>>) -> Tableau {
1505    let n = shape.n();
1506    let mut positions = vec![(usize::MAX, usize::MAX); n + 1];
1507    for (row, entries) in rows.iter().enumerate() {
1508        for (col, entry) in entries.iter().enumerate() {
1509            positions[*entry] = (row, col);
1510        }
1511    }
1512
1513    Tableau {
1514        shape,
1515        rows,
1516        positions,
1517    }
1518}
1519
1520fn checked_factorial(n: usize) -> Result<usize, FftError> {
1521    let mut out = 1usize;
1522    for value in 2..=n {
1523        out = out
1524            .checked_mul(value)
1525            .ok_or(FftError::FactorialOverflow(n))?;
1526    }
1527    Ok(out)
1528}
1529
1530fn next_permutation(values: &mut [usize]) -> bool {
1531    if values.len() < 2 {
1532        return false;
1533    }
1534
1535    let mut pivot = values.len() - 2;
1536    while values[pivot] >= values[pivot + 1] {
1537        if pivot == 0 {
1538            values.reverse();
1539            return false;
1540        }
1541        pivot -= 1;
1542    }
1543
1544    let mut successor = values.len() - 1;
1545    while values[successor] <= values[pivot] {
1546        successor -= 1;
1547    }
1548
1549    values.swap(pivot, successor);
1550    values[pivot + 1..].reverse();
1551    true
1552}
1553
1554fn is_prime(value: u64) -> bool {
1555    if value < 2 {
1556        return false;
1557    }
1558    if value == 2 {
1559        return true;
1560    }
1561    if value % 2 == 0 {
1562        return false;
1563    }
1564
1565    let mut divisor = 3;
1566    while divisor <= value / divisor {
1567        if value % divisor == 0 {
1568            return false;
1569        }
1570        divisor += 2;
1571    }
1572
1573    true
1574}
1575
1576#[cfg(test)]
1577mod tests {
1578    use super::*;
1579
1580    #[test]
1581    fn field_arithmetic_works() {
1582        let field = PrimeField::new(17).unwrap();
1583        assert_eq!(field.add(16, 3), 2);
1584        assert_eq!(field.sub(2, 5), 14);
1585        assert_eq!(field.mul(6, 8), 14);
1586        assert_eq!(field.from_i64(-3), 14);
1587        assert_eq!(field.mul(5, field.inv(5).unwrap()), 1);
1588    }
1589
1590    #[test]
1591    fn matrix_inverse_multiplies_to_identity() {
1592        let field = PrimeField::new(101).unwrap();
1593        let matrix = Matrix::from_vec(2, 2, field, vec![1, 2, 3, 5]).unwrap();
1594        let inverse = matrix.inverse().unwrap();
1595        let identity = Matrix::identity(2, field);
1596
1597        assert_eq!(matrix.mul(&inverse).unwrap(), identity);
1598        assert_eq!(inverse.mul(&matrix).unwrap(), identity);
1599    }
1600
1601    #[test]
1602    fn matrix_inverse_rejects_singular_matrix() {
1603        let field = PrimeField::new(101).unwrap();
1604        let matrix = Matrix::from_vec(2, 2, field, vec![1, 2, 2, 4]).unwrap();
1605
1606        assert_eq!(matrix.inverse(), Err(FftError::NonInvertibleMatrix));
1607    }
1608
1609    #[test]
1610    fn rejects_bad_characteristics() {
1611        assert!(matches!(
1612            SymmetricFft::new(5, 5),
1613            Err(FftError::CharacteristicTooSmall { .. })
1614        ));
1615        assert!(matches!(
1616            SymmetricFft::new(5, 9),
1617            Err(FftError::CompositeModulus(9))
1618        ));
1619    }
1620
1621    #[test]
1622    fn partition_counts_are_correct_for_small_n() {
1623        let counts: Vec<_> = (1..=7).map(|n| partitions(n).len()).collect();
1624        assert_eq!(counts, vec![1, 2, 3, 5, 7, 11, 15]);
1625    }
1626
1627    #[test]
1628    fn tableaux_are_in_last_letter_order() {
1629        let shape = Partition::new(vec![2, 1]).unwrap();
1630        let tableaux = standard_tableaux(&shape);
1631        assert_eq!(tableaux.len(), 2);
1632        assert_eq!(tableaux[0].rows(), &[vec![1, 3], vec![2]]);
1633        assert_eq!(tableaux[1].rows(), &[vec![1, 2], vec![3]]);
1634    }
1635
1636    #[test]
1637    fn seminormal_generators_satisfy_coxeter_relations() {
1638        let plan = SymmetricFft::new(5, 101).unwrap();
1639
1640        for n in 2..=5 {
1641            for partition in &plan.levels[n].partitions {
1642                let irrep = plan.levels[n].irreps.get(partition).unwrap();
1643                let identity = Matrix::identity(irrep.dimension(), plan.field);
1644
1645                for i in 0..n - 1 {
1646                    let generator = plan.generator_matrix(partition, i).unwrap();
1647                    assert_eq!(generator.mul(&generator).unwrap(), identity);
1648                }
1649
1650                for i in 0..n.saturating_sub(2) {
1651                    let left = plan.generator_matrix(partition, i).unwrap();
1652                    let right = plan.generator_matrix(partition, i + 1).unwrap();
1653                    let lhs = left.mul(&right).unwrap().mul(&left).unwrap();
1654                    let rhs = right.mul(&left).unwrap().mul(&right).unwrap();
1655                    assert_eq!(lhs, rhs);
1656                }
1657
1658                for i in 0..n - 1 {
1659                    for j in i + 2..n - 1 {
1660                        let left = plan.generator_matrix(partition, i).unwrap();
1661                        let right = plan.generator_matrix(partition, j).unwrap();
1662                        assert_eq!(left.mul(&right).unwrap(), right.mul(&left).unwrap());
1663                    }
1664                }
1665            }
1666        }
1667    }
1668
1669    #[test]
1670    fn fft_matches_naive_dft_for_small_ranks() {
1671        for n in 1..=5 {
1672            let plan = SymmetricFft::new(n, 101).unwrap();
1673            let values: Vec<_> = (0..plan.input_len())
1674                .map(|i| ((i * i + 3 * i + 7) % 101) as u64)
1675                .collect();
1676
1677            let fast = plan.fft(&values).unwrap();
1678            let naive = plan.naive_dft(&values).unwrap();
1679            assert_eq!(fast.blocks(), naive.blocks());
1680        }
1681    }
1682
1683    #[test]
1684    fn inverse_fft_recovers_input_values() {
1685        for n in 1..=6 {
1686            let plan = SymmetricFft::new(n, 101).unwrap();
1687            let values: Vec<_> = (0..plan.input_len())
1688                .map(|i| ((7 * i * i + 11 * i + 103) % 211) as u64)
1689                .collect();
1690            let expected: Vec<_> = values
1691                .iter()
1692                .map(|value| plan.field.normalize(*value))
1693                .collect();
1694
1695            let transform = plan.fft(&values).unwrap();
1696            let recovered = plan.ifft(&transform).unwrap();
1697
1698            assert_eq!(recovered, expected, "failed roundtrip for S_{n}");
1699        }
1700    }
1701
1702    #[test]
1703    fn inverse_fft_is_two_sided_on_transform_image() {
1704        for n in 1..=5 {
1705            let plan = SymmetricFft::new(n, 101).unwrap();
1706            let values: Vec<_> = (0..plan.input_len())
1707                .map(|i| ((13 * i * i + 5 * i + 19) % 101) as u64)
1708                .collect();
1709
1710            let transform = plan.fft(&values).unwrap();
1711            let recovered = plan.ifft(&transform).unwrap();
1712            let transform_again = plan.fft(&recovered).unwrap();
1713
1714            assert_eq!(
1715                transform_again.blocks(),
1716                transform.blocks(),
1717                "failed transform roundtrip for S_{n}"
1718            );
1719        }
1720    }
1721
1722    #[test]
1723    fn inverse_fft_rejects_malformed_transforms() {
1724        let plan = SymmetricFft::new(3, 101).unwrap();
1725        let mut transform = plan.fft(&vec![1; plan.input_len()]).unwrap();
1726        transform.blocks.remove(&Partition::new(vec![3]).unwrap());
1727
1728        assert_eq!(plan.ifft(&transform), Err(FftError::TransformShape));
1729    }
1730
1731    #[test]
1732    fn group_algebra_multiply_matches_naive_convolution() {
1733        for n in 1..=5 {
1734            let plan = SymmetricFft::new(n, 101).unwrap();
1735            let lhs: Vec<_> = (0..plan.input_len())
1736                .map(|i| ((3 * i * i + 7 * i + 11) % 101) as u64)
1737                .collect();
1738            let rhs: Vec<_> = (0..plan.input_len())
1739                .map(|i| ((5 * i * i + 13 * i + 17) % 101) as u64)
1740                .collect();
1741
1742            let fast = plan.multiply(&lhs, &rhs).unwrap();
1743            let naive = plan.naive_multiply(&lhs, &rhs).unwrap();
1744
1745            assert_eq!(fast, naive, "failed multiplication for S_{n}");
1746        }
1747    }
1748
1749    #[test]
1750    fn group_algebra_invert_inverts_group_basis_units() {
1751        for n in 2..=5 {
1752            let plan = SymmetricFft::new(n, 101).unwrap();
1753            let mut values = vec![0; plan.input_len()];
1754            let unit_index = (n + 1).min(plan.input_len() - 1);
1755            values[unit_index] = 7;
1756
1757            let inverse = plan.invert(&values).unwrap();
1758            let mut identity = vec![0; plan.input_len()];
1759            let identity_images = Permutation::identity(n).images().to_vec();
1760            let identity_index = plan
1761                .permutations()
1762                .iter()
1763                .position(|permutation| permutation.images() == identity_images.as_slice())
1764                .unwrap();
1765            identity[identity_index] = 1;
1766
1767            assert_eq!(
1768                plan.multiply(&values, &inverse).unwrap(),
1769                identity,
1770                "failed right inverse for S_{n}"
1771            );
1772            assert_eq!(
1773                plan.multiply(&inverse, &values).unwrap(),
1774                identity,
1775                "failed left inverse for S_{n}"
1776            );
1777        }
1778    }
1779
1780    #[test]
1781    fn group_algebra_invert_rejects_zero_element() {
1782        let plan = SymmetricFft::new(4, 101).unwrap();
1783        let zero = vec![0; plan.input_len()];
1784
1785        assert_eq!(plan.invert(&zero), Err(FftError::NonInvertibleMatrix));
1786    }
1787
1788    #[test]
1789    fn multiplication_transform_matches_block_products() {
1790        for n in 1..=5 {
1791            let plan = SymmetricFft::new(n, 101).unwrap();
1792            let lhs: Vec<_> = (0..plan.input_len())
1793                .map(|i| ((i * i + 2 * i + 3) % 101) as u64)
1794                .collect();
1795            let rhs: Vec<_> = (0..plan.input_len())
1796                .map(|i| ((7 * i * i + 5 * i + 1) % 101) as u64)
1797                .collect();
1798
1799            let product = plan.multiply(&lhs, &rhs).unwrap();
1800            let product_transform = plan.fft(&product).unwrap();
1801            let lhs_transform = plan.fft(&lhs).unwrap();
1802            let rhs_transform = plan.fft(&rhs).unwrap();
1803
1804            for partition in plan.partitions() {
1805                let expected = lhs_transform
1806                    .block(partition)
1807                    .unwrap()
1808                    .mul(rhs_transform.block(partition).unwrap())
1809                    .unwrap();
1810                assert_eq!(
1811                    product_transform.block(partition).unwrap(),
1812                    &expected,
1813                    "failed block product for {partition}"
1814                );
1815            }
1816        }
1817    }
1818
1819    #[test]
1820    #[ignore = "timing-dependent; run with `cargo test --release multiplication_is_faster_than_naive -- --ignored`"]
1821    fn multiplication_is_faster_than_naive() {
1822        let plan = SymmetricFft::new(7, 1_000_003).unwrap();
1823        let lhs: Vec<_> = (0..plan.input_len())
1824            .map(|i| ((3 * i * i + 7 * i + 11) as u64) % plan.field.modulus())
1825            .collect();
1826        let rhs: Vec<_> = (0..plan.input_len())
1827            .map(|i| ((5 * i * i + 13 * i + 17) as u64) % plan.field.modulus())
1828            .collect();
1829
1830        let start = Instant::now();
1831        let fast = plan.multiply(&lhs, &rhs).unwrap();
1832        let fast_elapsed = start.elapsed();
1833
1834        let start = Instant::now();
1835        let naive = plan.naive_multiply(&lhs, &rhs).unwrap();
1836        let naive_elapsed = start.elapsed();
1837
1838        assert_eq!(fast, naive);
1839        assert!(
1840            fast_elapsed < naive_elapsed,
1841            "FFT multiplication took {fast_elapsed:?}, naive multiplication took {naive_elapsed:?}"
1842        );
1843    }
1844
1845    #[test]
1846    fn permutation_words_match_composition_convention() {
1847        for n in 1..=5 {
1848            for permutation in all_permutations(n) {
1849                let mut rebuilt = Permutation::identity(n);
1850                for adjacent_index in permutation.adjacent_word() {
1851                    rebuilt = rebuilt.compose(&Permutation::adjacent(n, adjacent_index));
1852                }
1853                assert_eq!(rebuilt, permutation);
1854            }
1855        }
1856    }
1857}