frlearn_neighbor 0.1.0

Neighbour search abstractions and brute-force backend for FuzzyRough.
Documentation
use frlearn_core::{FrResult, Matrix};
use ndarray::Array2;

use crate::distance::Metric;
use crate::{NeighborIndex, pairwise_distance, validate_query_features};

#[derive(Debug, Clone)]
pub struct BruteForceIndex {
    pub x_train: Matrix,
    pub metric: Metric,
}

impl BruteForceIndex {
    pub fn new(x_train: Matrix, metric: Metric) -> Self {
        Self { x_train, metric }
    }
}

impl NeighborIndex for BruteForceIndex {
    fn query(&self, xq: &Matrix, k: usize) -> FrResult<(Array2<usize>, Matrix)> {
        validate_query_features(&self.x_train, xq)?;

        let n_query = xq.nrows();
        let n_train = self.x_train.nrows();
        let effective_k = k.min(n_train);

        let mut indices = Array2::<usize>::zeros((n_query, effective_k));
        let mut distances = Matrix::zeros((n_query, effective_k));

        if n_query == 0 || effective_k == 0 {
            return Ok((indices, distances));
        }

        for (query_idx, query_row) in xq.outer_iter().enumerate() {
            let mut row_distances = self
                .x_train
                .outer_iter()
                .enumerate()
                .map(|(train_idx, train_row)| {
                    let distance = pairwise_distance(query_row, train_row, self.metric)?;
                    Ok((train_idx, distance))
                })
                .collect::<FrResult<Vec<_>>>()?;

            row_distances.sort_by(|left, right| {
                left.1
                    .total_cmp(&right.1)
                    .then_with(|| left.0.cmp(&right.0))
            });

            for neighbor_idx in 0..effective_k {
                let (train_idx, distance) = row_distances[neighbor_idx];
                indices[[query_idx, neighbor_idx]] = train_idx;
                distances[[query_idx, neighbor_idx]] = distance;
            }
        }

        Ok((indices, distances))
    }
}