1pub mod decay;
20
21pub use decay::{DecayLane, DecayParams, decay_weight};
22
23use std::time::SystemTime;
24
25use crate::model::memory::MemoryRecord;
26
27pub trait ScoreLane: Send + Sync {
30 fn score(&self, mem: &MemoryRecord, ctx: &ScoreContext) -> f32;
32
33 fn name(&self) -> &'static str;
35}
36
37#[derive(Debug, Clone)]
41pub struct ScoreContext {
42 pub now: SystemTime,
43 pub query_text: String,
44 pub letta_mode: bool,
47}
48
49impl ScoreContext {
50 pub fn new(now: SystemTime, query_text: impl Into<String>) -> Self {
51 Self {
52 now,
53 query_text: query_text.into(),
54 letta_mode: false,
55 }
56 }
57
58 pub fn with_letta_mode(mut self, on: bool) -> Self {
59 self.letta_mode = on;
60 self
61 }
62}
63
64pub const DEFAULT_VECTOR_WEIGHT: f32 = 0.55;
67pub const DEFAULT_BM25_WEIGHT: f32 = 0.20;
68pub const DEFAULT_RECENCY_WEIGHT: f32 = 0.15;
69pub const DEFAULT_DECAY_WEIGHT: f32 = 0.10;
70
71pub fn fuse_default(vector: f32, bm25: f32, recency: f32, decay: f32) -> f32 {
75 fuse_weighted(
76 vector,
77 bm25,
78 recency,
79 decay,
80 DEFAULT_VECTOR_WEIGHT,
81 DEFAULT_BM25_WEIGHT,
82 DEFAULT_RECENCY_WEIGHT,
83 DEFAULT_DECAY_WEIGHT,
84 )
85}
86
87#[allow(clippy::too_many_arguments)]
88pub fn fuse_weighted(
89 vector: f32,
90 bm25: f32,
91 recency: f32,
92 decay: f32,
93 w_vector: f32,
94 w_bm25: f32,
95 w_recency: f32,
96 w_decay: f32,
97) -> f32 {
98 (vector * w_vector + bm25 * w_bm25 + recency * w_recency + decay * w_decay).clamp(0.0, 1.0)
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn default_weights_sum_to_one() {
107 let s = DEFAULT_VECTOR_WEIGHT
108 + DEFAULT_BM25_WEIGHT
109 + DEFAULT_RECENCY_WEIGHT
110 + DEFAULT_DECAY_WEIGHT;
111 assert!((s - 1.0).abs() < 1e-6, "weights sum {s} should be 1.0");
112 }
113
114 #[test]
115 fn fuse_clamps_to_unit_interval() {
116 let s = fuse_default(-1.0, 0.0, 0.0, 0.0);
118 assert_eq!(s, 0.0);
119 let s = fuse_default(2.0, 2.0, 2.0, 2.0);
121 assert_eq!(s, 1.0);
122 }
123
124 #[test]
125 fn fuse_default_is_monotonic_in_each_lane() {
126 let base = fuse_default(0.5, 0.5, 0.5, 0.5);
127 assert!(fuse_default(0.6, 0.5, 0.5, 0.5) > base);
128 assert!(fuse_default(0.5, 0.6, 0.5, 0.5) > base);
129 assert!(fuse_default(0.5, 0.5, 0.6, 0.5) > base);
130 assert!(fuse_default(0.5, 0.5, 0.5, 0.6) > base);
131 }
132}