use crate::search::SearchResult;
use crate::traits::MemoryMeta;
pub trait ScoringStrategy: Send + Sync {
fn score_multiplier(
&self,
record: &MemoryMeta,
query: &str,
base_score: f32,
) -> f32;
}
#[derive(Debug, Clone)]
pub struct ScoredResult {
pub memory_id: i64,
pub score: f32,
pub raw_score: f32,
pub score_multiplier: f32,
}
impl ScoredResult {
pub fn from_search_result(result: &SearchResult) -> Self {
Self {
memory_id: result.memory_id,
score: result.score,
raw_score: result.score,
score_multiplier: 1.0,
}
}
pub fn apply_scoring(&mut self, multiplier: f32) {
self.score_multiplier *= multiplier;
self.score = self.raw_score * self.score_multiplier;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::MemoryType;
use chrono::Utc;
struct BoostScorer(f32);
impl ScoringStrategy for BoostScorer {
fn score_multiplier(&self, _record: &MemoryMeta, _query: &str, _base: f32) -> f32 {
self.0
}
}
fn test_meta() -> MemoryMeta {
MemoryMeta {
id: Some(1),
searchable_text: "test".into(),
memory_type: MemoryType::Semantic,
importance: 5,
category: None,
created_at: Utc::now(),
metadata: Default::default(),
}
}
#[test]
fn scored_result_from_search() {
let sr = SearchResult {
memory_id: 42,
score: 0.75,
};
let scored = ScoredResult::from_search_result(&sr);
assert_eq!(scored.memory_id, 42);
assert!((scored.score - 0.75).abs() < f32::EPSILON);
assert!((scored.raw_score - 0.75).abs() < f32::EPSILON);
assert!((scored.score_multiplier - 1.0).abs() < f32::EPSILON);
}
#[test]
fn apply_scoring_multiplier() {
let sr = SearchResult {
memory_id: 1,
score: 1.0,
};
let mut scored = ScoredResult::from_search_result(&sr);
scored.apply_scoring(2.0);
assert!((scored.score - 2.0).abs() < f32::EPSILON);
assert!((scored.score_multiplier - 2.0).abs() < f32::EPSILON);
assert!((scored.raw_score - 1.0).abs() < f32::EPSILON);
scored.apply_scoring(0.5);
assert!((scored.score - 1.0).abs() < 0.01); assert!((scored.score_multiplier - 1.0).abs() < f32::EPSILON);
}
#[test]
fn scoring_strategy_trait_object() {
let scorer: Box<dyn ScoringStrategy> = Box::new(BoostScorer(1.5));
let meta = test_meta();
let m = scorer.score_multiplier(&meta, "test", 1.0);
assert!((m - 1.5).abs() < f32::EPSILON);
}
}