1use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum LoRARank {
16 Micro(usize),
18 Base(usize),
20 Custom(usize),
22}
23
24impl Default for LoRARank {
25 fn default() -> Self {
26 Self::Base(8)
27 }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct LoRAConfig {
33 pub rank: LoRARank,
35 pub alpha: f64,
37 pub dropout: f64,
39 pub dim: usize,
41 pub learning_rate: f64,
43}
44
45impl Default for LoRAConfig {
46 fn default() -> Self {
47 Self {
48 rank: LoRARank::Base(8),
49 alpha: 16.0,
50 dropout: 0.0,
51 dim: 256,
52 learning_rate: 0.001,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct LoRAAdapter {
60 config: LoRAConfig,
62 matrix_a: Vec<Vec<f64>>,
64 matrix_b: Vec<Vec<f64>>,
66 delta_weights: Vec<Vec<f64>>,
68 is_merged: bool,
70 update_count: u64,
72}
73
74impl LoRAAdapter {
75 pub fn new(config: LoRAConfig) -> Self {
77 let rank = match config.rank {
78 LoRARank::Micro(r) => r,
79 LoRARank::Base(r) => r,
80 LoRARank::Custom(r) => r,
81 };
82
83 let matrix_a: Vec<Vec<f64>> = (0..config.dim)
85 .map(|i| (0..rank).map(|j| ((i * j) as f64 * 0.001).sin() * 0.01).collect())
86 .collect();
87
88 let matrix_b: Vec<Vec<f64>> = (0..rank)
89 .map(|_| vec![0.0; config.dim])
90 .collect();
91
92 let delta_weights = vec![vec![0.0; config.dim]; config.dim];
93
94 Self {
95 config,
96 matrix_a,
97 matrix_b,
98 delta_weights,
99 is_merged: false,
100 update_count: 0,
101 }
102 }
103
104 pub fn apply(&self, input: &[f64]) -> Vec<f64> {
106 if self.is_merged {
107 return self.apply_merged(input);
109 }
110
111 let rank = self.matrix_a.first().map_or(0, |r| r.len());
113 let mut intermediate = vec![0.0; rank];
114 for (i, row) in self.matrix_a.iter().enumerate() {
115 if i < input.len() {
116 for (j, &a) in row.iter().enumerate() {
117 intermediate[j] += a * input[i];
118 }
119 }
120 }
121
122 let mut output = vec![0.0; self.config.dim];
124 for (i, row) in self.matrix_b.iter().enumerate() {
125 for (j, &b) in row.iter().enumerate() {
126 if j < output.len() {
127 output[j] += b * intermediate[i];
128 }
129 }
130 }
131
132 let rank_val = match self.config.rank {
134 LoRARank::Micro(r) => r,
135 LoRARank::Base(r) => r,
136 LoRARank::Custom(r) => r,
137 } as f64;
138 let scale = self.config.alpha / rank_val;
139
140 for v in &mut output {
141 *v *= scale;
142 }
143
144 output
145 }
146
147 fn apply_merged(&self, input: &[f64]) -> Vec<f64> {
149 let mut output = vec![0.0; self.config.dim];
150 for (i, row) in self.delta_weights.iter().enumerate() {
151 for (j, &w) in row.iter().enumerate() {
152 if j < input.len() {
153 output[i] += w * input[j];
154 }
155 }
156 }
157 output
158 }
159
160 pub fn update(&mut self, input: &[f64], target: &[f64]) {
162 let output = self.apply(input);
164
165 let error: Vec<f64> = target
167 .iter()
168 .zip(output.iter())
169 .map(|(&t, &o)| t - o)
170 .collect();
171
172 let rank = self.matrix_a.first().map(|r| r.len()).unwrap_or(0);
173
174 for (i, row) in self.matrix_b.iter_mut().enumerate() {
176 if i < rank {
177 for (j, b) in row.iter_mut().enumerate() {
178 if j < error.len() {
179 let grad = error[j] * self.matrix_a.first().and_then(|r| r.get(i)).copied().unwrap_or(0.0);
181 *b += self.config.learning_rate * grad;
182 }
183 }
184 }
185 }
186
187 for (i, row) in self.matrix_a.iter_mut().enumerate() {
189 if i < input.len() {
190 for (j, a) in row.iter_mut().enumerate() {
191 if j < rank {
192 let grad = error.get(i).copied().unwrap_or(0.0) * input[i];
193 *a += self.config.learning_rate * grad * 0.1;
194 }
195 }
196 }
197 }
198
199 self.update_count += 1;
200 }
201
202 pub fn merge(&mut self) {
204 if self.is_merged {
205 return;
206 }
207
208 let rank = self.matrix_a.first().map(|r| r.len()).unwrap_or(0);
209 let rank_val = match self.config.rank {
210 LoRARank::Micro(r) => r,
211 LoRARank::Base(r) => r,
212 LoRARank::Custom(r) => r,
213 } as f64;
214 let scale = self.config.alpha / rank_val;
215
216 for i in 0..self.config.dim {
218 for j in 0..self.config.dim {
219 let mut sum = 0.0;
220 for k in 0..rank {
221 let a_val = self.matrix_a.get(j).and_then(|r| r.get(k)).copied().unwrap_or(0.0);
222 let b_val = self.matrix_b.get(k).and_then(|r| r.get(i)).copied().unwrap_or(0.0);
223 sum += b_val * a_val;
224 }
225 self.delta_weights[i][j] = sum * scale;
226 }
227 }
228
229 self.is_merged = true;
230 }
231
232 pub fn update_count(&self) -> u64 {
234 self.update_count
235 }
236}
237
238#[derive(Debug, Clone)]
240pub struct EWCPlusPlus {
241 fisher: Vec<f64>,
243 optimal_weights: Vec<f64>,
245 lambda: f64,
247 gamma: f64,
249 sample_count: u64,
251}
252
253impl EWCPlusPlus {
254 pub fn new(dim: usize, lambda: f64) -> Self {
256 Self {
257 fisher: vec![0.0; dim],
258 optimal_weights: vec![0.0; dim],
259 lambda,
260 gamma: 0.9,
261 sample_count: 0,
262 }
263 }
264
265 pub fn update_fisher(&mut self, gradients: &[f64]) {
267 for (i, &g) in gradients.iter().enumerate() {
268 if i < self.fisher.len() {
269 self.fisher[i] = self.gamma * self.fisher[i] + (1.0 - self.gamma) * g * g;
271 }
272 }
273 self.sample_count += 1;
274 }
275
276 pub fn store_optimal(&mut self, weights: &[f64]) {
278 for (i, &w) in weights.iter().enumerate() {
279 if i < self.optimal_weights.len() {
280 self.optimal_weights[i] = w;
281 }
282 }
283 }
284
285 pub fn penalty(&self, current_weights: &[f64]) -> f64 {
287 let mut penalty = 0.0;
288 for (i, (¤t_w, &optimal_w)) in current_weights.iter().zip(self.optimal_weights.iter()).enumerate().take(self.fisher.len()) {
289 let diff = current_w - optimal_w;
290 penalty += self.fisher[i] * diff * diff;
291 }
292 0.5 * self.lambda * penalty
293 }
294
295 pub fn regularize_gradient(&self, gradients: &mut [f64], current_weights: &[f64]) {
297 for i in 0..gradients.len().min(self.fisher.len()) {
298 let ewc_grad = self.lambda * self.fisher[i] * (current_weights[i] - self.optimal_weights[i]);
299 gradients[i] += ewc_grad;
300 }
301 }
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct ReasoningPattern {
307 pub id: String,
309 pub input: Vec<f64>,
311 pub output: Vec<f64>,
313 pub score: f64,
315 pub usage_count: u64,
317 pub cluster_id: usize,
319}
320
321fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
323 if a.len() != b.len() {
324 return 0.0;
325 }
326 let mut dot = 0.0;
327 let mut norm_a = 0.0;
328 let mut norm_b = 0.0;
329 for (&x, &y) in a.iter().zip(b.iter()) {
330 dot += x * y;
331 norm_a += x * x;
332 norm_b += y * y;
333 }
334 let denom = (norm_a * norm_b).sqrt();
335 if denom > 0.0 { dot / denom } else { 0.0 }
336}
337
338fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
340 a.iter()
341 .zip(b.iter())
342 .map(|(&x, &y)| (x - y).powi(2))
343 .sum::<f64>()
344 .sqrt()
345}
346
347#[derive(Debug, Clone)]
349pub struct ReasoningBank {
350 patterns: Vec<ReasoningPattern>,
352 centroids: Vec<Vec<f64>>,
354 num_clusters: usize,
356 max_patterns: usize,
358}
359
360impl ReasoningBank {
361 pub fn new(num_clusters: usize, max_patterns: usize) -> Self {
363 Self {
364 patterns: Vec::new(),
365 centroids: Vec::new(),
366 num_clusters,
367 max_patterns,
368 }
369 }
370
371 pub fn store(&mut self, pattern: ReasoningPattern) {
373 if self.patterns.len() >= self.max_patterns {
374 if let Some(min_idx) = self
376 .patterns
377 .iter()
378 .enumerate()
379 .min_by(|(_, a), (_, b)| a.score.partial_cmp(&b.score).unwrap())
380 .map(|(i, _)| i)
381 {
382 self.patterns.remove(min_idx);
383 }
384 }
385
386 self.patterns.push(pattern);
387 }
388
389 pub fn retrieve(&self, query: &[f64], k: usize) -> Vec<&ReasoningPattern> {
391 let mut scored: Vec<(f64, &ReasoningPattern)> = self
392 .patterns
393 .iter()
394 .map(|p| {
395 let sim = cosine_similarity(query, &p.input);
396 (sim, p)
397 })
398 .collect();
399
400 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
401 scored.into_iter().take(k).map(|(_, p)| p).collect()
402 }
403
404 pub fn update_clusters(&mut self) {
406 if self.patterns.is_empty() {
407 return;
408 }
409
410 let dim = self.patterns[0].input.len();
411
412 if self.centroids.is_empty() || self.centroids.len() != self.num_clusters {
414 self.initialize_centroids_kmeans_pp(dim);
415 }
416
417 let centroids_snapshot = self.centroids.clone();
419
420 let mut cluster_sums: Vec<Vec<f64>> = vec![vec![0.0; dim]; self.num_clusters];
422 let mut cluster_counts: Vec<usize> = vec![0; self.num_clusters];
423
424 for pattern in &mut self.patterns {
425 let nearest = centroids_snapshot
427 .iter()
428 .enumerate()
429 .map(|(i, c)| (i, euclidean_distance(&pattern.input, c)))
430 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
431 .map(|(i, _)| i)
432 .unwrap_or(0);
433
434 pattern.cluster_id = nearest;
435
436 for (j, &v) in pattern.input.iter().enumerate() {
438 if j < dim {
439 cluster_sums[nearest][j] += v;
440 }
441 }
442 cluster_counts[nearest] += 1;
443 }
444
445 for (i, centroid) in self.centroids.iter_mut().enumerate() {
447 if cluster_counts[i] > 0 {
448 for (j, c) in centroid.iter_mut().enumerate() {
449 *c = cluster_sums[i][j] / cluster_counts[i] as f64;
450 }
451 }
452 }
453 }
454
455 fn initialize_centroids_kmeans_pp(&mut self, dim: usize) {
456 self.centroids = Vec::with_capacity(self.num_clusters);
457
458 if self.patterns.is_empty() {
459 for _ in 0..self.num_clusters {
460 self.centroids.push(vec![0.0; dim]);
461 }
462 return;
463 }
464
465 self.centroids.push(self.patterns[0].input.clone());
467
468 for _ in 1..self.num_clusters {
470 let distances: Vec<f64> = self
471 .patterns
472 .iter()
473 .map(|p| {
474 self.centroids
475 .iter()
476 .map(|c| euclidean_distance(&p.input, c))
477 .fold(f64::INFINITY, f64::min)
478 })
479 .collect();
480
481 let total: f64 = distances.iter().map(|d| d * d).sum();
482 if total > 0.0 {
483 let max_idx = distances
485 .iter()
486 .enumerate()
487 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
488 .map(|(i, _)| i)
489 .unwrap_or(0);
490 self.centroids.push(self.patterns[max_idx].input.clone());
491 } else {
492 self.centroids.push(vec![0.0; dim]);
493 }
494 }
495 }
496
497 pub fn len(&self) -> usize {
499 self.patterns.len()
500 }
501
502 pub fn is_empty(&self) -> bool {
504 self.patterns.is_empty()
505 }
506}
507
508pub struct RuntimeAdaptation {
510 micro_lora: LoRAAdapter,
512 base_lora: LoRAAdapter,
514 ewc: EWCPlusPlus,
516 reasoning_bank: ReasoningBank,
518 current_weights: Vec<f64>,
520 adaptation_count: u64,
522}
523
524impl RuntimeAdaptation {
525 pub fn new(dim: usize) -> Self {
527 let micro_config = LoRAConfig {
528 rank: LoRARank::Micro(2),
529 alpha: 4.0,
530 learning_rate: 0.01,
531 dim,
532 ..Default::default()
533 };
534
535 let base_config = LoRAConfig {
536 rank: LoRARank::Base(8),
537 alpha: 16.0,
538 learning_rate: 0.001,
539 dim,
540 ..Default::default()
541 };
542
543 Self {
544 micro_lora: LoRAAdapter::new(micro_config),
545 base_lora: LoRAAdapter::new(base_config),
546 ewc: EWCPlusPlus::new(dim, 1000.0),
547 reasoning_bank: ReasoningBank::new(10, 1000),
548 current_weights: vec![0.0; dim],
549 adaptation_count: 0,
550 }
551 }
552
553 pub fn adapt(&mut self, input: &[f64], output: &[f64]) {
555 self.micro_lora.update(input, output);
557
558 if self.adaptation_count % 10 == 0 {
560 self.base_lora.update(input, output);
561 }
562
563 let pattern = ReasoningPattern {
565 id: format!("pattern_{}", self.adaptation_count),
566 input: input.to_vec(),
567 output: output.to_vec(),
568 score: 1.0,
569 usage_count: 1,
570 cluster_id: 0,
571 };
572 self.reasoning_bank.store(pattern);
573
574 self.adaptation_count += 1;
575 }
576
577 pub fn apply(&self, input: &[f64]) -> Vec<f64> {
579 let micro_out = self.micro_lora.apply(input);
580 let base_out = self.base_lora.apply(input);
581
582 micro_out
584 .iter()
585 .zip(base_out.iter())
586 .zip(input.iter())
587 .map(|((&m, &b), &i)| i + 0.6 * m + 0.4 * b)
588 .collect()
589 }
590
591 pub fn consolidate(&mut self) {
593 self.micro_lora.merge();
595
596 let gradients = vec![0.01; self.current_weights.len()]; self.ewc.update_fisher(&gradients);
599
600 self.ewc.store_optimal(&self.current_weights);
602
603 self.reasoning_bank.update_clusters();
605 }
606
607 pub fn stats(&self) -> AdaptationStats {
609 AdaptationStats {
610 adaptation_count: self.adaptation_count,
611 micro_lora_updates: self.micro_lora.update_count(),
612 base_lora_updates: self.base_lora.update_count(),
613 reasoning_patterns: self.reasoning_bank.len(),
614 ewc_samples: self.ewc.sample_count,
615 }
616 }
617}
618
619#[derive(Debug, Clone, Serialize, Deserialize)]
621pub struct AdaptationStats {
622 pub adaptation_count: u64,
624 pub micro_lora_updates: u64,
626 pub base_lora_updates: u64,
628 pub reasoning_patterns: usize,
630 pub ewc_samples: u64,
632}
633
634#[cfg(test)]
635mod tests {
636 use super::*;
637
638 #[test]
639 fn test_lora_adapter() {
640 let config = LoRAConfig {
641 dim: 16,
642 rank: LoRARank::Base(4),
643 ..Default::default()
644 };
645 let mut adapter = LoRAAdapter::new(config);
646
647 let input = vec![0.5; 16];
648 let output = adapter.apply(&input);
649 assert_eq!(output.len(), 16);
650
651 let target = vec![0.3; 16];
652 adapter.update(&input, &target);
653 assert_eq!(adapter.update_count(), 1);
654 }
655
656 #[test]
657 fn test_ewc() {
658 let mut ewc = EWCPlusPlus::new(8, 100.0);
659
660 let gradients = vec![0.1; 8];
661 ewc.update_fisher(&gradients);
662
663 let weights = vec![0.5; 8];
664 ewc.store_optimal(&weights);
665
666 let current = vec![0.6; 8];
667 let penalty = ewc.penalty(¤t);
668 assert!(penalty > 0.0);
669 }
670
671 #[test]
672 fn test_reasoning_bank() {
673 let mut bank = ReasoningBank::new(5, 100);
674
675 let pattern = ReasoningPattern {
676 id: "test".to_string(),
677 input: vec![0.5; 8],
678 output: vec![0.3; 8],
679 score: 1.0,
680 usage_count: 1,
681 cluster_id: 0,
682 };
683 bank.store(pattern);
684
685 let query = vec![0.5; 8];
686 let results = bank.retrieve(&query, 1);
687 assert_eq!(results.len(), 1);
688 }
689
690 #[test]
691 fn test_runtime_adaptation() {
692 let mut adapter = RuntimeAdaptation::new(16);
693
694 let input = vec![0.5; 16];
695 let output = vec![0.3; 16];
696 adapter.adapt(&input, &output);
697
698 let result = adapter.apply(&input);
699 assert_eq!(result.len(), 16);
700
701 let stats = adapter.stats();
702 assert_eq!(stats.adaptation_count, 1);
703 }
704}