use std::alloc::Global;
use feanor_math::algorithms::unity_root::is_prim_root_of_unity;
use feanor_math::divisibility::DivisibilityRingStore;
use feanor_math::homomorphism::Homomorphism;
use feanor_math::integer::IntegerRingStore;
use feanor_math::primitive_int::StaticRing;
use feanor_math::ring::*;
use feanor_math::rings::extension::*;
use feanor_math::rings::zn::zn_64::*;
use feanor_math::rings::zn::*;
use feanor_math::assert_el_eq;
use feanor_math::seq::VectorView;
use tracing::instrument;
use crate::circuit::PlaintextCircuit;
use crate::cyclotomic::*;
use crate::lintransform::matmul::*;
use crate::lintransform::trace::trace_circuit;
use crate::lintransform::PowerTable;
use crate::number_ring::hypercube::isomorphism::*;
use crate::number_ring::quotient::NumberRingQuotientBase;
use crate::number_ring::*;
const ZZ: StaticRing<i64> = StaticRing::<i64>::RING;
fn pow2_bitreversed_dwt_butterfly<G, NumberRing>(H: &DefaultHypercube<NumberRing>, dim_index: usize, l: usize, root_of_unity_4l: El<SlotRingOver<Zn>>, row_autos: G) -> MatmulTransform<NumberRing>
where G: Fn(&[usize]) -> CyclotomicGaloisGroupEl,
NumberRing: HECyclotomicNumberRing + Clone
{
let m = H.hypercube().m(dim_index);
let log2_m = ZZ.abs_log2_ceil(&(m as i64)).unwrap();
assert_eq!(m, 1 << log2_m);
let g = H.hypercube().map_1d(dim_index, -1);
let smaller_cyclotomic_index_ring = Zn::new(4 * l as u64);
let Gal = H.galois_group();
let red = ZnReductionMap::new(Gal.underlying_ring(), &smaller_cyclotomic_index_ring).unwrap();
assert_el_eq!(&smaller_cyclotomic_index_ring, &smaller_cyclotomic_index_ring.one(), &smaller_cyclotomic_index_ring.pow(red.map(Gal.to_ring_el(g)), l));
let l = l;
assert!(l > 1);
assert!(m % l == 0);
let zeta_power_table = PowerTable::new(H.slot_ring(), root_of_unity_4l, 4 * l);
enum TwiddleFactor {
Zero, PosPowerZeta(ZnEl), NegPowerZeta(ZnEl)
}
let pow_of_zeta = |factor: TwiddleFactor| match factor {
TwiddleFactor::PosPowerZeta(pow) => H.slot_ring().clone_el(&zeta_power_table.get_power(Gal.underlying_ring().smallest_positive_lift(pow))),
TwiddleFactor::NegPowerZeta(pow) => H.slot_ring().negate(H.slot_ring().clone_el(&zeta_power_table.get_power(Gal.underlying_ring().smallest_positive_lift(pow)))),
TwiddleFactor::Zero => H.slot_ring().zero()
};
let forward_mask = H.from_slot_values(H.hypercube().hypercube_iter(|idxs| {
let idx_in_block = idxs[dim_index] % l;
if idx_in_block >= l / 2 {
TwiddleFactor::PosPowerZeta(Gal.underlying_ring().zero())
} else {
TwiddleFactor::Zero
}
}).map(&pow_of_zeta));
let diagonal_mask = H.from_slot_values(H.hypercube().hypercube_iter(|idxs| {
let idx_in_block = idxs[dim_index] % l;
if idx_in_block >= l / 2 {
TwiddleFactor::NegPowerZeta(Gal.to_ring_el(Gal.mul(Gal.pow(g, (idx_in_block - l / 2) as i64), row_autos(&idxs))))
} else {
TwiddleFactor::PosPowerZeta(Gal.underlying_ring().zero())
}
}).map(&pow_of_zeta));
let backward_mask = H.from_slot_values(H.hypercube().hypercube_iter(|idxs| {
let idx_in_block = idxs[dim_index] % l;
if idx_in_block < l / 2 {
TwiddleFactor::PosPowerZeta(Gal.to_ring_el(Gal.mul(Gal.pow(g, idx_in_block as i64), row_autos(&idxs))))
} else {
TwiddleFactor::Zero
}
}).map(&pow_of_zeta));
let result = MatmulTransform::linear_combine_shifts(H, [
(
(0..H.hypercube().dim_count()).map(|_| 0).collect::<Vec<_>>(),
diagonal_mask
),
(
(0..H.hypercube().dim_count()).map(|i| if i == dim_index { l as i64 / 2 } else { 0 }).collect::<Vec<_>>(),
forward_mask
),
(
(0..H.hypercube().dim_count()).map(|i| if i == dim_index { -(l as i64) / 2 } else { 0 }).collect::<Vec<_>>(),
backward_mask
)
].iter().map(|(shift, coeff)| (shift.copy_els(), H.ring().clone_el(coeff))));
return result;
}
fn pow2_bitreversed_inv_dwt_butterfly<G, NumberRing>(H: &DefaultHypercube<NumberRing>, dim_index: usize, l: usize, root_of_unity_4l: El<SlotRingOver<Zn>>, row_autos: G) -> MatmulTransform<NumberRing>
where G: Fn(&[usize]) -> CyclotomicGaloisGroupEl,
NumberRing: HECyclotomicNumberRing + Clone
{
let m = H.hypercube().m(dim_index);
let log2_m = ZZ.abs_log2_ceil(&(m as i64)).unwrap();
assert_eq!(m, 1 << log2_m);
let g = H.hypercube().map_1d(dim_index, -1);
let smaller_cyclotomic_index_ring = Zn::new(4 * l as u64);
let Gal = H.galois_group();
let red = ZnReductionMap::new(Gal.underlying_ring(), &smaller_cyclotomic_index_ring).unwrap();
assert_el_eq!(&smaller_cyclotomic_index_ring, &smaller_cyclotomic_index_ring.one(), &smaller_cyclotomic_index_ring.pow(red.map(Gal.to_ring_el(g)), l));
let l = l;
assert!(l > 1);
assert!(m % l == 0);
let zeta_power_table = PowerTable::new(H.slot_ring(), root_of_unity_4l, 4 * l);
enum TwiddleFactor {
Zero, PosPowerZeta(ZnEl), NegPowerZeta(ZnEl)
}
let pow_of_zeta = |factor: TwiddleFactor| match factor {
TwiddleFactor::PosPowerZeta(pow) => H.slot_ring().clone_el(&zeta_power_table.get_power(Gal.underlying_ring().smallest_positive_lift(pow))),
TwiddleFactor::NegPowerZeta(pow) => H.slot_ring().negate(H.slot_ring().clone_el(&zeta_power_table.get_power(Gal.underlying_ring().smallest_positive_lift(pow)))),
TwiddleFactor::Zero => H.slot_ring().zero()
};
let inv_2 = H.ring().base_ring().invert(&H.ring().base_ring().int_hom().map(2)).unwrap();
let mut forward_mask = H.from_slot_values(H.hypercube().hypercube_iter(|idxs| {
let idx_in_block = idxs[dim_index] % l;
if idx_in_block >= l / 2 {
TwiddleFactor::PosPowerZeta(Gal.to_ring_el(Gal.mul(Gal.negate(Gal.pow(g, (idx_in_block - l / 2) as i64)), row_autos(&idxs))))
} else {
TwiddleFactor::Zero
}
}).map(&pow_of_zeta));
H.ring().inclusion().mul_assign_ref_map(&mut forward_mask, &inv_2);
let mut diagonal_mask = H.from_slot_values(H.hypercube().hypercube_iter(|idxs| {
let idx_in_block = idxs[dim_index] % l;
if idx_in_block >= l / 2 {
TwiddleFactor::NegPowerZeta(Gal.to_ring_el(Gal.mul(Gal.negate(Gal.pow(g, (idx_in_block - l / 2) as i64)), row_autos(&idxs))))
} else {
TwiddleFactor::PosPowerZeta(Gal.underlying_ring().zero())
}
}).map(&pow_of_zeta));
H.ring().inclusion().mul_assign_ref_map(&mut diagonal_mask, &inv_2);
let mut backward_mask = H.from_slot_values(H.hypercube().hypercube_iter(|idxs| {
let idx_in_block = idxs[dim_index] % l;
if idx_in_block < l / 2 {
TwiddleFactor::PosPowerZeta(Gal.underlying_ring().zero())
} else {
TwiddleFactor::Zero
}
}).map(&pow_of_zeta));
H.ring().inclusion().mul_assign_ref_map(&mut backward_mask, &inv_2);
let result = MatmulTransform::linear_combine_shifts(H, [
(
(0..H.hypercube().dim_count()).map(|_| 0).collect::<Vec<_>>(),
diagonal_mask
),
(
(0..H.hypercube().dim_count()).map(|i| if i == dim_index { l as i64 / 2 } else { 0 }).collect::<Vec<_>>(),
forward_mask
),
(
(0..H.hypercube().dim_count()).map(|i| if i == dim_index { -(l as i64) / 2 } else { 0 }).collect::<Vec<_>>(),
backward_mask
)
].iter().map(|(shift, coeff)| (shift.copy_els(), H.ring().clone_el(coeff))));
#[cfg(debug_assertions)] {
let expected = pow2_bitreversed_dwt_butterfly(H, dim_index, l, H.slot_ring().clone_el(&zeta_power_table.get_power(1)), row_autos).inverse(&H);
debug_assert!(result.eq(&expected, H));
}
return result;
}
fn pow2_bitreversed_dwt<G, NumberRing>(H: &DefaultHypercube<NumberRing>, dim_index: usize, row_autos: G) -> Vec<MatmulTransform<NumberRing>>
where G: Fn(&[usize]) -> CyclotomicGaloisGroupEl,
NumberRing: HECyclotomicNumberRing + Clone
{
let m = H.hypercube().m(dim_index);
let log2_m = ZZ.abs_log2_ceil(&(m as i64)).unwrap();
assert_eq!(m, 1 << log2_m);
assert!((H.ring().n() / m) % 4 == 0, "pow2_bitreversed_dwt() only possible if there is a 4m-th primitive root of unity");
let zeta = H.slot_ring().pow(H.slot_ring().canonical_gen(), H.ring().n() as usize / m / 4);
debug_assert!(is_prim_root_of_unity(H.slot_ring(), &H.slot_ring().canonical_gen(), H.ring().n() as usize));
debug_assert!(is_prim_root_of_unity(H.slot_ring(), &zeta, 4 * m));
let mut result = Vec::new();
for log2_l in 1..=log2_m {
result.push(pow2_bitreversed_dwt_butterfly(
H,
dim_index,
1 << log2_l,
H.slot_ring().pow(H.slot_ring().clone_el(&zeta), m / (1 << log2_l)),
&row_autos
));
}
return result;
}
#[instrument(skip_all)]
fn pow2_bitreversed_inv_dwt<G, NumberRing>(H: &DefaultHypercube<NumberRing>, dim_index: usize, row_autos: G) -> Vec<MatmulTransform<NumberRing>>
where G: Fn(&[usize]) -> CyclotomicGaloisGroupEl,
NumberRing: HECyclotomicNumberRing + Clone
{
let m = H.hypercube().m(dim_index);
let log2_m = ZZ.abs_log2_ceil(&(m as i64)).unwrap();
assert_eq!(m, 1 << log2_m);
assert!((H.ring().n() / m) % 4 == 0, "pow2_bitreversed_dwt() only possible if there is a 4m-th primitive root of unity");
let zeta = H.slot_ring().pow(H.slot_ring().canonical_gen(), H.ring().n() as usize / m / 4);
debug_assert!(is_prim_root_of_unity(H.slot_ring(), &H.slot_ring().canonical_gen(), H.ring().n() as usize));
debug_assert!(is_prim_root_of_unity(H.slot_ring(), &zeta, 4 * m));
let mut result = Vec::new();
for log2_l in (1..=log2_m).rev() {
result.push(pow2_bitreversed_inv_dwt_butterfly(
H,
dim_index,
1 << log2_l,
H.slot_ring().pow(H.slot_ring().clone_el(&zeta), m / (1 << log2_l)),
&row_autos
));
}
return result;
}
#[instrument(skip_all)]
pub fn slots_to_coeffs_thin<NumberRing>(H: &DefaultHypercube<NumberRing>) -> Vec<MatmulTransform<NumberRing>>
where NumberRing: HECyclotomicNumberRing + Clone
{
let n = H.ring().get_ring().n();
let log2_n = ZZ.abs_log2_ceil(&(n as i64)).unwrap();
assert!(n == 1 << log2_n);
if H.hypercube().dim_count() == 2 {
assert_eq!(2, H.hypercube().m(1));
let zeta4 = H.slot_ring().pow(H.slot_ring().canonical_gen(), H.ring().n() as usize / 4);
let mut result = Vec::new();
result.push(MatmulTransform::linear_combine_shifts(H, [
(
(0..H.hypercube().dim_count()).map(|_| 0).collect::<Vec<_>>(),
H.from_slot_values(H.hypercube().hypercube_iter(|idxs| if idxs[1] == 0 {
H.slot_ring().one()
} else {
debug_assert!(idxs[1] == 1);
H.slot_ring().negate(H.slot_ring().clone_el(&zeta4))
}))
),
(
(0..H.hypercube().dim_count()).map(|i| if i == 1 { 1 } else { 0 }).collect::<Vec<_>>(),
H.from_slot_values(H.hypercube().hypercube_iter(|idxs| if idxs[1] == 0 {
H.slot_ring().clone_el(&zeta4)
} else {
debug_assert!(idxs[1] == 1);
H.slot_ring().one()
}))
)
].iter().map(|(shift, coeff)| (shift.copy_els(), H.ring().clone_el(coeff)))));
result.extend(pow2_bitreversed_dwt(H, 0, |idxs| if idxs[1] == 0 {
H.galois_group().identity()
} else {
debug_assert!(idxs[1] == 1);
H.galois_group().negate(H.galois_group().identity())
}));
return result;
} else {
assert_eq!(1, H.hypercube().dim_count());
return pow2_bitreversed_dwt(H, 0, |_idxs| H.galois_group().identity());
}
}
#[instrument(skip_all)]
pub fn slots_to_coeffs_thin_inv<NumberRing>(H: &DefaultHypercube<NumberRing>) -> Vec<MatmulTransform<NumberRing>>
where NumberRing: HECyclotomicNumberRing + Clone
{
let n = H.ring().get_ring().n();
let log2_n = ZZ.abs_log2_ceil(&(n as i64)).unwrap();
assert!(n == 1 << log2_n);
if H.hypercube().dim_count() == 2 {
assert_eq!(2, H.hypercube().m(1));
let zeta4_inv = H.slot_ring().pow(H.slot_ring().canonical_gen(), 3 * H.ring().n() as usize / 4);
let two_inv = H.ring().base_ring().invert(&H.slot_ring().base_ring().int_hom().map(2)).unwrap();
let mut result = Vec::new();
result.extend(pow2_bitreversed_inv_dwt(H, 0, |idxs| if idxs[1] == 0 {
H.galois_group().identity()
} else {
debug_assert!(idxs[1] == 1);
H.galois_group().negate(H.galois_group().identity())
}));
result.push(MatmulTransform::linear_combine_shifts(H, [
(
(0..H.hypercube().dim_count()).map(|_| 0).collect::<Vec<_>>(),
H.ring().inclusion().mul_map(H.from_slot_values(H.hypercube().hypercube_iter(|idxs| if idxs[1] == 0 {
H.slot_ring().one()
} else {
debug_assert!(idxs[1] == 1);
H.slot_ring().negate(H.slot_ring().clone_el(&zeta4_inv))
})), H.ring().base_ring().clone_el(&two_inv))
),
(
(0..H.hypercube().dim_count()).map(|i| if i == 1 { 1 } else { 0 }).collect::<Vec<_>>(),
H.ring().inclusion().mul_map(H.from_slot_values(H.hypercube().hypercube_iter(|idxs| if idxs[1] == 0 {
H.slot_ring().one()
} else {
debug_assert!(idxs[1] == 1);
H.slot_ring().clone_el(&zeta4_inv)
})), two_inv)
)
].iter().map(|(shift, coeff)| (shift.copy_els(), H.ring().clone_el(coeff)))));
return result;
} else {
assert_eq!(1, H.hypercube().dim_count());
return pow2_bitreversed_inv_dwt(H, 0, |_idxs| H.galois_group().identity());
}
}
#[instrument(skip_all)]
pub fn coeffs_to_slots_thin<NumberRing>(H: &DefaultHypercube<NumberRing>) -> PlaintextCircuit<NumberRingQuotientBase<NumberRing, Zn, Global>>
where NumberRing: HECyclotomicNumberRing + Clone
{
let log2_n = ZZ.abs_log2_ceil(&(H.hypercube().n() as i64)).unwrap();
assert_eq!(H.hypercube().n(), 1 << log2_n);
let mut result = slots_to_coeffs_thin_inv(H);
let last = MatmulTransform::mult_scalar_slots(H, &H.slot_ring().inclusion().map(H.slot_ring().base_ring().invert(&H.slot_ring().base_ring().int_hom().map(H.slot_ring().rank() as i32)).unwrap()));
*result.last_mut().unwrap() = result.last().unwrap().compose(&last, H);
let trace_circuit = trace_circuit(H.ring().get_ring(), &H.galois_group(), H.hypercube().frobenius(1), H.slot_ring().rank());
let result_circuit = MatmulTransform::to_circuit_many(result, H);
return trace_circuit.compose(result_circuit, H.ring());
}
#[cfg(test)]
use crate::ring_literal;
#[cfg(test)]
use crate::number_ring::pow2_cyclotomic::Pow2CyclotomicNumberRing;
#[cfg(test)]
use feanor_math::algorithms::fft::cooley_tuckey::bitreverse;
#[cfg(test)]
use crate::number_ring::hypercube::structure::HypercubeStructure;
#[test]
fn test_slots_to_coeffs_thin() {
let ring = NumberRingQuotientBase::new(Pow2CyclotomicNumberRing::new(64), Zn::new(97));
let hypercube = HypercubeStructure::halevi_shoup_hypercube(CyclotomicGaloisGroup::new(64), 97);
let H = HypercubeIsomorphism::new::<false>(&ring, hypercube);
let mut current = H.from_slot_values((1..17).map(|n| H.slot_ring().int_hom().map(n)));
for T in slots_to_coeffs_thin(&H) {
current = ring.get_ring().compute_linear_transform(&H, ¤t, &T);
}
let mut expected = [0; 32];
for i in 0..8 {
for j in 0..2 {
expected[bitreverse(i, 3) * 2 + j * 16] = (i * 2 + j + 1) as i32;
}
}
assert_el_eq!(&ring, &ring_literal(&ring, &expected), ¤t);
let ring = NumberRingQuotientBase::new(Pow2CyclotomicNumberRing::new(64), Zn::new(23));
let hypercube = HypercubeStructure::halevi_shoup_hypercube(CyclotomicGaloisGroup::new(64), 23);
let H = HypercubeIsomorphism::new::<false>(&ring, hypercube);
let mut current = H.from_slot_values([1, 2, 3, 4].into_iter().map(|n| H.slot_ring().int_hom().map(n)));
for T in slots_to_coeffs_thin(&H) {
current = ring.get_ring().compute_linear_transform(&H, ¤t, &T);
}
let mut expected = [0; 32];
for i in 0..4 {
expected[bitreverse(i, 2) * 4] = (i + 1) as i32;
}
assert_el_eq!(&ring, &ring_literal(&ring, &expected), ¤t);
}
#[test]
fn test_slots_to_coeffs_thin_inv() {
let ring = NumberRingQuotientBase::new(Pow2CyclotomicNumberRing::new(64), Zn::new(23));
let hypercube = HypercubeStructure::halevi_shoup_hypercube(CyclotomicGaloisGroup::new(64), 23);
let H = HypercubeIsomorphism::new::<false>(&ring, hypercube);
for (transform, actual) in slots_to_coeffs_thin(&H).into_iter().rev().zip(slots_to_coeffs_thin_inv(&H).into_iter()) {
let expected = transform.inverse(&H);
assert!(expected.eq(&actual, &H));
}
let ring = NumberRingQuotientBase::new(Pow2CyclotomicNumberRing::new(64), Zn::new(97));
let hypercube = HypercubeStructure::halevi_shoup_hypercube(CyclotomicGaloisGroup::new(64), 97);
let H = HypercubeIsomorphism::new::<false>(&ring, hypercube);
for (transform, actual) in slots_to_coeffs_thin(&H).into_iter().rev().zip(slots_to_coeffs_thin_inv(&H).into_iter()) {
let expected = transform.inverse(&H);
assert!(expected.eq(&actual, &H));
}
}
#[test]
fn test_coeffs_to_slots_thin() {
let ring = NumberRingQuotientBase::new(Pow2CyclotomicNumberRing::new(64), Zn::new(97));
let hypercube = HypercubeStructure::halevi_shoup_hypercube(CyclotomicGaloisGroup::new(64), 97);
let H = HypercubeIsomorphism::new::<false>(&ring, hypercube);
let mut input = [0; 32];
for i in 0..8 {
for j in 0..2 {
input[bitreverse(i, 3) * 2 + j * 16] = (i * 2 + j + 1) as i32;
input[bitreverse(i, 3) * 2 + j * 16 + 1] = (i * 2 + j + 1 + 16) as i32;
}
}
let current = ring_literal(&ring, &input);
let circuit = coeffs_to_slots_thin(&H);
let actual = circuit.evaluate(std::slice::from_ref(¤t), ring.identity()).pop().unwrap();
let expected = H.from_slot_values((1..17).map(|n| H.slot_ring().int_hom().map(n)));
assert_el_eq!(&ring, &expected, &actual);
let ring = NumberRingQuotientBase::new(Pow2CyclotomicNumberRing::new(64), Zn::new(23));
let hypercube = HypercubeStructure::halevi_shoup_hypercube(CyclotomicGaloisGroup::new(64), 23);
let H = HypercubeIsomorphism::new::<false>(&ring, hypercube);
let mut input = [0; 32];
input[4] = 1;
input[16] = 1;
let current = ring_literal(&ring, &input);
let circuit = coeffs_to_slots_thin(&H);
let actual = circuit.evaluate(std::slice::from_ref(¤t), ring.identity()).pop().unwrap();
let expected = H.from_slot_values([0, 0, 1, 0].into_iter().map(|n| H.slot_ring().int_hom().map(n)));
assert_el_eq!(&ring, &expected, &actual);
let mut input = [0; 32];
for i in 0..4 {
input[bitreverse(i, 2) * 4] = (i + 1) as i32;
for k in 1..4 {
input[bitreverse(i, 2) * 4 + k] = (i + 1 + 4 * k) as i32;
}
for k in 0..4 {
input[bitreverse(i, 2) * 4 + k + 16] = (i + 1 + 4 * k + 16) as i32;
}
}
let current = ring_literal(&ring, &input);
let circuit = coeffs_to_slots_thin(&H);
let actual = circuit.evaluate(std::slice::from_ref(¤t), ring.identity()).pop().unwrap();
let expected = H.from_slot_values([1, 2, 3, 4].into_iter().map(|n| H.slot_ring().int_hom().map(n)));
assert_el_eq!(&ring, &expected, &actual);
}