1use std::io::Write;
17use std::sync::Arc;
18
19use arrow_array::cast::AsArray;
20use arrow_array::types::{Float32Type, UInt32Type, UInt64Type};
21use arrow_array::{
22 Array, Float32Array, LargeBinaryArray, ListArray, RecordBatch, UInt32Array, UInt64Array,
23};
24use arrow_schema::{DataType, Field, Schema};
25use bytes::Bytes;
26use lance_arrow::ipc::{
27 read_ipc_stream_single_at, read_len_prefixed_bytes_at, write_ipc_stream,
28 write_len_prefixed_bytes,
29};
30use lance_core::cache::CacheCodecImpl;
31use lance_core::{Error, Result};
32use serde::{Deserialize, Serialize};
33
34use super::index::{
35 CompressedPositionStorage, CompressedPostingList, PlainPostingList, PositionStreamCodec,
36 Positions, PostingList, PostingTailCodec, SharedPositionStream,
37};
38
39const POSTING_VARIANT_PLAIN: u8 = 0;
44const POSTING_VARIANT_COMPRESSED: u8 = 1;
45
46const POSITIONS_TAG_NONE: u8 = 0;
47const POSITIONS_TAG_LEGACY: u8 = 1;
48const POSITIONS_TAG_SHARED: u8 = 2;
49
50const POSTING_TAIL_CODEC_FIXED32: u8 = 0;
51const POSTING_TAIL_CODEC_VARINT_DELTA: u8 = 1;
52
53const POSITION_STREAM_CODEC_VARINT_DOC_DELTA: u8 = 0;
54const POSITION_STREAM_CODEC_PACKED_DELTA: u8 = 1;
55
56fn posting_tail_codec_to_u8(c: PostingTailCodec) -> u8 {
61 match c {
62 PostingTailCodec::Fixed32 => POSTING_TAIL_CODEC_FIXED32,
63 PostingTailCodec::VarintDelta => POSTING_TAIL_CODEC_VARINT_DELTA,
64 }
65}
66
67fn u8_to_posting_tail_codec(v: u8) -> Result<PostingTailCodec> {
68 match v {
69 POSTING_TAIL_CODEC_FIXED32 => Ok(PostingTailCodec::Fixed32),
70 POSTING_TAIL_CODEC_VARINT_DELTA => Ok(PostingTailCodec::VarintDelta),
71 _ => Err(Error::io(format!("unknown posting tail codec: {v}"))),
72 }
73}
74
75fn position_stream_codec_to_u8(c: PositionStreamCodec) -> u8 {
76 match c {
77 PositionStreamCodec::VarintDocDelta => POSITION_STREAM_CODEC_VARINT_DOC_DELTA,
78 PositionStreamCodec::PackedDelta => POSITION_STREAM_CODEC_PACKED_DELTA,
79 }
80}
81
82fn u8_to_position_stream_codec(v: u8) -> Result<PositionStreamCodec> {
83 match v {
84 POSITION_STREAM_CODEC_VARINT_DOC_DELTA => Ok(PositionStreamCodec::VarintDocDelta),
85 POSITION_STREAM_CODEC_PACKED_DELTA => Ok(PositionStreamCodec::PackedDelta),
86 _ => Err(Error::io(format!("unknown position stream codec: {v}"))),
87 }
88}
89
90fn write_json_header(writer: &mut dyn Write, header: &impl Serialize) -> Result<()> {
95 let bytes = serde_json::to_vec(header)?;
96 write_len_prefixed_bytes(writer, &bytes)?;
97 Ok(())
98}
99
100fn read_json_header<T: serde::de::DeserializeOwned>(data: &Bytes, offset: &mut usize) -> Result<T> {
101 let bytes = read_len_prefixed_bytes_at(data, offset).map_err(|e| Error::io(e.to_string()))?;
102 serde_json::from_slice(&bytes)
103 .map_err(|e| Error::io(format!("failed to deserialize cache header: {e}")))
104}
105
106fn write_u8(writer: &mut dyn Write, value: u8) -> Result<()> {
107 writer
108 .write_all(&[value])
109 .map_err(|e| Error::io(format!("failed to write tag byte: {e}")))
110}
111
112fn read_u8(data: &Bytes, offset: &mut usize) -> Result<u8> {
113 let bytes = data.as_ref();
114 if *offset >= bytes.len() {
115 return Err(Error::io(
116 "truncated cache entry: missing tag byte".to_string(),
117 ));
118 }
119 let v = bytes[*offset];
120 *offset += 1;
121 Ok(v)
122}
123
124const POSITION_LIST_COLUMN: &str = "position_list";
129const BLOCK_OFFSETS_COLUMN: &str = "block_offsets";
130const ROW_IDS_COLUMN: &str = "row_ids";
131const FREQUENCIES_COLUMN: &str = "frequencies";
132const BLOCKS_COLUMN: &str = "blocks";
133
134#[derive(Serialize, Deserialize)]
135struct SharedPositionsHeader {
136 codec: u8,
137}
138
139fn write_position_storage(
140 writer: &mut dyn Write,
141 storage: &CompressedPositionStorage,
142) -> Result<()> {
143 match storage {
144 CompressedPositionStorage::LegacyPerDoc(list) => {
145 write_u8(writer, POSITIONS_TAG_LEGACY)?;
146 let schema = Arc::new(Schema::new(vec![Field::new(
147 POSITION_LIST_COLUMN,
148 list.data_type().clone(),
149 list.is_nullable(),
150 )]));
151 let batch = RecordBatch::try_new(schema, vec![Arc::new(list.clone())])?;
152 write_ipc_stream(&batch, writer)?;
153 }
154 CompressedPositionStorage::SharedStream(stream) => {
155 write_u8(writer, POSITIONS_TAG_SHARED)?;
156 let header = SharedPositionsHeader {
157 codec: position_stream_codec_to_u8(stream.codec()),
158 };
159 write_json_header(writer, &header)?;
160
161 let offsets = UInt32Array::from(stream.block_offsets().to_vec());
162 let schema = Arc::new(Schema::new(vec![Field::new(
163 BLOCK_OFFSETS_COLUMN,
164 DataType::UInt32,
165 false,
166 )]));
167 let batch = RecordBatch::try_new(schema, vec![Arc::new(offsets)])?;
168 write_ipc_stream(&batch, writer)?;
169
170 write_len_prefixed_bytes(writer, stream.bytes())?;
171 }
172 }
173 Ok(())
174}
175
176fn read_position_storage(
177 data: &Bytes,
178 offset: &mut usize,
179 tag: u8,
180) -> Result<CompressedPositionStorage> {
181 match tag {
182 POSITIONS_TAG_LEGACY => {
183 let batch =
184 read_ipc_stream_single_at(data, offset).map_err(|e| Error::io(e.to_string()))?;
185 let list = batch
186 .column(0)
187 .as_any()
188 .downcast_ref::<ListArray>()
189 .ok_or_else(|| Error::io("legacy position column is not a ListArray".to_string()))?
190 .clone();
191 Ok(CompressedPositionStorage::LegacyPerDoc(list))
192 }
193 POSITIONS_TAG_SHARED => {
194 let header: SharedPositionsHeader = read_json_header(data, offset)?;
195 let codec = u8_to_position_stream_codec(header.codec)?;
196
197 let batch =
198 read_ipc_stream_single_at(data, offset).map_err(|e| Error::io(e.to_string()))?;
199 let block_offsets = batch
200 .column(0)
201 .as_primitive_opt::<UInt32Type>()
202 .ok_or_else(|| Error::io("block_offsets column is not UInt32".to_string()))?
203 .values()
204 .to_vec();
205
206 let bytes =
210 read_len_prefixed_bytes_at(data, offset).map_err(|e| Error::io(e.to_string()))?;
211
212 Ok(CompressedPositionStorage::SharedStream(
213 SharedPositionStream::new(codec, block_offsets, bytes),
214 ))
215 }
216 other => Err(Error::io(format!("unknown positions tag: {other}"))),
217 }
218}
219
220#[derive(Serialize, Deserialize)]
225struct PlainPostingHeader {
226 max_score: Option<f32>,
227}
228
229#[derive(Serialize, Deserialize)]
230struct CompressedPostingHeader {
231 max_score: f32,
232 length: u32,
233 posting_tail_codec: u8,
234}
235
236impl CacheCodecImpl for PostingList {
237 fn serialize(&self, writer: &mut dyn Write) -> Result<()> {
238 match self {
239 Self::Plain(plain) => {
240 write_u8(writer, POSTING_VARIANT_PLAIN)?;
241 serialize_plain(writer, plain)
242 }
243 Self::Compressed(compressed) => {
244 write_u8(writer, POSTING_VARIANT_COMPRESSED)?;
245 serialize_compressed(writer, compressed)
246 }
247 }
248 }
249
250 fn deserialize(data: &Bytes) -> Result<Self> {
251 let mut offset = 0;
252 let variant = read_u8(data, &mut offset)?;
253 match variant {
254 POSTING_VARIANT_PLAIN => Ok(Self::Plain(deserialize_plain(data, &mut offset)?)),
255 POSTING_VARIANT_COMPRESSED => {
256 Ok(Self::Compressed(deserialize_compressed(data, &mut offset)?))
257 }
258 other => Err(Error::io(format!("unknown PostingList variant: {other}"))),
259 }
260 }
261}
262
263fn serialize_plain(writer: &mut dyn Write, plain: &PlainPostingList) -> Result<()> {
264 let header = PlainPostingHeader {
265 max_score: plain.max_score,
266 };
267 write_json_header(writer, &header)?;
268
269 let row_ids = UInt64Array::new(plain.row_ids.clone(), None);
270 let frequencies = Float32Array::new(plain.frequencies.clone(), None);
271 let schema = Arc::new(Schema::new(vec![
272 Field::new(ROW_IDS_COLUMN, DataType::UInt64, false),
273 Field::new(FREQUENCIES_COLUMN, DataType::Float32, false),
274 ]));
275 let batch = RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(frequencies)])?;
276 write_ipc_stream(&batch, writer)?;
277
278 match &plain.positions {
279 Some(list) => {
280 write_position_storage(
283 writer,
284 &CompressedPositionStorage::LegacyPerDoc(list.clone()),
285 )?;
286 }
287 None => write_u8(writer, POSITIONS_TAG_NONE)?,
288 }
289 Ok(())
290}
291
292fn deserialize_plain(data: &Bytes, offset: &mut usize) -> Result<PlainPostingList> {
293 let header: PlainPostingHeader = read_json_header(data, offset)?;
294
295 let batch = read_ipc_stream_single_at(data, offset).map_err(|e| Error::io(e.to_string()))?;
296 let row_ids = batch
297 .column(0)
298 .as_primitive_opt::<UInt64Type>()
299 .ok_or_else(|| Error::io("row_ids column is not UInt64".to_string()))?
300 .values()
301 .clone();
302 let frequencies = batch
303 .column(1)
304 .as_primitive_opt::<Float32Type>()
305 .ok_or_else(|| Error::io("frequencies column is not Float32".to_string()))?
306 .values()
307 .clone();
308
309 let positions_tag = read_u8(data, offset)?;
310 let positions = match positions_tag {
311 POSITIONS_TAG_NONE => None,
312 POSITIONS_TAG_LEGACY => match read_position_storage(data, offset, positions_tag)? {
313 CompressedPositionStorage::LegacyPerDoc(list) => Some(list),
314 CompressedPositionStorage::SharedStream(_) => {
315 unreachable!("shared stream tag was read as legacy variant (this is a bug)")
316 }
317 },
318 other => {
319 return Err(Error::io(format!(
320 "Plain posting list cannot have positions tag {other}"
321 )));
322 }
323 };
324
325 Ok(PlainPostingList::new(
326 row_ids,
327 frequencies,
328 header.max_score,
329 positions,
330 ))
331}
332
333fn serialize_compressed(writer: &mut dyn Write, posting: &CompressedPostingList) -> Result<()> {
334 let header = CompressedPostingHeader {
335 max_score: posting.max_score,
336 length: posting.length,
337 posting_tail_codec: posting_tail_codec_to_u8(posting.posting_tail_codec),
338 };
339 write_json_header(writer, &header)?;
340
341 let schema = Arc::new(Schema::new(vec![Field::new(
342 BLOCKS_COLUMN,
343 DataType::LargeBinary,
344 false,
345 )]));
346 let batch = RecordBatch::try_new(schema, vec![Arc::new(posting.blocks.clone())])?;
347 write_ipc_stream(&batch, writer)?;
348
349 match &posting.positions {
350 Some(storage) => write_position_storage(writer, storage)?,
351 None => write_u8(writer, POSITIONS_TAG_NONE)?,
352 }
353 Ok(())
354}
355
356fn deserialize_compressed(data: &Bytes, offset: &mut usize) -> Result<CompressedPostingList> {
357 let header: CompressedPostingHeader = read_json_header(data, offset)?;
358 let posting_tail_codec = u8_to_posting_tail_codec(header.posting_tail_codec)?;
359
360 let batch = read_ipc_stream_single_at(data, offset).map_err(|e| Error::io(e.to_string()))?;
361 let blocks = batch
362 .column(0)
363 .as_any()
364 .downcast_ref::<LargeBinaryArray>()
365 .ok_or_else(|| Error::io("blocks column is not a LargeBinaryArray".to_string()))?
366 .clone();
367
368 let positions_tag = read_u8(data, offset)?;
369 let positions = if positions_tag == POSITIONS_TAG_NONE {
370 None
371 } else {
372 Some(read_position_storage(data, offset, positions_tag)?)
373 };
374
375 Ok(CompressedPostingList::new(
376 blocks,
377 header.max_score,
378 header.length,
379 posting_tail_codec,
380 positions,
381 ))
382}
383
384impl CacheCodecImpl for Positions {
389 fn serialize(&self, writer: &mut dyn Write) -> Result<()> {
390 write_position_storage(writer, &self.0)
391 }
392
393 fn deserialize(data: &Bytes) -> Result<Self> {
394 let mut offset = 0;
395 let tag = read_u8(data, &mut offset)?;
396 if tag == POSITIONS_TAG_NONE {
397 return Err(Error::io(
398 "Positions cache entry cannot encode the None variant".to_string(),
399 ));
400 }
401 let storage = read_position_storage(data, &mut offset, tag)?;
402 Ok(Self(storage))
403 }
404}
405
406#[cfg(test)]
411mod tests {
412 use arrow::buffer::ScalarBuffer;
413 use arrow_array::LargeBinaryArray;
414 use arrow_array::builder::{Int32Builder, ListBuilder};
415 use bytes::Bytes;
416 use lance_core::cache::CacheCodecImpl;
417
418 use super::super::index::{
419 CompressedPositionStorage, CompressedPostingList, PlainPostingList, PositionStreamCodec,
420 Positions, PostingList, PostingTailCodec, SharedPositionStream,
421 };
422
423 fn legacy_positions(rows: &[&[i32]]) -> arrow_array::ListArray {
424 let mut builder = ListBuilder::new(Int32Builder::new());
425 for row in rows {
426 for v in *row {
427 builder.values().append_value(*v);
428 }
429 builder.append(true);
430 }
431 builder.finish()
432 }
433
434 fn assert_plain_eq(a: &PlainPostingList, b: &PlainPostingList) {
435 assert_eq!(a.row_ids.as_ref(), b.row_ids.as_ref());
436 assert_eq!(a.frequencies.as_ref(), b.frequencies.as_ref());
437 assert_eq!(a.max_score, b.max_score);
438 match (&a.positions, &b.positions) {
439 (None, None) => {}
440 (Some(x), Some(y)) => assert_eq!(x, y),
441 _ => panic!("positions mismatch"),
442 }
443 }
444
445 fn assert_position_storage_eq(a: &CompressedPositionStorage, b: &CompressedPositionStorage) {
446 match (a, b) {
447 (
448 CompressedPositionStorage::LegacyPerDoc(x),
449 CompressedPositionStorage::LegacyPerDoc(y),
450 ) => assert_eq!(x, y),
451 (
452 CompressedPositionStorage::SharedStream(x),
453 CompressedPositionStorage::SharedStream(y),
454 ) => {
455 assert_eq!(x.codec(), y.codec());
456 assert_eq!(x.block_offsets(), y.block_offsets());
457 assert_eq!(x.bytes(), y.bytes());
458 }
459 _ => panic!("position storage variant mismatch"),
460 }
461 }
462
463 fn roundtrip_posting_list(entry: &PostingList) -> PostingList {
464 let mut buf = Vec::new();
465 entry.serialize(&mut buf).unwrap();
466 PostingList::deserialize(&Bytes::from(buf)).unwrap()
467 }
468
469 fn roundtrip_positions(entry: &Positions) -> Positions {
470 let mut buf = Vec::new();
471 entry.serialize(&mut buf).unwrap();
472 Positions::deserialize(&Bytes::from(buf)).unwrap()
473 }
474
475 fn assert_slice_points_into_bytes(slice: &[u8], bytes: &Bytes) {
476 let slice_start = slice.as_ptr() as usize;
477 let slice_end = slice_start + slice.len();
478 let bytes_start = bytes.as_ptr() as usize;
479 let bytes_end = bytes_start + bytes.len();
480 assert!(
481 slice_start >= bytes_start && slice_end <= bytes_end,
482 "slice [{slice_start:#x}, {slice_end:#x}) should point into bytes \
483 [{bytes_start:#x}, {bytes_end:#x})",
484 );
485 }
486
487 #[test]
488 fn plain_posting_list_no_positions_roundtrip() {
489 let plain = PlainPostingList::new(
490 ScalarBuffer::from(vec![10u64, 20, 30]),
491 ScalarBuffer::from(vec![0.5f32, 1.0, 1.5]),
492 Some(2.0),
493 None,
494 );
495 let entry = PostingList::Plain(plain.clone());
496 match roundtrip_posting_list(&entry) {
497 PostingList::Plain(restored) => assert_plain_eq(&plain, &restored),
498 PostingList::Compressed(_) => panic!("expected Plain variant"),
499 }
500 }
501
502 #[test]
503 fn plain_posting_list_with_positions_roundtrip() {
504 let plain = PlainPostingList::new(
505 ScalarBuffer::from(vec![1u64, 2]),
506 ScalarBuffer::from(vec![1.0f32, 1.0]),
507 None,
508 Some(legacy_positions(&[&[3, 7], &[1, 4, 9]])),
509 );
510 let entry = PostingList::Plain(plain.clone());
511 match roundtrip_posting_list(&entry) {
512 PostingList::Plain(restored) => assert_plain_eq(&plain, &restored),
513 PostingList::Compressed(_) => panic!("expected Plain variant"),
514 }
515 }
516
517 #[test]
518 fn compressed_posting_list_no_positions_roundtrip() {
519 let blocks = LargeBinaryArray::from_opt_vec(vec![
521 Some(&[1u8, 2, 3, 4, 5][..]),
522 Some(&[6, 7, 8, 9, 10][..]),
523 ]);
524 let posting =
525 CompressedPostingList::new(blocks, 3.5, 42, PostingTailCodec::VarintDelta, None);
526 let entry = PostingList::Compressed(posting.clone());
527 match roundtrip_posting_list(&entry) {
528 PostingList::Compressed(restored) => {
529 assert_eq!(restored.max_score, posting.max_score);
530 assert_eq!(restored.length, posting.length);
531 assert_eq!(restored.posting_tail_codec, posting.posting_tail_codec);
532 assert_eq!(restored.blocks, posting.blocks);
533 assert!(restored.positions.is_none());
534 }
535 PostingList::Plain(_) => panic!("expected Compressed variant"),
536 }
537 }
538
539 #[test]
540 fn compressed_posting_list_legacy_positions_roundtrip() {
541 let blocks = LargeBinaryArray::from_opt_vec(vec![Some(&[1u8, 2, 3][..])]);
542 let posting = CompressedPostingList::new(
543 blocks,
544 1.25,
545 5,
546 PostingTailCodec::Fixed32,
547 Some(CompressedPositionStorage::LegacyPerDoc(legacy_positions(
548 &[&[0, 4, 8]],
549 ))),
550 );
551 let entry = PostingList::Compressed(posting.clone());
552 match roundtrip_posting_list(&entry) {
553 PostingList::Compressed(restored) => {
554 assert_eq!(restored.posting_tail_codec, posting.posting_tail_codec);
555 assert_position_storage_eq(
556 restored.positions.as_ref().unwrap(),
557 posting.positions.as_ref().unwrap(),
558 );
559 }
560 PostingList::Plain(_) => panic!("expected Compressed variant"),
561 }
562 }
563
564 #[test]
565 fn compressed_posting_list_shared_stream_roundtrip() {
566 for codec in [
567 PositionStreamCodec::VarintDocDelta,
568 PositionStreamCodec::PackedDelta,
569 ] {
570 let blocks = LargeBinaryArray::from_opt_vec(vec![Some(&[9u8; 16][..])]);
571 let stream = SharedPositionStream::new(
572 codec,
573 vec![0u32, 4, 11],
574 Bytes::from((0u8..32).collect::<Vec<_>>()),
575 );
576 let posting = CompressedPostingList::new(
577 blocks,
578 7.0,
579 3,
580 PostingTailCodec::VarintDelta,
581 Some(CompressedPositionStorage::SharedStream(stream)),
582 );
583 let entry = PostingList::Compressed(posting.clone());
584 match roundtrip_posting_list(&entry) {
585 PostingList::Compressed(restored) => {
586 assert_position_storage_eq(
587 restored.positions.as_ref().unwrap(),
588 posting.positions.as_ref().unwrap(),
589 );
590 }
591 PostingList::Plain(_) => panic!("expected Compressed variant"),
592 }
593 }
594 }
595
596 #[test]
597 fn shared_stream_deserialize_borrows_from_input_bytes() {
598 let blocks = LargeBinaryArray::from_opt_vec(vec![Some(&[9u8; 16][..])]);
599 let expected_stream = SharedPositionStream::new(
600 PositionStreamCodec::PackedDelta,
601 vec![0u32, 4, 11],
602 Bytes::from((0u8..32).collect::<Vec<_>>()),
603 );
604 let posting = CompressedPostingList::new(
605 blocks,
606 7.0,
607 3,
608 PostingTailCodec::VarintDelta,
609 Some(CompressedPositionStorage::SharedStream(
610 expected_stream.clone(),
611 )),
612 );
613 let mut buf = Vec::new();
614 PostingList::Compressed(posting)
615 .serialize(&mut buf)
616 .unwrap();
617 let serialized = Bytes::from(buf);
618
619 let restored = PostingList::deserialize(&serialized).unwrap();
620 let PostingList::Compressed(restored) = restored else {
621 panic!("expected Compressed variant");
622 };
623 let Some(CompressedPositionStorage::SharedStream(stream)) = restored.positions else {
624 panic!("expected shared-stream positions");
625 };
626
627 assert_eq!(stream.codec(), expected_stream.codec());
628 assert_eq!(stream.block_offsets(), expected_stream.block_offsets());
629 assert_eq!(stream.bytes(), expected_stream.bytes());
630 assert_slice_points_into_bytes(stream.bytes(), &serialized);
631 }
632
633 #[test]
634 fn positions_legacy_roundtrip() {
635 let positions = Positions(CompressedPositionStorage::LegacyPerDoc(legacy_positions(
636 &[&[1, 2, 3], &[], &[10]],
637 )));
638 let restored = roundtrip_positions(&positions);
639 assert_position_storage_eq(&positions.0, &restored.0);
640 }
641
642 #[test]
643 fn positions_shared_stream_roundtrip() {
644 let stream = SharedPositionStream::new(
645 PositionStreamCodec::PackedDelta,
646 vec![0u32, 8],
647 Bytes::from(vec![0xAAu8; 24]),
648 );
649 let positions = Positions(CompressedPositionStorage::SharedStream(stream));
650 let restored = roundtrip_positions(&positions);
651 assert_position_storage_eq(&positions.0, &restored.0);
652 }
653
654 #[test]
655 fn truncated_data_errors() {
656 let plain = PlainPostingList::new(
657 ScalarBuffer::from(vec![1u64]),
658 ScalarBuffer::from(vec![1.0f32]),
659 None,
660 None,
661 );
662 let entry = PostingList::Plain(plain);
663 let mut buf = Vec::new();
664 entry.serialize(&mut buf).unwrap();
665 buf.truncate(buf.len() / 2);
666 assert!(PostingList::deserialize(&Bytes::from(buf)).is_err());
667 }
668}