Skip to main content

converge_knowledge/learning/
mod.rs

1//! Self-learning engine using GNN-inspired approaches.
2//!
3//! This module implements adaptive learning mechanisms inspired by Graph Neural Networks
4//! to improve search results over time based on user interactions.
5//!
6//! # Components
7//!
8//! - **LearningEngine**: Core learning with GNN message passing
9//! - **FeedbackCollector**: Implicit signal capture from user interactions
10//! - **BatchScheduler**: Background jobs for pattern detection and enrichment
11//! - **InsightStore**: Storage for discovered patterns and relationships
12//!
13//! # Learning Flow
14//!
15//! ```text
16//! User Interactions
17//!        │
18//!        ▼
19//! ┌─────────────────┐
20//! │FeedbackCollector│ ──► Implicit signals (view, select, dwell)
21//! └────────┬────────┘
22//!          │
23//!          ▼
24//! ┌─────────────────┐
25//! │FeedbackProcessor│ ──► ProcessedFeedback (relevance deltas)
26//! └────────┬────────┘
27//!          │
28//!          ▼
29//! ┌─────────────────┐
30//! │ LearningEngine  │ ──► Update weights, GNN propagation
31//! └────────┬────────┘
32//!          │
33//!          ▼
34//! ┌─────────────────┐
35//! │ BatchScheduler  │ ──► Patterns, gaps, classifications
36//! └────────┬────────┘
37//!          │
38//!          ▼
39//! ┌─────────────────┐
40//! │  InsightStore   │ ──► Published insights for retrieval
41//! └─────────────────┘
42//! ```
43
44pub mod batch;
45pub mod feedback;
46mod gnn;
47mod replay;
48
49pub use batch::{
50    BatchInput, BatchJob, BatchScheduler, EntryMetadata, Insight, InsightStore, JobRun, JobStatus,
51    JobType, KnowledgeClass, RelationshipType, Trend,
52};
53pub use feedback::{
54    FeedbackCollector, FeedbackConfig, FeedbackProcessor, FeedbackSignal, ProcessedFeedback,
55    QueryId, SessionId, SignalType,
56};
57pub use gnn::GnnLayer;
58pub use replay::ReplayBuffer;
59
60use crate::core::SearchResult;
61use dashmap::DashMap;
62use serde::{Deserialize, Serialize};
63use std::collections::VecDeque;
64use uuid::Uuid;
65
66/// Learning engine that improves search quality over time.
67///
68/// The engine uses several mechanisms inspired by the ruvector-gnn crate:
69/// - Experience replay for stable learning
70/// - GNN-style message passing for relevance propagation
71/// - Elastic weight consolidation to prevent catastrophic forgetting
72pub struct LearningEngine {
73    /// Embedding dimensions.
74    dimensions: usize,
75
76    /// Learning rate.
77    learning_rate: f32,
78
79    /// Query-result relevance weights.
80    relevance_weights: Vec<f32>,
81
82    /// GNN layer for embedding transformation.
83    gnn_layer: GnnLayer,
84
85    /// Experience replay buffer.
86    replay_buffer: ReplayBuffer<Experience>,
87
88    /// Query patterns for learning (reserved for future pattern-based learning).
89    #[allow(dead_code)]
90    query_patterns: VecDeque<QueryPattern>,
91
92    /// Entry relevance scores learned from feedback.
93    entry_scores: DashMap<Uuid, f32>,
94
95    /// Fisher information for EWC.
96    fisher_diagonal: Vec<f32>,
97
98    /// Total queries processed.
99    query_count: u64,
100}
101
102/// A recorded query pattern for learning.
103#[derive(Debug, Clone, Serialize, Deserialize)]
104struct QueryPattern {
105    /// Query embedding.
106    query_embedding: Vec<f32>,
107
108    /// Result embeddings with feedback.
109    result_embeddings: Vec<(Vec<f32>, f32)>,
110
111    /// Timestamp.
112    timestamp: u64,
113}
114
115impl LearningEngine {
116    /// Create a new learning engine.
117    pub fn new(dimensions: usize, learning_rate: f32) -> Self {
118        Self {
119            dimensions,
120            learning_rate,
121            relevance_weights: vec![1.0; dimensions],
122            gnn_layer: GnnLayer::new(dimensions, dimensions * 2, 4),
123            replay_buffer: ReplayBuffer::new(10000),
124            query_patterns: VecDeque::with_capacity(1000),
125            entry_scores: DashMap::new(),
126            fisher_diagonal: vec![0.0; dimensions],
127            query_count: 0,
128        }
129    }
130
131    /// Re-rank candidates based on learned patterns.
132    pub fn rerank(
133        &self,
134        query_embedding: &[f32],
135        mut candidates: Vec<(Uuid, f32)>,
136        vectors: &DashMap<Uuid, Vec<f32>>,
137    ) -> Vec<(Uuid, f32)> {
138        // Transform query through GNN layer
139        let neighbors: Vec<Vec<f32>> = candidates
140            .iter()
141            .take(10)
142            .filter_map(|(id, _)| vectors.get(id).map(|v| v.clone()))
143            .collect();
144
145        let edge_weights: Vec<f32> = candidates
146            .iter()
147            .take(10)
148            .map(|(_, d)| 1.0 - d.min(1.0))
149            .collect();
150
151        let transformed_query = self
152            .gnn_layer
153            .forward(query_embedding, &neighbors, &edge_weights);
154
155        // Re-compute distances with transformed query
156        for (id, distance) in candidates.iter_mut() {
157            if let Some(vector) = vectors.get(id) {
158                // Apply learned relevance weights
159                let weighted_distance = self.weighted_distance(&transformed_query, &vector);
160
161                // Apply entry-specific learned score
162                let entry_boost = self.entry_scores.get(id).map(|s| *s).unwrap_or(1.0);
163
164                *distance = weighted_distance / entry_boost;
165            }
166        }
167
168        // Re-sort
169        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
170
171        candidates
172    }
173
174    /// Compute weighted distance using learned weights.
175    fn weighted_distance(&self, a: &[f32], b: &[f32]) -> f32 {
176        let mut weighted_dot = 0.0f32;
177        let mut weighted_norm_a = 0.0f32;
178        let mut weighted_norm_b = 0.0f32;
179
180        for i in 0..a.len().min(b.len()).min(self.dimensions) {
181            let w = self.relevance_weights[i];
182            weighted_dot += a[i] * b[i] * w;
183            weighted_norm_a += a[i] * a[i] * w;
184            weighted_norm_b += b[i] * b[i] * w;
185        }
186
187        let norm = (weighted_norm_a * weighted_norm_b).sqrt();
188        if norm > 0.0 {
189            1.0 - (weighted_dot / norm)
190        } else {
191            1.0
192        }
193    }
194
195    /// Record a query and its results for learning.
196    pub fn record_query(&mut self, query_embedding: &[f32], results: &[SearchResult]) {
197        if results.is_empty() {
198            return;
199        }
200
201        self.query_count += 1;
202
203        // Store in replay buffer
204        for (rank, result) in results.iter().enumerate() {
205            self.replay_buffer.add(Experience {
206                query: query_embedding.to_vec(),
207                result_id: result.entry.id,
208                rank: rank as u32,
209                score: result.score,
210            });
211        }
212
213        // Periodic learning from replay buffer
214        if self.query_count % 10 == 0 {
215            self.learn_from_replay();
216        }
217    }
218
219    /// Record user feedback on a search result.
220    pub fn record_feedback(&mut self, result_embedding: &[f32], positive: bool) {
221        let adjustment = if positive {
222            self.learning_rate
223        } else {
224            -self.learning_rate * 0.5
225        };
226
227        // Adjust relevance weights based on feedback
228        for (i, &val) in result_embedding.iter().enumerate() {
229            if i < self.dimensions {
230                // Apply adjustment proportional to embedding value
231                let delta = adjustment * val.abs();
232
233                // EWC regularization: smaller updates for important weights
234                let ewc_factor = 1.0 / (1.0 + self.fisher_diagonal[i]);
235
236                self.relevance_weights[i] =
237                    (self.relevance_weights[i] + delta * ewc_factor).clamp(0.1, 10.0);
238            }
239        }
240
241        // Update Fisher information estimate
242        self.update_fisher(result_embedding);
243    }
244
245    /// Update Fisher information diagonal for EWC.
246    fn update_fisher(&mut self, embedding: &[f32]) {
247        for (i, &val) in embedding.iter().enumerate() {
248            if i < self.dimensions {
249                // Exponential moving average of squared gradients
250                self.fisher_diagonal[i] = 0.99 * self.fisher_diagonal[i] + 0.01 * val * val;
251            }
252        }
253    }
254
255    /// Learn from replay buffer samples.
256    fn learn_from_replay(&mut self) {
257        let samples = self.replay_buffer.sample(32);
258
259        for experience in samples {
260            // Higher-ranked results should have boosted scores
261            let target_boost = 1.0 + (1.0 / (1.0 + experience.rank as f32));
262
263            // Update entry score
264            self.entry_scores
265                .entry(experience.result_id)
266                .and_modify(|s| {
267                    *s = (*s + target_boost) / 2.0;
268                })
269                .or_insert(target_boost);
270
271            // Update GNN layer weights (simplified)
272            self.gnn_layer
273                .update(&experience.query, target_boost, self.learning_rate);
274        }
275    }
276
277    /// Get the current query count.
278    pub fn query_count(&self) -> u64 {
279        self.query_count
280    }
281
282    /// Get learning statistics.
283    pub fn stats(&self) -> LearningStats {
284        let avg_weight: f32 = self.relevance_weights.iter().sum::<f32>() / self.dimensions as f32;
285        let weight_variance: f32 = self
286            .relevance_weights
287            .iter()
288            .map(|w| (w - avg_weight).powi(2))
289            .sum::<f32>()
290            / self.dimensions as f32;
291
292        LearningStats {
293            query_count: self.query_count,
294            replay_buffer_size: self.replay_buffer.len(),
295            learned_entries: self.entry_scores.len(),
296            avg_relevance_weight: avg_weight,
297            weight_variance,
298        }
299    }
300}
301
302/// Experience for replay buffer.
303#[derive(Debug, Clone)]
304struct Experience {
305    query: Vec<f32>,
306    result_id: Uuid,
307    rank: u32,
308    /// Raw score at record time (reserved for future reward-weighted learning).
309    #[allow(dead_code)]
310    score: f32,
311}
312
313/// Learning statistics.
314#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct LearningStats {
316    /// Total number of queries processed.
317    pub query_count: u64,
318    /// Current number of experiences in the replay buffer.
319    pub replay_buffer_size: usize,
320    /// Number of entries with learned relevance scores.
321    pub learned_entries: usize,
322    /// Mean relevance weight across all embedding dimensions.
323    pub avg_relevance_weight: f32,
324    /// Variance of relevance weights (measures how specialised learning has become).
325    pub weight_variance: f32,
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn test_learning_engine_creation() {
334        let engine = LearningEngine::new(128, 0.01);
335        assert_eq!(engine.dimensions, 128);
336        assert_eq!(engine.query_count, 0);
337    }
338
339    #[test]
340    fn test_feedback_updates_weights() {
341        let mut engine = LearningEngine::new(64, 0.1);
342        let initial_weights = engine.relevance_weights.clone();
343
344        let embedding = vec![0.5; 64];
345        engine.record_feedback(&embedding, true);
346
347        // Weights should have changed
348        assert_ne!(engine.relevance_weights, initial_weights);
349    }
350}