use async_trait::async_trait;
use sqlx::postgres::{PgPool, PgPoolOptions};
use sqlx::Row;
use super::error::{StorageError, StorageResult};
use super::vector::{VectorBackend, VectorSearchResult};
use crate::constants::EMBEDDING_DIMENSIONS_COUNT;
const TABLE_NAME: &str = "embeddings";
pub struct PostgresVectorBackend {
pool: PgPool,
}
impl std::fmt::Debug for PostgresVectorBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PostgresVectorBackend")
.finish_non_exhaustive()
}
}
impl PostgresVectorBackend {
pub async fn connect(database_url: &str) -> StorageResult<Self> {
let pool = PgPoolOptions::new()
.max_connections(10)
.connect(database_url)
.await
.map_err(|e| StorageError::ConnectionFailed(e.to_string()))?;
let backend = Self { pool };
backend.init_table().await?;
Ok(backend)
}
async fn init_table(&self) -> StorageResult<()> {
sqlx::query("CREATE EXTENSION IF NOT EXISTS vector")
.execute(&self.pool)
.await
.map_err(|e| {
StorageError::WriteFailed(format!("Failed to create vector extension: {}", e))
})?;
let create_table_sql = format!(
r#"
CREATE TABLE IF NOT EXISTS {} (
id TEXT PRIMARY KEY,
embedding vector({}) NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
"#,
TABLE_NAME, EMBEDDING_DIMENSIONS_COUNT
);
sqlx::query(&create_table_sql)
.execute(&self.pool)
.await
.map_err(|e| StorageError::WriteFailed(format!("Failed to create table: {}", e)))?;
let create_index_sql = format!(
r#"
CREATE INDEX IF NOT EXISTS idx_{}_vector
ON {} USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100)
"#,
TABLE_NAME, TABLE_NAME
);
sqlx::query(&create_index_sql)
.execute(&self.pool)
.await
.map_err(|e| StorageError::WriteFailed(format!("Failed to create index: {}", e)))?;
Ok(())
}
}
#[async_trait]
impl VectorBackend for PostgresVectorBackend {
async fn store(&self, id: &str, embedding: &[f32]) -> StorageResult<()> {
assert!(!id.is_empty(), "id must not be empty");
assert_eq!(
embedding.len(),
EMBEDDING_DIMENSIONS_COUNT,
"embedding must have {} dimensions, got {}",
EMBEDDING_DIMENSIONS_COUNT,
embedding.len()
);
let embedding_str = format!(
"[{}]",
embedding
.iter()
.map(|f| f.to_string())
.collect::<Vec<_>>()
.join(",")
);
let sql = format!(
r#"
INSERT INTO {} (id, embedding)
VALUES ($1, $2::vector)
ON CONFLICT (id)
DO UPDATE SET embedding = EXCLUDED.embedding
"#,
TABLE_NAME
);
sqlx::query(&sql)
.bind(id)
.bind(&embedding_str)
.execute(&self.pool)
.await
.map_err(|e| StorageError::WriteFailed(e.to_string()))?;
Ok(())
}
async fn search(
&self,
embedding: &[f32],
limit: usize,
) -> StorageResult<Vec<VectorSearchResult>> {
assert_eq!(
embedding.len(),
EMBEDDING_DIMENSIONS_COUNT,
"query embedding must have {} dimensions, got {}",
EMBEDDING_DIMENSIONS_COUNT,
embedding.len()
);
assert!(limit > 0, "limit must be positive");
let embedding_str = format!(
"[{}]",
embedding
.iter()
.map(|f| f.to_string())
.collect::<Vec<_>>()
.join(",")
);
let sql = format!(
r#"
SELECT id, 1 - (embedding <=> $1::vector) as score
FROM {}
WHERE embedding IS NOT NULL
ORDER BY embedding <=> $1::vector
LIMIT $2
"#,
TABLE_NAME
);
let rows = sqlx::query(&sql)
.bind(&embedding_str)
.bind(limit as i64)
.fetch_all(&self.pool)
.await
.map_err(|e| StorageError::ReadFailed(e.to_string()))?;
let results = rows
.into_iter()
.map(|row| {
let id: String = row.get("id");
let score: f32 = row.get::<f64, _>("score") as f32;
VectorSearchResult { id, score }
})
.collect();
Ok(results)
}
async fn delete(&self, id: &str) -> StorageResult<()> {
assert!(!id.is_empty(), "id must not be empty");
let sql = format!("DELETE FROM {} WHERE id = $1", TABLE_NAME);
sqlx::query(&sql)
.bind(id)
.execute(&self.pool)
.await
.map_err(|e| StorageError::WriteFailed(e.to_string()))?;
Ok(())
}
async fn exists(&self, id: &str) -> StorageResult<bool> {
assert!(!id.is_empty(), "id must not be empty");
let sql = format!("SELECT EXISTS(SELECT 1 FROM {} WHERE id = $1)", TABLE_NAME);
let exists: bool = sqlx::query_scalar(&sql)
.bind(id)
.fetch_one(&self.pool)
.await
.map_err(|e| StorageError::ReadFailed(e.to_string()))?;
Ok(exists)
}
async fn get(&self, id: &str) -> StorageResult<Option<Vec<f32>>> {
assert!(!id.is_empty(), "id must not be empty");
let sql = format!("SELECT embedding FROM {} WHERE id = $1", TABLE_NAME);
let row = sqlx::query(&sql)
.bind(id)
.fetch_optional(&self.pool)
.await
.map_err(|e| StorageError::ReadFailed(e.to_string()))?;
match row {
Some(row) => {
let embedding_str: String = row.get("embedding");
let embedding = parse_pgvector_string(&embedding_str)
.map_err(|e| StorageError::DeserializationError(e))?;
Ok(Some(embedding))
}
None => Ok(None),
}
}
async fn count(&self) -> StorageResult<usize> {
let sql = format!("SELECT COUNT(*) FROM {}", TABLE_NAME);
let count: i64 = sqlx::query_scalar(&sql)
.fetch_one(&self.pool)
.await
.map_err(|e| StorageError::ReadFailed(e.to_string()))?;
Ok(count as usize)
}
}
fn parse_pgvector_string(s: &str) -> Result<Vec<f32>, String> {
let s = s.trim();
if !s.starts_with('[') || !s.ends_with(']') {
return Err(format!("Invalid pgvector format: {}", s));
}
let inner = &s[1..s.len() - 1];
let values: Result<Vec<f32>, _> = inner.split(',').map(|v| v.trim().parse::<f32>()).collect();
values.map_err(|e| format!("Failed to parse pgvector values: {}", e))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_pgvector_string() {
let s = "[1.0, 2.0, 3.0]";
let vec = parse_pgvector_string(s).unwrap();
assert_eq!(vec, vec![1.0, 2.0, 3.0]);
let s = "[ 1.5 , 2.5 , 3.5 ]";
let vec = parse_pgvector_string(s).unwrap();
assert_eq!(vec, vec![1.5, 2.5, 3.5]);
}
#[test]
fn test_parse_pgvector_string_invalid() {
let s = "1.0, 2.0, 3.0"; assert!(parse_pgvector_string(s).is_err());
let s = "[1.0, 2.0, abc]"; assert!(parse_pgvector_string(s).is_err());
}
}