frlearn_neighbor/
bruteforce.rs1use frlearn_core::{FrResult, Matrix};
2use ndarray::Array2;
3
4use crate::distance::Metric;
5use crate::{NeighborIndex, pairwise_distance, validate_query_features};
6
7#[derive(Debug, Clone)]
8pub struct BruteForceIndex {
9 pub x_train: Matrix,
10 pub metric: Metric,
11}
12
13impl BruteForceIndex {
14 pub fn new(x_train: Matrix, metric: Metric) -> Self {
15 Self { x_train, metric }
16 }
17}
18
19impl NeighborIndex for BruteForceIndex {
20 fn query(&self, xq: &Matrix, k: usize) -> FrResult<(Array2<usize>, Matrix)> {
21 validate_query_features(&self.x_train, xq)?;
22
23 let n_query = xq.nrows();
24 let n_train = self.x_train.nrows();
25 let effective_k = k.min(n_train);
26
27 let mut indices = Array2::<usize>::zeros((n_query, effective_k));
28 let mut distances = Matrix::zeros((n_query, effective_k));
29
30 if n_query == 0 || effective_k == 0 {
31 return Ok((indices, distances));
32 }
33
34 for (query_idx, query_row) in xq.outer_iter().enumerate() {
35 let mut row_distances = self
36 .x_train
37 .outer_iter()
38 .enumerate()
39 .map(|(train_idx, train_row)| {
40 let distance = pairwise_distance(query_row, train_row, self.metric)?;
41 Ok((train_idx, distance))
42 })
43 .collect::<FrResult<Vec<_>>>()?;
44
45 row_distances.sort_by(|left, right| {
46 left.1
47 .total_cmp(&right.1)
48 .then_with(|| left.0.cmp(&right.0))
49 });
50
51 for neighbor_idx in 0..effective_k {
52 let (train_idx, distance) = row_distances[neighbor_idx];
53 indices[[query_idx, neighbor_idx]] = train_idx;
54 distances[[query_idx, neighbor_idx]] = distance;
55 }
56 }
57
58 Ok((indices, distances))
59 }
60}