omega_brain/
runtime_adaptation.rs

1//! Runtime Adaptation System
2//!
3//! Implements continuous learning without catastrophic forgetting:
4//! - MicroLoRA (rank 1-2): Instant adaptation for immediate context
5//! - BaseLoRA (rank 4-16): Long-term learning and skill acquisition
6//! - EWC++ (Elastic Weight Consolidation): Prevents forgetting
7//! - ReasoningBank: Stores successful reasoning patterns
8//!
9//! Inspired by ruvector-sona architecture.
10
11use serde::{Deserialize, Serialize};
12
13/// LoRA rank configuration
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum LoRARank {
16    /// MicroLoRA: rank 1-2 for instant adaptation
17    Micro(usize),
18    /// BaseLoRA: rank 4-16 for long-term learning
19    Base(usize),
20    /// Custom rank
21    Custom(usize),
22}
23
24impl Default for LoRARank {
25    fn default() -> Self {
26        Self::Base(8)
27    }
28}
29
30/// LoRA adapter configuration
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct LoRAConfig {
33    /// Rank of the low-rank decomposition
34    pub rank: LoRARank,
35    /// Alpha scaling factor
36    pub alpha: f64,
37    /// Dropout rate
38    pub dropout: f64,
39    /// Target dimensions
40    pub dim: usize,
41    /// Learning rate
42    pub learning_rate: f64,
43}
44
45impl Default for LoRAConfig {
46    fn default() -> Self {
47        Self {
48            rank: LoRARank::Base(8),
49            alpha: 16.0,
50            dropout: 0.0,
51            dim: 256,
52            learning_rate: 0.001,
53        }
54    }
55}
56
57/// LoRA adapter matrices
58#[derive(Debug, Clone)]
59pub struct LoRAAdapter {
60    /// Configuration
61    config: LoRAConfig,
62    /// Matrix A (d x r)
63    matrix_a: Vec<Vec<f64>>,
64    /// Matrix B (r x d)
65    matrix_b: Vec<Vec<f64>>,
66    /// Merged delta weights
67    delta_weights: Vec<Vec<f64>>,
68    /// Is merged with base
69    is_merged: bool,
70    /// Update count
71    update_count: u64,
72}
73
74impl LoRAAdapter {
75    /// Create new LoRA adapter
76    pub fn new(config: LoRAConfig) -> Self {
77        let rank = match config.rank {
78            LoRARank::Micro(r) => r,
79            LoRARank::Base(r) => r,
80            LoRARank::Custom(r) => r,
81        };
82
83        // Initialize A with small random values, B with zeros
84        let matrix_a: Vec<Vec<f64>> = (0..config.dim)
85            .map(|i| (0..rank).map(|j| ((i * j) as f64 * 0.001).sin() * 0.01).collect())
86            .collect();
87
88        let matrix_b: Vec<Vec<f64>> = (0..rank)
89            .map(|_| vec![0.0; config.dim])
90            .collect();
91
92        let delta_weights = vec![vec![0.0; config.dim]; config.dim];
93
94        Self {
95            config,
96            matrix_a,
97            matrix_b,
98            delta_weights,
99            is_merged: false,
100            update_count: 0,
101        }
102    }
103
104    /// Apply LoRA transformation
105    pub fn apply(&self, input: &[f64]) -> Vec<f64> {
106        if self.is_merged {
107            // Use merged weights
108            return self.apply_merged(input);
109        }
110
111        // Compute A * x
112        let rank = self.matrix_a.first().map_or(0, |r| r.len());
113        let mut intermediate = vec![0.0; rank];
114        for (i, row) in self.matrix_a.iter().enumerate() {
115            if i < input.len() {
116                for (j, &a) in row.iter().enumerate() {
117                    intermediate[j] += a * input[i];
118                }
119            }
120        }
121
122        // Compute B * (A * x)
123        let mut output = vec![0.0; self.config.dim];
124        for (i, row) in self.matrix_b.iter().enumerate() {
125            for (j, &b) in row.iter().enumerate() {
126                if j < output.len() {
127                    output[j] += b * intermediate[i];
128                }
129            }
130        }
131
132        // Scale by alpha / rank
133        let rank_val = match self.config.rank {
134            LoRARank::Micro(r) => r,
135            LoRARank::Base(r) => r,
136            LoRARank::Custom(r) => r,
137        } as f64;
138        let scale = self.config.alpha / rank_val;
139
140        for v in &mut output {
141            *v *= scale;
142        }
143
144        output
145    }
146
147    /// Apply using merged weights
148    fn apply_merged(&self, input: &[f64]) -> Vec<f64> {
149        let mut output = vec![0.0; self.config.dim];
150        for (i, row) in self.delta_weights.iter().enumerate() {
151            for (j, &w) in row.iter().enumerate() {
152                if j < input.len() {
153                    output[i] += w * input[j];
154                }
155            }
156        }
157        output
158    }
159
160    /// Update adapter weights
161    pub fn update(&mut self, input: &[f64], target: &[f64]) {
162        // Compute current output
163        let output = self.apply(input);
164
165        // Compute error
166        let error: Vec<f64> = target
167            .iter()
168            .zip(output.iter())
169            .map(|(&t, &o)| t - o)
170            .collect();
171
172        let rank = self.matrix_a.first().map(|r| r.len()).unwrap_or(0);
173
174        // Update B (gradient descent)
175        for (i, row) in self.matrix_b.iter_mut().enumerate() {
176            if i < rank {
177                for (j, b) in row.iter_mut().enumerate() {
178                    if j < error.len() {
179                        // Simplified gradient
180                        let grad = error[j] * self.matrix_a.first().and_then(|r| r.get(i)).copied().unwrap_or(0.0);
181                        *b += self.config.learning_rate * grad;
182                    }
183                }
184            }
185        }
186
187        // Update A
188        for (i, row) in self.matrix_a.iter_mut().enumerate() {
189            if i < input.len() {
190                for (j, a) in row.iter_mut().enumerate() {
191                    if j < rank {
192                        let grad = error.get(i).copied().unwrap_or(0.0) * input[i];
193                        *a += self.config.learning_rate * grad * 0.1;
194                    }
195                }
196            }
197        }
198
199        self.update_count += 1;
200    }
201
202    /// Merge LoRA weights into delta
203    pub fn merge(&mut self) {
204        if self.is_merged {
205            return;
206        }
207
208        let rank = self.matrix_a.first().map(|r| r.len()).unwrap_or(0);
209        let rank_val = match self.config.rank {
210            LoRARank::Micro(r) => r,
211            LoRARank::Base(r) => r,
212            LoRARank::Custom(r) => r,
213        } as f64;
214        let scale = self.config.alpha / rank_val;
215
216        // Compute BA
217        for i in 0..self.config.dim {
218            for j in 0..self.config.dim {
219                let mut sum = 0.0;
220                for k in 0..rank {
221                    let a_val = self.matrix_a.get(j).and_then(|r| r.get(k)).copied().unwrap_or(0.0);
222                    let b_val = self.matrix_b.get(k).and_then(|r| r.get(i)).copied().unwrap_or(0.0);
223                    sum += b_val * a_val;
224                }
225                self.delta_weights[i][j] = sum * scale;
226            }
227        }
228
229        self.is_merged = true;
230    }
231
232    /// Get update count
233    pub fn update_count(&self) -> u64 {
234        self.update_count
235    }
236}
237
238/// EWC++ (Elastic Weight Consolidation) for preventing catastrophic forgetting
239#[derive(Debug, Clone)]
240pub struct EWCPlusPlus {
241    /// Fisher information diagonal
242    fisher: Vec<f64>,
243    /// Optimal weights from previous task
244    optimal_weights: Vec<f64>,
245    /// Lambda (importance weight)
246    lambda: f64,
247    /// Online update rate
248    gamma: f64,
249    /// Sample count
250    sample_count: u64,
251}
252
253impl EWCPlusPlus {
254    /// Create new EWC++
255    pub fn new(dim: usize, lambda: f64) -> Self {
256        Self {
257            fisher: vec![0.0; dim],
258            optimal_weights: vec![0.0; dim],
259            lambda,
260            gamma: 0.9,
261            sample_count: 0,
262        }
263    }
264
265    /// Update Fisher information online
266    pub fn update_fisher(&mut self, gradients: &[f64]) {
267        for (i, &g) in gradients.iter().enumerate() {
268            if i < self.fisher.len() {
269                // Online update: F_new = gamma * F_old + (1 - gamma) * g^2
270                self.fisher[i] = self.gamma * self.fisher[i] + (1.0 - self.gamma) * g * g;
271            }
272        }
273        self.sample_count += 1;
274    }
275
276    /// Store optimal weights
277    pub fn store_optimal(&mut self, weights: &[f64]) {
278        for (i, &w) in weights.iter().enumerate() {
279            if i < self.optimal_weights.len() {
280                self.optimal_weights[i] = w;
281            }
282        }
283    }
284
285    /// Compute EWC penalty
286    pub fn penalty(&self, current_weights: &[f64]) -> f64 {
287        let mut penalty = 0.0;
288        for (i, (&current_w, &optimal_w)) in current_weights.iter().zip(self.optimal_weights.iter()).enumerate().take(self.fisher.len()) {
289            let diff = current_w - optimal_w;
290            penalty += self.fisher[i] * diff * diff;
291        }
292        0.5 * self.lambda * penalty
293    }
294
295    /// Regularized gradient
296    pub fn regularize_gradient(&self, gradients: &mut [f64], current_weights: &[f64]) {
297        for i in 0..gradients.len().min(self.fisher.len()) {
298            let ewc_grad = self.lambda * self.fisher[i] * (current_weights[i] - self.optimal_weights[i]);
299            gradients[i] += ewc_grad;
300        }
301    }
302}
303
304/// Reasoning pattern stored in ReasoningBank
305#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct ReasoningPattern {
307    /// Pattern ID
308    pub id: String,
309    /// Input embedding
310    pub input: Vec<f64>,
311    /// Output embedding
312    pub output: Vec<f64>,
313    /// Success score
314    pub score: f64,
315    /// Usage count
316    pub usage_count: u64,
317    /// Cluster ID
318    pub cluster_id: usize,
319}
320
321/// Cosine similarity (free function to avoid borrow conflicts)
322fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
323    if a.len() != b.len() {
324        return 0.0;
325    }
326    let mut dot = 0.0;
327    let mut norm_a = 0.0;
328    let mut norm_b = 0.0;
329    for (&x, &y) in a.iter().zip(b.iter()) {
330        dot += x * y;
331        norm_a += x * x;
332        norm_b += y * y;
333    }
334    let denom = (norm_a * norm_b).sqrt();
335    if denom > 0.0 { dot / denom } else { 0.0 }
336}
337
338/// Euclidean distance (free function to avoid borrow conflicts)
339fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
340    a.iter()
341        .zip(b.iter())
342        .map(|(&x, &y)| (x - y).powi(2))
343        .sum::<f64>()
344        .sqrt()
345}
346
347/// ReasoningBank using K-means++ clustering
348#[derive(Debug, Clone)]
349pub struct ReasoningBank {
350    /// Stored patterns
351    patterns: Vec<ReasoningPattern>,
352    /// Cluster centroids
353    centroids: Vec<Vec<f64>>,
354    /// Number of clusters
355    num_clusters: usize,
356    /// Maximum patterns
357    max_patterns: usize,
358}
359
360impl ReasoningBank {
361    /// Create new ReasoningBank
362    pub fn new(num_clusters: usize, max_patterns: usize) -> Self {
363        Self {
364            patterns: Vec::new(),
365            centroids: Vec::new(),
366            num_clusters,
367            max_patterns,
368        }
369    }
370
371    /// Store a reasoning pattern
372    pub fn store(&mut self, pattern: ReasoningPattern) {
373        if self.patterns.len() >= self.max_patterns {
374            // Remove lowest scoring pattern
375            if let Some(min_idx) = self
376                .patterns
377                .iter()
378                .enumerate()
379                .min_by(|(_, a), (_, b)| a.score.partial_cmp(&b.score).unwrap())
380                .map(|(i, _)| i)
381            {
382                self.patterns.remove(min_idx);
383            }
384        }
385
386        self.patterns.push(pattern);
387    }
388
389    /// Retrieve similar patterns
390    pub fn retrieve(&self, query: &[f64], k: usize) -> Vec<&ReasoningPattern> {
391        let mut scored: Vec<(f64, &ReasoningPattern)> = self
392            .patterns
393            .iter()
394            .map(|p| {
395                let sim = cosine_similarity(query, &p.input);
396                (sim, p)
397            })
398            .collect();
399
400        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
401        scored.into_iter().take(k).map(|(_, p)| p).collect()
402    }
403
404    /// Update centroids using K-means++
405    pub fn update_clusters(&mut self) {
406        if self.patterns.is_empty() {
407            return;
408        }
409
410        let dim = self.patterns[0].input.len();
411
412        // Initialize centroids if needed
413        if self.centroids.is_empty() || self.centroids.len() != self.num_clusters {
414            self.initialize_centroids_kmeans_pp(dim);
415        }
416
417        // Clone centroids to avoid borrow conflicts
418        let centroids_snapshot = self.centroids.clone();
419
420        // Assign patterns to clusters and update centroids
421        let mut cluster_sums: Vec<Vec<f64>> = vec![vec![0.0; dim]; self.num_clusters];
422        let mut cluster_counts: Vec<usize> = vec![0; self.num_clusters];
423
424        for pattern in &mut self.patterns {
425            // Find nearest centroid
426            let nearest = centroids_snapshot
427                .iter()
428                .enumerate()
429                .map(|(i, c)| (i, euclidean_distance(&pattern.input, c)))
430                .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
431                .map(|(i, _)| i)
432                .unwrap_or(0);
433
434            pattern.cluster_id = nearest;
435
436            // Accumulate for centroid update
437            for (j, &v) in pattern.input.iter().enumerate() {
438                if j < dim {
439                    cluster_sums[nearest][j] += v;
440                }
441            }
442            cluster_counts[nearest] += 1;
443        }
444
445        // Update centroids
446        for (i, centroid) in self.centroids.iter_mut().enumerate() {
447            if cluster_counts[i] > 0 {
448                for (j, c) in centroid.iter_mut().enumerate() {
449                    *c = cluster_sums[i][j] / cluster_counts[i] as f64;
450                }
451            }
452        }
453    }
454
455    fn initialize_centroids_kmeans_pp(&mut self, dim: usize) {
456        self.centroids = Vec::with_capacity(self.num_clusters);
457
458        if self.patterns.is_empty() {
459            for _ in 0..self.num_clusters {
460                self.centroids.push(vec![0.0; dim]);
461            }
462            return;
463        }
464
465        // First centroid: random pattern
466        self.centroids.push(self.patterns[0].input.clone());
467
468        // Remaining centroids: weighted by distance
469        for _ in 1..self.num_clusters {
470            let distances: Vec<f64> = self
471                .patterns
472                .iter()
473                .map(|p| {
474                    self.centroids
475                        .iter()
476                        .map(|c| euclidean_distance(&p.input, c))
477                        .fold(f64::INFINITY, f64::min)
478                })
479                .collect();
480
481            let total: f64 = distances.iter().map(|d| d * d).sum();
482            if total > 0.0 {
483                // Simplified: pick pattern with maximum distance
484                let max_idx = distances
485                    .iter()
486                    .enumerate()
487                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
488                    .map(|(i, _)| i)
489                    .unwrap_or(0);
490                self.centroids.push(self.patterns[max_idx].input.clone());
491            } else {
492                self.centroids.push(vec![0.0; dim]);
493            }
494        }
495    }
496
497    /// Get pattern count
498    pub fn len(&self) -> usize {
499        self.patterns.len()
500    }
501
502    /// Check if empty
503    pub fn is_empty(&self) -> bool {
504        self.patterns.is_empty()
505    }
506}
507
508/// Complete runtime adaptation system
509pub struct RuntimeAdaptation {
510    /// MicroLoRA for instant adaptation
511    micro_lora: LoRAAdapter,
512    /// BaseLoRA for long-term learning
513    base_lora: LoRAAdapter,
514    /// EWC++ for forgetting prevention
515    ewc: EWCPlusPlus,
516    /// ReasoningBank for pattern storage
517    reasoning_bank: ReasoningBank,
518    /// Current weights
519    current_weights: Vec<f64>,
520    /// Adaptation count
521    adaptation_count: u64,
522}
523
524impl RuntimeAdaptation {
525    /// Create new runtime adaptation system
526    pub fn new(dim: usize) -> Self {
527        let micro_config = LoRAConfig {
528            rank: LoRARank::Micro(2),
529            alpha: 4.0,
530            learning_rate: 0.01,
531            dim,
532            ..Default::default()
533        };
534
535        let base_config = LoRAConfig {
536            rank: LoRARank::Base(8),
537            alpha: 16.0,
538            learning_rate: 0.001,
539            dim,
540            ..Default::default()
541        };
542
543        Self {
544            micro_lora: LoRAAdapter::new(micro_config),
545            base_lora: LoRAAdapter::new(base_config),
546            ewc: EWCPlusPlus::new(dim, 1000.0),
547            reasoning_bank: ReasoningBank::new(10, 1000),
548            current_weights: vec![0.0; dim],
549            adaptation_count: 0,
550        }
551    }
552
553    /// Adapt to new input-output pair
554    pub fn adapt(&mut self, input: &[f64], output: &[f64]) {
555        // Quick adaptation with MicroLoRA
556        self.micro_lora.update(input, output);
557
558        // Slow adaptation with BaseLoRA
559        if self.adaptation_count % 10 == 0 {
560            self.base_lora.update(input, output);
561        }
562
563        // Store successful pattern
564        let pattern = ReasoningPattern {
565            id: format!("pattern_{}", self.adaptation_count),
566            input: input.to_vec(),
567            output: output.to_vec(),
568            score: 1.0,
569            usage_count: 1,
570            cluster_id: 0,
571        };
572        self.reasoning_bank.store(pattern);
573
574        self.adaptation_count += 1;
575    }
576
577    /// Apply adaptation to input
578    pub fn apply(&self, input: &[f64]) -> Vec<f64> {
579        let micro_out = self.micro_lora.apply(input);
580        let base_out = self.base_lora.apply(input);
581
582        // Combine outputs
583        micro_out
584            .iter()
585            .zip(base_out.iter())
586            .zip(input.iter())
587            .map(|((&m, &b), &i)| i + 0.6 * m + 0.4 * b)
588            .collect()
589    }
590
591    /// Consolidate learning (call periodically)
592    pub fn consolidate(&mut self) {
593        // Merge MicroLoRA into BaseLoRA
594        self.micro_lora.merge();
595
596        // Update EWC Fisher information
597        let gradients = vec![0.01; self.current_weights.len()]; // Placeholder
598        self.ewc.update_fisher(&gradients);
599
600        // Store current weights as optimal
601        self.ewc.store_optimal(&self.current_weights);
602
603        // Update reasoning bank clusters
604        self.reasoning_bank.update_clusters();
605    }
606
607    /// Get adaptation stats
608    pub fn stats(&self) -> AdaptationStats {
609        AdaptationStats {
610            adaptation_count: self.adaptation_count,
611            micro_lora_updates: self.micro_lora.update_count(),
612            base_lora_updates: self.base_lora.update_count(),
613            reasoning_patterns: self.reasoning_bank.len(),
614            ewc_samples: self.ewc.sample_count,
615        }
616    }
617}
618
619/// Adaptation statistics
620#[derive(Debug, Clone, Serialize, Deserialize)]
621pub struct AdaptationStats {
622    /// Total adaptations
623    pub adaptation_count: u64,
624    /// MicroLoRA updates
625    pub micro_lora_updates: u64,
626    /// BaseLoRA updates
627    pub base_lora_updates: u64,
628    /// Stored reasoning patterns
629    pub reasoning_patterns: usize,
630    /// EWC samples
631    pub ewc_samples: u64,
632}
633
634#[cfg(test)]
635mod tests {
636    use super::*;
637
638    #[test]
639    fn test_lora_adapter() {
640        let config = LoRAConfig {
641            dim: 16,
642            rank: LoRARank::Base(4),
643            ..Default::default()
644        };
645        let mut adapter = LoRAAdapter::new(config);
646
647        let input = vec![0.5; 16];
648        let output = adapter.apply(&input);
649        assert_eq!(output.len(), 16);
650
651        let target = vec![0.3; 16];
652        adapter.update(&input, &target);
653        assert_eq!(adapter.update_count(), 1);
654    }
655
656    #[test]
657    fn test_ewc() {
658        let mut ewc = EWCPlusPlus::new(8, 100.0);
659
660        let gradients = vec![0.1; 8];
661        ewc.update_fisher(&gradients);
662
663        let weights = vec![0.5; 8];
664        ewc.store_optimal(&weights);
665
666        let current = vec![0.6; 8];
667        let penalty = ewc.penalty(&current);
668        assert!(penalty > 0.0);
669    }
670
671    #[test]
672    fn test_reasoning_bank() {
673        let mut bank = ReasoningBank::new(5, 100);
674
675        let pattern = ReasoningPattern {
676            id: "test".to_string(),
677            input: vec![0.5; 8],
678            output: vec![0.3; 8],
679            score: 1.0,
680            usage_count: 1,
681            cluster_id: 0,
682        };
683        bank.store(pattern);
684
685        let query = vec![0.5; 8];
686        let results = bank.retrieve(&query, 1);
687        assert_eq!(results.len(), 1);
688    }
689
690    #[test]
691    fn test_runtime_adaptation() {
692        let mut adapter = RuntimeAdaptation::new(16);
693
694        let input = vec![0.5; 16];
695        let output = vec![0.3; 16];
696        adapter.adapt(&input, &output);
697
698        let result = adapter.apply(&input);
699        assert_eq!(result.len(), 16);
700
701        let stats = adapter.stats();
702        assert_eq!(stats.adaptation_count, 1);
703    }
704}