1use std::ops::Range;
5
6use arrow::datatypes::Float64Type;
7use arrow::{compute::concat_batches, datatypes::Float16Type};
8use arrow_array::{
9 ArrayRef, RecordBatch, UInt8Array, UInt64Array,
10 cast::AsArray,
11 types::{Float32Type, UInt8Type, UInt64Type},
12};
13use arrow_schema::{DataType, SchemaRef};
14use async_trait::async_trait;
15use deepsize::DeepSizeOf;
16use lance_core::{Error, ROW_ID, Result};
17use lance_file::previous::reader::FileReader as PreviousFileReader;
18use lance_io::object_store::ObjectStore;
19use lance_linalg::distance::{DistanceType, dot_distance, l2_distance_uint_scalar};
20use lance_table::format::SelfDescribingFileReader;
21use object_store::path::Path;
22use serde::{Deserialize, Serialize};
23use std::sync::Arc;
24
25use super::{ScalarQuantizer, scale_to_u8};
26use crate::frag_reuse::FragReuseIndex;
27use crate::{
28 INDEX_METADATA_SCHEMA_KEY, IndexMetadata,
29 vector::{
30 SQ_CODE_COLUMN,
31 quantizer::{QuantizerMetadata, QuantizerStorage},
32 storage::{DistCalculator, VectorStore},
33 transform::Transformer,
34 },
35};
36
37pub const SQ_METADATA_KEY: &str = "lance:sq";
38
39#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
40pub struct ScalarQuantizationMetadata {
41 pub dim: usize,
42 pub num_bits: u16,
43 pub bounds: Range<f64>,
44}
45
46impl DeepSizeOf for ScalarQuantizationMetadata {
47 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
48 0
49 }
50}
51
52#[async_trait]
53impl QuantizerMetadata for ScalarQuantizationMetadata {
54 async fn load(reader: &PreviousFileReader) -> Result<Self> {
55 let metadata_str = reader
56 .schema()
57 .metadata
58 .get(SQ_METADATA_KEY)
59 .ok_or(Error::index(format!(
60 "Reading SQ metadata: metadata key {} not found",
61 SQ_METADATA_KEY
62 )))?;
63 serde_json::from_str(metadata_str)
64 .map_err(|_| Error::index(format!("Failed to parse index metadata: {}", metadata_str)))
65 }
66}
67
68#[derive(Debug, Clone)]
70struct SQStorageChunk {
71 batch: RecordBatch,
72
73 dim: usize,
74
75 row_ids: UInt64Array,
79 sq_codes: UInt8Array,
80}
81
82impl SQStorageChunk {
83 fn new(batch: RecordBatch) -> Result<Self> {
85 let row_ids = batch
86 .column_by_name(ROW_ID)
87 .ok_or(Error::index(
88 "Row ID column not found in the batch".to_owned(),
89 ))?
90 .as_primitive::<UInt64Type>()
91 .clone();
92 let fsl = batch
93 .column_by_name(SQ_CODE_COLUMN)
94 .ok_or(Error::index(
95 "SQ code column not found in the batch".to_owned(),
96 ))?
97 .as_fixed_size_list();
98 let dim = fsl.value_length() as usize;
99 let sq_codes = fsl
100 .values()
101 .as_primitive_opt::<UInt8Type>()
102 .ok_or(Error::index(
103 "SQ code column is not FixedSizeList<u8>".to_owned(),
104 ))?
105 .clone();
106 Ok(Self {
107 batch,
108 dim,
109 row_ids,
110 sq_codes,
111 })
112 }
113
114 fn dim(&self) -> usize {
116 self.dim
117 }
118
119 fn len(&self) -> usize {
120 self.row_ids.len()
121 }
122
123 fn schema(&self) -> &SchemaRef {
124 self.batch.schema_ref()
125 }
126
127 #[inline]
128 fn row_id(&self, id: u32) -> u64 {
129 self.row_ids.value(id as usize)
130 }
131
132 #[inline]
134 fn sq_code_slice(&self, id: u32) -> &[u8] {
135 &self.sq_codes.values()[id as usize * self.dim..(id + 1) as usize * self.dim]
137 }
138}
139
140impl DeepSizeOf for SQStorageChunk {
141 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
142 self.batch.get_array_memory_size()
143 }
144}
145
146#[derive(Debug, Clone)]
147pub struct ScalarQuantizationStorage {
148 quantizer: ScalarQuantizer,
149
150 distance_type: DistanceType,
151
152 offsets: Vec<u32>,
154 chunks: Vec<SQStorageChunk>,
155}
156
157impl DeepSizeOf for ScalarQuantizationStorage {
158 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
159 self.chunks
160 .iter()
161 .map(|c| c.deep_size_of_children(context))
162 .sum()
163 }
164}
165
166const SQ_CHUNK_CAPACITY: usize = 1024;
167
168impl ScalarQuantizationStorage {
169 pub fn try_new(
170 num_bits: u16,
171 distance_type: DistanceType,
172 bounds: Range<f64>,
173 batches: impl IntoIterator<Item = RecordBatch>,
174 frag_reuse_index: Option<Arc<FragReuseIndex>>,
175 ) -> Result<Self> {
176 let mut chunks = Vec::with_capacity(SQ_CHUNK_CAPACITY);
177 let mut offsets = Vec::with_capacity(SQ_CHUNK_CAPACITY + 1);
178 offsets.push(0);
179 for mut batch in batches.into_iter() {
180 if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
181 batch = frag_reuse_index_ref.remap_row_ids_record_batch(batch, 0)?
182 }
183 offsets.push(offsets.last().unwrap() + batch.num_rows() as u32);
184 let chunk = SQStorageChunk::new(batch)?;
185 chunks.push(chunk);
186 }
187 let quantizer = ScalarQuantizer::with_bounds(num_bits, chunks[0].dim(), bounds);
188
189 Ok(Self {
190 quantizer,
191 distance_type,
192 offsets,
193 chunks,
194 })
195 }
196
197 fn chunk(&self, id: u32) -> (u32, &SQStorageChunk) {
205 match self.offsets.binary_search(&id) {
206 Ok(o) => (self.offsets[o], &self.chunks[o]),
207 Err(o) => (self.offsets[o - 1], &self.chunks[o - 1]),
208 }
209 }
210
211 pub async fn load(
212 object_store: &ObjectStore,
213 path: &Path,
214 frag_reuse_index: Option<Arc<FragReuseIndex>>,
215 ) -> Result<Self> {
216 let reader = PreviousFileReader::try_new_self_described(object_store, path, None).await?;
217 let schema = reader.schema();
218
219 let metadata_str = schema
220 .metadata
221 .get(INDEX_METADATA_SCHEMA_KEY)
222 .ok_or(Error::index(format!(
223 "Reading SQ storage: index key {} not found",
224 INDEX_METADATA_SCHEMA_KEY
225 )))?;
226 let index_metadata: IndexMetadata = serde_json::from_str(metadata_str).map_err(|_| {
227 Error::index(format!("Failed to parse index metadata: {}", metadata_str))
228 })?;
229 let distance_type = DistanceType::try_from(index_metadata.distance_type.as_str())?;
230 let metadata = ScalarQuantizationMetadata::load(&reader).await?;
231
232 Self::load_partition(
233 &reader,
234 0..reader.len(),
235 distance_type,
236 &metadata,
237 frag_reuse_index,
238 )
239 .await
240 }
241
242 fn optimize(self) -> Result<Self> {
243 if self.len() <= SQ_CHUNK_CAPACITY {
244 Ok(self)
245 } else {
246 let mut new = self.clone();
247 let batch = concat_batches(
248 self.chunks[0].schema(),
249 self.chunks.iter().map(|c| &c.batch),
250 )?;
251 new.offsets = vec![0, batch.num_rows() as u32];
252 new.chunks = vec![SQStorageChunk::new(batch)?];
253 Ok(new)
254 }
255 }
256}
257
258#[async_trait]
259impl QuantizerStorage for ScalarQuantizationStorage {
260 type Metadata = ScalarQuantizationMetadata;
261
262 fn try_from_batch(
263 batch: RecordBatch,
264 metadata: &Self::Metadata,
265 distance_type: DistanceType,
266 frag_reuse_index: Option<Arc<FragReuseIndex>>,
267 ) -> Result<Self>
268 where
269 Self: Sized,
270 {
271 Self::try_new(
272 metadata.num_bits,
273 distance_type,
274 metadata.bounds.clone(),
275 [batch],
276 frag_reuse_index,
277 )
278 }
279
280 fn metadata(&self) -> &Self::Metadata {
281 &self.quantizer.metadata
282 }
283
284 async fn load_partition(
293 reader: &PreviousFileReader,
294 range: std::ops::Range<usize>,
295 distance_type: DistanceType,
296 metadata: &Self::Metadata,
297 frag_reuse_index: Option<Arc<FragReuseIndex>>,
298 ) -> Result<Self> {
299 let schema = reader.schema();
300 let batch = reader.read_range(range, schema).await?;
301
302 Self::try_new(
303 metadata.num_bits,
304 distance_type,
305 metadata.bounds.clone(),
306 [batch],
307 frag_reuse_index,
308 )
309 }
310}
311
312impl VectorStore for ScalarQuantizationStorage {
313 type DistanceCalculator<'a> = SQDistCalculator<'a>;
314
315 fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch>> {
316 Ok(self.chunks.iter().map(|c| c.batch.clone()))
317 }
318
319 fn append_batch(&self, batch: RecordBatch, vector_column: &str) -> Result<Self> {
320 let transformer = super::transform::SQTransformer::new(
322 self.quantizer.clone(),
323 vector_column.to_string(),
324 SQ_CODE_COLUMN.to_string(),
325 );
326
327 let new_batch = transformer.transform(&batch)?;
328
329 let mut storage = self.clone();
331 let offset = self.len() as u32;
332 let new_chunk = SQStorageChunk::new(new_batch)?;
333 storage.offsets.push(offset + new_chunk.len() as u32);
334 storage.chunks.push(new_chunk);
335
336 storage.optimize()
337 }
338
339 fn schema(&self) -> &SchemaRef {
340 self.chunks[0].schema()
341 }
342
343 fn as_any(&self) -> &dyn std::any::Any {
344 self
345 }
346
347 fn len(&self) -> usize {
348 *self.offsets.last().unwrap() as usize
349 }
350
351 fn distance_type(&self) -> DistanceType {
353 self.distance_type
354 }
355
356 fn row_id(&self, id: u32) -> u64 {
357 let (offset, chunk) = self.chunk(id);
358 chunk.row_id(id - offset)
359 }
360
361 fn row_ids(&self) -> impl Iterator<Item = &u64> {
362 self.chunks.iter().flat_map(|c| c.row_ids.values())
363 }
364
365 fn dist_calculator(&self, query: ArrayRef, _dist_q_c: f32) -> Self::DistanceCalculator<'_> {
370 SQDistCalculator::new(query, self, self.quantizer.bounds())
371 }
372
373 fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> {
374 let (offset, chunk) = self.chunk(id);
375 let query_sq_code = chunk.sq_code_slice(id - offset).to_vec();
376 let bounds = self.quantizer.bounds();
377 SQDistCalculator {
378 query_sq_code,
379 scale: sq_distance_scale(&bounds),
380 storage: self,
381 }
382 }
383}
384
385#[inline]
386fn sq_distance_scale(bounds: &Range<f64>) -> f32 {
387 let range = (bounds.end - bounds.start) as f32;
388 (range * range) / (255.0_f32 * 255.0_f32)
389}
390
391pub struct SQDistCalculator<'a> {
392 query_sq_code: Vec<u8>,
393 scale: f32,
394 storage: &'a ScalarQuantizationStorage,
395}
396
397impl<'a> SQDistCalculator<'a> {
398 fn new(query: ArrayRef, storage: &'a ScalarQuantizationStorage, bounds: Range<f64>) -> Self {
399 let query_sq_code = match query.data_type() {
404 DataType::Float16 => {
405 scale_to_u8::<Float16Type>(query.as_primitive::<Float16Type>().values(), &bounds)
406 }
407 DataType::Float32 => {
408 scale_to_u8::<Float32Type>(query.as_primitive::<Float32Type>().values(), &bounds)
409 }
410 DataType::Float64 => {
411 scale_to_u8::<Float64Type>(query.as_primitive::<Float64Type>().values(), &bounds)
412 }
413 _ => {
414 panic!("Unsupported data type for ScalarQuantizationStorage");
415 }
416 };
417 Self {
418 query_sq_code,
419 scale: sq_distance_scale(&bounds),
420 storage,
421 }
422 }
423}
424
425impl DistCalculator for SQDistCalculator<'_> {
426 fn distance(&self, id: u32) -> f32 {
427 let (offset, chunk) = self.storage.chunk(id);
428 let sq_code = chunk.sq_code_slice(id - offset);
429 let dist = match self.storage.distance_type {
430 DistanceType::L2 | DistanceType::Cosine => {
431 l2_distance_uint_scalar(sq_code, &self.query_sq_code)
432 }
433 DistanceType::Dot => dot_distance(sq_code, &self.query_sq_code),
434 _ => panic!("We should not reach here: sq distance can only be L2 or Dot"),
435 };
436 dist * self.scale
437 }
438
439 fn distance_all(&self, _k_hint: usize) -> Vec<f32> {
440 match self.storage.distance_type {
441 DistanceType::L2 | DistanceType::Cosine => self
442 .storage
443 .chunks
444 .iter()
445 .flat_map(|c| {
446 c.sq_codes
447 .values()
448 .chunks_exact(c.dim())
449 .map(|sq_codes| l2_distance_uint_scalar(sq_codes, &self.query_sq_code))
450 })
451 .map(|dist| dist * self.scale)
452 .collect(),
453 DistanceType::Dot => self
454 .storage
455 .chunks
456 .iter()
457 .flat_map(|c| {
458 c.sq_codes
459 .values()
460 .chunks_exact(c.dim())
461 .map(|sq_codes| dot_distance(sq_codes, &self.query_sq_code))
462 })
463 .map(|dist| dist * self.scale)
464 .collect(),
465 _ => panic!("We should not reach here: sq distance can only be L2 or Dot"),
466 }
467 }
468
469 #[allow(unused_variables)]
470 fn prefetch(&self, id: u32) {
471 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
472 {
473 const CACHE_LINE_SIZE: usize = 64;
474
475 let (offset, chunk) = self.storage.chunk(id);
476 let dim = chunk.dim();
477 let base_ptr = chunk.sq_code_slice(id - offset).as_ptr();
478
479 unsafe {
480 for offset in (0..dim).step_by(CACHE_LINE_SIZE) {
482 {
483 use core::arch::x86_64::{_MM_HINT_T0, _mm_prefetch};
484 _mm_prefetch(base_ptr.add(offset) as *const i8, _MM_HINT_T0);
485 }
486 }
487 }
488 }
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 use std::iter::repeat_with;
497 use std::sync::Arc;
498
499 use arrow_array::FixedSizeListArray;
500 use arrow_schema::{DataType, Field, Schema};
501 use lance_arrow::FixedSizeListArrayExt;
502 use lance_testing::datagen::generate_random_array;
503 use rand::prelude::*;
504
505 fn create_record_batch(row_ids: Range<u64>) -> RecordBatch {
506 const DIM: usize = 64;
507
508 let mut rng = rand::rng();
509 let row_ids = UInt64Array::from_iter_values(row_ids);
510 let sq_code = UInt8Array::from_iter_values(
511 repeat_with(|| rng.random::<u8>()).take(row_ids.len() * DIM),
512 );
513 let code_arr = FixedSizeListArray::try_new_from_values(sq_code, DIM as i32).unwrap();
514
515 let schema = Arc::new(Schema::new(vec![
516 Field::new(ROW_ID, DataType::UInt64, false),
517 Field::new(
518 SQ_CODE_COLUMN,
519 DataType::FixedSizeList(
520 Arc::new(Field::new("item", DataType::UInt8, true)),
521 DIM as i32,
522 ),
523 false,
524 ),
525 ]));
526 RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(code_arr)]).unwrap()
527 }
528
529 #[test]
530 fn test_get_chunks() {
531 const DIM: usize = 64;
532
533 let storage = ScalarQuantizationStorage::try_new(
534 8,
535 DistanceType::L2,
536 -0.7..0.7,
537 (0..4).map(|start| create_record_batch(start * 100..(start + 1) * 100)),
538 None,
539 )
540 .unwrap();
541
542 assert_eq!(storage.len(), 400);
543
544 let (offset, chunk) = storage.chunk(0);
545 assert_eq!(offset, 0);
546 assert_eq!(chunk.row_id(20), 20);
547
548 let (offset, _) = storage.chunk(50);
549 assert_eq!(offset, 0);
550
551 let row_ids = UInt64Array::from_iter_values(100..250);
552 let vector_data = generate_random_array(row_ids.len() * DIM);
553 let fsl = FixedSizeListArray::try_new_from_values(vector_data, DIM as i32).unwrap();
554
555 let schema = Arc::new(Schema::new(vec![
556 Field::new(ROW_ID, DataType::UInt64, false),
557 Field::new(
558 "vector",
559 DataType::FixedSizeList(
560 Arc::new(Field::new("item", DataType::Float32, true)),
561 DIM as i32,
562 ),
563 false,
564 ),
565 ]));
566
567 let second_batch =
568 RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(fsl)]).unwrap();
569 let storage = storage.append_batch(second_batch, "vector").unwrap();
570
571 assert_eq!(storage.len(), 550);
572 let (offset, chunk) = storage.chunk(112);
573 assert_eq!(offset, 100);
574 assert_eq!(chunk.row_id(10), 110);
575
576 let (offset, chunk) = storage.chunk(432);
577 assert_eq!(offset, 400);
578 assert_eq!(chunk.row_id(5), 105);
579 }
580}