use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use turbovault_core::prelude::*;
use turbovault_parser::to_plain_text;
use turbovault_vault::VaultManager;
use crate::search_engine::is_stopword;
struct DocumentVector {
path: PathBuf,
title: String,
preview: String,
tfidf: HashMap<String, f64>,
norm: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimilarityResult {
pub path: String,
pub title: String,
pub score: f64,
pub shared_terms: Vec<String>,
pub preview: String,
}
pub struct SimilarityEngine {
#[allow(dead_code)]
manager: Arc<VaultManager>,
documents: Vec<DocumentVector>,
idf: HashMap<String, f64>,
#[allow(dead_code)]
doc_count: usize,
}
impl SimilarityEngine {
pub async fn new(manager: Arc<VaultManager>) -> Result<Self> {
let files = manager.scan_vault().await?;
let vault_path = manager.vault_path().clone();
let doc_count = files.len().max(1);
let mut doc_freq: HashMap<String, usize> = HashMap::new();
let mut parsed_docs: Vec<(PathBuf, String, String, HashMap<String, usize>)> = Vec::new();
for file_path in &files {
if let Ok(vault_file) = manager.parse_file(file_path).await {
let plain = to_plain_text(&vault_file.content);
let tokens = tokenize(&plain);
let mut term_counts: HashMap<String, usize> = HashMap::new();
for token in &tokens {
*term_counts.entry(token.clone()).or_insert(0) += 1;
}
for term in term_counts.keys() {
*doc_freq.entry(term.clone()).or_insert(0) += 1;
}
let rel_path = file_path.strip_prefix(&vault_path).unwrap_or(file_path);
let title = vault_file
.headings
.first()
.map(|h| h.text.clone())
.unwrap_or_else(|| {
rel_path
.file_stem()
.unwrap_or_default()
.to_string_lossy()
.to_string()
});
let preview = plain
.lines()
.next()
.unwrap_or("")
.chars()
.take(200)
.collect();
parsed_docs.push((rel_path.to_path_buf(), title, preview, term_counts));
}
}
let idf: HashMap<String, f64> = doc_freq
.into_iter()
.map(|(term, count)| {
let idf_val = (doc_count as f64 / count as f64).ln();
(term, idf_val)
})
.collect();
let mut documents = Vec::with_capacity(parsed_docs.len());
for (path, title, preview, term_counts) in parsed_docs {
let total_terms: usize = term_counts.values().sum();
if total_terms == 0 {
continue;
}
let mut tfidf: HashMap<String, f64> = HashMap::new();
let mut norm_sq = 0.0f64;
for (term, count) in &term_counts {
let tf = *count as f64 / total_terms as f64;
let idf_val = idf.get(term).copied().unwrap_or(0.0);
let tfidf_val = tf * idf_val;
if tfidf_val > 0.0 {
tfidf.insert(term.clone(), tfidf_val);
norm_sq += tfidf_val * tfidf_val;
}
}
let norm = norm_sq.sqrt();
documents.push(DocumentVector {
path,
title,
preview,
tfidf,
norm,
});
}
Ok(Self {
manager,
documents,
idf,
doc_count,
})
}
pub fn semantic_search(&self, query: &str, limit: usize) -> Vec<SimilarityResult> {
let query_tokens = tokenize(query);
if query_tokens.is_empty() {
return vec![];
}
let mut query_counts: HashMap<String, usize> = HashMap::new();
for token in &query_tokens {
*query_counts.entry(token.clone()).or_insert(0) += 1;
}
let total_terms = query_tokens.len();
let mut query_tfidf: HashMap<String, f64> = HashMap::new();
let mut query_norm_sq = 0.0f64;
for (term, count) in &query_counts {
let tf = *count as f64 / total_terms as f64;
let idf_val = self.idf.get(term).copied().unwrap_or(0.0);
let tfidf_val = tf * idf_val;
if tfidf_val > 0.0 {
query_tfidf.insert(term.clone(), tfidf_val);
query_norm_sq += tfidf_val * tfidf_val;
}
}
let query_norm = query_norm_sq.sqrt();
if query_norm < f64::EPSILON {
return vec![];
}
self.rank_by_similarity(&query_tfidf, query_norm, None, limit)
}
pub fn find_similar_notes(&self, path: &str, limit: usize) -> Vec<SimilarityResult> {
let target_path = PathBuf::from(path);
let target = self.documents.iter().find(|d| d.path == target_path);
match target {
Some(doc) => self.rank_by_similarity(&doc.tfidf, doc.norm, Some(path), limit),
None => vec![],
}
}
fn rank_by_similarity(
&self,
query_tfidf: &HashMap<String, f64>,
query_norm: f64,
exclude_path: Option<&str>,
limit: usize,
) -> Vec<SimilarityResult> {
let mut results: Vec<(f64, Vec<String>, &DocumentVector)> = Vec::new();
for doc in &self.documents {
if let Some(excl) = exclude_path
&& doc.path.to_string_lossy() == excl
{
continue;
}
if doc.norm < f64::EPSILON {
continue;
}
let mut dot_product = 0.0f64;
let mut shared_terms = Vec::new();
for (term, query_weight) in query_tfidf {
if let Some(doc_weight) = doc.tfidf.get(term) {
dot_product += query_weight * doc_weight;
shared_terms.push(term.clone());
}
}
if dot_product > 0.0 {
let cosine_sim = dot_product / (query_norm * doc.norm);
shared_terms.sort_by(|a, b| {
let wa = doc.tfidf.get(a).unwrap_or(&0.0);
let wb = doc.tfidf.get(b).unwrap_or(&0.0);
wb.partial_cmp(wa).unwrap_or(std::cmp::Ordering::Equal)
});
shared_terms.truncate(10);
results.push((cosine_sim, shared_terms, doc));
}
}
results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(limit);
results
.into_iter()
.map(|(score, shared_terms, doc)| SimilarityResult {
path: doc.path.to_string_lossy().to_string(),
title: doc.title.clone(),
score: (score * 10000.0).round() / 10000.0,
shared_terms,
preview: doc.preview.clone(),
})
.collect()
}
pub fn document_count(&self) -> usize {
self.documents.len()
}
}
pub(crate) fn tokenize(text: &str) -> Vec<String> {
let words: Vec<String> = text
.split(|c: char| c.is_whitespace() || c.is_ascii_punctuation())
.filter(|w| w.len() >= 3)
.map(|w| w.to_lowercase())
.filter(|w| !is_stopword(w))
.collect();
let mut tokens = words.clone();
for pair in words.windows(2) {
tokens.push(format!("{}_{}", pair[0], pair[1]));
}
tokens
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokenize_basic() {
let tokens = tokenize("The quick brown fox jumps over the lazy dog");
assert!(tokens.contains(&"quick".to_string()));
assert!(tokens.contains(&"brown".to_string()));
assert!(tokens.contains(&"fox".to_string()));
assert!(!tokens.contains(&"the".to_string()));
assert!(tokens.contains(&"over".to_string()));
}
#[test]
fn test_tokenize_bigrams() {
let tokens = tokenize("machine learning algorithms");
assert!(tokens.contains(&"machine".to_string()));
assert!(tokens.contains(&"learning".to_string()));
assert!(tokens.contains(&"algorithms".to_string()));
assert!(tokens.contains(&"machine_learning".to_string()));
assert!(tokens.contains(&"learning_algorithms".to_string()));
}
#[test]
fn test_tokenize_filters_short_words() {
let tokens = tokenize("I am a ok fine yes no do go");
assert!(!tokens.contains(&"am".to_string()));
assert!(!tokens.contains(&"ok".to_string()));
assert!(tokens.contains(&"fine".to_string()));
}
#[test]
fn test_tokenize_empty() {
let tokens = tokenize("");
assert!(tokens.is_empty());
}
}