use crate::database::SearchConfig;
use crate::node::SearchHit;
#[derive(Debug, Clone)]
pub struct HookContext {
pub custom_data: serde_json::Value,
pub stage_timings: Vec<(String, std::time::Duration)>,
pub abort: bool,
}
impl Default for HookContext {
fn default() -> Self {
Self::new()
}
}
impl HookContext {
pub fn new() -> Self {
Self {
custom_data: serde_json::Value::Null,
stage_timings: Vec::new(),
abort: false,
}
}
pub fn record_timing(&mut self, stage: impl Into<String>, elapsed: std::time::Duration) {
self.stage_timings.push((stage.into(), elapsed));
}
}
pub trait SearchHook: Send + Sync {
fn on_pre_search(
&self,
_query_vector: &mut Vec<f32>,
_config: &mut SearchConfig,
_ctx: &mut HookContext,
) {
}
fn on_custom_recall(
&self,
_query_vector: &[f32],
_config: &SearchConfig,
_ctx: &mut HookContext,
) -> Option<Vec<SearchHit>> {
None }
fn on_post_recall(&self, _hits: &mut Vec<SearchHit>, _ctx: &mut HookContext) {
}
fn on_pre_graph_expand(&self, _seeds: &mut Vec<SearchHit>, _ctx: &mut HookContext) {
}
fn on_rerank(
&self,
_hits: &mut Vec<SearchHit>,
_ctx: &mut HookContext,
) -> Option<Vec<SearchHit>> {
None }
fn on_post_search(&self, _results: &mut Vec<SearchHit>, _ctx: &mut HookContext) {
}
}
pub struct NoopHook;
impl SearchHook for NoopHook {}
pub struct CompositeHook {
hooks: Vec<Box<dyn SearchHook>>,
}
impl CompositeHook {
pub fn new() -> Self {
Self { hooks: Vec::new() }
}
pub fn add(&mut self, hook: impl SearchHook + 'static) {
self.hooks.push(Box::new(hook));
}
}
impl Default for CompositeHook {
fn default() -> Self {
Self::new()
}
}
impl SearchHook for CompositeHook {
fn on_pre_search(
&self,
query_vector: &mut Vec<f32>,
config: &mut SearchConfig,
ctx: &mut HookContext,
) {
for hook in &self.hooks {
hook.on_pre_search(query_vector, config, ctx);
if ctx.abort {
return;
}
}
}
fn on_custom_recall(
&self,
query_vector: &[f32],
config: &SearchConfig,
ctx: &mut HookContext,
) -> Option<Vec<SearchHit>> {
for hook in &self.hooks {
if let Some(result) = hook.on_custom_recall(query_vector, config, ctx) {
return Some(result);
}
}
None
}
fn on_post_recall(&self, hits: &mut Vec<SearchHit>, ctx: &mut HookContext) {
for hook in &self.hooks {
hook.on_post_recall(hits, ctx);
}
}
fn on_pre_graph_expand(&self, seeds: &mut Vec<SearchHit>, ctx: &mut HookContext) {
for hook in &self.hooks {
hook.on_pre_graph_expand(seeds, ctx);
}
}
fn on_rerank(
&self,
hits: &mut Vec<SearchHit>,
ctx: &mut HookContext,
) -> Option<Vec<SearchHit>> {
for hook in &self.hooks {
if let Some(result) = hook.on_rerank(hits, ctx) {
return Some(result);
}
}
None
}
fn on_post_search(&self, results: &mut Vec<SearchHit>, ctx: &mut HookContext) {
for hook in &self.hooks {
hook.on_post_search(results, ctx);
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct FfiSearchHit {
pub id: u64,
pub score: f32,
}
pub type FfiRecallFn = unsafe extern "C" fn(
query_ptr: *const f32,
query_len: usize,
top_k: usize,
out_hits: *mut FfiSearchHit,
out_count: *mut usize,
) -> i32;
pub type FfiRerankFn = unsafe extern "C" fn(hits_ptr: *mut FfiSearchHit, hits_count: usize) -> i32;
pub struct FfiHook {
_lib: libloading::Library,
recall_fn: Option<FfiRecallFn>,
rerank_fn: Option<FfiRerankFn>,
}
unsafe impl Send for FfiHook {}
unsafe impl Sync for FfiHook {}
impl FfiHook {
pub fn load(path: &str) -> crate::error::Result<Self> {
unsafe {
let lib = libloading::Library::new(path).map_err(|e| {
crate::error::TriviumError::HookLoadError(format!(
"无法加载外置 Hook 动态库 '{}': {}",
path, e
))
})?;
let recall_fn = lib.get::<FfiRecallFn>(b"trivium_recall").ok().map(|f| *f);
let rerank_fn = lib.get::<FfiRerankFn>(b"trivium_rerank").ok().map(|f| *f);
tracing::info!(
"已加载外置 Hook 模块: {} (recall={}, rerank={})",
path,
recall_fn.is_some(),
rerank_fn.is_some()
);
Ok(Self {
_lib: lib,
recall_fn,
rerank_fn,
})
}
}
}
impl SearchHook for FfiHook {
fn on_custom_recall(
&self,
query_vector: &[f32],
config: &SearchConfig,
_ctx: &mut HookContext,
) -> Option<Vec<SearchHit>> {
let recall_fn = self.recall_fn?;
let buf_size = config.top_k * 2;
let mut buf = vec![FfiSearchHit { id: 0, score: 0.0 }; buf_size];
let mut count: usize = 0;
let ret = unsafe {
(recall_fn)(
query_vector.as_ptr(),
query_vector.len(),
config.top_k,
buf.as_mut_ptr(),
&mut count,
)
};
if ret != 0 {
tracing::warn!("FFI recall 函数返回错误码: {}", ret);
return None;
}
let hits: Vec<SearchHit> = buf[..count.min(buf_size)]
.iter()
.filter(|h| h.id != 0)
.map(|h| SearchHit {
id: h.id,
score: h.score,
payload: serde_json::Value::Null,
})
.collect();
Some(hits)
}
fn on_rerank(
&self,
hits: &mut Vec<SearchHit>,
_ctx: &mut HookContext,
) -> Option<Vec<SearchHit>> {
let rerank_fn = self.rerank_fn?;
let mut ffi_hits: Vec<FfiSearchHit> = hits
.iter()
.map(|h| FfiSearchHit {
id: h.id,
score: h.score,
})
.collect();
let ret = unsafe { (rerank_fn)(ffi_hits.as_mut_ptr(), ffi_hits.len()) };
if ret != 0 {
tracing::warn!("FFI rerank 函数返回错误码: {}", ret);
return None;
}
let mut reranked: Vec<SearchHit> = hits
.iter()
.zip(ffi_hits.iter())
.map(|(original, ffi)| SearchHit {
id: original.id,
score: ffi.score,
payload: original.payload.clone(),
})
.collect();
reranked.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Some(reranked)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_noop_hook_is_default() {
let hook = NoopHook;
let mut ctx = HookContext::new();
let mut vec = vec![1.0, 2.0, 3.0];
let mut config = SearchConfig::default();
hook.on_pre_search(&mut vec, &mut config, &mut ctx);
assert_eq!(vec, vec![1.0, 2.0, 3.0]); assert!(!ctx.abort);
assert!(hook.on_custom_recall(&vec, &config, &mut ctx).is_none());
let mut hits = vec![];
hook.on_post_recall(&mut hits, &mut ctx);
hook.on_pre_graph_expand(&mut hits, &mut ctx);
assert!(hook.on_rerank(&mut hits, &mut ctx).is_none());
hook.on_post_search(&mut hits, &mut ctx);
}
#[test]
fn test_hook_context() {
let mut ctx = HookContext::new();
assert!(ctx.custom_data.is_null());
assert!(ctx.stage_timings.is_empty());
assert!(!ctx.abort);
ctx.custom_data = serde_json::json!({"user_id": "u_123"});
ctx.record_timing("recall", std::time::Duration::from_millis(5));
assert_eq!(ctx.custom_data["user_id"], "u_123");
assert_eq!(ctx.stage_timings.len(), 1);
assert_eq!(ctx.stage_timings[0].0, "recall");
}
struct TimeDecayHook {
decay_rate: f32,
}
impl SearchHook for TimeDecayHook {
fn on_post_recall(&self, hits: &mut Vec<SearchHit>, _ctx: &mut HookContext) {
for hit in hits.iter_mut() {
hit.score *= self.decay_rate;
}
}
}
#[test]
fn test_custom_hook() {
let hook = TimeDecayHook { decay_rate: 0.8 };
let mut ctx = HookContext::new();
let mut hits = vec![
SearchHit {
id: 1,
score: 1.0,
payload: serde_json::Value::Null,
},
SearchHit {
id: 2,
score: 0.5,
payload: serde_json::Value::Null,
},
];
hook.on_post_recall(&mut hits, &mut ctx);
assert!((hits[0].score - 0.8).abs() < 1e-6);
assert!((hits[1].score - 0.4).abs() < 1e-6);
}
#[test]
fn test_composite_hook() {
struct BoostHook;
impl SearchHook for BoostHook {
fn on_post_recall(&self, hits: &mut Vec<SearchHit>, _ctx: &mut HookContext) {
for hit in hits.iter_mut() {
hit.score *= 2.0;
}
}
}
struct FilterHook;
impl SearchHook for FilterHook {
fn on_post_recall(&self, hits: &mut Vec<SearchHit>, _ctx: &mut HookContext) {
hits.retain(|h| h.score > 0.5);
}
}
let mut composite = CompositeHook::new();
composite.add(BoostHook);
composite.add(FilterHook);
let mut ctx = HookContext::new();
let mut hits = vec![
SearchHit {
id: 1,
score: 0.3, payload: serde_json::Value::Null,
},
SearchHit {
id: 2,
score: 0.2, payload: serde_json::Value::Null,
},
];
composite.on_post_recall(&mut hits, &mut ctx);
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].id, 1);
}
}