Skip to main content

mnemo_core/score/
mod.rs

1//! v0.4.0 (P1-4) — pluggable score lanes for hybrid recall fusion.
2//!
3//! The default Mnemo recall fuses four signals: vector similarity,
4//! BM25 lexical, recency, and (new in v0.4.0) Ebbinghaus-style decay
5//! with reinforcement. Each signal is a `ScoreLane` that maps a
6//! candidate memory to a `f32` in `[0.0, 1.0]`. The fusion sums the
7//! lanes with operator-tuned weights:
8//!
9//! ```text
10//! score = 0.55 * vector
11//!       + 0.20 * bm25
12//!       + 0.15 * recency
13//!       + 0.10 * decay
14//! ```
15//!
16//! Letta-protocol mode (`LettaProtocolMode`) skips the decay lane so
17//! parity with Letta's published numbers is preserved.
18
19pub mod decay;
20
21pub use decay::{DecayLane, DecayParams, decay_weight};
22
23use std::time::SystemTime;
24
25use crate::model::memory::MemoryRecord;
26
27/// One scoring signal. Implementations are stateless or hold cheap
28/// configuration; the recall path holds them as `Arc<dyn ScoreLane>`.
29pub trait ScoreLane: Send + Sync {
30    /// Bounded score in `[0.0, 1.0]`. Higher is better.
31    fn score(&self, mem: &MemoryRecord, ctx: &ScoreContext) -> f32;
32
33    /// Stable name — used in audit-log explanations + debug output.
34    fn name(&self) -> &'static str;
35}
36
37/// Context the recall path threads through to every lane. Holds
38/// per-query state (current time, agent's recent activity) so a lane
39/// can use it without re-querying storage.
40#[derive(Debug, Clone)]
41pub struct ScoreContext {
42    pub now: SystemTime,
43    pub query_text: String,
44    /// `true` when the recall request set `mode = Letta`. Lanes that
45    /// must be bypassed for parity check this flag.
46    pub letta_mode: bool,
47}
48
49impl ScoreContext {
50    pub fn new(now: SystemTime, query_text: impl Into<String>) -> Self {
51        Self {
52            now,
53            query_text: query_text.into(),
54            letta_mode: false,
55        }
56    }
57
58    pub fn with_letta_mode(mut self, on: bool) -> Self {
59        self.letta_mode = on;
60        self
61    }
62}
63
64/// Default v0.4.0 fusion weights. Tuned against the bundled
65/// LongMemEval_M sample so the bench gate doesn't regress.
66pub const DEFAULT_VECTOR_WEIGHT: f32 = 0.55;
67pub const DEFAULT_BM25_WEIGHT: f32 = 0.20;
68pub const DEFAULT_RECENCY_WEIGHT: f32 = 0.15;
69pub const DEFAULT_DECAY_WEIGHT: f32 = 0.10;
70
71/// Sum the four lane signals using the v0.4.0 default weights.
72/// Operators that override weights via `RecallRequest.hybrid_weights`
73/// build their own `fuse_weighted` call site.
74pub fn fuse_default(vector: f32, bm25: f32, recency: f32, decay: f32) -> f32 {
75    fuse_weighted(
76        vector,
77        bm25,
78        recency,
79        decay,
80        DEFAULT_VECTOR_WEIGHT,
81        DEFAULT_BM25_WEIGHT,
82        DEFAULT_RECENCY_WEIGHT,
83        DEFAULT_DECAY_WEIGHT,
84    )
85}
86
87#[allow(clippy::too_many_arguments)]
88pub fn fuse_weighted(
89    vector: f32,
90    bm25: f32,
91    recency: f32,
92    decay: f32,
93    w_vector: f32,
94    w_bm25: f32,
95    w_recency: f32,
96    w_decay: f32,
97) -> f32 {
98    (vector * w_vector + bm25 * w_bm25 + recency * w_recency + decay * w_decay).clamp(0.0, 1.0)
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn default_weights_sum_to_one() {
107        let s = DEFAULT_VECTOR_WEIGHT
108            + DEFAULT_BM25_WEIGHT
109            + DEFAULT_RECENCY_WEIGHT
110            + DEFAULT_DECAY_WEIGHT;
111        assert!((s - 1.0).abs() < 1e-6, "weights sum {s} should be 1.0");
112    }
113
114    #[test]
115    fn fuse_clamps_to_unit_interval() {
116        // Negative inputs cannot push the fused score below 0.
117        let s = fuse_default(-1.0, 0.0, 0.0, 0.0);
118        assert_eq!(s, 0.0);
119        // Inputs > 1.0 cannot push it above 1.0.
120        let s = fuse_default(2.0, 2.0, 2.0, 2.0);
121        assert_eq!(s, 1.0);
122    }
123
124    #[test]
125    fn fuse_default_is_monotonic_in_each_lane() {
126        let base = fuse_default(0.5, 0.5, 0.5, 0.5);
127        assert!(fuse_default(0.6, 0.5, 0.5, 0.5) > base);
128        assert!(fuse_default(0.5, 0.6, 0.5, 0.5) > base);
129        assert!(fuse_default(0.5, 0.5, 0.6, 0.5) > base);
130        assert!(fuse_default(0.5, 0.5, 0.5, 0.6) > base);
131    }
132}