use std::collections::BTreeMap;
use std::fmt::Write;
use miette::{ensure, Result};
use crate::data::value::DataValue;
#[derive(Clone, Debug)]
pub struct MmrParams {
pub lambda: f64,
pub k: usize,
pub embedding_col: String,
}
impl Default for MmrParams {
fn default() -> Self {
MmrParams {
lambda: 0.5,
k: 10,
embedding_col: "emb".into(),
}
}
}
#[derive(Clone, Debug)]
pub struct HybridList {
pub label: String,
pub rule_body: String,
}
#[derive(Clone, Debug)]
pub struct HybridSearch {
pub relation: String,
pub id_col: String,
pub vector_index: String,
pub query_vector: Vec<f32>,
pub vector_f64: bool,
pub vector_k: usize,
pub ef: usize,
pub fts_index: String,
pub query_text: String,
pub fts_k: usize,
pub extra_lists: Vec<HybridList>,
pub rrf_k: f64,
pub mmr: Option<MmrParams>,
pub limit: usize,
}
impl Default for HybridSearch {
fn default() -> Self {
HybridSearch {
relation: String::new(),
id_col: "id".into(),
vector_index: String::new(),
query_vector: Vec::new(),
vector_f64: false,
vector_k: 10,
ef: 50,
fts_index: String::new(),
query_text: String::new(),
fts_k: 10,
extra_lists: Vec::new(),
rrf_k: 60.0,
mmr: None,
limit: 10,
}
}
}
fn validate_ident(s: &str, what: &str) -> Result<()> {
let mut chars = s.chars();
let ok = match chars.next() {
Some(c) if c.is_ascii_alphabetic() || c == '_' => {
chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
}
_ => false,
};
ensure!(
ok,
"hybrid_search: {what} must be a bare identifier (got {s:?})"
);
Ok(())
}
fn fmt_f64(x: f64) -> String {
let s = format!("{x:?}");
s
}
pub fn build_hybrid_query(q: &HybridSearch) -> Result<(String, BTreeMap<String, DataValue>)> {
validate_ident(&q.relation, "relation")?;
validate_ident(&q.id_col, "id_col")?;
validate_ident(&q.vector_index, "vector_index")?;
validate_ident(&q.fts_index, "fts_index")?;
ensure!(q.vector_k > 0, "hybrid_search: vector_k must be > 0");
ensure!(q.ef > 0, "hybrid_search: ef must be > 0");
ensure!(q.fts_k > 0, "hybrid_search: fts_k must be > 0");
ensure!(
!q.query_vector.is_empty(),
"hybrid_search: query_vector is empty"
);
ensure!(q.rrf_k.is_finite() && q.rrf_k >= 0.0, "hybrid_search: rrf_k must be finite and >= 0");
for l in &q.extra_lists {
validate_ident(&l.label, "extra_lists.label")?;
}
if let Some(m) = &q.mmr {
validate_ident(&m.embedding_col, "mmr.embedding_col")?;
ensure!(m.lambda.is_finite(), "hybrid_search: mmr.lambda must be finite");
}
let rel = &q.relation;
let idc = &q.id_col;
let vec_call = if q.vector_f64 {
"vec($qv, 'F64')"
} else {
"vec($qv)"
};
let mut s = String::new();
writeln!(
s,
"sem[id, score] := ~{rel}:{vidx}{{ {idc}: id | query: {vec_call}, k: {vk}, ef: {ef}, bind_distance: __dist }}, score = -__dist",
vidx = q.vector_index,
vk = q.vector_k,
ef = q.ef,
)
.unwrap();
writeln!(
s,
"txt[id, score] := ~{rel}:{fidx}{{ {idc}: id | query: $qt, k: {fk}, bind_score: score }}",
fidx = q.fts_index,
fk = q.fts_k,
)
.unwrap();
writeln!(s, "combined[__lid, id, score] := sem[id, score], __lid = 'semantic'").unwrap();
writeln!(s, "combined[__lid, id, score] := txt[id, score], __lid = 'text'").unwrap();
for l in &q.extra_lists {
writeln!(
s,
"combined[__lid, id, score] := {body}, __lid = '{label}'",
body = l.rule_body,
label = l.label,
)
.unwrap();
}
let rrf_k = fmt_f64(q.rrf_k);
match &q.mmr {
None => {
writeln!(
s,
"?[id, score] <~ ReciprocalRankFusion(combined[__lid, id, score], k: {rrf_k})"
)
.unwrap();
writeln!(s, ":order -score").unwrap();
writeln!(s, ":limit {}", q.limit).unwrap();
}
Some(m) => {
writeln!(
s,
"fused[id, score] <~ ReciprocalRankFusion(combined[__lid, id, score], k: {rrf_k})"
)
.unwrap();
writeln!(
s,
"cand[id, score, __emb] := fused[id, score], *{rel}{{ {idc}: id, {emb}: __emb }}",
emb = m.embedding_col,
)
.unwrap();
writeln!(
s,
"?[id, rank] <~ MaximalMarginalRelevance(cand[id, score, __emb], lambda: {lambda}, k: {k})",
lambda = fmt_f64(m.lambda.clamp(0.0, 1.0)),
k = m.k,
)
.unwrap();
writeln!(s, ":order rank").unwrap();
}
}
let mut params = BTreeMap::new();
params.insert(
"qv".to_string(),
DataValue::List(q.query_vector.iter().map(|f| DataValue::from(*f as f64)).collect()),
);
params.insert("qt".to_string(), DataValue::from(q.query_text.as_str()));
Ok((s, params))
}