use crate::cut_and_project::{TileCoord, TileType};
#[derive(Debug, Clone)]
pub struct TensorTile {
pub position: [f64; 2],
pub orientation: f64,
pub tile_type: TileType,
pub source_coords: [i32; 5],
pub tensor: Vec<f32>,
pub tensor_shape: (usize, usize),
}
impl TensorTile {
pub fn new(
source_coords: [i32; 5],
tile_type: TileType,
orientation: f64,
position: [f64; 2],
) -> Self {
let shape = match tile_type {
TileType::Thick => (5, 5),
TileType::Thin => (3, 8),
TileType::Rejected => (1, 1), };
let len = shape.0 * shape.1;
Self {
position,
orientation,
tile_type,
source_coords,
tensor: vec![0.0f32; len],
tensor_shape: shape,
}
}
pub fn fill_from_source(&mut self) {
let (rows, cols) = self.tensor_shape;
let src = self.source_coords;
let a = (src[0].unsigned_abs() as f32).max(1.0) / 100.0;
let b = src[1] as f32;
let c = (src[2].unsigned_abs() as f32).max(1.0);
let d = src[3].unsigned_abs() as u32;
let e = src[4] as f32;
for i in 0..rows {
for j in 0..cols {
let idx = i * cols + j;
let i_f = i as f32;
let j_f = j as f32;
let m = rows as f32;
let n = cols as f32;
let mode_a = a * (src[0].unsigned_abs() as f32 + 1.0) / 10.0;
let mode_b = b * i_f / m;
let mode_c = (2.0 * std::f32::consts::PI * c * j_f / n).sin();
let hash = Self::simple_hash(i as u32, j as u32, d);
let mode_d = (hash as f32) / (u32::MAX as f32);
let mode_e = (2.0 * std::f32::consts::PI * c * i_f / m + e / 10.0).sin();
self.tensor[idx] = mode_a + mode_b + mode_c + mode_d + mode_e;
}
}
}
fn simple_hash(i: u32, j: u32, seed: u32) -> u32 {
let mut h = 2166136261u32;
h ^= i.wrapping_mul(seed.wrapping_add(1));
h = h.wrapping_mul(16777619);
h ^= j.wrapping_mul(seed.wrapping_add(7));
h = h.wrapping_mul(16777619);
h ^= seed;
h = h.wrapping_mul(16777619);
h
}
pub fn apply_threshold(&mut self, threshold: f32) {
for v in &mut self.tensor {
if *v < threshold {
*v = 0.0;
}
}
}
pub fn tensor_at(&self, row: usize, col: usize) -> f32 {
let (rows, cols) = self.tensor_shape;
assert!(row < rows && col < cols, "index out of bounds");
self.tensor[row * cols + col]
}
pub fn l1_norm(&self) -> f32 {
self.tensor.iter().map(|v| v.abs()).sum()
}
pub fn l2_norm(&self) -> f32 {
let sum_sq: f32 = self.tensor.iter().map(|v| v * v).sum();
sum_sq.sqrt()
}
pub fn tensor_len(&self) -> usize {
self.tensor_shape.0 * self.tensor_shape.1
}
}
#[derive(Debug, Clone)]
pub struct TensorTiling {
pub tiles: Vec<TensorTile>,
pub adjacency: Vec<(usize, usize, f64)>,
}
impl TensorTiling {
pub fn new(tiles: Vec<TensorTile>) -> Self {
let adjacency = Self::detect_adjacency(&tiles);
Self { tiles, adjacency }
}
fn detect_adjacency(tiles: &[TensorTile]) -> Vec<(usize, usize, f64)> {
let threshold = 2.0;
let threshold_sq = threshold * threshold;
let mut edges = Vec::new();
for i in 0..tiles.len() {
for j in (i + 1)..tiles.len() {
let dx = tiles[i].position[0] - tiles[j].position[0];
let dy = tiles[i].position[1] - tiles[j].position[1];
let dist_sq = dx * dx + dy * dy;
if dist_sq < threshold_sq {
let orientation = dy.atan2(dx);
edges.push((i, j, orientation));
}
}
}
edges
}
pub fn apply_kernel<F>(&mut self, f: F)
where
F: Fn(&mut TensorTile),
{
for tile in &mut self.tiles {
f(tile);
}
}
pub fn constraint_check(&self) -> f32 {
let mut total_mismatch = 0.0f32;
for &(i, j, _orientation) in &self.adjacency {
let tile_a = &self.tiles[i];
let tile_b = &self.tiles[j];
let (rows_a, cols_a) = tile_a.tensor_shape;
let (rows_b, cols_b) = tile_b.tensor_shape;
let border_len = cols_a.min(cols_b);
for col in 0..border_len {
let va = tile_a.tensor_at(rows_a - 1, col.min(cols_a - 1));
let vb = tile_b.tensor_at(0, col.min(cols_b - 1));
total_mismatch += (va - vb).abs();
}
let col_border_len = rows_a.min(rows_b);
for row in 0..col_border_len {
let va = tile_a.tensor_at(row.min(rows_a - 1), cols_a - 1);
let vb = tile_b.tensor_at(row.min(rows_b - 1), 0);
total_mismatch += (va - vb).abs();
}
}
total_mismatch
}
}
pub fn generate_tensor_tiling(
_lattice_points: &[[i32; 5]],
baseline_tiles: &[TileCoord],
) -> TensorTiling {
let mut tensor_tiles = Vec::with_capacity(baseline_tiles.len());
for tc in baseline_tiles {
let mut coords = [0i32; 5];
for (k, &v) in tc.source_coords.iter().enumerate().take(5) {
coords[k] = v;
}
let mut tt = TensorTile::new(
coords,
tc.tile_type,
0.0, [tc.x, tc.y],
);
tt.fill_from_source();
tensor_tiles.push(tt);
}
TensorTiling::new(tensor_tiles)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cut_and_project::{CutAndProjectCompiler, TileType};
#[test]
fn test_tensor_tile_creation() {
let tile = TensorTile::new(
[1, 2, 3, 4, 5],
TileType::Thick,
0.0,
[1.0, 2.0],
);
assert_eq!(tile.tensor_shape, (5, 5));
assert_eq!(tile.tensor.len(), 25);
assert!(tile.tensor.iter().all(|&v| v == 0.0));
}
#[test]
fn test_tensor_tile_creation_thin() {
let tile = TensorTile::new(
[1, 2, 3, 4, 5],
TileType::Thin,
0.0,
[0.0, 0.0],
);
assert_eq!(tile.tensor_shape, (3, 8));
assert_eq!(tile.tensor.len(), 24);
}
#[test]
fn test_tensor_tile_fill_modes() {
let bases: [[i32; 5]; 5] = [
[10, 0, 0, 0, 0],
[0, 10, 0, 0, 0],
[0, 0, 10, 0, 0],
[0, 0, 0, 10, 0],
[0, 0, 0, 0, 10],
];
let filled: Vec<Vec<f32>> = bases
.iter()
.map(|&coords| {
let mut tile =
TensorTile::new(coords, TileType::Thick, 0.0, [0.0, 0.0]);
tile.fill_from_source();
tile.tensor.clone()
})
.collect();
for i in 0..filled.len() {
for j in (i + 1)..filled.len() {
assert_ne!(
filled[i], filled[j],
"Fill modes {} and {} produced identical tensors",
i, j
);
}
}
}
#[test]
fn test_threshold_filter() {
let mut tile = TensorTile::new(
[5, 5, 5, 5, 5],
TileType::Thick,
0.0,
[0.0, 0.0],
);
tile.fill_from_source();
let before = tile.tensor.clone();
tile.apply_threshold(0.5);
for (idx, &v) in tile.tensor.iter().enumerate() {
if before[idx] < 0.5 {
assert_eq!(v, 0.0, "value {} was below threshold but not zeroed", before[idx]);
} else {
assert_eq!(v, before[idx], "value {} was above threshold but changed", before[idx]);
}
}
}
#[test]
fn test_tiling_adjacency() {
let tiles = vec![
TensorTile::new([1, 0, 0, 0, 0], TileType::Thick, 0.0, [0.0, 0.0]),
TensorTile::new([0, 1, 0, 0, 0], TileType::Thick, 0.0, [0.5, 0.5]),
];
let tiling = TensorTiling::new(tiles);
assert!(
!tiling.adjacency.is_empty(),
"Two close tiles should detect adjacency"
);
}
#[test]
fn test_tiling_no_adjacency_far_tiles() {
let tiles = vec![
TensorTile::new([1, 0, 0, 0, 0], TileType::Thick, 0.0, [0.0, 0.0]),
TensorTile::new([0, 1, 0, 0, 0], TileType::Thick, 0.0, [100.0, 100.0]),
];
let tiling = TensorTiling::new(tiles);
assert!(
tiling.adjacency.is_empty(),
"Far tiles should not be adjacent"
);
}
#[test]
fn test_constraint_check_identical_vs_different() {
let mut tile_a =
TensorTile::new([2, 2, 2, 2, 2], TileType::Thick, 0.0, [0.0, 0.0]);
tile_a.fill_from_source();
let mut tile_b_identical = tile_a.clone();
tile_b_identical.position = [0.5, 0.5];
let tiling_identical = TensorTiling::new(vec![tile_a.clone(), tile_b_identical]);
let mismatch_identical = tiling_identical.constraint_check();
let mut tile_c =
TensorTile::new([50, 50, 50, 50, 50], TileType::Thick, 0.0, [0.0, 0.0]);
tile_c.fill_from_source();
let mut tile_d = tile_c.clone();
tile_d.position = [0.5, 0.5];
let tiling_different = TensorTiling::new(vec![tile_a, tile_d]);
let mismatch_different = tiling_different.constraint_check();
assert!(
mismatch_different >= mismatch_identical,
"Different tiles (mismatch={}) should have >= mismatch than identical ({})",
mismatch_different,
mismatch_identical,
);
}
#[test]
fn test_norms() {
let mut tile =
TensorTile::new([3, 3, 3, 3, 3], TileType::Thick, 0.0, [0.0, 0.0]);
tile.fill_from_source();
let l1 = tile.l1_norm();
let l2 = tile.l2_norm();
assert!(l1 > 0.0, "L1 norm should be positive after fill");
assert!(l2 > 0.0, "L2 norm should be positive after fill");
assert!(l1 >= l2, "L1 ({}) should be >= L2 ({})", l1, l2);
}
#[test]
fn test_generate_tensor_tiling() {
let compiler = CutAndProjectCompiler::new(5, 2).with_golden_projection();
let baseline = compiler.compile(2);
if baseline.is_empty() {
return; }
let lattice: Vec<[i32; 5]> = baseline
.iter()
.map(|tc| {
let mut c = [0i32; 5];
for (k, &v) in tc.source_coords.iter().enumerate().take(5) {
c[k] = v;
}
c
})
.collect();
let tiling = generate_tensor_tiling(&lattice, &baseline);
assert_eq!(tiling.tiles.len(), baseline.len());
let any_filled = tiling
.tiles
.iter()
.any(|t| t.tensor.iter().any(|&v| v != 0.0));
assert!(any_filled, "At least one tensor should have non-zero values");
}
}