1use std::ops::Range;
5
6use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array, UInt32Array};
7use deepsize::DeepSizeOf;
8use itertools::Itertools;
9use lance_arrow::FixedSizeListArrayExt;
10use lance_core::{Error, Result};
11use lance_file::previous::{
12 reader::FileReader as PreviousFileReader, writer::FileWriter as PreviousFileWriter,
13};
14use lance_io::{traits::WriteExt, utils::read_message};
15use lance_linalg::distance::DistanceType;
16use lance_table::io::manifest::ManifestDescribing;
17use log::debug;
18use serde::{Deserialize, Serialize};
19use snafu::location;
20
21use crate::pb::Ivf as PbIvf;
22
23pub const IVF_METADATA_KEY: &str = "lance:ivf";
24pub const IVF_PARTITION_KEY: &str = "lance:ivf:partition";
25
26#[derive(Debug, Clone, PartialEq)]
28pub struct IvfModel {
29 pub centroids: Option<FixedSizeListArray>,
33
34 pub offsets: Vec<usize>,
36
37 pub lengths: Vec<u32>,
39
40 pub loss: Option<f64>,
42}
43
44impl DeepSizeOf for IvfModel {
45 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
46 self.centroids
47 .as_ref()
48 .map(|centroids| centroids.get_array_memory_size())
49 .unwrap_or_default()
50 + self.lengths.deep_size_of_children(context)
51 + self.offsets.deep_size_of_children(context)
52 }
53}
54
55impl IvfModel {
56 pub fn empty() -> Self {
57 Self {
58 centroids: None,
59 offsets: vec![],
60 lengths: vec![],
61 loss: None,
62 }
63 }
64
65 pub fn new(centroids: FixedSizeListArray, loss: Option<f64>) -> Self {
66 Self {
67 centroids: Some(centroids),
68 offsets: vec![],
69 lengths: vec![],
70 loss,
71 }
72 }
73
74 pub fn centroid(&self, partition: usize) -> Option<ArrayRef> {
75 self.centroids.as_ref().map(|c| c.value(partition))
76 }
77
78 pub fn dimension(&self) -> usize {
80 self.centroids
81 .as_ref()
82 .map(|c| c.value_length() as usize)
83 .unwrap_or(0)
84 }
85
86 pub fn num_partitions(&self) -> usize {
88 self.centroids
89 .as_ref()
90 .map(|c| c.len())
91 .unwrap_or_else(|| self.offsets.len())
92 }
93
94 pub fn partition_size(&self, part: usize) -> usize {
95 self.lengths.get(part).copied().unwrap_or_default() as usize
96 }
97
98 pub fn num_rows(&self) -> u64 {
99 self.lengths.iter().map(|x| *x as u64).sum()
100 }
101
102 pub fn loss(&self) -> Option<f64> {
103 self.loss
104 }
105
106 pub fn find_partitions(
108 &self,
109 query: &dyn Array,
110 nprobes: usize,
111 distance_type: DistanceType,
112 ) -> Result<(UInt32Array, Float32Array)> {
113 let internal = crate::vector::ivf::new_ivf_transformer(
114 self.centroids.clone().unwrap(),
115 distance_type,
116 vec![],
117 );
118 internal.find_partitions(query, nprobes)
119 }
120
121 pub fn add_partition(&mut self, len: u32) {
123 self.offsets.push(
124 self.offsets.last().cloned().unwrap_or_default()
125 + self.lengths.last().cloned().unwrap_or_default() as usize,
126 );
127 self.lengths.push(len);
128 }
129
130 pub fn add_partition_with_offset(&mut self, offset: usize, len: u32) {
133 self.offsets.push(offset);
134 self.lengths.push(len);
135 }
136
137 pub fn centroids_array(&self) -> Option<&FixedSizeListArray> {
141 self.centroids.as_ref()
142 }
143
144 pub fn row_range(&self, partition: usize) -> Range<usize> {
145 let start = self.offsets[partition];
146 let end = start + self.lengths[partition] as usize;
147 start..end
148 }
149
150 pub async fn load(reader: &PreviousFileReader) -> Result<Self> {
151 let schema = reader.schema();
152 let meta_str = schema.metadata.get(IVF_METADATA_KEY).ok_or(Error::Index {
153 message: format!("{} not found during search", IVF_METADATA_KEY),
154 location: location!(),
155 })?;
156 let ivf_metadata: IvfMetadata =
157 serde_json::from_str(meta_str).map_err(|e| Error::Index {
158 message: format!("Failed to parse IVF metadata: {}", e),
159 location: location!(),
160 })?;
161
162 let pb: PbIvf = read_message(
163 reader.object_reader.as_ref(),
164 ivf_metadata.pb_position as usize,
165 )
166 .await?;
167 Self::try_from(pb)
168 }
169
170 pub async fn write(&self, writer: &mut PreviousFileWriter<ManifestDescribing>) -> Result<()> {
172 let pb = PbIvf::try_from(self)?;
173 let pos = writer.object_writer.write_protobuf(&pb).await?;
174 let ivf_metadata = IvfMetadata { pb_position: pos };
175 writer.add_metadata(IVF_METADATA_KEY, &serde_json::to_string(&ivf_metadata)?);
176 Ok(())
177 }
178}
179
180impl TryFrom<&IvfModel> for PbIvf {
182 type Error = Error;
183
184 fn try_from(ivf: &IvfModel) -> Result<Self> {
185 let lengths = ivf.lengths.clone();
186
187 Ok(Self {
188 centroids: vec![], lengths,
190 offsets: ivf.offsets.iter().map(|x| *x as u64).collect(),
191 centroids_tensor: ivf.centroids.as_ref().map(|c| c.try_into()).transpose()?,
192 loss: ivf.loss,
193 })
194 }
195}
196
197impl TryFrom<PbIvf> for IvfModel {
199 type Error = Error;
200
201 fn try_from(proto: PbIvf) -> Result<Self> {
202 let centroids = if let Some(tensor) = proto.centroids_tensor.as_ref() {
203 debug!("Ivf: loading IVF centroids from index format v2");
205 Some(FixedSizeListArray::try_from(tensor)?)
206 } else if !proto.centroids.is_empty() {
207 debug!("Ivf: loading IVF centroids from index format v1");
209 let f32_centroids = Float32Array::from(proto.centroids.clone());
210 let dimension = f32_centroids.len() / proto.lengths.len();
211 Some(FixedSizeListArray::try_new_from_values(
212 f32_centroids,
213 dimension as i32,
214 )?)
215 } else {
216 None
219 };
220 let offsets = match proto.offsets.len() {
225 0 => proto
226 .lengths
227 .iter()
228 .scan(0_usize, |state, &x| {
229 let old = *state;
230 *state += x as usize;
231 Some(old)
232 })
233 .collect_vec(),
234 _ => proto.offsets.iter().map(|x| *x as usize).collect(),
235 };
236 assert_eq!(offsets.len(), proto.lengths.len());
237 Ok(Self {
238 centroids,
239 offsets,
240 lengths: proto.lengths,
241 loss: proto.loss,
242 })
243 }
244}
245
246#[derive(Serialize, Deserialize, Debug)]
248struct IvfMetadata {
249 pb_position: usize,
251}
252
253#[cfg(test)]
254mod tests {
255 use std::sync::Arc;
256
257 use arrow_array::{Float32Array, RecordBatch};
258 use arrow_schema::{DataType, Field, Schema as ArrowSchema};
259 use lance_core::datatypes::Schema;
260 use lance_io::object_store::ObjectStore;
261 use lance_table::format::SelfDescribingFileReader;
262 use object_store::path::Path;
263
264 use crate::pb;
265
266 use super::*;
267
268 #[test]
269 fn test_ivf_find_rows() {
270 let mut ivf = IvfModel::empty();
271 ivf.add_partition(20);
272 ivf.add_partition(50);
273
274 assert_eq!(ivf.row_range(0), 0..20);
275 assert_eq!(ivf.row_range(1), 20..70);
276 }
277
278 #[tokio::test]
279 async fn test_write_and_load() {
280 let mut ivf = IvfModel::empty();
281 ivf.add_partition(20);
282 ivf.add_partition(50);
283
284 let object_store = ObjectStore::memory();
285 let path = Path::from("/foo");
286 let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float32, true)]);
287 let schema = Schema::try_from(&arrow_schema).unwrap();
288
289 {
290 let mut writer = PreviousFileWriter::try_new(
291 &object_store,
292 &path,
293 schema.clone(),
294 &Default::default(),
295 )
296 .await
297 .unwrap();
298 let batch = RecordBatch::try_new(
300 Arc::new(arrow_schema),
301 vec![Arc::new(Float32Array::from(vec![Some(1.0)]))],
302 )
303 .unwrap();
304 writer.write(&[batch]).await.unwrap();
305 ivf.write(&mut writer).await.unwrap();
306 writer.finish().await.unwrap();
307 }
308
309 let reader = PreviousFileReader::try_new_self_described(&object_store, &path, None)
310 .await
311 .unwrap();
312 assert!(reader.schema().metadata.contains_key(IVF_METADATA_KEY));
313
314 let ivf2 = IvfModel::load(&reader).await.unwrap();
315 assert_eq!(ivf, ivf2);
316 assert_eq!(ivf2.num_partitions(), 2);
317 }
318
319 #[test]
320 fn test_load_v1_format_ivf() {
321 let pb_ivf = pb::Ivf {
323 centroids: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
324 lengths: vec![2, 2],
325 offsets: vec![0, 2],
326 centroids_tensor: None,
327 loss: None,
328 };
329
330 let ivf = IvfModel::try_from(pb_ivf).unwrap();
331 assert_eq!(ivf.num_partitions(), 2);
332 assert_eq!(ivf.dimension(), 3);
333 assert_eq!(ivf.centroids.as_ref().unwrap().len(), 2);
334 assert_eq!(ivf.centroids.as_ref().unwrap().value_length(), 3);
335 }
336
337 #[test]
338 fn test_centroids_array_getter() {
339 use arrow_array::Float32Array;
340 let values = Float32Array::from(vec![1.0, 2.0, 3.0, 4.0]);
342 let centroids = FixedSizeListArray::try_new_from_values(values, 2).unwrap();
343 let ivf = IvfModel::new(centroids.clone(), None);
344 let out = ivf.centroids_array().unwrap();
345
346 assert_eq!(out.len(), centroids.len());
348 assert_eq!(out.value_length(), centroids.value_length());
349
350 let first = ivf.centroid(0).unwrap();
352 let first_vals = first.as_any().downcast_ref::<Float32Array>().unwrap();
353 assert_eq!(first_vals.len(), 2);
354 assert_eq!(first_vals.value(0), 1.0);
355 assert_eq!(first_vals.value(1), 2.0);
356 }
357}