use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
use crate::analysis::analyzer::analyzer::Analyzer;
use crate::analysis::analyzer::standard::StandardAnalyzer;
use crate::error::Result;
use crate::lexical::core::field::FieldValue;
use crate::lexical::query::Query;
use crate::lexical::query::matcher::Matcher;
use crate::lexical::query::scorer::Scorer;
use crate::lexical::reader::LexicalIndexReader;
#[derive(Debug, Clone)]
pub struct SimilarityConfig {
pub algorithm: SimilarityAlgorithm,
pub min_similarity: f32,
pub max_results: usize,
pub similarity_fields: Vec<String>,
pub normalize_vectors: bool,
pub exact_match_boost: f32,
}
impl Default for SimilarityConfig {
fn default() -> Self {
SimilarityConfig {
algorithm: SimilarityAlgorithm::Cosine,
min_similarity: 0.1,
max_results: 20,
similarity_fields: vec!["content".to_string()],
normalize_vectors: true,
exact_match_boost: 1.5,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum SimilarityAlgorithm {
Cosine,
Jaccard,
Euclidean,
Manhattan,
TermFrequency,
BM25,
}
#[derive(Debug, Clone)]
pub struct DocumentVector {
pub doc_id: u32,
pub features: HashMap<String, f32>,
pub magnitude: f32,
}
impl DocumentVector {
pub fn new(doc_id: u32) -> Self {
DocumentVector {
doc_id,
features: HashMap::new(),
magnitude: 0.0,
}
}
pub fn set_feature(&mut self, term: String, weight: f32) {
self.features.insert(term, weight);
}
pub fn calculate_magnitude(&mut self) {
self.magnitude = self.features.values().map(|w| w * w).sum::<f32>().sqrt();
}
pub fn normalize(&mut self) {
if self.magnitude == 0.0 {
self.calculate_magnitude();
}
if self.magnitude > 0.0 {
for weight in self.features.values_mut() {
*weight /= self.magnitude;
}
self.magnitude = 1.0;
}
}
pub fn terms(&self) -> HashSet<String> {
self.features.keys().cloned().collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimilarityResult {
pub doc_id: u32,
pub similarity: f32,
pub explanation: Option<String>,
}
impl SimilarityResult {
pub fn new(doc_id: u32, similarity: f32) -> Self {
SimilarityResult {
doc_id,
similarity,
explanation: None,
}
}
pub fn with_explanation(mut self, explanation: String) -> Self {
self.explanation = Some(explanation);
self
}
}
pub struct SimilaritySearchEngine {
config: SimilarityConfig,
analyzer: Box<dyn Analyzer>,
document_vectors: HashMap<u32, DocumentVector>,
}
impl std::fmt::Debug for SimilaritySearchEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SimilaritySearchEngine")
.field("config", &self.config)
.field("analyzer", &"<dyn Analyzer>")
.field("document_vectors", &self.document_vectors)
.finish()
}
}
impl SimilaritySearchEngine {
pub fn new(config: SimilarityConfig) -> Self {
SimilaritySearchEngine {
config,
analyzer: Box::new(StandardAnalyzer::new().unwrap()),
document_vectors: HashMap::new(),
}
}
pub fn with_analyzer(config: SimilarityConfig, analyzer: Box<dyn Analyzer>) -> Self {
SimilaritySearchEngine {
config,
analyzer,
document_vectors: HashMap::new(),
}
}
pub fn find_similar(
&mut self,
target_doc_id: u32,
reader: &dyn LexicalIndexReader,
) -> Result<Vec<SimilarityResult>> {
let target_vector = self.get_or_create_document_vector(target_doc_id, reader)?;
let candidate_docs = self.get_candidate_documents(reader)?;
let mut results = Vec::new();
for candidate_doc_id in candidate_docs {
if candidate_doc_id == target_doc_id {
continue; }
let candidate_vector = self.get_or_create_document_vector(candidate_doc_id, reader)?;
let similarity = self.calculate_similarity(&target_vector, &candidate_vector)?;
if similarity >= self.config.min_similarity {
results.push(SimilarityResult::new(candidate_doc_id, similarity));
}
}
results.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(Ordering::Equal)
});
results.truncate(self.config.max_results);
Ok(results)
}
pub fn find_similar_to_text(
&mut self,
text: &str,
reader: &dyn LexicalIndexReader,
) -> Result<Vec<SimilarityResult>> {
let query_vector = self.create_vector_from_text(text, 0)?;
let candidate_docs = self.get_candidate_documents(reader)?;
let mut results = Vec::new();
for candidate_doc_id in candidate_docs {
let candidate_vector = self.get_or_create_document_vector(candidate_doc_id, reader)?;
let similarity = self.calculate_similarity(&query_vector, &candidate_vector)?;
if similarity >= self.config.min_similarity {
results.push(SimilarityResult::new(candidate_doc_id, similarity));
}
}
results.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(Ordering::Equal)
});
results.truncate(self.config.max_results);
Ok(results)
}
fn get_or_create_document_vector(
&mut self,
doc_id: u32,
reader: &dyn LexicalIndexReader,
) -> Result<DocumentVector> {
if let Some(vector) = self.document_vectors.get(&doc_id) {
return Ok(vector.clone());
}
let vector = self.create_document_vector(doc_id, reader)?;
self.document_vectors.insert(doc_id, vector.clone());
Ok(vector)
}
fn create_document_vector(
&self,
doc_id: u32,
reader: &dyn LexicalIndexReader,
) -> Result<DocumentVector> {
let mut vector = DocumentVector::new(doc_id);
for field_name in &self.config.similarity_fields {
if let Ok(field_text) = self.get_document_field_text(doc_id, field_name, reader) {
self.add_text_to_vector(&mut vector, &field_text)?;
}
}
vector.calculate_magnitude();
if self.config.normalize_vectors {
vector.normalize();
}
Ok(vector)
}
fn create_vector_from_text(&self, text: &str, doc_id: u32) -> Result<DocumentVector> {
let mut vector = DocumentVector::new(doc_id);
self.add_text_to_vector(&mut vector, text)?;
vector.calculate_magnitude();
if self.config.normalize_vectors {
vector.normalize();
}
Ok(vector)
}
fn add_text_to_vector(&self, vector: &mut DocumentVector, text: &str) -> Result<()> {
let tokens = self.analyzer.analyze(text)?;
let mut term_counts: HashMap<String, f32> = HashMap::new();
for token in tokens {
*term_counts.entry(token.text.to_lowercase()).or_insert(0.0) += 1.0;
}
for (term, count) in term_counts {
let current_weight = vector.features.get(&term).unwrap_or(&0.0);
vector.set_feature(term, current_weight + count);
}
Ok(())
}
#[allow(dead_code)]
fn create_tfidf_vector(
&self,
doc_id: u32,
reader: &dyn LexicalIndexReader,
) -> Result<DocumentVector> {
let mut vector = DocumentVector::new(doc_id);
for field_name in &self.config.similarity_fields {
if let Ok(field_text) = self.get_document_field_text(doc_id, field_name, reader) {
self.add_tfidf_terms_to_vector(&mut vector, &field_text, reader)?;
}
}
vector.calculate_magnitude();
if self.config.normalize_vectors {
vector.normalize();
}
Ok(vector)
}
#[allow(dead_code)]
fn add_tfidf_terms_to_vector(
&self,
vector: &mut DocumentVector,
text: &str,
reader: &dyn LexicalIndexReader,
) -> Result<()> {
let tokens = self.analyzer.analyze(text)?;
let mut term_counts: HashMap<String, f32> = HashMap::new();
for token in tokens {
*term_counts.entry(token.text.to_lowercase()).or_insert(0.0) += 1.0;
}
let total_doc_count = reader.doc_count() as f32;
let doc_length = term_counts.len() as f32;
for (term, tf) in term_counts {
let tf_normalized = tf / doc_length;
let df = self
.estimate_document_frequency(&term, reader)
.unwrap_or(1.0);
let idf = if total_doc_count > 0.0 {
(total_doc_count / df).ln() + 1.0
} else {
1.0
};
let tfidf_weight = tf_normalized * idf;
let current_weight = vector.features.get(&term).unwrap_or(&0.0);
vector.set_feature(term, current_weight + tfidf_weight);
}
Ok(())
}
#[allow(dead_code)]
fn estimate_document_frequency(
&self,
term: &str,
_reader: &dyn LexicalIndexReader,
) -> Result<f32> {
let df = match term.len() {
1..=3 => 100.0, 4..=6 => 50.0, 7..=10 => 20.0, _ => 5.0, };
let stop_words = [
"the", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by",
];
if stop_words.contains(&term) {
Ok(500.0) } else {
Ok(df)
}
}
fn get_document_field_text(
&self,
doc_id: u32,
field_name: &str,
reader: &dyn LexicalIndexReader,
) -> Result<String> {
match reader.document(doc_id as u64) {
Ok(Some(document)) => {
if let Some(field_value) = document.get_field(field_name) {
match field_value {
FieldValue::Text(text) => Ok(text.clone()),
FieldValue::Int64(value) => Ok(value.to_string()),
FieldValue::Float64(value) => Ok(value.to_string()),
FieldValue::Bool(value) => Ok(value.to_string()),
FieldValue::Bytes(data, mime) => {
if mime.as_deref() == Some("text/plain") {
Ok(String::from_utf8_lossy(data).to_string())
} else {
Ok(String::new())
}
}
FieldValue::DateTime(dt) => Ok(dt.to_rfc3339()),
FieldValue::Geo(lat, lon) => Ok(format!("{},{}", lat, lon)),
FieldValue::Null => Ok(String::new()),
FieldValue::Vector(v) => Ok(format!("[vector: dim={}]", v.len())),
}
} else {
Ok(String::new())
}
}
Ok(None) => {
Ok(String::new())
}
Err(_) => {
Ok(format!("Document {doc_id} field {field_name}"))
}
}
}
fn get_candidate_documents(&self, reader: &dyn LexicalIndexReader) -> Result<Vec<u32>> {
let doc_count = reader.doc_count() as u32;
let max_candidates = std::cmp::min(doc_count, 10000); Ok((1..=max_candidates).collect())
}
fn calculate_similarity(
&self,
vector1: &DocumentVector,
vector2: &DocumentVector,
) -> Result<f32> {
match self.config.algorithm {
SimilarityAlgorithm::Cosine => self.cosine_similarity(vector1, vector2),
SimilarityAlgorithm::Jaccard => self.jaccard_similarity(vector1, vector2),
SimilarityAlgorithm::Euclidean => self.euclidean_similarity(vector1, vector2),
SimilarityAlgorithm::Manhattan => self.manhattan_similarity(vector1, vector2),
SimilarityAlgorithm::TermFrequency => self.term_frequency_similarity(vector1, vector2),
SimilarityAlgorithm::BM25 => self.bm25_similarity(vector1, vector2),
}
}
fn cosine_similarity(&self, vector1: &DocumentVector, vector2: &DocumentVector) -> Result<f32> {
if vector1.features.is_empty() || vector2.features.is_empty() {
return Ok(0.0);
}
let mut dot_product = 0.0;
let mut magnitude1 = 0.0;
let mut magnitude2 = 0.0;
let all_terms: HashSet<_> = vector1
.features
.keys()
.chain(vector2.features.keys())
.collect();
for term in all_terms {
let weight1 = vector1.features.get(term).unwrap_or(&0.0);
let weight2 = vector2.features.get(term).unwrap_or(&0.0);
dot_product += weight1 * weight2;
magnitude1 += weight1 * weight1;
magnitude2 += weight2 * weight2;
}
if magnitude1 == 0.0 || magnitude2 == 0.0 {
return Ok(0.0);
}
Ok(dot_product / (magnitude1.sqrt() * magnitude2.sqrt()))
}
fn jaccard_similarity(
&self,
vector1: &DocumentVector,
vector2: &DocumentVector,
) -> Result<f32> {
let terms1 = vector1.terms();
let terms2 = vector2.terms();
let intersection = terms1.intersection(&terms2).count();
let union = terms1.union(&terms2).count();
if union == 0 {
Ok(0.0)
} else {
Ok(intersection as f32 / union as f32)
}
}
fn euclidean_similarity(
&self,
vector1: &DocumentVector,
vector2: &DocumentVector,
) -> Result<f32> {
let all_terms: HashSet<_> = vector1
.features
.keys()
.chain(vector2.features.keys())
.collect();
let mut sum_squared_diff = 0.0;
for term in all_terms {
let weight1 = vector1.features.get(term).unwrap_or(&0.0);
let weight2 = vector2.features.get(term).unwrap_or(&0.0);
let diff = weight1 - weight2;
sum_squared_diff += diff * diff;
}
let distance = sum_squared_diff.sqrt();
Ok(1.0 / (1.0 + distance))
}
fn manhattan_similarity(
&self,
vector1: &DocumentVector,
vector2: &DocumentVector,
) -> Result<f32> {
let all_terms: HashSet<_> = vector1
.features
.keys()
.chain(vector2.features.keys())
.collect();
let mut sum_abs_diff = 0.0;
for term in all_terms {
let weight1 = vector1.features.get(term).unwrap_or(&0.0);
let weight2 = vector2.features.get(term).unwrap_or(&0.0);
sum_abs_diff += (weight1 - weight2).abs();
}
Ok(1.0 / (1.0 + sum_abs_diff))
}
fn term_frequency_similarity(
&self,
vector1: &DocumentVector,
vector2: &DocumentVector,
) -> Result<f32> {
let mut shared_weight = 0.0;
let mut total_weight = 0.0;
let all_terms: HashSet<_> = vector1
.features
.keys()
.chain(vector2.features.keys())
.collect();
for term in all_terms {
let weight1 = vector1.features.get(term).unwrap_or(&0.0);
let weight2 = vector2.features.get(term).unwrap_or(&0.0);
shared_weight += weight1.min(*weight2);
total_weight += weight1.max(*weight2);
}
if total_weight == 0.0 {
Ok(0.0)
} else {
Ok(shared_weight / total_weight)
}
}
fn bm25_similarity(&self, vector1: &DocumentVector, vector2: &DocumentVector) -> Result<f32> {
let k1 = 1.2;
let b = 0.75;
let mut score = 0.0;
let doc_len1 = vector1.features.values().sum::<f32>();
let doc_len2 = vector2.features.values().sum::<f32>();
let avg_doc_len = (doc_len1 + doc_len2) / 2.0;
if avg_doc_len == 0.0 {
return Ok(0.0);
}
for (term, &tf1) in &vector1.features {
if let Some(&tf2) = vector2.features.get(term) {
let norm_tf1 = tf1 / (tf1 + k1 * (1.0 - b + b * doc_len1 / avg_doc_len));
let norm_tf2 = tf2 / (tf2 + k1 * (1.0 - b + b * doc_len2 / avg_doc_len));
score += norm_tf1 * norm_tf2;
}
}
Ok(score)
}
pub fn clear_cache(&mut self) {
self.document_vectors.clear();
}
pub fn cache_stats(&self) -> (usize, usize) {
let cached_docs = self.document_vectors.len();
let total_features: usize = self
.document_vectors
.values()
.map(|v| v.features.len())
.sum();
(cached_docs, total_features)
}
}
#[derive(Debug)]
pub struct MoreLikeThisQuery {
config: SimilarityConfig,
input: MoreLikeThisInput,
}
#[derive(Debug, Clone)]
pub enum MoreLikeThisInput {
Text(String),
DocumentId(u32),
DocumentIds(Vec<u32>),
}
impl MoreLikeThisQuery {
pub fn from_text(text: String, config: SimilarityConfig) -> Self {
MoreLikeThisQuery {
config,
input: MoreLikeThisInput::Text(text),
}
}
pub fn from_document(doc_id: u32, config: SimilarityConfig) -> Self {
MoreLikeThisQuery {
config,
input: MoreLikeThisInput::DocumentId(doc_id),
}
}
pub fn from_documents(doc_ids: Vec<u32>, config: SimilarityConfig) -> Self {
MoreLikeThisQuery {
config,
input: MoreLikeThisInput::DocumentIds(doc_ids),
}
}
pub fn execute(&self, reader: &dyn LexicalIndexReader) -> Result<Vec<SimilarityResult>> {
let mut engine = SimilaritySearchEngine::new(self.config.clone());
match &self.input {
MoreLikeThisInput::Text(text) => engine.find_similar_to_text(text, reader),
MoreLikeThisInput::DocumentId(doc_id) => engine.find_similar(*doc_id, reader),
MoreLikeThisInput::DocumentIds(doc_ids) => {
if let Some(&first_doc) = doc_ids.first() {
engine.find_similar(first_doc, reader)
} else {
Ok(Vec::new())
}
}
}
}
pub fn with_boost(mut self, boost: f32) -> Self {
self.config.exact_match_boost = boost;
self
}
pub fn min_similarity(mut self, min_similarity: f32) -> Self {
self.config.min_similarity = min_similarity;
self
}
pub fn max_results(mut self, max_results: usize) -> Self {
self.config.max_results = max_results;
self
}
pub fn similarity_fields(mut self, fields: Vec<String>) -> Self {
self.config.similarity_fields = fields;
self
}
pub fn algorithm(mut self, algorithm: SimilarityAlgorithm) -> Self {
self.config.algorithm = algorithm;
self
}
}
impl Query for MoreLikeThisQuery {
fn matcher(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Matcher>> {
let results = self.execute(reader)?;
Ok(Box::new(MoreLikeThisMatcher::new(results)))
}
fn scorer(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Scorer>> {
let results = self.execute(reader)?;
Ok(Box::new(MoreLikeThisScorer::new(results)))
}
fn boost(&self) -> f32 {
self.config.exact_match_boost
}
fn set_boost(&mut self, boost: f32) {
self.config.exact_match_boost = boost;
}
fn clone_box(&self) -> Box<dyn Query> {
Box::new(MoreLikeThisQuery {
config: self.config.clone(),
input: self.input.clone(),
})
}
fn description(&self) -> String {
match &self.input {
MoreLikeThisInput::Text(text) => format!("MoreLikeThis(text: \"{text}\"))"),
MoreLikeThisInput::DocumentId(doc_id) => format!("MoreLikeThis(doc: {doc_id}))"),
MoreLikeThisInput::DocumentIds(doc_ids) => format!("MoreLikeThis(docs: {doc_ids:?})"),
}
}
fn is_empty(&self, _reader: &dyn LexicalIndexReader) -> Result<bool> {
match &self.input {
MoreLikeThisInput::Text(text) => Ok(text.trim().is_empty()),
MoreLikeThisInput::DocumentId(_) => Ok(false),
MoreLikeThisInput::DocumentIds(doc_ids) => Ok(doc_ids.is_empty()),
}
}
fn cost(&self, reader: &dyn LexicalIndexReader) -> Result<u64> {
let doc_count = reader.doc_count();
Ok(doc_count * 10) }
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[derive(Debug)]
pub struct MoreLikeThisMatcher {
results: Vec<SimilarityResult>,
current_index: usize,
current_doc_id: u64,
}
impl MoreLikeThisMatcher {
pub fn new(mut results: Vec<SimilarityResult>) -> Self {
results.sort_by_key(|r| r.doc_id);
MoreLikeThisMatcher {
results,
current_index: 0,
current_doc_id: 0,
}
}
}
impl Matcher for MoreLikeThisMatcher {
fn doc_id(&self) -> u64 {
self.current_doc_id
}
fn next(&mut self) -> Result<bool> {
if self.current_index < self.results.len() {
self.current_doc_id = self.results[self.current_index].doc_id as u64;
self.current_index += 1;
Ok(true)
} else {
Ok(false)
}
}
fn skip_to(&mut self, target: u64) -> Result<bool> {
while self.current_index < self.results.len() {
let doc_id = self.results[self.current_index].doc_id as u64;
if doc_id >= target {
self.current_doc_id = doc_id;
self.current_index += 1;
return Ok(true);
}
self.current_index += 1;
}
Ok(false)
}
fn cost(&self) -> u64 {
self.results.len() as u64
}
fn is_exhausted(&self) -> bool {
self.current_index >= self.results.len()
}
}
#[derive(Debug)]
pub struct MoreLikeThisScorer {
similarity_scores: HashMap<u32, f32>,
boost: f32,
}
impl MoreLikeThisScorer {
pub fn new(results: Vec<SimilarityResult>) -> Self {
let mut similarity_scores = HashMap::new();
for result in results {
similarity_scores.insert(result.doc_id, result.similarity);
}
MoreLikeThisScorer {
similarity_scores,
boost: 1.0,
}
}
pub fn set_boost(&mut self, boost: f32) {
self.boost = boost;
}
}
impl Scorer for MoreLikeThisScorer {
fn score(&self, doc_id: u64, _term_freq: f32, _field_length: Option<f32>) -> f32 {
self.similarity_scores.get(&(doc_id as u32)).unwrap_or(&0.0) * self.boost
}
fn boost(&self) -> f32 {
self.boost
}
fn set_boost(&mut self, boost: f32) {
self.boost = boost;
}
fn max_score(&self) -> f32 {
self.similarity_scores
.values()
.fold(0.0_f32, |max, &score| max.max(score))
* self.boost
}
fn name(&self) -> &'static str {
"MoreLikeThisScorer"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_similarity_config() {
let config = SimilarityConfig::default();
assert_eq!(config.algorithm, SimilarityAlgorithm::Cosine);
assert_eq!(config.min_similarity, 0.1);
assert_eq!(config.max_results, 20);
assert!(config.normalize_vectors);
}
#[test]
fn test_document_vector() {
let mut vector = DocumentVector::new(1);
vector.set_feature("term1".to_string(), 2.0);
vector.set_feature("term2".to_string(), 3.0);
vector.calculate_magnitude();
assert!((vector.magnitude - (4.0 + 9.0_f32).sqrt()).abs() < 1e-6);
vector.normalize();
assert!((vector.magnitude - 1.0).abs() < 1e-6);
let terms = vector.terms();
assert_eq!(terms.len(), 2);
assert!(terms.contains("term1"));
assert!(terms.contains("term2"));
}
#[test]
fn test_similarity_result() {
let result =
SimilarityResult::new(123, 0.85).with_explanation("High cosine similarity".to_string());
assert_eq!(result.doc_id, 123);
assert_eq!(result.similarity, 0.85);
assert_eq!(
result.explanation,
Some("High cosine similarity".to_string())
);
}
#[test]
fn test_cosine_similarity() {
let config = SimilarityConfig::default();
let engine = SimilaritySearchEngine::new(config);
let mut vector1 = DocumentVector::new(1);
vector1.set_feature("term1".to_string(), 1.0);
vector1.set_feature("term2".to_string(), 0.0);
let mut vector2 = DocumentVector::new(2);
vector2.set_feature("term1".to_string(), 1.0);
vector2.set_feature("term2".to_string(), 0.0);
let similarity = engine.cosine_similarity(&vector1, &vector2).unwrap();
assert!((similarity - 1.0).abs() < 1e-6);
vector2.set_feature("term1".to_string(), 0.0);
vector2.set_feature("term2".to_string(), 1.0);
let similarity = engine.cosine_similarity(&vector1, &vector2).unwrap();
assert!((similarity - 0.0).abs() < 1e-6); }
#[test]
fn test_jaccard_similarity() {
let config = SimilarityConfig::default();
let engine = SimilaritySearchEngine::new(config);
let mut vector1 = DocumentVector::new(1);
vector1.set_feature("term1".to_string(), 1.0);
vector1.set_feature("term2".to_string(), 1.0);
let mut vector2 = DocumentVector::new(2);
vector2.set_feature("term1".to_string(), 1.0);
vector2.set_feature("term3".to_string(), 1.0);
let similarity = engine.jaccard_similarity(&vector1, &vector2).unwrap();
assert!((similarity - 1.0 / 3.0).abs() < 1e-6);
}
#[test]
fn test_more_like_this_query() {
let config = SimilarityConfig::default();
let text_query = MoreLikeThisQuery::from_text("sample text".to_string(), config.clone());
match text_query.input {
MoreLikeThisInput::Text(ref text) => assert_eq!(text, "sample text"),
_ => panic!("Expected text input"),
}
let doc_query = MoreLikeThisQuery::from_document(42, config.clone());
match doc_query.input {
MoreLikeThisInput::DocumentId(doc_id) => assert_eq!(doc_id, 42),
_ => panic!("Expected document ID input"),
}
let docs_query = MoreLikeThisQuery::from_documents(vec![1, 2, 3], config);
match docs_query.input {
MoreLikeThisInput::DocumentIds(ref doc_ids) => assert_eq!(doc_ids, &vec![1, 2, 3]),
_ => panic!("Expected document IDs input"),
}
}
#[test]
fn test_more_like_this_query_builder() {
let config = SimilarityConfig::default();
let query = MoreLikeThisQuery::from_text("test text".to_string(), config)
.with_boost(2.0)
.min_similarity(0.3)
.max_results(10)
.algorithm(SimilarityAlgorithm::Jaccard);
assert_eq!(query.boost(), 2.0);
assert_eq!(query.config.min_similarity, 0.3);
assert_eq!(query.config.max_results, 10);
assert_eq!(query.config.algorithm, SimilarityAlgorithm::Jaccard);
}
#[test]
fn test_more_like_this_matcher() {
let results = vec![
SimilarityResult::new(3, 0.8),
SimilarityResult::new(1, 0.9),
SimilarityResult::new(5, 0.7),
];
let mut matcher = MoreLikeThisMatcher::new(results);
assert!(matcher.next().unwrap());
assert_eq!(matcher.doc_id(), 1);
assert!(matcher.next().unwrap());
assert_eq!(matcher.doc_id(), 3);
assert!(matcher.next().unwrap());
assert_eq!(matcher.doc_id(), 5);
assert!(!matcher.next().unwrap());
}
#[test]
fn test_more_like_this_scorer() {
let results = vec![
SimilarityResult::new(1, 0.9),
SimilarityResult::new(2, 0.8),
SimilarityResult::new(3, 0.7),
];
let mut scorer = MoreLikeThisScorer::new(results);
scorer.set_boost(2.0);
assert_eq!(scorer.score(1, 1.0, None), 0.9 * 2.0);
assert_eq!(scorer.score(2, 1.0, None), 0.8 * 2.0);
assert_eq!(scorer.score(3, 1.0, None), 0.7 * 2.0);
assert_eq!(scorer.score(999, 1.0, None), 0.0);
assert_eq!(scorer.max_score(), 0.9 * 2.0);
}
#[test]
fn test_tfidf_estimation() {
let config = SimilarityConfig::default();
let engine = SimilaritySearchEngine::new(config);
assert!(
engine
.estimate_document_frequency("the", &MockIndexReader)
.unwrap()
> 100.0
);
assert!(
engine
.estimate_document_frequency("specialized", &MockIndexReader)
.unwrap()
< 50.0
);
assert!(
engine
.estimate_document_frequency("antidisestablishmentarianism", &MockIndexReader)
.unwrap()
< 10.0
);
}
#[derive(Debug)]
struct MockIndexReader;
impl LexicalIndexReader for MockIndexReader {
fn doc_count(&self) -> u64 {
1000
}
fn max_doc(&self) -> u64 {
1000
}
fn is_deleted(&self, _doc_id: u64) -> bool {
false
}
fn document(
&self,
_doc_id: u64,
) -> Result<Option<crate::lexical::core::document::Document>> {
Ok(None)
}
fn term_info(
&self,
_field: &str,
_term: &str,
) -> Result<Option<crate::lexical::reader::ReaderTermInfo>> {
Ok(None)
}
fn postings(
&self,
_field: &str,
_term: &str,
) -> Result<Option<Box<dyn crate::lexical::reader::PostingIterator>>> {
Ok(None)
}
fn field_stats(&self, _field: &str) -> Result<Option<crate::lexical::reader::FieldStats>> {
Ok(None)
}
fn is_closed(&self) -> bool {
false
}
fn close(&mut self) -> Result<()> {
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
}