use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use arc_swap::ArcSwap;
use tokio::sync::Semaphore;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
use post_cortex_core::session::active_session::ActiveSession;
#[cfg(feature = "embeddings")]
use crate::content_vectorizer::{ContentVectorizer, ContentVectorizerConfig};
#[cfg(feature = "embeddings")]
use post_cortex_embeddings::EmbeddingConfig;
#[cfg(feature = "embeddings")]
use post_cortex_embeddings::VectorDbConfig;
use super::system::ConversationMemorySystem;
const MAX_VECTORIZATION_RETRIES: u32 = 3;
const VECTORIZATION_RETRY_DELAY_MS: u64 = 100;
const MAX_VECTORIZER_INIT_RETRIES: u32 = 10;
const MAX_PARALLEL_VECTORIZATION: usize = 4;
impl ConversationMemorySystem {
#[cfg(feature = "embeddings")]
pub async fn spawn_background_vectorization(
&self,
session_id: Uuid,
session_arc: Arc<ArcSwap<ActiveSession>>,
) {
if !self.config.enable_embeddings || !self.config.auto_vectorize_on_update {
return;
}
let vectorizer = match self.ensure_vectorizer_initialized().await {
Ok(v) => v,
Err(e) => {
debug!("Vectorizer init failed (non-fatal): {}", e);
return;
}
};
let storage_actor = self.storage_actor.clone();
tokio::spawn(async move {
let session = session_arc.load();
match vectorizer.vectorize_latest_update(&session).await {
Ok(count) if count > 0 => {
let _ = vectorizer.invalidate_session_cache(session_id).await;
storage_actor.persist_session_and_update_nowait((**session).clone(), vec![]);
debug!(
"Background vectorization: {} update(s) for session {}",
count, session_id
);
}
Ok(_) => {}
Err(e) => {
debug!(
"Background vectorization failed for session {}: {}",
session_id, e
);
}
}
});
}
#[cfg(feature = "embeddings")]
pub(crate) async fn ensure_vectorizer_initialized(
&self,
) -> Result<Arc<ContentVectorizer>, String> {
if let Some(vectorizer) = self.content_vectorizer.get() {
return Ok(Arc::clone(vectorizer));
}
let attempt = self
.embedding_config_holder
.init_attempt_count
.load(Ordering::Relaxed)
+ 1;
if attempt > MAX_VECTORIZER_INIT_RETRIES as u64 + 1 {
if let Some(last_error) = self.embedding_config_holder.last_init_error.read().as_ref() {
return Err(format!(
"Vectorizer initialization failed after {} attempts. Last error: {}",
attempt - 1,
last_error
));
}
return Err(format!(
"Vectorizer initialization failed after {} attempts",
attempt - 1
));
}
info!(
"Lazy-initializing content vectorizer (attempt {}/{})...",
attempt,
MAX_VECTORIZER_INIT_RETRIES + 1
);
let result: Result<&Arc<ContentVectorizer>, String> = self
.content_vectorizer
.get_or_try_init(|| async {
let embedding_config = EmbeddingConfig {
model_type: self.embedding_config_holder.model_type,
max_batch_size: 32,
..Default::default()
};
let vector_db_config = VectorDbConfig {
dimension: self.embedding_config_holder.vector_dimension,
max_vectors: self.embedding_config_holder.max_vectors_per_session,
..Default::default()
};
let vectorizer_config = ContentVectorizerConfig {
embedding_config,
vector_db_config,
enable_cross_session_search: self
.embedding_config_holder
.cross_session_search_enabled,
..Default::default()
};
let mut vectorizer = ContentVectorizer::new(vectorizer_config)
.await
.map_err(|e| format!("Failed to initialize content vectorizer: {}", e))?;
vectorizer.set_persistent_storage(self.vector_storage.clone());
match vectorizer.load_all_embeddings_from_storage().await {
Ok(count) => {
if count > 0 {
info!("Loaded {} persisted embeddings from storage during initialization", count);
}
}
Err(e) => {
warn!("Failed to load persisted embeddings (non-fatal, will re-vectorize on demand): {}", e);
}
}
Ok(Arc::new(vectorizer))
})
.await;
match result {
Ok(vectorizer) => {
info!(
"Content vectorizer initialized successfully on attempt {}",
attempt
);
*self.embedding_config_holder.last_init_error.write() = None;
Ok(Arc::clone(vectorizer))
}
Err(e) => {
let real_attempt = self
.embedding_config_holder
.init_attempt_count
.fetch_add(1, Ordering::Relaxed)
+ 1;
*self.embedding_config_holder.last_init_error.write() = Some(e.clone());
error!(
"Vectorizer initialization failed on attempt {}: {}",
real_attempt, e
);
Err(e)
}
}
}
#[cfg(feature = "embeddings")]
pub async fn ensure_semantic_engine_initialized(
&self,
) -> Result<Arc<crate::semantic_query_engine::SemanticQueryEngine>, String> {
if let Some(engine) = self.semantic_query_engine.get() {
return Ok(Arc::clone(engine));
}
let vectorizer = self.ensure_vectorizer_initialized().await?;
self.semantic_query_engine
.get_or_try_init(|| async {
info!("Lazy-initializing semantic query engine...");
use crate::semantic_query_engine::{SemanticQueryConfig, SemanticQueryEngine};
let config = SemanticQueryConfig {
cross_session_enabled: self
.embedding_config_holder
.cross_session_search_enabled,
similarity_threshold: self.config.semantic_search_threshold,
..Default::default()
};
let engine = SemanticQueryEngine::new((*vectorizer).clone(), config);
Ok(Arc::new(engine))
})
.await
.map(Arc::clone)
}
#[cfg(feature = "embeddings")]
pub async fn vectorize_session(&self, session_id: Uuid) -> Result<usize, String> {
let _timer = self.performance_monitor.start_timer("vectorize_session");
let vectorizer = self.ensure_vectorizer_initialized().await?;
let session_result = self.get_session(session_id).await?;
let session = session_result.load();
match vectorizer.vectorize_session(&session).await {
Ok(count) => {
info!("Vectorized {} items for session {}", count, session_id);
Ok(count)
}
Err(e) => {
warn!("Failed to vectorize session {session_id}: {e}");
Err(format!("Vectorization failed: {e}"))
}
}
}
#[cfg(feature = "embeddings")]
pub async fn auto_vectorize_if_enabled(&self, session_id: Uuid) -> Result<(), String> {
if !self.config.enable_embeddings || !self.config.auto_vectorize_on_update {
return Ok(());
}
let vectorizer = match self.ensure_vectorizer_initialized().await {
Ok(v) => v,
Err(e) => {
warn!("Failed to initialize vectorizer: {}", e);
return Ok(()); }
};
let session_arc = match self.get_session(session_id).await {
Ok(s) => s,
Err(e) => {
warn!(
"Failed to load session {} for vectorization: {}",
session_id, e
);
return Ok(()); }
};
let session = session_arc.load();
let mut last_error = None;
for attempt in 1..=MAX_VECTORIZATION_RETRIES {
match vectorizer.vectorize_latest_update(&session).await {
Ok(count) => {
info!(
"Incrementally vectorized {} update(s) for session {} (attempt {})",
count, session_id, attempt
);
if count > 0 {
if let Err(e) = vectorizer.invalidate_session_cache(session_id).await {
debug!(
"Cache invalidation for session {} (non-critical): {}",
session_id, e
);
}
self.storage_actor
.persist_session_and_update_nowait((**session).clone(), vec![]);
debug!("Session {} vectorization persist enqueued", session_id);
}
return Ok(());
}
Err(e) => {
last_error = Some(e.to_string());
if attempt < MAX_VECTORIZATION_RETRIES {
let delay_ms = VECTORIZATION_RETRY_DELAY_MS * (1 << (attempt - 1));
debug!(
"Vectorization attempt {} failed for session {}, retrying in {}ms: {}",
attempt, session_id, delay_ms, e
);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
}
}
}
if let Some(error) = last_error {
warn!(
"Incremental vectorization failed for session {} after {} retries: {}",
session_id, MAX_VECTORIZATION_RETRIES, error
);
}
Ok(()) }
#[cfg(feature = "embeddings")]
pub async fn vectorize_all_sessions(&self) -> Result<(usize, usize, usize), String> {
info!("Starting full vectorization of all sessions (parallel mode)");
let start_time = std::time::Instant::now();
let vectorizer = self.ensure_vectorizer_initialized().await?;
let session_ids = self.list_sessions().await?;
let total_sessions = session_ids.len();
if total_sessions == 0 {
info!("No sessions found to vectorize");
return Ok((0, 0, 0));
}
info!(
"Found {} sessions to vectorize (max {} parallel tasks)",
total_sessions, MAX_PARALLEL_VECTORIZATION
);
let total_vectorized = Arc::new(AtomicUsize::new(0));
let successful_sessions = Arc::new(AtomicUsize::new(0));
let failed_sessions = Arc::new(AtomicUsize::new(0));
let processed_count = Arc::new(AtomicUsize::new(0));
let semaphore = Arc::new(Semaphore::new(MAX_PARALLEL_VECTORIZATION));
let mut handles = Vec::with_capacity(total_sessions);
for session_id in session_ids {
let vectorizer = Arc::clone(&vectorizer);
let semaphore = Arc::clone(&semaphore);
let total_vectorized = Arc::clone(&total_vectorized);
let successful_sessions = Arc::clone(&successful_sessions);
let failed_sessions = Arc::clone(&failed_sessions);
let processed_count = Arc::clone(&processed_count);
let session_data = match self.get_session(session_id).await {
Ok(arc) => Some(arc.load().as_ref().clone()),
Err(e) => {
failed_sessions.fetch_add(1, Ordering::Relaxed);
let count = processed_count.fetch_add(1, Ordering::Relaxed) + 1;
warn!(
"[{}/{}] Failed to load session {}: {}",
count, total_sessions, session_id, e
);
None
}
};
if let Some(session) = session_data {
let handle = tokio::spawn(async move {
let _permit = semaphore.acquire().await.expect("Semaphore closed");
let already_vectorized = vectorizer.is_session_vectorized(session_id);
if already_vectorized {
let existing_count = vectorizer.count_session_embeddings(session_id);
debug!(
"Session {} already has {} embeddings, re-vectorizing...",
session_id, existing_count
);
}
match vectorizer.vectorize_session(&session).await {
Ok(count) => {
total_vectorized.fetch_add(count, Ordering::Relaxed);
successful_sessions.fetch_add(1, Ordering::Relaxed);
let processed = processed_count.fetch_add(1, Ordering::Relaxed) + 1;
info!(
"[{}/{}] Vectorized {} items for session {}",
processed, total_sessions, count, session_id
);
}
Err(e) => {
failed_sessions.fetch_add(1, Ordering::Relaxed);
let processed = processed_count.fetch_add(1, Ordering::Relaxed) + 1;
warn!(
"[{}/{}] Failed to vectorize session {}: {}",
processed, total_sessions, session_id, e
);
}
}
});
handles.push(handle);
}
}
for handle in handles {
let _ = handle.await;
}
if let Err(e) = vectorizer.clear_query_cache().await {
warn!(
"Failed to clear query cache after bulk vectorization: {}",
e
);
}
let elapsed = start_time.elapsed();
let total = total_vectorized.load(Ordering::Relaxed);
let success = successful_sessions.load(Ordering::Relaxed);
let failed = failed_sessions.load(Ordering::Relaxed);
info!(
"Bulk vectorization complete in {:.2}s: {} total items across {} successful sessions ({} failed)",
elapsed.as_secs_f64(),
total,
success,
failed
);
Ok((total, success, failed))
}
#[cfg(feature = "embeddings")]
pub async fn semantic_search_global(
&self,
query: &str,
limit: Option<usize>,
date_range: Option<(chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)>,
recency_bias: Option<f32>,
) -> Result<Vec<crate::content_vectorizer::SemanticSearchResult>, String> {
let _timer = self
.performance_monitor
.start_timer("semantic_search_global");
let vectorizer = self.ensure_vectorizer_initialized().await?;
let options = crate::content_vectorizer::SearchOptions {
limit: Some(limit.unwrap_or(20)),
date_range,
recency_bias,
};
match vectorizer
.semantic_search(query, limit.unwrap_or(20), None, options)
.await
{
Ok(results) => Ok(results),
Err(e) => Err(format!("Semantic search failed: {e}")),
}
}
#[cfg(feature = "embeddings")]
pub async fn semantic_search_session(
&self,
session_id: Uuid,
query: &str,
limit: Option<usize>,
date_range: Option<(chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)>,
recency_bias: Option<f32>,
) -> Result<Vec<crate::content_vectorizer::SemanticSearchResult>, String> {
let _timer = self
.performance_monitor
.start_timer("semantic_search_session");
let vectorizer = self.ensure_vectorizer_initialized().await?;
let session_arc = self.get_session(session_id).await?;
if !vectorizer.is_session_vectorized(session_id) {
info!(
"Session {} not vectorized, auto-vectorizing before search",
session_id
);
if let Err(e) = self.vectorize_session(session_id).await {
warn!(
"Auto-vectorization failed for session {}: {}",
session_id, e
);
}
}
let options = crate::content_vectorizer::SearchOptions {
limit: Some(limit.unwrap_or(20)),
date_range,
recency_bias,
};
match vectorizer
.semantic_search(query, limit.unwrap_or(20), Some(session_id), options)
.await
{
Ok(results) => {
let session = session_arc.load();
Ok(enrich_results_with_graph(&session, query, results))
}
Err(e) => Err(format!("Session semantic search failed: {e}")),
}
}
#[cfg(feature = "embeddings")]
pub async fn find_related_content(
&self,
session_id: Uuid,
topic: &str,
limit: Option<usize>,
) -> Result<Vec<crate::content_vectorizer::SemanticSearchResult>, String> {
let _timer = self.performance_monitor.start_timer("find_related_content");
let vectorizer = self.ensure_vectorizer_initialized().await?;
let session_result = self.get_session(session_id).await?;
let session = session_result.load();
if !vectorizer.is_session_vectorized(session_id) {
info!(
"Session {} not vectorized, auto-vectorizing before related content search",
session_id
);
if let Err(e) = self.vectorize_session(session_id).await {
warn!(
"Auto-vectorization failed for session {}: {}",
session_id, e
);
}
}
match vectorizer
.find_related_content(&session, topic, limit.unwrap_or(10))
.await
{
Ok(results) => Ok(results),
Err(e) => Err(format!("Related content search failed: {e}")),
}
}
#[cfg(feature = "embeddings")]
pub async fn semantic_search_multisession(
&self,
session_ids: &[Uuid],
query: &str,
limit: Option<usize>,
date_range: Option<(chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)>,
recency_bias: Option<f32>,
) -> Result<Vec<crate::content_vectorizer::SemanticSearchResult>, String> {
let _timer = self
.performance_monitor
.start_timer("semantic_search_multisession");
let vectorizer = self.ensure_vectorizer_initialized().await?;
let options = crate::content_vectorizer::SearchOptions {
limit: Some(limit.unwrap_or(20)),
date_range,
recency_bias,
};
match vectorizer
.semantic_search_multisession(query, limit.unwrap_or(20), session_ids, options)
.await
{
Ok(results) => Ok(results),
Err(e) => Err(format!("Multisession semantic search failed: {e}")),
}
}
#[cfg(feature = "embeddings")]
pub fn get_vectorization_stats(
&self,
) -> Result<std::collections::HashMap<String, usize>, String> {
if let Some(vectorizer) = self.content_vectorizer.get() {
Ok(vectorizer.get_vectorization_stats())
} else {
Err("Embeddings not initialized yet (call any vectorization method first)".to_string())
}
}
pub fn embeddings_enabled(&self) -> bool {
self.config.enable_embeddings && cfg!(feature = "embeddings") && {
#[cfg(feature = "embeddings")]
{
self.content_vectorizer.get().is_some()
}
#[cfg(not(feature = "embeddings"))]
{
false
}
}
}
pub async fn enable_embeddings_config(&mut self) -> Result<(), String> {
if !cfg!(feature = "embeddings") {
return Err("Embeddings feature not compiled in".to_string());
}
self.config.enable_embeddings = true;
Ok(())
}
pub async fn set_embedding_model(&mut self, model_type: String) -> Result<(), String> {
self.config.embeddings_model_type = model_type;
Ok(())
}
pub async fn invalidate_and_rebuild_entity_graph(
&self,
session_id: Uuid,
file_path: &str,
) -> Result<(u32, usize), String> {
let entries_invalidated = self.storage_actor.invalidate_source(file_path).await?;
let session_arc = self
.session_manager
.get_or_create_session(session_id)
.await?;
let current = session_arc.load();
let mut new_session = (**current).clone();
let removed = new_session.remove_updates_for_file(file_path);
if removed > 0 {
match new_session.rebuild_entity_graph_from_updates().await {
Ok((before, after)) => {
info!(
"Invalidate+rebuild for {}: {} source refs, {} updates removed, entities {} -> {}",
file_path, entries_invalidated, removed, before, after,
);
let entities_after = after;
let new_arc = Arc::new(new_session);
let prev = session_arc.compare_and_swap(¤t, Arc::clone(&new_arc));
if Arc::ptr_eq(&prev, ¤t) {
self.storage_actor
.persist_session_and_update_nowait((*new_arc).clone(), vec![]);
} else {
warn!(
"CAS failed during invalidate+rebuild for session {}",
session_id
);
}
Ok((entries_invalidated, entities_after))
}
Err(e) => {
warn!("Entity graph rebuild failed after invalidation: {}", e);
Ok((entries_invalidated, 0))
}
}
} else {
debug!(
"No updates reference file {}, skipping entity graph rebuild",
file_path
);
Ok((entries_invalidated, new_session.entity_graph.entity_count()))
}
}
pub async fn clear_query_cache(&self) -> Result<(), String> {
#[cfg(feature = "embeddings")]
{
if let Some(vectorizer) = self.content_vectorizer.get() {
vectorizer
.clear_query_cache()
.await
.map_err(|e| format!("Failed to clear query cache: {}", e))?;
info!("Query cache cleared successfully");
}
}
Ok(())
}
}
#[cfg(feature = "embeddings")]
fn extract_entities(text: &str) -> Vec<String> {
text.split_whitespace()
.filter_map(|w| {
let clean = w.trim_matches(|c: char| !c.is_alphanumeric() && c != '_' && c != '-');
(clean.len() > 3).then(|| clean.to_lowercase())
})
.collect()
}
#[cfg(feature = "embeddings")]
fn enrich_results_with_graph(
session: &post_cortex_core::session::active_session::ActiveSession,
query: &str,
results: Vec<crate::content_vectorizer::SemanticSearchResult>,
) -> Vec<crate::content_vectorizer::SemanticSearchResult> {
let entity_graph = &session.entity_graph;
tracing::debug!(
"Graph-RAG: enrichment for {} results (graph has {} entities)",
results.len(),
entity_graph.entity_count()
);
let query_entities = extract_entities(query);
let mut global_graph_map: std::collections::HashMap<String, Vec<String>> =
std::collections::HashMap::new();
for q_entity in &query_entities {
let relations = entity_graph.find_related_entities(q_entity);
if !relations.is_empty() {
global_graph_map.insert(q_entity.clone(), relations);
}
}
let mut graph_insights = String::new();
if !global_graph_map.is_empty() {
graph_insights.push_str("\n[System Knowledge Map]:\n");
for (entity, rels) in &global_graph_map {
graph_insights.push_str(&format!(
"- {} is central to: {}\n",
entity,
rels.join(", ")
));
}
}
let mut enriched: Vec<_> = results
.into_iter()
.map(|mut result| {
let mut chunk_entities = extract_entities(&result.text_content);
chunk_entities.sort();
chunk_entities.dedup();
let mut local_rels = Vec::new();
for entity in chunk_entities.iter().take(2) {
if global_graph_map.contains_key(entity) {
continue;
}
let relations = entity_graph.find_related_entities(entity);
if !relations.is_empty() {
let limited: Vec<_> = relations.iter().take(5).cloned().collect();
local_rels.push(format!("{}: {}", entity, limited.join(", ")));
}
}
if !local_rels.is_empty() {
result.text_content = format!(
"{}\n(Graph expansion: {})",
result.text_content,
local_rels.join(" | ")
);
}
result
})
.collect();
if enriched.len() >= 2 {
let top1 = extract_entities(&enriched[0].text_content);
let top2 = extract_entities(&enriched[1].text_content);
if let (Some(e1), Some(e2)) = (top1.first(), top2.first())
&& e1 != e2
&& let Some(path) = entity_graph.find_shortest_path(e1, e2)
&& path.len() > 2
{
graph_insights.push_str(&format!(
"\n[Structural Insight]: Found connection: {}\n",
path.join(" -> ")
));
}
}
if !graph_insights.is_empty() && !enriched.is_empty() {
enriched[0].text_content = format!("{}{}\n---\n", graph_insights, enriched[0].text_content);
}
enriched
}