use crate::VectorType;
use crate::database::config::SearchConfig;
use crate::error::Result;
use crate::hook::{HookContext, SearchHook};
use crate::index::brute_force;
use crate::node::{NodeId, SearchHit};
use crate::storage::memtable::MemTable;
use std::sync::{Arc, Mutex, MutexGuard};
fn lock_or_recover<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
mutex.lock().unwrap_or_else(|poisoned| {
tracing::warn!("Mutex was poisoned (pipeline), recovering...");
poisoned.into_inner()
})
}
pub(crate) fn execute_pipeline<T: VectorType>(
memtable: &Arc<Mutex<MemTable<T>>>,
hook: &Arc<dyn SearchHook>,
query_text: Option<&str>,
query_vector: Option<&[T]>,
config: &SearchConfig,
ctx: &mut HookContext,
) -> Result<Vec<SearchHit>> {
#[allow(unused_mut)]
let mut mt = lock_or_recover(memtable);
let dim = mt.dim();
if let Some(qv) = query_vector {
if qv.len() != dim {
return Err(crate::error::TriviumError::DimensionMismatch {
expected: dim,
got: qv.len(),
});
}
for item in qv {
let f = item.to_f32();
if f.is_nan() || f.is_infinite() {
return Err(crate::error::TriviumError::InvalidVector {
reason: "Query vector contains NaN or Infinity".to_string(),
});
}
}
}
let mut safe_cfg = config.clone();
safe_cfg.top_k = safe_cfg.top_k.max(1);
safe_cfg.fista_lambda = safe_cfg.fista_lambda.clamp(1e-5, 100.0);
safe_cfg.teleport_alpha = safe_cfg.teleport_alpha.clamp(0.0, 1.0);
safe_cfg.dpp_quality_weight = safe_cfg.dpp_quality_weight.clamp(0.0, 10.0);
safe_cfg.fista_threshold = safe_cfg.fista_threshold.clamp(0.0, f32::MAX);
safe_cfg.bq_candidate_ratio = safe_cfg.bq_candidate_ratio.clamp(0.0, 1.0);
let mut query_vec_f32: Vec<f32> = query_vector
.map(|qv| qv.iter().map(|x| x.to_f32()).collect())
.unwrap_or_default();
{
let t0 = std::time::Instant::now();
hook.on_pre_search(&mut query_vec_f32, &mut safe_cfg, ctx);
ctx.record_timing("hook_pre_search", t0.elapsed());
}
if ctx.abort {
return Ok(vec![]);
}
let hooked_query: Vec<T> = query_vec_f32.iter().map(|&x| T::from_f32(x)).collect();
let query_vector: Option<&[T]> = if query_vector.is_some() {
Some(&hooked_query)
} else {
None
};
let config = &safe_cfg;
let custom_recall_result = {
let t0 = std::time::Instant::now();
let result = hook.on_custom_recall(&query_vec_f32, config, ctx);
ctx.record_timing("hook_custom_recall", t0.elapsed());
result
};
let mut anchor_hits: Vec<SearchHit> = Vec::new();
let mut seed_map: std::collections::HashMap<NodeId, f32> = std::collections::HashMap::new();
if let Some(custom_hits) = custom_recall_result {
for hit in custom_hits {
*seed_map.entry(hit.id).or_insert(0.0) += hit.score;
}
} else {
mt.ensure_vectors_cache();
recall_text(&mt, config, query_text, &mut seed_map);
recall_vector(&mt, config, query_vector, &mut seed_map);
recall_residual(&mt, config, query_vector, &mut seed_map);
}
aggregate_seeds(&mt, config, &seed_map, &mut anchor_hits);
{
let t0 = std::time::Instant::now();
hook.on_post_recall(&mut anchor_hits, ctx);
ctx.record_timing("hook_post_recall", t0.elapsed());
}
if anchor_hits.is_empty() {
return Ok(vec![]);
}
let mut seeds = Vec::with_capacity(anchor_hits.len());
for mut hit in anchor_hits {
if let Some(payload) = mt.get_payload(hit.id) {
hit.payload = payload.clone();
seeds.push(hit);
}
}
{
let t0 = std::time::Instant::now();
hook.on_pre_graph_expand(&mut seeds, ctx);
ctx.record_timing("hook_pre_graph_expand", t0.elapsed());
}
let t_graph = std::time::Instant::now();
let mut expanded = crate::graph::traversal::expand_graph(
&mt,
seeds,
config.expand_depth,
config.teleport_alpha,
config.enable_inverse_inhibition,
config.lateral_inhibition_threshold,
config.enable_refractory_fatigue,
);
ctx.record_timing("graph_expand", t_graph.elapsed());
{
let t0 = std::time::Instant::now();
if let Some(reranked) = hook.on_rerank(&mut expanded, ctx) {
expanded = reranked;
}
ctx.record_timing("hook_rerank", t0.elapsed());
}
if config.enable_advanced_pipeline
&& config.enable_dpp
&& expanded.len() > config.top_k
&& let Some(mut final_results) = apply_dpp(&mt, config, &expanded)
{
{
let t0 = std::time::Instant::now();
hook.on_post_search(&mut final_results, ctx);
ctx.record_timing("hook_post_search", t0.elapsed());
}
return Ok(final_results);
}
expanded.truncate(config.top_k);
{
let t0 = std::time::Instant::now();
hook.on_post_search(&mut expanded, ctx);
ctx.record_timing("hook_post_search", t0.elapsed());
}
Ok(expanded)
}
fn recall_text<T: VectorType>(
mt: &MemTable<T>,
config: &SearchConfig,
query_text: Option<&str>,
seed_map: &mut std::collections::HashMap<NodeId, f32>,
) {
if !config.enable_text_hybrid_search {
return;
}
if let Some(txt) = query_text {
let text_engine = mt.text_engine();
let ac_hits = text_engine.search_ac(txt);
for (id, score) in ac_hits {
*seed_map.entry(id).or_insert(0.0) += score * config.text_boost;
}
let bm25_hits = text_engine.search_bm25(txt, config.bm25_k1, config.bm25_b);
for (id, score) in bm25_hits {
let normalized_score = (score / 10.0).clamp(0.0, 1.0) * config.text_boost;
*seed_map.entry(id).or_insert(0.0) += normalized_score;
}
}
}
fn recall_vector<T: VectorType>(
mt: &MemTable<T>,
config: &SearchConfig,
query_vector: Option<&[T]>,
seed_map: &mut std::collections::HashMap<NodeId, f32>,
) {
let query_vector = match query_vector {
Some(qv) => qv,
None => return,
};
let dim = mt.dim();
let vectors = mt.flat_vectors();
let filter_ref = config.payload_filter.as_ref();
let passes_filter = |id: NodeId| -> bool {
match filter_ref {
None => true,
Some(f) => mt.get_payload(id).is_some_and(|p| f.matches(p)),
}
};
let total_nodes = mt.node_count();
let use_bq = config.enable_bq_coarse_search || total_nodes > 20_000;
let use_int8_rocket = total_nodes > 100_000;
let vector_hits: Vec<SearchHit> = if use_bq {
bq_pipeline(
mt,
config,
query_vector,
vectors,
dim,
use_int8_rocket,
&passes_filter,
)
} else {
brute_force_pipeline(mt, config, query_vector, vectors, dim, &passes_filter)
};
for hit in vector_hits {
*seed_map.entry(hit.id).or_insert(0.0) += hit.score;
}
}
fn bq_pipeline<T: VectorType + Sync>(
mt: &MemTable<T>,
config: &SearchConfig,
query_vector: &[T],
vectors: &[T],
dim: usize,
use_int8_rocket: bool,
passes_filter: &(dyn Fn(NodeId) -> bool + Sync),
) -> Vec<SearchHit> {
use std::collections::BinaryHeap;
let q_bq = crate::index::bq::BqSignature::from_vector(query_vector);
let slot_count = mt.internal_slot_count();
let candidate_cnt =
(((mt.node_count() as f32) * config.bq_candidate_ratio).ceil() as usize).max(config.top_k);
let bq_sigs = mt.bq_signatures_slice();
let id_map = mt.internal_indices();
let fast_tags = mt.fast_tags_slice();
let has_filter = config.payload_filter.is_some();
let bloom_mask = config
.payload_filter
.as_ref()
.map(|f| f.extract_must_have_mask())
.unwrap_or(0);
let mut heap: BinaryHeap<(u32, usize)> = BinaryHeap::with_capacity(candidate_cnt + 1);
let scan_len = slot_count.min(bq_sigs.len()).min(fast_tags.len());
for i in 0..scan_len {
let node_id = id_map[i];
if node_id == 0 {
continue;
}
if bloom_mask != 0 && (fast_tags[i] & bloom_mask) != bloom_mask {
continue;
}
if has_filter && !passes_filter(node_id) {
continue;
}
let dist = bq_sigs[i].hamming_distance(&q_bq);
if heap.len() < candidate_cnt {
heap.push((dist, i));
} else if let Some(&(worst_dist, _)) = heap.peek()
&& dist < worst_dist
{
heap.pop();
heap.push((dist, i));
}
}
let mut bq_winners: Vec<usize> = heap.into_iter().map(|(_, idx)| idx).collect();
bq_winners.sort_unstable();
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::{_MM_HINT_T0, _mm_prefetch};
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn arm_prefetch(ptr: *const u8) {
unsafe {
std::arch::asm!("prfm pldl1keep, [{ptr}]", ptr = in(reg) ptr, options(nostack, preserves_flags));
}
}
let int8_pool_ref = mt.int8_pool();
let final_candidates: Vec<usize> =
if let (true, Some(i8pool)) = (use_int8_rocket, int8_pool_ref) {
let query_i8 = i8pool.quantize_query(query_vector);
let int8_top_n = (config.top_k * 10).max(50);
let mut i8_heap: BinaryHeap<std::cmp::Reverse<(i32, usize)>> =
BinaryHeap::with_capacity(int8_top_n + 1);
for (iter_idx, &slot_idx) in bq_winners.iter().enumerate() {
if !i8pool.is_valid_index(slot_idx) {
continue;
}
#[cfg(target_arch = "x86_64")]
if iter_idx + 2 < bq_winners.len() {
let prefetch_idx = bq_winners[iter_idx + 2];
if i8pool.is_valid_index(prefetch_idx) {
let prefetch_offset = prefetch_idx * dim;
unsafe {
_mm_prefetch(i8pool.data.as_ptr().add(prefetch_offset), _MM_HINT_T0);
}
}
}
#[cfg(target_arch = "aarch64")]
if iter_idx + 2 < bq_winners.len() {
let prefetch_idx = bq_winners[iter_idx + 2];
if i8pool.is_valid_index(prefetch_idx) {
let prefetch_offset = prefetch_idx * dim;
unsafe {
arm_prefetch(i8pool.data.as_ptr().add(prefetch_offset) as *const u8);
}
}
}
let i8_score = i8pool.dot_score(slot_idx, &query_i8);
if i8_heap.len() < int8_top_n {
i8_heap.push(std::cmp::Reverse((i8_score, slot_idx)));
} else if let Some(&std::cmp::Reverse((worst_score, _))) = i8_heap.peek()
&& i8_score > worst_score
{
i8_heap.pop();
i8_heap.push(std::cmp::Reverse((i8_score, slot_idx)));
}
}
let mut elites: Vec<usize> = i8_heap
.into_iter()
.map(|std::cmp::Reverse((_, idx))| idx)
.collect();
elites.sort_unstable();
elites
} else {
bq_winners
};
let mut refined = Vec::with_capacity(final_candidates.len());
for (iter_idx, &i) in final_candidates.iter().enumerate() {
let offset = i * dim;
if offset + dim <= vectors.len() {
#[cfg(target_arch = "x86_64")]
if iter_idx + 1 < final_candidates.len() {
let next_offset = final_candidates[iter_idx + 1] * dim;
if next_offset + dim <= vectors.len() {
unsafe {
_mm_prefetch(vectors.as_ptr().add(next_offset) as *const i8, _MM_HINT_T0);
}
}
}
#[cfg(target_arch = "aarch64")]
if iter_idx + 1 < final_candidates.len() {
let next_offset = final_candidates[iter_idx + 1] * dim;
if next_offset + dim <= vectors.len() {
unsafe {
arm_prefetch(vectors.as_ptr().add(next_offset) as *const u8);
}
}
}
let score = T::similarity(query_vector, &vectors[offset..offset + dim]);
if score >= config.min_score {
refined.push(SearchHit {
id: mt.get_id_by_index(i),
score,
payload: serde_json::Value::Null,
});
}
}
}
refined.sort_unstable_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
refined.truncate(config.top_k);
for hit in &mut refined {
if let Some(p) = mt.get_payload(hit.id) {
hit.payload = p.clone();
}
}
refined
}
fn brute_force_pipeline<T: VectorType + Sync>(
mt: &MemTable<T>,
config: &SearchConfig,
query_vector: &[T],
vectors: &[T],
dim: usize,
passes_filter: &(dyn Fn(NodeId) -> bool + Sync),
) -> Vec<SearchHit> {
let bloom_mask = config
.payload_filter
.as_ref()
.map(|f| f.extract_must_have_mask())
.unwrap_or(0);
let fast_tags = mt.fast_tags_slice();
brute_force::search(
query_vector,
vectors,
dim,
config.top_k,
config.min_score,
|idx| {
let id = mt.get_id_by_index(idx);
if bloom_mask != 0
&& idx < fast_tags.len()
&& (fast_tags[idx] & bloom_mask) != bloom_mask
{
return 0; }
if passes_filter(id) { id } else { 0 }
},
)
}
fn recall_residual<T: VectorType>(
mt: &MemTable<T>,
config: &SearchConfig,
query_vector: Option<&[T]>,
seed_map: &mut std::collections::HashMap<NodeId, f32>,
) {
if !config.enable_advanced_pipeline || !config.enable_sparse_residual || seed_map.is_empty() {
return;
}
let query_vector = match query_vector {
Some(qv) => qv,
None => return,
};
let entity_vecs: Vec<Vec<f32>> = seed_map
.keys()
.filter_map(|&id| {
mt.get_vector(id)
.map(|v| v.iter().map(|&x| x.to_f32()).collect())
})
.collect();
let q_f32: Vec<f32> = query_vector.iter().map(|&x| x.to_f32()).collect();
let (_, residual, residual_norm) =
crate::cognitive::fista_solve(&q_f32, &entity_vecs, config.fista_lambda, 80);
if residual_norm > config.fista_threshold {
tracing::debug!(
"FISTA 残差较高 ({} > {}),触发影子查询",
residual_norm,
config.fista_threshold
);
let r_orig: Vec<T> = residual.iter().map(|&x| T::from_f32(x)).collect();
let dim = mt.dim();
let shadow_hits = brute_force::search(
&r_orig,
mt.flat_vectors(),
dim,
config.top_k,
config.min_score,
|idx| mt.get_id_by_index(idx),
);
for sh in shadow_hits {
*seed_map.entry(sh.id).or_insert(0.0) += sh.score * 0.8; }
}
}
fn aggregate_seeds<T: VectorType>(
mt: &MemTable<T>,
config: &SearchConfig,
seed_map: &std::collections::HashMap<NodeId, f32>,
anchor_hits: &mut Vec<SearchHit>,
) {
let filter_ref = config.payload_filter.as_ref();
for (&id, &score) in seed_map {
if score >= config.min_score {
let passes = match filter_ref {
None => mt.contains(id),
Some(f) => mt.get_payload(id).is_some_and(|p| f.matches(p)),
};
if passes {
let payload = mt
.get_payload(id)
.cloned()
.unwrap_or(serde_json::Value::Null);
anchor_hits.push(SearchHit { id, score, payload });
}
}
}
anchor_hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
anchor_hits.truncate(config.top_k.max(15));
}
fn apply_dpp<T: VectorType>(
mt: &MemTable<T>,
config: &SearchConfig,
expanded: &[SearchHit],
) -> Option<Vec<SearchHit>> {
let limit = config.top_k;
let dpp_pool_size = std::cmp::min(expanded.len(), limit * 3);
let mut pool_vecs = Vec::with_capacity(dpp_pool_size);
let mut pool_scores = Vec::with_capacity(dpp_pool_size);
let mut pool_valid = Vec::with_capacity(dpp_pool_size);
for i in 0..dpp_pool_size {
let hit = &expanded[i];
if let Some(v) = mt.get_vector(hit.id) {
pool_vecs.push(v.iter().map(|&x| x.to_f32()).collect());
pool_scores.push(hit.score);
pool_valid.push(hit.clone());
}
}
if pool_valid.len() <= limit {
return None;
}
let selected_idx =
crate::cognitive::dpp_greedy(&pool_vecs, &pool_scores, limit, config.dpp_quality_weight);
let mut final_results = Vec::with_capacity(limit);
for &idx in &selected_idx {
final_results.push(pool_valid[idx].clone());
}
final_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Some(final_results)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::database::config::SearchConfig;
use crate::filter::Filter;
use crate::hook::{HookContext, NoopHook, SearchHook};
use crate::node::SearchHit;
use crate::storage::memtable::MemTable;
use std::sync::{Arc, Mutex};
fn make_memtable(dim: usize, nodes: &[(u64, Vec<f32>, serde_json::Value)]) -> MemTable<f32> {
let mut mt = MemTable::new(dim);
for (id, vec, payload) in nodes {
mt.insert_with_id(*id, vec, payload.clone()).unwrap();
}
mt
}
fn wrap(mt: MemTable<f32>) -> Arc<Mutex<MemTable<f32>>> {
Arc::new(Mutex::new(mt))
}
fn default_config() -> SearchConfig {
SearchConfig {
top_k: 5,
min_score: 0.0,
expand_depth: 0,
..Default::default()
}
}
#[test]
fn test_aggregate_seeds_sorts_descending_and_truncates() {
let mt = make_memtable(
2,
&[
(1, vec![1.0, 0.0], serde_json::json!({"a": 1})),
(2, vec![0.0, 1.0], serde_json::json!({"a": 2})),
(3, vec![0.5, 0.5], serde_json::json!({"a": 3})),
],
);
let cfg = SearchConfig {
top_k: 2,
min_score: 0.0,
..Default::default()
};
let mut seed_map = std::collections::HashMap::new();
seed_map.insert(1u64, 0.9f32);
seed_map.insert(2, 0.5);
seed_map.insert(3, 0.7);
let mut hits = Vec::new();
aggregate_seeds(&mt, &cfg, &seed_map, &mut hits);
assert!(hits.len() <= 15);
for w in hits.windows(2) {
assert!(w[0].score >= w[1].score, "应按分数降序");
}
}
#[test]
fn test_aggregate_seeds_filters_by_min_score() {
let mt = make_memtable(
2,
&[
(1, vec![1.0, 0.0], serde_json::json!({})),
(2, vec![0.0, 1.0], serde_json::json!({})),
],
);
let cfg = SearchConfig {
top_k: 10,
min_score: 0.8,
..Default::default()
};
let mut seed_map = std::collections::HashMap::new();
seed_map.insert(1u64, 0.9f32);
seed_map.insert(2, 0.3);
let mut hits = Vec::new();
aggregate_seeds(&mt, &cfg, &seed_map, &mut hits);
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].id, 1);
}
#[test]
fn test_aggregate_seeds_with_payload_filter() {
let mt = make_memtable(
2,
&[
(1, vec![1.0, 0.0], serde_json::json!({"role": "admin"})),
(2, vec![0.0, 1.0], serde_json::json!({"role": "user"})),
],
);
let cfg = SearchConfig {
top_k: 10,
min_score: 0.0,
payload_filter: Some(Filter::eq("role", serde_json::json!("admin"))),
..Default::default()
};
let mut seed_map = std::collections::HashMap::new();
seed_map.insert(1u64, 0.9f32);
seed_map.insert(2, 0.8);
let mut hits = Vec::new();
aggregate_seeds(&mt, &cfg, &seed_map, &mut hits);
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].id, 1);
}
#[test]
fn test_aggregate_seeds_empty_map() {
let mt = make_memtable(2, &[(1, vec![1.0, 0.0], serde_json::json!({}))]);
let cfg = default_config();
let seed_map = std::collections::HashMap::new();
let mut hits = Vec::new();
aggregate_seeds(&mt, &cfg, &seed_map, &mut hits);
assert!(hits.is_empty());
}
#[test]
fn test_recall_vector_basic() {
let mut mt = make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({})),
(2, vec![0.0, 1.0, 0.0], serde_json::json!({})),
(3, vec![0.0, 0.0, 1.0], serde_json::json!({})),
],
);
mt.ensure_vectors_cache();
let cfg = SearchConfig {
top_k: 2,
min_score: 0.0,
..Default::default()
};
let query: Vec<f32> = vec![1.0, 0.0, 0.0];
let mut seed_map = std::collections::HashMap::new();
recall_vector(&mt, &cfg, Some(&query), &mut seed_map);
assert!(!seed_map.is_empty(), "应召回至少一个节点");
let best_id = seed_map
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0;
assert_eq!(*best_id, 1);
}
#[test]
fn test_recall_vector_none_query_is_noop() {
let mut mt = make_memtable(3, &[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))]);
mt.ensure_vectors_cache();
let cfg = default_config();
let mut seed_map = std::collections::HashMap::new();
recall_vector(&mt, &cfg, None, &mut seed_map);
assert!(seed_map.is_empty());
}
#[test]
fn test_recall_vector_with_payload_filter() {
let mut mt = make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({"tag": "yes"})),
(2, vec![0.9, 0.1, 0.0], serde_json::json!({"tag": "no"})),
],
);
mt.ensure_vectors_cache();
let cfg = SearchConfig {
top_k: 5,
min_score: 0.0,
payload_filter: Some(Filter::eq("tag", serde_json::json!("yes"))),
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut seed_map = std::collections::HashMap::new();
recall_vector(&mt, &cfg, Some(&query), &mut seed_map);
assert!(seed_map.contains_key(&1));
assert!(
!seed_map.contains_key(&2),
"node 2 应被 payload_filter 过滤"
);
}
#[test]
fn test_recall_text_disabled_is_noop() {
let mt = make_memtable(
2,
&[(1, vec![1.0, 0.0], serde_json::json!({"text": "hello"}))],
);
let cfg = SearchConfig {
enable_text_hybrid_search: false,
..Default::default()
};
let mut seed_map = std::collections::HashMap::new();
recall_text(&mt, &cfg, Some("hello"), &mut seed_map);
assert!(seed_map.is_empty());
}
#[test]
fn test_recall_text_none_query_is_noop() {
let mt = make_memtable(
2,
&[(1, vec![1.0, 0.0], serde_json::json!({"text": "hello"}))],
);
let cfg = SearchConfig {
enable_text_hybrid_search: true,
..Default::default()
};
let mut seed_map = std::collections::HashMap::new();
recall_text(&mt, &cfg, None, &mut seed_map);
assert!(seed_map.is_empty());
}
#[test]
fn test_recall_residual_disabled_is_noop() {
let mut mt = make_memtable(3, &[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))]);
mt.ensure_vectors_cache();
let cfg = SearchConfig {
enable_advanced_pipeline: false,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut seed_map = std::collections::HashMap::new();
seed_map.insert(1u64, 0.9f32);
let before = seed_map.clone();
recall_residual(&mt, &cfg, Some(&query), &mut seed_map);
assert_eq!(seed_map, before, "disabled 时 seed_map 不应变化");
}
#[test]
fn test_recall_residual_empty_seeds_is_noop() {
let mut mt = make_memtable(3, &[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))]);
mt.ensure_vectors_cache();
let cfg = SearchConfig {
enable_advanced_pipeline: true,
enable_sparse_residual: true,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut seed_map = std::collections::HashMap::new();
recall_residual(&mt, &cfg, Some(&query), &mut seed_map);
assert!(seed_map.is_empty());
}
#[test]
fn test_apply_dpp_returns_none_when_pool_too_small() {
let mt = make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({})),
(2, vec![0.0, 1.0, 0.0], serde_json::json!({})),
],
);
let cfg = SearchConfig {
top_k: 5,
enable_dpp: true,
dpp_quality_weight: 1.0,
..Default::default()
};
let expanded = vec![
SearchHit {
id: 1,
score: 0.9,
payload: serde_json::json!({}),
},
SearchHit {
id: 2,
score: 0.5,
payload: serde_json::json!({}),
},
];
assert!(apply_dpp(&mt, &cfg, &expanded).is_none());
}
#[test]
fn test_apply_dpp_selects_diverse_subset() {
let mt = make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({})),
(2, vec![0.99, 0.01, 0.0], serde_json::json!({})),
(3, vec![0.0, 1.0, 0.0], serde_json::json!({})),
(4, vec![0.0, 0.0, 1.0], serde_json::json!({})),
],
);
let cfg = SearchConfig {
top_k: 2,
enable_dpp: true,
dpp_quality_weight: 1.0,
..Default::default()
};
let expanded = vec![
SearchHit {
id: 1,
score: 1.0,
payload: serde_json::json!({}),
},
SearchHit {
id: 2,
score: 0.95,
payload: serde_json::json!({}),
},
SearchHit {
id: 3,
score: 0.8,
payload: serde_json::json!({}),
},
SearchHit {
id: 4,
score: 0.7,
payload: serde_json::json!({}),
},
];
let result = apply_dpp(&mt, &cfg, &expanded);
assert!(result.is_some());
let selected = result.unwrap();
assert_eq!(selected.len(), 2);
let ids: Vec<u64> = selected.iter().map(|h| h.id).collect();
assert!(ids.contains(&1), "最高分节点应被选中");
assert!(!ids.contains(&2), "DPP 应优先选择多样化的节点而非相似节点");
}
#[test]
fn test_execute_pipeline_dimension_mismatch() {
let mt = wrap(make_memtable(
3,
&[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))],
));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = default_config();
let bad_query = vec![1.0, 0.0]; let mut ctx = HookContext::new();
let result = execute_pipeline(&mt, &hook, None, Some(&bad_query), &cfg, &mut ctx);
assert!(result.is_err(), "维度不匹配应返回错误");
}
#[test]
fn test_execute_pipeline_nan_query_rejected() {
let mt = wrap(make_memtable(
3,
&[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))],
));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = default_config();
let nan_query = vec![f32::NAN, 0.0, 0.0];
let mut ctx = HookContext::new();
let result = execute_pipeline(&mt, &hook, None, Some(&nan_query), &cfg, &mut ctx);
assert!(result.is_err(), "NaN 查询向量应被拒绝");
}
#[test]
fn test_execute_pipeline_inf_query_rejected() {
let mt = wrap(make_memtable(
3,
&[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))],
));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = default_config();
let inf_query = vec![f32::INFINITY, 0.0, 0.0];
let mut ctx = HookContext::new();
let result = execute_pipeline(&mt, &hook, None, Some(&inf_query), &cfg, &mut ctx);
assert!(result.is_err(), "Infinity 查询向量应被拒绝");
}
#[test]
fn test_execute_pipeline_empty_db() {
let mt = wrap(MemTable::<f32>::new(3));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = default_config();
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
assert!(results.is_empty(), "空库应返回空结果");
}
#[test]
fn test_execute_pipeline_basic_vector_search() {
let mt = wrap(make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({"name": "a"})),
(2, vec![0.0, 1.0, 0.0], serde_json::json!({"name": "b"})),
(3, vec![0.0, 0.0, 1.0], serde_json::json!({"name": "c"})),
],
));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = SearchConfig {
top_k: 2,
min_score: 0.0,
expand_depth: 0,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].id, 1, "最相似节点应排第一");
}
#[test]
fn test_execute_pipeline_respects_top_k() {
let nodes: Vec<(u64, Vec<f32>, serde_json::Value)> = (1..=10)
.map(|i| {
(
i as u64,
vec![1.0, i as f32 * 0.01, 0.0],
serde_json::json!({}),
)
})
.collect();
let mt = wrap(make_memtable(3, &nodes));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = SearchConfig {
top_k: 3,
min_score: 0.0,
expand_depth: 0,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
assert!(results.len() <= 3, "结果数不应超过 top_k");
}
#[test]
fn test_execute_pipeline_records_timings() {
let mt = wrap(make_memtable(
3,
&[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))],
));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = default_config();
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let _ = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
assert!(!ctx.stage_timings.is_empty(), "管线应记录阶段计时");
let stage_names: Vec<&str> = ctx.stage_timings.iter().map(|(n, _)| n.as_str()).collect();
assert!(stage_names.contains(&"hook_pre_search"));
assert!(stage_names.contains(&"hook_post_search"));
}
#[test]
fn test_hook_abort_returns_empty() {
struct AbortHook;
impl SearchHook for AbortHook {
fn on_pre_search(&self, _: &mut Vec<f32>, _: &mut SearchConfig, ctx: &mut HookContext) {
ctx.abort = true;
}
}
let mt = wrap(make_memtable(
3,
&[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))],
));
let hook: Arc<dyn SearchHook> = Arc::new(AbortHook);
let cfg = default_config();
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
assert!(results.is_empty(), "abort=true 时应返回空结果");
}
#[test]
fn test_hook_custom_recall_overrides_builtin() {
struct FixedRecallHook;
impl SearchHook for FixedRecallHook {
fn on_custom_recall(
&self,
_: &[f32],
_: &SearchConfig,
_: &mut HookContext,
) -> Option<Vec<SearchHit>> {
Some(vec![SearchHit {
id: 999,
score: 1.0,
payload: serde_json::Value::Null,
}])
}
}
let mt = wrap(make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({})),
(
999,
vec![0.0, 0.0, 1.0],
serde_json::json!({"custom": true}),
),
],
));
let hook: Arc<dyn SearchHook> = Arc::new(FixedRecallHook);
let cfg = SearchConfig {
top_k: 5,
min_score: 0.0,
expand_depth: 0,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 999, "自定义召回应覆盖内置召回");
}
#[test]
fn test_hook_post_recall_filters() {
struct FilterLowScoreHook;
impl SearchHook for FilterLowScoreHook {
fn on_post_recall(&self, hits: &mut Vec<SearchHit>, _: &mut HookContext) {
hits.retain(|h| h.score > 0.5);
}
}
let mt = wrap(make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({})),
(2, vec![0.0, 1.0, 0.0], serde_json::json!({})),
(3, vec![0.0, 0.0, 1.0], serde_json::json!({})),
],
));
let hook: Arc<dyn SearchHook> = Arc::new(FilterLowScoreHook);
let cfg = SearchConfig {
top_k: 10,
min_score: 0.0,
expand_depth: 0,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
for r in &results {
assert!(
r.score > 0.5,
"Hook 过滤后不应有低分结果: score={}",
r.score
);
}
}
#[test]
fn test_hook_rerank_reverses_order() {
struct ReverseRerankHook;
impl SearchHook for ReverseRerankHook {
fn on_rerank(
&self,
hits: &mut Vec<SearchHit>,
_: &mut HookContext,
) -> Option<Vec<SearchHit>> {
let mut reversed = hits.clone();
reversed.reverse();
Some(reversed)
}
}
let mt = wrap(make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({})),
(2, vec![0.7, 0.7, 0.0], serde_json::json!({})),
],
));
let hook: Arc<dyn SearchHook> = Arc::new(ReverseRerankHook);
let cfg = SearchConfig {
top_k: 5,
min_score: 0.0,
expand_depth: 0,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
assert!(results.len() >= 2);
assert_eq!(results[0].id, 2, "rerank 反转后 node 2 应排第一");
}
#[test]
fn test_pipeline_clamps_extreme_config() {
let mt = wrap(make_memtable(
3,
&[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))],
));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = SearchConfig {
top_k: 0,
min_score: 0.0,
expand_depth: 0,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx);
assert!(results.is_ok(), "极端参数不应 panic");
}
#[test]
fn test_pipeline_with_graph_expansion() {
let mut mt = make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({"name": "seed"})),
(
2,
vec![0.0, 1.0, 0.0],
serde_json::json!({"name": "neighbor"}),
),
],
);
mt.link(1, 2, "related".to_string(), 0.8).unwrap();
let mt = wrap(mt);
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = SearchConfig {
top_k: 5,
min_score: 0.0,
expand_depth: 1,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
let ids: Vec<u64> = results.iter().map(|h| h.id).collect();
assert!(ids.contains(&2), "图扩散应将邻居节点 2 纳入结果");
}
}