use crate::types::DataMatrix;
use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
pub trait NeighborhoodGraph {
fn neighbors(&self, index: usize) -> &[usize];
fn initialize(&mut self, data: &DataMatrix);
}
pub struct GridNeighborhoodGraph {
grid: std::collections::HashMap<usize, Vec<usize>>,
point_to_cell: Vec<usize>,
cell_size_x: f64,
cell_size_y: f64,
cells_per_axis: usize,
empty: Vec<usize>,
}
impl GridNeighborhoodGraph {
pub fn new(cell_size_x: f64, cell_size_y: f64, cells_per_axis: usize) -> Self {
Self {
grid: std::collections::HashMap::new(),
point_to_cell: Vec::new(),
cell_size_x,
cell_size_y,
cells_per_axis,
empty: Vec::new(),
}
}
pub fn initialize(&mut self, data: &DataMatrix) -> bool {
if data.ncols() < 2 {
return false;
}
self.grid.clear();
self.point_to_cell = vec![0; data.nrows()];
for row in 0..data.nrows() {
let x = data[(row, 0)];
let y = data[(row, 1)];
if x < 0.0 || y < 0.0 {
continue; }
let cell_x = (x / self.cell_size_x).floor() as usize;
let cell_y = (y / self.cell_size_y).floor() as usize;
let cell_x = cell_x.min(self.cells_per_axis.saturating_sub(1));
let cell_y = cell_y.min(self.cells_per_axis.saturating_sub(1));
let cell_idx = cell_y * self.cells_per_axis + cell_x;
self.grid.entry(cell_idx).or_default().push(row);
self.point_to_cell[row] = cell_idx;
}
!self.grid.is_empty()
}
}
impl NeighborhoodGraph for GridNeighborhoodGraph {
fn neighbors(&self, index: usize) -> &[usize] {
if index >= self.point_to_cell.len() {
return &self.empty;
}
let cell_idx = self.point_to_cell[index];
self.grid
.get(&cell_idx)
.map(|v| v.as_slice())
.unwrap_or(&self.empty)
}
fn initialize(&mut self, data: &DataMatrix) {
let _ = GridNeighborhoodGraph::initialize(self, data);
}
}
pub struct UsearchNeighborhoodGraph {
index: Option<Index>,
neighbors: Vec<Vec<usize>>,
k_neighbors: usize,
dimensions: usize,
empty: Vec<usize>,
}
impl UsearchNeighborhoodGraph {
pub fn new(k_neighbors: usize, dimensions: usize) -> Self {
Self {
index: None,
neighbors: Vec::new(),
k_neighbors,
dimensions,
empty: Vec::new(),
}
}
fn build_neighbors(&mut self, data: &DataMatrix) -> bool {
let n = data.nrows();
let dims = if self.dimensions > 0 {
self.dimensions
} else {
data.ncols()
};
if n == 0 || dims == 0 {
return false;
}
let options = IndexOptions {
dimensions: dims,
metric: MetricKind::L2sq,
quantization: ScalarKind::F32,
..Default::default()
};
let index = match Index::new(&options) {
Ok(idx) => idx,
Err(_) => return false,
};
if index.reserve(n).is_err() {
return false;
}
for i in 0..n {
let mut vector = Vec::<f32>::with_capacity(dims);
for j in 0..dims {
vector.push(data[(i, j)] as f32);
}
if index.add(i as u64, &vector).is_err() {
return false;
}
}
self.neighbors.clear();
self.neighbors.resize(n, Vec::new());
for i in 0..n {
let mut query = Vec::<f32>::with_capacity(dims);
for j in 0..dims {
query.push(data[(i, j)] as f32);
}
let k = (self.k_neighbors + 1).min(n);
match index.search(&query, k) {
Ok(results) => {
for (key, _distance) in results.keys.iter().zip(results.distances.iter()) {
let key_usize = *key as usize;
if key_usize != i {
self.neighbors[i].push(key_usize);
if self.neighbors[i].len() >= self.k_neighbors {
break;
}
}
}
}
Err(_) => {
}
}
}
self.index = Some(index);
true
}
}
impl NeighborhoodGraph for UsearchNeighborhoodGraph {
fn neighbors(&self, index: usize) -> &[usize] {
if index >= self.neighbors.len() {
return &self.empty;
}
&self.neighbors[index]
}
fn initialize(&mut self, data: &DataMatrix) {
if self.dimensions == 0 {
self.dimensions = data.ncols();
}
let _ = self.build_neighbors(data);
}
}
pub struct DummyNeighborhood {
neighbors: Vec<Vec<usize>>,
}
impl DummyNeighborhood {
pub fn new(num_points: usize, window: usize) -> Self {
let mut neighbors = Vec::with_capacity(num_points);
for i in 0..num_points {
let start = i.saturating_sub(window);
let end = (i + window).min(num_points - 1);
neighbors.push((start..=end).collect());
}
Self { neighbors }
}
}
impl NeighborhoodGraph for DummyNeighborhood {
fn neighbors(&self, index: usize) -> &[usize] {
&self.neighbors[index]
}
fn initialize(&mut self, _data: &DataMatrix) {
}
}