use async_trait::async_trait;
use sqlx::PgPool;
use crate::error::{AgentError, Result};
use crate::memory::{MemoryRecord, MemoryStore};
#[cfg(feature = "fastembed")]
use fastembed::{InitOptions, TextEmbedding};
#[cfg(feature = "fastembed")]
use tokio::sync::OnceCell;
pub struct PostgresStore {
pool: PgPool,
#[cfg(feature = "fastembed")]
embedder: OnceCell<TextEmbedding>,
}
impl PostgresStore {
pub async fn new(database_url: &str) -> Result<Self> {
let pool = PgPool::connect(database_url).await.map_err(|e| {
AgentError::MemoryError(format!("Failed to connect to PostgreSQL: {}", e))
})?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS memories (
id UUID PRIMARY KEY,
session_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
importance REAL NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
metadata JSONB,
embedding vector(384)
);
CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);
CREATE INDEX IF NOT EXISTS idx_memories_timestamp ON memories(timestamp DESC);
"#,
)
.execute(&pool)
.await
.map_err(|e| AgentError::MemoryError(format!("Failed to create table: {}", e)))?;
Ok(Self {
pool,
#[cfg(feature = "fastembed")]
embedder: OnceCell::new(),
})
}
pub async fn create_embedding_index(&self) -> Result<()> {
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_memories_embedding
ON memories USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100);
"#,
)
.execute(&self.pool)
.await
.map_err(|e| AgentError::MemoryError(format!("Failed to create index: {}", e)))?;
Ok(())
}
}
#[async_trait]
impl MemoryStore for PostgresStore {
async fn store(&self, record: MemoryRecord) -> Result<()> {
let embedding_vec: Option<Vec<f32>> = record.embedding;
let metadata_json = record
.metadata
.as_ref()
.map(|m| serde_json::to_value(m).ok())
.flatten();
sqlx::query(
r#"
INSERT INTO memories (id, session_id, role, content, importance, timestamp, metadata, embedding)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (id) DO UPDATE SET
content = EXCLUDED.content,
importance = EXCLUDED.importance,
metadata = EXCLUDED.metadata,
embedding = EXCLUDED.embedding
"#,
)
.bind(record.id)
.bind(&record.session_id)
.bind(&record.role)
.bind(&record.content)
.bind(record.importance)
.bind(record.timestamp)
.bind(metadata_json)
.bind(embedding_vec)
.execute(&self.pool)
.await
.map_err(|e| AgentError::MemoryError(format!("Failed to store memory: {}", e)))?;
Ok(())
}
async fn retrieve(&self, session_id: &str, limit: usize) -> Result<Vec<MemoryRecord>> {
let records = sqlx::query_as::<
_,
(
uuid::Uuid,
String,
String,
String,
f32,
chrono::DateTime<chrono::Utc>,
Option<serde_json::Value>,
Option<Vec<f32>>,
),
>(
r#"SELECT id, session_id, role, content, importance, timestamp, metadata, embedding
FROM memories
WHERE session_id = $1
ORDER BY timestamp DESC
LIMIT $2"#,
)
.bind(session_id)
.bind(limit as i64)
.fetch_all(&self.pool)
.await
.map_err(|e| AgentError::MemoryError(format!("Failed to retrieve memories: {}", e)))?;
Ok(records
.into_iter()
.map(
|(id, session_id, role, content, importance, timestamp, metadata, embedding)| {
MemoryRecord {
id,
session_id,
role,
content,
importance,
timestamp,
metadata: metadata.and_then(|v| serde_json::from_value(v).ok()),
embedding,
}
},
)
.collect())
}
async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
#[cfg(feature = "fastembed")]
{
let embedder = self
.embedder
.get_or_try_init(|| async {
TextEmbedding::try_new(InitOptions::default())
.map_err(|e| AgentError::MemoryError(e.to_string()))
})
.await?;
let embeddings = embedder
.embed(vec![_text], None)
.map_err(|e| AgentError::MemoryError(e.to_string()))?;
Ok(embeddings[0].clone())
}
#[cfg(not(feature = "fastembed"))]
Ok(Vec::new())
}
async fn search(
&self,
session_id: &str,
query_embedding: Vec<f32>,
limit: usize,
) -> Result<Vec<MemoryRecord>> {
let records = sqlx::query_as::<
_,
(
uuid::Uuid,
String,
String,
String,
f32,
chrono::DateTime<chrono::Utc>,
Option<serde_json::Value>,
Option<Vec<f32>>,
),
>(
r#"SELECT id, session_id, role, content, importance, timestamp, metadata, embedding
FROM memories
WHERE session_id = $1 AND embedding IS NOT NULL
ORDER BY embedding <=> $2
LIMIT $3"#,
)
.bind(session_id)
.bind(&query_embedding)
.bind(limit as i64)
.fetch_all(&self.pool)
.await
.map_err(|e| AgentError::MemoryError(format!("Failed to search memories: {}", e)))?;
Ok(records
.into_iter()
.map(
|(id, session_id, role, content, importance, timestamp, metadata, embedding)| {
MemoryRecord {
id,
session_id,
role,
content,
importance,
timestamp,
metadata: metadata.and_then(|v| serde_json::from_value(v).ok()),
embedding,
}
},
)
.collect())
}
async fn flush(&self) -> Result<()> {
Ok(())
}
}