use anyhow::{Context, Result};
use chrono::Utc;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::atomic::Ordering;
use tracing::{debug, warn};
use uuid::Uuid;
use post_cortex_core::session::active_session::ActiveSession;
use post_cortex_embeddings::SearchMatch;
use super::types::{ContentType, SearchOptions, SemanticSearchResult};
use super::vectorizer::ContentVectorizer;
impl ContentVectorizer {
pub async fn semantic_search(
&self,
query: &str,
limit: usize,
session_filter: Option<Uuid>,
options: SearchOptions,
) -> Result<Vec<SemanticSearchResult>> {
let recency_bias = options.recency_bias.unwrap_or(self.config.recency_bias);
let date_range = options.date_range;
debug!("Performing semantic search for: '{}'", query);
let mut hasher = std::collections::hash_map::DefaultHasher::new();
query.hash(&mut hasher);
limit.hash(&mut hasher);
if let Some(ref session_id) = session_filter {
session_id.hash(&mut hasher);
}
if let Some(ref range) = date_range {
range.0.timestamp().hash(&mut hasher);
range.1.timestamp().hash(&mut hasher);
}
if recency_bias != 0.0 {
recency_bias.to_bits().hash(&mut hasher);
}
let params_hash = hasher.finish();
let query_embedding = if let Some(ref cache) = self.query_cache {
let query_embedding = self.embedding_engine.encode_text(query).await?;
if let Some(cached_results) = cache.search(query, &query_embedding, params_hash) {
debug!("Cache hit for semantic search query: '{}'", query);
return Ok(cached_results);
}
query_embedding
} else {
self.embedding_engine.encode_text(query).await?
};
let search_results = match (session_filter, date_range) {
(Some(session_id), Some((start, end))) => {
self.vector_db
.search_with_filter(&query_embedding, limit, |metadata| {
metadata.source == session_id.to_string()
&& metadata.timestamp >= start
&& metadata.timestamp <= end
})?
}
(Some(session_id), None) => {
self.vector_db
.search_in_source(&query_embedding, limit, &session_id.to_string())?
}
(None, Some((start, end))) if self.config.enable_cross_session_search => self
.vector_db
.search_with_filter(&query_embedding, limit, |metadata| {
metadata.timestamp >= start && metadata.timestamp <= end
})?,
(None, None) if self.config.enable_cross_session_search => {
self.vector_db.search(&query_embedding, limit)?
}
_ => return Ok(Vec::new()),
};
let results = self.process_search_results(search_results, Some(recency_bias))?;
if let Some(ref cache) = self.query_cache
&& let Err(e) = cache.cache_results(
query.to_string(),
query_embedding,
results.clone(),
params_hash,
session_filter,
)
{
warn!("Failed to cache search results: {}", e);
}
Ok(results)
}
pub async fn semantic_search_multisession(
&self,
query: &str,
limit: usize,
allowed_sessions: &[Uuid],
options: SearchOptions,
) -> Result<Vec<SemanticSearchResult>> {
let recency_bias = options.recency_bias.unwrap_or(self.config.recency_bias);
let date_range = options.date_range;
debug!(
"Performing multisession semantic search with recency_bias={:?} for: '{}'",
recency_bias, query
);
let mut hasher = std::collections::hash_map::DefaultHasher::new();
query.hash(&mut hasher);
limit.hash(&mut hasher);
for session_id in allowed_sessions {
session_id.hash(&mut hasher);
}
if let Some(ref range) = date_range {
range.0.timestamp().hash(&mut hasher);
range.1.timestamp().hash(&mut hasher);
}
if recency_bias != 0.0 {
recency_bias.to_bits().hash(&mut hasher);
}
let params_hash = hasher.finish();
let query_embedding = if let Some(ref cache) = self.query_cache {
let query_embedding = self.embedding_engine.encode_text(query).await?;
if let Some(cached_results) = cache.search(query, &query_embedding, params_hash) {
debug!(
"Cache hit for multisession semantic search query: '{}'",
query
);
return Ok(cached_results);
}
query_embedding
} else {
self.embedding_engine.encode_text(query).await?
};
let valid_sessions: std::collections::HashSet<String> =
allowed_sessions.iter().map(|id| id.to_string()).collect();
let search_results = if let Some((start, end)) = date_range {
self.vector_db
.search_with_filter(&query_embedding, limit, |metadata| {
valid_sessions.contains(&metadata.source)
&& metadata.timestamp >= start
&& metadata.timestamp <= end
})?
} else {
self.vector_db
.search_with_filter(&query_embedding, limit, |metadata| {
valid_sessions.contains(&metadata.source)
})?
};
let results = self.process_search_results(search_results, Some(recency_bias))?;
if let Some(ref cache) = self.query_cache
&& let Err(e) = cache.cache_results(
query.to_string(),
query_embedding,
results.clone(),
params_hash,
None, )
{
warn!("Failed to cache search results: {}", e);
}
Ok(results)
}
fn process_search_results(
&self,
search_results: Vec<SearchMatch>,
recency_bias_override: Option<f32>,
) -> Result<Vec<SemanticSearchResult>> {
let now = Utc::now();
let recency_bias = recency_bias_override
.unwrap_or(self.config.recency_bias)
.clamp(0.0, 10.0);
let score_adjuster: Option<Box<dyn crate::scoring::ScoreAdjuster>> = if recency_bias > 0.0 {
Some(Box::new(crate::scoring::TemporalDecayAdjuster::new(
recency_bias,
now,
)))
} else {
None
};
let decay_start = std::time::Instant::now();
let mut decay_count = 0;
let mut results = Vec::new();
for result in search_results {
let content_type = parse_content_type(&result.metadata.content_type);
let session_id = Uuid::parse_str(&result.metadata.source)
.context("Invalid session ID in metadata")?;
let importance_weight = content_type.importance_weight();
let base_score = result.similarity.mul_add(0.7, importance_weight * 0.3);
let combined_score = if let Some(ref adjuster) = score_adjuster {
decay_count += 1;
adjuster.adjust(base_score, &result.metadata)
} else {
base_score
};
results.push(SemanticSearchResult {
content_id: result.metadata.id,
session_id,
content_type,
text_content: result.metadata.text,
similarity_score: result.similarity,
importance_score: importance_weight,
timestamp: result.metadata.timestamp,
combined_score,
});
}
if decay_count > 0 {
let decay_duration_ns = decay_start.elapsed().as_nanos() as u64;
self.recency_bias_total_duration_ns
.fetch_add(decay_duration_ns, Ordering::Relaxed);
self.recency_bias_total_results
.fetch_add(decay_count, Ordering::Relaxed);
self.recency_bias_calculation_count
.fetch_add(1, Ordering::Relaxed);
}
let original_count = results.len();
let mut dedup_map: HashMap<String, SemanticSearchResult> =
HashMap::with_capacity(results.len());
for result in results {
let content_id = result.content_id.clone();
match dedup_map.entry(content_id) {
std::collections::hash_map::Entry::Occupied(mut entry) => {
let existing = entry.get();
if result.combined_score > existing.combined_score
|| (result.combined_score == existing.combined_score
&& result.timestamp > existing.timestamp)
{
entry.insert(result);
}
}
std::collections::hash_map::Entry::Vacant(entry) => {
entry.insert(result);
}
}
}
let mut results: Vec<SemanticSearchResult> = dedup_map.into_values().collect();
debug!(
"Deduplicated {} results to {} unique",
original_count,
results.len()
);
results.sort_by(|a, b| {
b.combined_score
.partial_cmp(&a.combined_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
debug!("Found {} semantic search results", results.len());
Ok(results)
}
pub async fn find_related_content(
&self,
session: &ActiveSession,
topic: &str,
limit: usize,
) -> Result<Vec<SemanticSearchResult>> {
let results = self
.semantic_search(topic, limit * 2, None, SearchOptions::default())
.await?;
let filtered: Vec<_> = results
.into_iter()
.filter(|r| r.session_id != session.id())
.take(limit)
.collect();
debug!(
"Found {} related content items across sessions for topic: '{}'",
filtered.len(),
topic
);
Ok(filtered)
}
pub async fn compute_text_similarity(&self, text1: &str, text2: &str) -> Result<f32> {
let embedding1 = self.embedding_engine.encode_text(text1).await?;
let embedding2 = self.embedding_engine.encode_text(text2).await?;
let dot_product: f32 = embedding1
.iter()
.zip(embedding2.iter())
.map(|(a, b)| a * b)
.sum();
let norm1: f32 = embedding1.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm2: f32 = embedding2.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm1 == 0.0 || norm2 == 0.0 {
return Ok(0.0);
}
Ok(dot_product / (norm1 * norm2))
}
pub fn get_vectorization_stats(&self) -> HashMap<String, usize> {
let db_stats = self.vector_db.get_stats();
let mut stats = HashMap::new();
stats.insert("total_vectors".to_string(), db_stats.total_vectors);
stats.insert(
"memory_usage_mb".to_string(),
db_stats.memory_usage_bytes / 1024 / 1024,
);
stats.insert(
"embedding_dimension".to_string(),
self.embedding_engine.embedding_dimension(),
);
stats
}
pub fn is_session_vectorized(&self, session_id: Uuid) -> bool {
self.vector_db
.has_session_update_embeddings(&session_id.to_string())
}
pub fn count_session_embeddings(&self, session_id: Uuid) -> usize {
self.vector_db
.count_session_embeddings(&session_id.to_string())
}
}
fn parse_content_type(content_type_str: &str) -> ContentType {
match content_type_str {
"EntityDescription" => ContentType::EntityDescription,
"UserMessage" => ContentType::UserMessage,
"DecisionPoint" => ContentType::DecisionPoint,
"CodeSnippet" => ContentType::CodeSnippet,
"ProblemSolution" => ContentType::ProblemSolution,
"SessionMetadata" => ContentType::SessionMetadata,
_ => ContentType::UpdateContent,
}
}