1use std::ops::Range;
5
6use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array, UInt32Array};
7use deepsize::DeepSizeOf;
8use itertools::Itertools;
9use lance_arrow::FixedSizeListArrayExt;
10use lance_core::{Error, Result};
11use lance_file::previous::{
12 reader::FileReader as PreviousFileReader, writer::FileWriter as PreviousFileWriter,
13};
14use lance_io::{traits::WriteExt, utils::read_message};
15use lance_linalg::distance::DistanceType;
16use lance_table::io::manifest::ManifestDescribing;
17use log::debug;
18use serde::{Deserialize, Serialize};
19
20use crate::pb::Ivf as PbIvf;
21
22pub const IVF_METADATA_KEY: &str = "lance:ivf";
23pub const IVF_PARTITION_KEY: &str = "lance:ivf:partition";
24
25#[derive(Debug, Clone, PartialEq)]
27pub struct IvfModel {
28 pub centroids: Option<FixedSizeListArray>,
32
33 pub offsets: Vec<usize>,
35
36 pub lengths: Vec<u32>,
38
39 pub loss: Option<f64>,
41}
42
43impl DeepSizeOf for IvfModel {
44 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
45 self.centroids
46 .as_ref()
47 .map(|centroids| centroids.get_array_memory_size())
48 .unwrap_or_default()
49 + self.lengths.deep_size_of_children(context)
50 + self.offsets.deep_size_of_children(context)
51 }
52}
53
54impl IvfModel {
55 pub fn empty() -> Self {
56 Self {
57 centroids: None,
58 offsets: vec![],
59 lengths: vec![],
60 loss: None,
61 }
62 }
63
64 pub fn new(centroids: FixedSizeListArray, loss: Option<f64>) -> Self {
65 Self {
66 centroids: Some(centroids),
67 offsets: vec![],
68 lengths: vec![],
69 loss,
70 }
71 }
72
73 pub fn centroid(&self, partition: usize) -> Option<ArrayRef> {
74 self.centroids.as_ref().map(|c| c.value(partition))
75 }
76
77 pub fn dimension(&self) -> usize {
79 self.centroids
80 .as_ref()
81 .map(|c| c.value_length() as usize)
82 .unwrap_or(0)
83 }
84
85 pub fn num_partitions(&self) -> usize {
87 self.centroids
88 .as_ref()
89 .map(|c| c.len())
90 .unwrap_or_else(|| self.offsets.len())
91 }
92
93 pub fn partition_size(&self, part: usize) -> usize {
94 self.lengths.get(part).copied().unwrap_or_default() as usize
95 }
96
97 pub fn num_rows(&self) -> u64 {
98 self.lengths.iter().map(|x| *x as u64).sum()
99 }
100
101 pub fn loss(&self) -> Option<f64> {
102 self.loss
103 }
104
105 pub fn find_partitions(
107 &self,
108 query: &dyn Array,
109 nprobes: usize,
110 distance_type: DistanceType,
111 ) -> Result<(UInt32Array, Float32Array)> {
112 let internal = crate::vector::ivf::new_ivf_transformer(
113 self.centroids.clone().unwrap(),
114 distance_type,
115 vec![],
116 );
117 internal.find_partitions(query, nprobes)
118 }
119
120 pub fn add_partition(&mut self, len: u32) {
122 self.offsets.push(
123 self.offsets.last().cloned().unwrap_or_default()
124 + self.lengths.last().cloned().unwrap_or_default() as usize,
125 );
126 self.lengths.push(len);
127 }
128
129 pub fn add_partition_with_offset(&mut self, offset: usize, len: u32) {
132 self.offsets.push(offset);
133 self.lengths.push(len);
134 }
135
136 pub fn centroids_array(&self) -> Option<&FixedSizeListArray> {
140 self.centroids.as_ref()
141 }
142
143 pub fn row_range(&self, partition: usize) -> Range<usize> {
144 let start = self.offsets[partition];
145 let end = start + self.lengths[partition] as usize;
146 start..end
147 }
148
149 pub async fn load(reader: &PreviousFileReader) -> Result<Self> {
150 let schema = reader.schema();
151 let meta_str = schema
152 .metadata
153 .get(IVF_METADATA_KEY)
154 .ok_or(Error::index(format!(
155 "{} not found during search",
156 IVF_METADATA_KEY
157 )))?;
158 let ivf_metadata: IvfMetadata = serde_json::from_str(meta_str)
159 .map_err(|e| Error::index(format!("Failed to parse IVF metadata: {}", e)))?;
160
161 let pb: PbIvf = read_message(
162 reader.object_reader.as_ref(),
163 ivf_metadata.pb_position as usize,
164 )
165 .await?;
166 Self::try_from(pb)
167 }
168
169 pub async fn write(&self, writer: &mut PreviousFileWriter<ManifestDescribing>) -> Result<()> {
171 let pb = PbIvf::try_from(self)?;
172 let pos = writer.object_writer.write_protobuf(&pb).await?;
173 let ivf_metadata = IvfMetadata { pb_position: pos };
174 writer.add_metadata(IVF_METADATA_KEY, &serde_json::to_string(&ivf_metadata)?);
175 Ok(())
176 }
177}
178
179impl TryFrom<&IvfModel> for PbIvf {
181 type Error = Error;
182
183 fn try_from(ivf: &IvfModel) -> Result<Self> {
184 let lengths = ivf.lengths.clone();
185
186 Ok(Self {
187 centroids: vec![], lengths,
189 offsets: ivf.offsets.iter().map(|x| *x as u64).collect(),
190 centroids_tensor: ivf.centroids.as_ref().map(|c| c.try_into()).transpose()?,
191 loss: ivf.loss,
192 })
193 }
194}
195
196impl TryFrom<PbIvf> for IvfModel {
198 type Error = Error;
199
200 fn try_from(proto: PbIvf) -> Result<Self> {
201 let centroids = if let Some(tensor) = proto.centroids_tensor.as_ref() {
202 debug!("Ivf: loading IVF centroids from index format v2");
204 Some(FixedSizeListArray::try_from(tensor)?)
205 } else if !proto.centroids.is_empty() {
206 debug!("Ivf: loading IVF centroids from index format v1");
208 let f32_centroids = Float32Array::from(proto.centroids.clone());
209 let dimension = f32_centroids.len() / proto.lengths.len();
210 Some(FixedSizeListArray::try_new_from_values(
211 f32_centroids,
212 dimension as i32,
213 )?)
214 } else {
215 None
218 };
219 let offsets = match proto.offsets.len() {
224 0 => proto
225 .lengths
226 .iter()
227 .scan(0_usize, |state, &x| {
228 let old = *state;
229 *state += x as usize;
230 Some(old)
231 })
232 .collect_vec(),
233 _ => proto.offsets.iter().map(|x| *x as usize).collect(),
234 };
235 assert_eq!(offsets.len(), proto.lengths.len());
236 Ok(Self {
237 centroids,
238 offsets,
239 lengths: proto.lengths,
240 loss: proto.loss,
241 })
242 }
243}
244
245#[derive(Serialize, Deserialize, Debug)]
247struct IvfMetadata {
248 pb_position: usize,
250}
251
252#[cfg(test)]
253mod tests {
254 use std::sync::Arc;
255
256 use arrow_array::{Float32Array, RecordBatch};
257 use arrow_schema::{DataType, Field, Schema as ArrowSchema};
258 use lance_core::datatypes::Schema;
259 use lance_io::object_store::ObjectStore;
260 use lance_table::format::SelfDescribingFileReader;
261 use object_store::path::Path;
262
263 use crate::pb;
264
265 use super::*;
266
267 #[test]
268 fn test_ivf_find_rows() {
269 let mut ivf = IvfModel::empty();
270 ivf.add_partition(20);
271 ivf.add_partition(50);
272
273 assert_eq!(ivf.row_range(0), 0..20);
274 assert_eq!(ivf.row_range(1), 20..70);
275 }
276
277 #[tokio::test]
278 async fn test_write_and_load() {
279 let mut ivf = IvfModel::empty();
280 ivf.add_partition(20);
281 ivf.add_partition(50);
282
283 let object_store = ObjectStore::memory();
284 let path = Path::from("/foo");
285 let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float32, true)]);
286 let schema = Schema::try_from(&arrow_schema).unwrap();
287
288 {
289 let mut writer = PreviousFileWriter::try_new(
290 &object_store,
291 &path,
292 schema.clone(),
293 &Default::default(),
294 )
295 .await
296 .unwrap();
297 let batch = RecordBatch::try_new(
299 Arc::new(arrow_schema),
300 vec![Arc::new(Float32Array::from(vec![Some(1.0)]))],
301 )
302 .unwrap();
303 writer.write(&[batch]).await.unwrap();
304 ivf.write(&mut writer).await.unwrap();
305 writer.finish().await.unwrap();
306 }
307
308 let reader = PreviousFileReader::try_new_self_described(&object_store, &path, None)
309 .await
310 .unwrap();
311 assert!(reader.schema().metadata.contains_key(IVF_METADATA_KEY));
312
313 let ivf2 = IvfModel::load(&reader).await.unwrap();
314 assert_eq!(ivf, ivf2);
315 assert_eq!(ivf2.num_partitions(), 2);
316 }
317
318 #[test]
319 fn test_load_v1_format_ivf() {
320 let pb_ivf = pb::Ivf {
322 centroids: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
323 lengths: vec![2, 2],
324 offsets: vec![0, 2],
325 centroids_tensor: None,
326 loss: None,
327 };
328
329 let ivf = IvfModel::try_from(pb_ivf).unwrap();
330 assert_eq!(ivf.num_partitions(), 2);
331 assert_eq!(ivf.dimension(), 3);
332 assert_eq!(ivf.centroids.as_ref().unwrap().len(), 2);
333 assert_eq!(ivf.centroids.as_ref().unwrap().value_length(), 3);
334 }
335
336 #[test]
337 fn test_centroids_array_getter() {
338 use arrow_array::Float32Array;
339 let values = Float32Array::from(vec![1.0, 2.0, 3.0, 4.0]);
341 let centroids = FixedSizeListArray::try_new_from_values(values, 2).unwrap();
342 let ivf = IvfModel::new(centroids.clone(), None);
343 let out = ivf.centroids_array().unwrap();
344
345 assert_eq!(out.len(), centroids.len());
347 assert_eq!(out.value_length(), centroids.value_length());
348
349 let first = ivf.centroid(0).unwrap();
351 let first_vals = first.as_any().downcast_ref::<Float32Array>().unwrap();
352 assert_eq!(first_vals.len(), 2);
353 assert_eq!(first_vals.value(0), 1.0);
354 assert_eq!(first_vals.value(1), 2.0);
355 }
356}