Skip to main content

frlearn_neighbor/
bruteforce.rs

1use frlearn_core::{FrResult, Matrix};
2use ndarray::Array2;
3
4use crate::distance::Metric;
5use crate::{NeighborIndex, pairwise_distance, validate_query_features};
6
7#[derive(Debug, Clone)]
8pub struct BruteForceIndex {
9    pub x_train: Matrix,
10    pub metric: Metric,
11}
12
13impl BruteForceIndex {
14    pub fn new(x_train: Matrix, metric: Metric) -> Self {
15        Self { x_train, metric }
16    }
17}
18
19impl NeighborIndex for BruteForceIndex {
20    fn query(&self, xq: &Matrix, k: usize) -> FrResult<(Array2<usize>, Matrix)> {
21        validate_query_features(&self.x_train, xq)?;
22
23        let n_query = xq.nrows();
24        let n_train = self.x_train.nrows();
25        let effective_k = k.min(n_train);
26
27        let mut indices = Array2::<usize>::zeros((n_query, effective_k));
28        let mut distances = Matrix::zeros((n_query, effective_k));
29
30        if n_query == 0 || effective_k == 0 {
31            return Ok((indices, distances));
32        }
33
34        for (query_idx, query_row) in xq.outer_iter().enumerate() {
35            let mut row_distances = self
36                .x_train
37                .outer_iter()
38                .enumerate()
39                .map(|(train_idx, train_row)| {
40                    let distance = pairwise_distance(query_row, train_row, self.metric)?;
41                    Ok((train_idx, distance))
42                })
43                .collect::<FrResult<Vec<_>>>()?;
44
45            row_distances.sort_by(|left, right| {
46                left.1
47                    .total_cmp(&right.1)
48                    .then_with(|| left.0.cmp(&right.0))
49            });
50
51            for neighbor_idx in 0..effective_k {
52                let (train_idx, distance) = row_distances[neighbor_idx];
53                indices[[query_idx, neighbor_idx]] = train_idx;
54                distances[[query_idx, neighbor_idx]] = distance;
55            }
56        }
57
58        Ok((indices, distances))
59    }
60}