1use std::ops::{AddAssign, DivAssign};
5use std::sync::Arc;
6use std::{iter, ops::MulAssign};
7
8use crate::vector::kmeans::{compute_partitions, KMeansAlgoFloat};
9use arrow_array::ArrowNumericType;
10use arrow_array::{
11 cast::AsArray,
12 types::{Float16Type, Float32Type, Float64Type, UInt32Type},
13 Array, FixedSizeListArray, PrimitiveArray, RecordBatch, UInt32Array,
14};
15use arrow_schema::DataType;
16use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt};
17use lance_core::{Error, Result};
18use lance_linalg::distance::{DistanceType, Dot, L2};
19use lance_table::utils::LanceIteratorExtension;
20use num_traits::{Float, FromPrimitive, Num};
21use snafu::location;
22use tracing::instrument;
23
24use super::{transform::Transformer, PQ_CODE_COLUMN};
25
26pub const RESIDUAL_COLUMN: &str = "__residual_vector";
27
28#[derive(Clone)]
33pub struct ResidualTransform {
34 centroids: FixedSizeListArray,
36
37 part_col: String,
39
40 vec_col: String,
42}
43
44impl std::fmt::Debug for ResidualTransform {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 write!(f, "ResidualTransform")
47 }
48}
49
50impl ResidualTransform {
51 pub fn new(centroids: FixedSizeListArray, part_col: &str, column: &str) -> Self {
52 Self {
53 centroids,
54 part_col: part_col.to_owned(),
55 vec_col: column.to_owned(),
56 }
57 }
58}
59
60fn do_compute_residual<T: ArrowNumericType>(
61 centroids: &FixedSizeListArray,
62 vectors: &FixedSizeListArray,
63 distance_type: Option<DistanceType>,
64 partitions: Option<&UInt32Array>,
65) -> Result<FixedSizeListArray>
66where
67 T::Native: Num + Float + L2 + Dot + MulAssign + DivAssign + AddAssign + FromPrimitive,
68 PrimitiveArray<T>: From<Vec<T::Native>>,
69{
70 let dimension = centroids.value_length() as usize;
71 let centroids = centroids.values().as_primitive::<T>();
72 let vectors = vectors.values().as_primitive::<T>();
73
74 let part_ids = partitions.cloned().unwrap_or_else(|| {
75 compute_partitions::<T, KMeansAlgoFloat<T>>(
76 centroids,
77 vectors,
78 dimension,
79 distance_type.expect("provide either partitions or distance type"),
80 )
81 .0
82 .into()
83 });
84 let part_ids = part_ids.values();
85
86 let vectors_slice = vectors.values();
87 let centroids_slice = centroids.values();
88 let residuals = vectors_slice
89 .chunks_exact(dimension)
90 .enumerate()
91 .flat_map(|(idx, vector)| {
92 let part_id = part_ids[idx] as usize;
93 let c = ¢roids_slice[part_id * dimension..(part_id + 1) * dimension];
94 iter::zip(vector, c).map(|(v, cent)| *v - *cent)
95 })
96 .exact_size(vectors.len())
97 .collect::<Vec<_>>();
98 let residual_arr = PrimitiveArray::<T>::from_iter_values(residuals);
99 debug_assert_eq!(residual_arr.len(), vectors.len());
100 Ok(FixedSizeListArray::try_new_from_values(
101 residual_arr,
102 dimension as i32,
103 )?)
104}
105
106pub(crate) fn compute_residual(
114 centroids: &FixedSizeListArray,
115 vectors: &FixedSizeListArray,
116 distance_type: Option<DistanceType>,
117 partitions: Option<&UInt32Array>,
118) -> Result<FixedSizeListArray> {
119 if centroids.value_length() != vectors.value_length() {
120 return Err(Error::Index {
121 message: format!(
122 "Compute residual vector: centroid and vector length mismatch: centroid: {}, vector: {}",
123 centroids.value_length(),
124 vectors.value_length(),
125 ),
126 location: location!(),
127 });
128 }
129 match (centroids.value_type(), vectors.value_type()) {
131 (DataType::Float16, DataType::Float16) => {
132 do_compute_residual::<Float16Type>(centroids, vectors, distance_type, partitions)
133 }
134 (DataType::Float32, DataType::Float32) => {
135 do_compute_residual::<Float32Type>(centroids, vectors, distance_type, partitions)
136 }
137 (DataType::Float64, DataType::Float64) => {
138 do_compute_residual::<Float64Type>(centroids, vectors, distance_type, partitions)
139 }
140 (DataType::Float32, DataType::Int8) => {
141 do_compute_residual::<Float32Type>(
142 centroids,
143 &vectors.convert_to_floating_point()?,
144 distance_type,
145 partitions)
146 }
147 _ => Err(Error::Index {
148 message: format!(
149 "Compute residual vector: centroids and vector type mismatch: centroid: {}, vector: {}",
150 centroids.value_type(),
151 vectors.value_type(),
152 ),
153 location: location!(),
154 })
155 }
156}
157
158impl Transformer for ResidualTransform {
159 #[instrument(name = "ResidualTransform::transform", level = "debug", skip_all)]
163 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
164 if batch.column_by_name(PQ_CODE_COLUMN).is_some() {
165 return Ok(batch.clone());
167 }
168
169 let part_ids = batch.column_by_name(&self.part_col).ok_or(Error::Index {
170 message: format!(
171 "Compute residual vector: partition id column not found: {}",
172 self.part_col
173 ),
174 location: location!(),
175 })?;
176 let original = batch.column_by_name(&self.vec_col).ok_or(Error::Index {
177 message: format!(
178 "Compute residual vector: original vector column not found: {}",
179 self.vec_col
180 ),
181 location: location!(),
182 })?;
183 let original_vectors = original.as_fixed_size_list_opt().ok_or(Error::Index {
184 message: format!(
185 "Compute residual vector: original vector column {} is not fixed size list: {}",
186 self.vec_col,
187 original.data_type(),
188 ),
189 location: location!(),
190 })?;
191
192 let part_ids_ref = part_ids.as_primitive::<UInt32Type>();
193 let residual_arr =
194 compute_residual(&self.centroids, original_vectors, None, Some(part_ids_ref))?;
195
196 let batch = if residual_arr.data_type() != original.data_type() {
198 batch.replace_column_schema_by_name(
199 &self.vec_col,
200 residual_arr.data_type().clone(),
201 Arc::new(residual_arr),
202 )?
203 } else {
204 batch.replace_column_by_name(&self.vec_col, Arc::new(residual_arr))?
205 };
206
207 Ok(batch)
208 }
209}