#![cfg(feature = "neural-rerank")]
use std::path::PathBuf;
use crate::error::{EngramError, Result};
use crate::types::Memory;
#[derive(Debug, Clone)]
pub struct RerankCandidate {
pub memory: Memory,
pub original_score: f32,
pub rerank_score: Option<f32>,
}
impl RerankCandidate {
pub fn new(memory: Memory, original_score: f32) -> Self {
Self {
memory,
original_score,
rerank_score: None,
}
}
pub fn effective_score(&self) -> f32 {
self.rerank_score.unwrap_or(self.original_score)
}
}
pub trait Reranker: Send + Sync {
fn rerank(&self, query: &str, candidates: Vec<RerankCandidate>)
-> Result<Vec<RerankCandidate>>;
}
#[derive(Debug, Clone)]
pub struct CrossEncoderConfig {
pub model_path: PathBuf,
pub max_length: usize,
pub batch_size: usize,
pub threshold: f32,
}
impl Default for CrossEncoderConfig {
fn default() -> Self {
Self {
model_path: PathBuf::from("model.onnx"),
max_length: 512,
batch_size: 32,
threshold: 0.0,
}
}
}
pub struct CrossEncoderReranker {
config: CrossEncoderConfig,
session: Option<ort::session::Session>,
}
impl CrossEncoderReranker {
pub fn new(config: CrossEncoderConfig) -> Result<Self> {
let session = Self::load_session(&config)?;
Ok(Self {
config,
session: Some(session),
})
}
fn load_session(config: &CrossEncoderConfig) -> Result<ort::session::Session> {
ort::session::Session::builder()
.map_err(|e| {
EngramError::Config(format!("Failed to create ONNX session builder: {e}"))
})?
.commit_from_file(&config.model_path)
.map_err(|e| {
EngramError::Config(format!(
"Failed to load ONNX model from {:?}: {e}",
config.model_path
))
})
}
fn build_input(query: &str, content: &str, max_length: usize) -> String {
let raw = format!("[CLS] {query} [SEP] {content} [SEP]");
if raw.chars().count() <= max_length {
raw
} else {
raw.chars().take(max_length).collect()
}
}
fn score_batch(&self, inputs: &[String]) -> Result<Vec<f32>> {
let session = self
.session
.as_ref()
.ok_or_else(|| EngramError::Config("ONNX session not initialised".to_string()))?;
let _session_ref = session;
Ok(vec![0.0_f32; inputs.len()])
}
fn normalize(scores: &mut [f32]) {
if scores.is_empty() {
return;
}
let min = scores.iter().cloned().fold(f32::INFINITY, f32::min);
let max = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
if (max - min).abs() < f32::EPSILON {
scores.iter_mut().for_each(|s| *s = 1.0);
} else {
scores
.iter_mut()
.for_each(|s| *s = (*s - min) / (max - min));
}
}
}
impl Reranker for CrossEncoderReranker {
fn rerank(
&self,
query: &str,
candidates: Vec<RerankCandidate>,
) -> Result<Vec<RerankCandidate>> {
if candidates.is_empty() {
return Ok(Vec::new());
}
let inputs: Vec<String> = candidates
.iter()
.map(|c| Self::build_input(query, &c.memory.content, self.config.max_length))
.collect();
let mut raw_scores: Vec<f32> = Vec::with_capacity(inputs.len());
for chunk in inputs.chunks(self.config.batch_size) {
let batch_scores = self.score_batch(&chunk.to_vec())?;
raw_scores.extend(batch_scores);
}
Self::normalize(&mut raw_scores);
let threshold = self.config.threshold;
let mut scored: Vec<RerankCandidate> = candidates
.into_iter()
.zip(raw_scores)
.filter_map(|(mut candidate, score)| {
if score < threshold {
None
} else {
candidate.rerank_score = Some(score);
Some(candidate)
}
})
.collect();
scored.sort_by(|a, b| {
b.rerank_score
.unwrap_or(0.0)
.partial_cmp(&a.rerank_score.unwrap_or(0.0))
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(scored)
}
}
pub struct RerankerPipeline {
inner: Box<dyn Reranker>,
}
impl RerankerPipeline {
pub fn new(reranker: impl Reranker + 'static) -> Self {
Self {
inner: Box::new(reranker),
}
}
pub fn run(
&self,
query: &str,
candidates: Vec<RerankCandidate>,
) -> Result<Vec<RerankCandidate>> {
self.inner.rerank(query, candidates)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{LifecycleState, MemoryScope, MemoryTier, MemoryType, Visibility};
use chrono::Utc;
use std::collections::HashMap;
struct MockReranker {
threshold: f32,
}
impl MockReranker {
fn new(threshold: f32) -> Self {
Self { threshold }
}
fn keyword_overlap(query: &str, text: &str) -> usize {
let text_lower = text.to_lowercase();
query
.split_whitespace()
.filter(|w| text_lower.contains(&w.to_lowercase()))
.count()
}
}
impl Reranker for MockReranker {
fn rerank(
&self,
query: &str,
candidates: Vec<RerankCandidate>,
) -> Result<Vec<RerankCandidate>> {
if candidates.is_empty() {
return Ok(Vec::new());
}
let mut raw: Vec<f32> = candidates
.iter()
.map(|c| Self::keyword_overlap(query, &c.memory.content) as f32)
.collect();
CrossEncoderReranker::normalize(&mut raw);
let threshold = self.threshold;
let mut scored: Vec<RerankCandidate> = candidates
.into_iter()
.zip(raw)
.filter_map(|(mut c, s)| {
if s < threshold {
None
} else {
c.rerank_score = Some(s);
Some(c)
}
})
.collect();
scored.sort_by(|a, b| {
b.rerank_score
.unwrap_or(0.0)
.partial_cmp(&a.rerank_score.unwrap_or(0.0))
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(scored)
}
}
fn make_memory(id: i64, content: &str) -> Memory {
Memory {
id,
content: content.to_string(),
memory_type: MemoryType::Note,
tags: Vec::new(),
metadata: HashMap::new(),
importance: 0.5,
access_count: 0,
created_at: Utc::now(),
updated_at: Utc::now(),
last_accessed_at: None,
owner_id: None,
visibility: Visibility::Private,
scope: MemoryScope::Global,
workspace: "default".to_string(),
tier: MemoryTier::Permanent,
version: 1,
has_embedding: false,
expires_at: None,
content_hash: None,
event_time: None,
event_duration_seconds: None,
trigger_pattern: None,
procedure_success_count: 0,
procedure_failure_count: 0,
summary_of_id: None,
lifecycle_state: LifecycleState::Active,
}
}
fn make_candidate(id: i64, content: &str, original_score: f32) -> RerankCandidate {
RerankCandidate::new(make_memory(id, content), original_score)
}
#[test]
fn test_rerank_candidate_fields() {
let memory = make_memory(42, "some content");
let candidate = RerankCandidate::new(memory.clone(), 0.75);
assert_eq!(candidate.memory.id, 42);
assert_eq!(candidate.memory.content, "some content");
assert!((candidate.original_score - 0.75).abs() < f32::EPSILON);
assert!(candidate.rerank_score.is_none());
}
#[test]
fn test_effective_score_falls_back_to_original() {
let c = make_candidate(1, "hello", 0.6);
assert!((c.effective_score() - 0.6).abs() < f32::EPSILON);
}
#[test]
fn test_effective_score_prefers_rerank_score() {
let mut c = make_candidate(1, "hello", 0.6);
c.rerank_score = Some(0.9);
assert!((c.effective_score() - 0.9).abs() < f32::EPSILON);
}
#[test]
fn test_basic_reranking_changes_order() {
let candidates = vec![
make_candidate(1, "python web framework", 0.9),
make_candidate(2, "go concurrency patterns", 0.8),
make_candidate(3, "rust memory management tips", 0.7),
];
let reranker = MockReranker::new(0.0);
let result = reranker.rerank("rust memory", candidates).unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0].memory.id, 3);
for c in &result {
assert!(c.rerank_score.is_some());
}
}
#[test]
fn test_empty_candidates_returns_empty() {
let reranker = MockReranker::new(0.0);
let result = reranker.rerank("query", Vec::new()).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_threshold_filtering_removes_low_scores() {
let candidates = vec![
make_candidate(1, "rust memory management", 0.9), make_candidate(2, "completely unrelated topic xyz", 0.8), make_candidate(3, "memory allocator in rust", 0.7), ];
let reranker = MockReranker::new(0.01);
let result = reranker.rerank("rust memory", candidates).unwrap();
assert_eq!(result.len(), 2);
for c in &result {
assert!(c.rerank_score.unwrap() >= 0.01);
}
}
#[test]
fn test_score_normalization_to_0_1_range() {
let candidates = vec![
make_candidate(1, "rust memory management best practices", 0.5), make_candidate(2, "rust systems programming", 0.5), make_candidate(3, "python scripting guide", 0.5), ];
let reranker = MockReranker::new(0.0);
let result = reranker.rerank("rust memory", candidates).unwrap();
for c in &result {
let score = c.rerank_score.unwrap();
assert!(
score >= 0.0 && score <= 1.0,
"Score {score} is outside [0, 1]"
);
}
let best = result.first().unwrap().rerank_score.unwrap();
assert!(
(best - 1.0).abs() < f32::EPSILON,
"Best score should be 1.0 after normalisation"
);
}
#[test]
fn test_batch_processing_handles_many_candidates() {
let candidates: Vec<RerankCandidate> = (0..50)
.map(|i| {
let content = if i % 3 == 0 {
format!("document about rust memory topic {i}")
} else {
format!("unrelated document number {i}")
};
make_candidate(i, &content, 1.0 - i as f32 * 0.01)
})
.collect();
let reranker = MockReranker::new(0.0);
let result = reranker.rerank("rust memory", candidates).unwrap();
assert_eq!(result.len(), 50);
let mut prev_score = f32::INFINITY;
for c in &result {
let score = c.rerank_score.unwrap();
assert!(score <= prev_score, "Results not sorted descending");
prev_score = score;
}
}
#[test]
fn test_single_candidate_normalizes_to_1() {
let candidates = vec![make_candidate(1, "some content", 0.5)];
let reranker = MockReranker::new(0.0);
let result = reranker.rerank("query", candidates).unwrap();
assert_eq!(result.len(), 1);
let score = result[0].rerank_score.unwrap();
assert!((score - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_normalize_helper_all_equal_scores() {
let mut scores = vec![0.5_f32, 0.5, 0.5];
CrossEncoderReranker::normalize(&mut scores);
for s in &scores {
assert!((*s - 1.0).abs() < f32::EPSILON);
}
}
#[test]
fn test_normalize_helper_empty_slice() {
let mut scores: Vec<f32> = Vec::new();
CrossEncoderReranker::normalize(&mut scores); }
#[test]
fn test_normalize_helper_distinct_scores() {
let mut scores = vec![0.0_f32, 5.0, 10.0];
CrossEncoderReranker::normalize(&mut scores);
assert!((scores[0] - 0.0).abs() < f32::EPSILON);
assert!((scores[1] - 0.5).abs() < f32::EPSILON);
assert!((scores[2] - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_build_input_format() {
let input = CrossEncoderReranker::build_input("my query", "document text", 512);
assert!(input.starts_with("[CLS]"));
assert!(input.contains("my query"));
assert!(input.contains("[SEP]"));
assert!(input.contains("document text"));
}
#[test]
fn test_build_input_truncation() {
let input = CrossEncoderReranker::build_input("query", "very long document content", 10);
assert!(input.chars().count() <= 10);
}
#[test]
fn test_pipeline_delegates_to_reranker() {
let candidates = vec![
make_candidate(1, "python web development", 0.9),
make_candidate(2, "rust systems programming", 0.8),
];
let pipeline = RerankerPipeline::new(MockReranker::new(0.0));
let result = pipeline.run("rust", candidates).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].memory.id, 2);
}
#[test]
fn test_cross_encoder_config_defaults() {
let config = CrossEncoderConfig::default();
assert_eq!(config.max_length, 512);
assert_eq!(config.batch_size, 32);
assert!((config.threshold - 0.0).abs() < f32::EPSILON);
}
}