lance_index/vector/
utils.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use arrow::{
5    array::AsArray,
6    datatypes::{Float16Type, Float32Type, Float64Type},
7};
8use arrow_array::{Array, ArrayRef, BooleanArray, FixedSizeListArray};
9use arrow_schema::{DataType, Field};
10use lance_arrow::FixedSizeListArrayExt;
11use lance_core::{Error, Result};
12use lance_io::encodings::plain::bytes_to_array;
13use lance_linalg::distance::DistanceType;
14use prost::bytes;
15use snafu::location;
16use std::sync::LazyLock;
17use std::{ops::Range, sync::Arc};
18
19use super::pb;
20use crate::pb::Tensor;
21use crate::vector::flat::storage::FlatFloatStorage;
22use crate::vector::hnsw::builder::{HnswBuildParams, HnswQueryParams};
23use crate::vector::hnsw::HNSW;
24use crate::vector::v3::subindex::IvfSubIndex;
25
26enum SimpleIndexStatus {
27    Auto,
28    Enabled,
29    Disabled,
30}
31
32static USE_HNSW_SPEEDUP_INDEXING: LazyLock<SimpleIndexStatus> = LazyLock::new(|| {
33    if let Ok(v) = std::env::var("LANCE_USE_HNSW_SPEEDUP_INDEXING") {
34        if v == "enabled" {
35            SimpleIndexStatus::Enabled
36        } else if v == "disabled" {
37            SimpleIndexStatus::Disabled
38        } else {
39            SimpleIndexStatus::Auto
40        }
41    } else {
42        SimpleIndexStatus::Auto
43    }
44});
45
46#[derive(Debug)]
47pub struct SimpleIndex {
48    store: FlatFloatStorage,
49    index: HNSW,
50}
51
52impl SimpleIndex {
53    pub fn try_new(store: FlatFloatStorage) -> Result<Self> {
54        let hnsw = HNSW::index_vectors(
55            &store,
56            HnswBuildParams::default().ef_construction(15).num_edges(12),
57        )?;
58        Ok(Self { store, index: hnsw })
59    }
60
61    // train HNSW over the centroids to speed up finding the nearest clusters,
62    // only train if all conditions are met:
63    //  - the centroids are float32s or uint8s
64    //  - `num_centroids * dimension >= 1_000_000`
65    //      we benchmarked that it's 2x faster in the case of 1024 centroids and 1024 dimensions,
66    //      so set the threshold to 1_000_000.
67    pub fn may_train_index(
68        centroids: ArrayRef,
69        dimension: usize,
70        distance_type: DistanceType,
71    ) -> Result<Option<Self>> {
72        match *USE_HNSW_SPEEDUP_INDEXING {
73            SimpleIndexStatus::Auto => {
74                if centroids.len() < 1_000_000 {
75                    return Ok(None);
76                }
77            }
78            SimpleIndexStatus::Disabled => return Ok(None),
79            _ => {}
80        }
81
82        match centroids.data_type() {
83            DataType::Float32 => {
84                let fsl =
85                    FixedSizeListArray::try_new_from_values(centroids.clone(), dimension as i32)?;
86                let store = FlatFloatStorage::new(fsl, distance_type);
87                Self::try_new(store).map(Some)
88            }
89            _ => Ok(None),
90        }
91    }
92
93    pub(crate) fn search(&self, query: ArrayRef) -> Result<(u32, f32)> {
94        let res = self.index.search_basic(
95            query,
96            1,
97            &HnswQueryParams {
98                ef: 15,
99                lower_bound: None,
100                upper_bound: None,
101                dist_q_c: 0.0,
102            },
103            None,
104            &self.store,
105        )?;
106        Ok((res[0].id, res[0].dist.0))
107    }
108}
109
110#[inline]
111#[allow(dead_code)]
112pub(crate) fn prefetch_arrow_array(array: &dyn Array) -> Result<()> {
113    match array.data_type() {
114        DataType::FixedSizeList(_, _) => {
115            let array = array.as_fixed_size_list();
116            return prefetch_arrow_array(array.values());
117        }
118        DataType::Float16 => {
119            let array = array.as_primitive::<Float16Type>();
120            do_prefetch(array.values().as_ptr_range())
121        }
122        DataType::Float32 => {
123            let array = array.as_primitive::<Float32Type>();
124            do_prefetch(array.values().as_ptr_range())
125        }
126        DataType::Float64 => {
127            let array = array.as_primitive::<Float64Type>();
128            do_prefetch(array.values().as_ptr_range())
129        }
130        _ => {
131            return Err(Error::io(
132                format!("unsupported prefetch on {} type", array.data_type()),
133                location!(),
134            ));
135        }
136    }
137
138    Ok(())
139}
140
141#[inline]
142pub(crate) fn do_prefetch<T>(ptrs: Range<*const T>) {
143    // TODO use rust intrinsics instead of x86 intrinsics
144    // TODO finish this
145    unsafe {
146        let (ptr, end_ptr) = (ptrs.start as *const i8, ptrs.end as *const i8);
147        let mut current_ptr = ptr;
148        while current_ptr < end_ptr {
149            const CACHE_LINE_SIZE: usize = 64;
150            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
151            {
152                use core::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
153                _mm_prefetch(current_ptr, _MM_HINT_T0);
154            }
155            current_ptr = current_ptr.add(CACHE_LINE_SIZE);
156        }
157    }
158}
159
160impl From<pb::tensor::DataType> for DataType {
161    fn from(dt: pb::tensor::DataType) -> Self {
162        match dt {
163            pb::tensor::DataType::Uint8 => Self::UInt8,
164            pb::tensor::DataType::Uint16 => Self::UInt16,
165            pb::tensor::DataType::Uint32 => Self::UInt32,
166            pb::tensor::DataType::Uint64 => Self::UInt64,
167            pb::tensor::DataType::Float16 => Self::Float16,
168            pb::tensor::DataType::Float32 => Self::Float32,
169            pb::tensor::DataType::Float64 => Self::Float64,
170            pb::tensor::DataType::Bfloat16 => unimplemented!(),
171        }
172    }
173}
174
175impl TryFrom<&DataType> for pb::tensor::DataType {
176    type Error = Error;
177
178    fn try_from(dt: &DataType) -> Result<Self> {
179        match dt {
180            DataType::UInt8 => Ok(Self::Uint8),
181            DataType::UInt16 => Ok(Self::Uint16),
182            DataType::UInt32 => Ok(Self::Uint32),
183            DataType::UInt64 => Ok(Self::Uint64),
184            DataType::Float16 => Ok(Self::Float16),
185            DataType::Float32 => Ok(Self::Float32),
186            DataType::Float64 => Ok(Self::Float64),
187            _ => Err(Error::Index {
188                message: format!("pb tensor type not supported: {:?}", dt),
189                location: location!(),
190            }),
191        }
192    }
193}
194
195impl TryFrom<DataType> for pb::tensor::DataType {
196    type Error = Error;
197
198    fn try_from(dt: DataType) -> Result<Self> {
199        (&dt).try_into()
200    }
201}
202
203impl TryFrom<&FixedSizeListArray> for pb::Tensor {
204    type Error = Error;
205
206    fn try_from(array: &FixedSizeListArray) -> Result<Self> {
207        let mut tensor = Self::default();
208        tensor.data_type = pb::tensor::DataType::try_from(array.value_type())? as i32;
209        tensor.shape = vec![array.len() as u32, array.value_length() as u32];
210        let flat_array = array.values();
211        tensor.data = flat_array.into_data().buffers()[0].to_vec();
212        Ok(tensor)
213    }
214}
215
216impl TryFrom<&pb::Tensor> for FixedSizeListArray {
217    type Error = Error;
218
219    fn try_from(tensor: &Tensor) -> Result<Self> {
220        if tensor.shape.len() != 2 {
221            return Err(Error::Index {
222                message: format!("only accept 2-D tensor shape, got: {:?}", tensor.shape),
223                location: location!(),
224            });
225        }
226        let dim = tensor.shape[1] as usize;
227        let num_rows = tensor.shape[0] as usize;
228
229        let data = bytes::Bytes::from(tensor.data.clone());
230        let flat_array = bytes_to_array(
231            &DataType::from(pb::tensor::DataType::try_from(tensor.data_type).unwrap()),
232            data,
233            dim * num_rows,
234            0,
235        )?;
236
237        if flat_array.len() != dim * num_rows {
238            return Err(Error::Index {
239                message: format!(
240                    "Tensor shape {:?} does not match to data len: {}",
241                    tensor.shape,
242                    flat_array.len()
243                ),
244                location: location!(),
245            });
246        }
247
248        let field = Field::new("item", flat_array.data_type().clone(), true);
249        Ok(Self::try_new(
250            Arc::new(field),
251            dim as i32,
252            flat_array,
253            None,
254        )?)
255    }
256}
257
258/// Check if all vectors in the FixedSizeListArray are finite
259/// null values are considered as not finite
260/// returns a BooleanArray
261/// with the same length as the FixedSizeListArray
262/// with true for finite values and false for non-finite values
263pub fn is_finite(fsl: &FixedSizeListArray) -> BooleanArray {
264    let is_finite = fsl
265        .iter()
266        .map(|v| match v {
267            Some(v) => match v.data_type() {
268                DataType::Float16 => {
269                    let v = v.as_primitive::<Float16Type>();
270                    v.null_count() == 0 && v.values().iter().all(|v| v.is_finite())
271                }
272                DataType::Float32 => {
273                    let v = v.as_primitive::<Float32Type>();
274                    v.null_count() == 0 && v.values().iter().all(|v| v.is_finite())
275                }
276                DataType::Float64 => {
277                    let v = v.as_primitive::<Float64Type>();
278                    v.null_count() == 0 && v.values().iter().all(|v| v.is_finite())
279                }
280                _ => v.null_count() == 0,
281            },
282            None => false,
283        })
284        .collect::<Vec<_>>();
285    BooleanArray::from(is_finite)
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    use arrow_array::{Float16Array, Float32Array, Float64Array};
293    use half::f16;
294    use lance_arrow::FixedSizeListArrayExt;
295    use num_traits::identities::Zero;
296
297    #[test]
298    fn test_fsl_to_tensor() {
299        let fsl =
300            FixedSizeListArray::try_new_from_values(Float16Array::from(vec![f16::zero(); 20]), 5)
301                .unwrap();
302        let tensor = pb::Tensor::try_from(&fsl).unwrap();
303        assert_eq!(tensor.data_type, pb::tensor::DataType::Float16 as i32);
304        assert_eq!(tensor.shape, vec![4, 5]);
305        assert_eq!(tensor.data.len(), 20 * 2);
306
307        let fsl =
308            FixedSizeListArray::try_new_from_values(Float32Array::from(vec![0.0; 20]), 5).unwrap();
309        let tensor = pb::Tensor::try_from(&fsl).unwrap();
310        assert_eq!(tensor.data_type, pb::tensor::DataType::Float32 as i32);
311        assert_eq!(tensor.shape, vec![4, 5]);
312        assert_eq!(tensor.data.len(), 20 * 4);
313
314        let fsl =
315            FixedSizeListArray::try_new_from_values(Float64Array::from(vec![0.0; 20]), 5).unwrap();
316        let tensor = pb::Tensor::try_from(&fsl).unwrap();
317        assert_eq!(tensor.data_type, pb::tensor::DataType::Float64 as i32);
318        assert_eq!(tensor.shape, vec![4, 5]);
319        assert_eq!(tensor.data.len(), 20 * 8);
320    }
321}