1use std::ops::Range;
5
6use arrow::datatypes::Float64Type;
7use arrow::{compute::concat_batches, datatypes::Float16Type};
8use arrow_array::{
9 cast::AsArray,
10 types::{Float32Type, UInt64Type, UInt8Type},
11 ArrayRef, RecordBatch, UInt64Array, UInt8Array,
12};
13use arrow_schema::{DataType, SchemaRef};
14use async_trait::async_trait;
15use deepsize::DeepSizeOf;
16use lance_core::{Error, Result, ROW_ID};
17use lance_file::reader::FileReader;
18use lance_io::object_store::ObjectStore;
19use lance_linalg::distance::{dot_distance, l2_distance_uint_scalar, DistanceType};
20use lance_table::format::SelfDescribingFileReader;
21use object_store::path::Path;
22use serde::{Deserialize, Serialize};
23use snafu::location;
24use std::sync::Arc;
25
26use super::{inverse_scalar_dist, scale_to_u8, ScalarQuantizer};
27use crate::frag_reuse::FragReuseIndex;
28use crate::{
29 vector::{
30 quantizer::{QuantizerMetadata, QuantizerStorage},
31 storage::{DistCalculator, VectorStore},
32 transform::Transformer,
33 SQ_CODE_COLUMN,
34 },
35 IndexMetadata, INDEX_METADATA_SCHEMA_KEY,
36};
37
38pub const SQ_METADATA_KEY: &str = "lance:sq";
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ScalarQuantizationMetadata {
42 pub dim: usize,
43 pub num_bits: u16,
44 pub bounds: Range<f64>,
45}
46
47impl DeepSizeOf for ScalarQuantizationMetadata {
48 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
49 0
50 }
51}
52
53#[async_trait]
54impl QuantizerMetadata for ScalarQuantizationMetadata {
55 async fn load(reader: &FileReader) -> Result<Self> {
56 let metadata_str = reader
57 .schema()
58 .metadata
59 .get(SQ_METADATA_KEY)
60 .ok_or(Error::Index {
61 message: format!(
62 "Reading SQ metadata: metadata key {} not found",
63 SQ_METADATA_KEY
64 ),
65 location: location!(),
66 })?;
67 serde_json::from_str(metadata_str).map_err(|_| Error::Index {
68 message: format!("Failed to parse index metadata: {}", metadata_str),
69 location: location!(),
70 })
71 }
72}
73
74#[derive(Debug, Clone)]
76struct SQStorageChunk {
77 batch: RecordBatch,
78
79 dim: usize,
80
81 row_ids: UInt64Array,
85 sq_codes: UInt8Array,
86}
87
88impl SQStorageChunk {
89 fn new(batch: RecordBatch) -> Result<Self> {
91 let row_ids = batch
92 .column_by_name(ROW_ID)
93 .ok_or(Error::Index {
94 message: "Row ID column not found in the batch".to_owned(),
95 location: location!(),
96 })?
97 .as_primitive::<UInt64Type>()
98 .clone();
99 let fsl = batch
100 .column_by_name(SQ_CODE_COLUMN)
101 .ok_or(Error::Index {
102 message: "SQ code column not found in the batch".to_owned(),
103 location: location!(),
104 })?
105 .as_fixed_size_list();
106 let dim = fsl.value_length() as usize;
107 let sq_codes = fsl
108 .values()
109 .as_primitive_opt::<UInt8Type>()
110 .ok_or(Error::Index {
111 message: "SQ code column is not FixedSizeList<u8>".to_owned(),
112 location: location!(),
113 })?
114 .clone();
115 Ok(Self {
116 batch,
117 dim,
118 row_ids,
119 sq_codes,
120 })
121 }
122
123 fn dim(&self) -> usize {
125 self.dim
126 }
127
128 fn len(&self) -> usize {
129 self.row_ids.len()
130 }
131
132 fn schema(&self) -> &SchemaRef {
133 self.batch.schema_ref()
134 }
135
136 #[inline]
137 fn row_id(&self, id: u32) -> u64 {
138 self.row_ids.value(id as usize)
139 }
140
141 #[inline]
143 fn sq_code_slice(&self, id: u32) -> &[u8] {
144 &self.sq_codes.values()[id as usize * self.dim..(id + 1) as usize * self.dim]
146 }
147}
148
149impl DeepSizeOf for SQStorageChunk {
150 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
151 self.batch.get_array_memory_size()
152 }
153}
154
155#[derive(Debug, Clone)]
156pub struct ScalarQuantizationStorage {
157 quantizer: ScalarQuantizer,
158
159 distance_type: DistanceType,
160
161 offsets: Vec<u32>,
163 chunks: Vec<SQStorageChunk>,
164}
165
166impl DeepSizeOf for ScalarQuantizationStorage {
167 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
168 self.chunks
169 .iter()
170 .map(|c| c.deep_size_of_children(context))
171 .sum()
172 }
173}
174
175const SQ_CHUNK_CAPACITY: usize = 1024;
176
177impl ScalarQuantizationStorage {
178 pub fn try_new(
179 num_bits: u16,
180 distance_type: DistanceType,
181 bounds: Range<f64>,
182 batches: impl IntoIterator<Item = RecordBatch>,
183 fri: Option<Arc<FragReuseIndex>>,
184 ) -> Result<Self> {
185 let mut chunks = Vec::with_capacity(SQ_CHUNK_CAPACITY);
186 let mut offsets = Vec::with_capacity(SQ_CHUNK_CAPACITY + 1);
187 offsets.push(0);
188 for mut batch in batches.into_iter() {
189 if let Some(fri_ref) = fri.as_ref() {
190 batch = fri_ref.remap_row_ids_record_batch(batch, 0)?
191 }
192 offsets.push(offsets.last().unwrap() + batch.num_rows() as u32);
193 let chunk = SQStorageChunk::new(batch)?;
194 chunks.push(chunk);
195 }
196 let quantizer = ScalarQuantizer::with_bounds(num_bits, chunks[0].dim(), bounds);
197
198 Ok(Self {
199 quantizer,
200 distance_type,
201 offsets,
202 chunks,
203 })
204 }
205
206 fn chunk(&self, id: u32) -> (u32, &SQStorageChunk) {
214 match self.offsets.binary_search(&id) {
215 Ok(o) => (self.offsets[o], &self.chunks[o]),
216 Err(o) => (self.offsets[o - 1], &self.chunks[o - 1]),
217 }
218 }
219
220 pub async fn load(
221 object_store: &ObjectStore,
222 path: &Path,
223 fri: Option<Arc<FragReuseIndex>>,
224 ) -> Result<Self> {
225 let reader = FileReader::try_new_self_described(object_store, path, None).await?;
226 let schema = reader.schema();
227
228 let metadata_str = schema
229 .metadata
230 .get(INDEX_METADATA_SCHEMA_KEY)
231 .ok_or(Error::Index {
232 message: format!(
233 "Reading SQ storage: index key {} not found",
234 INDEX_METADATA_SCHEMA_KEY
235 ),
236 location: location!(),
237 })?;
238 let index_metadata: IndexMetadata =
239 serde_json::from_str(metadata_str).map_err(|_| Error::Index {
240 message: format!("Failed to parse index metadata: {}", metadata_str),
241 location: location!(),
242 })?;
243 let distance_type = DistanceType::try_from(index_metadata.distance_type.as_str())?;
244 let metadata = ScalarQuantizationMetadata::load(&reader).await?;
245
246 Self::load_partition(&reader, 0..reader.len(), distance_type, &metadata, fri).await
247 }
248
249 fn optimize(self) -> Result<Self> {
250 if self.len() <= SQ_CHUNK_CAPACITY {
251 Ok(self)
252 } else {
253 let mut new = self.clone();
254 let batch = concat_batches(
255 self.chunks[0].schema(),
256 self.chunks.iter().map(|c| &c.batch),
257 )?;
258 new.offsets = vec![0, batch.num_rows() as u32];
259 new.chunks = vec![SQStorageChunk::new(batch)?];
260 Ok(new)
261 }
262 }
263}
264
265#[async_trait]
266impl QuantizerStorage for ScalarQuantizationStorage {
267 type Metadata = ScalarQuantizationMetadata;
268
269 fn try_from_batch(
270 batch: RecordBatch,
271 metadata: &Self::Metadata,
272 distance_type: DistanceType,
273 fri: Option<Arc<FragReuseIndex>>,
274 ) -> Result<Self>
275 where
276 Self: Sized,
277 {
278 Self::try_new(
279 metadata.num_bits,
280 distance_type,
281 metadata.bounds.clone(),
282 [batch],
283 fri,
284 )
285 }
286
287 fn metadata(&self) -> &Self::Metadata {
288 &self.quantizer.metadata
289 }
290
291 async fn load_partition(
300 reader: &FileReader,
301 range: std::ops::Range<usize>,
302 distance_type: DistanceType,
303 metadata: &Self::Metadata,
304 fri: Option<Arc<FragReuseIndex>>,
305 ) -> Result<Self> {
306 let schema = reader.schema();
307 let batch = reader.read_range(range, schema).await?;
308
309 Self::try_new(
310 metadata.num_bits,
311 distance_type,
312 metadata.bounds.clone(),
313 [batch],
314 fri,
315 )
316 }
317}
318
319impl VectorStore for ScalarQuantizationStorage {
320 type DistanceCalculator<'a> = SQDistCalculator<'a>;
321
322 fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch>> {
323 Ok(self.chunks.iter().map(|c| c.batch.clone()))
324 }
325
326 fn append_batch(&self, batch: RecordBatch, vector_column: &str) -> Result<Self> {
327 let transformer = super::transform::SQTransformer::new(
329 self.quantizer.clone(),
330 vector_column.to_string(),
331 SQ_CODE_COLUMN.to_string(),
332 );
333
334 let new_batch = transformer.transform(&batch)?;
335
336 let mut storage = self.clone();
338 let offset = self.len() as u32;
339 let new_chunk = SQStorageChunk::new(new_batch)?;
340 storage.offsets.push(offset + new_chunk.len() as u32);
341 storage.chunks.push(new_chunk);
342
343 storage.optimize()
344 }
345
346 fn schema(&self) -> &SchemaRef {
347 self.chunks[0].schema()
348 }
349
350 fn as_any(&self) -> &dyn std::any::Any {
351 self
352 }
353
354 fn len(&self) -> usize {
355 *self.offsets.last().unwrap() as usize
356 }
357
358 fn distance_type(&self) -> DistanceType {
360 self.distance_type
361 }
362
363 fn row_id(&self, id: u32) -> u64 {
364 let (offset, chunk) = self.chunk(id);
365 chunk.row_id(id - offset)
366 }
367
368 fn row_ids(&self) -> impl Iterator<Item = &u64> {
369 self.chunks.iter().flat_map(|c| c.row_ids.values())
370 }
371
372 fn dist_calculator(&self, query: ArrayRef) -> Self::DistanceCalculator<'_> {
377 SQDistCalculator::new(query, self, self.quantizer.bounds())
378 }
379
380 fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> {
381 let (offset, chunk) = self.chunk(id);
382 let query_sq_code = chunk.sq_code_slice(id - offset).to_vec();
383 SQDistCalculator {
384 query_sq_code,
385 bounds: self.quantizer.bounds(),
386 storage: self,
387 }
388 }
389}
390
391pub struct SQDistCalculator<'a> {
392 query_sq_code: Vec<u8>,
393 bounds: Range<f64>,
394 storage: &'a ScalarQuantizationStorage,
395}
396
397impl<'a> SQDistCalculator<'a> {
398 fn new(query: ArrayRef, storage: &'a ScalarQuantizationStorage, bounds: Range<f64>) -> Self {
399 let query_sq_code = match query.data_type() {
404 DataType::Float16 => {
405 scale_to_u8::<Float16Type>(query.as_primitive::<Float16Type>().values(), &bounds)
406 }
407 DataType::Float32 => {
408 scale_to_u8::<Float32Type>(query.as_primitive::<Float32Type>().values(), &bounds)
409 }
410 DataType::Float64 => {
411 scale_to_u8::<Float64Type>(query.as_primitive::<Float64Type>().values(), &bounds)
412 }
413 _ => {
414 panic!("Unsupported data type for ScalarQuantizationStorage");
415 }
416 };
417 Self {
418 query_sq_code,
419 bounds,
420 storage,
421 }
422 }
423}
424
425impl DistCalculator for SQDistCalculator<'_> {
426 fn distance(&self, id: u32) -> f32 {
427 let (offset, chunk) = self.storage.chunk(id);
428 let sq_code = chunk.sq_code_slice(id - offset);
429 let dist = match self.storage.distance_type {
430 DistanceType::L2 | DistanceType::Cosine => {
431 l2_distance_uint_scalar(sq_code, &self.query_sq_code)
432 }
433 DistanceType::Dot => dot_distance(sq_code, &self.query_sq_code),
434 _ => panic!("We should not reach here: sq distance can only be L2 or Dot"),
435 };
436 inverse_scalar_dist(std::iter::once(dist), &self.bounds)[0]
437 }
438
439 fn distance_all(&self, _k_hint: usize) -> Vec<f32> {
440 match self.storage.distance_type {
441 DistanceType::L2 | DistanceType::Cosine => inverse_scalar_dist(
442 self.storage.chunks.iter().flat_map(|c| {
443 c.sq_codes
444 .values()
445 .chunks_exact(c.dim())
446 .map(|sq_codes| l2_distance_uint_scalar(sq_codes, &self.query_sq_code))
447 }),
448 &self.bounds,
449 ),
450 DistanceType::Dot => inverse_scalar_dist(
451 self.storage.chunks.iter().flat_map(|c| {
452 c.sq_codes
453 .values()
454 .chunks_exact(c.dim())
455 .map(|sq_codes| dot_distance(sq_codes, &self.query_sq_code))
456 }),
457 &self.bounds,
458 ),
459 _ => panic!("We should not reach here: sq distance can only be L2 or Dot"),
460 }
461 }
462
463 #[allow(unused_variables)]
464 fn prefetch(&self, id: u32) {
465 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
466 {
467 const CACHE_LINE_SIZE: usize = 64;
468
469 let (offset, chunk) = self.storage.chunk(id);
470 let dim = chunk.dim();
471 let base_ptr = chunk.sq_code_slice(id - offset).as_ptr();
472
473 unsafe {
474 for offset in (0..dim).step_by(CACHE_LINE_SIZE) {
476 {
477 use core::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
478 _mm_prefetch(base_ptr.add(offset) as *const i8, _MM_HINT_T0);
479 }
480 }
481 }
482 }
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489
490 use std::iter::repeat_with;
491 use std::sync::Arc;
492
493 use arrow_array::FixedSizeListArray;
494 use arrow_schema::{DataType, Field, Schema};
495 use lance_arrow::FixedSizeListArrayExt;
496 use lance_testing::datagen::generate_random_array;
497 use rand::prelude::*;
498
499 fn create_record_batch(row_ids: Range<u64>) -> RecordBatch {
500 const DIM: usize = 64;
501
502 let mut rng = rand::thread_rng();
503 let row_ids = UInt64Array::from_iter_values(row_ids);
504 let sq_code =
505 UInt8Array::from_iter_values(repeat_with(|| rng.gen::<u8>()).take(row_ids.len() * DIM));
506 let code_arr = FixedSizeListArray::try_new_from_values(sq_code, DIM as i32).unwrap();
507
508 let schema = Arc::new(Schema::new(vec![
509 Field::new(ROW_ID, DataType::UInt64, false),
510 Field::new(
511 SQ_CODE_COLUMN,
512 DataType::FixedSizeList(
513 Arc::new(Field::new("item", DataType::UInt8, true)),
514 DIM as i32,
515 ),
516 false,
517 ),
518 ]));
519 RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(code_arr)]).unwrap()
520 }
521
522 #[test]
523 fn test_get_chunks() {
524 const DIM: usize = 64;
525
526 let storage = ScalarQuantizationStorage::try_new(
527 8,
528 DistanceType::L2,
529 -0.7..0.7,
530 (0..4).map(|start| create_record_batch(start * 100..(start + 1) * 100)),
531 None,
532 )
533 .unwrap();
534
535 assert_eq!(storage.len(), 400);
536
537 let (offset, chunk) = storage.chunk(0);
538 assert_eq!(offset, 0);
539 assert_eq!(chunk.row_id(20), 20);
540
541 let (offset, _) = storage.chunk(50);
542 assert_eq!(offset, 0);
543
544 let row_ids = UInt64Array::from_iter_values(100..250);
545 let vector_data = generate_random_array(row_ids.len() * DIM);
546 let fsl = FixedSizeListArray::try_new_from_values(vector_data, DIM as i32).unwrap();
547
548 let schema = Arc::new(Schema::new(vec![
549 Field::new(ROW_ID, DataType::UInt64, false),
550 Field::new(
551 "vector",
552 DataType::FixedSizeList(
553 Arc::new(Field::new("item", DataType::Float32, true)),
554 DIM as i32,
555 ),
556 false,
557 ),
558 ]));
559
560 let second_batch =
561 RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(fsl)]).unwrap();
562 let storage = storage.append_batch(second_batch, "vector").unwrap();
563
564 assert_eq!(storage.len(), 550);
565 let (offset, chunk) = storage.chunk(112);
566 assert_eq!(offset, 100);
567 assert_eq!(chunk.row_id(10), 110);
568
569 let (offset, chunk) = storage.chunk(432);
570 assert_eq!(offset, 400);
571 assert_eq!(chunk.row_id(5), 105);
572 }
573}