#![warn(missing_docs)]
use std::collections::BTreeMap;
use std::error::Error;
use std::fmt;
#[cfg(test)]
use std::time::Instant;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum FftError {
CompositeModulus(u64),
CharacteristicTooSmall {
modulus: u64,
n: usize,
},
FactorialOverflow(usize),
InputLength {
expected: usize,
got: usize,
},
MatrixShape,
TransformShape,
NonInvertibleMatrix,
RankZero,
}
impl fmt::Display for FftError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::CompositeModulus(p) => write!(f, "{p} is not a prime modulus"),
Self::CharacteristicTooSmall { modulus, n } => {
write!(
f,
"expected characteristic p > n, got p = {modulus}, n = {n}"
)
}
Self::FactorialOverflow(n) => write!(f, "{n}! does not fit in usize"),
Self::InputLength { expected, got } => {
write!(f, "expected {expected} input coefficients, got {got}")
}
Self::MatrixShape => write!(f, "matrix shape mismatch"),
Self::TransformShape => write!(f, "Fourier transform block shape mismatch"),
Self::NonInvertibleMatrix => write!(f, "matrix block is not invertible"),
Self::RankZero => write!(f, "rank n must be at least 1"),
}
}
}
impl Error for FftError {}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct PrimeField {
modulus: u64,
}
impl PrimeField {
pub fn new(modulus: u64) -> Result<Self, FftError> {
if !is_prime(modulus) {
return Err(FftError::CompositeModulus(modulus));
}
Ok(Self { modulus })
}
pub fn for_symmetric_group(n: usize, modulus: u64) -> Result<Self, FftError> {
let field = Self::new(modulus)?;
if modulus <= n as u64 {
return Err(FftError::CharacteristicTooSmall { modulus, n });
}
Ok(field)
}
pub fn modulus(self) -> u64 {
self.modulus
}
pub fn zero(self) -> u64 {
0
}
pub fn one(self) -> u64 {
1 % self.modulus
}
pub fn normalize(self, value: u64) -> u64 {
value % self.modulus
}
pub fn from_i64(self, value: i64) -> u64 {
let modulus = self.modulus as i128;
let mut value = (value as i128) % modulus;
if value < 0 {
value += modulus;
}
value as u64
}
pub fn add(self, lhs: u64, rhs: u64) -> u64 {
((lhs as u128 + rhs as u128) % self.modulus as u128) as u64
}
pub fn sub(self, lhs: u64, rhs: u64) -> u64 {
((lhs as u128 + self.modulus as u128 - rhs as u128) % self.modulus as u128) as u64
}
pub fn neg(self, value: u64) -> u64 {
if value == 0 { 0 } else { self.modulus - value }
}
pub fn mul(self, lhs: u64, rhs: u64) -> u64 {
((lhs as u128 * rhs as u128) % self.modulus as u128) as u64
}
pub fn pow(self, mut base: u64, mut exp: u64) -> u64 {
let mut acc = self.one();
base = self.normalize(base);
while exp > 0 {
if exp & 1 == 1 {
acc = self.mul(acc, base);
}
base = self.mul(base, base);
exp >>= 1;
}
acc
}
pub fn inv(self, value: u64) -> Option<u64> {
if value == 0 {
None
} else {
Some(self.pow(value, self.modulus - 2))
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Matrix {
rows: usize,
cols: usize,
modulus: u64,
data: Vec<u64>,
}
impl Matrix {
pub fn zero(rows: usize, cols: usize, field: PrimeField) -> Self {
Self {
rows,
cols,
modulus: field.modulus(),
data: vec![0; rows * cols],
}
}
pub fn identity(size: usize, field: PrimeField) -> Self {
let mut matrix = Self::zero(size, size, field);
for i in 0..size {
matrix.set(i, i, field.one());
}
matrix
}
pub fn from_vec(
rows: usize,
cols: usize,
field: PrimeField,
data: Vec<u64>,
) -> Result<Self, FftError> {
if data.len() != rows * cols {
return Err(FftError::MatrixShape);
}
Ok(Self {
rows,
cols,
modulus: field.modulus(),
data: data
.into_iter()
.map(|value| field.normalize(value))
.collect(),
})
}
pub fn rows(&self) -> usize {
self.rows
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn modulus(&self) -> u64 {
self.modulus
}
pub fn data(&self) -> &[u64] {
&self.data
}
pub fn get(&self, row: usize, col: usize) -> u64 {
self.data[row * self.cols + col]
}
pub fn set(&mut self, row: usize, col: usize, value: u64) {
self.data[row * self.cols + col] = value % self.modulus;
}
pub fn add_assign(&mut self, rhs: &Self) -> Result<(), FftError> {
if self.rows != rhs.rows || self.cols != rhs.cols || self.modulus != rhs.modulus {
return Err(FftError::MatrixShape);
}
let field = PrimeField {
modulus: self.modulus,
};
for (lhs, rhs) in self.data.iter_mut().zip(rhs.data.iter()) {
*lhs = field.add(*lhs, *rhs);
}
Ok(())
}
pub fn add_scaled_assign(&mut self, scalar: u64, rhs: &Self) -> Result<(), FftError> {
if self.rows != rhs.rows || self.cols != rhs.cols || self.modulus != rhs.modulus {
return Err(FftError::MatrixShape);
}
let field = PrimeField {
modulus: self.modulus,
};
let scalar = field.normalize(scalar);
for (lhs, rhs) in self.data.iter_mut().zip(rhs.data.iter()) {
*lhs = field.add(*lhs, field.mul(scalar, *rhs));
}
Ok(())
}
pub fn mul(&self, rhs: &Self) -> Result<Self, FftError> {
if self.cols != rhs.rows || self.modulus != rhs.modulus {
return Err(FftError::MatrixShape);
}
let field = PrimeField {
modulus: self.modulus,
};
let mut out = Self::zero(self.rows, rhs.cols, field);
for row in 0..self.rows {
for mid in 0..self.cols {
let lhs = self.get(row, mid);
if lhs == 0 {
continue;
}
for col in 0..rhs.cols {
let idx = row * rhs.cols + col;
out.data[idx] = field.add(out.data[idx], field.mul(lhs, rhs.get(mid, col)));
}
}
}
Ok(out)
}
pub fn inverse(&self) -> Result<Self, FftError> {
if self.rows != self.cols {
return Err(FftError::MatrixShape);
}
let field = PrimeField {
modulus: self.modulus,
};
let size = self.rows;
let mut rows = vec![vec![field.zero(); 2 * size]; size];
for (row_index, row) in rows.iter_mut().enumerate() {
for (col, entry) in row.iter_mut().take(size).enumerate() {
*entry = self.get(row_index, col);
}
row[size + row_index] = field.one();
}
for col in 0..size {
let pivot_row = (col..size)
.find(|&row| rows[row][col] != field.zero())
.ok_or(FftError::NonInvertibleMatrix)?;
if pivot_row != col {
rows.swap(col, pivot_row);
}
let pivot_inverse = field
.inv(rows[col][col])
.ok_or(FftError::NonInvertibleMatrix)?;
for entry in &mut rows[col] {
*entry = field.mul(*entry, pivot_inverse);
}
let pivot = rows[col].clone();
for (row_index, row) in rows.iter_mut().enumerate() {
if row_index == col {
continue;
}
let factor = row[col];
if factor == field.zero() {
continue;
}
for (entry, pivot_entry) in row.iter_mut().zip(pivot.iter()) {
*entry = field.sub(*entry, field.mul(factor, *pivot_entry));
}
}
}
let mut data = Vec::with_capacity(size * size);
for row in &rows {
data.extend_from_slice(&row[size..]);
}
Self::from_vec(size, size, field, data)
}
fn left_multiply_sparse_rows(&self, rows: &[Vec<(usize, u64)>]) -> Self {
debug_assert_eq!(self.rows, rows.len());
let field = PrimeField {
modulus: self.modulus,
};
let mut out = Self::zero(self.rows, self.cols, field);
for (row, terms) in rows.iter().enumerate() {
for &(src_row, coeff) in terms {
if coeff == 0 {
continue;
}
for col in 0..self.cols {
let idx = row * self.cols + col;
out.data[idx] =
field.add(out.data[idx], field.mul(coeff, self.get(src_row, col)));
}
}
}
out
}
fn submatrix(&self, start_row: usize, start_col: usize, rows: usize, cols: usize) -> Self {
let field = PrimeField {
modulus: self.modulus,
};
let mut out = Self::zero(rows, cols, field);
for row in 0..rows {
for col in 0..cols {
out.set(row, col, self.get(start_row + row, start_col + col));
}
}
out
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Partition(Vec<usize>);
impl Partition {
pub fn new(parts: Vec<usize>) -> Result<Self, FftError> {
if parts.contains(&0) {
return Err(FftError::MatrixShape);
}
for pair in parts.windows(2) {
if pair[0] < pair[1] {
return Err(FftError::MatrixShape);
}
}
Ok(Self(parts))
}
pub fn parts(&self) -> &[usize] {
&self.0
}
pub fn n(&self) -> usize {
self.0.iter().sum()
}
pub fn removable_rows(&self) -> Vec<usize> {
let mut rows = Vec::new();
for row in 0..self.0.len() {
let next = self.0.get(row + 1).copied().unwrap_or(0);
if self.0[row] > next {
rows.push(row);
}
}
rows
}
pub fn remove_box(&self, row: usize) -> Self {
let mut parts = self.0.clone();
parts[row] -= 1;
if parts[row] == 0 {
parts.remove(row);
}
Self(parts)
}
}
impl fmt::Display for Partition {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "(")?;
for (i, part) in self.0.iter().enumerate() {
if i > 0 {
write!(f, ",")?;
}
write!(f, "{part}")?;
}
write!(f, ")")
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Tableau {
shape: Partition,
rows: Vec<Vec<usize>>,
positions: Vec<(usize, usize)>,
}
impl Tableau {
pub fn shape(&self) -> &Partition {
&self.shape
}
pub fn rows(&self) -> &[Vec<usize>] {
&self.rows
}
pub fn position(&self, entry: usize) -> (usize, usize) {
assert!(entry > 0 && entry < self.positions.len());
self.positions[entry]
}
fn key(&self) -> Vec<usize> {
self.rows.iter().flatten().copied().collect()
}
fn swapped_key(&self, lhs: usize, rhs: usize) -> Vec<usize> {
self.rows
.iter()
.flatten()
.map(|entry| {
if *entry == lhs {
rhs
} else if *entry == rhs {
lhs
} else {
*entry
}
})
.collect()
}
fn content(&self, entry: usize) -> i64 {
let (row, col) = self.position(entry);
col as i64 - row as i64
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Permutation {
images: Vec<usize>,
}
impl Permutation {
pub fn identity(n: usize) -> Self {
Self {
images: (0..n).collect(),
}
}
pub fn images(&self) -> &[usize] {
&self.images
}
pub fn len(&self) -> usize {
self.images.len()
}
pub fn is_empty(&self) -> bool {
self.images.is_empty()
}
pub fn compose(&self, rhs: &Self) -> Self {
assert_eq!(self.len(), rhs.len());
let images = rhs.images.iter().map(|image| self.images[*image]).collect();
Self { images }
}
pub fn adjacent(n: usize, index: usize) -> Self {
let mut images: Vec<_> = (0..n).collect();
images.swap(index, index + 1);
Self { images }
}
fn cycle_moving_last_to(n: usize, target: usize) -> Self {
let mut images: Vec<_> = (0..n).collect();
if target < n - 1 {
for (i, image) in images.iter_mut().enumerate().take(n - 1).skip(target) {
*image = i + 1;
}
images[n - 1] = target;
}
Self { images }
}
fn embed_fixing_last(&self) -> Self {
let mut images = self.images.clone();
images.push(self.images.len());
Self { images }
}
fn adjacent_word(&self) -> Vec<usize> {
let n = self.len();
let mut current: Vec<_> = (0..n).collect();
let mut word = Vec::new();
for pos in 0..n {
let target = self.images[pos];
let mut current_pos = current
.iter()
.position(|value| *value == target)
.expect("permutation image");
while current_pos > pos {
current.swap(current_pos - 1, current_pos);
word.push(current_pos - 1);
current_pos -= 1;
}
}
word
}
}
#[derive(Clone, Debug)]
pub struct FourierTransform {
n: usize,
modulus: u64,
blocks: BTreeMap<Partition, Matrix>,
}
impl FourierTransform {
pub fn n(&self) -> usize {
self.n
}
pub fn modulus(&self) -> u64 {
self.modulus
}
pub fn blocks(&self) -> &BTreeMap<Partition, Matrix> {
&self.blocks
}
pub fn block(&self, partition: &Partition) -> Option<&Matrix> {
self.blocks.get(partition)
}
}
#[derive(Clone, Debug)]
pub struct SymmetricFft {
n: usize,
field: PrimeField,
levels: Vec<Level>,
}
impl SymmetricFft {
pub fn new(n: usize, modulus: u64) -> Result<Self, FftError> {
if n == 0 {
return Err(FftError::RankZero);
}
checked_factorial(n)?;
let field = PrimeField::for_symmetric_group(n, modulus)?;
let mut levels = Vec::with_capacity(n + 1);
levels.push(Level::empty(field));
for k in 1..=n {
let previous = levels.last().expect("previous level");
levels.push(Level::new(k, field, previous));
}
Ok(Self { n, field, levels })
}
pub fn n(&self) -> usize {
self.n
}
pub fn field(&self) -> PrimeField {
self.field
}
pub fn input_len(&self) -> usize {
self.levels[self.n].permutations.len()
}
pub fn permutations(&self) -> &[Permutation] {
&self.levels[self.n].permutations
}
pub fn partitions(&self) -> &[Partition] {
&self.levels[self.n].partitions
}
pub fn standard_tableaux(&self, partition: &Partition) -> Option<&[Tableau]> {
self.levels[self.n]
.irreps
.get(partition)
.map(|irrep| irrep.tableaux.as_slice())
}
pub fn fft(&self, values: &[u64]) -> Result<FourierTransform, FftError> {
self.validate_input(values)?;
let values: Vec<_> = values
.iter()
.map(|value| self.field.normalize(*value))
.collect();
let blocks = self.fft_level(self.n, &values)?;
Ok(FourierTransform {
n: self.n,
modulus: self.field.modulus(),
blocks,
})
}
pub fn ifft(&self, transform: &FourierTransform) -> Result<Vec<u64>, FftError> {
self.validate_transform(transform)?;
self.ifft_level(self.n, transform.blocks())
}
pub fn multiply(&self, lhs: &[u64], rhs: &[u64]) -> Result<Vec<u64>, FftError> {
self.validate_input(lhs)?;
self.validate_input(rhs)?;
let lhs_transform = self.fft(lhs)?;
let rhs_transform = self.fft(rhs)?;
let mut blocks = BTreeMap::new();
for partition in &self.levels[self.n].partitions {
let lhs_block = lhs_transform
.block(partition)
.ok_or(FftError::TransformShape)?;
let rhs_block = rhs_transform
.block(partition)
.ok_or(FftError::TransformShape)?;
blocks.insert(partition.clone(), lhs_block.mul(rhs_block)?);
}
self.ifft(&FourierTransform {
n: self.n,
modulus: self.field.modulus(),
blocks,
})
}
pub fn invert(&self, values: &[u64]) -> Result<Vec<u64>, FftError> {
self.validate_input(values)?;
let transform = self.fft(values)?;
let inverse = self.invert_transform(&transform)?;
self.ifft(&inverse)
}
pub fn invert_transform(
&self,
transform: &FourierTransform,
) -> Result<FourierTransform, FftError> {
self.validate_transform(transform)?;
let mut blocks = BTreeMap::new();
for partition in &self.levels[self.n].partitions {
let block = transform.block(partition).ok_or(FftError::TransformShape)?;
blocks.insert(partition.clone(), block.inverse()?);
}
Ok(FourierTransform {
n: self.n,
modulus: self.field.modulus(),
blocks,
})
}
pub fn naive_multiply(&self, lhs: &[u64], rhs: &[u64]) -> Result<Vec<u64>, FftError> {
self.validate_input(lhs)?;
self.validate_input(rhs)?;
let level = &self.levels[self.n];
let mut out = vec![self.field.zero(); level.permutations.len()];
for (lhs_index, lhs_perm) in level.permutations.iter().enumerate() {
let lhs_value = self.field.normalize(lhs[lhs_index]);
if lhs_value == 0 {
continue;
}
for (rhs_index, rhs_perm) in level.permutations.iter().enumerate() {
let rhs_value = self.field.normalize(rhs[rhs_index]);
if rhs_value == 0 {
continue;
}
let product = lhs_perm.compose(rhs_perm);
let product_index = *level
.permutation_index
.get(product.images())
.expect("permutation product index");
out[product_index] = self
.field
.add(out[product_index], self.field.mul(lhs_value, rhs_value));
}
}
Ok(out)
}
pub fn naive_dft(&self, values: &[u64]) -> Result<FourierTransform, FftError> {
self.validate_input(values)?;
let level = &self.levels[self.n];
let mut blocks = BTreeMap::new();
for partition in &level.partitions {
let irrep = level.irreps.get(partition).expect("irrep data");
let mut block = Matrix::zero(irrep.dimension(), irrep.dimension(), self.field);
for (perm_index, perm) in level.permutations.iter().enumerate() {
let value = self.field.normalize(values[perm_index]);
if value == 0 {
continue;
}
let rho = self.representation_matrix(partition, perm)?;
block.add_scaled_assign(value, &rho)?;
}
blocks.insert(partition.clone(), block);
}
Ok(FourierTransform {
n: self.n,
modulus: self.field.modulus(),
blocks,
})
}
pub fn generator_matrix(&self, partition: &Partition, adjacent_index: usize) -> Option<Matrix> {
let level = self.levels.get(partition.n())?;
let irrep = level.irreps.get(partition)?;
let rows = irrep.generator_rows.get(adjacent_index)?;
Some(matrix_from_sparse_rows(rows, self.field))
}
pub fn representation_matrix(
&self,
partition: &Partition,
permutation: &Permutation,
) -> Result<Matrix, FftError> {
if permutation.len() != partition.n() {
return Err(FftError::MatrixShape);
}
let level = self
.levels
.get(partition.n())
.ok_or(FftError::MatrixShape)?;
let irrep = level.irreps.get(partition).ok_or(FftError::MatrixShape)?;
let mut matrix = Matrix::identity(irrep.dimension(), self.field);
for adjacent_index in permutation.adjacent_word() {
let generator =
matrix_from_sparse_rows(&irrep.generator_rows[adjacent_index], self.field);
matrix = matrix.mul(&generator)?;
}
Ok(matrix)
}
fn validate_input(&self, values: &[u64]) -> Result<(), FftError> {
let expected = self.input_len();
if values.len() != expected {
return Err(FftError::InputLength {
expected,
got: values.len(),
});
}
Ok(())
}
fn validate_transform(&self, transform: &FourierTransform) -> Result<(), FftError> {
if transform.n() != self.n || transform.modulus() != self.field.modulus() {
return Err(FftError::TransformShape);
}
let level = &self.levels[self.n];
if transform.blocks().len() != level.partitions.len() {
return Err(FftError::TransformShape);
}
for partition in &level.partitions {
let irrep = level.irreps.get(partition).expect("irrep data");
let block = transform.block(partition).ok_or(FftError::TransformShape)?;
if block.rows() != irrep.dimension()
|| block.cols() != irrep.dimension()
|| block.modulus() != self.field.modulus()
{
return Err(FftError::TransformShape);
}
}
Ok(())
}
fn fft_level(&self, k: usize, values: &[u64]) -> Result<BTreeMap<Partition, Matrix>, FftError> {
let level = &self.levels[k];
if k == 1 {
let partition = Partition(vec![1]);
let matrix = Matrix::from_vec(1, 1, self.field, vec![values[0]])?;
return Ok(BTreeMap::from([(partition, matrix)]));
}
let previous = &self.levels[k - 1];
let mut sub_transforms = Vec::with_capacity(k);
for target in 0..k {
let coset_rep = Permutation::cycle_moving_last_to(k, target);
let mut sub_values = vec![self.field.zero(); previous.permutations.len()];
for (sub_index, sub_perm) in previous.permutations.iter().enumerate() {
let embedded = sub_perm.embed_fixing_last();
let perm = coset_rep.compose(&embedded);
let value_index = *level
.permutation_index
.get(perm.images())
.expect("coset permutation index");
sub_values[sub_index] = values[value_index];
}
sub_transforms.push(self.fft_level(k - 1, &sub_values)?);
}
let mut out = BTreeMap::new();
for partition in &level.partitions {
let irrep = level.irreps.get(partition).expect("irrep data");
let mut block = Matrix::zero(irrep.dimension(), irrep.dimension(), self.field);
for (target, transform) in sub_transforms.iter().enumerate() {
let mut embedded = embed_restricted_blocks(irrep, transform, self.field);
for adjacent_index in (target..k - 1).rev() {
embedded =
embedded.left_multiply_sparse_rows(&irrep.generator_rows[adjacent_index]);
}
block.add_assign(&embedded)?;
}
out.insert(partition.clone(), block);
}
Ok(out)
}
fn ifft_level(
&self,
k: usize,
blocks: &BTreeMap<Partition, Matrix>,
) -> Result<Vec<u64>, FftError> {
let level = &self.levels[k];
if k == 1 {
let block = blocks
.get(&Partition(vec![1]))
.ok_or(FftError::TransformShape)?;
return Ok(vec![block.get(0, 0)]);
}
let previous = &self.levels[k - 1];
let mut values = vec![self.field.zero(); level.permutations.len()];
for target in 0..k {
let mut sub_blocks = previous
.partitions
.iter()
.map(|partition| {
let irrep = previous.irreps.get(partition).expect("previous irrep");
(
partition.clone(),
Matrix::zero(irrep.dimension(), irrep.dimension(), self.field),
)
})
.collect::<BTreeMap<_, _>>();
for partition in &level.partitions {
let irrep = level.irreps.get(partition).expect("irrep data");
let mut shifted = blocks
.get(partition)
.ok_or(FftError::TransformShape)?
.clone();
for adjacent_index in target..k - 1 {
shifted =
shifted.left_multiply_sparse_rows(&irrep.generator_rows[adjacent_index]);
}
let numerator = self.field.normalize(irrep.dimension() as u64);
for branch in &irrep.branches {
let denominator = self.field.mul(
self.field.normalize(k as u64),
self.field.normalize(branch.size as u64),
);
let scalar = self
.field
.mul(numerator, self.field.inv(denominator).expect("p > n"));
let projected =
shifted.submatrix(branch.start, branch.start, branch.size, branch.size);
sub_blocks
.get_mut(&branch.partition)
.expect("sub-block")
.add_scaled_assign(scalar, &projected)?;
}
}
let sub_values = self.ifft_level(k - 1, &sub_blocks)?;
let coset_rep = Permutation::cycle_moving_last_to(k, target);
for (sub_index, sub_perm) in previous.permutations.iter().enumerate() {
let embedded = sub_perm.embed_fixing_last();
let perm = coset_rep.compose(&embedded);
let value_index = *level
.permutation_index
.get(perm.images())
.expect("coset permutation index");
values[value_index] = sub_values[sub_index];
}
}
Ok(values)
}
}
#[derive(Clone, Debug)]
struct Level {
partitions: Vec<Partition>,
irreps: BTreeMap<Partition, IrrepData>,
permutations: Vec<Permutation>,
permutation_index: BTreeMap<Vec<usize>, usize>,
}
impl Level {
fn empty(_field: PrimeField) -> Self {
Self {
partitions: Vec::new(),
irreps: BTreeMap::new(),
permutations: Vec::new(),
permutation_index: BTreeMap::new(),
}
}
fn new(n: usize, field: PrimeField, previous: &Self) -> Self {
let partitions = partitions(n);
let permutations = all_permutations(n);
let permutation_index = permutations
.iter()
.enumerate()
.map(|(index, permutation)| (permutation.images.clone(), index))
.collect();
let mut irreps = BTreeMap::new();
for partition in &partitions {
let irrep = IrrepData::new(partition.clone(), field, previous);
irreps.insert(partition.clone(), irrep);
}
Self {
partitions,
irreps,
permutations,
permutation_index,
}
}
}
#[derive(Clone, Debug)]
struct IrrepData {
tableaux: Vec<Tableau>,
branches: Vec<BranchBlock>,
generator_rows: Vec<Vec<Vec<(usize, u64)>>>,
}
impl IrrepData {
fn new(partition: Partition, field: PrimeField, previous: &Level) -> Self {
let tableaux = standard_tableaux(&partition);
let tableau_index = tableaux
.iter()
.enumerate()
.map(|(index, tableau)| (tableau.key(), index))
.collect::<BTreeMap<_, _>>();
let mut branches = Vec::new();
let mut start = 0;
if partition.n() > 1 {
for row in partition.removable_rows() {
let subpartition = partition.remove_box(row);
let size = previous
.irreps
.get(&subpartition)
.expect("branch irrep")
.dimension();
branches.push(BranchBlock {
partition: subpartition,
start,
size,
});
start += size;
}
}
let generator_rows = (0..partition.n().saturating_sub(1))
.map(|adjacent_index| {
seminormal_generator_rows(&tableaux, &tableau_index, adjacent_index, field)
})
.collect();
Self {
tableaux,
branches,
generator_rows,
}
}
fn dimension(&self) -> usize {
self.tableaux.len()
}
}
#[derive(Clone, Debug)]
struct BranchBlock {
partition: Partition,
start: usize,
size: usize,
}
pub fn partitions(n: usize) -> Vec<Partition> {
fn go(remaining: usize, max_part: usize, current: &mut Vec<usize>, out: &mut Vec<Partition>) {
if remaining == 0 {
out.push(Partition(current.clone()));
return;
}
for part in (1..=remaining.min(max_part)).rev() {
current.push(part);
go(remaining - part, part, current, out);
current.pop();
}
}
let mut out = Vec::new();
go(n, n, &mut Vec::new(), &mut out);
out
}
pub fn standard_tableaux(partition: &Partition) -> Vec<Tableau> {
fn go(shape: &Partition) -> Vec<Vec<Vec<usize>>> {
if shape.n() == 0 {
return vec![Vec::new()];
}
let mut out = Vec::new();
let entry = shape.n();
for row in shape.removable_rows() {
let subshape = shape.remove_box(row);
for mut rows in go(&subshape) {
if row == rows.len() {
rows.push(vec![entry]);
} else {
rows[row].push(entry);
}
out.push(rows);
}
}
out
}
go(partition)
.into_iter()
.map(|rows| tableau_from_rows(partition.clone(), rows))
.collect()
}
pub fn all_permutations(n: usize) -> Vec<Permutation> {
let mut current: Vec<_> = (0..n).collect();
let mut out = vec![Permutation {
images: current.clone(),
}];
while next_permutation(&mut current) {
out.push(Permutation {
images: current.clone(),
});
}
out
}
fn seminormal_generator_rows(
tableaux: &[Tableau],
tableau_index: &BTreeMap<Vec<usize>, usize>,
adjacent_index: usize,
field: PrimeField,
) -> Vec<Vec<(usize, u64)>> {
let dim = tableaux.len();
let lhs = adjacent_index + 1;
let rhs = adjacent_index + 2;
let mut rows = vec![Vec::new(); dim];
let mut done = vec![false; dim];
for index in 0..dim {
if done[index] {
continue;
}
let tableau = &tableaux[index];
let lhs_pos = tableau.position(lhs);
let rhs_pos = tableau.position(rhs);
if lhs_pos.0 == rhs_pos.0 {
rows[index] = vec![(index, field.one())];
done[index] = true;
} else if lhs_pos.1 == rhs_pos.1 {
rows[index] = vec![(index, field.neg(field.one()))];
done[index] = true;
} else {
let swapped_key = tableau.swapped_key(lhs, rhs);
let pair_index = *tableau_index
.get(&swapped_key)
.expect("standard tableau adjacent swap");
let first = index.min(pair_index);
let second = index.max(pair_index);
let first_tableau = &tableaux[first];
let distance = first_tableau.content(rhs) - first_tableau.content(lhs);
let distance = field.from_i64(distance);
let inv_distance = field
.inv(distance)
.expect("p > n keeps axial distance invertible");
let inv_distance_squared = field.mul(inv_distance, inv_distance);
rows[first] = vec![
(first, inv_distance),
(second, field.sub(field.one(), inv_distance_squared)),
];
rows[second] = vec![(first, field.one()), (second, field.neg(inv_distance))];
done[first] = true;
done[second] = true;
}
}
rows
}
fn embed_restricted_blocks(
irrep: &IrrepData,
transform: &BTreeMap<Partition, Matrix>,
field: PrimeField,
) -> Matrix {
let mut embedded = Matrix::zero(irrep.dimension(), irrep.dimension(), field);
for branch in &irrep.branches {
let block = transform
.get(&branch.partition)
.expect("restricted transform block");
debug_assert_eq!(block.rows(), branch.size);
debug_assert_eq!(block.cols(), branch.size);
for row in 0..branch.size {
for col in 0..branch.size {
embedded.set(branch.start + row, branch.start + col, block.get(row, col));
}
}
}
embedded
}
fn matrix_from_sparse_rows(rows: &[Vec<(usize, u64)>], field: PrimeField) -> Matrix {
let mut matrix = Matrix::zero(rows.len(), rows.len(), field);
for (row, terms) in rows.iter().enumerate() {
for &(col, coeff) in terms {
matrix.set(row, col, coeff);
}
}
matrix
}
fn tableau_from_rows(shape: Partition, rows: Vec<Vec<usize>>) -> Tableau {
let n = shape.n();
let mut positions = vec![(usize::MAX, usize::MAX); n + 1];
for (row, entries) in rows.iter().enumerate() {
for (col, entry) in entries.iter().enumerate() {
positions[*entry] = (row, col);
}
}
Tableau {
shape,
rows,
positions,
}
}
fn checked_factorial(n: usize) -> Result<usize, FftError> {
let mut out = 1usize;
for value in 2..=n {
out = out
.checked_mul(value)
.ok_or(FftError::FactorialOverflow(n))?;
}
Ok(out)
}
fn next_permutation(values: &mut [usize]) -> bool {
if values.len() < 2 {
return false;
}
let mut pivot = values.len() - 2;
while values[pivot] >= values[pivot + 1] {
if pivot == 0 {
values.reverse();
return false;
}
pivot -= 1;
}
let mut successor = values.len() - 1;
while values[successor] <= values[pivot] {
successor -= 1;
}
values.swap(pivot, successor);
values[pivot + 1..].reverse();
true
}
fn is_prime(value: u64) -> bool {
if value < 2 {
return false;
}
if value == 2 {
return true;
}
if value % 2 == 0 {
return false;
}
let mut divisor = 3;
while divisor <= value / divisor {
if value % divisor == 0 {
return false;
}
divisor += 2;
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn field_arithmetic_works() {
let field = PrimeField::new(17).unwrap();
assert_eq!(field.add(16, 3), 2);
assert_eq!(field.sub(2, 5), 14);
assert_eq!(field.mul(6, 8), 14);
assert_eq!(field.from_i64(-3), 14);
assert_eq!(field.mul(5, field.inv(5).unwrap()), 1);
}
#[test]
fn matrix_inverse_multiplies_to_identity() {
let field = PrimeField::new(101).unwrap();
let matrix = Matrix::from_vec(2, 2, field, vec![1, 2, 3, 5]).unwrap();
let inverse = matrix.inverse().unwrap();
let identity = Matrix::identity(2, field);
assert_eq!(matrix.mul(&inverse).unwrap(), identity);
assert_eq!(inverse.mul(&matrix).unwrap(), identity);
}
#[test]
fn matrix_inverse_rejects_singular_matrix() {
let field = PrimeField::new(101).unwrap();
let matrix = Matrix::from_vec(2, 2, field, vec![1, 2, 2, 4]).unwrap();
assert_eq!(matrix.inverse(), Err(FftError::NonInvertibleMatrix));
}
#[test]
fn rejects_bad_characteristics() {
assert!(matches!(
SymmetricFft::new(5, 5),
Err(FftError::CharacteristicTooSmall { .. })
));
assert!(matches!(
SymmetricFft::new(5, 9),
Err(FftError::CompositeModulus(9))
));
}
#[test]
fn partition_counts_are_correct_for_small_n() {
let counts: Vec<_> = (1..=7).map(|n| partitions(n).len()).collect();
assert_eq!(counts, vec![1, 2, 3, 5, 7, 11, 15]);
}
#[test]
fn tableaux_are_in_last_letter_order() {
let shape = Partition::new(vec![2, 1]).unwrap();
let tableaux = standard_tableaux(&shape);
assert_eq!(tableaux.len(), 2);
assert_eq!(tableaux[0].rows(), &[vec![1, 3], vec![2]]);
assert_eq!(tableaux[1].rows(), &[vec![1, 2], vec![3]]);
}
#[test]
fn seminormal_generators_satisfy_coxeter_relations() {
let plan = SymmetricFft::new(5, 101).unwrap();
for n in 2..=5 {
for partition in &plan.levels[n].partitions {
let irrep = plan.levels[n].irreps.get(partition).unwrap();
let identity = Matrix::identity(irrep.dimension(), plan.field);
for i in 0..n - 1 {
let generator = plan.generator_matrix(partition, i).unwrap();
assert_eq!(generator.mul(&generator).unwrap(), identity);
}
for i in 0..n.saturating_sub(2) {
let left = plan.generator_matrix(partition, i).unwrap();
let right = plan.generator_matrix(partition, i + 1).unwrap();
let lhs = left.mul(&right).unwrap().mul(&left).unwrap();
let rhs = right.mul(&left).unwrap().mul(&right).unwrap();
assert_eq!(lhs, rhs);
}
for i in 0..n - 1 {
for j in i + 2..n - 1 {
let left = plan.generator_matrix(partition, i).unwrap();
let right = plan.generator_matrix(partition, j).unwrap();
assert_eq!(left.mul(&right).unwrap(), right.mul(&left).unwrap());
}
}
}
}
}
#[test]
fn fft_matches_naive_dft_for_small_ranks() {
for n in 1..=5 {
let plan = SymmetricFft::new(n, 101).unwrap();
let values: Vec<_> = (0..plan.input_len())
.map(|i| ((i * i + 3 * i + 7) % 101) as u64)
.collect();
let fast = plan.fft(&values).unwrap();
let naive = plan.naive_dft(&values).unwrap();
assert_eq!(fast.blocks(), naive.blocks());
}
}
#[test]
fn inverse_fft_recovers_input_values() {
for n in 1..=6 {
let plan = SymmetricFft::new(n, 101).unwrap();
let values: Vec<_> = (0..plan.input_len())
.map(|i| ((7 * i * i + 11 * i + 103) % 211) as u64)
.collect();
let expected: Vec<_> = values
.iter()
.map(|value| plan.field.normalize(*value))
.collect();
let transform = plan.fft(&values).unwrap();
let recovered = plan.ifft(&transform).unwrap();
assert_eq!(recovered, expected, "failed roundtrip for S_{n}");
}
}
#[test]
fn inverse_fft_is_two_sided_on_transform_image() {
for n in 1..=5 {
let plan = SymmetricFft::new(n, 101).unwrap();
let values: Vec<_> = (0..plan.input_len())
.map(|i| ((13 * i * i + 5 * i + 19) % 101) as u64)
.collect();
let transform = plan.fft(&values).unwrap();
let recovered = plan.ifft(&transform).unwrap();
let transform_again = plan.fft(&recovered).unwrap();
assert_eq!(
transform_again.blocks(),
transform.blocks(),
"failed transform roundtrip for S_{n}"
);
}
}
#[test]
fn inverse_fft_rejects_malformed_transforms() {
let plan = SymmetricFft::new(3, 101).unwrap();
let mut transform = plan.fft(&vec![1; plan.input_len()]).unwrap();
transform.blocks.remove(&Partition::new(vec![3]).unwrap());
assert_eq!(plan.ifft(&transform), Err(FftError::TransformShape));
}
#[test]
fn group_algebra_multiply_matches_naive_convolution() {
for n in 1..=5 {
let plan = SymmetricFft::new(n, 101).unwrap();
let lhs: Vec<_> = (0..plan.input_len())
.map(|i| ((3 * i * i + 7 * i + 11) % 101) as u64)
.collect();
let rhs: Vec<_> = (0..plan.input_len())
.map(|i| ((5 * i * i + 13 * i + 17) % 101) as u64)
.collect();
let fast = plan.multiply(&lhs, &rhs).unwrap();
let naive = plan.naive_multiply(&lhs, &rhs).unwrap();
assert_eq!(fast, naive, "failed multiplication for S_{n}");
}
}
#[test]
fn group_algebra_invert_inverts_group_basis_units() {
for n in 2..=5 {
let plan = SymmetricFft::new(n, 101).unwrap();
let mut values = vec![0; plan.input_len()];
let unit_index = (n + 1).min(plan.input_len() - 1);
values[unit_index] = 7;
let inverse = plan.invert(&values).unwrap();
let mut identity = vec![0; plan.input_len()];
let identity_images = Permutation::identity(n).images().to_vec();
let identity_index = plan
.permutations()
.iter()
.position(|permutation| permutation.images() == identity_images.as_slice())
.unwrap();
identity[identity_index] = 1;
assert_eq!(
plan.multiply(&values, &inverse).unwrap(),
identity,
"failed right inverse for S_{n}"
);
assert_eq!(
plan.multiply(&inverse, &values).unwrap(),
identity,
"failed left inverse for S_{n}"
);
}
}
#[test]
fn group_algebra_invert_rejects_zero_element() {
let plan = SymmetricFft::new(4, 101).unwrap();
let zero = vec![0; plan.input_len()];
assert_eq!(plan.invert(&zero), Err(FftError::NonInvertibleMatrix));
}
#[test]
fn multiplication_transform_matches_block_products() {
for n in 1..=5 {
let plan = SymmetricFft::new(n, 101).unwrap();
let lhs: Vec<_> = (0..plan.input_len())
.map(|i| ((i * i + 2 * i + 3) % 101) as u64)
.collect();
let rhs: Vec<_> = (0..plan.input_len())
.map(|i| ((7 * i * i + 5 * i + 1) % 101) as u64)
.collect();
let product = plan.multiply(&lhs, &rhs).unwrap();
let product_transform = plan.fft(&product).unwrap();
let lhs_transform = plan.fft(&lhs).unwrap();
let rhs_transform = plan.fft(&rhs).unwrap();
for partition in plan.partitions() {
let expected = lhs_transform
.block(partition)
.unwrap()
.mul(rhs_transform.block(partition).unwrap())
.unwrap();
assert_eq!(
product_transform.block(partition).unwrap(),
&expected,
"failed block product for {partition}"
);
}
}
}
#[test]
#[ignore = "timing-dependent; run with `cargo test --release multiplication_is_faster_than_naive -- --ignored`"]
fn multiplication_is_faster_than_naive() {
let plan = SymmetricFft::new(7, 1_000_003).unwrap();
let lhs: Vec<_> = (0..plan.input_len())
.map(|i| ((3 * i * i + 7 * i + 11) as u64) % plan.field.modulus())
.collect();
let rhs: Vec<_> = (0..plan.input_len())
.map(|i| ((5 * i * i + 13 * i + 17) as u64) % plan.field.modulus())
.collect();
let start = Instant::now();
let fast = plan.multiply(&lhs, &rhs).unwrap();
let fast_elapsed = start.elapsed();
let start = Instant::now();
let naive = plan.naive_multiply(&lhs, &rhs).unwrap();
let naive_elapsed = start.elapsed();
assert_eq!(fast, naive);
assert!(
fast_elapsed < naive_elapsed,
"FFT multiplication took {fast_elapsed:?}, naive multiplication took {naive_elapsed:?}"
);
}
#[test]
fn permutation_words_match_composition_convention() {
for n in 1..=5 {
for permutation in all_permutations(n) {
let mut rebuilt = Permutation::identity(n);
for adjacent_index in permutation.adjacent_word() {
rebuilt = rebuilt.compose(&Permutation::adjacent(n, adjacent_index));
}
assert_eq!(rebuilt, permutation);
}
}
}
}