Skip to main content

lance_index/vector/
utils.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use arrow::compute::cast;
5use arrow_array::types::{Float16Type, Float32Type, Float64Type};
6use arrow_array::{Array, ArrayRef, BooleanArray, FixedSizeListArray, cast::AsArray};
7use arrow_schema::{DataType, Field};
8use lance_arrow::FixedSizeListArrayExt;
9use lance_core::{Error, Result};
10use lance_io::encodings::plain::bytes_to_array;
11use lance_linalg::distance::DistanceType;
12use prost::bytes;
13use std::sync::LazyLock;
14use std::{ops::Range, sync::Arc};
15
16use super::pb;
17use crate::pb::Tensor;
18use crate::vector::flat::storage::FlatBinStorage;
19use crate::vector::flat::storage::FlatFloatStorage;
20use crate::vector::hnsw::HNSW;
21use crate::vector::hnsw::builder::{HnswBuildParams, HnswQueryParams};
22use crate::vector::v3::subindex::IvfSubIndex;
23
24enum SimpleIndexStatus {
25    Auto,
26    Enabled,
27    Disabled,
28}
29
30static USE_HNSW_SPEEDUP_INDEXING: LazyLock<SimpleIndexStatus> = LazyLock::new(|| {
31    if let Ok(v) = std::env::var("LANCE_USE_HNSW_SPEEDUP_INDEXING") {
32        if v == "enabled" {
33            SimpleIndexStatus::Enabled
34        } else if v == "disabled" {
35            SimpleIndexStatus::Disabled
36        } else {
37            SimpleIndexStatus::Auto
38        }
39    } else {
40        SimpleIndexStatus::Auto
41    }
42});
43
44#[derive(Debug)]
45pub struct SimpleIndex {
46    store: SimpleStore,
47    index: HNSW,
48}
49
50#[derive(Debug)]
51enum SimpleStore {
52    Float(FlatFloatStorage),
53    Binary(FlatBinStorage),
54}
55
56impl SimpleIndex {
57    fn try_new(store: SimpleStore) -> Result<Self> {
58        let hnsw = match &store {
59            SimpleStore::Float(store) => HNSW::index_vectors(
60                store,
61                HnswBuildParams::default().ef_construction(15).num_edges(12),
62            )?,
63            SimpleStore::Binary(store) => HNSW::index_vectors(
64                store,
65                HnswBuildParams::default().ef_construction(15).num_edges(12),
66            )?,
67        };
68        Ok(Self { store, index: hnsw })
69    }
70
71    // train HNSW over the centroids to speed up finding the nearest clusters,
72    // only train if all conditions are met:
73    //  - the centroids are float16/float32 or uint8 with hamming distance
74    //  - `num_centroids * dimension >= 1_000_000`
75    //      we benchmarked that it's 2x faster in the case of 1024 centroids and 1024 dimensions,
76    //      so set the threshold to 1_000_000.
77    pub fn may_train_index(
78        centroids: ArrayRef,
79        dimension: usize,
80        distance_type: DistanceType,
81    ) -> Result<Option<Self>> {
82        match *USE_HNSW_SPEEDUP_INDEXING {
83            SimpleIndexStatus::Auto => {
84                if centroids.len() < 1_000_000 {
85                    return Ok(None);
86                }
87            }
88            SimpleIndexStatus::Disabled => return Ok(None),
89            _ => {}
90        }
91
92        let store = match (centroids.data_type(), distance_type) {
93            (DataType::Float16 | DataType::Float32 | DataType::Float64, _) => {
94                let fsl = FixedSizeListArray::try_new_from_values(centroids, dimension as i32)?;
95                SimpleStore::Float(FlatFloatStorage::new(fsl, distance_type))
96            }
97            (DataType::UInt8, DistanceType::Hamming) => {
98                let fsl = FixedSizeListArray::try_new_from_values(centroids, dimension as i32)?;
99                SimpleStore::Binary(FlatBinStorage::new(fsl, distance_type))
100            }
101            _ => return Ok(None),
102        };
103        Self::try_new(store).map(Some)
104    }
105
106    pub(crate) fn search(&self, query: ArrayRef) -> Result<(u32, f32)> {
107        let params = HnswQueryParams {
108            ef: 15,
109            lower_bound: None,
110            upper_bound: None,
111            dist_q_c: 0.0,
112        };
113        let res = match &self.store {
114            SimpleStore::Float(store) => self.index.search_basic(query, 1, &params, None, store)?,
115            SimpleStore::Binary(store) => {
116                let query = if query.data_type() == &DataType::UInt8 {
117                    query
118                } else {
119                    cast(&query, &DataType::UInt8).map_err(|e| Error::index(e.to_string()))?
120                };
121                self.index.search_basic(query, 1, &params, None, store)?
122            }
123        };
124        Ok((res[0].id, res[0].dist.0))
125    }
126}
127
128#[inline]
129pub(crate) fn do_prefetch<T>(ptrs: Range<*const T>) {
130    // TODO use rust intrinsics instead of x86 intrinsics
131    // TODO finish this
132    unsafe {
133        let (ptr, end_ptr) = (ptrs.start as *const i8, ptrs.end as *const i8);
134        let mut current_ptr = ptr;
135        while current_ptr < end_ptr {
136            const CACHE_LINE_SIZE: usize = 64;
137            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
138            {
139                use core::arch::x86_64::{_MM_HINT_T0, _mm_prefetch};
140                _mm_prefetch(current_ptr, _MM_HINT_T0);
141            }
142            current_ptr = current_ptr.add(CACHE_LINE_SIZE);
143        }
144    }
145}
146
147impl From<pb::tensor::DataType> for DataType {
148    fn from(dt: pb::tensor::DataType) -> Self {
149        match dt {
150            pb::tensor::DataType::Uint8 => Self::UInt8,
151            pb::tensor::DataType::Uint16 => Self::UInt16,
152            pb::tensor::DataType::Uint32 => Self::UInt32,
153            pb::tensor::DataType::Uint64 => Self::UInt64,
154            pb::tensor::DataType::Float16 => Self::Float16,
155            pb::tensor::DataType::Float32 => Self::Float32,
156            pb::tensor::DataType::Float64 => Self::Float64,
157            pb::tensor::DataType::Bfloat16 => unimplemented!(),
158        }
159    }
160}
161
162impl TryFrom<&DataType> for pb::tensor::DataType {
163    type Error = Error;
164
165    fn try_from(dt: &DataType) -> Result<Self> {
166        match dt {
167            DataType::UInt8 => Ok(Self::Uint8),
168            DataType::UInt16 => Ok(Self::Uint16),
169            DataType::UInt32 => Ok(Self::Uint32),
170            DataType::UInt64 => Ok(Self::Uint64),
171            DataType::Float16 => Ok(Self::Float16),
172            DataType::Float32 => Ok(Self::Float32),
173            DataType::Float64 => Ok(Self::Float64),
174            _ => Err(Error::index(format!(
175                "pb tensor type not supported: {:?}",
176                dt
177            ))),
178        }
179    }
180}
181
182impl TryFrom<DataType> for pb::tensor::DataType {
183    type Error = Error;
184
185    fn try_from(dt: DataType) -> Result<Self> {
186        (&dt).try_into()
187    }
188}
189
190impl TryFrom<&FixedSizeListArray> for pb::Tensor {
191    type Error = Error;
192
193    fn try_from(array: &FixedSizeListArray) -> Result<Self> {
194        let mut tensor = Self::default();
195        tensor.data_type = pb::tensor::DataType::try_from(array.value_type())? as i32;
196        tensor.shape = vec![Array::len(array) as u32, array.value_length() as u32];
197        let flat_array = array.values();
198        tensor.data = flat_array.into_data().buffers()[0].to_vec();
199        Ok(tensor)
200    }
201}
202
203impl TryFrom<&pb::Tensor> for FixedSizeListArray {
204    type Error = Error;
205
206    fn try_from(tensor: &Tensor) -> Result<Self> {
207        if tensor.shape.len() != 2 {
208            return Err(Error::index(format!(
209                "only accept 2-D tensor shape, got: {:?}",
210                tensor.shape
211            )));
212        }
213        let dim = tensor.shape[1] as usize;
214        let num_rows = tensor.shape[0] as usize;
215
216        let data = bytes::Bytes::from(tensor.data.clone());
217        let flat_array = bytes_to_array(
218            &DataType::from(pb::tensor::DataType::try_from(tensor.data_type).unwrap()),
219            data,
220            dim * num_rows,
221            0,
222        )?;
223
224        if flat_array.len() != dim * num_rows {
225            return Err(Error::index(format!(
226                "Tensor shape {:?} does not match to data len: {}",
227                tensor.shape,
228                flat_array.len()
229            )));
230        }
231
232        let field = Field::new("item", flat_array.data_type().clone(), true);
233        Ok(Self::try_new(
234            Arc::new(field),
235            dim as i32,
236            flat_array,
237            None,
238        )?)
239    }
240}
241
242/// Check if all vectors in the FixedSizeListArray are finite
243/// null values are considered as not finite
244/// returns a BooleanArray
245/// with the same length as the FixedSizeListArray
246/// with true for finite values and false for non-finite values
247pub fn is_finite(fsl: &FixedSizeListArray) -> BooleanArray {
248    let is_finite = fsl
249        .iter()
250        .map(|v| match v {
251            Some(v) => match v.data_type() {
252                DataType::Float16 => {
253                    let v = v.as_primitive::<Float16Type>();
254                    Array::null_count(v) == 0 && v.values().iter().all(|v| v.is_finite())
255                }
256                DataType::Float32 => {
257                    let v = v.as_primitive::<Float32Type>();
258                    Array::null_count(v) == 0 && v.values().iter().all(|v| v.is_finite())
259                }
260                DataType::Float64 => {
261                    let v = v.as_primitive::<Float64Type>();
262                    Array::null_count(v) == 0 && v.values().iter().all(|v| v.is_finite())
263                }
264                _ => Array::null_count(&v) == 0,
265            },
266            None => false,
267        })
268        .collect::<Vec<_>>();
269    BooleanArray::from(is_finite)
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    use arrow_array::{Float16Array, Float32Array, Float64Array, UInt8Array};
277    use half::f16;
278    use lance_arrow::FixedSizeListArrayExt;
279    use num_traits::identities::Zero;
280
281    use arrow::compute::cast;
282    use rstest::rstest;
283
284    fn build_index(centroids: ArrayRef, dim: usize) -> SimpleIndex {
285        let f32_centroids = cast(&centroids, &DataType::Float32).unwrap();
286        let fsl = FixedSizeListArray::try_new_from_values(f32_centroids, dim as i32).unwrap();
287        let store = SimpleStore::Float(FlatFloatStorage::new(fsl, DistanceType::L2));
288        SimpleIndex::try_new(store).unwrap()
289    }
290
291    fn build_binary_index(centroids: ArrayRef, dim: usize) -> SimpleIndex {
292        let u8_centroids = if centroids.data_type() == &DataType::UInt8 {
293            centroids
294        } else {
295            cast(&centroids, &DataType::UInt8).unwrap()
296        };
297        let fsl = FixedSizeListArray::try_new_from_values(u8_centroids, dim as i32).unwrap();
298        let store = SimpleStore::Binary(FlatBinStorage::new(fsl, DistanceType::Hamming));
299        SimpleIndex::try_new(store).unwrap()
300    }
301
302    #[rstest]
303    #[case::f16(Arc::new(Float16Array::from(
304        (0..100).flat_map(|i| std::iter::repeat_n(f16::from_f32(i as f32), 16)).collect::<Vec<_>>(),
305    )) as ArrayRef)]
306    #[case::f32(Arc::new(Float32Array::from(
307        (0..100).flat_map(|i| std::iter::repeat_n(i as f32, 16)).collect::<Vec<_>>(),
308    )) as ArrayRef)]
309    fn test_simple_index_nearest_centroid(#[case] centroids: ArrayRef) {
310        let index = build_index(centroids, 16);
311        let query: ArrayRef = Arc::new(Float32Array::from(vec![42.1f32; 16]));
312        let (id, _) = index.search(query).unwrap();
313        assert_eq!(id, 42);
314    }
315
316    #[test]
317    fn test_simple_index_nearest_centroid_binary() {
318        let centroids: ArrayRef = Arc::new(UInt8Array::from(
319            (0..100)
320                .flat_map(|i| std::iter::repeat_n(i as u8, 16))
321                .collect::<Vec<_>>(),
322        ));
323        let index = build_binary_index(centroids, 16);
324        let query: ArrayRef = Arc::new(UInt8Array::from(vec![42u8; 16]));
325        let (id, dist) = index.search(query).unwrap();
326        assert_eq!(id, 42);
327        assert_eq!(dist, 0.0);
328    }
329
330    #[test]
331    fn test_simple_index_rejects_f64() {
332        let centroids: ArrayRef = Arc::new(Float64Array::from(vec![0.0; 1600]));
333        let result = SimpleIndex::may_train_index(centroids, 16, DistanceType::L2).unwrap();
334        assert!(result.is_none());
335    }
336
337    #[test]
338    fn test_simple_index_rejects_uint8_non_hamming() {
339        let centroids: ArrayRef = Arc::new(UInt8Array::from(vec![0u8; 1600]));
340        let result = SimpleIndex::may_train_index(centroids, 16, DistanceType::L2).unwrap();
341        assert!(result.is_none());
342    }
343
344    #[test]
345    fn test_fsl_to_tensor() {
346        let fsl =
347            FixedSizeListArray::try_new_from_values(Float16Array::from(vec![f16::zero(); 20]), 5)
348                .unwrap();
349        let tensor = pb::Tensor::try_from(&fsl).unwrap();
350        assert_eq!(tensor.data_type, pb::tensor::DataType::Float16 as i32);
351        assert_eq!(tensor.shape, vec![4, 5]);
352        assert_eq!(tensor.data.len(), 20 * 2);
353
354        let fsl =
355            FixedSizeListArray::try_new_from_values(Float32Array::from(vec![0.0; 20]), 5).unwrap();
356        let tensor = pb::Tensor::try_from(&fsl).unwrap();
357        assert_eq!(tensor.data_type, pb::tensor::DataType::Float32 as i32);
358        assert_eq!(tensor.shape, vec![4, 5]);
359        assert_eq!(tensor.data.len(), 20 * 4);
360
361        let fsl =
362            FixedSizeListArray::try_new_from_values(Float64Array::from(vec![0.0; 20]), 5).unwrap();
363        let tensor = pb::Tensor::try_from(&fsl).unwrap();
364        assert_eq!(tensor.data_type, pb::tensor::DataType::Float64 as i32);
365        assert_eq!(tensor.shape, vec![4, 5]);
366        assert_eq!(tensor.data.len(), 20 * 8);
367    }
368}