1use std::ops::{AddAssign, DivAssign};
5use std::sync::Arc;
6use std::{iter, ops::MulAssign};
7
8use crate::vector::kmeans::{compute_partitions, KMeansAlgoFloat};
9use arrow_array::ArrowNumericType;
10use arrow_array::{
11 cast::AsArray,
12 types::{Float16Type, Float32Type, Float64Type, UInt32Type},
13 Array, FixedSizeListArray, PrimitiveArray, RecordBatch, UInt32Array,
14};
15use arrow_schema::DataType;
16use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt};
17use lance_core::{Error, Result};
18use lance_linalg::distance::{DistanceType, Dot, L2};
19use lance_table::utils::LanceIteratorExtension;
20use num_traits::{Float, FromPrimitive, Num};
21use snafu::location;
22use tracing::instrument;
23
24use super::{transform::Transformer, PQ_CODE_COLUMN};
25
26#[derive(Clone)]
31pub struct ResidualTransform {
32 centroids: FixedSizeListArray,
34
35 part_col: String,
37
38 vec_col: String,
40}
41
42impl std::fmt::Debug for ResidualTransform {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "ResidualTransform")
45 }
46}
47
48impl ResidualTransform {
49 pub fn new(centroids: FixedSizeListArray, part_col: &str, column: &str) -> Self {
50 Self {
51 centroids,
52 part_col: part_col.to_owned(),
53 vec_col: column.to_owned(),
54 }
55 }
56}
57
58fn do_compute_residual<T: ArrowNumericType>(
59 centroids: &FixedSizeListArray,
60 vectors: &FixedSizeListArray,
61 distance_type: Option<DistanceType>,
62 partitions: Option<&UInt32Array>,
63) -> Result<FixedSizeListArray>
64where
65 T::Native: Num + Float + L2 + Dot + MulAssign + DivAssign + AddAssign + FromPrimitive,
66 PrimitiveArray<T>: From<Vec<T::Native>>,
67{
68 let dimension = centroids.value_length() as usize;
69 let centroids = centroids.values().as_primitive::<T>();
70 let vectors = vectors.values().as_primitive::<T>();
71
72 let part_ids = partitions.cloned().unwrap_or_else(|| {
73 compute_partitions::<T, KMeansAlgoFloat<T>>(
74 centroids,
75 vectors,
76 dimension,
77 distance_type.expect("provide either partitions or distance type"),
78 )
79 .0
80 .into()
81 });
82 let part_ids = part_ids.values();
83
84 let vectors_slice = vectors.values();
85 let centroids_slice = centroids.values();
86 let residuals = vectors_slice
87 .chunks_exact(dimension)
88 .enumerate()
89 .flat_map(|(idx, vector)| {
90 let part_id = part_ids[idx] as usize;
91 let c = ¢roids_slice[part_id * dimension..(part_id + 1) * dimension];
92 iter::zip(vector, c).map(|(v, cent)| *v - *cent)
93 })
94 .exact_size(vectors.len())
95 .collect::<Vec<_>>();
96 let residual_arr = PrimitiveArray::<T>::from_iter_values(residuals);
97 debug_assert_eq!(residual_arr.len(), vectors.len());
98 Ok(FixedSizeListArray::try_new_from_values(
99 residual_arr,
100 dimension as i32,
101 )?)
102}
103
104pub(crate) fn compute_residual(
112 centroids: &FixedSizeListArray,
113 vectors: &FixedSizeListArray,
114 distance_type: Option<DistanceType>,
115 partitions: Option<&UInt32Array>,
116) -> Result<FixedSizeListArray> {
117 if centroids.value_length() != vectors.value_length() {
118 return Err(Error::Index {
119 message: format!(
120 "Compute residual vector: centroid and vector length mismatch: centroid: {}, vector: {}",
121 centroids.value_length(),
122 vectors.value_length(),
123 ),
124 location: location!(),
125 });
126 }
127 match (centroids.value_type(), vectors.value_type()) {
129 (DataType::Float16, DataType::Float16) => {
130 do_compute_residual::<Float16Type>(centroids, vectors, distance_type, partitions)
131 }
132 (DataType::Float32, DataType::Float32) => {
133 do_compute_residual::<Float32Type>(centroids, vectors, distance_type, partitions)
134 }
135 (DataType::Float64, DataType::Float64) => {
136 do_compute_residual::<Float64Type>(centroids, vectors, distance_type, partitions)
137 }
138 (DataType::Float32, DataType::Int8) => {
139 do_compute_residual::<Float32Type>(
140 centroids,
141 &vectors.convert_to_floating_point()?,
142 distance_type,
143 partitions)
144 }
145 _ => Err(Error::Index {
146 message: format!(
147 "Compute residual vector: centroids and vector type mismatch: centroid: {}, vector: {}",
148 centroids.value_type(),
149 vectors.value_type(),
150 ),
151 location: location!(),
152 })
153 }
154}
155
156impl Transformer for ResidualTransform {
157 #[instrument(name = "ResidualTransform::transform", level = "debug", skip_all)]
161 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
162 if batch.column_by_name(PQ_CODE_COLUMN).is_some() {
163 return Ok(batch.clone());
165 }
166
167 let part_ids = batch.column_by_name(&self.part_col).ok_or(Error::Index {
168 message: format!(
169 "Compute residual vector: partition id column not found: {}",
170 self.part_col
171 ),
172 location: location!(),
173 })?;
174 let original = batch.column_by_name(&self.vec_col).ok_or(Error::Index {
175 message: format!(
176 "Compute residual vector: original vector column {} not found in batch {}",
177 self.vec_col,
178 batch.schema(),
179 ),
180 location: location!(),
181 })?;
182 let original_vectors = original.as_fixed_size_list_opt().ok_or(Error::Index {
183 message: format!(
184 "Compute residual vector: original vector column {} is not fixed size list: {}",
185 self.vec_col,
186 original.data_type(),
187 ),
188 location: location!(),
189 })?;
190
191 let part_ids_ref = part_ids.as_primitive::<UInt32Type>();
192 let residual_arr =
193 compute_residual(&self.centroids, original_vectors, None, Some(part_ids_ref))?;
194
195 let batch = if residual_arr.data_type() != original.data_type() {
196 batch.replace_column_schema_by_name(
197 &self.vec_col,
198 residual_arr.data_type().clone(),
199 Arc::new(residual_arr),
200 )?
201 } else {
202 batch.replace_column_by_name(&self.vec_col, Arc::new(residual_arr))?
203 };
204
205 Ok(batch)
206 }
207}