1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
//! Feedback-driven weight learning for retrieval fusion.
//!
//! When a user marks a recalled memory as relevant or irrelevant, the feedback
//! is used to nudge the 6 fusion weights via online gradient descent.
//! Weights are stored per-namespace so different use cases converge to
//! different optimal retrieval strategies.
use serde::{Deserialize, Serialize};
/// Learning rate for online gradient descent.
const DEFAULT_LEARNING_RATE: f32 = 0.01;
/// Minimum weight value (prevents any signal from being zeroed out).
const MIN_WEIGHT: f32 = 0.01;
/// Feedback signal for a recalled memory.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalFeedback {
/// The 6 raw signal scores for this candidate: vector, BM25, activation, spread, intent, confidence.
pub signals: [f32; 6],
/// Whether the user found this memory relevant (true) or not (false).
pub relevant: bool,
}
/// Online weight learner using gradient descent.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WeightLearner {
/// Current learned weights.
pub weights: [f32; 6],
/// Learning rate.
pub learning_rate: f32,
/// Number of feedback samples received.
pub sample_count: u64,
}
impl Default for WeightLearner {
fn default() -> Self {
// Proportional to [1.0, 0.8, 1.0, 0.8, 0.5, 0.5] (RRF signal weights),
// normalized to sum to 1.0 (total = 4.6).
Self {
weights: [
1.0 / 4.6,
0.8 / 4.6,
1.0 / 4.6,
0.8 / 4.6,
0.5 / 4.6,
0.5 / 4.6,
],
learning_rate: DEFAULT_LEARNING_RATE,
sample_count: 0,
}
}
}
impl WeightLearner {
/// Create a learner initialized with specific weights.
pub fn with_weights(weights: [f32; 6]) -> Self {
Self {
weights,
..Default::default()
}
}
/// Apply a single feedback sample using online gradient descent.
///
/// For a relevant memory, we want the weighted score to be high, so we
/// increase weights proportional to the signal values. For irrelevant
/// memories, we decrease weights.
pub fn update(&mut self, feedback: &RetrievalFeedback) {
let direction: f32 = if feedback.relevant { 1.0 } else { -1.0 };
for (w, &signal) in self.weights.iter_mut().zip(&feedback.signals) {
*w = (*w + self.learning_rate * direction * signal).max(MIN_WEIGHT);
}
// Normalize so weights sum to 1.0
let sum: f32 = self.weights.iter().sum();
if sum > 0.0 {
for w in &mut self.weights {
*w /= sum;
}
}
self.sample_count += 1;
}
/// Apply a batch of feedback samples.
pub fn update_batch(&mut self, feedback: &[RetrievalFeedback]) {
for f in feedback {
self.update(f);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_weights_sum_to_one() {
let learner = WeightLearner::default();
let sum: f32 = learner.weights.iter().sum();
assert!((sum - 1.0).abs() < 0.01);
}
#[test]
fn test_positive_feedback_increases_active_weights() {
let mut learner = WeightLearner::default();
let initial_w0 = learner.weights[0];
// Feedback: vector score was high (1.0), other scores low
let feedback = RetrievalFeedback {
signals: [1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
relevant: true,
};
learner.update(&feedback);
// After normalization, w0 should be higher relative to others
assert!(
learner.weights[0] > initial_w0,
"vector weight should increase"
);
}
#[test]
fn test_negative_feedback_decreases_active_weights() {
let mut learner = WeightLearner::default();
let initial_w0 = learner.weights[0];
let feedback = RetrievalFeedback {
signals: [1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
relevant: false,
};
learner.update(&feedback);
assert!(
learner.weights[0] < initial_w0,
"vector weight should decrease"
);
}
#[test]
fn test_weights_always_positive() {
let mut learner = WeightLearner::default();
// Apply many negative feedback samples
let feedback = RetrievalFeedback {
signals: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
relevant: false,
};
for _ in 0..1000 {
learner.update(&feedback);
}
for w in &learner.weights {
assert!(*w >= MIN_WEIGHT, "weight {} should be >= MIN_WEIGHT", w);
}
}
#[test]
fn test_weights_normalized() {
let mut learner = WeightLearner::default();
let feedback = RetrievalFeedback {
signals: [0.9, 0.1, 0.5, 0.3, 0.7, 0.2],
relevant: true,
};
learner.update(&feedback);
let sum: f32 = learner.weights.iter().sum();
assert!(
(sum - 1.0).abs() < 0.001,
"weights should sum to 1.0, got {sum}"
);
}
#[test]
fn test_sample_count_increments() {
let mut learner = WeightLearner::default();
assert_eq!(learner.sample_count, 0);
let feedback = RetrievalFeedback {
signals: [0.5; 6],
relevant: true,
};
learner.update(&feedback);
assert_eq!(learner.sample_count, 1);
learner.update(&feedback);
assert_eq!(learner.sample_count, 2);
}
}