Skip to main content

lance_index/vector/
residual.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::ops::{AddAssign, DivAssign};
5use std::sync::Arc;
6use std::{iter, ops::MulAssign};
7
8use crate::vector::kmeans::{KMeansAlgoFloat, compute_partitions};
9use arrow_array::ArrowNumericType;
10use arrow_array::{
11    Array, FixedSizeListArray, PrimitiveArray, RecordBatch, UInt32Array,
12    cast::AsArray,
13    types::{Float16Type, Float32Type, Float64Type, UInt32Type},
14};
15use arrow_schema::DataType;
16use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt};
17use lance_core::{Error, Result};
18use lance_linalg::distance::{DistanceType, Dot, L2};
19use num_traits::{Float, FromPrimitive, Num};
20use tracing::instrument;
21
22use super::{PQ_CODE_COLUMN, transform::Transformer};
23
24/// Compute the residual vector of a Vector Matrix to their centroids.
25///
26/// The residual vector is the difference between the original vector and the centroid.
27///
28#[derive(Clone)]
29pub struct ResidualTransform {
30    /// Flattened centroids.
31    centroids: FixedSizeListArray,
32
33    /// Partition Column
34    part_col: String,
35
36    /// Vector Column
37    vec_col: String,
38}
39
40impl std::fmt::Debug for ResidualTransform {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        write!(f, "ResidualTransform")
43    }
44}
45
46impl ResidualTransform {
47    pub fn new(centroids: FixedSizeListArray, part_col: &str, column: &str) -> Self {
48        Self {
49            centroids,
50            part_col: part_col.to_owned(),
51            vec_col: column.to_owned(),
52        }
53    }
54}
55
56fn do_compute_residual<T: ArrowNumericType>(
57    centroids: &FixedSizeListArray,
58    vectors: &FixedSizeListArray,
59    distance_type: Option<DistanceType>,
60    partitions: Option<&UInt32Array>,
61) -> Result<FixedSizeListArray>
62where
63    T::Native: Num + Float + L2 + Dot + MulAssign + DivAssign + AddAssign + FromPrimitive,
64    PrimitiveArray<T>: From<Vec<T::Native>>,
65{
66    let dimension = centroids.value_length() as usize;
67    let centroids = centroids.values().as_primitive::<T>();
68    let vectors = vectors.values().as_primitive::<T>();
69
70    let part_ids = partitions.cloned().unwrap_or_else(|| {
71        compute_partitions::<T, KMeansAlgoFloat<T>>(
72            centroids,
73            vectors,
74            dimension,
75            distance_type.expect("provide either partitions or distance type"),
76        )
77        .0
78        .into()
79    });
80    let part_ids = part_ids.values();
81
82    let vectors_slice = vectors.values();
83    let centroids_slice = centroids.values();
84    let mut residuals = Vec::with_capacity(vectors.len());
85    for (idx, vector) in vectors_slice.chunks_exact(dimension).enumerate() {
86        let part_id = part_ids[idx] as usize;
87        let c = &centroids_slice[part_id * dimension..(part_id + 1) * dimension];
88        residuals.extend(iter::zip(vector, c).map(|(v, cent)| *v - *cent));
89    }
90    debug_assert_eq!(residuals.len(), vectors.len());
91    let residual_arr = PrimitiveArray::<T>::from_iter_values(residuals);
92    debug_assert_eq!(residual_arr.len(), vectors.len());
93    Ok(FixedSizeListArray::try_new_from_values(
94        residual_arr,
95        dimension as i32,
96    )?)
97}
98
99/// Compute residual vectors from the original vectors and centroids.
100///
101/// ## Parameter
102/// - `centroids`: The KMeans centroids.
103/// - `vectors`: The original vectors to compute residual vectors.
104/// - `distance_type`: The distance type to compute the residual vector.
105/// - `partitions`: The partition ID for each vector, if present.
106pub(crate) fn compute_residual(
107    centroids: &FixedSizeListArray,
108    vectors: &FixedSizeListArray,
109    distance_type: Option<DistanceType>,
110    partitions: Option<&UInt32Array>,
111) -> Result<FixedSizeListArray> {
112    if centroids.value_length() != vectors.value_length() {
113        return Err(Error::index(format!(
114            "Compute residual vector: centroid and vector length mismatch: centroid: {}, vector: {}",
115            centroids.value_length(),
116            vectors.value_length(),
117        )));
118    }
119    // TODO: Bf16 is not supported yet.
120    match (centroids.value_type(), vectors.value_type()) {
121        (DataType::Float16, DataType::Float16) => {
122            do_compute_residual::<Float16Type>(centroids, vectors, distance_type, partitions)
123        }
124        (DataType::Float32, DataType::Float32) => {
125            do_compute_residual::<Float32Type>(centroids, vectors, distance_type, partitions)
126        }
127        (DataType::Float64, DataType::Float64) => {
128            do_compute_residual::<Float64Type>(centroids, vectors, distance_type, partitions)
129        }
130        (DataType::Float32, DataType::Int8) => do_compute_residual::<Float32Type>(
131            centroids,
132            &vectors.convert_to_floating_point()?,
133            distance_type,
134            partitions,
135        ),
136        _ => Err(Error::index(format!(
137            "Compute residual vector: centroids and vector type mismatch: centroid: {}, vector: {}",
138            centroids.value_type(),
139            vectors.value_type(),
140        ))),
141    }
142}
143
144impl Transformer for ResidualTransform {
145    /// Replace the original vector in the [`RecordBatch`] to residual vectors.
146    ///
147    /// The new [`RecordBatch`] will have a new column named `RESIDUAL_COLUMN`.
148    #[instrument(name = "ResidualTransform::transform", level = "debug", skip_all)]
149    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
150        if batch.column_by_name(PQ_CODE_COLUMN).is_some() {
151            // If the PQ code column is present, we don't need to compute residual vectors.
152            return Ok(batch.clone());
153        }
154
155        let part_ids = batch
156            .column_by_name(&self.part_col)
157            .ok_or(Error::index(format!(
158                "Compute residual vector: partition id column not found: {}",
159                self.part_col
160            )))?;
161        let original = batch
162            .column_by_name(&self.vec_col)
163            .ok_or(Error::index(format!(
164                "Compute residual vector: original vector column {} not found in batch {}",
165                self.vec_col,
166                batch.schema(),
167            )))?;
168        let original_vectors = original
169            .as_fixed_size_list_opt()
170            .ok_or(Error::index(format!(
171                "Compute residual vector: original vector column {} is not fixed size list: {}",
172                self.vec_col,
173                original.data_type(),
174            )))?;
175
176        let part_ids_ref = part_ids.as_primitive::<UInt32Type>();
177        let residual_arr =
178            compute_residual(&self.centroids, original_vectors, None, Some(part_ids_ref))?;
179
180        let batch = if residual_arr.data_type() != original.data_type() {
181            batch.replace_column_schema_by_name(
182                &self.vec_col,
183                residual_arr.data_type().clone(),
184                Arc::new(residual_arr),
185            )?
186        } else {
187            batch.replace_column_by_name(&self.vec_col, Arc::new(residual_arr))?
188        };
189
190        Ok(batch)
191    }
192}