use std::collections::HashMap;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use deadpool_postgres::Pool;
use tokio::runtime::Handle;
use smooth_operator_core::{Document, KnowledgeBase, KnowledgeResult};
use smooth_operator::access_control::{AccessContext, DocAcl};
use smooth_operator::embedding::{Embedder, InputType};
const RRF_K: f32 = 60.0;
#[derive(Clone)]
pub struct PgKnowledgeBase {
pool: Pool,
embedder: Arc<dyn Embedder>,
handle: Handle,
organization_id: Option<String>,
access: Option<AccessContext>,
}
impl PgKnowledgeBase {
pub(crate) fn new(
pool: Pool,
embedder: Arc<dyn Embedder>,
handle: Handle,
organization_id: Option<String>,
) -> Self {
Self {
pool,
embedder,
handle,
organization_id,
access: None,
}
}
#[must_use]
pub fn with_access(&self, access: AccessContext) -> Self {
Self {
access: Some(access),
..self.clone()
}
}
fn vector_literal(v: &[f32]) -> String {
let mut s = String::with_capacity(v.len() * 8 + 2);
s.push('[');
for (i, x) in v.iter().enumerate() {
if i > 0 {
s.push(',');
}
s.push_str(&x.to_string());
}
s.push(']');
s
}
async fn ingest_async(&self, doc: Document) -> Result<()> {
let embeddings = self
.embedder
.embed(std::slice::from_ref(&doc.content), InputType::Document)
.await?;
let embedding = embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("embedder returned no vector"))?;
let literal = Self::vector_literal(&embedding);
let metadata = serde_json::to_value(&doc.metadata)?;
let acl: Option<serde_json::Value> = DocAcl::from_metadata(&doc.metadata)
.map(|a| serde_json::to_value(&a))
.transpose()?;
let row_id = doc.id.clone();
let client = self.pool.get().await?;
client
.execute(
"INSERT INTO knowledge_vectors
(id, document_id, organization_id, source, content, embedding, metadata, acl)
VALUES ($1, $2, $3, $4, $5, $6::text::vector, $7, $8)
ON CONFLICT (id) DO UPDATE SET
document_id = EXCLUDED.document_id,
organization_id = EXCLUDED.organization_id,
source = EXCLUDED.source,
content = EXCLUDED.content,
embedding = EXCLUDED.embedding,
metadata = EXCLUDED.metadata,
acl = EXCLUDED.acl",
&[
&row_id,
&doc.id,
&self.organization_id,
&doc.source,
&doc.content,
&literal,
&metadata,
&acl,
],
)
.await?;
Ok(())
}
async fn query_async(&self, query: &str, limit: usize) -> Result<Vec<KnowledgeResult>> {
let embeddings = self
.embedder
.embed(&[query.to_string()], InputType::Query)
.await?;
let embedding = embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("embedder returned no query vector"))?;
let literal = Self::vector_literal(&embedding);
let candidate_n: i64 = i64::try_from((limit * 4).max(20)).unwrap_or(20);
let client = self.pool.get().await?;
let acl_user: Option<String> = self.access.as_ref().and_then(|c| c.user_id.clone());
let acl_groups: Vec<String> = self
.access
.as_ref()
.map(|c| c.groups.clone())
.unwrap_or_default();
let acl_predicate = if self.access.is_some() {
"(acl IS NULL \
OR (acl->>'public')::boolean IS TRUE \
OR ($4::text IS NOT NULL AND acl->'users' ? $4) \
OR (acl->'groups' ?| $5::text[]))"
} else {
"TRUE"
};
let query_owned = query.to_string();
let dense_sql = format!(
"SELECT id, document_id, source, content
FROM knowledge_vectors
WHERE ($1::text IS NULL OR organization_id = $1)
AND {acl_predicate}
ORDER BY embedding <=> $2::text::vector
LIMIT $3"
);
let mut dense_params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
vec![&self.organization_id, &literal, &candidate_n];
if self.access.is_some() {
dense_params.push(&acl_user);
dense_params.push(&acl_groups);
}
let dense_rows = client.query(&dense_sql, &dense_params).await?;
let sparse_sql = format!(
"SELECT id, document_id, source, content
FROM knowledge_vectors
WHERE ($1::text IS NULL OR organization_id = $1)
AND content_tsv @@ plainto_tsquery('english', $2)
AND {acl_predicate}
ORDER BY ts_rank(content_tsv, plainto_tsquery('english', $2)) DESC
LIMIT $3"
);
let mut sparse_params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
vec![&self.organization_id, &query_owned, &candidate_n];
if self.access.is_some() {
sparse_params.push(&acl_user);
sparse_params.push(&acl_groups);
}
let sparse_rows = client.query(&sparse_sql, &sparse_params).await?;
struct Hit {
document_id: String,
source: String,
content: String,
score: f32,
}
let mut fused: HashMap<String, Hit> = HashMap::new();
let mut fuse = |rows: &[tokio_postgres::Row]| {
for (rank, row) in rows.iter().enumerate() {
let id: String = row.get(0);
let document_id: String = row.get(1);
let source: String = row.get(2);
let content: String = row.get(3);
#[allow(clippy::cast_precision_loss)]
let contribution = 1.0 / (RRF_K + (rank as f32) + 1.0);
fused
.entry(id)
.and_modify(|h| h.score += contribution)
.or_insert(Hit {
document_id,
source,
content,
score: contribution,
});
}
};
fuse(&dense_rows);
fuse(&sparse_rows);
let mut hits: Vec<Hit> = fused.into_values().collect();
hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
hits.truncate(limit);
Ok(hits
.into_iter()
.map(|h| KnowledgeResult {
document_id: h.document_id,
chunk: h.content,
score: h.score,
source: h.source,
})
.collect())
}
}
impl PgKnowledgeBase {
fn run_blocking<F, T>(&self, fut: F) -> Result<T>
where
F: std::future::Future<Output = Result<T>> + Send + 'static,
T: Send + 'static,
{
let join = self.handle.spawn(fut);
let (tx, rx) = std::sync::mpsc::channel();
std::thread::spawn(move || {
let result = (|| -> Result<T> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let joined = rt.block_on(join);
joined.map_err(|e| anyhow!("knowledge task panicked or was cancelled: {e}"))?
})();
let _ = tx.send(result);
});
rx.recv()
.map_err(|e| anyhow!("knowledge task channel closed: {e}"))?
}
}
impl KnowledgeBase for PgKnowledgeBase {
fn ingest(&self, doc: Document) -> Result<()> {
let this = self.clone();
self.run_blocking(async move { this.ingest_async(doc).await })
}
fn query(&self, query: &str, limit: usize) -> Result<Vec<KnowledgeResult>> {
let this = self.clone();
let query = query.to_string();
self.run_blocking(async move { this.query_async(&query, limit).await })
}
}