use diskann::utils::IntoUsize;
use diskann_utils::views;
use diskann_vector::{PureDistanceFunction, distance::SquaredL2};
pub struct MismatchRecord {
pub row: usize,
pub chunk: usize,
pub a_assignment: usize,
pub a_pivot: Vec<f32>,
pub b_assignment: usize,
pub b_pivot: Vec<f32>,
pub data: Vec<f32>,
pub center: Vec<f32>,
pub squared_l2_a: f32,
pub squared_l2_b: f32,
}
impl std::fmt::Display for MismatchRecord {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
writeln!(f, "mismatch on row {} and chunk {}", self.row, self.chunk)?;
writeln!(
f,
"argument A had assignment {} but B had assignment {}",
self.a_assignment, self.b_assignment
)?;
writeln!(f, "data = {:?}", self.data)?;
writeln!(f, "center = {:?}", self.center)?;
writeln!(f, "pivot_a = {:?}", self.a_pivot)?;
writeln!(f, "pivot_b = {:?}", self.b_pivot)?;
writeln!(f, "distance from a = {}", self.squared_l2_a)?;
writeln!(f, "distance from b = {}", self.squared_l2_b)
}
}
pub fn compare_pq<T, U>(
data: views::MatrixView<'_, T>,
schema: diskann_quantization::views::ChunkOffsetsView<'_>,
pivots: views::MatrixView<'_, f32>,
center: &[f32],
a: views::MatrixView<'_, U>,
b: views::MatrixView<'_, U>,
) -> Vec<MismatchRecord>
where
T: Copy + Into<f32>,
U: Copy + IntoUsize,
{
std::iter::zip(a.row_iter(), b.row_iter())
.enumerate()
.flat_map(|(row, (a_row, b_row))| {
std::iter::zip(a_row.iter(), b_row.iter())
.enumerate()
.filter_map(move |(chunk, (a, b))| {
let a: usize = a.into_usize();
let b: usize = b.into_usize();
if a == b {
return None;
}
let range = schema.at(chunk);
let source_data: Vec<f32> = data.row(row)[range.clone()]
.iter()
.map(|&x| x.into())
.collect();
let center = center[range.clone()].to_vec();
let a_pivot = pivots.row(a)[range.clone()].to_vec();
let b_pivot = pivots.row(b)[range.clone()].to_vec();
let source_data_compensated: Vec<f32> =
std::iter::zip(source_data.iter(), center.iter())
.map(|(s, c)| s - c)
.collect();
let squared_l2_a =
SquaredL2::evaluate(source_data_compensated.as_slice(), a_pivot.as_slice());
let squared_l2_b =
SquaredL2::evaluate(source_data_compensated.as_slice(), b_pivot.as_slice());
Some(MismatchRecord {
row,
chunk,
a_assignment: a,
a_pivot,
b_assignment: b,
b_pivot,
data: source_data,
center,
squared_l2_a,
squared_l2_b,
})
})
})
.collect()
}