#![allow(dead_code)]
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct IndexDocument {
pub id: String,
pub fields: HashMap<String, String>,
}
impl IndexDocument {
pub fn new(id: &str) -> Self {
Self {
id: id.to_string(),
fields: HashMap::new(),
}
}
pub fn with_field(mut self, name: &str, value: &str) -> Self {
self.fields.insert(name.to_string(), value.to_string());
self
}
}
#[derive(Debug, Clone)]
pub struct SearchHit {
pub doc_id: String,
pub score: f64,
pub matched_fields: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BoolOp {
And,
Or,
}
#[derive(Debug, Clone)]
pub struct SearchQuery {
pub terms: Vec<String>,
pub bool_op: BoolOp,
pub field: Option<String>,
pub limit: usize,
}
impl SearchQuery {
pub fn new(text: &str) -> Self {
let terms = tokenize(text);
Self {
terms,
bool_op: BoolOp::And,
field: None,
limit: 100,
}
}
pub fn with_op(mut self, op: BoolOp) -> Self {
self.bool_op = op;
self
}
pub fn with_field(mut self, field: &str) -> Self {
self.field = Some(field.to_string());
self
}
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
}
#[derive(Debug, Clone)]
struct PostingEntry {
doc_id: String,
field: String,
count: u32,
}
#[derive(Debug)]
pub struct SearchIndex {
postings: HashMap<String, Vec<PostingEntry>>,
doc_count: usize,
doc_lengths: HashMap<String, usize>,
doc_ids: HashSet<String>,
}
impl SearchIndex {
pub fn new() -> Self {
Self {
postings: HashMap::new(),
doc_count: 0,
doc_lengths: HashMap::new(),
doc_ids: HashSet::new(),
}
}
pub fn doc_count(&self) -> usize {
self.doc_count
}
pub fn term_count(&self) -> usize {
self.postings.len()
}
pub fn add(&mut self, doc: &IndexDocument) {
if self.doc_ids.contains(&doc.id) {
return; }
self.doc_ids.insert(doc.id.clone());
self.doc_count += 1;
let mut total_tokens = 0usize;
for (field_name, field_value) in &doc.fields {
let tokens = tokenize(field_value);
let mut tf: HashMap<String, u32> = HashMap::new();
for t in &tokens {
*tf.entry(t.clone()).or_insert(0) += 1;
}
total_tokens += tokens.len();
for (term, count) in tf {
self.postings.entry(term).or_default().push(PostingEntry {
doc_id: doc.id.clone(),
field: field_name.clone(),
count,
});
}
}
self.doc_lengths.insert(doc.id.clone(), total_tokens);
}
pub fn remove(&mut self, doc_id: &str) {
if !self.doc_ids.remove(doc_id) {
return;
}
self.doc_count -= 1;
self.doc_lengths.remove(doc_id);
for postings in self.postings.values_mut() {
postings.retain(|p| p.doc_id != doc_id);
}
self.postings.retain(|_, v| !v.is_empty());
}
#[allow(clippy::cast_precision_loss)]
pub fn search(&self, query: &SearchQuery) -> Vec<SearchHit> {
if query.terms.is_empty() || self.doc_count == 0 {
return Vec::new();
}
let mut doc_scores: HashMap<String, (f64, HashSet<String>)> = HashMap::new();
let mut term_doc_sets: Vec<HashSet<String>> = Vec::new();
for term in &query.terms {
let mut term_docs = HashSet::new();
if let Some(postings) = self.postings.get(term) {
let df = postings
.iter()
.map(|p| &p.doc_id)
.collect::<HashSet<_>>()
.len();
let idf = (self.doc_count as f64 / df.max(1) as f64).ln().max(0.0);
for posting in postings {
if let Some(ref field_filter) = query.field {
if &posting.field != field_filter {
continue;
}
}
let doc_len = *self.doc_lengths.get(&posting.doc_id).unwrap_or(&1);
let tf = posting.count as f64 / doc_len.max(1) as f64;
let score = tf * idf;
let entry = doc_scores
.entry(posting.doc_id.clone())
.or_insert_with(|| (0.0, HashSet::new()));
entry.0 += score;
entry.1.insert(posting.field.clone());
term_docs.insert(posting.doc_id.clone());
}
}
term_doc_sets.push(term_docs);
}
let candidate_ids: HashSet<String> =
if query.bool_op == BoolOp::And && !term_doc_sets.is_empty() {
let mut result = term_doc_sets[0].clone();
for s in &term_doc_sets[1..] {
result = result.intersection(s).cloned().collect();
}
result
} else {
term_doc_sets.into_iter().flatten().collect()
};
let mut hits: Vec<SearchHit> = doc_scores
.into_iter()
.filter(|(id, _)| candidate_ids.contains(id))
.map(|(id, (score, fields))| SearchHit {
doc_id: id,
score,
matched_fields: fields.into_iter().collect(),
})
.collect();
hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
hits.truncate(query.limit);
hits
}
pub fn contains(&self, doc_id: &str) -> bool {
self.doc_ids.contains(doc_id)
}
#[allow(clippy::cast_precision_loss)]
pub fn avg_doc_length(&self) -> f64 {
if self.doc_count == 0 {
return 0.0;
}
let total: usize = self.doc_lengths.values().sum();
total as f64 / self.doc_count as f64
}
}
impl Default for SearchIndex {
fn default() -> Self {
Self::new()
}
}
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.map(String::from)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn build_index() -> SearchIndex {
let mut idx = SearchIndex::new();
idx.add(
&IndexDocument::new("doc1")
.with_field("title", "Sunset Time-lapse 4K")
.with_field("tags", "nature sunset timelapse"),
);
idx.add(
&IndexDocument::new("doc2")
.with_field("title", "City Night Drone")
.with_field("tags", "city night aerial drone"),
);
idx.add(
&IndexDocument::new("doc3")
.with_field("title", "Ocean Sunset Cinematic")
.with_field("tags", "ocean sunset cinematic waves"),
);
idx
}
#[test]
fn test_index_doc_count() {
let idx = build_index();
assert_eq!(idx.doc_count(), 3);
}
#[test]
fn test_term_count_positive() {
let idx = build_index();
assert!(idx.term_count() > 0);
}
#[test]
fn test_search_single_term() {
let idx = build_index();
let hits = idx.search(&SearchQuery::new("sunset"));
assert_eq!(hits.len(), 2);
}
#[test]
fn test_search_and_operator() {
let idx = build_index();
let q = SearchQuery::new("sunset ocean").with_op(BoolOp::And);
let hits = idx.search(&q);
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].doc_id, "doc3");
}
#[test]
fn test_search_or_operator() {
let idx = build_index();
let q = SearchQuery::new("sunset drone").with_op(BoolOp::Or);
let hits = idx.search(&q);
assert_eq!(hits.len(), 3); }
#[test]
fn test_search_field_scoped() {
let idx = build_index();
let q = SearchQuery::new("sunset").with_field("tags");
let hits = idx.search(&q);
assert_eq!(hits.len(), 2);
}
#[test]
fn test_search_no_match() {
let idx = build_index();
let hits = idx.search(&SearchQuery::new("nonexistent"));
assert!(hits.is_empty());
}
#[test]
fn test_search_limit() {
let idx = build_index();
let q = SearchQuery::new("sunset").with_limit(1);
let hits = idx.search(&q);
assert_eq!(hits.len(), 1);
}
#[test]
fn test_remove_document() {
let mut idx = build_index();
idx.remove("doc1");
assert_eq!(idx.doc_count(), 2);
assert!(!idx.contains("doc1"));
}
#[test]
fn test_contains() {
let idx = build_index();
assert!(idx.contains("doc1"));
assert!(!idx.contains("doc999"));
}
#[test]
fn test_duplicate_add_ignored() {
let mut idx = build_index();
let doc = IndexDocument::new("doc1").with_field("title", "Duplicate");
idx.add(&doc);
assert_eq!(idx.doc_count(), 3);
}
#[test]
fn test_avg_doc_length() {
let idx = build_index();
let avg = idx.avg_doc_length();
assert!(avg > 0.0);
}
#[test]
fn test_empty_query() {
let idx = build_index();
let hits = idx.search(&SearchQuery::new(""));
assert!(hits.is_empty());
}
#[test]
fn test_empty_index_search() {
let idx = SearchIndex::new();
let hits = idx.search(&SearchQuery::new("hello"));
assert!(hits.is_empty());
assert_eq!(idx.avg_doc_length(), 0.0);
}
#[test]
fn test_search_hits_have_scores() {
let idx = build_index();
let hits = idx.search(&SearchQuery::new("sunset"));
for hit in &hits {
assert!(hit.score >= 0.0);
}
}
#[test]
fn test_search_hits_sorted_by_score() {
let idx = build_index();
let hits = idx.search(&SearchQuery::new("sunset"));
for w in hits.windows(2) {
assert!(w[0].score >= w[1].score);
}
}
}