1use std::ops::Range;
5
6use arrow::datatypes::Float64Type;
7use arrow::{compute::concat_batches, datatypes::Float16Type};
8use arrow_array::{
9 cast::AsArray,
10 types::{Float32Type, UInt64Type, UInt8Type},
11 ArrayRef, RecordBatch, UInt64Array, UInt8Array,
12};
13use arrow_schema::{DataType, SchemaRef};
14use async_trait::async_trait;
15use deepsize::DeepSizeOf;
16use lance_core::{Error, Result, ROW_ID};
17use lance_file::previous::reader::FileReader as PreviousFileReader;
18use lance_io::object_store::ObjectStore;
19use lance_linalg::distance::{dot_distance, l2_distance_uint_scalar, DistanceType};
20use lance_table::format::SelfDescribingFileReader;
21use object_store::path::Path;
22use serde::{Deserialize, Serialize};
23use snafu::location;
24use std::sync::Arc;
25
26use super::{inverse_scalar_dist, scale_to_u8, ScalarQuantizer};
27use crate::frag_reuse::FragReuseIndex;
28use crate::{
29 vector::{
30 quantizer::{QuantizerMetadata, QuantizerStorage},
31 storage::{DistCalculator, VectorStore},
32 transform::Transformer,
33 SQ_CODE_COLUMN,
34 },
35 IndexMetadata, INDEX_METADATA_SCHEMA_KEY,
36};
37
38pub const SQ_METADATA_KEY: &str = "lance:sq";
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ScalarQuantizationMetadata {
42 pub dim: usize,
43 pub num_bits: u16,
44 pub bounds: Range<f64>,
45}
46
47impl DeepSizeOf for ScalarQuantizationMetadata {
48 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
49 0
50 }
51}
52
53#[async_trait]
54impl QuantizerMetadata for ScalarQuantizationMetadata {
55 async fn load(reader: &PreviousFileReader) -> Result<Self> {
56 let metadata_str = reader
57 .schema()
58 .metadata
59 .get(SQ_METADATA_KEY)
60 .ok_or(Error::Index {
61 message: format!(
62 "Reading SQ metadata: metadata key {} not found",
63 SQ_METADATA_KEY
64 ),
65 location: location!(),
66 })?;
67 serde_json::from_str(metadata_str).map_err(|_| Error::Index {
68 message: format!("Failed to parse index metadata: {}", metadata_str),
69 location: location!(),
70 })
71 }
72}
73
74#[derive(Debug, Clone)]
76struct SQStorageChunk {
77 batch: RecordBatch,
78
79 dim: usize,
80
81 row_ids: UInt64Array,
85 sq_codes: UInt8Array,
86}
87
88impl SQStorageChunk {
89 fn new(batch: RecordBatch) -> Result<Self> {
91 let row_ids = batch
92 .column_by_name(ROW_ID)
93 .ok_or(Error::Index {
94 message: "Row ID column not found in the batch".to_owned(),
95 location: location!(),
96 })?
97 .as_primitive::<UInt64Type>()
98 .clone();
99 let fsl = batch
100 .column_by_name(SQ_CODE_COLUMN)
101 .ok_or(Error::Index {
102 message: "SQ code column not found in the batch".to_owned(),
103 location: location!(),
104 })?
105 .as_fixed_size_list();
106 let dim = fsl.value_length() as usize;
107 let sq_codes = fsl
108 .values()
109 .as_primitive_opt::<UInt8Type>()
110 .ok_or(Error::Index {
111 message: "SQ code column is not FixedSizeList<u8>".to_owned(),
112 location: location!(),
113 })?
114 .clone();
115 Ok(Self {
116 batch,
117 dim,
118 row_ids,
119 sq_codes,
120 })
121 }
122
123 fn dim(&self) -> usize {
125 self.dim
126 }
127
128 fn len(&self) -> usize {
129 self.row_ids.len()
130 }
131
132 fn schema(&self) -> &SchemaRef {
133 self.batch.schema_ref()
134 }
135
136 #[inline]
137 fn row_id(&self, id: u32) -> u64 {
138 self.row_ids.value(id as usize)
139 }
140
141 #[inline]
143 fn sq_code_slice(&self, id: u32) -> &[u8] {
144 &self.sq_codes.values()[id as usize * self.dim..(id + 1) as usize * self.dim]
146 }
147}
148
149impl DeepSizeOf for SQStorageChunk {
150 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
151 self.batch.get_array_memory_size()
152 }
153}
154
155#[derive(Debug, Clone)]
156pub struct ScalarQuantizationStorage {
157 quantizer: ScalarQuantizer,
158
159 distance_type: DistanceType,
160
161 offsets: Vec<u32>,
163 chunks: Vec<SQStorageChunk>,
164}
165
166impl DeepSizeOf for ScalarQuantizationStorage {
167 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
168 self.chunks
169 .iter()
170 .map(|c| c.deep_size_of_children(context))
171 .sum()
172 }
173}
174
175const SQ_CHUNK_CAPACITY: usize = 1024;
176
177impl ScalarQuantizationStorage {
178 pub fn try_new(
179 num_bits: u16,
180 distance_type: DistanceType,
181 bounds: Range<f64>,
182 batches: impl IntoIterator<Item = RecordBatch>,
183 frag_reuse_index: Option<Arc<FragReuseIndex>>,
184 ) -> Result<Self> {
185 let mut chunks = Vec::with_capacity(SQ_CHUNK_CAPACITY);
186 let mut offsets = Vec::with_capacity(SQ_CHUNK_CAPACITY + 1);
187 offsets.push(0);
188 for mut batch in batches.into_iter() {
189 if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
190 batch = frag_reuse_index_ref.remap_row_ids_record_batch(batch, 0)?
191 }
192 offsets.push(offsets.last().unwrap() + batch.num_rows() as u32);
193 let chunk = SQStorageChunk::new(batch)?;
194 chunks.push(chunk);
195 }
196 let quantizer = ScalarQuantizer::with_bounds(num_bits, chunks[0].dim(), bounds);
197
198 Ok(Self {
199 quantizer,
200 distance_type,
201 offsets,
202 chunks,
203 })
204 }
205
206 fn chunk(&self, id: u32) -> (u32, &SQStorageChunk) {
214 match self.offsets.binary_search(&id) {
215 Ok(o) => (self.offsets[o], &self.chunks[o]),
216 Err(o) => (self.offsets[o - 1], &self.chunks[o - 1]),
217 }
218 }
219
220 pub async fn load(
221 object_store: &ObjectStore,
222 path: &Path,
223 frag_reuse_index: Option<Arc<FragReuseIndex>>,
224 ) -> Result<Self> {
225 let reader = PreviousFileReader::try_new_self_described(object_store, path, None).await?;
226 let schema = reader.schema();
227
228 let metadata_str = schema
229 .metadata
230 .get(INDEX_METADATA_SCHEMA_KEY)
231 .ok_or(Error::Index {
232 message: format!(
233 "Reading SQ storage: index key {} not found",
234 INDEX_METADATA_SCHEMA_KEY
235 ),
236 location: location!(),
237 })?;
238 let index_metadata: IndexMetadata =
239 serde_json::from_str(metadata_str).map_err(|_| Error::Index {
240 message: format!("Failed to parse index metadata: {}", metadata_str),
241 location: location!(),
242 })?;
243 let distance_type = DistanceType::try_from(index_metadata.distance_type.as_str())?;
244 let metadata = ScalarQuantizationMetadata::load(&reader).await?;
245
246 Self::load_partition(
247 &reader,
248 0..reader.len(),
249 distance_type,
250 &metadata,
251 frag_reuse_index,
252 )
253 .await
254 }
255
256 fn optimize(self) -> Result<Self> {
257 if self.len() <= SQ_CHUNK_CAPACITY {
258 Ok(self)
259 } else {
260 let mut new = self.clone();
261 let batch = concat_batches(
262 self.chunks[0].schema(),
263 self.chunks.iter().map(|c| &c.batch),
264 )?;
265 new.offsets = vec![0, batch.num_rows() as u32];
266 new.chunks = vec![SQStorageChunk::new(batch)?];
267 Ok(new)
268 }
269 }
270}
271
272#[async_trait]
273impl QuantizerStorage for ScalarQuantizationStorage {
274 type Metadata = ScalarQuantizationMetadata;
275
276 fn try_from_batch(
277 batch: RecordBatch,
278 metadata: &Self::Metadata,
279 distance_type: DistanceType,
280 frag_reuse_index: Option<Arc<FragReuseIndex>>,
281 ) -> Result<Self>
282 where
283 Self: Sized,
284 {
285 Self::try_new(
286 metadata.num_bits,
287 distance_type,
288 metadata.bounds.clone(),
289 [batch],
290 frag_reuse_index,
291 )
292 }
293
294 fn metadata(&self) -> &Self::Metadata {
295 &self.quantizer.metadata
296 }
297
298 async fn load_partition(
307 reader: &PreviousFileReader,
308 range: std::ops::Range<usize>,
309 distance_type: DistanceType,
310 metadata: &Self::Metadata,
311 frag_reuse_index: Option<Arc<FragReuseIndex>>,
312 ) -> Result<Self> {
313 let schema = reader.schema();
314 let batch = reader.read_range(range, schema).await?;
315
316 Self::try_new(
317 metadata.num_bits,
318 distance_type,
319 metadata.bounds.clone(),
320 [batch],
321 frag_reuse_index,
322 )
323 }
324}
325
326impl VectorStore for ScalarQuantizationStorage {
327 type DistanceCalculator<'a> = SQDistCalculator<'a>;
328
329 fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch>> {
330 Ok(self.chunks.iter().map(|c| c.batch.clone()))
331 }
332
333 fn append_batch(&self, batch: RecordBatch, vector_column: &str) -> Result<Self> {
334 let transformer = super::transform::SQTransformer::new(
336 self.quantizer.clone(),
337 vector_column.to_string(),
338 SQ_CODE_COLUMN.to_string(),
339 );
340
341 let new_batch = transformer.transform(&batch)?;
342
343 let mut storage = self.clone();
345 let offset = self.len() as u32;
346 let new_chunk = SQStorageChunk::new(new_batch)?;
347 storage.offsets.push(offset + new_chunk.len() as u32);
348 storage.chunks.push(new_chunk);
349
350 storage.optimize()
351 }
352
353 fn schema(&self) -> &SchemaRef {
354 self.chunks[0].schema()
355 }
356
357 fn as_any(&self) -> &dyn std::any::Any {
358 self
359 }
360
361 fn len(&self) -> usize {
362 *self.offsets.last().unwrap() as usize
363 }
364
365 fn distance_type(&self) -> DistanceType {
367 self.distance_type
368 }
369
370 fn row_id(&self, id: u32) -> u64 {
371 let (offset, chunk) = self.chunk(id);
372 chunk.row_id(id - offset)
373 }
374
375 fn row_ids(&self) -> impl Iterator<Item = &u64> {
376 self.chunks.iter().flat_map(|c| c.row_ids.values())
377 }
378
379 fn dist_calculator(&self, query: ArrayRef, _dist_q_c: f32) -> Self::DistanceCalculator<'_> {
384 SQDistCalculator::new(query, self, self.quantizer.bounds())
385 }
386
387 fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> {
388 let (offset, chunk) = self.chunk(id);
389 let query_sq_code = chunk.sq_code_slice(id - offset).to_vec();
390 SQDistCalculator {
391 query_sq_code,
392 bounds: self.quantizer.bounds(),
393 storage: self,
394 }
395 }
396}
397
398pub struct SQDistCalculator<'a> {
399 query_sq_code: Vec<u8>,
400 bounds: Range<f64>,
401 storage: &'a ScalarQuantizationStorage,
402}
403
404impl<'a> SQDistCalculator<'a> {
405 fn new(query: ArrayRef, storage: &'a ScalarQuantizationStorage, bounds: Range<f64>) -> Self {
406 let query_sq_code = match query.data_type() {
411 DataType::Float16 => {
412 scale_to_u8::<Float16Type>(query.as_primitive::<Float16Type>().values(), &bounds)
413 }
414 DataType::Float32 => {
415 scale_to_u8::<Float32Type>(query.as_primitive::<Float32Type>().values(), &bounds)
416 }
417 DataType::Float64 => {
418 scale_to_u8::<Float64Type>(query.as_primitive::<Float64Type>().values(), &bounds)
419 }
420 _ => {
421 panic!("Unsupported data type for ScalarQuantizationStorage");
422 }
423 };
424 Self {
425 query_sq_code,
426 bounds,
427 storage,
428 }
429 }
430}
431
432impl DistCalculator for SQDistCalculator<'_> {
433 fn distance(&self, id: u32) -> f32 {
434 let (offset, chunk) = self.storage.chunk(id);
435 let sq_code = chunk.sq_code_slice(id - offset);
436 let dist = match self.storage.distance_type {
437 DistanceType::L2 | DistanceType::Cosine => {
438 l2_distance_uint_scalar(sq_code, &self.query_sq_code)
439 }
440 DistanceType::Dot => dot_distance(sq_code, &self.query_sq_code),
441 _ => panic!("We should not reach here: sq distance can only be L2 or Dot"),
442 };
443 inverse_scalar_dist(std::iter::once(dist), &self.bounds)[0]
444 }
445
446 fn distance_all(&self, _k_hint: usize) -> Vec<f32> {
447 match self.storage.distance_type {
448 DistanceType::L2 | DistanceType::Cosine => inverse_scalar_dist(
449 self.storage.chunks.iter().flat_map(|c| {
450 c.sq_codes
451 .values()
452 .chunks_exact(c.dim())
453 .map(|sq_codes| l2_distance_uint_scalar(sq_codes, &self.query_sq_code))
454 }),
455 &self.bounds,
456 ),
457 DistanceType::Dot => inverse_scalar_dist(
458 self.storage.chunks.iter().flat_map(|c| {
459 c.sq_codes
460 .values()
461 .chunks_exact(c.dim())
462 .map(|sq_codes| dot_distance(sq_codes, &self.query_sq_code))
463 }),
464 &self.bounds,
465 ),
466 _ => panic!("We should not reach here: sq distance can only be L2 or Dot"),
467 }
468 }
469
470 #[allow(unused_variables)]
471 fn prefetch(&self, id: u32) {
472 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
473 {
474 const CACHE_LINE_SIZE: usize = 64;
475
476 let (offset, chunk) = self.storage.chunk(id);
477 let dim = chunk.dim();
478 let base_ptr = chunk.sq_code_slice(id - offset).as_ptr();
479
480 unsafe {
481 for offset in (0..dim).step_by(CACHE_LINE_SIZE) {
483 {
484 use core::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
485 _mm_prefetch(base_ptr.add(offset) as *const i8, _MM_HINT_T0);
486 }
487 }
488 }
489 }
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496
497 use std::iter::repeat_with;
498 use std::sync::Arc;
499
500 use arrow_array::FixedSizeListArray;
501 use arrow_schema::{DataType, Field, Schema};
502 use lance_arrow::FixedSizeListArrayExt;
503 use lance_testing::datagen::generate_random_array;
504 use rand::prelude::*;
505
506 fn create_record_batch(row_ids: Range<u64>) -> RecordBatch {
507 const DIM: usize = 64;
508
509 let mut rng = rand::rng();
510 let row_ids = UInt64Array::from_iter_values(row_ids);
511 let sq_code = UInt8Array::from_iter_values(
512 repeat_with(|| rng.random::<u8>()).take(row_ids.len() * DIM),
513 );
514 let code_arr = FixedSizeListArray::try_new_from_values(sq_code, DIM as i32).unwrap();
515
516 let schema = Arc::new(Schema::new(vec![
517 Field::new(ROW_ID, DataType::UInt64, false),
518 Field::new(
519 SQ_CODE_COLUMN,
520 DataType::FixedSizeList(
521 Arc::new(Field::new("item", DataType::UInt8, true)),
522 DIM as i32,
523 ),
524 false,
525 ),
526 ]));
527 RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(code_arr)]).unwrap()
528 }
529
530 #[test]
531 fn test_get_chunks() {
532 const DIM: usize = 64;
533
534 let storage = ScalarQuantizationStorage::try_new(
535 8,
536 DistanceType::L2,
537 -0.7..0.7,
538 (0..4).map(|start| create_record_batch(start * 100..(start + 1) * 100)),
539 None,
540 )
541 .unwrap();
542
543 assert_eq!(storage.len(), 400);
544
545 let (offset, chunk) = storage.chunk(0);
546 assert_eq!(offset, 0);
547 assert_eq!(chunk.row_id(20), 20);
548
549 let (offset, _) = storage.chunk(50);
550 assert_eq!(offset, 0);
551
552 let row_ids = UInt64Array::from_iter_values(100..250);
553 let vector_data = generate_random_array(row_ids.len() * DIM);
554 let fsl = FixedSizeListArray::try_new_from_values(vector_data, DIM as i32).unwrap();
555
556 let schema = Arc::new(Schema::new(vec![
557 Field::new(ROW_ID, DataType::UInt64, false),
558 Field::new(
559 "vector",
560 DataType::FixedSizeList(
561 Arc::new(Field::new("item", DataType::Float32, true)),
562 DIM as i32,
563 ),
564 false,
565 ),
566 ]));
567
568 let second_batch =
569 RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(fsl)]).unwrap();
570 let storage = storage.append_batch(second_batch, "vector").unwrap();
571
572 assert_eq!(storage.len(), 550);
573 let (offset, chunk) = storage.chunk(112);
574 assert_eq!(offset, 100);
575 assert_eq!(chunk.row_id(10), 110);
576
577 let (offset, chunk) = storage.chunk(432);
578 assert_eq!(offset, 400);
579 assert_eq!(chunk.row_id(5), 105);
580 }
581}