use crate::error::Result;
use crate::models::{
Memory, SearchParams, SearchResult, SearchResultWithMetadata, SearchStrategy, StorageStats,
};
use async_trait::async_trait;
use sqlx::{PgPool, Row};
use uuid::Uuid;
#[async_trait]
pub trait StorageInterface: Send + Sync {
async fn store(
&self,
content: &str,
context: String,
summary: String,
tags: Option<Vec<String>>,
) -> Result<Uuid>;
async fn store_chunk(
&self,
content: &str,
context: String,
summary: String,
tags: Option<Vec<String>>,
chunk_index: i32,
total_chunks: i32,
parent_id: Uuid,
) -> Result<Uuid>;
async fn get(&self, id: Uuid) -> Result<Option<Memory>>;
async fn delete(&self, id: Uuid) -> Result<bool>;
async fn search(&self, params: SearchParams) -> Result<Vec<SearchResult>>;
async fn stats(&self) -> Result<StorageStats>;
async fn list_recent(&self, limit: i64) -> Result<Vec<Memory>>;
async fn get_chunks(&self, parent_id: Uuid) -> Result<Vec<Memory>>;
}
pub struct Storage {
pool: PgPool,
}
impl Storage {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
pub async fn store(
&self,
content: &str,
context: String,
summary: String,
tags: Option<Vec<String>>,
) -> Result<Uuid> {
let memory = Memory::new(content.to_string(), context, summary, tags);
let result: Uuid = sqlx::query_scalar(
r#"
INSERT INTO memories (id, content, content_hash, tags, context, summary, chunk_index, total_chunks, parent_id, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT (content_hash) DO UPDATE SET
context = EXCLUDED.context,
summary = EXCLUDED.summary,
tags = EXCLUDED.tags,
updated_at = EXCLUDED.updated_at
RETURNING id
"#
)
.bind(memory.id)
.bind(memory.content)
.bind(memory.content_hash)
.bind(&memory.tags)
.bind(&memory.context)
.bind(&memory.summary)
.bind(memory.chunk_index)
.bind(memory.total_chunks)
.bind(memory.parent_id)
.bind(memory.created_at)
.bind(memory.updated_at)
.fetch_one(&self.pool)
.await?;
Ok(result)
}
pub async fn store_chunk(
&self,
content: &str,
context: String,
summary: String,
tags: Option<Vec<String>>,
chunk_index: i32,
total_chunks: i32,
parent_id: Uuid,
) -> Result<Uuid> {
let memory = Memory::new_chunk(
content.to_string(),
context,
summary,
tags,
chunk_index,
total_chunks,
parent_id,
);
let result: Uuid = sqlx::query_scalar(
r#"
INSERT INTO memories (id, content, content_hash, tags, context, summary, chunk_index, total_chunks, parent_id, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id
"#
)
.bind(memory.id)
.bind(memory.content)
.bind(memory.content_hash)
.bind(&memory.tags)
.bind(&memory.context)
.bind(&memory.summary)
.bind(memory.chunk_index)
.bind(memory.total_chunks)
.bind(memory.parent_id)
.bind(memory.created_at)
.bind(memory.updated_at)
.fetch_one(&self.pool)
.await?;
Ok(result)
}
pub async fn get(&self, id: Uuid) -> Result<Option<Memory>> {
let memory = sqlx::query_as::<_, Memory>(
r#"
SELECT
id,
content,
content_hash,
tags,
context,
summary,
chunk_index,
total_chunks,
parent_id,
created_at,
updated_at
FROM memories
WHERE id = $1
"#,
)
.bind(id)
.fetch_optional(&self.pool)
.await?;
Ok(memory)
}
pub async fn get_chunks(&self, parent_id: Uuid) -> Result<Vec<Memory>> {
let memories = sqlx::query_as::<_, Memory>(
r#"
SELECT
id,
content,
content_hash,
tags,
context,
summary,
chunk_index,
total_chunks,
parent_id,
created_at,
updated_at
FROM memories
WHERE parent_id = $1
ORDER BY chunk_index ASC
"#,
)
.bind(parent_id)
.fetch_all(&self.pool)
.await?;
Ok(memories)
}
pub async fn delete(&self, id: Uuid) -> Result<bool> {
let result = sqlx::query("DELETE FROM memories WHERE id = $1")
.bind(id)
.execute(&self.pool)
.await?;
Ok(result.rows_affected() > 0)
}
pub async fn stats(&self) -> Result<StorageStats> {
let row = sqlx::query(
r#"
SELECT
COUNT(*) as total_memories,
pg_size_pretty(pg_total_relation_size('memories')) as table_size,
MAX(created_at) as last_memory_created
FROM memories
"#,
)
.fetch_one(&self.pool)
.await?;
let stats = StorageStats {
total_memories: row.get("total_memories"),
table_size: row.get("table_size"),
last_memory_created: row.get("last_memory_created"),
};
Ok(stats)
}
pub async fn list_recent(&self, limit: i64) -> Result<Vec<Memory>> {
let memories = sqlx::query_as::<_, Memory>(
r#"
SELECT
id,
content,
content_hash,
tags,
context,
summary,
chunk_index,
total_chunks,
parent_id,
created_at,
updated_at
FROM memories
ORDER BY created_at DESC
LIMIT $1
"#,
)
.bind(limit)
.fetch_all(&self.pool)
.await?;
Ok(memories)
}
pub async fn find_similar_content(
&self,
content_hash: &str,
limit: i64,
) -> Result<Vec<Memory>> {
let memories = sqlx::query_as::<_, Memory>(
r#"
SELECT
id,
content,
content_hash,
tags,
context,
summary,
chunk_index,
total_chunks,
parent_id,
created_at,
updated_at
FROM memories
WHERE content_hash = $1
ORDER BY created_at DESC
LIMIT $2
"#,
)
.bind(content_hash)
.bind(limit)
.fetch_all(&self.pool)
.await?;
Ok(memories)
}
pub async fn exists_with_content(&self, content_hash: &str) -> Result<bool> {
let count: i64 =
sqlx::query_scalar("SELECT COUNT(*) FROM memories WHERE content_hash = $1")
.bind(content_hash)
.fetch_one(&self.pool)
.await?;
Ok(count > 0)
}
pub async fn get_content_stats(&self) -> Result<Vec<(String, i64)>> {
let rows = sqlx::query(
r#"
SELECT
content_hash,
COUNT(*) as total_count
FROM memories
GROUP BY content_hash
HAVING COUNT(*) > 1
ORDER BY total_count DESC
LIMIT 50
"#,
)
.fetch_all(&self.pool)
.await?;
let stats = rows
.into_iter()
.map(|row| {
(
row.get::<String, _>("content_hash"),
row.get::<i64, _>("total_count"),
)
})
.collect();
Ok(stats)
}
pub async fn search_memories(&self, params: SearchParams) -> Result<Vec<SearchResult>> {
self.search_memories_progressive(params).await
}
async fn search_memories_progressive(&self, params: SearchParams) -> Result<Vec<SearchResult>> {
let result_with_metadata = self
.search_memories_progressive_with_metadata(params)
.await?;
Ok(result_with_metadata.results)
}
pub async fn search_memories_progressive_with_metadata(
&self,
params: SearchParams,
) -> Result<SearchResultWithMetadata> {
use crate::models::SearchMetadata;
let stage1_results = self.search_memories_internal(params.clone()).await?;
if !stage1_results.is_empty() {
let metadata = SearchMetadata {
stage_used: 1,
stage_description: "Original parameters".to_string(),
threshold_used: params.similarity_threshold,
total_results: stage1_results.len(),
};
return Ok(SearchResultWithMetadata {
results: stage1_results,
metadata,
});
}
let mut relaxed_params = params.clone();
relaxed_params.similarity_threshold = (params.similarity_threshold - 0.25).max(0.1);
let stage2_results = self
.search_memories_internal(relaxed_params.clone())
.await?;
if !stage2_results.is_empty() {
let metadata = SearchMetadata {
stage_used: 2,
stage_description: "Relaxed threshold".to_string(),
threshold_used: relaxed_params.similarity_threshold,
total_results: stage2_results.len(),
};
return Ok(SearchResultWithMetadata {
results: stage2_results,
metadata,
});
}
let mut content_params = params.clone();
content_params.similarity_threshold = 0.1;
content_params.use_tag_embedding = false;
content_params.search_strategy = SearchStrategy::ContentFirst;
let stage3_results = self.search_memories_internal(content_params).await?;
let metadata = SearchMetadata {
stage_used: 3,
stage_description: "Content-only similarity".to_string(),
threshold_used: 0.1,
total_results: stage3_results.len(),
};
Ok(SearchResultWithMetadata {
results: stage3_results,
metadata,
})
}
async fn search_memories_internal(&self, params: SearchParams) -> Result<Vec<SearchResult>> {
let has_embeddings = self.check_embedding_columns_exist().await?;
if !has_embeddings || (!params.use_tag_embedding && !params.use_content_embedding) {
return self.search_memories_fallback(params).await;
}
let query_memory_ids = if params.use_tag_embedding || params.use_content_embedding {
let similar_text_rows = sqlx::query(
r#"
SELECT id, summary, content
FROM memories
WHERE to_tsvector('english', summary || ' ' || content) @@ plainto_tsquery('english', $1)
AND embedding_vector IS NOT NULL
LIMIT 5
"#
)
.bind(¶ms.query)
.fetch_all(&self.pool)
.await;
match similar_text_rows {
Ok(rows) => {
if rows.is_empty() {
return self.search_memories_fallback(params).await;
}
rows.into_iter()
.map(|row| row.get::<Uuid, _>("id"))
.collect::<Vec<_>>()
}
Err(_) => {
return self.search_memories_fallback(params).await;
}
}
} else {
vec![]
};
let mut results = match params.search_strategy {
SearchStrategy::TagsFirst => self.search_tags_first(¶ms, &query_memory_ids).await?,
SearchStrategy::ContentFirst => {
self.search_content_first(¶ms, &query_memory_ids)
.await?
}
SearchStrategy::Hybrid => self.search_hybrid(¶ms, &query_memory_ids).await?,
};
if params.boost_recent {
self.apply_recency_boost(&mut results);
}
results.sort_by(|a, b| b.combined_score.partial_cmp(&a.combined_score).unwrap());
results.truncate(params.max_results);
Ok(results)
}
async fn check_embedding_columns_exist(&self) -> Result<bool> {
let result = sqlx::query(
r#"
SELECT COUNT(*) as count
FROM information_schema.columns
WHERE table_name = 'memories'
AND column_name IN ('embedding_vector', 'tag_embedding')
"#,
)
.fetch_one(&self.pool)
.await;
match result {
Ok(row) => {
let count: i64 = row.get("count");
Ok(count >= 2) }
Err(_) => Ok(false),
}
}
async fn search_tags_first(
&self,
params: &SearchParams,
query_ids: &[Uuid],
) -> Result<Vec<SearchResult>> {
if query_ids.is_empty() {
return Ok(vec![]);
}
let tag_results = sqlx::query(
r#"
WITH query_embedding AS (
SELECT tag_embedding as query_vector
FROM memories
WHERE id = $1 AND tag_embedding IS NOT NULL
LIMIT 1
)
SELECT m.*,
(m.tag_embedding <=> q.query_vector) as tag_similarity,
m.semantic_cluster
FROM memories m, query_embedding q
WHERE m.tag_embedding IS NOT NULL
AND ($2::text[] IS NULL OR m.tags && $2::text[])
AND (m.tag_embedding <=> q.query_vector) <= $3
ORDER BY m.tag_embedding <=> q.query_vector
LIMIT $4
"#,
)
.bind(query_ids[0])
.bind(¶ms.tag_filter)
.bind(1.0 - params.similarity_threshold) .bind((params.max_results * 3) as i64) .fetch_all(&self.pool)
.await?;
self.enhance_with_content_similarity(tag_results, query_ids, params)
.await
}
async fn search_content_first(
&self,
params: &SearchParams,
query_ids: &[Uuid],
) -> Result<Vec<SearchResult>> {
if query_ids.is_empty() {
return Ok(vec![]);
}
let content_results = sqlx::query(
r#"
WITH query_embedding AS (
SELECT embedding_vector as query_vector
FROM memories
WHERE id = $1 AND embedding_vector IS NOT NULL
LIMIT 1
)
SELECT m.*,
(m.embedding_vector <=> q.query_vector) as content_similarity,
m.semantic_cluster
FROM memories m, query_embedding q
WHERE m.embedding_vector IS NOT NULL
AND ($2::text[] IS NULL OR m.tags && $2::text[])
AND (m.embedding_vector <=> q.query_vector) <= $3
ORDER BY m.embedding_vector <=> q.query_vector
LIMIT $4
"#,
)
.bind(query_ids[0])
.bind(¶ms.tag_filter)
.bind(1.0 - params.similarity_threshold)
.bind((params.max_results * 2) as i64)
.fetch_all(&self.pool)
.await?;
self.enhance_with_tag_similarity(content_results, query_ids, params)
.await
}
async fn search_hybrid(
&self,
params: &SearchParams,
query_ids: &[Uuid],
) -> Result<Vec<SearchResult>> {
if query_ids.is_empty() {
return Ok(vec![]);
}
let results = sqlx::query(
r#"
WITH query_embeddings AS (
SELECT
embedding_vector as content_query_vector,
tag_embedding as tag_query_vector
FROM memories
WHERE id = $1
AND embedding_vector IS NOT NULL
AND tag_embedding IS NOT NULL
LIMIT 1
)
SELECT m.*,
(m.embedding_vector <=> q.content_query_vector) as content_similarity,
(m.tag_embedding <=> q.tag_query_vector) as tag_similarity,
m.semantic_cluster
FROM memories m, query_embeddings q
WHERE m.embedding_vector IS NOT NULL
AND m.tag_embedding IS NOT NULL
AND ($2::text[] IS NULL OR m.tags && $2::text[])
AND (
(m.embedding_vector <=> q.content_query_vector) <= $3 OR
(m.tag_embedding <=> q.tag_query_vector) <= $3
)
ORDER BY LEAST(
m.embedding_vector <=> q.content_query_vector,
m.tag_embedding <=> q.tag_query_vector
)
LIMIT $4
"#,
)
.bind(query_ids[0])
.bind(¶ms.tag_filter)
.bind(1.0 - params.similarity_threshold)
.bind((params.max_results * 2) as i64)
.fetch_all(&self.pool)
.await?;
Ok(self.rows_to_search_results(results, params))
}
async fn enhance_with_content_similarity(
&self,
tag_results: Vec<sqlx::postgres::PgRow>,
query_ids: &[Uuid],
params: &SearchParams,
) -> Result<Vec<SearchResult>> {
if query_ids.is_empty() || tag_results.is_empty() {
return Ok(vec![]);
}
let memory_ids: Vec<Uuid> = tag_results.iter().map(|r| r.get("id")).collect();
let content_similarities = if params.use_content_embedding {
sqlx::query(
r#"
WITH query_embedding AS (
SELECT embedding_vector as query_vector
FROM memories
WHERE id = $1 AND embedding_vector IS NOT NULL
LIMIT 1
)
SELECT m.id, (m.embedding_vector <=> q.query_vector) as content_similarity
FROM memories m, query_embedding q
WHERE m.id = ANY($2) AND m.embedding_vector IS NOT NULL
"#,
)
.bind(query_ids[0])
.bind(&memory_ids)
.fetch_all(&self.pool)
.await?
} else {
vec![]
};
let content_sim_map: std::collections::HashMap<Uuid, f64> = content_similarities
.into_iter()
.map(|row| (row.get("id"), 1.0 - row.get::<f64, _>("content_similarity")))
.collect();
let mut results = vec![];
for row in tag_results {
let memory_id: Uuid = row.get("id");
let tag_similarity = Some(1.0 - row.get::<f64, _>("tag_similarity"));
let content_similarity = content_sim_map.get(&memory_id).copied();
let semantic_cluster = row.get("semantic_cluster");
let memory = self.row_to_memory(&row);
let result = SearchResult::new(
memory,
tag_similarity,
content_similarity,
semantic_cluster,
params.tag_weight,
params.content_weight,
);
if result.combined_score >= params.similarity_threshold {
results.push(result);
}
}
Ok(results)
}
async fn enhance_with_tag_similarity(
&self,
content_results: Vec<sqlx::postgres::PgRow>,
query_ids: &[Uuid],
params: &SearchParams,
) -> Result<Vec<SearchResult>> {
if query_ids.is_empty() || content_results.is_empty() {
return Ok(vec![]);
}
let memory_ids: Vec<Uuid> = content_results.iter().map(|r| r.get("id")).collect();
let tag_similarities = if params.use_tag_embedding {
sqlx::query(
r#"
WITH query_embedding AS (
SELECT tag_embedding as query_vector
FROM memories
WHERE id = $1 AND tag_embedding IS NOT NULL
LIMIT 1
)
SELECT m.id, (m.tag_embedding <=> q.query_vector) as tag_similarity
FROM memories m, query_embedding q
WHERE m.id = ANY($2) AND m.tag_embedding IS NOT NULL
"#,
)
.bind(query_ids[0])
.bind(&memory_ids)
.fetch_all(&self.pool)
.await?
} else {
vec![]
};
let tag_sim_map: std::collections::HashMap<Uuid, f64> = tag_similarities
.into_iter()
.map(|row| (row.get("id"), 1.0 - row.get::<f64, _>("tag_similarity")))
.collect();
let mut results = vec![];
for row in content_results {
let memory_id: Uuid = row.get("id");
let content_similarity = Some(1.0 - row.get::<f64, _>("content_similarity"));
let tag_similarity = tag_sim_map.get(&memory_id).copied();
let semantic_cluster = row.get("semantic_cluster");
let memory = self.row_to_memory(&row);
let result = SearchResult::new(
memory,
tag_similarity,
content_similarity,
semantic_cluster,
params.tag_weight,
params.content_weight,
);
if result.combined_score >= params.similarity_threshold {
results.push(result);
}
}
Ok(results)
}
fn rows_to_search_results(
&self,
rows: Vec<sqlx::postgres::PgRow>,
params: &SearchParams,
) -> Vec<SearchResult> {
rows.into_iter()
.filter_map(|row| {
let tag_similarity = row
.try_get::<f64, _>("tag_similarity")
.ok()
.map(|v| 1.0 - v);
let content_similarity = row
.try_get::<f64, _>("content_similarity")
.ok()
.map(|v| 1.0 - v);
let semantic_cluster = row.get("semantic_cluster");
let memory = self.row_to_memory(&row);
let result = SearchResult::new(
memory,
tag_similarity,
content_similarity,
semantic_cluster,
params.tag_weight,
params.content_weight,
);
if result.combined_score >= params.similarity_threshold {
Some(result)
} else {
None
}
})
.collect()
}
fn row_to_memory(&self, row: &sqlx::postgres::PgRow) -> Memory {
Memory {
id: row.get("id"),
content: row.get("content"),
content_hash: row.get("content_hash"),
tags: row.get("tags"),
context: row.get("context"),
summary: row.get("summary"),
chunk_index: row.get("chunk_index"),
total_chunks: row.get("total_chunks"),
parent_id: row.get("parent_id"),
created_at: row.get("created_at"),
updated_at: row.get("updated_at"),
}
}
fn apply_recency_boost(&self, results: &mut [SearchResult]) {
let now = chrono::Utc::now();
for result in results.iter_mut() {
let age_days = (now - result.memory.created_at).num_days() as f64;
let recency_factor = (1.0 / (1.0 + age_days / 30.0)).max(0.1); result.combined_score *= recency_factor;
}
}
async fn search_memories_fallback(&self, params: SearchParams) -> Result<Vec<SearchResult>> {
let search_pattern = format!("%{}%", params.query);
let query_sql = if let Some(ref _tag_filter) = params.tag_filter {
r#"
SELECT *,
CAST(CASE
WHEN content ILIKE $1 AND summary ILIKE $1 THEN 1.0
WHEN content ILIKE $1 OR summary ILIKE $1 THEN 0.8
WHEN context ILIKE $1 THEN 0.6
WHEN EXISTS (SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE $1) THEN 0.5
ELSE 0.4
END AS FLOAT8) as rank
FROM memories
WHERE (content ILIKE $1 OR summary ILIKE $1 OR context ILIKE $1
OR EXISTS (SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE $1))
AND tags && $2::text[]
ORDER BY rank DESC, created_at DESC
LIMIT $3
"#
} else {
r#"
SELECT *,
CAST(CASE
WHEN content ILIKE $1 AND summary ILIKE $1 THEN 1.0
WHEN content ILIKE $1 OR summary ILIKE $1 THEN 0.8
WHEN context ILIKE $1 THEN 0.6
WHEN EXISTS (SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE $1) THEN 0.5
ELSE 0.4
END AS FLOAT8) as rank
FROM memories
WHERE content ILIKE $1 OR summary ILIKE $1 OR context ILIKE $1
OR EXISTS (SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE $1)
ORDER BY rank DESC, created_at DESC
LIMIT $2
"#
};
let rows = if let Some(ref tag_filter) = params.tag_filter {
sqlx::query(query_sql)
.bind(&search_pattern)
.bind(tag_filter)
.bind(params.max_results as i64)
.fetch_all(&self.pool)
.await?
} else {
sqlx::query(query_sql)
.bind(&search_pattern)
.bind(params.max_results as i64)
.fetch_all(&self.pool)
.await?
};
let results = rows
.into_iter()
.map(|row| {
let rank: f64 = row.get("rank");
let memory = self.row_to_memory(&row);
SearchResult {
memory,
tag_similarity: None,
content_similarity: Some(rank),
combined_score: rank, semantic_cluster: None,
}
})
.filter(|result| result.combined_score >= params.similarity_threshold)
.collect();
Ok(results)
}
}
#[async_trait]
impl StorageInterface for Storage {
async fn store(
&self,
content: &str,
context: String,
summary: String,
tags: Option<Vec<String>>,
) -> Result<Uuid> {
self.store(content, context, summary, tags).await
}
async fn store_chunk(
&self,
content: &str,
context: String,
summary: String,
tags: Option<Vec<String>>,
chunk_index: i32,
total_chunks: i32,
parent_id: Uuid,
) -> Result<Uuid> {
self.store_chunk(content, context, summary, tags, chunk_index, total_chunks, parent_id)
.await
}
async fn get(&self, id: Uuid) -> Result<Option<Memory>> {
self.get(id).await
}
async fn delete(&self, id: Uuid) -> Result<bool> {
self.delete(id).await
}
async fn search(&self, params: SearchParams) -> Result<Vec<SearchResult>> {
self.search_memories(params).await
}
async fn stats(&self) -> Result<StorageStats> {
self.stats().await
}
async fn list_recent(&self, limit: i64) -> Result<Vec<Memory>> {
self.list_recent(limit).await
}
async fn get_chunks(&self, parent_id: Uuid) -> Result<Vec<Memory>> {
self.get_chunks(parent_id).await
}
}