use super::super::statement::GraphDirection;
use super::helpers::{
array_floats_after, float_pair_after, float_triple_after, quoted_after, usize_after, word_after,
};
use super::tokenizer::{Tok, tokenize};
pub struct FusionKeywords {
pub vector_top_k: &'static str,
pub expansion_depth: &'static str,
pub edge_label: &'static str,
pub final_top_k: &'static str,
pub rrf_k: &'static str,
pub vector_field: &'static str,
pub direction: &'static str,
pub max_visited: &'static str,
pub query_anchor: &'static str,
pub bm25_query: &'static str,
pub bm25_field: &'static str,
}
pub const RAG_FUSION_KEYWORDS: FusionKeywords = FusionKeywords {
vector_top_k: "VECTOR_TOP_K",
expansion_depth: "EXPANSION_DEPTH",
edge_label: "EDGE_LABEL",
final_top_k: "FINAL_TOP_K",
rrf_k: "RRF_K",
vector_field: "VECTOR_FIELD",
direction: "DIRECTION",
max_visited: "MAX_VISITED",
query_anchor: "QUERY",
bm25_query: "BM25",
bm25_field: "ON",
};
pub const SEARCH_FUSION_KEYWORDS: FusionKeywords = FusionKeywords {
vector_top_k: "VECTOR_TOP_K",
expansion_depth: "DEPTH",
edge_label: "LABEL",
final_top_k: "TOP",
rrf_k: "RRF_K",
vector_field: "VECTOR_FIELD",
direction: "DIRECTION",
max_visited: "MAX_VISITED",
query_anchor: "ARRAY",
bm25_query: "BM25",
bm25_field: "ON",
};
#[derive(Debug, Clone, Default, PartialEq)]
pub struct FusionParams {
pub query_vector: Option<Vec<f32>>,
pub vector_top_k: Option<usize>,
pub expansion_depth: Option<usize>,
pub edge_label: Option<String>,
pub final_top_k: Option<usize>,
pub rrf_k: Option<(f64, f64)>,
pub rrf_k_triple: Option<(f64, f64, f64)>,
pub vector_field: Option<String>,
pub direction: Option<GraphDirection>,
pub max_visited: Option<usize>,
pub bm25_query: Option<String>,
pub bm25_field: Option<String>,
}
impl FusionParams {
pub(super) fn extract(toks: &[Tok<'_>], sql: &str, kw: &FusionKeywords) -> Self {
let direction = match word_after(toks, kw.direction)
.as_deref()
.map(str::to_ascii_uppercase)
.as_deref()
{
Some("IN") => Some(GraphDirection::In),
Some("BOTH") => Some(GraphDirection::Both),
Some("OUT") => Some(GraphDirection::Out),
_ => None,
};
let rrf_k_triple = float_triple_after(toks, kw.rrf_k);
let rrf_k = if rrf_k_triple.is_some() {
None
} else {
float_pair_after(toks, kw.rrf_k)
};
let (bm25_query, bm25_field) = if !kw.bm25_query.is_empty() {
(
quoted_after(toks, kw.bm25_query),
quoted_after(toks, kw.bm25_field),
)
} else {
(None, None)
};
Self {
query_vector: array_floats_after(sql, kw.query_anchor),
vector_top_k: usize_after(toks, kw.vector_top_k),
expansion_depth: usize_after(toks, kw.expansion_depth),
edge_label: quoted_after(toks, kw.edge_label),
final_top_k: usize_after(toks, kw.final_top_k),
rrf_k,
rrf_k_triple,
vector_field: quoted_after(toks, kw.vector_field),
direction,
max_visited: usize_after(toks, kw.max_visited),
bm25_query,
bm25_field,
}
}
}
pub fn parse_search_using_fusion(sql: &str) -> Option<(String, FusionParams)> {
let toks = tokenize(sql);
let collection = match toks.as_slice() {
[Tok::Word(s), Tok::Word(c), Tok::Word(u), Tok::Word(f), ..]
if s.eq_ignore_ascii_case("SEARCH")
&& u.eq_ignore_ascii_case("USING")
&& f.eq_ignore_ascii_case("FUSION") =>
{
(*c).to_string()
}
_ => return None,
};
Some((
collection,
FusionParams::extract(&toks, sql, &SEARCH_FUSION_KEYWORDS),
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn search_fusion_full_surface_parses() {
let (col, p) = parse_search_using_fusion(
"SEARCH mycol USING FUSION(ARRAY[0.1, 0.2] VECTOR_TOP_K 5 DEPTH 2 \
LABEL 'related' TOP 10 RRF_K (60.0, 35.0))",
)
.unwrap();
assert_eq!(col, "mycol");
assert_eq!(p.query_vector.as_deref().map(<[f32]>::len), Some(2));
assert_eq!(p.vector_top_k, Some(5));
assert_eq!(p.expansion_depth, Some(2));
assert_eq!(p.edge_label.as_deref(), Some("related"));
assert_eq!(p.final_top_k, Some(10));
assert_eq!(p.rrf_k, Some((60.0, 35.0)));
assert_eq!(p.rrf_k_triple, None);
}
#[test]
fn search_fusion_three_source_parses() {
let (col, p) = parse_search_using_fusion(
"SEARCH entities USING FUSION(ARRAY[0.1, 0.3] VECTOR_FIELD 'embedding' \
VECTOR_TOP_K 50 BM25 'transformer attention' ON 'body' \
DEPTH 2 LABEL 'related_to' TOP 10 RRF_K (60.0, 35.0, 50.0))",
)
.unwrap();
assert_eq!(col, "entities");
assert_eq!(p.rrf_k, None);
assert_eq!(p.rrf_k_triple, Some((60.0, 35.0, 50.0)));
assert_eq!(p.bm25_query.as_deref(), Some("transformer attention"));
assert_eq!(p.bm25_field.as_deref(), Some("body"));
assert_eq!(p.expansion_depth, Some(2));
assert_eq!(p.edge_label.as_deref(), Some("related_to"));
assert_eq!(p.final_top_k, Some(10));
}
#[test]
fn search_fusion_label_literal_that_shadows_top_keyword() {
let (_, p) =
parse_search_using_fusion("SEARCH c USING FUSION(ARRAY[0.5] LABEL 'TOP_SECRET' TOP 7)")
.unwrap();
assert_eq!(p.edge_label.as_deref(), Some("TOP_SECRET"));
assert_eq!(p.final_top_k, Some(7));
}
#[test]
fn search_fusion_rejects_wrong_prefix() {
assert!(parse_search_using_fusion("INSERT INTO x VALUES (1)").is_none());
assert!(parse_search_using_fusion("SEARCH x USING VECTOR(ARRAY[1.0])").is_none());
}
}