use regex::Regex;
use std::collections::HashMap;
use std::sync::OnceLock;
pub const RRF_K: u32 = 60;
const ALPHA_SYMBOL: f32 = 0.3;
const ALPHA_NL: f32 = 0.5;
fn symbol_query_re() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
let pattern = concat!(
"^(?:",
r"[A-Za-z_][A-Za-z0-9_]*(?:(?:::|\\|->|\.)[A-Za-z_][A-Za-z0-9_]*)+", "|",
r"_[A-Za-z0-9_]*", "|",
r"[A-Za-z][A-Za-z0-9]*[A-Z_][A-Za-z0-9_]*", "|",
r"[A-Z][A-Za-z0-9]*", ")$",
);
Regex::new(pattern).expect("symbol_query_re")
})
}
pub fn is_symbol_query(query: &str) -> bool {
symbol_query_re().is_match(query.trim())
}
pub fn resolve_alpha(query: &str, alpha: Option<f32>) -> f32 {
match alpha {
Some(a) => a,
None => {
if is_symbol_query(query) {
ALPHA_SYMBOL
} else {
ALPHA_NL
}
}
}
}
pub fn rrf_scores(scored: &[(u32, f32)]) -> HashMap<u32, f32> {
if scored.is_empty() {
return HashMap::new();
}
let mut idx: Vec<usize> = (0..scored.len()).collect();
idx.sort_by(|&a, &b| {
scored[b]
.1
.partial_cmp(&scored[a].1)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut out = HashMap::with_capacity(scored.len());
for (rank0, i) in idx.into_iter().enumerate() {
let rank = (rank0 + 1) as u32;
out.insert(scored[i].0, 1.0 / (RRF_K + rank) as f32);
}
out
}
pub fn combine(
semantic: &HashMap<u32, f32>,
bm25: &HashMap<u32, f32>,
alpha: f32,
) -> HashMap<u32, f32> {
let mut out = HashMap::with_capacity(semantic.len() + bm25.len());
for (&id, &s) in semantic {
out.insert(id, alpha * s);
}
for (&id, &b) in bm25 {
let entry = out.entry(id).or_insert(0.0);
*entry += (1.0 - alpha) * b;
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn symbol_query_pascal() {
assert!(is_symbol_query("HandlerStack"));
assert!(is_symbol_query("Client"));
}
#[test]
fn symbol_query_namespaced() {
assert!(is_symbol_query("Sinatra::Base"));
assert!(is_symbol_query("app.use"));
assert!(is_symbol_query("Foo->bar"));
assert!(is_symbol_query(r"My\Namespace\Class"));
}
#[test]
fn symbol_query_dunder_and_camel() {
assert!(is_symbol_query("_internal"));
assert!(is_symbol_query("getUserById"));
assert!(is_symbol_query("snake_case_thing"));
}
#[test]
fn nl_query_lowercase_word() {
assert!(!is_symbol_query("session"));
assert!(!is_symbol_query("how do i do x"));
assert!(!is_symbol_query("authentication"));
}
#[test]
fn resolve_alpha_explicit_wins() {
assert_eq!(resolve_alpha("HandlerStack", Some(0.9)), 0.9);
assert_eq!(resolve_alpha("hello world", Some(0.1)), 0.1);
}
#[test]
fn resolve_alpha_auto() {
assert_eq!(resolve_alpha("HandlerStack", None), ALPHA_SYMBOL);
assert_eq!(resolve_alpha("how to do x", None), ALPHA_NL);
}
#[test]
fn rrf_scores_empty() {
assert!(rrf_scores(&[]).is_empty());
}
#[test]
fn rrf_scores_basic() {
let scored = [(10, 0.9), (20, 0.5), (30, 0.1)];
let out = rrf_scores(&scored);
assert_eq!(out.len(), 3);
assert!((out[&10] - 1.0 / (RRF_K + 1) as f32).abs() < 1e-7);
assert!((out[&20] - 1.0 / (RRF_K + 2) as f32).abs() < 1e-7);
assert!((out[&30] - 1.0 / (RRF_K + 3) as f32).abs() < 1e-7);
}
#[test]
fn rrf_scores_unsorted_input_is_resorted() {
let scored = [(20, 0.5), (30, 0.1), (10, 0.9)];
let out = rrf_scores(&scored);
assert!(out[&10] > out[&20]);
assert!(out[&20] > out[&30]);
}
#[test]
fn combine_weights_correctly() {
let mut sem = HashMap::new();
sem.insert(1, 1.0);
sem.insert(2, 0.5);
let mut bm = HashMap::new();
bm.insert(2, 0.4);
bm.insert(3, 0.6);
let out = combine(&sem, &bm, 0.5);
assert_eq!(out.len(), 3);
assert!((out[&1] - 0.5 * 1.0).abs() < 1e-6);
assert!((out[&2] - (0.5 * 0.5 + 0.5 * 0.4)).abs() < 1e-6);
assert!((out[&3] - 0.5 * 0.6).abs() < 1e-6);
}
#[test]
fn combine_alpha_one_is_pure_semantic() {
let mut sem = HashMap::new();
sem.insert(1, 1.0);
let mut bm = HashMap::new();
bm.insert(1, 0.5);
bm.insert(2, 0.5);
let out = combine(&sem, &bm, 1.0);
assert_eq!(out.len(), 2);
assert!((out[&1] - 1.0).abs() < 1e-6);
assert!(out[&2].abs() < 1e-6);
}
}