#![cfg(feature = "postgres-live")]
use std::collections::BTreeMap;
use postgres::types::ToSql;
use serde_json::Value;
use crate::pg_exec::PgConn;
use crate::postgres::quote_ident_pub as quote_ident;
use crate::search::{merge_row, SearchConfig, SearchQuery, SearchResult};
use crate::StorageError;
fn fts_table_name(entity: &str) -> String {
quote_ident(&format!("_fts_{entity}"))
}
fn fts_gin_index_name(entity: &str) -> String {
quote_ident(&format!("{entity}_fts_gin"))
}
fn facet_index_name(entity: &str, field: &str) -> String {
quote_ident(&format!("{entity}_facet_{field}"))
}
fn sort_index_name(entity: &str, field: &str) -> String {
quote_ident(&format!("{entity}_sort_{field}"))
}
pub fn create_search_index_sql(entity: &str, config: &SearchConfig) -> Vec<String> {
let entity_quoted = quote_ident(entity);
let mut stmts = Vec::new();
if !config.text.is_empty() {
stmts.push(format!(
"CREATE TABLE IF NOT EXISTS {fts} (\
entity_id text PRIMARY KEY REFERENCES {entity_quoted}(id) ON DELETE CASCADE, \
tsv tsvector NOT NULL)",
fts = fts_table_name(entity),
));
stmts.push(format!(
"CREATE INDEX IF NOT EXISTS {gin} ON {fts} USING GIN (tsv)",
gin = fts_gin_index_name(entity),
fts = fts_table_name(entity),
));
}
for field in &config.facets {
stmts.push(format!(
"CREATE INDEX IF NOT EXISTS {idx} ON {entity_quoted} ({field_quoted})",
idx = facet_index_name(entity, field),
field_quoted = quote_ident(field),
));
}
for field in &config.sortable {
stmts.push(format!(
"CREATE INDEX IF NOT EXISTS {idx} ON {entity_quoted} ({field_quoted})",
idx = sort_index_name(entity, field),
field_quoted = quote_ident(field),
));
}
stmts
}
pub fn remove_search_index_sql(entity: &str, config: &SearchConfig) -> Vec<String> {
let mut stmts = Vec::new();
if !config.text.is_empty() {
stmts.push(format!(
"DROP INDEX IF EXISTS {}",
fts_gin_index_name(entity)
));
stmts.push(format!(
"DROP TABLE IF EXISTS {} CASCADE",
fts_table_name(entity)
));
}
for field in &config.facets {
stmts.push(format!(
"DROP INDEX IF EXISTS {}",
facet_index_name(entity, field)
));
}
for field in &config.sortable {
stmts.push(format!(
"DROP INDEX IF EXISTS {}",
sort_index_name(entity, field)
));
}
stmts
}
fn collect_text(data: &Value, config: &SearchConfig) -> String {
config
.text
.iter()
.map(|f| {
data.get(f)
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string()
})
.collect::<Vec<_>>()
.join(" ")
}
fn lang_literal(config: &SearchConfig) -> String {
let raw = config.language_or_default();
let valid = !raw.is_empty()
&& raw.len() <= 63
&& raw.chars().all(|c| c.is_ascii_alphanumeric() || c == '_');
let chosen = if valid { raw } else { "english" };
format!("'{chosen}'")
}
pub fn apply_insert<C: PgConn>(
conn: &mut C,
entity: &str,
id: &str,
data: &Value,
config: &SearchConfig,
) -> Result<(), StorageError> {
if config.text.is_empty() {
return Ok(());
}
let text = collect_text(data, config);
let lang = lang_literal(config);
let sql = format!(
"INSERT INTO {fts} (entity_id, tsv) \
VALUES ($1, to_tsvector({lang}, $2)) \
ON CONFLICT (entity_id) DO UPDATE SET tsv = EXCLUDED.tsv",
fts = fts_table_name(entity),
);
conn.execute(&sql, &[&id, &text])
.map(|_| ())
.map_err(|e| StorageError::new("PG_FTS_INSERT_FAILED", &e.to_string()))
}
pub fn apply_update<C: PgConn>(
conn: &mut C,
entity: &str,
id: &str,
old_row: &Value,
patch: &Value,
config: &SearchConfig,
) -> Result<(), StorageError> {
if config.text.is_empty() {
return Ok(());
}
let touches_text = config.text.iter().any(|f| patch.get(f).is_some());
if !touches_text {
return Ok(());
}
let merged = merge_row(old_row, patch);
let text = collect_text(&merged, config);
let lang = lang_literal(config);
let sql = format!(
"INSERT INTO {fts} (entity_id, tsv) \
VALUES ($1, to_tsvector({lang}, $2)) \
ON CONFLICT (entity_id) DO UPDATE SET tsv = EXCLUDED.tsv",
fts = fts_table_name(entity),
);
conn.execute(&sql, &[&id, &text])
.map(|_| ())
.map_err(|e| StorageError::new("PG_FTS_UPDATE_FAILED", &e.to_string()))
}
pub fn apply_delete<C: PgConn>(
conn: &mut C,
entity: &str,
id: &str,
config: &SearchConfig,
) -> Result<(), StorageError> {
if config.text.is_empty() {
return Ok(());
}
let sql = format!(
"DELETE FROM {fts} WHERE entity_id = $1",
fts = fts_table_name(entity),
);
conn.execute(&sql, &[&id])
.map(|_| ())
.map_err(|e| StorageError::new("PG_FTS_DELETE_FAILED", &e.to_string()))
}
pub fn run_search<C: PgConn>(
conn: &mut C,
entity: &str,
config: &SearchConfig,
query: &SearchQuery,
) -> Result<SearchResult, StorageError> {
let t0 = std::time::Instant::now();
let entity_quoted = quote_ident(entity);
let fts = fts_table_name(entity);
let lang = lang_literal(config);
if let Some((field, _)) = &query.sort {
if !config.sortable.iter().any(|s| s == field) {
return Err(StorageError::new(
"INVALID_SORT_FIELD",
&format!(
"sort field \"{field}\" is not in the entity's `sortable` config"
),
));
}
}
let mut clauses: Vec<String> = Vec::new();
let mut params: Vec<Box<dyn ToSql + Sync>> = Vec::new();
let valid_facet = |f: &str| config.facets.iter().any(|cf| cf == f);
let has_query = !query.query.trim().is_empty() && !config.text.is_empty();
if has_query {
params.push(Box::new(query.query.clone()));
clauses.push(format!(
"f.tsv @@ plainto_tsquery({lang}, ${})",
params.len()
));
}
let mut filter_pairs: Vec<(String, String)> = Vec::new();
for (field, value) in &query.filters {
if !valid_facet(field) {
continue;
}
let value_str = match crate::search::stringify_facet(value) {
Some(s) => s,
None => return Ok(empty_result(t0)),
};
filter_pairs.push((field.clone(), value_str.clone()));
params.push(Box::new(value_str));
clauses.push(format!(
"{}::text = ${}",
qualified_column(&entity_quoted, field),
params.len()
));
}
let where_clause = if clauses.is_empty() {
String::new()
} else {
format!(" WHERE {}", clauses.join(" AND "))
};
let join_clause = if has_query {
format!(" JOIN {fts} f ON f.entity_id = e.id")
} else {
String::new()
};
let total_sql = format!(
"SELECT COUNT(*) FROM {entity_quoted} e{join_clause}{where_clause}"
);
let total: i64 = {
let pg_params = box_params(¶ms);
let row = conn
.query(&total_sql, &pg_params)
.map_err(|e| StorageError::new("PG_SEARCH_TOTAL_FAILED", &e.to_string()))?;
row.first().map(|r| r.get::<_, i64>(0)).unwrap_or(0)
};
let total = total as u64;
let order_clause = build_order_clause(query, has_query, &entity_quoted)?;
let limit = query.page_size.max(1).min(100);
let offset = query.page.saturating_mul(limit);
let select_cols = if has_query {
format!(
"e.*, ts_rank(f.tsv, plainto_tsquery({lang}, $1)) AS _rank"
)
} else {
"e.*".to_string()
};
let hits_sql = format!(
"SELECT {select_cols} FROM {entity_quoted} e{join_clause}{where_clause}{order_clause} \
LIMIT {limit} OFFSET {offset}"
);
let hits: Vec<Value> = {
let pg_params = box_params(¶ms);
let rows = conn
.query(&hits_sql, &pg_params)
.map_err(|e| StorageError::new("PG_SEARCH_HITS_FAILED", &e.to_string()))?;
rows.iter().map(crate::postgres::row_to_json_pub).collect()
};
let wanted_facets: Vec<&String> = if query.facets.is_empty() {
config.facets.iter().collect()
} else {
query.facets.iter().filter(|f| valid_facet(f)).collect()
};
let mut facet_counts: BTreeMap<String, BTreeMap<String, u64>> = BTreeMap::new();
for facet in wanted_facets {
let counts = run_facet_count(
conn,
&entity_quoted,
&fts,
facet,
has_query,
&query.query,
&lang,
&filter_pairs,
)?;
if !counts.is_empty() {
facet_counts.insert(facet.clone(), counts);
}
}
Ok(SearchResult {
hits,
facet_counts,
total,
took_ms: t0.elapsed().as_millis() as u64,
})
}
fn run_facet_count<C: PgConn>(
conn: &mut C,
entity_quoted: &str,
fts: &str,
facet: &str,
has_query: bool,
query_text: &str,
lang: &str,
filter_pairs: &[(String, String)],
) -> Result<BTreeMap<String, u64>, StorageError> {
let facet_col = qualified_column(entity_quoted, facet);
let mut clauses: Vec<String> = Vec::new();
let mut params: Vec<Box<dyn ToSql + Sync>> = Vec::new();
if has_query {
params.push(Box::new(query_text.to_string()));
clauses.push(format!(
"f.tsv @@ plainto_tsquery({lang}, ${})",
params.len()
));
}
for (field, value) in filter_pairs {
if field == facet {
continue; }
params.push(Box::new(value.clone()));
clauses.push(format!(
"{}::text = ${}",
qualified_column(entity_quoted, field),
params.len()
));
}
let where_clause = if clauses.is_empty() {
String::new()
} else {
format!(" WHERE {}", clauses.join(" AND "))
};
let join_clause = if has_query {
format!(" JOIN {fts} f ON f.entity_id = e.id")
} else {
String::new()
};
let sql = format!(
"SELECT {facet_col}::text AS value, COUNT(*) AS cnt \
FROM {entity_quoted} e{join_clause}{where_clause} \
GROUP BY {facet_col} ORDER BY cnt DESC LIMIT 100"
);
let pg_params = box_params(¶ms);
let rows = conn
.query(&sql, &pg_params)
.map_err(|e| StorageError::new("PG_SEARCH_FACET_FAILED", &e.to_string()))?;
let mut counts: BTreeMap<String, u64> = BTreeMap::new();
for row in &rows {
let val: Option<String> = row.get(0);
let cnt: i64 = row.get(1);
if let Some(v) = val {
counts.insert(v, cnt as u64);
}
}
Ok(counts)
}
fn qualified_column(_entity_quoted: &str, col: &str) -> String {
format!("e.{}", quote_ident(col))
}
fn build_order_clause(
query: &SearchQuery,
has_query: bool,
_entity_quoted: &str,
) -> Result<String, StorageError> {
if let Some((field, dir)) = &query.sort {
let dir = match dir.to_lowercase().as_str() {
"desc" => "DESC",
_ => "ASC",
};
Ok(format!(" ORDER BY e.{} {dir}", quote_ident(field)))
} else if has_query {
Ok(" ORDER BY _rank DESC".to_string())
} else {
Ok(" ORDER BY e.id".to_string())
}
}
fn empty_result(t0: std::time::Instant) -> SearchResult {
SearchResult {
hits: Vec::new(),
facet_counts: BTreeMap::new(),
total: 0,
took_ms: t0.elapsed().as_millis() as u64,
}
}
fn box_params(boxed: &[Box<dyn ToSql + Sync>]) -> Vec<&(dyn ToSql + Sync)> {
boxed.iter().map(|b| b.as_ref() as _).collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg() -> SearchConfig {
SearchConfig {
text: vec!["name".into(), "description".into()],
facets: vec!["brand".into(), "category".into()],
sortable: vec!["price".into(), "createdAt".into()],
language: None,
}
}
#[test]
fn create_emits_fts_table_gin_and_indexes() {
let stmts = create_search_index_sql("Product", &cfg());
let blob = stmts.join("\n");
assert!(blob.contains("CREATE TABLE IF NOT EXISTS \"_fts_Product\""));
assert!(blob.contains("tsv tsvector NOT NULL"));
assert!(blob.contains("USING GIN (tsv)"));
assert!(blob.contains("\"Product_facet_brand\""));
assert!(blob.contains("\"Product_facet_category\""));
assert!(blob.contains("\"Product_sort_price\""));
assert!(blob.contains("\"Product_sort_createdAt\""));
}
#[test]
fn create_skips_fts_when_no_text_fields() {
let cfg = SearchConfig {
text: vec![],
facets: vec!["brand".into()],
sortable: vec![],
language: None,
};
let stmts = create_search_index_sql("Product", &cfg);
let blob = stmts.join("\n");
assert!(!blob.contains("_fts_Product"));
assert!(blob.contains("\"Product_facet_brand\""));
}
#[test]
fn remove_drops_fts_and_indexes_when_text_present() {
let stmts = remove_search_index_sql("Product", &cfg());
let blob = stmts.join("\n");
assert!(blob.contains("DROP TABLE IF EXISTS \"_fts_Product\""));
assert!(blob.contains("DROP INDEX IF EXISTS \"Product_fts_gin\""));
assert!(blob.contains("DROP INDEX IF EXISTS \"Product_facet_brand\""));
}
#[test]
fn lang_literal_falls_back_to_english_for_invalid_input() {
let cfg = SearchConfig {
text: vec!["name".into()],
facets: vec![],
sortable: vec![],
language: Some("english'; DROP TABLE x; --".into()),
};
assert_eq!(lang_literal(&cfg), "'english'");
}
#[test]
fn lang_literal_passes_through_known_postgres_configs() {
for cfg_lang in ["english", "spanish", "french", "german", "simple"] {
let cfg = SearchConfig {
text: vec!["name".into()],
facets: vec![],
sortable: vec![],
language: Some(cfg_lang.to_string()),
};
assert_eq!(lang_literal(&cfg), format!("'{cfg_lang}'"));
}
}
#[test]
fn entity_name_with_double_quote_is_neutralized() {
let stmts = create_search_index_sql("Foo\"; DROP TABLE bar; --", &cfg());
let blob = stmts.join("\n");
assert!(!blob.contains("Foo\"; DROP"));
assert!(blob.contains("Foo\"\""));
}
}