1use std::collections::{BinaryHeap, HashMap};
8use std::sync::Arc;
9
10use arrow::array::AsArray;
11use arrow_array::{Array, ArrayRef, Float32Array, RecordBatch, UInt64Array};
12use arrow_schema::{DataType, Field, Schema, SchemaRef};
13use deepsize::DeepSizeOf;
14use lance_core::{Error, Result, ROW_ID_FIELD};
15use lance_file::previous::reader::FileReader as PreviousFileReader;
16use lance_linalg::distance::DistanceType;
17use serde::{Deserialize, Serialize};
18use snafu::location;
19
20use crate::{
21 metrics::MetricsCollector,
22 prefilter::PreFilter,
23 vector::{
24 graph::OrderedNode,
25 quantizer::{Quantization, QuantizationType, Quantizer, QuantizerMetadata},
26 storage::{DistCalculator, VectorStore},
27 v3::subindex::IvfSubIndex,
28 Query, DIST_COL,
29 },
30};
31
32use super::storage::{FlatBinStorage, FlatFloatStorage, FLAT_COLUMN};
33
34#[derive(Debug, Clone, Default, DeepSizeOf)]
37pub struct FlatIndex {}
38
39use std::sync::LazyLock;
40
41static ANN_SEARCH_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
42 Schema::new(vec![
43 Field::new(DIST_COL, DataType::Float32, true),
44 ROW_ID_FIELD.clone(),
45 ])
46 .into()
47});
48
49#[derive(Default)]
50pub struct FlatQueryParams {
51 lower_bound: Option<f32>,
52 upper_bound: Option<f32>,
53 dist_q_c: f32,
54}
55
56impl From<&Query> for FlatQueryParams {
57 fn from(q: &Query) -> Self {
58 Self {
59 lower_bound: q.lower_bound,
60 upper_bound: q.upper_bound,
61 dist_q_c: q.dist_q_c,
62 }
63 }
64}
65
66impl IvfSubIndex for FlatIndex {
67 type QueryParams = FlatQueryParams;
68 type BuildParams = ();
69
70 fn name() -> &'static str {
71 "FLAT"
72 }
73
74 fn metadata_key() -> &'static str {
75 "lance:flat"
76 }
77
78 fn schema() -> arrow_schema::SchemaRef {
79 Schema::new(vec![Field::new("__flat_marker", DataType::UInt64, false)]).into()
80 }
81
82 fn search(
83 &self,
84 query: ArrayRef,
85 k: usize,
86 params: Self::QueryParams,
87 storage: &impl VectorStore,
88 prefilter: Arc<dyn PreFilter>,
89 metrics: &dyn MetricsCollector,
90 ) -> Result<RecordBatch> {
91 let is_range_query = params.lower_bound.is_some() || params.upper_bound.is_some();
92 let row_ids = storage.row_ids();
93 let dist_calc = storage.dist_calculator(query, params.dist_q_c);
94 let mut res = BinaryHeap::with_capacity(k);
95 metrics.record_comparisons(storage.len());
96
97 match prefilter.is_empty() {
98 true => {
99 let dists = dist_calc.distance_all(k);
100
101 if is_range_query {
102 let lower_bound = params.lower_bound.unwrap_or(f32::MIN).into();
103 let upper_bound = params.upper_bound.unwrap_or(f32::MAX).into();
104
105 for (&row_id, dist) in row_ids.zip(dists) {
106 let dist = dist.into();
107 if dist < lower_bound || dist >= upper_bound {
108 continue;
109 }
110 if res.len() < k {
111 res.push(OrderedNode::new(row_id, dist));
112 } else if res.peek().unwrap().dist > dist {
113 res.pop();
114 res.push(OrderedNode::new(row_id, dist));
115 }
116 }
117 } else {
118 for (&row_id, dist) in row_ids.zip(dists) {
119 let dist = dist.into();
120 if res.len() < k {
121 res.push(OrderedNode::new(row_id, dist));
122 } else if res.peek().unwrap().dist > dist {
123 res.pop();
124 res.push(OrderedNode::new(row_id, dist));
125 }
126 }
127 }
128 }
129 false => {
130 let row_id_mask = prefilter.mask();
131 if is_range_query {
132 let lower_bound = params.lower_bound.unwrap_or(f32::MIN).into();
133 let upper_bound = params.upper_bound.unwrap_or(f32::MAX).into();
134 for (id, &row_id) in row_ids.enumerate() {
135 if !row_id_mask.selected(row_id) {
136 continue;
137 }
138 let dist = dist_calc.distance(id as u32).into();
139 if dist < lower_bound || dist >= upper_bound {
140 continue;
141 }
142
143 if res.len() < k {
144 res.push(OrderedNode::new(row_id, dist));
145 } else if res.peek().unwrap().dist > dist {
146 res.pop();
147 res.push(OrderedNode::new(row_id, dist));
148 }
149 }
150 } else {
151 for (id, &row_id) in row_ids.enumerate() {
152 if !row_id_mask.selected(row_id) {
153 continue;
154 }
155
156 let dist = dist_calc.distance(id as u32).into();
157 if res.len() < k {
158 res.push(OrderedNode::new(row_id, dist));
159 } else if res.peek().unwrap().dist > dist {
160 res.pop();
161 res.push(OrderedNode::new(row_id, dist));
162 }
163 }
164 }
165 }
166 };
167
168 let (row_ids, dists): (Vec<_>, Vec<_>) = res.into_iter().map(|r| (r.id, r.dist.0)).unzip();
171 let (row_ids, dists) = (UInt64Array::from(row_ids), Float32Array::from(dists));
172
173 Ok(RecordBatch::try_new(
174 ANN_SEARCH_SCHEMA.clone(),
175 vec![Arc::new(dists), Arc::new(row_ids)],
176 )?)
177 }
178
179 fn load(_: RecordBatch) -> Result<Self> {
180 Ok(Self {})
181 }
182
183 fn index_vectors(_: &impl VectorStore, _: Self::BuildParams) -> Result<Self>
184 where
185 Self: Sized,
186 {
187 Ok(Self {})
188 }
189
190 fn remap(&self, _: &HashMap<u64, Option<u64>>, _: &impl VectorStore) -> Result<Self> {
191 Ok(self.clone())
192 }
193
194 fn to_batch(&self) -> Result<RecordBatch> {
195 Ok(RecordBatch::new_empty(Schema::empty().into()))
196 }
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize, DeepSizeOf)]
200pub struct FlatMetadata {
201 pub dim: usize,
202}
203
204#[async_trait::async_trait]
205impl QuantizerMetadata for FlatMetadata {
206 async fn load(_: &PreviousFileReader) -> Result<Self> {
207 unimplemented!("Flat will be used in new index builder which doesn't require this")
208 }
209}
210
211#[derive(Debug, Clone, DeepSizeOf)]
212pub struct FlatQuantizer {
213 dim: usize,
214 distance_type: DistanceType,
215}
216
217impl FlatQuantizer {
218 pub fn new(dim: usize, distance_type: DistanceType) -> Self {
219 Self { dim, distance_type }
220 }
221}
222
223impl Quantization for FlatQuantizer {
224 type BuildParams = ();
225 type Metadata = FlatMetadata;
226 type Storage = FlatFloatStorage;
227
228 fn build(data: &dyn Array, distance_type: DistanceType, _: &Self::BuildParams) -> Result<Self> {
229 let dim = data.as_fixed_size_list().value_length();
230 Ok(Self::new(dim as usize, distance_type))
231 }
232
233 fn retrain(&mut self, _: &dyn Array) -> Result<()> {
234 Ok(())
235 }
236
237 fn code_dim(&self) -> usize {
238 self.dim
239 }
240
241 fn column(&self) -> &'static str {
242 FLAT_COLUMN
243 }
244
245 fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer> {
246 Ok(Quantizer::Flat(Self {
247 dim: metadata.dim,
248 distance_type,
249 }))
250 }
251
252 fn metadata(&self, _: Option<crate::vector::quantizer::QuantizationMetadata>) -> FlatMetadata {
253 FlatMetadata { dim: self.dim }
254 }
255
256 fn metadata_key() -> &'static str {
257 "flat"
258 }
259
260 fn quantization_type() -> QuantizationType {
261 QuantizationType::Flat
262 }
263
264 fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
265 Ok(vectors.slice(0, vectors.len()))
266 }
267
268 fn field(&self) -> Field {
269 Field::new(
270 FLAT_COLUMN,
271 DataType::FixedSizeList(
272 Arc::new(Field::new("item", DataType::Float32, true)),
273 self.dim as i32,
274 ),
275 true,
276 )
277 }
278}
279
280impl From<FlatQuantizer> for Quantizer {
281 fn from(value: FlatQuantizer) -> Self {
282 Self::Flat(value)
283 }
284}
285
286impl TryFrom<Quantizer> for FlatQuantizer {
287 type Error = Error;
288
289 fn try_from(value: Quantizer) -> Result<Self> {
290 match value {
291 Quantizer::Flat(quantizer) => Ok(quantizer),
292 _ => Err(Error::invalid_input(
293 "quantizer is not FlatQuantizer",
294 location!(),
295 )),
296 }
297 }
298}
299
300#[derive(Debug, Clone, DeepSizeOf)]
301pub struct FlatBinQuantizer {
302 dim: usize,
303 distance_type: DistanceType,
304}
305
306impl FlatBinQuantizer {
307 pub fn new(dim: usize, distance_type: DistanceType) -> Self {
308 Self { dim, distance_type }
309 }
310}
311
312impl Quantization for FlatBinQuantizer {
313 type BuildParams = ();
314 type Metadata = FlatMetadata;
315 type Storage = FlatBinStorage;
316
317 fn build(data: &dyn Array, distance_type: DistanceType, _: &Self::BuildParams) -> Result<Self> {
318 let dim = data.as_fixed_size_list().value_length();
319 Ok(Self::new(dim as usize, distance_type))
320 }
321
322 fn retrain(&mut self, _: &dyn Array) -> Result<()> {
323 Ok(())
324 }
325
326 fn code_dim(&self) -> usize {
327 self.dim
328 }
329
330 fn column(&self) -> &'static str {
331 FLAT_COLUMN
332 }
333
334 fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer> {
335 Ok(Quantizer::FlatBin(Self {
336 dim: metadata.dim,
337 distance_type,
338 }))
339 }
340
341 fn metadata(&self, _: Option<crate::vector::quantizer::QuantizationMetadata>) -> FlatMetadata {
342 FlatMetadata { dim: self.dim }
343 }
344
345 fn metadata_key() -> &'static str {
346 "flat"
347 }
348
349 fn quantization_type() -> QuantizationType {
350 QuantizationType::Flat
351 }
352
353 fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
354 Ok(vectors.slice(0, vectors.len()))
355 }
356
357 fn field(&self) -> Field {
358 Field::new(
359 FLAT_COLUMN,
360 DataType::FixedSizeList(
361 Arc::new(Field::new("item", DataType::UInt8, true)),
362 self.dim as i32,
363 ),
364 true,
365 )
366 }
367}
368
369impl From<FlatBinQuantizer> for Quantizer {
370 fn from(value: FlatBinQuantizer) -> Self {
371 Self::FlatBin(value)
372 }
373}
374
375impl TryFrom<Quantizer> for FlatBinQuantizer {
376 type Error = Error;
377
378 fn try_from(value: Quantizer) -> Result<Self> {
379 match value {
380 Quantizer::FlatBin(quantizer) => Ok(quantizer),
381 _ => Err(Error::invalid_input(
382 "quantizer is not FlatBinQuantizer",
383 location!(),
384 )),
385 }
386 }
387}