use std::collections::{HashMap, HashSet};
use std::fs::File;
use std::io::{Read, Write};
use std::path::Path;
use std::sync::Arc;
use std::time::Instant;
use memmap2::{Mmap, MmapOptions};
use regex::Regex;
use serde::{Deserialize, Serialize};
use crate::index::{FtsConfig, Index, IndexConfig, IndexStats};
use crate::traits::{DictError, Result};
use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
const MAX_FTS_BYTES: u64 = 64 * 1024 * 1024;
const MAX_DOCS: usize = 500_000;
type DocId = u32;
type TermId = u32;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Token {
text: String,
term_id: TermId,
position: u32,
doc_freq: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FtsSearchResult {
pub doc_id: DocId,
pub key: String,
pub score: f32,
pub highlights: Vec<(usize, usize)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct InvertedIndexEntry {
term_id: TermId,
term: String,
postings: Vec<Posting>,
doc_freq: u32,
term_freq: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Posting {
doc_id: DocId,
term_freq: u32,
positions: Vec<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Document {
doc_id: DocId,
key: String,
content: Vec<u8>,
doc_length: u32,
}
pub struct FtsIndex {
inverted_index: HashMap<String, InvertedIndexEntry>,
documents: HashMap<DocId, Document>,
next_term_id: TermId,
next_doc_id: DocId,
stop_words: HashSet<String>,
term_stats: HashMap<String, u32>,
config: FtsConfig,
stats: IndexStats,
lock: Arc<RwLock<()>>,
}
impl FtsIndex {
pub fn new() -> Self {
Self {
inverted_index: HashMap::new(),
documents: HashMap::new(),
next_term_id: 1,
next_doc_id: 1,
stop_words: HashSet::new(),
term_stats: HashMap::new(),
config: FtsConfig::default(),
stats: IndexStats {
entries: 0,
size: 0,
build_time: 0,
version: "1.0".to_string(),
config: IndexConfig::default(),
},
lock: Arc::new(RwLock::new(())),
}
}
pub fn with_config(config: FtsConfig) -> Self {
let stop_words = HashSet::from_iter(config.stop_words.clone());
Self {
inverted_index: HashMap::new(),
documents: HashMap::new(),
next_term_id: 1,
next_doc_id: 1,
stop_words,
term_stats: HashMap::new(),
config,
stats: IndexStats {
entries: 0,
size: 0,
build_time: 0,
version: "1.0".to_string(),
config: IndexConfig::default(),
},
lock: Arc::new(RwLock::new(())),
}
}
fn tokenize(&self, text: &str) -> Vec<(String, u32)> {
let word_pattern = Regex::new(r"\p{L}+|\p{N}+").expect("Failed to create word regex");
let mut tokens = Vec::new();
let mut positions = Vec::new();
let mut position = 0u32;
for mat in word_pattern.find_iter(text) {
let token = mat.as_str().to_lowercase();
if self.stop_words.contains(&token) {
continue;
}
if token.len() >= self.config.min_token_len && token.len() <= self.config.max_token_len
{
tokens.push(token);
positions.push(position);
position += 1;
}
}
tokens.into_iter().zip(positions).collect()
}
fn add_document(&mut self, key: String, content: &[u8]) -> Result<DocId> {
let _guard = self.lock.write();
let content_str = String::from_utf8_lossy(content);
let doc_id = self.next_doc_id;
self.next_doc_id += 1;
let tokens = self.tokenize(&content_str);
let mut term_freqs = HashMap::new();
for (term, position) in &tokens {
*term_freqs.entry(term.clone()).or_insert(0) += 1;
}
let document = Document {
doc_id,
key: key.clone(),
content: content.to_vec(),
doc_length: tokens.len() as u32,
};
self.documents.insert(doc_id, document);
for (term, term_freq) in term_freqs {
let positions: Vec<u32> = tokens
.iter()
.filter_map(|(t, pos)| if *t == term { Some(*pos) } else { None })
.collect();
let entry = self.inverted_index.entry(term.clone()).or_insert_with(|| {
let term_id = self.next_term_id;
self.next_term_id = self.next_term_id.saturating_add(1);
InvertedIndexEntry {
term_id,
term: term.clone(),
postings: Vec::new(),
doc_freq: 0,
term_freq: 0,
}
});
entry.postings.push(Posting {
doc_id,
term_freq,
positions,
});
entry.doc_freq = entry.doc_freq.saturating_add(1);
entry.term_freq = entry.term_freq.saturating_add(term_freq);
}
Ok(doc_id)
}
pub fn search(&self, query: &str) -> Result<Vec<(String, f32)>> {
let _guard = self.lock.read();
let query_tokens = self.tokenize(query);
if query_tokens.is_empty() {
return Ok(Vec::new());
}
let total_docs = self.documents.len() as f32;
let mut query_term_weights = HashMap::new();
for (term, _) in &query_tokens {
if let Some(entry) = self.inverted_index.get(term) {
let idf = (total_docs / (entry.doc_freq as f32 + 1.0)).ln() + 1.0;
query_term_weights.insert(term, idf);
}
}
let mut doc_scores = HashMap::<DocId, f32>::new();
for (term, idf) in query_term_weights {
if let Some(entry) = self.inverted_index.get(term.as_str()) {
for posting in &entry.postings {
let tf = 1.0 + (posting.term_freq as f32).ln(); let score = idf * tf;
*doc_scores.entry(posting.doc_id).or_insert(0.0) += score;
}
}
}
let mut results: Vec<_> = doc_scores
.into_iter()
.filter_map(|(doc_id, score)| {
self.documents
.get(&doc_id)
.map(|doc| (doc.key.clone(), score))
})
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(results)
}
pub fn prefix_search(&self, prefix: &str) -> Result<Vec<String>> {
let _guard = self.lock.read();
let mut terms = Vec::new();
for term in self.inverted_index.keys() {
if term.starts_with(prefix) {
terms.push(term.clone());
}
}
terms.sort();
Ok(terms)
}
pub fn term_frequency(&self, term: &str) -> u32 {
self.inverted_index
.get(term)
.map(|e| e.term_freq)
.unwrap_or(0)
}
pub fn document_frequency(&self, term: &str) -> u32 {
self.inverted_index
.get(term)
.map(|e| e.doc_freq)
.unwrap_or(0)
}
pub fn vocabulary_size(&self) -> usize {
self.inverted_index.len()
}
pub fn avg_doc_length(&self) -> f32 {
if self.documents.is_empty() {
return 0.0;
}
let total_length: u32 = self.documents.values().map(|doc| doc.doc_length).sum();
total_length as f32 / self.documents.len() as f32
}
pub fn suggest_spelling(&self, query: &str) -> Result<Vec<String>> {
let _guard = self.lock.read();
let mut suggestions = Vec::new();
for term in self.inverted_index.keys() {
let distance = self.edit_distance(query, term);
if distance <= 2 {
suggestions.push((term.clone(), distance));
}
}
suggestions.sort_by(|a, b| a.1.cmp(&b.1));
Ok(suggestions
.into_iter()
.take(10)
.map(|(term, _)| term)
.collect())
}
fn edit_distance(&self, s1: &str, s2: &str) -> u32 {
let (m, n) = (s1.len(), s2.len());
let mut dp = vec![vec![0u32; n + 1]; m + 1];
for i in 0..=m {
dp[i][0] = i as u32;
}
for j in 0..=n {
dp[0][j] = j as u32;
}
for i in 1..=m {
for j in 1..=n {
if s1.chars().nth(i - 1) == s2.chars().nth(j - 1) {
dp[i][j] = dp[i - 1][j - 1];
} else {
dp[i][j] = 1 + dp[i - 1][j].min(dp[i][j - 1]).min(dp[i - 1][j - 1]);
}
}
}
dp[m][n]
}
pub fn get_snippet(&self, doc_id: DocId, query: &str, max_length: usize) -> Option<String> {
let _guard = self.lock.read();
let doc = self.documents.get(&doc_id)?;
let content = String::from_utf8_lossy(&doc.content);
let query_tokens: Vec<_> = self.tokenize(query).into_iter().map(|(t, _)| t).collect();
let mut first_pos = content.len();
for token in &query_tokens {
if let Some(pos) = content.to_lowercase().find(token) {
first_pos = first_pos.min(pos);
}
}
if first_pos == content.len() {
return None;
}
let start = first_pos.saturating_sub(max_length / 2);
let end = (first_pos + max_length).min(content.len());
let snippet = content[start..end].to_string();
Some(snippet)
}
pub fn validate(&self) -> Result<bool> {
let _guard = self.lock.read();
for (doc_id, doc) in &self.documents {
let content = String::from_utf8_lossy(&doc.content);
let tokens = self.tokenize(&content);
for (term, _) in tokens {
if let Some(entry) = self.inverted_index.get(&term) {
let has_doc = entry.postings.iter().any(|p| p.doc_id == *doc_id);
if !has_doc {
return Ok(false);
}
}
}
}
for entry in self.inverted_index.values() {
let actual_doc_freq = entry.postings.len() as u32;
if actual_doc_freq != entry.doc_freq {
return Ok(false);
}
}
Ok(true)
}
pub fn get_stats(&self) -> &IndexStats {
&self.stats
}
}
impl Index for FtsIndex {
const INDEX_TYPE: &'static str = "fts";
fn build(&mut self, entries: &[(String, Vec<u8>)], config: &IndexConfig) -> Result<()> {
let start_time = Instant::now();
self.inverted_index.clear();
self.documents.clear();
self.next_term_id = 1;
self.next_doc_id = 1;
self.term_stats.clear();
self.config = config.fts_config.clone();
if let Some(max_mem) = config.max_memory {
let estimated: u64 = entries
.iter()
.map(|(k, v)| (k.len() + v.len()) as u64)
.sum();
if estimated > max_mem.saturating_mul(4) {
return Err(DictError::IndexError(
"FTS index build aborted: estimated input too large for configured max_memory"
.to_string(),
));
}
}
for (key, content) in entries {
self.add_document(key.clone(), content)?;
}
self.stats.entries = entries.len() as u64;
self.stats.build_time = start_time.elapsed().as_millis() as u64;
self.stats.size = self.inverted_index.len() as u64 * 100;
if !self.validate()? {
self.clear();
return Err(DictError::IndexError(
"FTS index validation failed; index discarded".to_string(),
));
}
Ok(())
}
fn load(&mut self, path: &Path) -> Result<()> {
let meta = std::fs::metadata(path).map_err(|e| DictError::IoError(e.to_string()))?;
if meta.len() > MAX_FTS_BYTES {
return Err(DictError::IndexError(format!(
"FTS index {} exceeds safety limit ({} bytes)",
path.display(),
meta.len()
)));
}
let file = File::open(path).map_err(|e| DictError::IoError(e.to_string()))?;
let mmap = unsafe {
MmapOptions::new()
.map(&file)
.map_err(|e| DictError::MmapError(e.to_string()))?
};
let serialized_data = &mmap[..];
let (inverted_index, documents, term_stats): (
HashMap<String, InvertedIndexEntry>,
HashMap<DocId, Document>,
HashMap<String, u32>,
) = bincode::deserialize(serialized_data)
.map_err(|e| DictError::SerializationError(e.to_string()))?;
self.inverted_index = inverted_index;
self.documents = documents;
self.term_stats = term_stats;
self.stats.size = std::fs::metadata(path)?.len();
if self.documents.len() > MAX_DOCS {
return Err(DictError::IndexError(format!(
"FTS index {} has too many documents ({})",
path.display(),
self.documents.len()
)));
}
Ok(())
}
fn save(&self, path: &Path) -> Result<()> {
let file = File::create(path).map_err(|e| DictError::IoError(e.to_string()))?;
let data = (&self.inverted_index, &self.documents, &self.term_stats);
let serialized =
bincode::serialize(&data).map_err(|e| DictError::SerializationError(e.to_string()))?;
let mut file = file;
file.write_all(&serialized)
.map_err(|e| DictError::IoError(e.to_string()))?;
Ok(())
}
fn stats(&self) -> &IndexStats {
&self.stats
}
fn is_built(&self) -> bool {
!self.inverted_index.is_empty() && !self.documents.is_empty()
}
fn clear(&mut self) {
self.inverted_index.clear();
self.documents.clear();
self.term_stats.clear();
self.next_term_id = 1;
self.next_doc_id = 1;
self.stats = IndexStats {
entries: 0,
size: 0,
build_time: 0,
version: "1.0".to_string(),
config: IndexConfig::default(),
};
}
fn verify(&self) -> Result<bool> {
self.validate()
}
}
impl Default for FtsIndex {
fn default() -> Self {
Self::new()
}
}