use sqlparser::ast;
use super::super::helpers::{
extract_float, extract_float_array, extract_func_args, extract_string_literal,
};
use crate::error::{Result, SqlError};
use crate::types::SqlPlan;
pub(super) fn plan_hybrid_from_sort(
args: &[ast::Expr],
collection: &str,
plan: &SqlPlan,
score_alias: Option<&str>,
) -> Result<Option<SqlPlan>> {
if args.len() < 2 {
return Err(no_args_rrf_score_error());
}
let limit = match plan {
SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
_ => 10,
};
let third_is_graph_score = args.get(2).is_some_and(is_function_call);
if third_is_graph_score {
plan_hybrid_triple(args, collection, limit, score_alias)
} else {
plan_hybrid_two_source(args, collection, limit, score_alias)
}
}
fn plan_hybrid_two_source(
args: &[ast::Expr],
collection: &str,
limit: usize,
score_alias: Option<&str>,
) -> Result<Option<SqlPlan>> {
if args.len() > 4 {
return Err(SqlError::InvalidFunction {
detail: format!(
"rrf_score() two-source form accepts at most 4 arguments \
(rank1, rank2, k1?, k2?); got {}. \
For three-source fusion use rrf_score(vector_distance(...), \
bm25_score(...), graph_score(...), k1?, k2?, k3?).",
args.len()
),
});
}
let vector = extract_vector_arg(&args[0])?;
let text = extract_text_arg(&args[1])?;
let k1 = args
.get(2)
.and_then(|e| extract_float(e).ok())
.unwrap_or(60.0);
let k2 = args
.get(3)
.and_then(|e| extract_float(e).ok())
.unwrap_or(60.0);
let vector_weight = k2 as f32 / (k1 as f32 + k2 as f32);
Ok(Some(SqlPlan::HybridSearch {
collection: collection.into(),
query_vector: vector,
query_text: text,
top_k: limit,
ef_search: limit * 2,
vector_weight,
fuzzy: true,
score_alias: score_alias.map(|s| s.to_string()),
}))
}
fn plan_hybrid_triple(
args: &[ast::Expr],
collection: &str,
limit: usize,
score_alias: Option<&str>,
) -> Result<Option<SqlPlan>> {
let k_count = args.len().saturating_sub(3);
if k_count == 1 || k_count == 2 {
return Err(SqlError::InvalidFunction {
detail: format!(
"rrf_score() three-source form requires 0 or 3 k-constants \
after the three source arguments, not {k_count}. \
Use rrf_score(v, t, g) or rrf_score(v, t, g, k1, k2, k3)."
),
});
}
if args.len() > 6 {
return Err(SqlError::InvalidFunction {
detail: format!(
"rrf_score() accepts at most 6 arguments in the three-source form \
(rank1, rank2, rank3, k1?, k2?, k3?); got {}.",
args.len()
),
});
}
let vector = extract_vector_arg(&args[0])?;
let text = extract_text_arg(&args[1])?;
let (graph_seed_id, graph_depth, graph_edge_label) = extract_graph_score_args(&args[2])?;
let k1 = args
.get(3)
.and_then(|e| extract_float(e).ok())
.unwrap_or(60.0);
let k2 = args
.get(4)
.and_then(|e| extract_float(e).ok())
.unwrap_or(60.0);
let k3 = args
.get(5)
.and_then(|e| extract_float(e).ok())
.unwrap_or(60.0);
Ok(Some(SqlPlan::HybridSearchTriple {
collection: collection.into(),
query_vector: vector,
query_text: text,
graph_seed_id,
graph_depth,
graph_edge_label,
top_k: limit,
ef_search: limit * 2,
fuzzy: true,
rrf_k: (k1, k2, k3),
score_alias: score_alias.map(|s| s.to_string()),
}))
}
fn extract_vector_arg(expr: &ast::Expr) -> Result<Vec<f32>> {
Ok(match expr {
ast::Expr::Function(f) => {
let inner_args = extract_func_args(f)?;
if inner_args.len() >= 2 {
extract_float_array(&inner_args[1]).unwrap_or_default()
} else {
Vec::new()
}
}
_ => Vec::new(),
})
}
fn extract_text_arg(expr: &ast::Expr) -> Result<String> {
Ok(match expr {
ast::Expr::Function(f) => {
let inner_args = extract_func_args(f)?;
if inner_args.len() >= 2 {
extract_string_literal(&inner_args[1]).unwrap_or_default()
} else {
String::new()
}
}
_ => String::new(),
})
}
fn extract_graph_score_args(expr: &ast::Expr) -> Result<(String, usize, Option<String>)> {
let ast::Expr::Function(f) = expr else {
return Ok((String::new(), 1, None));
};
let inner_args = extract_func_args(f)?;
let seed_id = inner_args
.get(1)
.and_then(|e| extract_string_literal(e).ok())
.unwrap_or_default();
let mut depth: usize = 1;
let mut edge_label: Option<String> = None;
for arg in inner_args.iter().skip(2) {
if let ast::Expr::Named { name, expr } = arg {
let key = name.value.to_ascii_lowercase();
match key.as_str() {
"depth" => {
if let Ok(d) = extract_float(expr) {
depth = d as usize;
}
}
"label" => {
edge_label = extract_string_literal(expr).ok();
}
_ => {}
}
}
}
Ok((seed_id, depth, edge_label))
}
fn is_function_call(expr: &ast::Expr) -> bool {
matches!(expr, ast::Expr::Function(_))
}
pub(super) fn no_args_rrf_score_error() -> SqlError {
SqlError::InvalidFunction {
detail: "rrf_score() requires at least vector_distance(...) and bm25_score(...) \
arguments; e.g. rrf_score(vector_distance(emb, ARRAY[...]), \
bm25_score(content, 'query'))"
.into(),
}
}