lance_index/vector/
residual.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::ops::{AddAssign, DivAssign};
5use std::sync::Arc;
6use std::{iter, ops::MulAssign};
7
8use crate::vector::kmeans::{compute_partitions, KMeansAlgoFloat};
9use arrow_array::ArrowNumericType;
10use arrow_array::{
11    cast::AsArray,
12    types::{Float16Type, Float32Type, Float64Type, UInt32Type},
13    Array, FixedSizeListArray, PrimitiveArray, RecordBatch, UInt32Array,
14};
15use arrow_schema::DataType;
16use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt};
17use lance_core::{Error, Result};
18use lance_linalg::distance::{DistanceType, Dot, L2};
19use lance_table::utils::LanceIteratorExtension;
20use num_traits::{Float, FromPrimitive, Num};
21use snafu::location;
22use tracing::instrument;
23
24use super::{transform::Transformer, PQ_CODE_COLUMN};
25
26/// Compute the residual vector of a Vector Matrix to their centroids.
27///
28/// The residual vector is the difference between the original vector and the centroid.
29///
30#[derive(Clone)]
31pub struct ResidualTransform {
32    /// Flattened centroids.
33    centroids: FixedSizeListArray,
34
35    /// Partition Column
36    part_col: String,
37
38    /// Vector Column
39    vec_col: String,
40}
41
42impl std::fmt::Debug for ResidualTransform {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "ResidualTransform")
45    }
46}
47
48impl ResidualTransform {
49    pub fn new(centroids: FixedSizeListArray, part_col: &str, column: &str) -> Self {
50        Self {
51            centroids,
52            part_col: part_col.to_owned(),
53            vec_col: column.to_owned(),
54        }
55    }
56}
57
58fn do_compute_residual<T: ArrowNumericType>(
59    centroids: &FixedSizeListArray,
60    vectors: &FixedSizeListArray,
61    distance_type: Option<DistanceType>,
62    partitions: Option<&UInt32Array>,
63) -> Result<FixedSizeListArray>
64where
65    T::Native: Num + Float + L2 + Dot + MulAssign + DivAssign + AddAssign + FromPrimitive,
66    PrimitiveArray<T>: From<Vec<T::Native>>,
67{
68    let dimension = centroids.value_length() as usize;
69    let centroids = centroids.values().as_primitive::<T>();
70    let vectors = vectors.values().as_primitive::<T>();
71
72    let part_ids = partitions.cloned().unwrap_or_else(|| {
73        compute_partitions::<T, KMeansAlgoFloat<T>>(
74            centroids,
75            vectors,
76            dimension,
77            distance_type.expect("provide either partitions or distance type"),
78        )
79        .0
80        .into()
81    });
82    let part_ids = part_ids.values();
83
84    let vectors_slice = vectors.values();
85    let centroids_slice = centroids.values();
86    let residuals = vectors_slice
87        .chunks_exact(dimension)
88        .enumerate()
89        .flat_map(|(idx, vector)| {
90            let part_id = part_ids[idx] as usize;
91            let c = &centroids_slice[part_id * dimension..(part_id + 1) * dimension];
92            iter::zip(vector, c).map(|(v, cent)| *v - *cent)
93        })
94        .exact_size(vectors.len())
95        .collect::<Vec<_>>();
96    let residual_arr = PrimitiveArray::<T>::from_iter_values(residuals);
97    debug_assert_eq!(residual_arr.len(), vectors.len());
98    Ok(FixedSizeListArray::try_new_from_values(
99        residual_arr,
100        dimension as i32,
101    )?)
102}
103
104/// Compute residual vectors from the original vectors and centroids.
105///
106/// ## Parameter
107/// - `centroids`: The KMeans centroids.
108/// - `vectors`: The original vectors to compute residual vectors.
109/// - `distance_type`: The distance type to compute the residual vector.
110/// - `partitions`: The partition ID for each vector, if present.
111pub(crate) fn compute_residual(
112    centroids: &FixedSizeListArray,
113    vectors: &FixedSizeListArray,
114    distance_type: Option<DistanceType>,
115    partitions: Option<&UInt32Array>,
116) -> Result<FixedSizeListArray> {
117    if centroids.value_length() != vectors.value_length() {
118        return Err(Error::Index {
119            message: format!(
120                "Compute residual vector: centroid and vector length mismatch: centroid: {}, vector: {}",
121                centroids.value_length(),
122                vectors.value_length(),
123            ),
124            location: location!(),
125        });
126    }
127    // TODO: Bf16 is not supported yet.
128    match (centroids.value_type(), vectors.value_type()) {
129        (DataType::Float16, DataType::Float16) => {
130            do_compute_residual::<Float16Type>(centroids, vectors, distance_type, partitions)
131        }
132        (DataType::Float32, DataType::Float32) => {
133            do_compute_residual::<Float32Type>(centroids, vectors, distance_type, partitions)
134        }
135        (DataType::Float64, DataType::Float64) => {
136            do_compute_residual::<Float64Type>(centroids, vectors, distance_type, partitions)
137        }
138        (DataType::Float32, DataType::Int8) => {
139            do_compute_residual::<Float32Type>(
140                centroids,
141                &vectors.convert_to_floating_point()?,
142                distance_type,
143                partitions)
144        }
145        _ => Err(Error::Index {
146            message: format!(
147                "Compute residual vector: centroids and vector type mismatch: centroid: {}, vector: {}",
148                centroids.value_type(),
149                vectors.value_type(),
150            ),
151            location: location!(),
152        })
153    }
154}
155
156impl Transformer for ResidualTransform {
157    /// Replace the original vector in the [`RecordBatch`] to residual vectors.
158    ///
159    /// The new [`RecordBatch`] will have a new column named [`RESIDUAL_COLUMN`].
160    #[instrument(name = "ResidualTransform::transform", level = "debug", skip_all)]
161    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
162        if batch.column_by_name(PQ_CODE_COLUMN).is_some() {
163            // If the PQ code column is present, we don't need to compute residual vectors.
164            return Ok(batch.clone());
165        }
166
167        let part_ids = batch.column_by_name(&self.part_col).ok_or(Error::Index {
168            message: format!(
169                "Compute residual vector: partition id column not found: {}",
170                self.part_col
171            ),
172            location: location!(),
173        })?;
174        let original = batch.column_by_name(&self.vec_col).ok_or(Error::Index {
175            message: format!(
176                "Compute residual vector: original vector column {} not found in batch {}",
177                self.vec_col,
178                batch.schema(),
179            ),
180            location: location!(),
181        })?;
182        let original_vectors = original.as_fixed_size_list_opt().ok_or(Error::Index {
183            message: format!(
184                "Compute residual vector: original vector column {} is not fixed size list: {}",
185                self.vec_col,
186                original.data_type(),
187            ),
188            location: location!(),
189        })?;
190
191        let part_ids_ref = part_ids.as_primitive::<UInt32Type>();
192        let residual_arr =
193            compute_residual(&self.centroids, original_vectors, None, Some(part_ids_ref))?;
194
195        let batch = if residual_arr.data_type() != original.data_type() {
196            batch.replace_column_schema_by_name(
197                &self.vec_col,
198                residual_arr.data_type().clone(),
199                Arc::new(residual_arr),
200            )?
201        } else {
202            batch.replace_column_by_name(&self.vec_col, Arc::new(residual_arr))?
203        };
204
205        Ok(batch)
206    }
207}