use crate::domain::bookmark::Bookmark;
use crate::domain::tag::Tag;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct SemanticSearch {
pub query: String,
pub limit: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct SemanticSearchResult {
pub bookmark: Bookmark,
pub similarity: f64,
}
impl SemanticSearch {
pub fn new(query: impl Into<String>, limit: Option<usize>) -> Self {
Self {
query: query.into(),
limit,
}
}
}
impl SemanticSearchResult {
pub fn similarity_percentage(&self) -> String {
format!("{:.1}%", self.similarity * 100.0)
}
pub fn new(bookmark: Bookmark, similarity: f64) -> Self {
Self {
bookmark,
similarity,
}
}
pub fn display(&self) -> String {
let id = self.bookmark.id.unwrap_or(0);
let title = &self.bookmark.title;
let url = &self.bookmark.url;
let binding = self.bookmark.formatted_tags();
let tags_str = binding.trim_matches(',');
let similarity = format!("{:.1}%", self.similarity * 100.0);
let tags_display = if !tags_str.is_empty() {
format!(" [{}]", tags_str)
} else {
String::new()
};
format!(
"{}: {} <{}> ({}%) (default){}",
id, title, url, similarity, tags_display
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SearchMode {
Hybrid,
Exact,
}
impl Default for SearchMode {
fn default() -> Self {
Self::Hybrid
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RankedResult {
pub bookmark_id: i32,
pub rank: usize,
}
#[derive(Debug, Clone)]
pub struct HybridSearch {
pub query: String,
pub tags_all: Option<HashSet<Tag>>,
pub tags_all_not: Option<HashSet<Tag>>,
pub tags_any: Option<HashSet<Tag>>,
pub tags_any_not: Option<HashSet<Tag>>,
pub tags_exact: Option<HashSet<Tag>>,
pub tags_prefix: Option<HashSet<Tag>>,
pub limit: Option<usize>,
pub mode: SearchMode,
}
impl HybridSearch {
pub fn new(query: impl Into<String>) -> Self {
Self {
query: query.into(),
tags_all: None,
tags_all_not: None,
tags_any: None,
tags_any_not: None,
tags_exact: None,
tags_prefix: None,
limit: None,
mode: SearchMode::default(),
}
}
pub fn has_tag_filters(&self) -> bool {
self.tags_all.is_some()
|| self.tags_all_not.is_some()
|| self.tags_any.is_some()
|| self.tags_any_not.is_some()
|| self.tags_exact.is_some()
|| self.tags_prefix.is_some()
}
pub fn effective_limit(&self) -> usize {
self.limit.unwrap_or(10)
}
pub fn apply_tag_filters<'a>(&self, bookmarks: &'a [Bookmark]) -> Vec<&'a Bookmark> {
let mut filtered: Vec<&Bookmark> = bookmarks.iter().collect();
if let Some(tags) = &self.tags_exact {
if !tags.is_empty() {
filtered.retain(|b| b.matches_exact_tags(tags));
}
}
if let Some(tags) = &self.tags_all {
if !tags.is_empty() {
filtered.retain(|b| b.matches_all_tags(tags));
}
}
if let Some(tags) = &self.tags_all_not {
if !tags.is_empty() {
filtered.retain(|b| !b.matches_all_tags(tags));
}
}
if let Some(tags) = &self.tags_any {
if !tags.is_empty() {
filtered.retain(|b| b.matches_any_tag(tags));
}
}
if let Some(tags) = &self.tags_any_not {
if !tags.is_empty() {
filtered.retain(|b| !b.matches_any_tag(tags));
}
}
if let Some(prefixes) = &self.tags_prefix {
if !prefixes.is_empty() {
filtered.retain(|b| {
prefixes.iter().any(|prefix| {
let prefix_str = prefix.value();
b.tags.iter().any(|tag| tag.value().starts_with(prefix_str))
})
});
}
}
filtered
}
}
#[derive(Debug, Clone)]
pub struct HybridSearchResult {
pub bookmark: Bookmark,
pub rrf_score: f64,
}
impl HybridSearchResult {
pub fn new(bookmark: Bookmark, rrf_score: f64) -> Self {
Self { bookmark, rrf_score }
}
}
pub struct RrfFusion;
impl RrfFusion {
pub fn fuse(
fts_results: &[RankedResult],
sem_results: &[RankedResult],
k: f64,
limit: usize,
) -> Vec<(i32, f64)> {
let mut scores: HashMap<i32, f64> = HashMap::new();
for result in fts_results {
*scores.entry(result.bookmark_id).or_default() +=
1.0 / (k + result.rank as f64 + 1.0);
}
for result in sem_results {
*scores.entry(result.bookmark_id).or_default() +=
1.0 / (k + result.rank as f64 + 1.0);
}
let mut scored: Vec<(i32, f64)> = scores.into_iter().collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit);
scored
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::tag::Tag;
use crate::util::testing::init_test_env;
use std::collections::HashSet;
fn create_test_bookmark(title: &str, content: &str, has_embedding: bool) -> Bookmark {
let mut tags = HashSet::new();
tags.insert(Tag::new("test").unwrap());
let mut bookmark =
Bookmark::new("https://example.com", title, content, tags).unwrap();
bookmark.set_embeddable(has_embedding);
bookmark
}
#[test]
fn given_two_ranked_lists_when_fuse_then_boosted_score() {
let fts = vec![
RankedResult { bookmark_id: 1, rank: 0 },
RankedResult { bookmark_id: 2, rank: 1 },
];
let sem = vec![
RankedResult { bookmark_id: 1, rank: 0 },
RankedResult { bookmark_id: 3, rank: 1 },
];
let results = RrfFusion::fuse(&fts, &sem, 60.0, 10);
assert_eq!(results[0].0, 1);
let expected_score = 2.0 / 61.0;
assert!((results[0].1 - expected_score).abs() < 1e-10);
let doc2_score = results.iter().find(|(id, _)| *id == 2).unwrap().1;
let doc3_score = results.iter().find(|(id, _)| *id == 3).unwrap().1;
assert!((doc2_score - doc3_score).abs() < 1e-10);
assert!((doc2_score - 1.0 / 62.0).abs() < 1e-10);
}
#[test]
fn given_one_empty_list_when_fuse_then_single_engine_scores() {
let fts = vec![
RankedResult { bookmark_id: 1, rank: 0 },
RankedResult { bookmark_id: 2, rank: 1 },
];
let sem: Vec<RankedResult> = vec![];
let results = RrfFusion::fuse(&fts, &sem, 60.0, 10);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 1);
assert!((results[0].1 - 1.0 / 61.0).abs() < 1e-10);
assert_eq!(results[1].0, 2);
assert!((results[1].1 - 1.0 / 62.0).abs() < 1e-10);
}
#[test]
fn given_tied_ranks_when_fuse_then_correct_scores() {
let fts = vec![
RankedResult { bookmark_id: 1, rank: 0 },
];
let sem = vec![
RankedResult { bookmark_id: 2, rank: 0 },
];
let results = RrfFusion::fuse(&fts, &sem, 60.0, 10);
assert_eq!(results.len(), 2);
assert!((results[0].1 - results[1].1).abs() < 1e-10);
assert!((results[0].1 - 1.0 / 61.0).abs() < 1e-10);
}
#[test]
fn given_k_constant_when_fuse_then_dampening_applied() {
let fts = vec![
RankedResult { bookmark_id: 1, rank: 0 },
];
let sem: Vec<RankedResult> = vec![];
let results_k60 = RrfFusion::fuse(&fts, &sem, 60.0, 10);
let results_k1 = RrfFusion::fuse(&fts, &sem, 1.0, 10);
assert!(results_k1[0].1 > results_k60[0].1);
assert!((results_k60[0].1 - 1.0 / 61.0).abs() < 1e-10);
assert!((results_k1[0].1 - 1.0 / 2.0).abs() < 1e-10);
}
#[test]
fn given_limit_when_fuse_then_truncated() {
let fts = vec![
RankedResult { bookmark_id: 1, rank: 0 },
RankedResult { bookmark_id: 2, rank: 1 },
RankedResult { bookmark_id: 3, rank: 2 },
];
let sem: Vec<RankedResult> = vec![];
let results = RrfFusion::fuse(&fts, &sem, 60.0, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 1);
assert_eq!(results[1].0, 2);
}
#[test]
fn given_semantic_search_when_new_then_stores_query_and_limit() {
let search = SemanticSearch::new("test query", Some(5));
assert_eq!(search.query, "test query");
assert_eq!(search.limit, Some(5));
}
#[test]
fn given_semantic_search_when_no_limit_then_limit_is_none() {
let search = SemanticSearch::new("test query", None);
assert_eq!(search.query, "test query");
assert_eq!(search.limit, None);
}
#[test]
fn given_similarity_score_when_format_percentage_then_returns_correct_format() {
let _ = init_test_env();
let bookmark = create_test_bookmark("Test", "Content", true);
let result = SemanticSearchResult {
bookmark,
similarity: 0.756,
};
assert_eq!(result.similarity_percentage(), "75.6%");
}
}