use std::collections::BTreeMap;
use std::fmt::Write;
use miette::{ensure, Result};
use crate::data::value::DataValue;
#[derive(Clone, Debug)]
pub struct MmrParams {
pub lambda: f64,
pub k: usize,
pub embedding_col: String,
}
impl Default for MmrParams {
fn default() -> Self {
MmrParams {
lambda: 0.5,
k: 10,
embedding_col: "emb".into(),
}
}
}
#[derive(Clone, Debug)]
pub struct HybridList {
pub label: String,
pub rule_body: String,
}
#[derive(Clone, Debug)]
pub struct GraphLeg {
pub label: String,
pub edge_relation: String,
pub from_col: String,
pub to_col: String,
pub seeds: Vec<DataValue>,
pub max_hops: usize,
pub undirected: bool,
}
impl Default for GraphLeg {
fn default() -> Self {
GraphLeg {
label: "graph".into(),
edge_relation: String::new(),
from_col: "from".into(),
to_col: "to".into(),
seeds: Vec::new(),
max_hops: 2,
undirected: false,
}
}
}
#[derive(Clone, Debug)]
pub struct HybridSearch {
pub relation: String,
pub id_col: String,
pub vector_index: String,
pub query_vector: Vec<f32>,
pub vector_f64: bool,
pub vector_k: usize,
pub ef: usize,
pub fts_index: String,
pub query_text: String,
pub fts_k: usize,
pub extra_lists: Vec<HybridList>,
pub graph_legs: Vec<GraphLeg>,
pub rrf_k: f64,
pub mmr: Option<MmrParams>,
pub limit: usize,
}
impl Default for HybridSearch {
fn default() -> Self {
HybridSearch {
relation: String::new(),
id_col: "id".into(),
vector_index: String::new(),
query_vector: Vec::new(),
vector_f64: false,
vector_k: 10,
ef: 50,
fts_index: String::new(),
query_text: String::new(),
fts_k: 10,
extra_lists: Vec::new(),
graph_legs: Vec::new(),
rrf_k: 60.0,
mmr: None,
limit: 10,
}
}
}
fn validate_ident(s: &str, what: &str) -> Result<()> {
let mut chars = s.chars();
let ok = match chars.next() {
Some(c) if c.is_ascii_alphabetic() || c == '_' => {
chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
}
_ => false,
};
ensure!(
ok,
"hybrid_search: {what} must be a bare identifier (got {s:?})"
);
Ok(())
}
fn fmt_f64(x: f64) -> String {
let s = format!("{x:?}");
s
}
pub fn build_hybrid_query(q: &HybridSearch) -> Result<(String, BTreeMap<String, DataValue>)> {
validate_ident(&q.relation, "relation")?;
validate_ident(&q.id_col, "id_col")?;
validate_ident(&q.vector_index, "vector_index")?;
validate_ident(&q.fts_index, "fts_index")?;
ensure!(q.vector_k > 0, "hybrid_search: vector_k must be > 0");
ensure!(q.ef > 0, "hybrid_search: ef must be > 0");
ensure!(q.fts_k > 0, "hybrid_search: fts_k must be > 0");
ensure!(
!q.query_vector.is_empty(),
"hybrid_search: query_vector is empty"
);
ensure!(q.rrf_k.is_finite() && q.rrf_k >= 0.0, "hybrid_search: rrf_k must be finite and >= 0");
for l in &q.extra_lists {
validate_ident(&l.label, "extra_lists.label")?;
}
for g in &q.graph_legs {
validate_ident(&g.label, "graph_legs.label")?;
validate_ident(&g.edge_relation, "graph_legs.edge_relation")?;
validate_ident(&g.from_col, "graph_legs.from_col")?;
validate_ident(&g.to_col, "graph_legs.to_col")?;
ensure!(g.max_hops >= 1, "hybrid_search: graph_legs.max_hops must be >= 1");
ensure!(!g.seeds.is_empty(), "hybrid_search: graph_legs.seeds is empty");
}
if let Some(m) = &q.mmr {
validate_ident(&m.embedding_col, "mmr.embedding_col")?;
ensure!(m.lambda.is_finite(), "hybrid_search: mmr.lambda must be finite");
}
let rel = &q.relation;
let idc = &q.id_col;
let vec_call = if q.vector_f64 {
"vec($qv, 'F64')"
} else {
"vec($qv)"
};
let mut s = String::new();
writeln!(
s,
"sem[id, score] := ~{rel}:{vidx}{{ {idc}: id | query: {vec_call}, k: {vk}, ef: {ef}, bind_distance: __dist }}, score = -__dist",
vidx = q.vector_index,
vk = q.vector_k,
ef = q.ef,
)
.unwrap();
writeln!(
s,
"txt[id, score] := ~{rel}:{fidx}{{ {idc}: id | query: $qt, k: {fk}, bind_score: score }}",
fidx = q.fts_index,
fk = q.fts_k,
)
.unwrap();
let mut params = BTreeMap::new();
params.insert(
"qv".to_string(),
DataValue::List(q.query_vector.iter().map(|f| DataValue::from(*f as f64)).collect()),
);
params.insert("qt".to_string(), DataValue::from(q.query_text.as_str()));
for (i, g) in q.graph_legs.iter().enumerate() {
let er = &g.edge_relation;
let fc = &g.from_col;
let tc = &g.to_col;
for (j, seed) in g.seeds.iter().enumerate() {
let pname = format!("hg{i}_seed{j}");
writeln!(s, "hg{i}_seed[__s] := __s = ${pname}").unwrap();
params.insert(pname, seed.clone());
}
writeln!(
s,
"hg{i}_reach[__to, min(__d)] := hg{i}_seed[__s], *{er}{{ {fc}: __s, {tc}: __to }}, __d = 1.0"
)
.unwrap();
if g.undirected {
writeln!(
s,
"hg{i}_reach[__to, min(__d)] := hg{i}_seed[__s], *{er}{{ {fc}: __to, {tc}: __s }}, __d = 1.0"
)
.unwrap();
}
let bound = fmt_f64(g.max_hops as f64);
writeln!(
s,
"hg{i}_reach[__to, min(__d)] := hg{i}_reach[__mid, __pd], __pd < {bound}, *{er}{{ {fc}: __mid, {tc}: __to }}, __d = __pd + 1.0"
)
.unwrap();
if g.undirected {
writeln!(
s,
"hg{i}_reach[__to, min(__d)] := hg{i}_reach[__mid, __pd], __pd < {bound}, *{er}{{ {fc}: __to, {tc}: __mid }}, __d = __pd + 1.0"
)
.unwrap();
}
}
writeln!(s, "combined[__lid, id, score] := sem[id, score], __lid = 'semantic'").unwrap();
writeln!(s, "combined[__lid, id, score] := txt[id, score], __lid = 'text'").unwrap();
for l in &q.extra_lists {
writeln!(
s,
"combined[__lid, id, score] := {body}, __lid = '{label}'",
body = l.rule_body,
label = l.label,
)
.unwrap();
}
for (i, g) in q.graph_legs.iter().enumerate() {
writeln!(
s,
"combined[__lid, id, score] := hg{i}_reach[id, __gd], score = -__gd, __lid = '{label}'",
label = g.label,
)
.unwrap();
}
let rrf_k = fmt_f64(q.rrf_k);
match &q.mmr {
None => {
writeln!(
s,
"?[id, score] <~ ReciprocalRankFusion(combined[__lid, id, score], k: {rrf_k})"
)
.unwrap();
writeln!(s, ":order -score").unwrap();
writeln!(s, ":limit {}", q.limit).unwrap();
}
Some(m) => {
writeln!(
s,
"fused[id, score] <~ ReciprocalRankFusion(combined[__lid, id, score], k: {rrf_k})"
)
.unwrap();
writeln!(
s,
"cand[id, score, __emb] := fused[id, score], *{rel}{{ {idc}: id, {emb}: __emb }}",
emb = m.embedding_col,
)
.unwrap();
writeln!(
s,
"?[id, rank] <~ MaximalMarginalRelevance(cand[id, score, __emb], lambda: {lambda}, k: {k})",
lambda = fmt_f64(m.lambda.clamp(0.0, 1.0)),
k = m.k,
)
.unwrap();
writeln!(s, ":order rank").unwrap();
}
}
Ok((s, params))
}