use semantic_memory::search::{cosine_similarity, sanitize_fts_query, source_dedup_key};
use semantic_memory::SearchSource;
use semantic_memory::{MemoryConfig, MemoryStore, MockEmbedder, SearchConfig, SearchSourceType};
use tempfile::TempDir;
fn test_store() -> (MemoryStore, TempDir) {
let tmp = TempDir::new().unwrap();
let config = MemoryConfig {
base_dir: tmp.path().to_path_buf(),
..Default::default()
};
let embedder = Box::new(MockEmbedder::new(768));
let store = MemoryStore::open_with_embedder(config, embedder).unwrap();
(store, tmp)
}
#[test]
fn cosine_identical_vectors() {
let v = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&v, &v);
assert!(
(sim - 1.0).abs() < 0.001,
"Identical vectors should have similarity ~1.0, got {}",
sim
);
}
#[test]
fn cosine_orthogonal_vectors() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!(
sim.abs() < 0.001,
"Orthogonal vectors should have similarity ~0.0, got {}",
sim
);
}
#[test]
fn cosine_opposite_vectors() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![-1.0, -2.0, -3.0];
let sim = cosine_similarity(&a, &b);
assert!(
(sim + 1.0).abs() < 0.001,
"Opposite vectors should have similarity ~-1.0, got {}",
sim
);
}
#[test]
fn cosine_zero_vector() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![0.0, 0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert_eq!(sim, 0.0, "Zero vector should return 0.0 similarity");
}
#[test]
fn sanitize_strips_fts_operators() {
let result = sanitize_fts_query("hello \"world\" + test");
assert_eq!(
result,
Some("\"hello\" OR \"world\" OR \"test\"".to_string())
);
}
#[test]
fn sanitize_empty_after_stripping() {
let result = sanitize_fts_query("\"*+-()^{}~:");
assert_eq!(result, None);
}
#[test]
fn sanitize_normal_query_unchanged() {
let result = sanitize_fts_query("hello world");
assert_eq!(result, Some("\"hello\" OR \"world\"".to_string()));
}
#[test]
fn sanitize_unicode_preserved() {
let result = sanitize_fts_query("中文 搜索");
assert_eq!(result, Some("\"中文\" OR \"搜索\"".to_string()));
}
#[test]
fn sanitize_empty_string() {
assert_eq!(sanitize_fts_query(""), None);
}
#[test]
fn sanitize_only_whitespace() {
assert_eq!(sanitize_fts_query(" "), None);
}
#[test]
fn sanitize_question_mark_in_chat() {
let result = sanitize_fts_query("how are you?");
assert_eq!(result, Some("\"how\" OR \"are\" OR \"you\"".to_string()));
}
#[test]
fn sanitize_question_mark_mid_sentence() {
let result = sanitize_fts_query("what did i say about rust?");
assert_eq!(
result,
Some("\"what\" OR \"did\" OR \"i\" OR \"say\" OR \"about\" OR \"rust\"".to_string())
);
}
#[test]
fn sanitize_version_number_with_dot() {
let result = sanitize_fts_query("llama3.1");
assert_eq!(result, Some("\"llama3\" OR \"1\"".to_string()));
}
#[test]
fn sanitize_quotes() {
let result = sanitize_fts_query(r#"he said "hello" to me"#);
assert_eq!(
result,
Some("\"he\" OR \"said\" OR \"hello\" OR \"to\" OR \"me\"".to_string())
);
}
#[test]
fn sanitize_parentheses() {
let result = sanitize_fts_query("function(arg1, arg2)");
assert_eq!(
result,
Some("\"function\" OR \"arg1\" OR \"arg2\"".to_string())
);
}
#[test]
fn sanitize_colons_and_dashes() {
let result = sanitize_fts_query("key:value foo-bar");
assert_eq!(
result,
Some("\"key\" OR \"value\" OR \"foo\" OR \"bar\"".to_string())
);
}
#[test]
fn sanitize_slashes() {
let result = sanitize_fts_query("path/to/file");
assert_eq!(result, Some("\"path\" OR \"to\" OR \"file\"".to_string()));
}
#[test]
fn sanitize_mixed_punctuation() {
let result = sanitize_fts_query("wait... what?! (really?)");
assert_eq!(
result,
Some("\"wait\" OR \"what\" OR \"really\"".to_string())
);
}
#[test]
fn sanitize_only_punctuation() {
assert_eq!(sanitize_fts_query("?!@#$%^&*()"), None);
}
#[test]
fn sanitize_underscores_preserved() {
let result = sanitize_fts_query("my_variable");
assert_eq!(result, Some("\"my_variable\"".to_string()));
}
#[test]
fn message_dedup_key_includes_session_scope() {
let a = SearchSource::Message {
message_id: 7,
session_id: "session-a".to_string(),
role: "user".to_string(),
};
let b = SearchSource::Message {
message_id: 7,
session_id: "session-b".to_string(),
role: "user".to_string(),
};
assert_ne!(source_dedup_key(&a), source_dedup_key(&b));
}
#[tokio::test]
async fn fts_search_with_question_mark_does_not_crash() {
let (store, _tmp) = test_store();
store
.add_fact("general", "I am doing well today", None, None)
.await
.unwrap();
let results = store
.search_fts_only("how are you?", None, None, None)
.await
.unwrap();
let _ = results;
}
#[tokio::test]
async fn fts_search_with_assorted_punctuation() {
let (store, _tmp) = test_store();
store
.add_fact("general", "Rust is a systems language", None, None)
.await
.unwrap();
let queries = vec![
"what did i say about rust?",
"llama3.1",
r#"he said "hello""#,
"function(arg)",
"key:value",
"path/to/file",
"wait... what?! (really?)",
];
for q in queries {
let result = store.search_fts_only(q, None, None, None).await;
assert!(
result.is_ok(),
"Query {:?} should not error: {:?}",
q,
result.err()
);
}
}
#[test]
fn rrf_fusion_order() {
use semantic_memory::search::{rrf_fuse, Bm25Hit, VectorHit};
use semantic_memory::{SearchConfig, SearchSource};
let make_fact_source = |id: &str| SearchSource::Fact {
fact_id: id.to_string(),
namespace: "test".to_string(),
};
let bm25_hits = vec![
Bm25Hit {
id: "A".to_string(),
content: "content A".to_string(),
source: make_fact_source("A"),
raw_score: 0.1,
updated_at: None,
},
Bm25Hit {
id: "B".to_string(),
content: "content B".to_string(),
source: make_fact_source("B"),
raw_score: 0.2,
updated_at: None,
},
Bm25Hit {
id: "C".to_string(),
content: "content C".to_string(),
source: make_fact_source("C"),
raw_score: 0.3,
updated_at: None,
},
];
let vector_hits = vec![
VectorHit {
id: "B".to_string(),
content: "content B".to_string(),
source: make_fact_source("B"),
similarity: 0.9,
updated_at: None,
source_rank: Some(1),
source_similarity: Some(0.9),
reranked_from_f32: false,
},
VectorHit {
id: "D".to_string(),
content: "content D".to_string(),
source: make_fact_source("D"),
similarity: 0.8,
updated_at: None,
source_rank: Some(2),
source_similarity: Some(0.8),
reranked_from_f32: false,
},
VectorHit {
id: "A".to_string(),
content: "content A".to_string(),
source: make_fact_source("A"),
similarity: 0.7,
updated_at: None,
source_rank: Some(3),
source_similarity: Some(0.7),
reranked_from_f32: false,
},
];
let config = SearchConfig::default();
let results = rrf_fuse(&bm25_hits, &vector_hits, &config, 10);
assert_eq!(results.len(), 4);
let ids: Vec<String> = results
.iter()
.map(|r| match &r.source {
SearchSource::Fact { fact_id, .. } => fact_id.clone(),
SearchSource::Chunk { chunk_id, .. } => chunk_id.clone(),
SearchSource::Message { message_id, .. } => message_id.to_string(),
SearchSource::Episode { document_id, .. } => document_id.clone(),
SearchSource::Projection { projection_id, .. } => projection_id.clone(),
})
.collect();
assert_eq!(ids, vec!["B", "A", "D", "C"]);
assert!(results[0].score > results[1].score);
}
#[tokio::test]
async fn hybrid_search_finds_facts() {
let (store, _tmp) = test_store();
store
.add_fact(
"general",
"Rust is a systems programming language",
None,
None,
)
.await
.unwrap();
store
.add_fact("general", "Python is great for data science", None, None)
.await
.unwrap();
store
.add_fact("general", "JavaScript runs in browsers", None, None)
.await
.unwrap();
let results = store
.search("systems programming", None, None, None)
.await
.unwrap();
assert!(!results.is_empty(), "Hybrid search should return results");
}
#[tokio::test]
async fn fts_only_search() {
let (store, _tmp) = test_store();
store
.add_fact(
"general",
"Rust is a systems programming language",
None,
None,
)
.await
.unwrap();
store
.add_fact("general", "Python is great for data science", None, None)
.await
.unwrap();
let results = store
.search_fts_only("Rust systems", None, None, None)
.await
.unwrap();
assert!(!results.is_empty());
assert!(results[0].content.contains("Rust"));
}
#[tokio::test]
async fn search_with_namespace_filter() {
let (store, _tmp) = test_store();
store
.add_fact("ns_a", "Fact in namespace A about dogs", None, None)
.await
.unwrap();
store
.add_fact("ns_b", "Fact in namespace B about dogs", None, None)
.await
.unwrap();
let results = store
.search_fts_only("dogs", None, Some(&["ns_a"]), None)
.await
.unwrap();
assert_eq!(results.len(), 1, "Should only find fact in namespace A");
}
#[tokio::test]
async fn search_with_source_type_filter() {
let (store, _tmp) = test_store();
store
.add_fact(
"general",
"This is a fact about quantum physics",
None,
None,
)
.await
.unwrap();
let results = store
.search_fts_only(
"quantum physics",
None,
None,
Some(&[SearchSourceType::Facts]),
)
.await
.unwrap();
assert!(!results.is_empty());
let results = store
.search_fts_only(
"quantum physics",
None,
None,
Some(&[SearchSourceType::Chunks]),
)
.await
.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn empty_query_returns_empty_results() {
let (store, _tmp) = test_store();
store
.add_fact("general", "Some content", None, None)
.await
.unwrap();
let results = store.search_fts_only("", None, None, None).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn special_chars_only_query_returns_empty() {
let (store, _tmp) = test_store();
store
.add_fact("general", "Some content", None, None)
.await
.unwrap();
let results = store
.search_fts_only("\"*+-()^{}~:", None, None, None)
.await
.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn parameterized_namespace_adversarial() {
let (store, _tmp) = test_store();
store
.add_fact("safe", "Safe fact about cats", None, None)
.await
.unwrap();
store
.add_fact("also-safe", "Also safe fact about cats", None, None)
.await
.unwrap();
store
.add_fact(
"it's-a-test",
"Adversarial namespace fact about cats",
None,
None,
)
.await
.unwrap();
let results = store
.search_fts_only("cats", None, Some(&["it's-a-test"]), None)
.await
.unwrap();
assert_eq!(
results.len(),
1,
"Should find fact in adversarial namespace"
);
assert!(results[0].content.contains("Adversarial"));
let results = store
.search_fts_only("cats", None, Some(&["safe"]), None)
.await
.unwrap();
assert_eq!(
results.len(),
1,
"Should only find fact in 'safe' namespace"
);
assert!(results[0].content.contains("Safe fact"));
}
#[tokio::test]
async fn dedup_removes_duplicate_content() {
let (store, _tmp) = test_store();
store
.add_fact("general", "Rust was released in 2015", None, None)
.await
.unwrap();
store
.ingest_document(
"Rust History",
"Rust was released in 2015",
"general",
None,
None,
)
.await
.unwrap();
let results = store
.search("Rust released", None, None, None)
.await
.unwrap();
assert_eq!(
results.len(),
2,
"Should keep results from different source types even with identical content"
);
}
#[tokio::test]
async fn dedup_keeps_different_content() {
let (store, _tmp) = test_store();
store
.add_fact(
"general",
"Rust was released as a language in 2015",
None,
None,
)
.await
.unwrap();
store
.add_fact(
"general",
"Go was released as a language in 2009",
None,
None,
)
.await
.unwrap();
let results = store
.search_fts_only("released language", None, None, None)
.await
.unwrap();
assert_eq!(
results.len(),
2,
"Should keep both results since content is different"
);
}
fn test_store_with_recency(half_life: Option<f64>, recency_weight: f64) -> (MemoryStore, TempDir) {
let tmp = TempDir::new().unwrap();
let config = MemoryConfig {
base_dir: tmp.path().to_path_buf(),
search: SearchConfig {
recency_half_life_days: half_life,
recency_weight,
..Default::default()
},
..Default::default()
};
let embedder = Box::new(MockEmbedder::new(768));
let store = MemoryStore::open_with_embedder(config, embedder).unwrap();
(store, tmp)
}
#[tokio::test]
async fn recency_disabled_no_effect() {
let (store, _tmp) = test_store_with_recency(None, 0.5);
store
.add_fact("general", "Recency test fact alpha", None, None)
.await
.unwrap();
let results = store
.search_fts_only("Recency test fact", None, None, None)
.await
.unwrap();
assert!(!results.is_empty());
let expected_score = 1.0 / (60.0 + 1.0); assert!(
(results[0].score - expected_score).abs() < 0.0001,
"Score should be pure BM25 RRF score without recency, got {} expected {}",
results[0].score,
expected_score
);
}
#[cfg(feature = "testing")]
#[tokio::test]
async fn recency_boosts_recent_facts() {
let (store, _tmp) = test_store_with_recency(Some(30.0), 0.5);
let fact_a_id = store
.add_fact(
"general",
"Recency quantum computing breakthrough",
None,
None,
)
.await
.unwrap();
let fact_b_id = store
.add_fact("general", "Recency quantum computing discovery", None, None)
.await
.unwrap();
let sixty_days_ago = (chrono::Utc::now() - chrono::Duration::days(60))
.format("%Y-%m-%d %H:%M:%S")
.to_string();
store
.raw_execute(
"UPDATE facts SET updated_at = ?1 WHERE id = ?2",
vec![sixty_days_ago, fact_b_id.clone()],
)
.await
.unwrap();
let results = store
.search("quantum computing", None, None, None)
.await
.unwrap();
assert!(results.len() >= 2, "Should find both facts");
let score_a = results
.iter()
.find(|r| match &r.source {
SearchSource::Fact { fact_id, .. } => fact_id == &fact_a_id,
_ => false,
})
.map(|r| r.score);
let score_b = results
.iter()
.find(|r| match &r.source {
SearchSource::Fact { fact_id, .. } => fact_id == &fact_b_id,
_ => false,
})
.map(|r| r.score);
assert!(
score_a.unwrap() > score_b.unwrap(),
"Recent fact A ({}) should score higher than old fact B ({})",
score_a.unwrap(),
score_b.unwrap()
);
}
#[tokio::test]
async fn recency_zero_half_life_is_rejected() {
let tmp = TempDir::new().unwrap();
let config = MemoryConfig {
base_dir: tmp.path().to_path_buf(),
search: SearchConfig {
recency_half_life_days: Some(0.0),
recency_weight: 0.5,
..Default::default()
},
..Default::default()
};
let embedder = Box::new(MockEmbedder::new(768));
let err = match MemoryStore::open_with_embedder(config, embedder) {
Ok(_) => panic!("zero recency half-life should be rejected"),
Err(err) => err,
};
assert_eq!(err.kind(), "invalid_config");
}
#[tokio::test]
async fn invalid_ollama_url_is_rejected() {
let tmp = TempDir::new().unwrap();
let config = MemoryConfig {
base_dir: tmp.path().to_path_buf(),
embedding: semantic_memory::EmbeddingConfig {
ollama_url: "not a url".to_string(),
..Default::default()
},
..Default::default()
};
let embedder = Box::new(MockEmbedder::new(768));
let err = match MemoryStore::open_with_embedder(config, embedder) {
Ok(_) => panic!("invalid Ollama URL should be rejected"),
Err(err) => err,
};
assert_eq!(err.kind(), "invalid_config");
}
#[tokio::test]
async fn test_vector_search_buffer_reuse_correctness() {
let (store, _tmp) = test_store();
for i in 0..100 {
store
.add_fact(
"general",
&format!("Buffer reuse test fact number {}", i),
None,
None,
)
.await
.unwrap();
}
let results = store
.search("Buffer reuse test fact", None, None, None)
.await
.unwrap();
assert!(!results.is_empty(), "Should find facts with buffer reuse");
for result in &results {
assert!(
result.score.is_finite(),
"Score should be finite, got {}",
result.score
);
assert!(
result.score >= 0.0,
"Score should be non-negative, got {}",
result.score
);
}
for i in 1..results.len() {
assert!(
results[i - 1].score >= results[i].score,
"Results should be ordered by score descending: {} < {}",
results[i - 1].score,
results[i].score
);
}
}
#[tokio::test]
async fn test_vector_search_completes_with_many_rows() {
let (store, _tmp) = test_store();
for i in 0..100 {
store
.add_fact(
"general",
&format!("Row count test fact number {}", i),
None,
None,
)
.await
.unwrap();
}
let results = store
.search("Row count test fact", None, None, None)
.await
.unwrap();
assert!(
!results.is_empty(),
"Search should complete successfully with many rows"
);
}