use std::sync::Arc;
use anyhow::{anyhow, Result};
use deadpool_postgres::Pool;
use tokio::runtime::Handle;
use smooth_operator_core::{Memory, MemoryEntry, MemoryType};
use smooth_operator::embedding::{Embedder, InputType};
#[derive(Clone)]
pub struct PgMemory {
pool: Pool,
embedder: Arc<dyn Embedder>,
handle: Handle,
organization_id: String,
user_id: Option<String>,
}
impl PgMemory {
pub(crate) fn new(
pool: Pool,
embedder: Arc<dyn Embedder>,
handle: Handle,
organization_id: impl Into<String>,
user_id: Option<String>,
) -> Self {
Self {
pool,
embedder,
handle,
organization_id: organization_id.into(),
user_id,
}
}
fn vector_literal(v: &[f32]) -> String {
let mut s = String::with_capacity(v.len() * 8 + 2);
s.push('[');
for (i, x) in v.iter().enumerate() {
if i > 0 {
s.push(',');
}
s.push_str(&x.to_string());
}
s.push(']');
s
}
fn memory_type_to_str(mt: MemoryType) -> Result<String> {
let json = serde_json::to_string(&mt)?;
Ok(json.trim_matches('"').to_string())
}
fn memory_type_from_str(s: &str) -> Result<MemoryType> {
Ok(serde_json::from_str(&format!("\"{s}\""))?)
}
async fn store_async(&self, entry: MemoryEntry) -> Result<()> {
let embeddings = self
.embedder
.embed(std::slice::from_ref(&entry.content), InputType::Document)
.await?;
let embedding = embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("embedder returned no vector"))?;
let literal = Self::vector_literal(&embedding);
let metadata = serde_json::to_value(&entry.metadata)?;
let memory_type = Self::memory_type_to_str(entry.memory_type)?;
let client = self.pool.get().await?;
client
.execute(
"INSERT INTO memories
(id, organization_id, user_id, content, memory_type, relevance,
metadata, embedding, created_at, last_accessed)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::text::vector, $9, $10)
ON CONFLICT (id) DO UPDATE SET
organization_id = EXCLUDED.organization_id,
user_id = EXCLUDED.user_id,
content = EXCLUDED.content,
memory_type = EXCLUDED.memory_type,
relevance = EXCLUDED.relevance,
metadata = EXCLUDED.metadata,
embedding = EXCLUDED.embedding,
last_accessed = EXCLUDED.last_accessed",
&[
&entry.id,
&self.organization_id,
&self.user_id,
&entry.content,
&memory_type,
&entry.relevance,
&metadata,
&literal,
&entry.created_at,
&entry.last_accessed,
],
)
.await?;
Ok(())
}
async fn recall_async(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
let embeddings = self
.embedder
.embed(&[query.to_string()], InputType::Query)
.await?;
let embedding = embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("embedder returned no query vector"))?;
let literal = Self::vector_literal(&embedding);
let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
let client = self.pool.get().await?;
let rows = client
.query(
"SELECT id, content, memory_type, metadata, created_at, last_accessed,
1 - (embedding <=> $3::text::vector) AS similarity
FROM memories
WHERE organization_id = $1
AND user_id IS NOT DISTINCT FROM $2
ORDER BY embedding <=> $3::text::vector
LIMIT $4",
&[&self.organization_id, &self.user_id, &literal, &limit_i64],
)
.await?;
rows.iter()
.map(|row| {
let memory_type =
Self::memory_type_from_str(row.get::<_, String>("memory_type").as_str())?;
let metadata_json: serde_json::Value = row.get("metadata");
let metadata = serde_json::from_value(metadata_json)?;
#[allow(clippy::cast_possible_truncation)]
let similarity = row.get::<_, f64>("similarity") as f32;
Ok(MemoryEntry {
id: row.get("id"),
content: row.get("content"),
memory_type,
relevance: similarity,
metadata,
created_at: row.get("created_at"),
last_accessed: row.get("last_accessed"),
})
})
.collect()
}
async fn forget_async(&self, id: &str) -> Result<()> {
let client = self.pool.get().await?;
client
.execute(
"DELETE FROM memories
WHERE id = $1
AND organization_id = $2
AND user_id IS NOT DISTINCT FROM $3",
&[&id, &self.organization_id, &self.user_id],
)
.await?;
Ok(())
}
fn run_blocking<F, T>(&self, fut: F) -> Result<T>
where
F: std::future::Future<Output = Result<T>> + Send + 'static,
T: Send + 'static,
{
let join = self.handle.spawn(fut);
let (tx, rx) = std::sync::mpsc::channel();
std::thread::spawn(move || {
let result = (|| -> Result<T> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let joined = rt.block_on(join);
joined.map_err(|e| anyhow!("memory task panicked or was cancelled: {e}"))?
})();
let _ = tx.send(result);
});
rx.recv()
.map_err(|e| anyhow!("memory task channel closed: {e}"))?
}
}
impl Memory for PgMemory {
fn store(&self, entry: MemoryEntry) -> Result<()> {
let this = self.clone();
self.run_blocking(async move { this.store_async(entry).await })
}
fn recall(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
let this = self.clone();
let query = query.to_string();
self.run_blocking(async move { this.recall_async(&query, limit).await })
}
fn forget(&self, id: &str) -> Result<()> {
let this = self.clone();
let id = id.to_string();
self.run_blocking(async move { this.forget_async(&id).await })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn memory_type_round_trips_through_text() {
for mt in [
MemoryType::ShortTerm,
MemoryType::LongTerm,
MemoryType::Entity,
MemoryType::User,
MemoryType::Feedback,
MemoryType::Project,
MemoryType::Reference,
] {
let s = PgMemory::memory_type_to_str(mt).expect("to_str");
assert!(
!s.contains('"'),
"stored memory_type must be unquoted: {s:?}"
);
let parsed = PgMemory::memory_type_from_str(&s).expect("from_str");
assert_eq!(parsed, mt);
}
}
#[test]
fn vector_literal_shape() {
assert_eq!(PgMemory::vector_literal(&[0.5, -1.0, 2.0]), "[0.5,-1,2]");
assert_eq!(PgMemory::vector_literal(&[]), "[]");
}
}