1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
6use std::io::{self, Read, Write};
7
8use crate::DocId;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct Posting {
13 pub doc_id: DocId,
14 pub term_freq: u32,
15}
16
17#[derive(Debug, Clone, Default)]
19pub struct PostingList {
20 postings: Vec<Posting>,
21}
22
23impl PostingList {
24 pub fn new() -> Self {
25 Self::default()
26 }
27
28 pub fn with_capacity(capacity: usize) -> Self {
29 Self {
30 postings: Vec::with_capacity(capacity),
31 }
32 }
33
34 pub fn push(&mut self, doc_id: DocId, term_freq: u32) {
36 debug_assert!(
37 self.postings.is_empty() || self.postings.last().unwrap().doc_id < doc_id,
38 "Postings must be added in sorted order"
39 );
40 self.postings.push(Posting { doc_id, term_freq });
41 }
42
43 pub fn add(&mut self, doc_id: DocId, term_freq: u32) {
45 if let Some(last) = self.postings.last_mut()
46 && last.doc_id == doc_id
47 {
48 last.term_freq += term_freq;
49 return;
50 }
51 self.postings.push(Posting { doc_id, term_freq });
52 }
53
54 pub fn doc_count(&self) -> u32 {
56 self.postings.len() as u32
57 }
58
59 pub fn len(&self) -> usize {
60 self.postings.len()
61 }
62
63 pub fn is_empty(&self) -> bool {
64 self.postings.is_empty()
65 }
66
67 pub fn iter(&self) -> impl Iterator<Item = &Posting> {
68 self.postings.iter()
69 }
70
71 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
73 write_vint(writer, self.postings.len() as u64)?;
75
76 let mut prev_doc_id = 0u32;
77 for posting in &self.postings {
78 let delta = posting.doc_id - prev_doc_id;
80 write_vint(writer, delta as u64)?;
81 write_vint(writer, posting.term_freq as u64)?;
82 prev_doc_id = posting.doc_id;
83 }
84
85 Ok(())
86 }
87
88 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
90 let count = read_vint(reader)? as usize;
91 let mut postings = Vec::with_capacity(count);
92
93 let mut prev_doc_id = 0u32;
94 for _ in 0..count {
95 let delta = read_vint(reader)? as u32;
96 let term_freq = read_vint(reader)? as u32;
97 let doc_id = prev_doc_id + delta;
98 postings.push(Posting { doc_id, term_freq });
99 prev_doc_id = doc_id;
100 }
101
102 Ok(Self { postings })
103 }
104}
105
106pub struct PostingListIterator<'a> {
108 postings: &'a [Posting],
109 position: usize,
110}
111
112impl<'a> PostingListIterator<'a> {
113 pub fn new(posting_list: &'a PostingList) -> Self {
114 Self {
115 postings: &posting_list.postings,
116 position: 0,
117 }
118 }
119
120 pub fn doc(&self) -> DocId {
122 if self.position < self.postings.len() {
123 self.postings[self.position].doc_id
124 } else {
125 TERMINATED
126 }
127 }
128
129 pub fn term_freq(&self) -> u32 {
131 if self.position < self.postings.len() {
132 self.postings[self.position].term_freq
133 } else {
134 0
135 }
136 }
137
138 pub fn advance(&mut self) -> DocId {
140 self.position += 1;
141 self.doc()
142 }
143
144 pub fn seek(&mut self, target: DocId) -> DocId {
146 while self.position < self.postings.len() {
148 if self.postings[self.position].doc_id >= target {
149 return self.postings[self.position].doc_id;
150 }
151 self.position += 1;
152 }
153 TERMINATED
154 }
155
156 pub fn size_hint(&self) -> usize {
158 self.postings.len().saturating_sub(self.position)
159 }
160}
161
162pub const TERMINATED: DocId = DocId::MAX;
164
165fn write_vint<W: Write>(writer: &mut W, mut value: u64) -> io::Result<()> {
167 loop {
168 let byte = (value & 0x7F) as u8;
169 value >>= 7;
170 if value == 0 {
171 writer.write_u8(byte)?;
172 return Ok(());
173 } else {
174 writer.write_u8(byte | 0x80)?;
175 }
176 }
177}
178
179fn read_vint<R: Read>(reader: &mut R) -> io::Result<u64> {
181 let mut result = 0u64;
182 let mut shift = 0;
183
184 loop {
185 let byte = reader.read_u8()?;
186 result |= ((byte & 0x7F) as u64) << shift;
187 if byte & 0x80 == 0 {
188 return Ok(result);
189 }
190 shift += 7;
191 if shift >= 64 {
192 return Err(io::Error::new(
193 io::ErrorKind::InvalidData,
194 "varint too long",
195 ));
196 }
197 }
198}
199
200#[allow(dead_code)]
202#[derive(Debug, Clone)]
203pub struct CompactPostingList {
204 data: Vec<u8>,
205 doc_count: u32,
206}
207
208#[allow(dead_code)]
209impl CompactPostingList {
210 pub fn from_posting_list(list: &PostingList) -> io::Result<Self> {
212 let mut data = Vec::new();
213 list.serialize(&mut data)?;
214 Ok(Self {
215 doc_count: list.len() as u32,
216 data,
217 })
218 }
219
220 pub fn as_bytes(&self) -> &[u8] {
222 &self.data
223 }
224
225 pub fn doc_count(&self) -> u32 {
227 self.doc_count
228 }
229
230 pub fn to_posting_list(&self) -> io::Result<PostingList> {
232 PostingList::deserialize(&mut &self.data[..])
233 }
234}
235
236pub const BLOCK_SIZE: usize = 128;
239
240#[derive(Debug, Clone)]
241pub struct BlockPostingList {
242 skip_list: Vec<(DocId, u32)>,
244 data: Vec<u8>,
246 doc_count: u32,
248}
249
250impl BlockPostingList {
251 pub fn from_posting_list(list: &PostingList) -> io::Result<Self> {
253 let mut skip_list = Vec::new();
254 let mut data = Vec::new();
255
256 let postings = &list.postings;
257 let mut i = 0;
258
259 while i < postings.len() {
260 let block_start = data.len() as u32;
261 let block_end = (i + BLOCK_SIZE).min(postings.len());
262 let block = &postings[i..block_end];
263
264 let last_doc_id = block.last().unwrap().doc_id;
266 skip_list.push((last_doc_id, block_start));
267
268 let mut prev_doc_id = if i == 0 { 0 } else { postings[i - 1].doc_id };
270 write_vint(&mut data, block.len() as u64)?;
271
272 for posting in block {
273 let delta = posting.doc_id - prev_doc_id;
274 write_vint(&mut data, delta as u64)?;
275 write_vint(&mut data, posting.term_freq as u64)?;
276 prev_doc_id = posting.doc_id;
277 }
278
279 i = block_end;
280 }
281
282 Ok(Self {
283 skip_list,
284 data,
285 doc_count: postings.len() as u32,
286 })
287 }
288
289 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
291 writer.write_u32::<LittleEndian>(self.doc_count)?;
293
294 writer.write_u32::<LittleEndian>(self.skip_list.len() as u32)?;
296 for (doc_id, offset) in &self.skip_list {
297 writer.write_u32::<LittleEndian>(*doc_id)?;
298 writer.write_u32::<LittleEndian>(*offset)?;
299 }
300
301 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
303 writer.write_all(&self.data)?;
304
305 Ok(())
306 }
307
308 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
310 let doc_count = reader.read_u32::<LittleEndian>()?;
311
312 let skip_count = reader.read_u32::<LittleEndian>()? as usize;
313 let mut skip_list = Vec::with_capacity(skip_count);
314 for _ in 0..skip_count {
315 let doc_id = reader.read_u32::<LittleEndian>()?;
316 let offset = reader.read_u32::<LittleEndian>()?;
317 skip_list.push((doc_id, offset));
318 }
319
320 let data_len = reader.read_u32::<LittleEndian>()? as usize;
321 let mut data = vec![0u8; data_len];
322 reader.read_exact(&mut data)?;
323
324 Ok(Self {
325 skip_list,
326 data,
327 doc_count,
328 })
329 }
330
331 pub fn doc_count(&self) -> u32 {
332 self.doc_count
333 }
334
335 pub fn iterator(&self) -> BlockPostingIterator<'_> {
337 BlockPostingIterator::new(self)
338 }
339
340 pub fn into_iterator(self) -> BlockPostingIterator<'static> {
342 BlockPostingIterator::owned(self)
343 }
344}
345
346pub struct BlockPostingIterator<'a> {
349 block_list: std::borrow::Cow<'a, BlockPostingList>,
350 current_block: usize,
351 block_postings: Vec<Posting>,
352 position_in_block: usize,
353 exhausted: bool,
354}
355
356#[allow(dead_code)]
358pub type OwnedBlockPostingIterator = BlockPostingIterator<'static>;
359
360impl<'a> BlockPostingIterator<'a> {
361 fn new(block_list: &'a BlockPostingList) -> Self {
362 let exhausted = block_list.skip_list.is_empty();
363 let mut iter = Self {
364 block_list: std::borrow::Cow::Borrowed(block_list),
365 current_block: 0,
366 block_postings: Vec::new(),
367 position_in_block: 0,
368 exhausted,
369 };
370 if !iter.exhausted {
371 iter.load_block(0);
372 }
373 iter
374 }
375
376 fn owned(block_list: BlockPostingList) -> BlockPostingIterator<'static> {
377 let exhausted = block_list.skip_list.is_empty();
378 let mut iter = BlockPostingIterator {
379 block_list: std::borrow::Cow::Owned(block_list),
380 current_block: 0,
381 block_postings: Vec::new(),
382 position_in_block: 0,
383 exhausted,
384 };
385 if !iter.exhausted {
386 iter.load_block(0);
387 }
388 iter
389 }
390
391 fn load_block(&mut self, block_idx: usize) {
392 if block_idx >= self.block_list.skip_list.len() {
393 self.exhausted = true;
394 return;
395 }
396
397 self.current_block = block_idx;
398 self.position_in_block = 0;
399
400 let offset = self.block_list.skip_list[block_idx].1 as usize;
401 let mut reader = &self.block_list.data[offset..];
402
403 let count = read_vint(&mut reader).unwrap_or(0) as usize;
404 self.block_postings.clear();
405 self.block_postings.reserve(count);
406
407 let mut prev_doc_id = if block_idx == 0 {
408 0
409 } else {
410 self.block_list.skip_list[block_idx - 1].0
411 };
412
413 for _ in 0..count {
414 if let (Ok(delta), Ok(tf)) = (read_vint(&mut reader), read_vint(&mut reader)) {
415 let doc_id = prev_doc_id + delta as u32;
416 self.block_postings.push(Posting {
417 doc_id,
418 term_freq: tf as u32,
419 });
420 prev_doc_id = doc_id;
421 }
422 }
423 }
424
425 pub fn doc(&self) -> DocId {
426 if self.exhausted {
427 TERMINATED
428 } else if self.position_in_block < self.block_postings.len() {
429 self.block_postings[self.position_in_block].doc_id
430 } else {
431 TERMINATED
432 }
433 }
434
435 pub fn term_freq(&self) -> u32 {
436 if self.exhausted || self.position_in_block >= self.block_postings.len() {
437 0
438 } else {
439 self.block_postings[self.position_in_block].term_freq
440 }
441 }
442
443 pub fn advance(&mut self) -> DocId {
444 if self.exhausted {
445 return TERMINATED;
446 }
447
448 self.position_in_block += 1;
449 if self.position_in_block >= self.block_postings.len() {
450 self.load_block(self.current_block + 1);
451 }
452 self.doc()
453 }
454
455 pub fn seek(&mut self, target: DocId) -> DocId {
456 if self.exhausted {
457 return TERMINATED;
458 }
459
460 let target_block = self
461 .block_list
462 .skip_list
463 .iter()
464 .position(|(last_doc, _)| *last_doc >= target);
465
466 if let Some(block_idx) = target_block {
467 if block_idx != self.current_block {
468 self.load_block(block_idx);
469 }
470
471 while self.position_in_block < self.block_postings.len() {
472 if self.block_postings[self.position_in_block].doc_id >= target {
473 return self.doc();
474 }
475 self.position_in_block += 1;
476 }
477
478 self.load_block(self.current_block + 1);
479 self.seek(target)
480 } else {
481 self.exhausted = true;
482 TERMINATED
483 }
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 #[test]
492 fn test_posting_list_basic() {
493 let mut list = PostingList::new();
494 list.push(1, 2);
495 list.push(5, 1);
496 list.push(10, 3);
497
498 assert_eq!(list.len(), 3);
499
500 let mut iter = PostingListIterator::new(&list);
501 assert_eq!(iter.doc(), 1);
502 assert_eq!(iter.term_freq(), 2);
503
504 assert_eq!(iter.advance(), 5);
505 assert_eq!(iter.term_freq(), 1);
506
507 assert_eq!(iter.advance(), 10);
508 assert_eq!(iter.term_freq(), 3);
509
510 assert_eq!(iter.advance(), TERMINATED);
511 }
512
513 #[test]
514 fn test_posting_list_serialization() {
515 let mut list = PostingList::new();
516 for i in 0..100 {
517 list.push(i * 3, (i % 5) + 1);
518 }
519
520 let mut buffer = Vec::new();
521 list.serialize(&mut buffer).unwrap();
522
523 let deserialized = PostingList::deserialize(&mut &buffer[..]).unwrap();
524 assert_eq!(deserialized.len(), list.len());
525
526 for (a, b) in list.iter().zip(deserialized.iter()) {
527 assert_eq!(a, b);
528 }
529 }
530
531 #[test]
532 fn test_posting_list_seek() {
533 let mut list = PostingList::new();
534 for i in 0..100 {
535 list.push(i * 2, 1);
536 }
537
538 let mut iter = PostingListIterator::new(&list);
539
540 assert_eq!(iter.seek(50), 50);
541 assert_eq!(iter.seek(51), 52);
542 assert_eq!(iter.seek(200), TERMINATED);
543 }
544
545 #[test]
546 fn test_block_posting_list() {
547 let mut list = PostingList::new();
548 for i in 0..500 {
549 list.push(i * 2, (i % 10) + 1);
550 }
551
552 let block_list = BlockPostingList::from_posting_list(&list).unwrap();
553 assert_eq!(block_list.doc_count(), 500);
554
555 let mut iter = block_list.iterator();
556 assert_eq!(iter.doc(), 0);
557 assert_eq!(iter.term_freq(), 1);
558
559 assert_eq!(iter.seek(500), 500);
561 assert_eq!(iter.seek(998), 998);
562 assert_eq!(iter.seek(1000), TERMINATED);
563 }
564
565 #[test]
566 fn test_block_posting_list_serialization() {
567 let mut list = PostingList::new();
568 for i in 0..300 {
569 list.push(i * 3, i + 1);
570 }
571
572 let block_list = BlockPostingList::from_posting_list(&list).unwrap();
573
574 let mut buffer = Vec::new();
575 block_list.serialize(&mut buffer).unwrap();
576
577 let deserialized = BlockPostingList::deserialize(&mut &buffer[..]).unwrap();
578 assert_eq!(deserialized.doc_count(), block_list.doc_count());
579
580 let mut iter1 = block_list.iterator();
582 let mut iter2 = deserialized.iterator();
583
584 while iter1.doc() != TERMINATED {
585 assert_eq!(iter1.doc(), iter2.doc());
586 assert_eq!(iter1.term_freq(), iter2.term_freq());
587 iter1.advance();
588 iter2.advance();
589 }
590 assert_eq!(iter2.doc(), TERMINATED);
591 }
592}