1use std::ops::{AddAssign, DivAssign};
5use std::sync::Arc;
6use std::{iter, ops::MulAssign};
7
8use crate::vector::kmeans::{KMeansAlgoFloat, compute_partitions};
9use arrow_array::ArrowNumericType;
10use arrow_array::{
11 Array, FixedSizeListArray, PrimitiveArray, RecordBatch, UInt32Array,
12 cast::AsArray,
13 types::{Float16Type, Float32Type, Float64Type, UInt32Type},
14};
15use arrow_schema::DataType;
16use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt};
17use lance_core::{Error, Result};
18use lance_linalg::distance::{DistanceType, Dot, L2};
19use num_traits::{Float, FromPrimitive, Num};
20use tracing::instrument;
21
22use super::{PQ_CODE_COLUMN, transform::Transformer};
23
24#[derive(Clone)]
29pub struct ResidualTransform {
30 centroids: FixedSizeListArray,
32
33 part_col: String,
35
36 vec_col: String,
38}
39
40impl std::fmt::Debug for ResidualTransform {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 write!(f, "ResidualTransform")
43 }
44}
45
46impl ResidualTransform {
47 pub fn new(centroids: FixedSizeListArray, part_col: &str, column: &str) -> Self {
48 Self {
49 centroids,
50 part_col: part_col.to_owned(),
51 vec_col: column.to_owned(),
52 }
53 }
54}
55
56fn do_compute_residual<T: ArrowNumericType>(
57 centroids: &FixedSizeListArray,
58 vectors: &FixedSizeListArray,
59 distance_type: Option<DistanceType>,
60 partitions: Option<&UInt32Array>,
61) -> Result<FixedSizeListArray>
62where
63 T::Native: Num + Float + L2 + Dot + MulAssign + DivAssign + AddAssign + FromPrimitive,
64 PrimitiveArray<T>: From<Vec<T::Native>>,
65{
66 let dimension = centroids.value_length() as usize;
67 let centroids = centroids.values().as_primitive::<T>();
68 let vectors = vectors.values().as_primitive::<T>();
69
70 let part_ids = partitions.cloned().unwrap_or_else(|| {
71 compute_partitions::<T, KMeansAlgoFloat<T>>(
72 centroids,
73 vectors,
74 dimension,
75 distance_type.expect("provide either partitions or distance type"),
76 )
77 .0
78 .into()
79 });
80 let part_ids = part_ids.values();
81
82 let vectors_slice = vectors.values();
83 let centroids_slice = centroids.values();
84 let mut residuals = Vec::with_capacity(vectors.len());
85 for (idx, vector) in vectors_slice.chunks_exact(dimension).enumerate() {
86 let part_id = part_ids[idx] as usize;
87 let c = ¢roids_slice[part_id * dimension..(part_id + 1) * dimension];
88 residuals.extend(iter::zip(vector, c).map(|(v, cent)| *v - *cent));
89 }
90 debug_assert_eq!(residuals.len(), vectors.len());
91 let residual_arr = PrimitiveArray::<T>::from_iter_values(residuals);
92 debug_assert_eq!(residual_arr.len(), vectors.len());
93 Ok(FixedSizeListArray::try_new_from_values(
94 residual_arr,
95 dimension as i32,
96 )?)
97}
98
99pub(crate) fn compute_residual(
107 centroids: &FixedSizeListArray,
108 vectors: &FixedSizeListArray,
109 distance_type: Option<DistanceType>,
110 partitions: Option<&UInt32Array>,
111) -> Result<FixedSizeListArray> {
112 if centroids.value_length() != vectors.value_length() {
113 return Err(Error::index(format!(
114 "Compute residual vector: centroid and vector length mismatch: centroid: {}, vector: {}",
115 centroids.value_length(),
116 vectors.value_length(),
117 )));
118 }
119 match (centroids.value_type(), vectors.value_type()) {
121 (DataType::Float16, DataType::Float16) => {
122 do_compute_residual::<Float16Type>(centroids, vectors, distance_type, partitions)
123 }
124 (DataType::Float32, DataType::Float32) => {
125 do_compute_residual::<Float32Type>(centroids, vectors, distance_type, partitions)
126 }
127 (DataType::Float64, DataType::Float64) => {
128 do_compute_residual::<Float64Type>(centroids, vectors, distance_type, partitions)
129 }
130 (DataType::Float32, DataType::Int8) => do_compute_residual::<Float32Type>(
131 centroids,
132 &vectors.convert_to_floating_point()?,
133 distance_type,
134 partitions,
135 ),
136 _ => Err(Error::index(format!(
137 "Compute residual vector: centroids and vector type mismatch: centroid: {}, vector: {}",
138 centroids.value_type(),
139 vectors.value_type(),
140 ))),
141 }
142}
143
144impl Transformer for ResidualTransform {
145 #[instrument(name = "ResidualTransform::transform", level = "debug", skip_all)]
149 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
150 if batch.column_by_name(PQ_CODE_COLUMN).is_some() {
151 return Ok(batch.clone());
153 }
154
155 let part_ids = batch
156 .column_by_name(&self.part_col)
157 .ok_or(Error::index(format!(
158 "Compute residual vector: partition id column not found: {}",
159 self.part_col
160 )))?;
161 let original = batch
162 .column_by_name(&self.vec_col)
163 .ok_or(Error::index(format!(
164 "Compute residual vector: original vector column {} not found in batch {}",
165 self.vec_col,
166 batch.schema(),
167 )))?;
168 let original_vectors = original
169 .as_fixed_size_list_opt()
170 .ok_or(Error::index(format!(
171 "Compute residual vector: original vector column {} is not fixed size list: {}",
172 self.vec_col,
173 original.data_type(),
174 )))?;
175
176 let part_ids_ref = part_ids.as_primitive::<UInt32Type>();
177 let residual_arr =
178 compute_residual(&self.centroids, original_vectors, None, Some(part_ids_ref))?;
179
180 let batch = if residual_arr.data_type() != original.data_type() {
181 batch.replace_column_schema_by_name(
182 &self.vec_col,
183 residual_arr.data_type().clone(),
184 Arc::new(residual_arr),
185 )?
186 } else {
187 batch.replace_column_by_name(&self.vec_col, Arc::new(residual_arr))?
188 };
189
190 Ok(batch)
191 }
192}