use crate::VectorType;
use crate::database::config::SearchConfig;
use crate::error::Result;
use crate::hook::{HookContext, SearchHook};
use crate::index::brute_force;
use crate::index::quiver::QuIVerSearchConfig;
use crate::node::{NodeId, SearchHit};
use crate::storage::memtable::MemTable;
use std::sync::{Arc, Mutex};
use super::lock_or_recover;
pub(crate) fn execute_pipeline<T: VectorType>(
memtable: &Arc<Mutex<MemTable<T>>>,
hook: &Arc<dyn SearchHook>,
query_text: Option<&str>,
query_vector: Option<&[T]>,
config: &SearchConfig,
ctx: &mut HookContext,
) -> Result<Vec<SearchHit>> {
#[allow(unused_mut)]
let mut mt = lock_or_recover(memtable);
let dim = mt.dim();
if let Some(qv) = query_vector {
if qv.len() != dim {
return Err(crate::error::TriviumError::DimensionMismatch {
expected: dim,
got: qv.len(),
});
}
for item in qv {
let f = item.to_f32();
if f.is_nan() || f.is_infinite() {
return Err(crate::error::TriviumError::InvalidVector {
reason: "Query vector contains NaN or Infinity".to_string(),
});
}
}
}
let mut safe_cfg = config.clone();
safe_cfg.top_k = safe_cfg.top_k.max(1);
safe_cfg.fista_lambda = safe_cfg.fista_lambda.clamp(1e-5, 100.0);
safe_cfg.teleport_alpha = safe_cfg.teleport_alpha.clamp(0.0, 1.0);
safe_cfg.dpp_quality_weight = safe_cfg.dpp_quality_weight.clamp(0.0, 10.0);
safe_cfg.fista_threshold = safe_cfg.fista_threshold.clamp(0.0, f32::MAX);
let mut query_vec_f32: Vec<f32> = query_vector
.map(|qv| qv.iter().map(|x| x.to_f32()).collect())
.unwrap_or_default();
{
let t0 = std::time::Instant::now();
hook.on_pre_search(&mut query_vec_f32, &mut safe_cfg, ctx);
ctx.record_timing("hook_pre_search", t0.elapsed());
}
if ctx.abort {
return Ok(vec![]);
}
let hooked_query: Vec<T> = query_vec_f32.iter().map(|&x| T::from_f32(x)).collect();
let query_vector: Option<&[T]> = if query_vector.is_some() {
Some(&hooked_query)
} else {
None
};
let config = &safe_cfg;
let custom_recall_result = {
let t0 = std::time::Instant::now();
let result = hook.on_custom_recall(&query_vec_f32, config, ctx);
ctx.record_timing("hook_custom_recall", t0.elapsed());
result
};
let mut anchor_hits: Vec<SearchHit> = Vec::new();
let mut seed_map: std::collections::HashMap<NodeId, f32> = std::collections::HashMap::new();
if let Some(custom_hits) = custom_recall_result {
for hit in custom_hits {
*seed_map.entry(hit.id).or_insert(0.0) += hit.score;
}
} else {
mt.ensure_vectors_cache();
recall_text(&mt, config, query_text, &mut seed_map);
recall_vector(&mt, config, query_vector, &mut seed_map);
recall_residual(&mt, config, query_vector, &mut seed_map);
}
aggregate_seeds(&mt, config, &seed_map, &mut anchor_hits);
{
let t0 = std::time::Instant::now();
hook.on_post_recall(&mut anchor_hits, ctx);
ctx.record_timing("hook_post_recall", t0.elapsed());
}
if anchor_hits.is_empty() {
return Ok(vec![]);
}
let mut seeds = Vec::with_capacity(anchor_hits.len());
for mut hit in anchor_hits {
if let Some(payload) = mt.get_payload(hit.id) {
hit.payload = payload.clone();
seeds.push(hit);
}
}
{
let t0 = std::time::Instant::now();
hook.on_pre_graph_expand(&mut seeds, ctx);
ctx.record_timing("hook_pre_graph_expand", t0.elapsed());
}
let t_graph = std::time::Instant::now();
let mut expanded = crate::graph::traversal::expand_graph(
&mt,
seeds,
config.expand_depth,
config.teleport_alpha,
config.enable_inverse_inhibition,
config.lateral_inhibition_threshold,
config.enable_refractory_fatigue,
);
ctx.record_timing("graph_expand", t_graph.elapsed());
{
let t0 = std::time::Instant::now();
if let Some(reranked) = hook.on_rerank(&mut expanded, ctx) {
expanded = reranked;
}
ctx.record_timing("hook_rerank", t0.elapsed());
}
if config.enable_advanced_pipeline
&& config.enable_dpp
&& expanded.len() > config.top_k
&& let Some(mut final_results) = apply_dpp(&mt, config, &expanded)
{
{
let t0 = std::time::Instant::now();
hook.on_post_search(&mut final_results, ctx);
ctx.record_timing("hook_post_search", t0.elapsed());
}
return Ok(final_results);
}
expanded.truncate(config.top_k);
{
let t0 = std::time::Instant::now();
hook.on_post_search(&mut expanded, ctx);
ctx.record_timing("hook_post_search", t0.elapsed());
}
Ok(expanded)
}
fn recall_text<T: VectorType>(
mt: &MemTable<T>,
config: &SearchConfig,
query_text: Option<&str>,
seed_map: &mut std::collections::HashMap<NodeId, f32>,
) {
if !config.enable_text_hybrid_search {
return;
}
if let Some(txt) = query_text {
let text_engine = mt.text_engine();
let ac_hits = text_engine.search_ac(txt);
for (id, score) in ac_hits {
*seed_map.entry(id).or_insert(0.0) += score * config.text_boost;
}
let bm25_hits = text_engine.search_bm25(txt, config.bm25_k1, config.bm25_b);
for (id, score) in bm25_hits {
let normalized_score = (score / 10.0).clamp(0.0, 1.0) * config.text_boost;
*seed_map.entry(id).or_insert(0.0) += normalized_score;
}
}
}
fn recall_vector<T: VectorType>(
mt: &MemTable<T>,
config: &SearchConfig,
query_vector: Option<&[T]>,
seed_map: &mut std::collections::HashMap<NodeId, f32>,
) {
let query_vector = match query_vector {
Some(qv) => qv,
None => return,
};
let dim = mt.dim();
let vectors = mt.flat_vectors();
let filter_ref = config.payload_filter.as_ref();
let passes_filter = |id: NodeId| -> bool {
match filter_ref {
None => true,
Some(f) => mt.get_payload(id).is_some_and(|p| f.matches(p)),
}
};
let vector_hits: Vec<SearchHit> = if !config.force_brute_force && mt.quiver().is_some() {
quiver_pipeline(mt, config, query_vector, &passes_filter)
} else {
brute_force_pipeline(mt, config, query_vector, vectors, dim, &passes_filter)
};
for hit in vector_hits {
*seed_map.entry(hit.id).or_insert(0.0) += hit.score;
}
}
fn brute_force_pipeline<T: VectorType + Sync>(
mt: &MemTable<T>,
config: &SearchConfig,
query_vector: &[T],
vectors: &[T],
dim: usize,
passes_filter: &(dyn Fn(NodeId) -> bool + Sync),
) -> Vec<SearchHit> {
let bloom_mask = config
.payload_filter
.as_ref()
.map(|f| f.extract_must_have_mask())
.unwrap_or(0);
let fast_tags = mt.fast_tags_slice();
brute_force::search(
query_vector,
vectors,
dim,
config.top_k,
config.min_score,
|idx| {
let id = mt.get_id_by_index(idx);
if bloom_mask != 0
&& idx < fast_tags.len()
&& (fast_tags[idx] & bloom_mask) != bloom_mask
{
return 0; }
if passes_filter(id) { id } else { 0 }
},
)
}
fn quiver_pipeline<T: VectorType + Sync>(
mt: &MemTable<T>,
config: &SearchConfig,
query_vector: &[T],
passes_filter: &(dyn Fn(NodeId) -> bool + Sync),
) -> Vec<SearchHit> {
let quiver = mt.quiver().unwrap();
let q_f32: Vec<f32> = query_vector.iter().map(|x| x.to_f32()).collect();
let flat = mt.flat_vectors();
let ext_vectors: Vec<f32> = flat.iter().map(|x| x.to_f32()).collect();
let ef_search = config.top_k.max(1) * 4;
let search_cfg = QuIVerSearchConfig {
top_k: config.top_k.max(1) * 2, ef_search,
};
let raw_results = quiver.search(&q_f32, &ext_vectors, &search_cfg);
let mut hits: Vec<SearchHit> = raw_results
.into_iter()
.filter(|&(id, score)| score >= config.min_score && passes_filter(id))
.map(|(id, score)| SearchHit {
id,
score,
payload: mt
.get_payload(id)
.cloned()
.unwrap_or(serde_json::Value::Null),
})
.collect();
hits.sort_unstable_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
hits.truncate(config.top_k);
hits
}
fn recall_residual<T: VectorType>(
mt: &MemTable<T>,
config: &SearchConfig,
query_vector: Option<&[T]>,
seed_map: &mut std::collections::HashMap<NodeId, f32>,
) {
if !config.enable_advanced_pipeline || !config.enable_sparse_residual || seed_map.is_empty() {
return;
}
let query_vector = match query_vector {
Some(qv) => qv,
None => return,
};
let entity_vecs: Vec<Vec<f32>> = seed_map
.keys()
.filter_map(|&id| {
mt.get_vector(id)
.map(|v| v.iter().map(|&x| x.to_f32()).collect())
})
.collect();
let q_f32: Vec<f32> = query_vector.iter().map(|&x| x.to_f32()).collect();
let (_, residual, residual_norm) =
crate::cognitive::fista_solve(&q_f32, &entity_vecs, config.fista_lambda, 80);
if residual_norm > config.fista_threshold {
tracing::debug!(
"FISTA 残差较高 ({} > {}),触发影子查询",
residual_norm,
config.fista_threshold
);
let r_orig: Vec<T> = residual.iter().map(|&x| T::from_f32(x)).collect();
let dim = mt.dim();
let shadow_hits = brute_force::search(
&r_orig,
mt.flat_vectors(),
dim,
config.top_k,
config.min_score,
|idx| mt.get_id_by_index(idx),
);
for sh in shadow_hits {
*seed_map.entry(sh.id).or_insert(0.0) += sh.score * 0.8; }
}
}
fn aggregate_seeds<T: VectorType>(
mt: &MemTable<T>,
config: &SearchConfig,
seed_map: &std::collections::HashMap<NodeId, f32>,
anchor_hits: &mut Vec<SearchHit>,
) {
let filter_ref = config.payload_filter.as_ref();
for (&id, &score) in seed_map {
if score >= config.min_score {
let passes = match filter_ref {
None => mt.contains(id),
Some(f) => mt.get_payload(id).is_some_and(|p| f.matches(p)),
};
if passes {
let payload = mt
.get_payload(id)
.cloned()
.unwrap_or(serde_json::Value::Null);
anchor_hits.push(SearchHit { id, score, payload });
}
}
}
anchor_hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
anchor_hits.truncate(config.top_k.max(15));
}
fn apply_dpp<T: VectorType>(
mt: &MemTable<T>,
config: &SearchConfig,
expanded: &[SearchHit],
) -> Option<Vec<SearchHit>> {
let limit = config.top_k;
let dpp_pool_size = std::cmp::min(expanded.len(), limit * 3);
let mut pool_vecs = Vec::with_capacity(dpp_pool_size);
let mut pool_scores = Vec::with_capacity(dpp_pool_size);
let mut pool_valid = Vec::with_capacity(dpp_pool_size);
for i in 0..dpp_pool_size {
let hit = &expanded[i];
if let Some(v) = mt.get_vector(hit.id) {
pool_vecs.push(v.iter().map(|&x| x.to_f32()).collect());
pool_scores.push(hit.score);
pool_valid.push(hit.clone());
}
}
if pool_valid.len() <= limit {
return None;
}
let selected_idx =
crate::cognitive::dpp_greedy(&pool_vecs, &pool_scores, limit, config.dpp_quality_weight);
let mut final_results = Vec::with_capacity(limit);
for &idx in &selected_idx {
final_results.push(pool_valid[idx].clone());
}
final_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Some(final_results)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::database::config::SearchConfig;
use crate::filter::Filter;
use crate::hook::{HookContext, NoopHook, SearchHook};
use crate::node::SearchHit;
use crate::storage::memtable::MemTable;
use std::sync::{Arc, Mutex};
fn make_memtable(dim: usize, nodes: &[(u64, Vec<f32>, serde_json::Value)]) -> MemTable<f32> {
let mut mt = MemTable::new(dim);
for (id, vec, payload) in nodes {
mt.insert_with_id(*id, vec, payload.clone()).unwrap();
}
mt
}
fn wrap(mt: MemTable<f32>) -> Arc<Mutex<MemTable<f32>>> {
Arc::new(Mutex::new(mt))
}
fn default_config() -> SearchConfig {
SearchConfig {
top_k: 5,
min_score: 0.0,
expand_depth: 0,
..Default::default()
}
}
#[test]
fn test_aggregate_seeds_sorts_descending_and_truncates() {
let mt = make_memtable(
2,
&[
(1, vec![1.0, 0.0], serde_json::json!({"a": 1})),
(2, vec![0.0, 1.0], serde_json::json!({"a": 2})),
(3, vec![0.5, 0.5], serde_json::json!({"a": 3})),
],
);
let cfg = SearchConfig {
top_k: 2,
min_score: 0.0,
..Default::default()
};
let mut seed_map = std::collections::HashMap::new();
seed_map.insert(1u64, 0.9f32);
seed_map.insert(2, 0.5);
seed_map.insert(3, 0.7);
let mut hits = Vec::new();
aggregate_seeds(&mt, &cfg, &seed_map, &mut hits);
assert!(hits.len() <= 15);
for w in hits.windows(2) {
assert!(w[0].score >= w[1].score, "应按分数降序");
}
}
#[test]
fn test_aggregate_seeds_filters_by_min_score() {
let mt = make_memtable(
2,
&[
(1, vec![1.0, 0.0], serde_json::json!({})),
(2, vec![0.0, 1.0], serde_json::json!({})),
],
);
let cfg = SearchConfig {
top_k: 10,
min_score: 0.8,
..Default::default()
};
let mut seed_map = std::collections::HashMap::new();
seed_map.insert(1u64, 0.9f32);
seed_map.insert(2, 0.3);
let mut hits = Vec::new();
aggregate_seeds(&mt, &cfg, &seed_map, &mut hits);
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].id, 1);
}
#[test]
fn test_aggregate_seeds_with_payload_filter() {
let mt = make_memtable(
2,
&[
(1, vec![1.0, 0.0], serde_json::json!({"role": "admin"})),
(2, vec![0.0, 1.0], serde_json::json!({"role": "user"})),
],
);
let cfg = SearchConfig {
top_k: 10,
min_score: 0.0,
payload_filter: Some(Filter::eq("role", serde_json::json!("admin"))),
..Default::default()
};
let mut seed_map = std::collections::HashMap::new();
seed_map.insert(1u64, 0.9f32);
seed_map.insert(2, 0.8);
let mut hits = Vec::new();
aggregate_seeds(&mt, &cfg, &seed_map, &mut hits);
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].id, 1);
}
#[test]
fn test_aggregate_seeds_empty_map() {
let mt = make_memtable(2, &[(1, vec![1.0, 0.0], serde_json::json!({}))]);
let cfg = default_config();
let seed_map = std::collections::HashMap::new();
let mut hits = Vec::new();
aggregate_seeds(&mt, &cfg, &seed_map, &mut hits);
assert!(hits.is_empty());
}
#[test]
fn test_recall_vector_basic() {
let mut mt = make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({})),
(2, vec![0.0, 1.0, 0.0], serde_json::json!({})),
(3, vec![0.0, 0.0, 1.0], serde_json::json!({})),
],
);
mt.ensure_vectors_cache();
let cfg = SearchConfig {
top_k: 2,
min_score: 0.0,
..Default::default()
};
let query: Vec<f32> = vec![1.0, 0.0, 0.0];
let mut seed_map = std::collections::HashMap::new();
recall_vector(&mt, &cfg, Some(&query), &mut seed_map);
assert!(!seed_map.is_empty(), "应召回至少一个节点");
let best_id = seed_map
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0;
assert_eq!(*best_id, 1);
}
#[test]
fn test_recall_vector_none_query_is_noop() {
let mut mt = make_memtable(3, &[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))]);
mt.ensure_vectors_cache();
let cfg = default_config();
let mut seed_map = std::collections::HashMap::new();
recall_vector(&mt, &cfg, None, &mut seed_map);
assert!(seed_map.is_empty());
}
#[test]
fn test_recall_vector_with_payload_filter() {
let mut mt = make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({"tag": "yes"})),
(2, vec![0.9, 0.1, 0.0], serde_json::json!({"tag": "no"})),
],
);
mt.ensure_vectors_cache();
let cfg = SearchConfig {
top_k: 5,
min_score: 0.0,
payload_filter: Some(Filter::eq("tag", serde_json::json!("yes"))),
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut seed_map = std::collections::HashMap::new();
recall_vector(&mt, &cfg, Some(&query), &mut seed_map);
assert!(seed_map.contains_key(&1));
assert!(
!seed_map.contains_key(&2),
"node 2 应被 payload_filter 过滤"
);
}
#[test]
fn test_recall_text_disabled_is_noop() {
let mt = make_memtable(
2,
&[(1, vec![1.0, 0.0], serde_json::json!({"text": "hello"}))],
);
let cfg = SearchConfig {
enable_text_hybrid_search: false,
..Default::default()
};
let mut seed_map = std::collections::HashMap::new();
recall_text(&mt, &cfg, Some("hello"), &mut seed_map);
assert!(seed_map.is_empty());
}
#[test]
fn test_recall_text_none_query_is_noop() {
let mt = make_memtable(
2,
&[(1, vec![1.0, 0.0], serde_json::json!({"text": "hello"}))],
);
let cfg = SearchConfig {
enable_text_hybrid_search: true,
..Default::default()
};
let mut seed_map = std::collections::HashMap::new();
recall_text(&mt, &cfg, None, &mut seed_map);
assert!(seed_map.is_empty());
}
#[test]
fn test_recall_residual_disabled_is_noop() {
let mut mt = make_memtable(3, &[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))]);
mt.ensure_vectors_cache();
let cfg = SearchConfig {
enable_advanced_pipeline: false,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut seed_map = std::collections::HashMap::new();
seed_map.insert(1u64, 0.9f32);
let before = seed_map.clone();
recall_residual(&mt, &cfg, Some(&query), &mut seed_map);
assert_eq!(seed_map, before, "disabled 时 seed_map 不应变化");
}
#[test]
fn test_recall_residual_empty_seeds_is_noop() {
let mut mt = make_memtable(3, &[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))]);
mt.ensure_vectors_cache();
let cfg = SearchConfig {
enable_advanced_pipeline: true,
enable_sparse_residual: true,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut seed_map = std::collections::HashMap::new();
recall_residual(&mt, &cfg, Some(&query), &mut seed_map);
assert!(seed_map.is_empty());
}
#[test]
fn test_apply_dpp_returns_none_when_pool_too_small() {
let mt = make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({})),
(2, vec![0.0, 1.0, 0.0], serde_json::json!({})),
],
);
let cfg = SearchConfig {
top_k: 5,
enable_dpp: true,
dpp_quality_weight: 1.0,
..Default::default()
};
let expanded = vec![
SearchHit {
id: 1,
score: 0.9,
payload: serde_json::json!({}),
},
SearchHit {
id: 2,
score: 0.5,
payload: serde_json::json!({}),
},
];
assert!(apply_dpp(&mt, &cfg, &expanded).is_none());
}
#[test]
fn test_apply_dpp_selects_diverse_subset() {
let mt = make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({})),
(2, vec![0.99, 0.01, 0.0], serde_json::json!({})),
(3, vec![0.0, 1.0, 0.0], serde_json::json!({})),
(4, vec![0.0, 0.0, 1.0], serde_json::json!({})),
],
);
let cfg = SearchConfig {
top_k: 2,
enable_dpp: true,
dpp_quality_weight: 1.0,
..Default::default()
};
let expanded = vec![
SearchHit {
id: 1,
score: 1.0,
payload: serde_json::json!({}),
},
SearchHit {
id: 2,
score: 0.95,
payload: serde_json::json!({}),
},
SearchHit {
id: 3,
score: 0.8,
payload: serde_json::json!({}),
},
SearchHit {
id: 4,
score: 0.7,
payload: serde_json::json!({}),
},
];
let result = apply_dpp(&mt, &cfg, &expanded);
assert!(result.is_some());
let selected = result.unwrap();
assert_eq!(selected.len(), 2);
let ids: Vec<u64> = selected.iter().map(|h| h.id).collect();
assert!(ids.contains(&1), "最高分节点应被选中");
assert!(!ids.contains(&2), "DPP 应优先选择多样化的节点而非相似节点");
}
#[test]
fn test_execute_pipeline_dimension_mismatch() {
let mt = wrap(make_memtable(
3,
&[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))],
));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = default_config();
let bad_query = vec![1.0, 0.0]; let mut ctx = HookContext::new();
let result = execute_pipeline(&mt, &hook, None, Some(&bad_query), &cfg, &mut ctx);
assert!(result.is_err(), "维度不匹配应返回错误");
}
#[test]
fn test_execute_pipeline_nan_query_rejected() {
let mt = wrap(make_memtable(
3,
&[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))],
));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = default_config();
let nan_query = vec![f32::NAN, 0.0, 0.0];
let mut ctx = HookContext::new();
let result = execute_pipeline(&mt, &hook, None, Some(&nan_query), &cfg, &mut ctx);
assert!(result.is_err(), "NaN 查询向量应被拒绝");
}
#[test]
fn test_execute_pipeline_inf_query_rejected() {
let mt = wrap(make_memtable(
3,
&[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))],
));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = default_config();
let inf_query = vec![f32::INFINITY, 0.0, 0.0];
let mut ctx = HookContext::new();
let result = execute_pipeline(&mt, &hook, None, Some(&inf_query), &cfg, &mut ctx);
assert!(result.is_err(), "Infinity 查询向量应被拒绝");
}
#[test]
fn test_execute_pipeline_empty_db() {
let mt = wrap(MemTable::<f32>::new(3));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = default_config();
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
assert!(results.is_empty(), "空库应返回空结果");
}
#[test]
fn test_execute_pipeline_basic_vector_search() {
let mt = wrap(make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({"name": "a"})),
(2, vec![0.0, 1.0, 0.0], serde_json::json!({"name": "b"})),
(3, vec![0.0, 0.0, 1.0], serde_json::json!({"name": "c"})),
],
));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = SearchConfig {
top_k: 2,
min_score: 0.0,
expand_depth: 0,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].id, 1, "最相似节点应排第一");
}
#[test]
fn test_execute_pipeline_respects_top_k() {
let nodes: Vec<(u64, Vec<f32>, serde_json::Value)> = (1..=10)
.map(|i| {
(
i as u64,
vec![1.0, i as f32 * 0.01, 0.0],
serde_json::json!({}),
)
})
.collect();
let mt = wrap(make_memtable(3, &nodes));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = SearchConfig {
top_k: 3,
min_score: 0.0,
expand_depth: 0,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
assert!(results.len() <= 3, "结果数不应超过 top_k");
}
#[test]
fn test_execute_pipeline_records_timings() {
let mt = wrap(make_memtable(
3,
&[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))],
));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = default_config();
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let _ = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
assert!(!ctx.stage_timings.is_empty(), "管线应记录阶段计时");
let stage_names: Vec<&str> = ctx.stage_timings.iter().map(|(n, _)| n.as_str()).collect();
assert!(stage_names.contains(&"hook_pre_search"));
assert!(stage_names.contains(&"hook_post_search"));
}
#[test]
fn test_hook_abort_returns_empty() {
struct AbortHook;
impl SearchHook for AbortHook {
fn on_pre_search(&self, _: &mut Vec<f32>, _: &mut SearchConfig, ctx: &mut HookContext) {
ctx.abort = true;
}
}
let mt = wrap(make_memtable(
3,
&[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))],
));
let hook: Arc<dyn SearchHook> = Arc::new(AbortHook);
let cfg = default_config();
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
assert!(results.is_empty(), "abort=true 时应返回空结果");
}
#[test]
fn test_hook_custom_recall_overrides_builtin() {
struct FixedRecallHook;
impl SearchHook for FixedRecallHook {
fn on_custom_recall(
&self,
_: &[f32],
_: &SearchConfig,
_: &mut HookContext,
) -> Option<Vec<SearchHit>> {
Some(vec![SearchHit {
id: 999,
score: 1.0,
payload: serde_json::Value::Null,
}])
}
}
let mt = wrap(make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({})),
(
999,
vec![0.0, 0.0, 1.0],
serde_json::json!({"custom": true}),
),
],
));
let hook: Arc<dyn SearchHook> = Arc::new(FixedRecallHook);
let cfg = SearchConfig {
top_k: 5,
min_score: 0.0,
expand_depth: 0,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 999, "自定义召回应覆盖内置召回");
}
#[test]
fn test_hook_post_recall_filters() {
struct FilterLowScoreHook;
impl SearchHook for FilterLowScoreHook {
fn on_post_recall(&self, hits: &mut Vec<SearchHit>, _: &mut HookContext) {
hits.retain(|h| h.score > 0.5);
}
}
let mt = wrap(make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({})),
(2, vec![0.0, 1.0, 0.0], serde_json::json!({})),
(3, vec![0.0, 0.0, 1.0], serde_json::json!({})),
],
));
let hook: Arc<dyn SearchHook> = Arc::new(FilterLowScoreHook);
let cfg = SearchConfig {
top_k: 10,
min_score: 0.0,
expand_depth: 0,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
for r in &results {
assert!(
r.score > 0.5,
"Hook 过滤后不应有低分结果: score={}",
r.score
);
}
}
#[test]
fn test_hook_rerank_reverses_order() {
struct ReverseRerankHook;
impl SearchHook for ReverseRerankHook {
fn on_rerank(
&self,
hits: &mut Vec<SearchHit>,
_: &mut HookContext,
) -> Option<Vec<SearchHit>> {
let mut reversed = hits.clone();
reversed.reverse();
Some(reversed)
}
}
let mt = wrap(make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({})),
(2, vec![0.7, 0.7, 0.0], serde_json::json!({})),
],
));
let hook: Arc<dyn SearchHook> = Arc::new(ReverseRerankHook);
let cfg = SearchConfig {
top_k: 5,
min_score: 0.0,
expand_depth: 0,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
assert!(results.len() >= 2);
assert_eq!(results[0].id, 2, "rerank 反转后 node 2 应排第一");
}
#[test]
fn test_pipeline_clamps_extreme_config() {
let mt = wrap(make_memtable(
3,
&[(1, vec![1.0, 0.0, 0.0], serde_json::json!({}))],
));
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = SearchConfig {
top_k: 0,
min_score: 0.0,
expand_depth: 0,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx);
assert!(results.is_ok(), "极端参数不应 panic");
}
#[test]
fn test_pipeline_with_graph_expansion() {
let mut mt = make_memtable(
3,
&[
(1, vec![1.0, 0.0, 0.0], serde_json::json!({"name": "seed"})),
(
2,
vec![0.0, 1.0, 0.0],
serde_json::json!({"name": "neighbor"}),
),
],
);
mt.link(1, 2, "related".to_string(), 0.8).unwrap();
let mt = wrap(mt);
let hook: Arc<dyn SearchHook> = Arc::new(NoopHook);
let cfg = SearchConfig {
top_k: 5,
min_score: 0.0,
expand_depth: 1,
..Default::default()
};
let query = vec![1.0, 0.0, 0.0];
let mut ctx = HookContext::new();
let results = execute_pipeline(&mt, &hook, None, Some(&query), &cfg, &mut ctx).unwrap();
let ids: Vec<u64> = results.iter().map(|h| h.id).collect();
assert!(ids.contains(&2), "图扩散应将邻居节点 2 纳入结果");
}
}