use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use crate::core::{ScopeFilter, ScopeIdentifiers, StoredFact};
use crate::embeddings::EmbeddingBase;
use crate::search::{MemoryKind, MemorySearch, ResultReranker, SearchRequest, SearchResult};
use crate::vector_store::{OutputData, VectorStoreBase};
use crate::Result;
pub struct SemanticSearch {
embedder: Arc<dyn EmbeddingBase>,
vector_store: Arc<dyn VectorStoreBase>,
reranker: Option<Arc<dyn ResultReranker>>,
}
impl SemanticSearch {
pub fn new(
embedder: Arc<dyn EmbeddingBase>,
vector_store: Arc<dyn VectorStoreBase>,
reranker: Option<Arc<dyn ResultReranker>>,
) -> Self {
Self {
embedder,
vector_store,
reranker,
}
}
fn merge_filters(
&self,
scope: &ScopeIdentifiers,
extra: &Option<HashMap<String, serde_json::Value>>,
) -> HashMap<String, serde_json::Value> {
let mut filters = ScopeFilter::from_scope(scope).to_map();
if let Some(extra_filters) = extra {
for (k, v) in extra_filters {
filters.insert(k.clone(), v.clone());
}
}
filters
}
async fn output_to_fact(&self, output: &OutputData) -> StoredFact {
let content = output.get_data().unwrap_or_default();
let hash = output.get_string("hash").unwrap_or_default();
let created_at_str = output.get_string("created_at").unwrap_or_default();
let updated_at_str = output.get_string("updated_at").unwrap_or_default();
let created_at = DateTime::parse_from_rfc3339(&created_at_str)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
let updated_at = DateTime::parse_from_rfc3339(&updated_at_str)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
let scope = ScopeIdentifiers {
user: output.get_string("user_id"),
agent: output.get_string("agent_id"),
session: output.get_string("session_id"),
};
let mut metadata = HashMap::new();
let core_keys = [
"data",
"hash",
"created_at",
"updated_at",
"user_id",
"agent_id",
"session_id",
];
for (k, v) in &output.payload {
if !core_keys.contains(&k.as_str()) {
metadata.insert(k.clone(), v.clone());
}
}
StoredFact {
id: output.id.clone(),
content,
scope,
embedding: None,
created_at,
updated_at,
content_hash: hash,
metadata,
relevance_score: output.score,
}
}
}
#[async_trait]
impl MemorySearch for SemanticSearch {
async fn search(&self, request: SearchRequest) -> Result<SearchResult> {
if request.kind != MemoryKind::Semantic {
return Ok(SearchResult::new(Vec::new()));
}
let query_embedding = self.embedder.embed(&request.query).await?;
let filters = self.merge_filters(&request.scope, &request.filters);
let fetch_limit = request.limit.saturating_mul(2);
let outputs = self
.vector_store
.search("", &query_embedding, fetch_limit.max(1), Some(filters))
.await?;
let mut facts = Vec::with_capacity(outputs.len());
for output in outputs {
facts.push(self.output_to_fact(&output).await);
}
if request.rerank {
if let Some(ref reranker) = self.reranker {
facts = reranker
.rerank(&request.query, facts, Some(request.limit))
.await?;
}
}
facts.truncate(request.limit);
Ok(SearchResult::new(facts))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::search::ResultReranker;
use crate::vector_store::base::Filters;
use parking_lot::Mutex;
use tokio_test::block_on;
struct FakeEmbedder;
#[async_trait]
impl EmbeddingBase for FakeEmbedder {
async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
Ok(vec![0.1, 0.2, 0.3])
}
fn embedding_dims(&self) -> usize {
3
}
}
#[derive(Clone)]
struct FakeVectorStore {
outputs: Vec<OutputData>,
recorded_filters: Arc<Mutex<Option<Filters>>>,
}
impl FakeVectorStore {
fn new(outputs: Vec<OutputData>) -> Self {
Self {
outputs,
recorded_filters: Arc::new(Mutex::new(None)),
}
}
}
#[async_trait]
impl VectorStoreBase for FakeVectorStore {
async fn create_collection(&self, _name: &str) -> Result<()> {
Ok(())
}
async fn insert(
&self,
_vectors: Vec<Vec<f32>>,
_payloads: Option<Vec<HashMap<String, serde_json::Value>>>,
_ids: Option<Vec<String>>,
) -> Result<()> {
Ok(())
}
async fn search(
&self,
_query: &str,
_vectors: &[f32],
limit: usize,
filters: Option<Filters>,
) -> Result<Vec<OutputData>> {
*self.recorded_filters.lock() = filters;
Ok(self.outputs.iter().cloned().take(limit).collect())
}
async fn delete(&self, _vector_id: &str) -> Result<()> {
Ok(())
}
async fn update(
&self,
_vector_id: &str,
_vector: Option<Vec<f32>>,
_payload: Option<HashMap<String, serde_json::Value>>,
) -> Result<()> {
Ok(())
}
async fn get(&self, _vector_id: &str) -> Result<Option<OutputData>> {
Ok(None)
}
async fn list_collections(&self) -> Result<Vec<String>> {
Ok(vec![])
}
async fn delete_collection(&self) -> Result<()> {
Ok(())
}
async fn collection_info(&self) -> Result<serde_json::Value> {
Ok(serde_json::json!({}))
}
async fn list(&self, _filters: Option<Filters>, _limit: usize) -> Result<Vec<OutputData>> {
Ok(vec![])
}
async fn reset(&self) -> Result<()> {
Ok(())
}
}
struct FakeReranker {
called: Arc<Mutex<bool>>,
}
impl FakeReranker {
fn new() -> Self {
Self {
called: Arc::new(Mutex::new(false)),
}
}
}
#[async_trait]
impl ResultReranker for FakeReranker {
async fn rerank(
&self,
_query: &str,
mut facts: Vec<StoredFact>,
limit: Option<usize>,
) -> Result<Vec<StoredFact>> {
*self.called.lock() = true;
facts.reverse();
if let Some(lim) = limit {
facts.truncate(lim);
}
Ok(facts)
}
}
fn make_output(id: &str, content: &str, score: f32) -> OutputData {
let now = Utc::now().to_rfc3339();
let mut payload = HashMap::new();
payload.insert("data".to_string(), serde_json::json!(content));
payload.insert("hash".to_string(), serde_json::json!("hash"));
payload.insert("created_at".to_string(), serde_json::json!(now));
payload.insert("updated_at".to_string(), serde_json::json!(now));
payload.insert("user_id".to_string(), serde_json::json!("user1"));
payload.insert("session_id".to_string(), serde_json::json!("session1"));
OutputData::new(id.to_string(), Some(score), payload)
}
#[test]
fn semantic_search_merges_filters_and_reranks() {
let embedder = Arc::new(FakeEmbedder);
let vector_store = Arc::new(FakeVectorStore::new(vec![
make_output("1", "first", 0.5),
make_output("2", "second", 0.9),
]));
let reranker = Arc::new(FakeReranker::new());
let search = SemanticSearch::new(embedder, vector_store.clone(), Some(reranker.clone()));
let scope = ScopeIdentifiers::for_user("user1").with_session("session1");
let request = SearchRequest::new("hello", scope, 1)
.with_filters(HashMap::from([(
"topic".to_string(),
serde_json::json!("rust"),
)]))
.with_rerank(true);
let result = block_on(search.search(request)).expect("search works");
assert_eq!(result.facts.len(), 1);
assert_eq!(result.facts[0].id, "2");
assert!(*reranker.called.lock());
let recorded = vector_store.recorded_filters.lock();
let recorded = recorded.as_ref().expect("filters recorded");
assert_eq!(recorded.get("user_id"), Some(&serde_json::json!("user1")));
assert_eq!(
recorded.get("session_id"),
Some(&serde_json::json!("session1"))
);
assert_eq!(recorded.get("topic"), Some(&serde_json::json!("rust")));
}
}