use anyhow::Result;
use chrono::Utc;
use super::git_utils::GitUtils;
use super::store::MemoryStore;
use super::types::{
Memory, MemoryConfig, MemoryMetadata, MemoryQuery, MemoryRelationship, MemorySearchResult,
MemoryType, RelationshipType,
};
use crate::config::Config;
use crate::embedding::{create_embedding_provider_from_parts, parse_provider_model};
pub struct MemoryManager {
store: MemoryStore,
config: MemoryConfig,
}
impl MemoryManager {
pub async fn new(config: &Config) -> Result<Self> {
let memory_config = MemoryConfig::default();
let current_dir = std::env::current_dir()?;
let db_path = crate::storage::get_project_database_path(¤t_dir)?;
let db_path_str = db_path.to_string_lossy().to_string();
let model_string = &config.embedding.text_model;
let (provider, model) = parse_provider_model(model_string);
let embedding_provider = create_embedding_provider_from_parts(&provider, &model).await?;
let store = MemoryStore::new(
&db_path_str,
embedding_provider,
memory_config.clone(),
config.clone(),
)
.await?;
Ok(Self {
store,
config: memory_config,
})
}
pub async fn with_config(config: &Config, memory_config: MemoryConfig) -> Result<Self> {
let current_dir = std::env::current_dir()?;
let db_path = crate::storage::get_project_database_path(¤t_dir)?;
let db_path_str = db_path.to_string_lossy().to_string();
let model_string = &config.embedding.text_model;
let (provider, model) = parse_provider_model(model_string);
let embedding_provider = create_embedding_provider_from_parts(&provider, &model).await?;
let store = MemoryStore::new(
&db_path_str,
embedding_provider,
memory_config.clone(),
config.clone(),
)
.await?;
Ok(Self {
store,
config: memory_config,
})
}
pub async fn memorize(
&mut self,
memory_type: MemoryType,
title: String,
content: String,
importance: Option<f32>,
tags: Option<Vec<String>>,
related_files: Option<Vec<String>>,
) -> Result<Memory> {
let mut metadata = MemoryMetadata {
git_commit: GitUtils::get_current_commit(),
importance: importance.unwrap_or(self.config.default_importance),
tags: tags.unwrap_or_default(),
related_files: Vec::new(), ..Default::default()
};
if let Some(files) = related_files {
metadata.related_files = files
.into_iter()
.map(|file| GitUtils::get_relative_path(&file).unwrap_or(file))
.collect();
}
if metadata.related_files.is_empty() {
if let Ok(modified_files) = GitUtils::get_modified_files() {
metadata.related_files = modified_files.into_iter().take(5).collect();
}
}
let memory = Memory::new(memory_type, title, content, Some(metadata));
self.store.store_memory(&memory).await?;
if self.config.auto_relationships {
self.create_automatic_relationships(&memory).await?;
}
Ok(memory)
}
pub async fn remember(
&self,
query: &str,
filters: Option<MemoryQuery>,
) -> Result<Vec<MemorySearchResult>> {
let mut search_query = filters.unwrap_or_default();
search_query.query_text = Some(query.to_string());
self.store.search_memories(&search_query).await
}
pub async fn remember_multi(
&self,
queries: &[String],
filters: Option<MemoryQuery>,
) -> Result<Vec<MemorySearchResult>> {
if queries.is_empty() {
return Ok(Vec::new());
}
if queries.len() == 1 {
return self.remember(&queries[0], filters).await;
}
let base_filters = filters.unwrap_or_default();
let mut all_results: std::collections::HashMap<String, MemorySearchResult> =
std::collections::HashMap::new();
let mut query_count: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
for query in queries {
let mut search_query = base_filters.clone();
search_query.query_text = Some(query.clone());
let results = self.store.search_memories(&search_query).await?;
for result in results {
let memory_id = result.memory.id.clone();
*query_count.entry(memory_id.clone()).or_insert(0) += 1;
match all_results.get(&memory_id) {
Some(existing) if existing.relevance_score >= result.relevance_score => {
}
_ => {
all_results.insert(memory_id, result);
}
}
}
}
let mut final_results: Vec<MemorySearchResult> = all_results
.into_iter()
.map(|(memory_id, mut result)| {
let matches = query_count.get(&memory_id).unwrap_or(&1);
if *matches > 1 {
let boost_factor = 1.0 + ((*matches as f32 - 1.0) * 0.1); result.relevance_score = (result.relevance_score * boost_factor).min(1.0);
result.selection_reason = format!(
"Matched {} of {} queries: {}",
matches,
queries.len(),
result.selection_reason
);
}
result
})
.collect();
final_results.sort_by(|a, b| {
b.relevance_score
.partial_cmp(&a.relevance_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
if let Some(limit) = base_filters.limit {
final_results.truncate(limit);
}
Ok(final_results)
}
pub async fn remember_advanced(&self, query: MemoryQuery) -> Result<Vec<MemorySearchResult>> {
self.store.search_memories(&query).await
}
pub async fn forget(&mut self, memory_id: &str) -> Result<()> {
self.store.delete_memory(memory_id).await
}
pub async fn forget_matching(&mut self, query: MemoryQuery) -> Result<usize> {
let search_results = self.store.search_memories(&query).await?;
let mut deleted_count = 0;
for result in search_results {
self.store.delete_memory(&result.memory.id).await?;
deleted_count += 1;
}
Ok(deleted_count)
}
pub async fn update_memory(
&mut self,
memory_id: &str,
title: Option<String>,
content: Option<String>,
metadata_updates: Option<MemoryMetadata>,
) -> Result<Option<Memory>> {
if let Some(mut memory) = self.store.get_memory(memory_id).await? {
let current_commit = GitUtils::get_current_commit();
if let Some(mut meta) = metadata_updates {
meta.git_commit = current_commit.clone();
memory.update(title, content, Some(meta));
} else if let Some(commit) = current_commit {
memory.metadata.git_commit = Some(commit);
memory.update(title, content, None);
} else {
memory.update(title, content, None);
}
self.store.update_memory(&memory).await?;
if self.config.auto_relationships {
self.update_automatic_relationships(&memory).await?;
}
Ok(Some(memory))
} else {
Ok(None)
}
}
pub async fn get_memory(&self, memory_id: &str) -> Result<Option<Memory>> {
self.store.get_memory(memory_id).await
}
pub async fn get_recent_memories(&self, limit: usize) -> Result<Vec<Memory>> {
let query = MemoryQuery {
limit: Some(limit),
sort_by: Some(super::types::MemorySortBy::CreatedAt),
sort_order: Some(super::types::SortOrder::Descending),
..Default::default()
};
let results = self.store.search_memories(&query).await?;
Ok(results.into_iter().map(|r| r.memory).collect())
}
pub async fn get_memories_by_type(
&self,
memory_type: MemoryType,
limit: Option<usize>,
) -> Result<Vec<Memory>> {
let query = MemoryQuery {
memory_types: Some(vec![memory_type]),
limit,
sort_by: Some(super::types::MemorySortBy::CreatedAt),
sort_order: Some(super::types::SortOrder::Descending),
..Default::default()
};
let results = self.store.search_memories(&query).await?;
Ok(results.into_iter().map(|r| r.memory).collect())
}
pub async fn get_memories_for_files(
&self,
file_paths: Vec<String>,
) -> Result<Vec<MemorySearchResult>> {
let relative_paths: Vec<String> = file_paths
.into_iter()
.map(|path| GitUtils::get_relative_path(&path).unwrap_or(path))
.collect();
let query = MemoryQuery {
related_files: Some(relative_paths),
sort_by: Some(super::types::MemorySortBy::Importance),
sort_order: Some(super::types::SortOrder::Descending),
..Default::default()
};
self.store.search_memories(&query).await
}
pub async fn get_memories_for_current_commit(&self) -> Result<Vec<Memory>> {
if let Some(commit) = GitUtils::get_current_commit() {
let query = MemoryQuery {
git_commit: Some(commit),
sort_by: Some(super::types::MemorySortBy::CreatedAt),
sort_order: Some(super::types::SortOrder::Descending),
..Default::default()
};
let results = self.store.search_memories(&query).await?;
Ok(results.into_iter().map(|r| r.memory).collect())
} else {
Ok(Vec::new())
}
}
pub async fn get_memories_by_tags(&self, tags: Vec<String>) -> Result<Vec<MemorySearchResult>> {
let query = MemoryQuery {
tags: Some(tags),
sort_by: Some(super::types::MemorySortBy::Importance),
sort_order: Some(super::types::SortOrder::Descending),
..Default::default()
};
self.store.search_memories(&query).await
}
pub async fn get_memory_stats(&self) -> Result<MemoryStats> {
let total_count = self.store.get_memory_count().await?;
let recent_memories = self.get_recent_memories(100).await?;
let mut type_counts = std::collections::HashMap::new();
for memory in &recent_memories {
*type_counts
.entry(memory.memory_type.to_string())
.or_insert(0) += 1;
}
Ok(MemoryStats {
total_memories: total_count,
type_counts,
recent_count: recent_memories.len().min(10),
git_commit: GitUtils::get_current_commit(),
})
}
pub async fn create_relationship(
&mut self,
source_id: String,
target_id: String,
relationship_type: RelationshipType,
strength: f32,
description: String,
) -> Result<MemoryRelationship> {
let relationship = MemoryRelationship {
id: uuid::Uuid::new_v4().to_string(),
source_id,
target_id,
relationship_type,
strength,
description,
created_at: Utc::now(),
};
self.store.store_relationship(&relationship).await?;
Ok(relationship)
}
pub async fn get_relationships(&self, memory_id: &str) -> Result<Vec<MemoryRelationship>> {
self.store.get_memory_relationships(memory_id).await
}
pub async fn get_related_memories(&self, memory_id: &str) -> Result<Vec<Memory>> {
let relationships = self.get_relationships(memory_id).await?;
let mut related_memories = Vec::new();
for rel in relationships {
let related_id = if rel.source_id == memory_id {
rel.target_id
} else {
rel.source_id
};
if let Some(memory) = self.store.get_memory(&related_id).await? {
related_memories.push(memory);
}
}
Ok(related_memories)
}
pub async fn cleanup(&mut self) -> Result<usize> {
self.store.cleanup_old_memories().await
}
pub async fn clear_all(&mut self) -> Result<usize> {
self.store.clear_all_memory_data().await
}
async fn create_automatic_relationships(&mut self, memory: &Memory) -> Result<()> {
let similar_query = MemoryQuery {
query_text: Some(memory.get_searchable_text()),
memory_types: Some(vec![memory.memory_type.clone()]),
limit: Some(5),
min_relevance: Some(self.config.relationship_threshold),
..Default::default()
};
let similar_memories = self.store.search_memories(&similar_query).await?;
for result in similar_memories {
if result.memory.id != memory.id
&& result.relevance_score >= self.config.relationship_threshold
{
let relationship_type = if result.relevance_score > 0.9 {
RelationshipType::Similar
} else {
RelationshipType::RelatedTo
};
let _ = self
.create_relationship(
memory.id.clone(),
result.memory.id,
relationship_type,
result.relevance_score,
format!(
"Auto-detected relationship (similarity: {:.2})",
result.relevance_score
),
)
.await;
}
}
if !memory.metadata.related_files.is_empty() {
let file_query = MemoryQuery {
related_files: Some(memory.metadata.related_files.clone()),
limit: Some(10),
..Default::default()
};
let file_related = self.store.search_memories(&file_query).await?;
for result in file_related {
if result.memory.id != memory.id {
let _ = self
.create_relationship(
memory.id.clone(),
result.memory.id,
RelationshipType::RelatedTo,
0.7, "Shares related files".to_string(),
)
.await;
}
}
}
Ok(())
}
async fn update_automatic_relationships(&mut self, memory: &Memory) -> Result<()> {
let existing_relationships = self.get_relationships(&memory.id).await?;
for rel in existing_relationships {
if rel.description.contains("Auto-detected")
|| rel.description.contains("Shares related files")
{
}
}
self.create_automatic_relationships(memory).await
}
pub async fn add_tag(&mut self, memory_id: &str, tag: String) -> Result<bool> {
if let Some(mut memory) = self.store.get_memory(memory_id).await? {
memory.add_tag(tag);
self.store.update_memory(&memory).await?;
Ok(true)
} else {
Ok(false)
}
}
pub async fn remove_tag(&mut self, memory_id: &str, tag: &str) -> Result<bool> {
if let Some(mut memory) = self.store.get_memory(memory_id).await? {
memory.remove_tag(tag);
self.store.update_memory(&memory).await?;
Ok(true)
} else {
Ok(false)
}
}
pub async fn add_related_file(&mut self, memory_id: &str, file_path: String) -> Result<bool> {
if let Some(mut memory) = self.store.get_memory(memory_id).await? {
let relative_path = GitUtils::get_relative_path(&file_path).unwrap_or(file_path);
memory.add_related_file(relative_path);
self.store.update_memory(&memory).await?;
Ok(true)
} else {
Ok(false)
}
}
pub async fn remove_related_file(&mut self, memory_id: &str, file_path: &str) -> Result<bool> {
if let Some(mut memory) = self.store.get_memory(memory_id).await? {
memory.remove_related_file(file_path);
self.store.update_memory(&memory).await?;
Ok(true)
} else {
Ok(false)
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryStats {
pub total_memories: usize,
pub type_counts: std::collections::HashMap<String, usize>,
pub recent_count: usize,
pub git_commit: Option<String>,
}
impl MemoryStats {
pub fn format(&self) -> String {
let mut output = "Memory Statistics:\n".to_string();
output.push_str(&format!(" Total memories: {}\n", self.total_memories));
output.push_str(&format!(" Recent memories: {}\n", self.recent_count));
if let Some(ref commit) = self.git_commit {
output.push_str(&format!(" Current commit: {}\n", commit));
}
if !self.type_counts.is_empty() {
output.push_str(" Memory types:\n");
for (memory_type, count) in &self.type_counts {
output.push_str(&format!(" {}: {}\n", memory_type, count));
}
}
output
}
}