1use arrow::compute::cast;
5use arrow_array::types::{Float16Type, Float32Type, Float64Type};
6use arrow_array::{Array, ArrayRef, BooleanArray, FixedSizeListArray, cast::AsArray};
7use arrow_schema::{DataType, Field};
8use lance_arrow::FixedSizeListArrayExt;
9use lance_core::{Error, Result};
10use lance_io::encodings::plain::bytes_to_array;
11use lance_linalg::distance::DistanceType;
12use prost::bytes;
13use std::sync::LazyLock;
14use std::{ops::Range, sync::Arc};
15
16use super::pb;
17use crate::pb::Tensor;
18use crate::vector::flat::storage::FlatBinStorage;
19use crate::vector::flat::storage::FlatFloatStorage;
20use crate::vector::hnsw::HNSW;
21use crate::vector::hnsw::builder::{HnswBuildParams, HnswQueryParams};
22use crate::vector::v3::subindex::IvfSubIndex;
23
24enum SimpleIndexStatus {
25 Auto,
26 Enabled,
27 Disabled,
28}
29
30static USE_HNSW_SPEEDUP_INDEXING: LazyLock<SimpleIndexStatus> = LazyLock::new(|| {
31 if let Ok(v) = std::env::var("LANCE_USE_HNSW_SPEEDUP_INDEXING") {
32 if v == "enabled" {
33 SimpleIndexStatus::Enabled
34 } else if v == "disabled" {
35 SimpleIndexStatus::Disabled
36 } else {
37 SimpleIndexStatus::Auto
38 }
39 } else {
40 SimpleIndexStatus::Auto
41 }
42});
43
44#[derive(Debug)]
45pub struct SimpleIndex {
46 store: SimpleStore,
47 index: HNSW,
48}
49
50#[derive(Debug)]
51enum SimpleStore {
52 Float(FlatFloatStorage),
53 Binary(FlatBinStorage),
54}
55
56impl SimpleIndex {
57 fn try_new(store: SimpleStore) -> Result<Self> {
58 let hnsw = match &store {
59 SimpleStore::Float(store) => HNSW::index_vectors(
60 store,
61 HnswBuildParams::default().ef_construction(15).num_edges(12),
62 )?,
63 SimpleStore::Binary(store) => HNSW::index_vectors(
64 store,
65 HnswBuildParams::default().ef_construction(15).num_edges(12),
66 )?,
67 };
68 Ok(Self { store, index: hnsw })
69 }
70
71 pub fn may_train_index(
78 centroids: ArrayRef,
79 dimension: usize,
80 distance_type: DistanceType,
81 ) -> Result<Option<Self>> {
82 match *USE_HNSW_SPEEDUP_INDEXING {
83 SimpleIndexStatus::Auto => {
84 if centroids.len() < 1_000_000 {
85 return Ok(None);
86 }
87 }
88 SimpleIndexStatus::Disabled => return Ok(None),
89 _ => {}
90 }
91
92 let store = match (centroids.data_type(), distance_type) {
93 (DataType::Float16 | DataType::Float32 | DataType::Float64, _) => {
94 let fsl = FixedSizeListArray::try_new_from_values(centroids, dimension as i32)?;
95 SimpleStore::Float(FlatFloatStorage::new(fsl, distance_type))
96 }
97 (DataType::UInt8, DistanceType::Hamming) => {
98 let fsl = FixedSizeListArray::try_new_from_values(centroids, dimension as i32)?;
99 SimpleStore::Binary(FlatBinStorage::new(fsl, distance_type))
100 }
101 _ => return Ok(None),
102 };
103 Self::try_new(store).map(Some)
104 }
105
106 pub(crate) fn search(&self, query: ArrayRef) -> Result<(u32, f32)> {
107 let params = HnswQueryParams {
108 ef: 15,
109 lower_bound: None,
110 upper_bound: None,
111 dist_q_c: 0.0,
112 };
113 let res = match &self.store {
114 SimpleStore::Float(store) => self.index.search_basic(query, 1, ¶ms, None, store)?,
115 SimpleStore::Binary(store) => {
116 let query = if query.data_type() == &DataType::UInt8 {
117 query
118 } else {
119 cast(&query, &DataType::UInt8).map_err(|e| Error::index(e.to_string()))?
120 };
121 self.index.search_basic(query, 1, ¶ms, None, store)?
122 }
123 };
124 Ok((res[0].id, res[0].dist.0))
125 }
126}
127
128#[inline]
129pub(crate) fn do_prefetch<T>(ptrs: Range<*const T>) {
130 unsafe {
133 let (ptr, end_ptr) = (ptrs.start as *const i8, ptrs.end as *const i8);
134 let mut current_ptr = ptr;
135 while current_ptr < end_ptr {
136 const CACHE_LINE_SIZE: usize = 64;
137 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
138 {
139 use core::arch::x86_64::{_MM_HINT_T0, _mm_prefetch};
140 _mm_prefetch(current_ptr, _MM_HINT_T0);
141 }
142 current_ptr = current_ptr.add(CACHE_LINE_SIZE);
143 }
144 }
145}
146
147impl From<pb::tensor::DataType> for DataType {
148 fn from(dt: pb::tensor::DataType) -> Self {
149 match dt {
150 pb::tensor::DataType::Uint8 => Self::UInt8,
151 pb::tensor::DataType::Uint16 => Self::UInt16,
152 pb::tensor::DataType::Uint32 => Self::UInt32,
153 pb::tensor::DataType::Uint64 => Self::UInt64,
154 pb::tensor::DataType::Float16 => Self::Float16,
155 pb::tensor::DataType::Float32 => Self::Float32,
156 pb::tensor::DataType::Float64 => Self::Float64,
157 pb::tensor::DataType::Bfloat16 => unimplemented!(),
158 }
159 }
160}
161
162impl TryFrom<&DataType> for pb::tensor::DataType {
163 type Error = Error;
164
165 fn try_from(dt: &DataType) -> Result<Self> {
166 match dt {
167 DataType::UInt8 => Ok(Self::Uint8),
168 DataType::UInt16 => Ok(Self::Uint16),
169 DataType::UInt32 => Ok(Self::Uint32),
170 DataType::UInt64 => Ok(Self::Uint64),
171 DataType::Float16 => Ok(Self::Float16),
172 DataType::Float32 => Ok(Self::Float32),
173 DataType::Float64 => Ok(Self::Float64),
174 _ => Err(Error::index(format!(
175 "pb tensor type not supported: {:?}",
176 dt
177 ))),
178 }
179 }
180}
181
182impl TryFrom<DataType> for pb::tensor::DataType {
183 type Error = Error;
184
185 fn try_from(dt: DataType) -> Result<Self> {
186 (&dt).try_into()
187 }
188}
189
190impl TryFrom<&FixedSizeListArray> for pb::Tensor {
191 type Error = Error;
192
193 fn try_from(array: &FixedSizeListArray) -> Result<Self> {
194 let mut tensor = Self::default();
195 tensor.data_type = pb::tensor::DataType::try_from(array.value_type())? as i32;
196 tensor.shape = vec![Array::len(array) as u32, array.value_length() as u32];
197 let flat_array = array.values();
198 tensor.data = flat_array.into_data().buffers()[0].to_vec();
199 Ok(tensor)
200 }
201}
202
203impl TryFrom<&pb::Tensor> for FixedSizeListArray {
204 type Error = Error;
205
206 fn try_from(tensor: &Tensor) -> Result<Self> {
207 if tensor.shape.len() != 2 {
208 return Err(Error::index(format!(
209 "only accept 2-D tensor shape, got: {:?}",
210 tensor.shape
211 )));
212 }
213 let dim = tensor.shape[1] as usize;
214 let num_rows = tensor.shape[0] as usize;
215
216 let data = bytes::Bytes::from(tensor.data.clone());
217 let flat_array = bytes_to_array(
218 &DataType::from(pb::tensor::DataType::try_from(tensor.data_type).unwrap()),
219 data,
220 dim * num_rows,
221 0,
222 )?;
223
224 if flat_array.len() != dim * num_rows {
225 return Err(Error::index(format!(
226 "Tensor shape {:?} does not match to data len: {}",
227 tensor.shape,
228 flat_array.len()
229 )));
230 }
231
232 let field = Field::new("item", flat_array.data_type().clone(), true);
233 Ok(Self::try_new(
234 Arc::new(field),
235 dim as i32,
236 flat_array,
237 None,
238 )?)
239 }
240}
241
242pub fn is_finite(fsl: &FixedSizeListArray) -> BooleanArray {
248 let is_finite = fsl
249 .iter()
250 .map(|v| match v {
251 Some(v) => match v.data_type() {
252 DataType::Float16 => {
253 let v = v.as_primitive::<Float16Type>();
254 Array::null_count(v) == 0 && v.values().iter().all(|v| v.is_finite())
255 }
256 DataType::Float32 => {
257 let v = v.as_primitive::<Float32Type>();
258 Array::null_count(v) == 0 && v.values().iter().all(|v| v.is_finite())
259 }
260 DataType::Float64 => {
261 let v = v.as_primitive::<Float64Type>();
262 Array::null_count(v) == 0 && v.values().iter().all(|v| v.is_finite())
263 }
264 _ => Array::null_count(&v) == 0,
265 },
266 None => false,
267 })
268 .collect::<Vec<_>>();
269 BooleanArray::from(is_finite)
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 use arrow_array::{Float16Array, Float32Array, Float64Array, UInt8Array};
277 use half::f16;
278 use lance_arrow::FixedSizeListArrayExt;
279 use num_traits::identities::Zero;
280
281 use arrow::compute::cast;
282 use rstest::rstest;
283
284 fn build_index(centroids: ArrayRef, dim: usize) -> SimpleIndex {
285 let f32_centroids = cast(¢roids, &DataType::Float32).unwrap();
286 let fsl = FixedSizeListArray::try_new_from_values(f32_centroids, dim as i32).unwrap();
287 let store = SimpleStore::Float(FlatFloatStorage::new(fsl, DistanceType::L2));
288 SimpleIndex::try_new(store).unwrap()
289 }
290
291 fn build_binary_index(centroids: ArrayRef, dim: usize) -> SimpleIndex {
292 let u8_centroids = if centroids.data_type() == &DataType::UInt8 {
293 centroids
294 } else {
295 cast(¢roids, &DataType::UInt8).unwrap()
296 };
297 let fsl = FixedSizeListArray::try_new_from_values(u8_centroids, dim as i32).unwrap();
298 let store = SimpleStore::Binary(FlatBinStorage::new(fsl, DistanceType::Hamming));
299 SimpleIndex::try_new(store).unwrap()
300 }
301
302 #[rstest]
303 #[case::f16(Arc::new(Float16Array::from(
304 (0..100).flat_map(|i| std::iter::repeat_n(f16::from_f32(i as f32), 16)).collect::<Vec<_>>(),
305 )) as ArrayRef)]
306 #[case::f32(Arc::new(Float32Array::from(
307 (0..100).flat_map(|i| std::iter::repeat_n(i as f32, 16)).collect::<Vec<_>>(),
308 )) as ArrayRef)]
309 fn test_simple_index_nearest_centroid(#[case] centroids: ArrayRef) {
310 let index = build_index(centroids, 16);
311 let query: ArrayRef = Arc::new(Float32Array::from(vec![42.1f32; 16]));
312 let (id, _) = index.search(query).unwrap();
313 assert_eq!(id, 42);
314 }
315
316 #[test]
317 fn test_simple_index_nearest_centroid_binary() {
318 let centroids: ArrayRef = Arc::new(UInt8Array::from(
319 (0..100)
320 .flat_map(|i| std::iter::repeat_n(i as u8, 16))
321 .collect::<Vec<_>>(),
322 ));
323 let index = build_binary_index(centroids, 16);
324 let query: ArrayRef = Arc::new(UInt8Array::from(vec![42u8; 16]));
325 let (id, dist) = index.search(query).unwrap();
326 assert_eq!(id, 42);
327 assert_eq!(dist, 0.0);
328 }
329
330 #[test]
331 fn test_simple_index_rejects_f64() {
332 let centroids: ArrayRef = Arc::new(Float64Array::from(vec![0.0; 1600]));
333 let result = SimpleIndex::may_train_index(centroids, 16, DistanceType::L2).unwrap();
334 assert!(result.is_none());
335 }
336
337 #[test]
338 fn test_simple_index_rejects_uint8_non_hamming() {
339 let centroids: ArrayRef = Arc::new(UInt8Array::from(vec![0u8; 1600]));
340 let result = SimpleIndex::may_train_index(centroids, 16, DistanceType::L2).unwrap();
341 assert!(result.is_none());
342 }
343
344 #[test]
345 fn test_fsl_to_tensor() {
346 let fsl =
347 FixedSizeListArray::try_new_from_values(Float16Array::from(vec![f16::zero(); 20]), 5)
348 .unwrap();
349 let tensor = pb::Tensor::try_from(&fsl).unwrap();
350 assert_eq!(tensor.data_type, pb::tensor::DataType::Float16 as i32);
351 assert_eq!(tensor.shape, vec![4, 5]);
352 assert_eq!(tensor.data.len(), 20 * 2);
353
354 let fsl =
355 FixedSizeListArray::try_new_from_values(Float32Array::from(vec![0.0; 20]), 5).unwrap();
356 let tensor = pb::Tensor::try_from(&fsl).unwrap();
357 assert_eq!(tensor.data_type, pb::tensor::DataType::Float32 as i32);
358 assert_eq!(tensor.shape, vec![4, 5]);
359 assert_eq!(tensor.data.len(), 20 * 4);
360
361 let fsl =
362 FixedSizeListArray::try_new_from_values(Float64Array::from(vec![0.0; 20]), 5).unwrap();
363 let tensor = pb::Tensor::try_from(&fsl).unwrap();
364 assert_eq!(tensor.data_type, pb::tensor::DataType::Float64 as i32);
365 assert_eq!(tensor.shape, vec![4, 5]);
366 assert_eq!(tensor.data.len(), 20 * 8);
367 }
368}