use crate::db::DbPool;
use crate::orm::{FieldSpec, Model, SqlType};
pub trait Searchable: Model {
fn kind() -> &'static str {
Self::TABLE
}
fn title() -> &'static str {
default_title::<Self>()
}
fn body() -> Vec<&'static str> {
default_body::<Self>()
}
fn ident() -> &'static str {
default_pk_column::<Self>()
}
fn filter_sql() -> Option<&'static str> {
None
}
}
fn is_content_text(f: &FieldSpec) -> bool {
matches!(f.ty, SqlType::Text) && f.text_format.is_none() && f.choices.is_empty()
}
pub fn default_body<T: Model>() -> Vec<&'static str> {
T::FIELDS
.iter()
.filter(|f| is_content_text(f))
.map(|f| f.name)
.collect()
}
pub fn default_title<T: Model>() -> &'static str {
let texts: Vec<&'static str> = default_body::<T>();
for want in ["title", "name"] {
if let Some(c) = texts.iter().find(|c| c.eq_ignore_ascii_case(want)) {
return c;
}
}
texts
.first()
.copied()
.unwrap_or_else(default_pk_column::<T>)
}
pub fn default_pk_column<T: Model>() -> &'static str {
T::FIELDS
.iter()
.find(|f| f.primary_key)
.map(|f| f.name)
.unwrap_or("id")
}
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct SearchHit {
pub kind: String,
pub pk: String,
pub title: String,
pub snippet: String,
pub rank: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Backend {
Postgres,
Sqlite,
}
fn quote_ident(name: &str) -> String {
format!("\"{}\"", name.replace('"', "\"\""))
}
fn concat_coalesce(cols: &[&str]) -> String {
cols.iter()
.map(|c| format!("coalesce({}, '')", quote_ident(c)))
.collect::<Vec<_>>()
.join(" || ' ' || ")
}
pub fn branch_sql<T: Searchable>(backend: Backend) -> String {
let table = quote_ident(T::TABLE);
let kind = T::kind().replace('\'', "''");
let ident = quote_ident(T::ident());
let title = quote_ident(T::title());
let body = T::body();
let body_cols: Vec<&str> = if body.is_empty() {
vec![T::title()]
} else {
body
};
let body_concat = concat_coalesce(&body_cols);
let rest: Vec<&str> = body_cols
.iter()
.copied()
.filter(|c| *c != T::title())
.collect();
let mut scope: Vec<String> = Vec::new();
if let Some(f) = T::filter_sql() {
scope.push(format!("({f})"));
}
if T::SOFT_DELETE {
scope.push(format!("{} IS NULL", quote_ident("deleted_at")));
}
let scope_and = if scope.is_empty() {
String::new()
} else {
format!(" AND {}", scope.join(" AND "))
};
match backend {
Backend::Postgres => {
let title_vec =
format!("setweight(to_tsvector('english', coalesce({title}, '')), 'A')");
let rest_vec = if rest.is_empty() {
String::new()
} else {
format!(" || to_tsvector('english', {})", concat_coalesce(&rest))
};
format!(
"SELECT '{kind}' AS kind, \
CAST({ident} AS text) AS pk, \
{title} AS title, \
left({body_concat}, 200) AS snippet, \
ts_rank({title_vec}{rest_vec}, websearch_to_tsquery('english', $1))::float8 AS rank \
FROM {table} \
WHERE to_tsvector('english', {body_concat}) @@ websearch_to_tsquery('english', $1){scope_and}"
)
}
Backend::Sqlite => {
let where_like = body_cols
.iter()
.map(|c| format!("{} LIKE ?1 ESCAPE '\\'", quote_ident(c)))
.collect::<Vec<_>>()
.join(" OR ");
let title_q = quote_ident(T::title());
let body_substr_terms = body_cols
.iter()
.map(|c| {
format!(
"(CASE WHEN {} LIKE ?1 ESCAPE '\\' THEN 1.0 ELSE 0 END)",
quote_ident(c)
)
})
.collect::<Vec<_>>()
.join(" + ");
format!(
"SELECT '{kind}' AS kind, \
CAST({ident} AS TEXT) AS pk, \
{title} AS title, \
substr({body_concat}, 1, 200) AS snippet, \
( (CASE WHEN {title_q} LIKE ?1 ESCAPE '\\' THEN 2.0 ELSE 0 END) \
+ {body_substr_terms} \
+ (CASE WHEN {title_q} LIKE ?2 ESCAPE '\\' THEN 1.0 ELSE 0 END) ) AS rank \
FROM {table} \
WHERE ({where_like}){scope_and}"
)
}
}
}
pub trait SearchSources {
fn branches(backend: Backend) -> Vec<String>;
}
macro_rules! impl_search_sources {
($($T:ident),+) => {
impl<$($T: Searchable),+> SearchSources for ($($T,)+) {
fn branches(backend: Backend) -> Vec<String> {
vec![$( branch_sql::<$T>(backend) ),+]
}
}
};
}
impl_search_sources!(A);
impl_search_sources!(A, B);
impl_search_sources!(A, B, C);
impl_search_sources!(A, B, C, D);
impl_search_sources!(A, B, C, D, E);
impl_search_sources!(A, B, C, D, E, F);
fn escape_like(q: &str) -> String {
q.replace('\\', "\\\\")
.replace('%', "\\%")
.replace('_', "\\_")
}
pub struct Search;
impl Search {
pub async fn across<S: SearchSources>(
query: &str,
limit: u64,
) -> Result<Vec<SearchHit>, sqlx::Error> {
let q = query.trim();
if q.is_empty() {
return Ok(Vec::new());
}
match crate::db::pool_dispatched() {
DbPool::Postgres(pool) => {
let sql = format!(
"{} ORDER BY rank DESC LIMIT $2",
S::branches(Backend::Postgres).join("\nUNION ALL\n")
);
sqlx::query_as::<_, SearchHit>(&sql)
.bind(q)
.bind(limit as i64)
.fetch_all(pool)
.await
}
DbPool::Sqlite(pool) => {
let sql = format!(
"{} ORDER BY rank DESC LIMIT ?3",
S::branches(Backend::Sqlite).join("\nUNION ALL\n")
);
let like = format!("%{}%", escape_like(q));
let prefix = format!("{}%", escape_like(q));
sqlx::query_as::<_, SearchHit>(&sql)
.bind(like) .bind(prefix) .bind(limit as i64) .fetch_all(pool)
.await
}
}
}
}