use crate::{Result, StorageError};
use roaring::RoaringBitmap;
use std::collections::HashMap;
use serde::{Serialize, Deserialize};
pub type TermId = u32;
pub type DocId = u64;
pub type Position = u32;
#[derive(Debug, Clone)]
pub struct Token {
pub text: String,
pub position: Position,
}
pub trait Tokenizer: Send + Sync {
fn tokenize(&self, text: &str) -> Vec<Token>;
fn name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct WhitespaceTokenizer {
pub case_sensitive: bool,
pub min_len: usize,
pub max_len: usize,
}
impl Default for WhitespaceTokenizer {
fn default() -> Self {
Self {
case_sensitive: false,
min_len: 1,
max_len: 64,
}
}
}
impl Tokenizer for WhitespaceTokenizer {
fn tokenize(&self, text: &str) -> Vec<Token> {
let normalized = if self.case_sensitive {
text.to_string()
} else {
text.to_lowercase()
};
normalized
.split(|c: char| !c.is_alphanumeric() && c != '_')
.enumerate()
.filter_map(|(i, s)| {
if s.len() >= self.min_len && s.len() <= self.max_len {
Some(Token {
text: s.to_string(),
position: i as Position,
})
} else {
None
}
})
.collect()
}
fn name(&self) -> &str {
"whitespace"
}
}
#[derive(Debug, Clone)]
pub struct NgramTokenizer {
pub n: usize,
pub case_sensitive: bool,
}
impl NgramTokenizer {
pub fn new(n: usize) -> Self {
Self {
n,
case_sensitive: false,
}
}
}
impl Tokenizer for NgramTokenizer {
fn tokenize(&self, text: &str) -> Vec<Token> {
let normalized = if self.case_sensitive {
text.to_string()
} else {
text.to_lowercase()
};
let chars: Vec<char> = normalized.chars().collect();
if chars.len() < self.n {
return vec![];
}
chars
.windows(self.n)
.enumerate()
.map(|(i, window)| Token {
text: window.iter().collect(),
position: i as Position,
})
.collect()
}
fn name(&self) -> &str {
"ngram"
}
}
#[derive(Debug, Clone, Copy)]
pub struct BM25Config {
pub k1: f32,
pub b: f32,
}
impl Default for BM25Config {
fn default() -> Self {
Self {
k1: 1.5,
b: 0.75,
}
}
}
#[derive(Debug, Clone)]
pub struct PostingList {
doc_ids: RoaringBitmap,
doc_freqs: Vec<u16>,
positions: Option<HashMap<DocId, Vec<Position>>>,
}
impl Serialize for PostingList {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("PostingList", 3)?;
let doc_ids: Vec<u32> = self.doc_ids.iter().collect();
state.serialize_field("doc_ids", &doc_ids)?;
state.serialize_field("doc_freqs", &self.doc_freqs)?;
state.serialize_field("positions", &self.positions)?;
state.end()
}
}
impl<'de> Deserialize<'de> for PostingList {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
struct Helper {
doc_ids: Vec<u32>,
doc_freqs: Vec<u16>,
positions: Option<HashMap<DocId, Vec<Position>>>,
}
let helper = Helper::deserialize(deserializer)?;
let doc_ids = RoaringBitmap::from_iter(helper.doc_ids);
Ok(PostingList {
doc_ids,
doc_freqs: helper.doc_freqs,
positions: helper.positions,
})
}
}
impl Default for PostingList {
fn default() -> Self {
Self::new()
}
}
impl PostingList {
pub fn new() -> Self {
Self {
doc_ids: RoaringBitmap::new(),
doc_freqs: Vec::new(),
positions: Some(HashMap::new()),
}
}
pub fn new_without_positions(disable_positions: bool) -> Self {
Self {
doc_ids: RoaringBitmap::new(),
doc_freqs: Vec::new(),
positions: if disable_positions { None } else { Some(HashMap::new()) },
}
}
pub fn serialize_compact(&self) -> Result<Vec<u8>> {
let mut buf = Vec::new();
self.doc_ids.serialize_into(&mut buf)
.map_err(|e| StorageError::Serialization(format!("RoaringBitmap serialize error: {}", e)))?;
let doc_freqs: Vec<u16> = if let Some(ref pos_map) = self.positions {
self.doc_ids.iter()
.map(|id| pos_map.get(&(id as u64)).map(|v| v.len() as u16).unwrap_or(1))
.collect()
} else {
vec![1u16; self.doc_ids.len() as usize]
};
buf.extend_from_slice(&(doc_freqs.len() as u32).to_le_bytes());
for &freq in &doc_freqs {
buf.extend_from_slice(&freq.to_le_bytes());
}
Ok(buf)
}
pub fn deserialize_compact(buf: &[u8]) -> Result<Self> {
use std::io::Cursor;
if buf.is_empty() {
return Err(StorageError::InvalidData("Empty buffer".into()));
}
let mut cursor = Cursor::new(buf);
let doc_ids = RoaringBitmap::deserialize_from(&mut cursor)
.map_err(|e| StorageError::Serialization(
format!("RoaringBitmap deserialize error (buf_len={}): {}", buf.len(), e)
))?;
let offset = cursor.position() as usize;
if offset + 4 > buf.len() {
return Err(StorageError::InvalidData(
format!("Buffer too small for doc_freqs count: offset={}, buf_len={}", offset, buf.len())
));
}
let count = u32::from_le_bytes([
buf[offset], buf[offset+1], buf[offset+2], buf[offset+3]
]) as usize;
let mut offset = offset + 4;
let mut doc_freqs = Vec::with_capacity(count);
for _ in 0..count {
if offset + 2 > buf.len() {
return Err(StorageError::InvalidData("Buffer too small for doc_freqs".into()));
}
let freq = u16::from_le_bytes([buf[offset], buf[offset+1]]);
doc_freqs.push(freq);
offset += 2;
}
Ok(PostingList {
doc_ids,
doc_freqs,
positions: None, })
}
pub fn add(&mut self, doc_id: DocId, position: Option<Position>) {
self.doc_ids.insert(doc_id as u32);
if let Some(pos) = position {
if let Some(ref mut pos_map) = self.positions {
pos_map.entry(doc_id).or_default().push(pos);
}
}
}
pub fn add_multiple(&mut self, doc_id: DocId, _count: u16, positions: Option<Vec<Position>>) {
self.doc_ids.insert(doc_id as u32);
if let Some(pos_vec) = positions {
if let Some(ref mut pos_map) = self.positions {
pos_map.entry(doc_id).or_default().extend(pos_vec);
}
}
}
fn rebuild_doc_freqs_array(&mut self) {
let old_freqs_map: HashMap<u64, u16> = self.doc_ids.iter()
.zip(self.doc_freqs.iter())
.map(|(id, &freq)| (id as u64, freq))
.collect();
self.doc_freqs = self.doc_ids.iter()
.map(|id| *old_freqs_map.get(&(id as u64)).unwrap_or(&0))
.collect();
}
pub fn merge(&mut self, other: &PostingList) {
let mut freq_map: HashMap<u64, u16> = self.doc_ids.iter()
.zip(self.doc_freqs.iter())
.map(|(id, &freq)| (id as u64, freq))
.collect();
for (id, &freq) in other.doc_ids.iter().zip(other.doc_freqs.iter()) {
*freq_map.entry(id as u64).or_insert(0) += freq;
}
self.doc_ids |= &other.doc_ids;
self.doc_freqs = self.doc_ids.iter()
.map(|id| *freq_map.get(&(id as u64)).unwrap_or(&0))
.collect();
if let (Some(ref mut self_pos), Some(ref other_pos)) = (&mut self.positions, &other.positions) {
for (doc_id, positions) in other_pos {
self_pos.entry(*doc_id).or_default().extend(positions);
}
}
}
pub fn doc_ids(&self) -> Vec<DocId> {
self.doc_ids.iter().map(|id| id as DocId).collect()
}
pub fn doc_count(&self) -> u64 {
self.doc_ids.len()
}
pub fn term_frequency(&self, doc_id: DocId) -> u16 {
if let Some(ref pos_map) = self.positions {
pos_map.get(&doc_id).map(|v| v.len() as u16).unwrap_or(1)
} else {
if self.doc_ids.contains(doc_id as u32) { 1 } else { 0 }
}
}
pub fn contains(&self, doc_id: DocId) -> bool {
self.doc_ids.contains(doc_id as u32)
}
pub fn get_positions(&self, doc_id: DocId) -> Option<&[Position]> {
self.positions.as_ref()?.get(&doc_id).map(|v| v.as_slice())
}
pub fn remove(&mut self, doc_id: DocId) {
if !self.doc_ids.contains(doc_id as u32) {
return;
}
self.doc_ids.remove(doc_id as u32);
if let Some(ref mut pos_map) = self.positions {
pos_map.remove(&doc_id);
}
self.rebuild_doc_freqs_array();
}
pub fn is_empty(&self) -> bool {
self.doc_ids.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_whitespace_tokenizer() {
let tokenizer = WhitespaceTokenizer::default();
let tokens = tokenizer.tokenize("Hello, World!");
assert_eq!(tokens.len(), 2);
assert_eq!(tokens[0].text, "hello");
assert_eq!(tokens[1].text, "world");
}
#[test]
fn test_ngram_tokenizer() {
let tokenizer = NgramTokenizer::new(2);
let tokens = tokenizer.tokenize("rust");
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0].text, "ru");
assert_eq!(tokens[1].text, "us");
assert_eq!(tokens[2].text, "st");
}
#[test]
fn test_posting_list() {
let mut posting = PostingList::new();
posting.add(1, Some(0));
posting.add(1, Some(5));
posting.add(2, Some(3));
assert_eq!(posting.doc_count(), 2);
assert_eq!(posting.term_frequency(1), 2);
assert_eq!(posting.term_frequency(2), 1);
}
}