use crate::error::Error;
use pgvector::Vector;
use sqlx::PgPool;
pub struct Neighbor {
pub id: i64,
pub score: f32,
}
impl sqlx::FromRow<'_, sqlx::postgres::PgRow> for Neighbor {
fn from_row(row: &sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
use sqlx::Row;
Ok(Self {
id: row.try_get("id")?,
score: row.try_get::<f32, _>("score")?,
})
}
}
pub struct PgVectorStore {
table: String,
column: String,
}
impl PgVectorStore {
pub fn new(table: &str, column: &str) -> Self {
Self {
table: table.to_string(),
column: column.to_string(),
}
}
pub async fn store(&self, pool: &PgPool, id: i64, embedding: &[f32]) -> Result<(), Error> {
let vec = Vector::from(embedding.to_vec());
let sql = format!(
"INSERT INTO {} (id, {}) VALUES ($1, $2) ON CONFLICT (id) DO UPDATE SET {} = $2",
self.table, self.column, self.column
);
sqlx::query(&sql)
.bind(id)
.bind(vec)
.execute(pool)
.await
.map_err(|e| Error::Sqlx(e.to_string()))?;
Ok(())
}
pub async fn nearest(
&self,
pool: &PgPool,
query: &[f32],
k: u32,
) -> Result<Vec<Neighbor>, Error> {
let vec = Vector::from(query.to_vec());
let sql = format!(
"SELECT id, (1.0 - ({} <=> $1))::float4 AS score FROM {} ORDER BY {} <=> $1 LIMIT $2",
self.column, self.table, self.column
);
sqlx::query_as::<_, Neighbor>(&sql)
.bind(vec)
.bind(i64::from(k))
.fetch_all(pool)
.await
.map_err(|e| Error::Sqlx(e.to_string()))
}
}