use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Clone, Debug)]
pub struct CachedLattice {
pub triples: Vec<(u32, u32, u32)>,
pub vectors: Vec<[f64; 2]>,
pub max_hypotenuse: u32,
pub density: usize,
}
impl CachedLattice {
pub fn new(density: usize) -> Self {
let (triples, vectors) = generate_pythagorean_lattice(density);
let max_hypotenuse = triples.iter().map(|&(_, _, c)| c).max().unwrap_or(0);
Self {
triples,
vectors,
max_hypotenuse,
density,
}
}
pub fn len(&self) -> usize {
self.vectors.len()
}
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
pub fn nearest(&self, point: [f64; 2]) -> ([f64; 2], usize, f64) {
let mut best_point = [0.0, 0.0];
let mut best_idx = 0;
let mut best_dist_sq = f64::MAX;
for (i, &v) in self.vectors.iter().enumerate() {
let dx = v[0] - point[0];
let dy = v[1] - point[1];
let dist_sq = dx * dx + dy * dy;
if dist_sq < best_dist_sq {
best_dist_sq = dist_sq;
best_point = v;
best_idx = i;
}
}
(best_point, best_idx, best_dist_sq)
}
pub fn as_slice(&self) -> &[[f64; 2]] {
&self.vectors
}
}
fn generate_pythagorean_lattice(density: usize) -> (Vec<(u32, u32, u32)>, Vec<[f64; 2]>) {
let mut triples = Vec::new();
let mut vectors = Vec::new();
for m in 2..density {
for n in 1..m {
if (m - n) % 2 == 1 && gcd(m as u32, n as u32) == 1 {
let a = (m * m - n * n) as u32;
let b = (2 * m * n) as u32;
let c = (m * m + n * n) as u32;
triples.push((a, b, c));
let a_c = a as f64 / c as f64;
let b_c = b as f64 / c as f64;
vectors.push([a_c, b_c]);
vectors.push([b_c, a_c]);
vectors.push([-a_c, b_c]);
vectors.push([a_c, -b_c]);
vectors.push([-a_c, -b_c]);
}
}
}
vectors.push([1.0, 0.0]);
vectors.push([0.0, 1.0]);
vectors.push([-1.0, 0.0]);
vectors.push([0.0, -1.0]);
(triples, vectors)
}
fn gcd(a: u32, b: u32) -> u32 {
if a == b {
return a;
}
if a == 0 {
return b;
}
if b == 0 {
return a;
}
let shift = (a | b).trailing_zeros();
let mut a = a >> a.trailing_zeros();
let mut b = b >> b.trailing_zeros();
while a != b {
if a > b {
a -= b;
a >>= a.trailing_zeros();
} else {
b -= a;
b >>= b.trailing_zeros();
}
}
a << shift
}
#[derive(Clone, Debug)]
pub struct LatticeCache {
cache: Arc<RwLock<HashMap<usize, CachedLattice>>>,
capacity: usize,
}
impl LatticeCache {
pub fn new(capacity: usize) -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::with_capacity(capacity))),
capacity,
}
}
pub fn with_default_capacity() -> Self {
Self::new(32)
}
pub fn get_or_compute(&self, density: usize) -> CachedLattice {
{
let cache = self.cache.read().unwrap();
if let Some(lattice) = cache.get(&density) {
return lattice.clone();
}
}
let mut cache = self.cache.write().unwrap();
if let Some(lattice) = cache.get(&density) {
return lattice.clone();
}
if cache.len() >= self.capacity {
if let Some(oldest_key) = cache.keys().next().copied() {
cache.remove(&oldest_key);
}
}
let lattice = CachedLattice::new(density);
cache.insert(density, lattice.clone());
lattice
}
pub fn contains(&self, density: usize) -> bool {
let cache = self.cache.read().unwrap();
cache.contains_key(&density)
}
pub fn len(&self) -> usize {
let cache = self.cache.read().unwrap();
cache.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&self) {
let mut cache = self.cache.write().unwrap();
cache.clear();
}
pub fn precompute(&self, densities: &[usize]) {
for &density in densities {
self.get_or_compute(density);
}
}
}
impl Default for LatticeCache {
fn default() -> Self {
Self::with_default_capacity()
}
}
static GLOBAL_CACHE: std::sync::OnceLock<LatticeCache> = std::sync::OnceLock::new();
pub fn global_cache() -> &'static LatticeCache {
GLOBAL_CACHE.get_or_init(|| LatticeCache::with_default_capacity())
}
pub fn clear_global_cache() {
if let Some(cache) = GLOBAL_CACHE.get() {
cache.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cached_lattice_generation() {
let lattice = CachedLattice::new(50);
assert!(lattice.len() > 0);
assert!(lattice.max_hypotenuse > 0);
assert_eq!(lattice.density, 50);
}
#[test]
fn test_lattice_nearest() {
let lattice = CachedLattice::new(100);
let (nearest, _idx, dist_sq) = lattice.nearest([0.6, 0.8]);
assert!((nearest[0] - 0.6).abs() < 0.01);
assert!((nearest[1] - 0.8).abs() < 0.01);
assert!(dist_sq < 0.001);
}
#[test]
fn test_cache_get_or_compute() {
let cache = LatticeCache::new(10);
let lattice1 = cache.get_or_compute(100);
assert!(cache.contains(100));
let lattice2 = cache.get_or_compute(100);
assert_eq!(lattice1.len(), lattice2.len());
}
#[test]
fn test_cache_eviction() {
let cache = LatticeCache::new(3);
cache.get_or_compute(10);
cache.get_or_compute(20);
cache.get_or_compute(30);
assert_eq!(cache.len(), 3);
cache.get_or_compute(40);
assert_eq!(cache.len(), 3);
}
#[test]
fn test_cache_precompute() {
let cache = LatticeCache::new(10);
cache.precompute(&[50, 100, 200]);
assert!(cache.contains(50));
assert!(cache.contains(100));
assert!(cache.contains(200));
}
#[test]
fn test_global_cache() {
let cache = global_cache();
let lattice = cache.get_or_compute(150);
assert!(lattice.len() > 0);
}
#[test]
fn test_gcd() {
assert_eq!(gcd(12, 8), 4);
assert_eq!(gcd(17, 13), 1);
assert_eq!(gcd(100, 50), 50);
}
#[test]
fn test_thread_safety() {
use std::thread;
let cache = LatticeCache::new(10);
let handles: Vec<_> = (0..4)
.map(|i| {
let cache = cache.clone();
thread::spawn(move || {
let density = 50 + i * 50;
let lattice = cache.get_or_compute(density);
lattice.len()
})
})
.collect();
for handle in handles {
let result = handle.join().unwrap();
assert!(result > 0);
}
}
}