lance_index/vector/
residual.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::ops::{AddAssign, DivAssign};
5use std::sync::Arc;
6use std::{iter, ops::MulAssign};
7
8use crate::vector::kmeans::{compute_partitions, KMeansAlgoFloat};
9use arrow_array::ArrowNumericType;
10use arrow_array::{
11    cast::AsArray,
12    types::{Float16Type, Float32Type, Float64Type, UInt32Type},
13    Array, FixedSizeListArray, PrimitiveArray, RecordBatch, UInt32Array,
14};
15use arrow_schema::DataType;
16use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt};
17use lance_core::{Error, Result};
18use lance_linalg::distance::{DistanceType, Dot, L2};
19use lance_table::utils::LanceIteratorExtension;
20use num_traits::{Float, FromPrimitive, Num};
21use snafu::location;
22use tracing::instrument;
23
24use super::{transform::Transformer, PQ_CODE_COLUMN};
25
26pub const RESIDUAL_COLUMN: &str = "__residual_vector";
27
28/// Compute the residual vector of a Vector Matrix to their centroids.
29///
30/// The residual vector is the difference between the original vector and the centroid.
31///
32#[derive(Clone)]
33pub struct ResidualTransform {
34    /// Flattened centroids.
35    centroids: FixedSizeListArray,
36
37    /// Partition Column
38    part_col: String,
39
40    /// Vector Column
41    vec_col: String,
42}
43
44impl std::fmt::Debug for ResidualTransform {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        write!(f, "ResidualTransform")
47    }
48}
49
50impl ResidualTransform {
51    pub fn new(centroids: FixedSizeListArray, part_col: &str, column: &str) -> Self {
52        Self {
53            centroids,
54            part_col: part_col.to_owned(),
55            vec_col: column.to_owned(),
56        }
57    }
58}
59
60fn do_compute_residual<T: ArrowNumericType>(
61    centroids: &FixedSizeListArray,
62    vectors: &FixedSizeListArray,
63    distance_type: Option<DistanceType>,
64    partitions: Option<&UInt32Array>,
65) -> Result<FixedSizeListArray>
66where
67    T::Native: Num + Float + L2 + Dot + MulAssign + DivAssign + AddAssign + FromPrimitive,
68    PrimitiveArray<T>: From<Vec<T::Native>>,
69{
70    let dimension = centroids.value_length() as usize;
71    let centroids = centroids.values().as_primitive::<T>();
72    let vectors = vectors.values().as_primitive::<T>();
73
74    let part_ids = partitions.cloned().unwrap_or_else(|| {
75        compute_partitions::<T, KMeansAlgoFloat<T>>(
76            centroids,
77            vectors,
78            dimension,
79            distance_type.expect("provide either partitions or distance type"),
80        )
81        .0
82        .into()
83    });
84    let part_ids = part_ids.values();
85
86    let vectors_slice = vectors.values();
87    let centroids_slice = centroids.values();
88    let residuals = vectors_slice
89        .chunks_exact(dimension)
90        .enumerate()
91        .flat_map(|(idx, vector)| {
92            let part_id = part_ids[idx] as usize;
93            let c = &centroids_slice[part_id * dimension..(part_id + 1) * dimension];
94            iter::zip(vector, c).map(|(v, cent)| *v - *cent)
95        })
96        .exact_size(vectors.len())
97        .collect::<Vec<_>>();
98    let residual_arr = PrimitiveArray::<T>::from_iter_values(residuals);
99    debug_assert_eq!(residual_arr.len(), vectors.len());
100    Ok(FixedSizeListArray::try_new_from_values(
101        residual_arr,
102        dimension as i32,
103    )?)
104}
105
106/// Compute residual vectors from the original vectors and centroids.
107///
108/// ## Parameter
109/// - `centroids`: The KMeans centroids.
110/// - `vectors`: The original vectors to compute residual vectors.
111/// - `distance_type`: The distance type to compute the residual vector.
112/// - `partitions`: The partition ID for each vector, if present.
113pub(crate) fn compute_residual(
114    centroids: &FixedSizeListArray,
115    vectors: &FixedSizeListArray,
116    distance_type: Option<DistanceType>,
117    partitions: Option<&UInt32Array>,
118) -> Result<FixedSizeListArray> {
119    if centroids.value_length() != vectors.value_length() {
120        return Err(Error::Index {
121            message: format!(
122                "Compute residual vector: centroid and vector length mismatch: centroid: {}, vector: {}",
123                centroids.value_length(),
124                vectors.value_length(),
125            ),
126            location: location!(),
127        });
128    }
129    // TODO: Bf16 is not supported yet.
130    match (centroids.value_type(), vectors.value_type()) {
131        (DataType::Float16, DataType::Float16) => {
132            do_compute_residual::<Float16Type>(centroids, vectors, distance_type, partitions)
133        }
134        (DataType::Float32, DataType::Float32) => {
135            do_compute_residual::<Float32Type>(centroids, vectors, distance_type, partitions)
136        }
137        (DataType::Float64, DataType::Float64) => {
138            do_compute_residual::<Float64Type>(centroids, vectors, distance_type, partitions)
139        }
140        (DataType::Float32, DataType::Int8) => {
141            do_compute_residual::<Float32Type>(
142                centroids,
143                &vectors.convert_to_floating_point()?,
144                distance_type,
145                partitions)
146        }
147        _ => Err(Error::Index {
148            message: format!(
149                "Compute residual vector: centroids and vector type mismatch: centroid: {}, vector: {}",
150                centroids.value_type(),
151                vectors.value_type(),
152            ),
153            location: location!(),
154        })
155    }
156}
157
158impl Transformer for ResidualTransform {
159    /// Replace the original vector in the [`RecordBatch`] to residual vectors.
160    ///
161    /// The new [`RecordBatch`] will have a new column named [`RESIDUAL_COLUMN`].
162    #[instrument(name = "ResidualTransform::transform", level = "debug", skip_all)]
163    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
164        if batch.column_by_name(PQ_CODE_COLUMN).is_some() {
165            // If the PQ code column is present, we don't need to compute residual vectors.
166            return Ok(batch.clone());
167        }
168
169        let part_ids = batch.column_by_name(&self.part_col).ok_or(Error::Index {
170            message: format!(
171                "Compute residual vector: partition id column not found: {}",
172                self.part_col
173            ),
174            location: location!(),
175        })?;
176        let original = batch.column_by_name(&self.vec_col).ok_or(Error::Index {
177            message: format!(
178                "Compute residual vector: original vector column not found: {}",
179                self.vec_col
180            ),
181            location: location!(),
182        })?;
183        let original_vectors = original.as_fixed_size_list_opt().ok_or(Error::Index {
184            message: format!(
185                "Compute residual vector: original vector column {} is not fixed size list: {}",
186                self.vec_col,
187                original.data_type(),
188            ),
189            location: location!(),
190        })?;
191
192        let part_ids_ref = part_ids.as_primitive::<UInt32Type>();
193        let residual_arr =
194            compute_residual(&self.centroids, original_vectors, None, Some(part_ids_ref))?;
195
196        // Replace original column with residual column.
197        let batch = if residual_arr.data_type() != original.data_type() {
198            batch.replace_column_schema_by_name(
199                &self.vec_col,
200                residual_arr.data_type().clone(),
201                Arc::new(residual_arr),
202            )?
203        } else {
204            batch.replace_column_by_name(&self.vec_col, Arc::new(residual_arr))?
205        };
206
207        Ok(batch)
208    }
209}