converge_knowledge/learning/
mod.rs1pub mod batch;
45pub mod feedback;
46mod gnn;
47mod replay;
48
49pub use batch::{
50 BatchInput, BatchJob, BatchScheduler, EntryMetadata, Insight, InsightStore, JobRun, JobStatus,
51 JobType, KnowledgeClass, RelationshipType, Trend,
52};
53pub use feedback::{
54 FeedbackCollector, FeedbackConfig, FeedbackProcessor, FeedbackSignal, ProcessedFeedback,
55 QueryId, SessionId, SignalType,
56};
57pub use gnn::GnnLayer;
58pub use replay::ReplayBuffer;
59
60use crate::core::SearchResult;
61use dashmap::DashMap;
62use serde::{Deserialize, Serialize};
63use std::collections::VecDeque;
64use uuid::Uuid;
65
66pub struct LearningEngine {
73 dimensions: usize,
75
76 learning_rate: f32,
78
79 relevance_weights: Vec<f32>,
81
82 gnn_layer: GnnLayer,
84
85 replay_buffer: ReplayBuffer<Experience>,
87
88 #[allow(dead_code)]
90 query_patterns: VecDeque<QueryPattern>,
91
92 entry_scores: DashMap<Uuid, f32>,
94
95 fisher_diagonal: Vec<f32>,
97
98 query_count: u64,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104struct QueryPattern {
105 query_embedding: Vec<f32>,
107
108 result_embeddings: Vec<(Vec<f32>, f32)>,
110
111 timestamp: u64,
113}
114
115impl LearningEngine {
116 pub fn new(dimensions: usize, learning_rate: f32) -> Self {
118 Self {
119 dimensions,
120 learning_rate,
121 relevance_weights: vec![1.0; dimensions],
122 gnn_layer: GnnLayer::new(dimensions, dimensions * 2, 4),
123 replay_buffer: ReplayBuffer::new(10000),
124 query_patterns: VecDeque::with_capacity(1000),
125 entry_scores: DashMap::new(),
126 fisher_diagonal: vec![0.0; dimensions],
127 query_count: 0,
128 }
129 }
130
131 pub fn rerank(
133 &self,
134 query_embedding: &[f32],
135 mut candidates: Vec<(Uuid, f32)>,
136 vectors: &DashMap<Uuid, Vec<f32>>,
137 ) -> Vec<(Uuid, f32)> {
138 let neighbors: Vec<Vec<f32>> = candidates
140 .iter()
141 .take(10)
142 .filter_map(|(id, _)| vectors.get(id).map(|v| v.clone()))
143 .collect();
144
145 let edge_weights: Vec<f32> = candidates
146 .iter()
147 .take(10)
148 .map(|(_, d)| 1.0 - d.min(1.0))
149 .collect();
150
151 let transformed_query = self
152 .gnn_layer
153 .forward(query_embedding, &neighbors, &edge_weights);
154
155 for (id, distance) in candidates.iter_mut() {
157 if let Some(vector) = vectors.get(id) {
158 let weighted_distance = self.weighted_distance(&transformed_query, &vector);
160
161 let entry_boost = self.entry_scores.get(id).map(|s| *s).unwrap_or(1.0);
163
164 *distance = weighted_distance / entry_boost;
165 }
166 }
167
168 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
170
171 candidates
172 }
173
174 fn weighted_distance(&self, a: &[f32], b: &[f32]) -> f32 {
176 let mut weighted_dot = 0.0f32;
177 let mut weighted_norm_a = 0.0f32;
178 let mut weighted_norm_b = 0.0f32;
179
180 for i in 0..a.len().min(b.len()).min(self.dimensions) {
181 let w = self.relevance_weights[i];
182 weighted_dot += a[i] * b[i] * w;
183 weighted_norm_a += a[i] * a[i] * w;
184 weighted_norm_b += b[i] * b[i] * w;
185 }
186
187 let norm = (weighted_norm_a * weighted_norm_b).sqrt();
188 if norm > 0.0 {
189 1.0 - (weighted_dot / norm)
190 } else {
191 1.0
192 }
193 }
194
195 pub fn record_query(&mut self, query_embedding: &[f32], results: &[SearchResult]) {
197 if results.is_empty() {
198 return;
199 }
200
201 self.query_count += 1;
202
203 for (rank, result) in results.iter().enumerate() {
205 self.replay_buffer.add(Experience {
206 query: query_embedding.to_vec(),
207 result_id: result.entry.id,
208 rank: rank as u32,
209 score: result.score,
210 });
211 }
212
213 if self.query_count % 10 == 0 {
215 self.learn_from_replay();
216 }
217 }
218
219 pub fn record_feedback(&mut self, result_embedding: &[f32], positive: bool) {
221 let adjustment = if positive {
222 self.learning_rate
223 } else {
224 -self.learning_rate * 0.5
225 };
226
227 for (i, &val) in result_embedding.iter().enumerate() {
229 if i < self.dimensions {
230 let delta = adjustment * val.abs();
232
233 let ewc_factor = 1.0 / (1.0 + self.fisher_diagonal[i]);
235
236 self.relevance_weights[i] =
237 (self.relevance_weights[i] + delta * ewc_factor).clamp(0.1, 10.0);
238 }
239 }
240
241 self.update_fisher(result_embedding);
243 }
244
245 fn update_fisher(&mut self, embedding: &[f32]) {
247 for (i, &val) in embedding.iter().enumerate() {
248 if i < self.dimensions {
249 self.fisher_diagonal[i] = 0.99 * self.fisher_diagonal[i] + 0.01 * val * val;
251 }
252 }
253 }
254
255 fn learn_from_replay(&mut self) {
257 let samples = self.replay_buffer.sample(32);
258
259 for experience in samples {
260 let target_boost = 1.0 + (1.0 / (1.0 + experience.rank as f32));
262
263 self.entry_scores
265 .entry(experience.result_id)
266 .and_modify(|s| {
267 *s = (*s + target_boost) / 2.0;
268 })
269 .or_insert(target_boost);
270
271 self.gnn_layer
273 .update(&experience.query, target_boost, self.learning_rate);
274 }
275 }
276
277 pub fn query_count(&self) -> u64 {
279 self.query_count
280 }
281
282 pub fn stats(&self) -> LearningStats {
284 let avg_weight: f32 = self.relevance_weights.iter().sum::<f32>() / self.dimensions as f32;
285 let weight_variance: f32 = self
286 .relevance_weights
287 .iter()
288 .map(|w| (w - avg_weight).powi(2))
289 .sum::<f32>()
290 / self.dimensions as f32;
291
292 LearningStats {
293 query_count: self.query_count,
294 replay_buffer_size: self.replay_buffer.len(),
295 learned_entries: self.entry_scores.len(),
296 avg_relevance_weight: avg_weight,
297 weight_variance,
298 }
299 }
300}
301
302#[derive(Debug, Clone)]
304struct Experience {
305 query: Vec<f32>,
306 result_id: Uuid,
307 rank: u32,
308 #[allow(dead_code)]
310 score: f32,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct LearningStats {
316 pub query_count: u64,
318 pub replay_buffer_size: usize,
320 pub learned_entries: usize,
322 pub avg_relevance_weight: f32,
324 pub weight_variance: f32,
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[test]
333 fn test_learning_engine_creation() {
334 let engine = LearningEngine::new(128, 0.01);
335 assert_eq!(engine.dimensions, 128);
336 assert_eq!(engine.query_count, 0);
337 }
338
339 #[test]
340 fn test_feedback_updates_weights() {
341 let mut engine = LearningEngine::new(64, 0.1);
342 let initial_weights = engine.relevance_weights.clone();
343
344 let embedding = vec![0.5; 64];
345 engine.record_feedback(&embedding, true);
346
347 assert_ne!(engine.relevance_weights, initial_weights);
349 }
350}