use std::marker::PhantomData;
pub trait TileSet: Copy + 'static {
const MASK: u16;
const NAME: &'static str;
}
pub trait TileComplement<Other: TileSet>: TileSet {}
#[derive(Copy, Clone, Debug)]
pub struct AllTiles;
impl TileSet for AllTiles {
const MASK: u16 = 0xFFFF;
const NAME: &'static str = "AllTiles";
}
#[derive(Copy, Clone, Debug)]
pub struct NoTiles;
impl TileSet for NoTiles {
const MASK: u16 = 0x0000;
const NAME: &'static str = "NoTiles";
}
#[derive(Copy, Clone, Debug)]
pub struct LowerHalf;
impl TileSet for LowerHalf {
const MASK: u16 = 0x00FF;
const NAME: &'static str = "LowerHalf";
}
#[derive(Copy, Clone, Debug)]
pub struct UpperHalf;
impl TileSet for UpperHalf {
const MASK: u16 = 0xFF00;
const NAME: &'static str = "UpperHalf";
}
#[derive(Copy, Clone, Debug)]
pub struct EvenTiles;
impl TileSet for EvenTiles {
const MASK: u16 = 0x5555;
const NAME: &'static str = "EvenTiles";
}
#[derive(Copy, Clone, Debug)]
pub struct OddTiles;
impl TileSet for OddTiles {
const MASK: u16 = 0xAAAA;
const NAME: &'static str = "OddTiles";
}
impl TileComplement<OddTiles> for EvenTiles {}
impl TileComplement<EvenTiles> for OddTiles {}
impl TileComplement<UpperHalf> for LowerHalf {}
impl TileComplement<LowerHalf> for UpperHalf {}
impl TileComplement<NoTiles> for AllTiles {}
impl TileComplement<AllTiles> for NoTiles {}
pub struct TileGroup<S: TileSet> {
_set: PhantomData<S>,
}
impl<S: TileSet> TileGroup<S> {
pub fn new() -> Self {
TileGroup { _set: PhantomData }
}
pub fn active_mask(&self) -> u16 {
S::MASK
}
}
impl TileGroup<AllTiles> {
pub fn ring_pass(&self, data: &[u64; 16]) -> [u64; 16] {
let mut result = [0u64; 16];
for i in 0..16 {
result[(i + 1) % 16] = data[i];
}
result
}
pub fn butterfly(&self, data: &[u64; 16], stride: usize) -> [u64; 16] {
assert!(stride > 0 && stride < 16, "stride must be 1..15");
let mut result = [0u64; 16];
for i in 0..16 {
result[i] = data[i ^ stride];
}
result
}
pub fn scatter(&self, _src: usize, data: &[u64; 16]) -> [u64; 16] {
*data
}
pub fn gather(&self, data: &[u64; 16], _dst: usize) -> [u64; 16] {
*data
}
pub fn reduce_sum(&self, data: &[u64; 16]) -> u64 {
let mut work = *data;
let mut stride = 8;
while stride >= 1 {
let exchanged = self.butterfly(&work, stride);
for i in 0..16 {
work[i] = work[i].wrapping_add(exchanged[i]);
}
stride >>= 1;
}
work[0]
}
}
impl TileGroup<AllTiles> {
pub fn diverge_halves(self) -> (TileGroup<LowerHalf>, TileGroup<UpperHalf>) {
(TileGroup::new(), TileGroup::new())
}
pub fn diverge_parity(self) -> (TileGroup<EvenTiles>, TileGroup<OddTiles>) {
(TileGroup::new(), TileGroup::new())
}
}
pub fn merge_tiles<A, B>(_a: TileGroup<A>, _b: TileGroup<B>) -> TileGroup<AllTiles>
where
A: TileSet + TileComplement<B>,
B: TileSet,
{
TileGroup::new()
}
pub struct CrossbarPort {
tile_id: usize,
channels: [Option<u64>; 16],
}
impl CrossbarPort {
pub fn new(tile_id: usize) -> Self {
CrossbarPort {
tile_id,
channels: [Option::None; 16],
}
}
pub fn send(&mut self, channel: usize, value: u64) {
assert!(channel < 16);
self.channels[channel] = Some(value);
}
pub fn is_sending(&self, channel: usize) -> bool {
self.channels[channel].is_some()
}
pub fn tile_id(&self) -> usize {
self.tile_id
}
}
pub struct Crossbar {
pipe: [[Option<u64>; 16]; 16], }
impl Crossbar {
pub fn new() -> Self {
Crossbar {
pipe: [[Option::None; 16]; 16],
}
}
pub fn clock(&mut self, ports: &[CrossbarPort]) {
for port in ports {
let src = port.tile_id();
for dst in 0..16 {
if let Some(data) = port.channels[dst] {
self.pipe[src][dst] = Some(data);
}
}
}
}
pub fn recv(&self, dst: usize, src: usize) -> Option<u64> {
self.pipe[src][dst]
}
pub fn recv_checked<S: TileSet>(
&self,
_proof: &TileGroup<S>,
dst: usize,
src: usize,
) -> Option<u64> {
if S::MASK & (1 << src) != 0 {
self.pipe[src][dst]
} else {
Option::None }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tile_sets_are_correct() {
assert_eq!(AllTiles::MASK, 0xFFFF);
assert_eq!(NoTiles::MASK, 0x0000);
assert_eq!(LowerHalf::MASK, 0x00FF);
assert_eq!(UpperHalf::MASK, 0xFF00);
assert_eq!(EvenTiles::MASK, 0x5555);
assert_eq!(OddTiles::MASK, 0xAAAA);
}
#[test]
fn tile_complements_cover_all() {
assert_eq!(LowerHalf::MASK | UpperHalf::MASK, AllTiles::MASK);
assert_eq!(LowerHalf::MASK & UpperHalf::MASK, 0);
assert_eq!(EvenTiles::MASK | OddTiles::MASK, AllTiles::MASK);
assert_eq!(EvenTiles::MASK & OddTiles::MASK, 0);
}
#[test]
fn ring_pass_rotates() {
let tiles: TileGroup<AllTiles> = TileGroup::new();
let data: [u64; 16] = core::array::from_fn(|i| i as u64);
let result = tiles.ring_pass(&data);
for i in 0..16 {
assert_eq!(result[(i + 1) % 16], data[i]);
}
}
#[test]
fn butterfly_exchange_swaps_pairs() {
let tiles: TileGroup<AllTiles> = TileGroup::new();
let data: [u64; 16] = core::array::from_fn(|i| i as u64 * 10);
let result = tiles.butterfly(&data, 1);
assert_eq!(result[0], data[1]); assert_eq!(result[1], data[0]); assert_eq!(result[2], data[3]);
assert_eq!(result[3], data[2]);
let result = tiles.butterfly(&data, 8);
assert_eq!(result[0], data[8]);
assert_eq!(result[8], data[0]);
assert_eq!(result[7], data[15]);
assert_eq!(result[15], data[7]);
}
#[test]
fn butterfly_reduction_sums_all() {
let tiles: TileGroup<AllTiles> = TileGroup::new();
let data: [u64; 16] = [1; 16];
let sum = tiles.reduce_sum(&data);
assert_eq!(sum, 16);
}
#[test]
fn butterfly_reduction_distinct_values() {
let tiles: TileGroup<AllTiles> = TileGroup::new();
let data: [u64; 16] = core::array::from_fn(|i| (i + 1) as u64);
let sum = tiles.reduce_sum(&data);
assert_eq!(sum, 136);
}
#[test]
fn diverge_merge_ring_pass() {
let tiles: TileGroup<AllTiles> = TileGroup::new();
let data: [u64; 16] = core::array::from_fn(|i| i as u64);
let (lower, upper) = tiles.diverge_halves();
assert_eq!(lower.active_mask(), 0x00FF);
assert_eq!(upper.active_mask(), 0xFF00);
let all = merge_tiles(lower, upper);
let result = all.ring_pass(&data);
assert_eq!(result[1], 0); }
#[test]
fn diverge_parity_merge() {
let tiles: TileGroup<AllTiles> = TileGroup::new();
let (evens, odds) = tiles.diverge_parity();
assert_eq!(evens.active_mask(), 0x5555);
assert_eq!(odds.active_mask(), 0xAAAA);
let all = merge_tiles(evens, odds);
let data = [42u64; 16];
let sum = all.reduce_sum(&data);
assert_eq!(sum, 42 * 16);
}
#[test]
fn stale_data_bug_demonstration() {
let mut xbar = Crossbar::new();
let mut port0 = CrossbarPort::new(0);
port0.send(1, 0xDEAD);
xbar.clock(&[port0]);
assert_eq!(xbar.recv(1, 0), Some(0xDEAD));
let port0_idle = CrossbarPort::new(0); assert!(!port0_idle.is_sending(1));
let stale = xbar.recv(1, 0);
assert_eq!(stale, Some(0xDEAD));
}
#[test]
fn checked_recv_prevents_stale_read() {
let mut xbar = Crossbar::new();
let mut port0 = CrossbarPort::new(0);
port0.send(1, 0xBEEF);
xbar.clock(&[port0]);
let all: TileGroup<AllTiles> = TileGroup::new();
assert_eq!(xbar.recv_checked(&all, 1, 0), Some(0xBEEF));
let upper: TileGroup<UpperHalf> = TileGroup::new();
assert_eq!(xbar.recv_checked(&upper, 1, 0), Option::None);
}
#[test]
fn vcpud_ring_send_recv_pattern() {
let mut xbar = Crossbar::new();
let ports: Vec<CrossbarPort> = (0..16)
.map(|id| {
let mut port = CrossbarPort::new(id);
port.send((id + 1) % 16, id as u64 * 100);
port
})
.collect();
xbar.clock(&ports);
for id in 0..16 {
let src = (id + 16 - 1) % 16; let data = xbar.recv(id, src);
assert_eq!(data, Some(src as u64 * 100));
}
}
#[test]
fn zero_overhead_tile_group() {
assert_eq!(std::mem::size_of::<TileGroup<AllTiles>>(), 0);
assert_eq!(std::mem::size_of::<TileGroup<LowerHalf>>(), 0);
assert_eq!(std::mem::size_of::<TileGroup<EvenTiles>>(), 0);
}
}