use deadpool_postgres::{Config, Pool, Runtime};
use pgvector::Vector;
use tokio_postgres::NoTls;
use serde::{Deserialize, Serialize};
use super::{RagConfig, RagError, RagResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "op", rename_all = "snake_case")]
pub enum MetadataFilter {
Eq { field: String, value: serde_json::Value },
Ne { field: String, value: serde_json::Value },
Gt { field: String, value: serde_json::Value },
Gte { field: String, value: serde_json::Value },
Lt { field: String, value: serde_json::Value },
Lte { field: String, value: serde_json::Value },
Exists { field: String },
NotExists { field: String },
Contains { field: String, value: String },
StartsWith { field: String, value: String },
EndsWith { field: String, value: String },
InArray { field: String, value: String },
In { field: String, values: Vec<serde_json::Value> },
NotIn { field: String, values: Vec<serde_json::Value> },
JsonPath { path: String },
And { filters: Vec<MetadataFilter> },
Or { filters: Vec<MetadataFilter> },
Not { filter: Box<MetadataFilter> },
}
impl MetadataFilter {
pub fn eq(field: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
Self::Eq { field: field.into(), value: value.into() }
}
pub fn ne(field: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
Self::Ne { field: field.into(), value: value.into() }
}
pub fn gt(field: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
Self::Gt { field: field.into(), value: value.into() }
}
pub fn gte(field: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
Self::Gte { field: field.into(), value: value.into() }
}
pub fn lt(field: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
Self::Lt { field: field.into(), value: value.into() }
}
pub fn lte(field: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
Self::Lte { field: field.into(), value: value.into() }
}
pub fn exists(field: impl Into<String>) -> Self {
Self::Exists { field: field.into() }
}
pub fn not_exists(field: impl Into<String>) -> Self {
Self::NotExists { field: field.into() }
}
pub fn contains(field: impl Into<String>, value: impl Into<String>) -> Self {
Self::Contains { field: field.into(), value: value.into() }
}
pub fn starts_with(field: impl Into<String>, value: impl Into<String>) -> Self {
Self::StartsWith { field: field.into(), value: value.into() }
}
pub fn ends_with(field: impl Into<String>, value: impl Into<String>) -> Self {
Self::EndsWith { field: field.into(), value: value.into() }
}
pub fn in_array(field: impl Into<String>, value: impl Into<String>) -> Self {
Self::InArray { field: field.into(), value: value.into() }
}
pub fn in_values(field: impl Into<String>, values: Vec<serde_json::Value>) -> Self {
Self::In { field: field.into(), values }
}
pub fn not_in(field: impl Into<String>, values: Vec<serde_json::Value>) -> Self {
Self::NotIn { field: field.into(), values }
}
pub fn json_path(path: impl Into<String>) -> Self {
Self::JsonPath { path: path.into() }
}
pub fn and(filters: Vec<MetadataFilter>) -> Self {
Self::And { filters }
}
pub fn or(filters: Vec<MetadataFilter>) -> Self {
Self::Or { filters }
}
pub fn not(filter: MetadataFilter) -> Self {
Self::Not { filter: Box::new(filter) }
}
pub fn to_sql(&self, param_offset: usize) -> RagResult<(String, Vec<String>)> {
let mut params = Vec::new();
let sql = self.to_sql_inner(param_offset, &mut params)?;
Ok((sql, params))
}
fn to_sql_inner(&self, param_offset: usize, params: &mut Vec<String>) -> RagResult<String> {
match self {
Self::Eq { field, value } => {
let field = validate_field_name(field)?;
let param_idx = param_offset + params.len() + 1;
params.push(json_value_to_string(value));
Ok(format!("metadata->>'{}' = ${}", field, param_idx))
}
Self::Ne { field, value } => {
let field = validate_field_name(field)?;
let param_idx = param_offset + params.len() + 1;
params.push(json_value_to_string(value));
Ok(format!("metadata->>'{}' != ${}", field, param_idx))
}
Self::Gt { field, value } => {
let field = validate_field_name(field)?;
let param_idx = param_offset + params.len() + 1;
params.push(json_value_to_string(value));
Ok(format!("(metadata->>'{}')::numeric > ${}::numeric", field, param_idx))
}
Self::Gte { field, value } => {
let field = validate_field_name(field)?;
let param_idx = param_offset + params.len() + 1;
params.push(json_value_to_string(value));
Ok(format!("(metadata->>'{}')::numeric >= ${}::numeric", field, param_idx))
}
Self::Lt { field, value } => {
let field = validate_field_name(field)?;
let param_idx = param_offset + params.len() + 1;
params.push(json_value_to_string(value));
Ok(format!("(metadata->>'{}')::numeric < ${}::numeric", field, param_idx))
}
Self::Lte { field, value } => {
let field = validate_field_name(field)?;
let param_idx = param_offset + params.len() + 1;
params.push(json_value_to_string(value));
Ok(format!("(metadata->>'{}')::numeric <= ${}::numeric", field, param_idx))
}
Self::Exists { field } => {
let field = validate_field_name(field)?;
Ok(format!("metadata ? '{}'", field))
}
Self::NotExists { field } => {
let field = validate_field_name(field)?;
Ok(format!("NOT (metadata ? '{}')", field))
}
Self::Contains { field, value } => {
let field = validate_field_name(field)?;
let param_idx = param_offset + params.len() + 1;
params.push(format!("%{}%", value));
Ok(format!("metadata->>'{}' ILIKE ${}", field, param_idx))
}
Self::StartsWith { field, value } => {
let field = validate_field_name(field)?;
let param_idx = param_offset + params.len() + 1;
params.push(format!("{}%", value));
Ok(format!("metadata->>'{}' LIKE ${}", field, param_idx))
}
Self::EndsWith { field, value } => {
let field = validate_field_name(field)?;
let param_idx = param_offset + params.len() + 1;
params.push(format!("%{}", value));
Ok(format!("metadata->>'{}' LIKE ${}", field, param_idx))
}
Self::InArray { field, value } => {
let field = validate_field_name(field)?;
let param_idx = param_offset + params.len() + 1;
params.push(value.clone());
Ok(format!("metadata->'{}' ? ${}", field, param_idx))
}
Self::In { field, values } => {
let field = validate_field_name(field)?;
if values.is_empty() {
return Ok("FALSE".to_string());
}
let placeholders: Vec<String> = values.iter().enumerate().map(|(i, _v)| {
let param_idx = param_offset + params.len() + 1 + i;
format!("${}", param_idx)
}).collect();
for v in values {
params.push(json_value_to_string(v));
}
Ok(format!("metadata->>'{}' IN ({})", field, placeholders.join(", ")))
}
Self::NotIn { field, values } => {
let field = validate_field_name(field)?;
if values.is_empty() {
return Ok("TRUE".to_string());
}
let placeholders: Vec<String> = values.iter().enumerate().map(|(i, _)| {
let param_idx = param_offset + params.len() + 1 + i;
format!("${}", param_idx)
}).collect();
for v in values {
params.push(json_value_to_string(v));
}
Ok(format!("metadata->>'{}' NOT IN ({})", field, placeholders.join(", ")))
}
Self::JsonPath { path } => {
Ok(format!("metadata @? '{}'", path.replace('\'', "''")))
}
Self::And { filters } => {
if filters.is_empty() {
return Ok("TRUE".to_string());
}
let mut parts = Vec::new();
for f in filters {
parts.push(f.to_sql_inner(param_offset, params)?);
}
Ok(format!("({})", parts.join(" AND ")))
}
Self::Or { filters } => {
if filters.is_empty() {
return Ok("FALSE".to_string());
}
let mut parts = Vec::new();
for f in filters {
parts.push(f.to_sql_inner(param_offset, params)?);
}
Ok(format!("({})", parts.join(" OR ")))
}
Self::Not { filter } => {
let inner = filter.to_sql_inner(param_offset, params)?;
Ok(format!("NOT ({})", inner))
}
}
}
pub fn parse(s: &str) -> Result<Self, String> {
let s = s.trim();
if s.ends_with('?') {
if s.starts_with('!') {
return Ok(Self::not_exists(&s[1..s.len()-1]));
}
return Ok(Self::exists(&s[..s.len()-1]));
}
let operators = [">=", "<=", "!=", "=", ">", "<", "~", "^", "$"];
for op in &operators {
if let Some(pos) = s.find(op) {
let field = s[..pos].trim();
let value = s[pos + op.len()..].trim();
let json_value: serde_json::Value = if let Ok(n) = value.parse::<i64>() {
serde_json::Value::Number(n.into())
} else if let Ok(n) = value.parse::<f64>() {
serde_json::Number::from_f64(n)
.map(serde_json::Value::Number)
.unwrap_or_else(|| serde_json::Value::String(value.to_string()))
} else if value == "true" {
serde_json::Value::Bool(true)
} else if value == "false" {
serde_json::Value::Bool(false)
} else if value == "null" {
serde_json::Value::Null
} else {
serde_json::Value::String(value.to_string())
};
return Ok(match *op {
"=" => Self::eq(field, json_value),
"!=" => Self::ne(field, json_value),
">" => Self::gt(field, json_value),
">=" => Self::gte(field, json_value),
"<" => Self::lt(field, json_value),
"<=" => Self::lte(field, json_value),
"~" => Self::contains(field, value),
"^" => Self::starts_with(field, value),
"$" => Self::ends_with(field, value),
_ => unreachable!(),
});
}
}
Err(format!("Invalid filter syntax: {}", s))
}
pub fn parse_many(s: &str) -> Result<Self, String> {
let filters: Result<Vec<_>, _> = s
.split([';', '\n'])
.map(|p| p.trim())
.filter(|p| !p.is_empty())
.map(Self::parse)
.collect();
let filters = filters?;
if filters.is_empty() {
return Err("No filters provided".to_string());
}
if filters.len() == 1 {
Ok(filters.into_iter().next().unwrap())
} else {
Ok(Self::and(filters))
}
}
}
fn validate_field_name(field: &str) -> Result<&str, RagError> {
if field.is_empty() {
return Err(RagError::QueryFailed("Empty field name".into()));
}
if field.len() > 128 {
return Err(RagError::QueryFailed("Field name too long".into()));
}
if !field.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '.') {
return Err(RagError::QueryFailed(
format!("Invalid field name '{}': only alphanumeric, underscore, and dot allowed", field)
));
}
Ok(field)
}
fn json_value_to_string(value: &serde_json::Value) -> String {
match value {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::Bool(b) => b.to_string(),
serde_json::Value::Null => String::new(),
_ => value.to_string(),
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Document {
pub id: i64,
pub content: String,
pub metadata: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub score: Option<f32>,
}
#[derive(Debug, Clone)]
pub struct NewDocument {
pub content: String,
pub embedding: Vec<f32>,
pub metadata: Option<serde_json::Value>,
}
pub struct RagStore {
pool: Pool,
config: RagConfig,
}
impl RagStore {
pub async fn connect(config: RagConfig) -> RagResult<Self> {
config.validate()?;
let mut pg_config = Config::new();
let url = url::Url::parse(config.connection_string())
.map_err(|e| RagError::ConfigError(format!("Invalid connection string: {}", e)))?;
pg_config.host = url.host_str().map(String::from);
pg_config.port = url.port();
pg_config.user = if url.username().is_empty() { None } else { Some(url.username().to_string()) };
pg_config.password = url.password().map(String::from);
pg_config.dbname = Some(url.path().trim_start_matches('/').to_string());
let pool = pg_config
.create_pool(Some(Runtime::Tokio1), NoTls)
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let client = pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
client.query_one("SELECT extversion FROM pg_extension WHERE extname = 'vector'", &[])
.await
.map_err(|_| RagError::ConnectionFailed(
"pgvector extension not installed. Run: CREATE EXTENSION vector;".into()
))?;
Ok(Self { pool, config })
}
pub async fn connect_with_config(config_path: Option<&str>) -> RagResult<Self> {
let config = RagConfig::load(config_path)?;
Self::connect(config).await
}
pub async fn create_table(&self) -> RagResult<()> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let tsv_column = if self.config.search_type() == super::SearchType::Hybrid {
format!(
", content_tsv tsvector GENERATED ALWAYS AS (to_tsvector('{}', content)) STORED",
self.config.text_search_language()
)
} else {
String::new()
};
let create_table = format!(
r#"
CREATE TABLE IF NOT EXISTS {} (
id BIGSERIAL PRIMARY KEY,
content TEXT NOT NULL,
embedding vector({}) NOT NULL,
metadata JSONB,
created_at TIMESTAMPTZ DEFAULT NOW(){}
)
"#,
self.config.table_name(),
self.config.embedding_dim(),
tsv_column,
);
client.execute(&create_table, &[]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
self.create_index_inner(&client).await?;
Ok(())
}
pub async fn create_index(&self) -> RagResult<()> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
self.create_index_inner(&client).await
}
async fn create_index_inner(&self, client: &deadpool_postgres::Object) -> RagResult<()> {
let index_type = self.config.index_type();
let ops = self.config.distance_metric().index_ops();
let (method, ops_class, with_clause) = index_type.index_sql(ops);
if !method.is_empty() {
let create_vec_idx = format!(
"CREATE INDEX IF NOT EXISTS {table}_embedding_idx ON {table} USING {method} (embedding {ops_class}) {with_clause}",
table = self.config.table_name(),
method = method,
ops_class = ops_class,
with_clause = with_clause,
);
let _ = client.execute(&create_vec_idx, &[]).await;
}
if self.config.search_type() == super::SearchType::Hybrid {
let create_gin_idx = format!(
"CREATE INDEX IF NOT EXISTS {}_content_tsv_idx ON {} USING gin (content_tsv)",
self.config.table_name(),
self.config.table_name(),
);
let _ = client.execute(&create_gin_idx, &[]).await;
}
Ok(())
}
pub async fn set_hnsw_ef_search(&self, ef_search: u16) -> RagResult<()> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let query = format!("SET hnsw.ef_search = {}", ef_search);
client.execute(&query, &[]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(())
}
pub async fn insert(&self, doc: NewDocument) -> RagResult<i64> {
if doc.embedding.len() != self.config.embedding_dim() {
return Err(RagError::DimensionMismatch {
expected: self.config.embedding_dim(),
actual: doc.embedding.len(),
});
}
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let embedding = Vector::from(doc.embedding);
let query = format!(
"INSERT INTO {} (content, embedding, metadata) VALUES ($1, $2, $3) RETURNING id",
self.config.table_name()
);
let row = client.query_one(&query, &[&doc.content, &embedding, &doc.metadata]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(row.get(0))
}
pub async fn upsert(&self, id: Option<i64>, doc: NewDocument) -> RagResult<i64> {
if doc.embedding.len() != self.config.embedding_dim() {
return Err(RagError::DimensionMismatch {
expected: self.config.embedding_dim(),
actual: doc.embedding.len(),
});
}
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let embedding = Vector::from(doc.embedding);
if let Some(id) = id {
let query = format!(
r#"INSERT INTO {} (id, content, embedding, metadata)
VALUES ($1, $2, $3, $4)
ON CONFLICT (id) DO UPDATE
SET content = EXCLUDED.content,
embedding = EXCLUDED.embedding,
metadata = EXCLUDED.metadata
RETURNING id"#,
self.config.table_name()
);
let row = client.query_one(&query, &[&id, &doc.content, &embedding, &doc.metadata]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(row.get(0))
} else {
let query = format!(
"INSERT INTO {} (content, embedding, metadata) VALUES ($1, $2, $3) RETURNING id",
self.config.table_name()
);
let row = client.query_one(&query, &[&doc.content, &embedding, &doc.metadata]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(row.get(0))
}
}
pub async fn insert_batch(&self, docs: Vec<NewDocument>) -> RagResult<Vec<i64>> {
for (i, doc) in docs.iter().enumerate() {
if doc.embedding.len() != self.config.embedding_dim() {
return Err(RagError::DimensionMismatch {
expected: self.config.embedding_dim(),
actual: doc.embedding.len(),
});
}
let _ = i;
}
let mut ids = Vec::with_capacity(docs.len());
let insert_sql = format!(
"INSERT INTO {} (content, embedding, metadata) VALUES ($1, $2, $3) RETURNING id",
self.config.table_name()
);
const BATCH_SIZE: usize = 100;
for chunk in docs.chunks(BATCH_SIZE) {
let mut client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let tx = client.transaction().await
.map_err(|e| RagError::QueryFailed(format!("begin transaction: {}", e)))?;
let stmt = tx.prepare(&insert_sql).await
.map_err(|e| RagError::QueryFailed(format!("prepare: {}", e)))?;
for doc in chunk {
let embedding = Vector::from(doc.embedding.clone());
let row = tx.query_one(&stmt, &[&doc.content, &embedding, &doc.metadata]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
ids.push(row.get::<_, i64>(0));
}
tx.commit().await
.map_err(|e| RagError::QueryFailed(format!("commit: {}", e)))?;
}
Ok(ids)
}
pub async fn search(&self, query_embedding: &[f32], limit: Option<usize>) -> RagResult<Vec<Document>> {
self.search_with_filter(query_embedding, limit, None).await
}
pub async fn search_with_filter(
&self,
query_embedding: &[f32],
limit: Option<usize>,
filter: Option<MetadataFilter>,
) -> RagResult<Vec<Document>> {
self.search_vector_inner(query_embedding, limit, filter).await
}
async fn search_vector_inner(
&self,
query_embedding: &[f32],
limit: Option<usize>,
filter: Option<MetadataFilter>,
) -> RagResult<Vec<Document>> {
if query_embedding.len() != self.config.embedding_dim() {
return Err(RagError::DimensionMismatch {
expected: self.config.embedding_dim(),
actual: query_embedding.len(),
});
}
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let embedding = Vector::from(query_embedding.to_vec());
let limit = limit.unwrap_or(self.config.max_results()) as i64;
let operator = self.config.distance_metric().operator();
let score_expr = match self.config.distance_metric() {
super::DistanceMetric::Cosine => format!("1 - (embedding {} $1)", operator),
super::DistanceMetric::L2 => format!("1 / (1 + (embedding {} $1))", operator),
super::DistanceMetric::InnerProduct => format!("-(embedding {} $1)", operator),
};
let min_sim = self.config.min_similarity() as f64;
let (filter_clause, filter_params) = if let Some(f) = filter {
let (sql, params) = f.to_sql(3)?; (format!(" AND {}", sql), params)
} else {
(String::new(), Vec::new())
};
let query = format!(
r#"
SELECT id, content, metadata, ({})::float4 as score
FROM {}
WHERE {} >= $2{}
ORDER BY embedding {} $1
LIMIT $3
"#,
score_expr,
self.config.table_name(),
score_expr,
filter_clause,
operator
);
use tokio_postgres::types::ToSql;
let mut params: Vec<&(dyn ToSql + Sync)> = vec![&embedding, &min_sim, &limit];
let filter_param_refs: Vec<&str> = filter_params.iter().map(|s| s.as_str()).collect();
for p in &filter_param_refs {
params.push(p);
}
let rows = client.query(&query, ¶ms).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
let docs = rows.iter().map(|row| {
Document {
id: row.get(0),
content: row.get(1),
metadata: row.get(2),
score: Some(row.get(3)),
}
}).collect();
Ok(docs)
}
pub async fn search_keyword(
&self,
query_text: &str,
limit: usize,
filter: Option<MetadataFilter>,
) -> RagResult<Vec<(i64, f32)>> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let lang = self.config.text_search_language();
let (filter_clause, filter_params) = if let Some(f) = filter {
let (sql, params) = f.to_sql(2)?; (format!(" AND {}", sql), params)
} else {
(String::new(), Vec::new())
};
let limit_i64 = limit as i64;
let query = format!(
r#"
SELECT id, ts_rank(content_tsv, plainto_tsquery('{lang}', $1)) as rank
FROM {table}
WHERE content_tsv @@ plainto_tsquery('{lang}', $1){filter}
ORDER BY rank DESC
LIMIT $2
"#,
lang = lang,
table = self.config.table_name(),
filter = filter_clause,
);
use tokio_postgres::types::ToSql;
let mut params: Vec<&(dyn ToSql + Sync)> = vec![&query_text, &limit_i64];
let filter_param_refs: Vec<&str> = filter_params.iter().map(|s| s.as_str()).collect();
for p in &filter_param_refs {
params.push(p);
}
let rows = client.query(&query, ¶ms).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
let results = rows.iter().map(|row| {
let id: i64 = row.get(0);
let score: f32 = row.get(1);
(id, score)
}).collect();
Ok(results)
}
pub async fn search_hybrid(
&self,
query_embedding: &[f32],
query_text: &str,
limit: Option<usize>,
filter: Option<MetadataFilter>,
) -> RagResult<Vec<Document>> {
let limit = limit.unwrap_or(self.config.max_results());
let oversampled = limit * self.config.hybrid_oversampling() as usize;
let vec_docs = self.search_vector_inner(query_embedding, Some(oversampled), filter.clone()).await?;
let vector_results: Vec<(i64, f32)> = vec_docs.iter().map(|d| (d.id, d.score.unwrap_or(0.0))).collect();
let keyword_results = self.search_keyword(query_text, oversampled, filter).await?;
let fused = rrf_fuse(&vector_results, &keyword_results, self.config.rrf_k(), limit);
let mut doc_map: std::collections::HashMap<i64, Document> = vec_docs.into_iter().map(|d| (d.id, d)).collect();
let missing_ids: Vec<i64> = fused.iter()
.filter(|(id, _)| !doc_map.contains_key(id))
.map(|(id, _)| *id)
.collect();
if !missing_ids.is_empty() {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let placeholders: Vec<String> = (1..=missing_ids.len()).map(|i| format!("${}", i)).collect();
let query = format!(
"SELECT id, content, metadata FROM {} WHERE id IN ({})",
self.config.table_name(),
placeholders.join(", ")
);
use tokio_postgres::types::ToSql;
let params: Vec<&(dyn ToSql + Sync)> = missing_ids.iter().map(|id| id as &(dyn ToSql + Sync)).collect();
let rows = client.query(&query, ¶ms).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
for row in &rows {
let doc = Document {
id: row.get(0),
content: row.get(1),
metadata: row.get(2),
score: None,
};
doc_map.insert(doc.id, doc);
}
}
let results = fused.into_iter().filter_map(|(id, score)| {
doc_map.remove(&id).map(|mut doc| {
doc.score = Some(score);
doc
})
}).collect();
Ok(results)
}
pub async fn count_with_filter(&self, filter: Option<MetadataFilter>) -> RagResult<i64> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let (filter_clause, filter_params) = if let Some(f) = filter {
let (sql, params) = f.to_sql(0)?;
(format!(" WHERE {}", sql), params)
} else {
(String::new(), Vec::new())
};
let query = format!(
"SELECT COUNT(*) FROM {}{}",
self.config.table_name(),
filter_clause
);
use tokio_postgres::types::ToSql;
let filter_param_refs: Vec<&str> = filter_params.iter().map(|s| s.as_str()).collect();
let params: Vec<&(dyn ToSql + Sync)> = filter_param_refs.iter().map(|p| p as &(dyn ToSql + Sync)).collect();
let row = client.query_one(&query, ¶ms).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(row.get(0))
}
pub async fn delete_with_filter(&self, filter: MetadataFilter) -> RagResult<u64> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let (filter_sql, filter_params) = filter.to_sql(0)?;
let query = format!(
"DELETE FROM {} WHERE {}",
self.config.table_name(),
filter_sql
);
use tokio_postgres::types::ToSql;
let filter_param_refs: Vec<&str> = filter_params.iter().map(|s| s.as_str()).collect();
let params: Vec<&(dyn ToSql + Sync)> = filter_param_refs.iter().map(|p| p as &(dyn ToSql + Sync)).collect();
let affected = client.execute(&query, ¶ms).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(affected)
}
pub async fn list_metadata_values(&self, field: &str, limit: Option<usize>) -> RagResult<Vec<String>> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let limit = limit.unwrap_or(100) as i64;
let field = validate_field_name(field)?;
let query = format!(
"SELECT DISTINCT metadata->>'{}' as val FROM {} WHERE metadata ? '{}' ORDER BY val LIMIT $1",
field,
self.config.table_name(),
field
);
let rows = client.query(&query, &[&limit]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
let values = rows.iter()
.filter_map(|row| row.get::<_, Option<String>>(0))
.collect();
Ok(values)
}
pub async fn get(&self, id: i64) -> RagResult<Option<Document>> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let query = format!(
"SELECT id, content, metadata FROM {} WHERE id = $1",
self.config.table_name()
);
let row = client.query_opt(&query, &[&id]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(row.map(|r| Document {
id: r.get(0),
content: r.get(1),
metadata: r.get(2),
score: None,
}))
}
pub async fn delete(&self, id: i64) -> RagResult<bool> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let query = format!("DELETE FROM {} WHERE id = $1", self.config.table_name());
let affected = client.execute(&query, &[&id]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(affected > 0)
}
pub async fn count(&self) -> RagResult<i64> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let query = format!("SELECT COUNT(*) FROM {}", self.config.table_name());
let row = client.query_one(&query, &[]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(row.get(0))
}
pub async fn clear(&self) -> RagResult<u64> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let query = format!("DELETE FROM {}", self.config.table_name());
let affected = client.execute(&query, &[]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(affected)
}
pub fn config(&self) -> &RagConfig {
&self.config
}
pub async fn health_check(&self) -> RagResult<()> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
client.query_one("SELECT 1", &[]).await
.map_err(|e| RagError::QueryFailed(format!("health check failed: {}", e)))?;
Ok(())
}
}
pub struct RagContextBuilder {
docs: Vec<Document>,
separator: String,
max_tokens: Option<usize>,
include_scores: bool,
}
impl RagContextBuilder {
pub fn new(docs: Vec<Document>) -> Self {
Self {
docs,
separator: "\n\n".to_string(),
max_tokens: None,
include_scores: false,
}
}
pub fn with_separator(mut self, sep: impl Into<String>) -> Self {
self.separator = sep.into();
self
}
pub fn with_max_tokens(mut self, max: usize) -> Self {
self.max_tokens = Some(max);
self
}
pub fn with_scores(mut self, include: bool) -> Self {
self.include_scores = include;
self
}
pub fn build(self) -> String {
let mut parts = Vec::new();
let mut total_chars = 0;
let max_chars = self.max_tokens.map(|t| t * 4);
for doc in &self.docs {
let part = if self.include_scores {
if let Some(score) = doc.score {
format!("[{:.2}] {}", score, doc.content)
} else {
doc.content.clone()
}
} else {
doc.content.clone()
};
if let Some(max) = max_chars
&& total_chars + part.len() > max {
break;
}
total_chars += part.len() + self.separator.len();
parts.push(part);
}
parts.join(&self.separator)
}
pub fn build_prompt(self, question: &str) -> String {
let context = self.build();
format!(
"Use the following context to answer the question.\n\n\
Context:\n{}\n\n\
Question: {}\n\n\
Answer:",
context,
question
)
}
}
pub(crate) fn rrf_fuse(
vector_results: &[(i64, f32)],
keyword_results: &[(i64, f32)],
k: u32,
limit: usize,
) -> Vec<(i64, f32)> {
use std::collections::HashMap;
let mut scores: HashMap<i64, f32> = HashMap::new();
for (rank, (id, _)) in vector_results.iter().enumerate() {
*scores.entry(*id).or_default() += 1.0 / (k as f32 + rank as f32 + 1.0);
}
for (rank, (id, _)) in keyword_results.iter().enumerate() {
*scores.entry(*id).or_default() += 1.0 / (k as f32 + rank as f32 + 1.0);
}
let mut fused: Vec<(i64, f32)> = scores.into_iter().collect();
fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
fused.truncate(limit);
fused
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_field_name_valid() {
assert!(validate_field_name("source").is_ok());
assert!(validate_field_name("my_field").is_ok());
assert!(validate_field_name("field123").is_ok());
assert!(validate_field_name("a.b").is_ok());
}
#[test]
fn test_validate_field_name_rejects_sql_injection() {
assert!(validate_field_name("'; DROP TABLE --").is_err());
assert!(validate_field_name("field; DELETE").is_err());
assert!(validate_field_name("").is_err());
assert!(validate_field_name("a\"b").is_err());
}
#[test]
fn test_rrf_fusion_basic() {
let vector_results = vec![(1i64, 0.95f32), (2, 0.85), (3, 0.75)];
let keyword_results = vec![(2i64, 0.9f32), (3, 0.8), (4, 0.7)];
let fused = rrf_fuse(&vector_results, &keyword_results, 60, 3);
assert_eq!(fused[0].0, 2);
assert!(fused.iter().all(|(_, score)| *score > 0.0));
assert!(fused.len() <= 3);
}
#[test]
fn test_rrf_fusion_disjoint() {
let vector_results = vec![(1i64, 0.9f32)];
let keyword_results = vec![(2i64, 0.9f32)];
let fused = rrf_fuse(&vector_results, &keyword_results, 60, 10);
assert_eq!(fused.len(), 2);
}
#[test]
fn test_rrf_fusion_empty() {
let fused = rrf_fuse(&[], &[], 60, 5);
assert!(fused.is_empty());
}
#[test]
fn test_metadata_filter_to_sql_validates_fields() {
let filter = MetadataFilter::eq("valid_field", "value");
assert!(filter.to_sql(0).is_ok());
let bad_filter = MetadataFilter::eq("'; DROP TABLE", "value");
assert!(bad_filter.to_sql(0).is_err());
}
}