use std::marker::PhantomData;
pub trait Permutation: Copy + Clone {
fn forward(i: u32) -> u32;
fn inverse(i: u32) -> u32;
fn is_self_dual() -> bool {
(0..32).all(|i| Self::forward(i) == Self::inverse(i))
}
}
pub trait HasDual: Permutation {
type Dual: Permutation;
}
#[derive(Copy, Clone, Debug)]
pub struct Xor<const MASK: u32>;
impl<const MASK: u32> Permutation for Xor<MASK> {
fn forward(i: u32) -> u32 {
(i ^ MASK) & 0x1F }
fn inverse(i: u32) -> u32 {
(i ^ MASK) & 0x1F
}
fn is_self_dual() -> bool {
true }
}
impl<const MASK: u32> HasDual for Xor<MASK> {
type Dual = Xor<MASK>; }
#[derive(Copy, Clone, Debug)]
pub struct RotateDown<const DELTA: u32>;
#[derive(Copy, Clone, Debug)]
pub struct RotateUp<const DELTA: u32>;
impl<const DELTA: u32> Permutation for RotateDown<DELTA> {
fn forward(i: u32) -> u32 {
(i + 32 - (DELTA & 0x1F)) & 0x1F
}
fn inverse(i: u32) -> u32 {
(i + (DELTA & 0x1F)) & 0x1F
}
fn is_self_dual() -> bool {
DELTA == 0 || DELTA == 16 }
}
impl<const DELTA: u32> Permutation for RotateUp<DELTA> {
fn forward(i: u32) -> u32 {
(i + (DELTA & 0x1F)) & 0x1F
}
fn inverse(i: u32) -> u32 {
(i + 32 - (DELTA & 0x1F)) & 0x1F
}
fn is_self_dual() -> bool {
DELTA == 0 || DELTA == 16
}
}
impl<const DELTA: u32> HasDual for RotateDown<DELTA> {
type Dual = RotateUp<DELTA>; }
impl<const DELTA: u32> HasDual for RotateUp<DELTA> {
type Dual = RotateDown<DELTA>;
}
#[derive(Copy, Clone, Debug)]
pub struct Identity;
impl Permutation for Identity {
fn forward(i: u32) -> u32 {
i
}
fn inverse(i: u32) -> u32 {
i
}
fn is_self_dual() -> bool {
true
}
}
impl HasDual for Identity {
type Dual = Identity;
}
#[derive(Copy, Clone, Debug)]
pub struct Compose<P1: Permutation, P2: Permutation>(PhantomData<(P1, P2)>);
impl<P1: Permutation, P2: Permutation> Permutation for Compose<P1, P2> {
fn forward(i: u32) -> u32 {
P2::forward(P1::forward(i))
}
fn inverse(i: u32) -> u32 {
P1::inverse(P2::inverse(i))
}
}
pub type ButterflyStage0 = Xor<1>;
pub type ButterflyStage1 = Xor<2>;
pub type ButterflyStage2 = Xor<4>;
pub type ButterflyStage3 = Xor<8>;
pub type ButterflyStage4 = Xor<16>;
pub type FullButterfly = Compose<
Compose<Compose<Compose<ButterflyStage0, ButterflyStage1>, ButterflyStage2>, ButterflyStage3>,
ButterflyStage4,
>;
pub type InverseButterfly = Compose<
Compose<Compose<Compose<ButterflyStage4, ButterflyStage3>, ButterflyStage2>, ButterflyStage1>,
ButterflyStage0,
>;
#[derive(Copy, Clone, Debug)]
pub struct Shuffled<T, P: Permutation> {
pub data: T,
_perm: PhantomData<P>,
}
impl<T, P: Permutation> Shuffled<T, P> {
pub fn new(data: T) -> Self {
Shuffled {
data,
_perm: PhantomData,
}
}
}
pub fn shuffle<T: Copy, P: Permutation>(values: [T; 32], _perm: P) -> Shuffled<[T; 32], P> {
let mut result = values;
for i in 0..32 {
let src = P::inverse(i as u32) as usize;
result[i] = values[src];
}
Shuffled::new(result)
}
pub fn unshuffle<T: Copy, P: HasDual>(shuffled: Shuffled<[T; 32], P>) -> Shuffled<[T; 32], P::Dual>
where
P::Dual: Permutation,
{
let mut result = shuffled.data;
for i in 0..32 {
let src = <P::Dual as Permutation>::inverse(i as u32) as usize;
result[i] = shuffled.data[src];
}
Shuffled::new(result)
}
pub fn shuffle_involution<T: Copy, P: Permutation + HasDual<Dual = P>>(
values: [T; 32],
_perm: P,
) -> [T; 32] {
let mut result = values;
for i in 0..32 {
let src = P::inverse(i as u32) as usize;
result[i] = values[src];
}
result
}
#[derive(Copy, Clone, Debug)]
pub struct LaneView {
pub my_lane: u32,
pub i_send_to: u32,
pub i_receive_from: u32,
}
impl LaneView {
pub fn for_lane<P: Permutation>(lane: u32) -> Self {
LaneView {
my_lane: lane,
i_send_to: P::forward(lane),
i_receive_from: P::inverse(lane),
}
}
pub fn is_symmetric(&self) -> bool {
self.i_send_to == self.i_receive_from
}
}
pub fn all_symmetric<P: Permutation>() -> bool {
(0..32).all(|i| {
let view = LaneView::for_lane::<P>(i);
view.is_symmetric()
})
}
pub mod session_view {
use super::*;
#[derive(Copy, Clone, Debug)]
pub struct Exchange<P: Permutation, T> {
_marker: PhantomData<(P, T)>,
}
pub fn project_to_lane<P: Permutation>(lane: u32) -> LaneView {
LaneView::for_lane::<P>(lane)
}
pub fn projections_isomorphic<P: Permutation>() -> bool {
true
}
}
pub mod algebra {
pub fn xor_group_composition(a: u32, b: u32) -> u32 {
a ^ b
}
pub fn rotate_group_composition(a: u32, b: u32) -> u32 {
(a + b) & 0x1F
}
pub const XOR_SUBGROUP_SIZE: u64 = 32;
pub const ROTATE_SUBGROUP_SIZE: u64 = 32;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_xor_self_dual() {
for mask in 0..32u32 {
for lane in 0..32u32 {
let after_one = (lane ^ mask) & 0x1F;
let after_two = (after_one ^ mask) & 0x1F;
assert_eq!(after_two, lane, "XOR<{}> is not involution", mask);
}
}
}
#[test]
fn test_xor_is_self_dual_type() {
assert!(Xor::<5>::is_self_dual());
assert!(Xor::<0>::is_self_dual());
assert!(Xor::<31>::is_self_dual());
}
#[test]
fn test_rotate_duality() {
for lane in 0..32u32 {
let down_then_up = RotateUp::<1>::forward(RotateDown::<1>::forward(lane));
assert_eq!(down_then_up, lane);
}
}
#[test]
fn test_shuffle_unshuffle_roundtrip() {
let original: [i32; 32] = core::array::from_fn(|i| i as i32);
let shuffled = shuffle(original, Xor::<5>);
let unshuffled = unshuffle(shuffled);
assert_eq!(unshuffled.data, original);
}
#[test]
fn test_lane_view_symmetric_for_xor() {
assert!(all_symmetric::<Xor<5>>());
assert!(all_symmetric::<Xor<16>>());
assert!(all_symmetric::<Xor<0>>());
}
#[test]
fn test_lane_view_asymmetric_for_rotate() {
let view = LaneView::for_lane::<RotateDown<5>>(0);
assert_ne!(view.i_send_to, view.i_receive_from);
assert_eq!(view.i_send_to, 27);
assert_eq!(view.i_receive_from, 5);
}
#[test]
fn test_xor_group_structure() {
assert_eq!(algebra::xor_group_composition(5, 3), 6);
assert_eq!(algebra::xor_group_composition(5, 0), 5);
assert_eq!(algebra::xor_group_composition(5, 5), 0);
assert_eq!(
algebra::xor_group_composition(5, 3),
algebra::xor_group_composition(3, 5)
);
assert_eq!(
algebra::xor_group_composition(algebra::xor_group_composition(5, 3), 7),
algebra::xor_group_composition(5, algebra::xor_group_composition(3, 7))
);
}
#[test]
fn test_butterfly_permutation() {
let mut lane = 0u32;
lane = Xor::<1>::forward(lane); lane = Xor::<2>::forward(lane); lane = Xor::<4>::forward(lane); lane = Xor::<8>::forward(lane); lane = Xor::<16>::forward(lane);
assert_eq!(lane, 31, "Full butterfly maps 0 -> 31");
for i in 0..32u32 {
assert_eq!(FullButterfly::forward(i), i ^ 31);
}
}
#[test]
fn test_butterfly_inverse() {
for i in 0..32u32 {
let through = InverseButterfly::forward(FullButterfly::forward(i));
assert_eq!(through, i);
}
}
}