use super::offsets::hypercubic;
pub struct Lattice {
pub shape: Vec<usize>,
pub strides: Vec<usize>,
pub n_spins: usize,
pub n_dims: usize,
pub n_neighbors: usize,
fwd_neighbors: Vec<u32>,
bwd_neighbors: Vec<u32>,
}
impl Lattice {
pub fn new(shape: Vec<usize>) -> Self {
let n_dims = shape.len();
Self::with_offsets(shape, hypercubic(n_dims))
}
pub fn with_offsets(shape: Vec<usize>, offsets: Vec<Vec<isize>>) -> Self {
let n_dims = shape.len();
let n_neighbors = offsets.len();
let n_spins: usize = shape.iter().product();
for (idx, off) in offsets.iter().enumerate() {
assert_eq!(
off.len(),
n_dims,
"offset {idx} has length {}, expected {n_dims}",
off.len(),
);
}
let mut strides = vec![1usize; n_dims];
for d in (0..n_dims.saturating_sub(1)).rev() {
strides[d] = strides[d + 1] * shape[d + 1];
}
let mut fwd_neighbors = vec![0u32; n_spins * n_neighbors];
let mut bwd_neighbors = vec![0u32; n_spins * n_neighbors];
for i in 0..n_spins {
let coords: Vec<usize> = (0..n_dims).map(|d| (i / strides[d]) % shape[d]).collect();
for (d, off) in offsets.iter().enumerate() {
for (sign, table) in [(1isize, &mut fwd_neighbors), (-1isize, &mut bwd_neighbors)] {
let mut flat = 0usize;
for dim in 0..n_dims {
let c = (coords[dim] as isize + sign * off[dim])
.rem_euclid(shape[dim] as isize)
as usize;
flat += c * strides[dim];
}
table[i * n_neighbors + d] = flat as u32;
}
}
}
Self {
shape,
strides,
n_spins,
n_dims,
n_neighbors,
fwd_neighbors,
bwd_neighbors,
}
}
#[inline]
pub fn neighbor_fwd(&self, flat_idx: usize, dim: usize) -> usize {
self.fwd_neighbors[flat_idx * self.n_neighbors + dim] as usize
}
#[inline]
pub fn neighbor_bwd(&self, flat_idx: usize, dim: usize) -> usize {
self.bwd_neighbors[flat_idx * self.n_neighbors + dim] as usize
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_2d_neighbors() {
let lat = Lattice::new(vec![3, 4]);
assert_eq!(lat.n_spins, 12);
assert_eq!(lat.strides, vec![4, 1]);
assert_eq!(lat.neighbor_fwd(0, 0), 4);
assert_eq!(lat.neighbor_fwd(0, 1), 1);
assert_eq!(lat.neighbor_bwd(0, 0), 8);
assert_eq!(lat.neighbor_bwd(0, 1), 3);
assert_eq!(lat.neighbor_fwd(11, 0), 3);
assert_eq!(lat.neighbor_fwd(11, 1), 8);
}
#[test]
fn test_3d_neighbors() {
let lat = Lattice::new(vec![2, 3, 4]);
assert_eq!(lat.n_spins, 24);
assert_eq!(lat.strides, vec![12, 4, 1]);
assert_eq!(lat.neighbor_fwd(0, 0), 12); assert_eq!(lat.neighbor_fwd(0, 1), 4); assert_eq!(lat.neighbor_fwd(0, 2), 1); }
#[test]
fn test_triangular_neighbors() {
use super::super::offsets::triangular;
let lat = Lattice::with_offsets(vec![4, 4], triangular());
assert_eq!(lat.n_neighbors, 3);
assert_eq!(lat.n_spins, 16);
assert_eq!(lat.neighbor_fwd(0, 0), 4);
assert_eq!(lat.neighbor_fwd(0, 1), 1);
assert_eq!(lat.neighbor_fwd(0, 2), 7);
assert_eq!(lat.neighbor_bwd(0, 0), 12);
assert_eq!(lat.neighbor_bwd(0, 1), 3);
assert_eq!(lat.neighbor_bwd(0, 2), 13);
assert_eq!(lat.neighbor_fwd(5, 2), 8);
assert_eq!(lat.neighbor_bwd(5, 2), 2);
assert_eq!(lat.neighbor_fwd(15, 0), 3);
assert_eq!(lat.neighbor_fwd(15, 1), 12);
assert_eq!(lat.neighbor_fwd(15, 2), 2);
}
}