use deadpool_postgres::{Config, Pool, Runtime};
use pgvector::Vector;
use tokio_postgres::NoTls;
use serde::{Deserialize, Serialize};
use super::{RagConfig, RagError, RagResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "op", rename_all = "snake_case")]
pub enum MetadataFilter {
Eq { field: String, value: serde_json::Value },
Ne { field: String, value: serde_json::Value },
Gt { field: String, value: serde_json::Value },
Gte { field: String, value: serde_json::Value },
Lt { field: String, value: serde_json::Value },
Lte { field: String, value: serde_json::Value },
Exists { field: String },
NotExists { field: String },
Contains { field: String, value: String },
StartsWith { field: String, value: String },
EndsWith { field: String, value: String },
InArray { field: String, value: String },
In { field: String, values: Vec<serde_json::Value> },
NotIn { field: String, values: Vec<serde_json::Value> },
JsonPath { path: String },
And { filters: Vec<MetadataFilter> },
Or { filters: Vec<MetadataFilter> },
Not { filter: Box<MetadataFilter> },
}
impl MetadataFilter {
pub fn eq(field: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
Self::Eq { field: field.into(), value: value.into() }
}
pub fn ne(field: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
Self::Ne { field: field.into(), value: value.into() }
}
pub fn gt(field: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
Self::Gt { field: field.into(), value: value.into() }
}
pub fn gte(field: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
Self::Gte { field: field.into(), value: value.into() }
}
pub fn lt(field: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
Self::Lt { field: field.into(), value: value.into() }
}
pub fn lte(field: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
Self::Lte { field: field.into(), value: value.into() }
}
pub fn exists(field: impl Into<String>) -> Self {
Self::Exists { field: field.into() }
}
pub fn not_exists(field: impl Into<String>) -> Self {
Self::NotExists { field: field.into() }
}
pub fn contains(field: impl Into<String>, value: impl Into<String>) -> Self {
Self::Contains { field: field.into(), value: value.into() }
}
pub fn starts_with(field: impl Into<String>, value: impl Into<String>) -> Self {
Self::StartsWith { field: field.into(), value: value.into() }
}
pub fn ends_with(field: impl Into<String>, value: impl Into<String>) -> Self {
Self::EndsWith { field: field.into(), value: value.into() }
}
pub fn in_array(field: impl Into<String>, value: impl Into<String>) -> Self {
Self::InArray { field: field.into(), value: value.into() }
}
pub fn in_values(field: impl Into<String>, values: Vec<serde_json::Value>) -> Self {
Self::In { field: field.into(), values }
}
pub fn not_in(field: impl Into<String>, values: Vec<serde_json::Value>) -> Self {
Self::NotIn { field: field.into(), values }
}
pub fn json_path(path: impl Into<String>) -> Self {
Self::JsonPath { path: path.into() }
}
pub fn and(filters: Vec<MetadataFilter>) -> Self {
Self::And { filters }
}
pub fn or(filters: Vec<MetadataFilter>) -> Self {
Self::Or { filters }
}
pub fn not(filter: MetadataFilter) -> Self {
Self::Not { filter: Box::new(filter) }
}
pub fn to_sql(&self, param_offset: usize) -> (String, Vec<String>) {
let mut params = Vec::new();
let sql = self.to_sql_inner(param_offset, &mut params);
(sql, params)
}
fn to_sql_inner(&self, param_offset: usize, params: &mut Vec<String>) -> String {
match self {
Self::Eq { field, value } => {
let param_idx = param_offset + params.len() + 1;
params.push(json_value_to_string(value));
format!("metadata->>'{}' = ${}", escape_field(field), param_idx)
}
Self::Ne { field, value } => {
let param_idx = param_offset + params.len() + 1;
params.push(json_value_to_string(value));
format!("metadata->>'{}' != ${}", escape_field(field), param_idx)
}
Self::Gt { field, value } => {
let param_idx = param_offset + params.len() + 1;
params.push(json_value_to_string(value));
format!("(metadata->>'{}')::numeric > ${}::numeric", escape_field(field), param_idx)
}
Self::Gte { field, value } => {
let param_idx = param_offset + params.len() + 1;
params.push(json_value_to_string(value));
format!("(metadata->>'{}')::numeric >= ${}::numeric", escape_field(field), param_idx)
}
Self::Lt { field, value } => {
let param_idx = param_offset + params.len() + 1;
params.push(json_value_to_string(value));
format!("(metadata->>'{}')::numeric < ${}::numeric", escape_field(field), param_idx)
}
Self::Lte { field, value } => {
let param_idx = param_offset + params.len() + 1;
params.push(json_value_to_string(value));
format!("(metadata->>'{}')::numeric <= ${}::numeric", escape_field(field), param_idx)
}
Self::Exists { field } => {
format!("metadata ? '{}'", escape_field(field))
}
Self::NotExists { field } => {
format!("NOT (metadata ? '{}')", escape_field(field))
}
Self::Contains { field, value } => {
let param_idx = param_offset + params.len() + 1;
params.push(format!("%{}%", value));
format!("metadata->>'{}' ILIKE ${}", escape_field(field), param_idx)
}
Self::StartsWith { field, value } => {
let param_idx = param_offset + params.len() + 1;
params.push(format!("{}%", value));
format!("metadata->>'{}' LIKE ${}", escape_field(field), param_idx)
}
Self::EndsWith { field, value } => {
let param_idx = param_offset + params.len() + 1;
params.push(format!("%{}", value));
format!("metadata->>'{}' LIKE ${}", escape_field(field), param_idx)
}
Self::InArray { field, value } => {
let param_idx = param_offset + params.len() + 1;
params.push(value.clone());
format!("metadata->'{}' ? ${}", escape_field(field), param_idx)
}
Self::In { field, values } => {
if values.is_empty() {
return "FALSE".to_string();
}
let placeholders: Vec<String> = values.iter().enumerate().map(|(i, v)| {
let param_idx = param_offset + params.len() + 1 + i;
format!("${}", param_idx)
}).collect();
for v in values {
params.push(json_value_to_string(v));
}
format!("metadata->>'{}' IN ({})", escape_field(field), placeholders.join(", "))
}
Self::NotIn { field, values } => {
if values.is_empty() {
return "TRUE".to_string();
}
let placeholders: Vec<String> = values.iter().enumerate().map(|(i, _)| {
let param_idx = param_offset + params.len() + 1 + i;
format!("${}", param_idx)
}).collect();
for v in values {
params.push(json_value_to_string(v));
}
format!("metadata->>'{}' NOT IN ({})", escape_field(field), placeholders.join(", "))
}
Self::JsonPath { path } => {
format!("metadata @? '{}'", path.replace('\'', "''"))
}
Self::And { filters } => {
if filters.is_empty() {
return "TRUE".to_string();
}
let parts: Vec<String> = filters.iter().map(|f| {
f.to_sql_inner(param_offset + params.len(), params)
}).collect();
format!("({})", parts.join(" AND "))
}
Self::Or { filters } => {
if filters.is_empty() {
return "FALSE".to_string();
}
let parts: Vec<String> = filters.iter().map(|f| {
f.to_sql_inner(param_offset + params.len(), params)
}).collect();
format!("({})", parts.join(" OR "))
}
Self::Not { filter } => {
let inner = filter.to_sql_inner(param_offset + params.len(), params);
format!("NOT ({})", inner)
}
}
}
pub fn parse(s: &str) -> Result<Self, String> {
let s = s.trim();
if s.ends_with('?') {
if s.starts_with('!') {
return Ok(Self::not_exists(&s[1..s.len()-1]));
}
return Ok(Self::exists(&s[..s.len()-1]));
}
let operators = [">=", "<=", "!=", "=", ">", "<", "~", "^", "$"];
for op in &operators {
if let Some(pos) = s.find(op) {
let field = s[..pos].trim();
let value = s[pos + op.len()..].trim();
let json_value: serde_json::Value = if let Ok(n) = value.parse::<i64>() {
serde_json::Value::Number(n.into())
} else if let Ok(n) = value.parse::<f64>() {
serde_json::Number::from_f64(n)
.map(serde_json::Value::Number)
.unwrap_or_else(|| serde_json::Value::String(value.to_string()))
} else if value == "true" {
serde_json::Value::Bool(true)
} else if value == "false" {
serde_json::Value::Bool(false)
} else if value == "null" {
serde_json::Value::Null
} else {
serde_json::Value::String(value.to_string())
};
return Ok(match *op {
"=" => Self::eq(field, json_value),
"!=" => Self::ne(field, json_value),
">" => Self::gt(field, json_value),
">=" => Self::gte(field, json_value),
"<" => Self::lt(field, json_value),
"<=" => Self::lte(field, json_value),
"~" => Self::contains(field, value),
"^" => Self::starts_with(field, value),
"$" => Self::ends_with(field, value),
_ => unreachable!(),
});
}
}
Err(format!("Invalid filter syntax: {}", s))
}
pub fn parse_many(s: &str) -> Result<Self, String> {
let filters: Result<Vec<_>, _> = s
.split(|c| c == ';' || c == '\n')
.map(|p| p.trim())
.filter(|p| !p.is_empty())
.map(Self::parse)
.collect();
let filters = filters?;
if filters.is_empty() {
return Err("No filters provided".to_string());
}
if filters.len() == 1 {
Ok(filters.into_iter().next().unwrap())
} else {
Ok(Self::and(filters))
}
}
}
fn escape_field(field: &str) -> String {
field.replace('\'', "''").replace('"', "\"\"")
}
fn json_value_to_string(value: &serde_json::Value) -> String {
match value {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::Bool(b) => b.to_string(),
serde_json::Value::Null => String::new(),
_ => value.to_string(),
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Document {
pub id: i64,
pub content: String,
pub metadata: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub score: Option<f32>,
}
#[derive(Debug, Clone)]
pub struct NewDocument {
pub content: String,
pub embedding: Vec<f32>,
pub metadata: Option<serde_json::Value>,
}
pub struct RagStore {
pool: Pool,
config: RagConfig,
}
impl RagStore {
pub async fn connect(config: RagConfig) -> RagResult<Self> {
config.validate()?;
let mut pg_config = Config::new();
let url = url::Url::parse(config.connection_string())
.map_err(|e| RagError::ConfigError(format!("Invalid connection string: {}", e)))?;
pg_config.host = url.host_str().map(String::from);
pg_config.port = url.port();
pg_config.user = if url.username().is_empty() { None } else { Some(url.username().to_string()) };
pg_config.password = url.password().map(String::from);
pg_config.dbname = Some(url.path().trim_start_matches('/').to_string());
let pool = pg_config
.create_pool(Some(Runtime::Tokio1), NoTls)
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let client = pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
client.query_one("SELECT extversion FROM pg_extension WHERE extname = 'vector'", &[])
.await
.map_err(|_| RagError::ConnectionFailed(
"pgvector extension not installed. Run: CREATE EXTENSION vector;".into()
))?;
Ok(Self { pool, config })
}
pub async fn connect_with_config(config_path: Option<&str>) -> RagResult<Self> {
let config = RagConfig::load(config_path)?;
Self::connect(config).await
}
pub async fn create_table(&self) -> RagResult<()> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let create_table = format!(
r#"
CREATE TABLE IF NOT EXISTS {} (
id BIGSERIAL PRIMARY KEY,
content TEXT NOT NULL,
embedding vector({}) NOT NULL,
metadata JSONB,
created_at TIMESTAMPTZ DEFAULT NOW()
)
"#,
self.config.table_name(),
self.config.embedding_dim()
);
client.execute(&create_table, &[]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
let create_index = format!(
r#"
CREATE INDEX IF NOT EXISTS {}_embedding_idx
ON {} USING ivfflat (embedding {})
"#,
self.config.table_name(),
self.config.table_name(),
self.config.distance_metric().index_ops()
);
let _ = client.execute(&create_index, &[]).await;
Ok(())
}
pub async fn insert(&self, doc: NewDocument) -> RagResult<i64> {
if doc.embedding.len() != self.config.embedding_dim() {
return Err(RagError::DimensionMismatch {
expected: self.config.embedding_dim(),
actual: doc.embedding.len(),
});
}
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let embedding = Vector::from(doc.embedding);
let query = format!(
"INSERT INTO {} (content, embedding, metadata) VALUES ($1, $2, $3) RETURNING id",
self.config.table_name()
);
let row = client.query_one(&query, &[&doc.content, &embedding, &doc.metadata]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(row.get(0))
}
pub async fn insert_batch(&self, docs: Vec<NewDocument>) -> RagResult<Vec<i64>> {
let mut ids = Vec::with_capacity(docs.len());
for doc in docs {
let id = self.insert(doc).await?;
ids.push(id);
}
Ok(ids)
}
pub async fn search(&self, query_embedding: &[f32], limit: Option<usize>) -> RagResult<Vec<Document>> {
self.search_with_filter(query_embedding, limit, None).await
}
pub async fn search_with_filter(
&self,
query_embedding: &[f32],
limit: Option<usize>,
filter: Option<MetadataFilter>,
) -> RagResult<Vec<Document>> {
if query_embedding.len() != self.config.embedding_dim() {
return Err(RagError::DimensionMismatch {
expected: self.config.embedding_dim(),
actual: query_embedding.len(),
});
}
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let embedding = Vector::from(query_embedding.to_vec());
let limit = limit.unwrap_or(self.config.max_results()) as i64;
let operator = self.config.distance_metric().operator();
let score_expr = match self.config.distance_metric() {
super::DistanceMetric::Cosine => format!("1 - (embedding {} $1)", operator),
super::DistanceMetric::L2 => format!("1 / (1 + (embedding {} $1))", operator),
super::DistanceMetric::InnerProduct => format!("-(embedding {} $1)", operator),
};
let min_sim = self.config.min_similarity();
let (filter_clause, filter_params) = if let Some(f) = filter {
let (sql, params) = f.to_sql(3); (format!(" AND {}", sql), params)
} else {
(String::new(), Vec::new())
};
let query = format!(
r#"
SELECT id, content, metadata, {} as score
FROM {}
WHERE {} >= $2{}
ORDER BY embedding {} $1
LIMIT $3
"#,
score_expr,
self.config.table_name(),
score_expr,
filter_clause,
operator
);
use tokio_postgres::types::ToSql;
let mut params: Vec<&(dyn ToSql + Sync)> = vec![&embedding, &min_sim, &limit];
let filter_param_refs: Vec<&str> = filter_params.iter().map(|s| s.as_str()).collect();
for p in &filter_param_refs {
params.push(p);
}
let rows = client.query(&query, ¶ms).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
let docs = rows.iter().map(|row| {
Document {
id: row.get(0),
content: row.get(1),
metadata: row.get(2),
score: Some(row.get(3)),
}
}).collect();
Ok(docs)
}
pub async fn count_with_filter(&self, filter: Option<MetadataFilter>) -> RagResult<i64> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let (filter_clause, filter_params) = if let Some(f) = filter {
let (sql, params) = f.to_sql(0);
(format!(" WHERE {}", sql), params)
} else {
(String::new(), Vec::new())
};
let query = format!(
"SELECT COUNT(*) FROM {}{}",
self.config.table_name(),
filter_clause
);
use tokio_postgres::types::ToSql;
let filter_param_refs: Vec<&str> = filter_params.iter().map(|s| s.as_str()).collect();
let params: Vec<&(dyn ToSql + Sync)> = filter_param_refs.iter().map(|p| p as &(dyn ToSql + Sync)).collect();
let row = client.query_one(&query, ¶ms).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(row.get(0))
}
pub async fn delete_with_filter(&self, filter: MetadataFilter) -> RagResult<u64> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let (filter_sql, filter_params) = filter.to_sql(0);
let query = format!(
"DELETE FROM {} WHERE {}",
self.config.table_name(),
filter_sql
);
use tokio_postgres::types::ToSql;
let filter_param_refs: Vec<&str> = filter_params.iter().map(|s| s.as_str()).collect();
let params: Vec<&(dyn ToSql + Sync)> = filter_param_refs.iter().map(|p| p as &(dyn ToSql + Sync)).collect();
let affected = client.execute(&query, ¶ms).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(affected)
}
pub async fn list_metadata_values(&self, field: &str, limit: Option<usize>) -> RagResult<Vec<String>> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let limit = limit.unwrap_or(100) as i64;
let query = format!(
"SELECT DISTINCT metadata->>'{}' as val FROM {} WHERE metadata ? '{}' ORDER BY val LIMIT $1",
escape_field(field),
self.config.table_name(),
escape_field(field)
);
let rows = client.query(&query, &[&limit]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
let values = rows.iter()
.filter_map(|row| row.get::<_, Option<String>>(0))
.collect();
Ok(values)
}
pub async fn get(&self, id: i64) -> RagResult<Option<Document>> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let query = format!(
"SELECT id, content, metadata FROM {} WHERE id = $1",
self.config.table_name()
);
let row = client.query_opt(&query, &[&id]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(row.map(|r| Document {
id: r.get(0),
content: r.get(1),
metadata: r.get(2),
score: None,
}))
}
pub async fn delete(&self, id: i64) -> RagResult<bool> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let query = format!("DELETE FROM {} WHERE id = $1", self.config.table_name());
let affected = client.execute(&query, &[&id]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(affected > 0)
}
pub async fn count(&self) -> RagResult<i64> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let query = format!("SELECT COUNT(*) FROM {}", self.config.table_name());
let row = client.query_one(&query, &[]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(row.get(0))
}
pub async fn clear(&self) -> RagResult<u64> {
let client = self.pool.get().await
.map_err(|e| RagError::ConnectionFailed(format!("{}", e)))?;
let query = format!("DELETE FROM {}", self.config.table_name());
let affected = client.execute(&query, &[]).await
.map_err(|e| RagError::QueryFailed(format!("{}", e)))?;
Ok(affected)
}
pub fn config(&self) -> &RagConfig {
&self.config
}
}
pub struct RagContextBuilder {
docs: Vec<Document>,
separator: String,
max_tokens: Option<usize>,
include_scores: bool,
}
impl RagContextBuilder {
pub fn new(docs: Vec<Document>) -> Self {
Self {
docs,
separator: "\n\n".to_string(),
max_tokens: None,
include_scores: false,
}
}
pub fn with_separator(mut self, sep: impl Into<String>) -> Self {
self.separator = sep.into();
self
}
pub fn with_max_tokens(mut self, max: usize) -> Self {
self.max_tokens = Some(max);
self
}
pub fn with_scores(mut self, include: bool) -> Self {
self.include_scores = include;
self
}
pub fn build(self) -> String {
let mut parts = Vec::new();
let mut total_chars = 0;
let max_chars = self.max_tokens.map(|t| t * 4);
for doc in &self.docs {
let part = if self.include_scores {
if let Some(score) = doc.score {
format!("[{:.2}] {}", score, doc.content)
} else {
doc.content.clone()
}
} else {
doc.content.clone()
};
if let Some(max) = max_chars {
if total_chars + part.len() > max {
break;
}
}
total_chars += part.len() + self.separator.len();
parts.push(part);
}
parts.join(&self.separator)
}
pub fn build_prompt(self, question: &str) -> String {
let context = self.build();
format!(
"Use the following context to answer the question.\n\n\
Context:\n{}\n\n\
Question: {}\n\n\
Answer:",
context,
question
)
}
}