1use crate::error::{CacheError, Result};
12use crate::multi_tier::CacheKey;
13use scirs2_core::ndarray::{Array1, Array2, Axis};
14use std::collections::{HashMap, VecDeque};
15
16fn rand_normal(mean: f64, std_dev: f64) -> f64 {
18 let u1 = fastrand::f64();
19 let u2 = fastrand::f64();
20 let u1 = if u1 < 1e-10 { 1e-10 } else { u1 };
22 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
23 mean + z0 * std_dev
24}
25
26pub struct TransformerPredictor {
28 embedding_dim: usize,
30 #[allow(dead_code)]
32 num_heads: usize,
33 seq_length: usize,
35 w_query: Option<Array2<f64>>,
37 w_key: Option<Array2<f64>>,
39 w_value: Option<Array2<f64>>,
41 w_output: Option<Array2<f64>>,
43 key_to_idx: HashMap<CacheKey, usize>,
45 idx_to_key: Vec<CacheKey>,
47 sequence: VecDeque<usize>,
49 vocab_size: usize,
51}
52
53impl TransformerPredictor {
54 pub fn new(embedding_dim: usize, num_heads: usize, seq_length: usize) -> Self {
56 Self {
57 embedding_dim,
58 num_heads,
59 seq_length,
60 w_query: None,
61 w_key: None,
62 w_value: None,
63 w_output: None,
64 key_to_idx: HashMap::new(),
65 idx_to_key: Vec::new(),
66 sequence: VecDeque::with_capacity(seq_length),
67 vocab_size: 0,
68 }
69 }
70
71 fn initialize_weights(&mut self) {
73 fastrand::seed(42);
75 let scale = (2.0 / self.embedding_dim as f64).sqrt();
76
77 let q_data: Vec<f64> = (0..self.embedding_dim * self.embedding_dim)
78 .map(|_| rand_normal(0.0, scale))
79 .collect();
80
81 let k_data: Vec<f64> = (0..self.embedding_dim * self.embedding_dim)
82 .map(|_| rand_normal(0.0, scale))
83 .collect();
84
85 let v_data: Vec<f64> = (0..self.embedding_dim * self.embedding_dim)
86 .map(|_| rand_normal(0.0, scale))
87 .collect();
88
89 let o_data: Vec<f64> = (0..self.embedding_dim * self.embedding_dim)
90 .map(|_| rand_normal(0.0, scale))
91 .collect();
92
93 self.w_query = Some(
94 Array2::from_shape_vec((self.embedding_dim, self.embedding_dim), q_data)
95 .unwrap_or_else(|_| Array2::zeros((self.embedding_dim, self.embedding_dim))),
96 );
97
98 self.w_key = Some(
99 Array2::from_shape_vec((self.embedding_dim, self.embedding_dim), k_data)
100 .unwrap_or_else(|_| Array2::zeros((self.embedding_dim, self.embedding_dim))),
101 );
102
103 self.w_value = Some(
104 Array2::from_shape_vec((self.embedding_dim, self.embedding_dim), v_data)
105 .unwrap_or_else(|_| Array2::zeros((self.embedding_dim, self.embedding_dim))),
106 );
107
108 self.w_output = Some(
109 Array2::from_shape_vec((self.embedding_dim, self.embedding_dim), o_data)
110 .unwrap_or_else(|_| Array2::zeros((self.embedding_dim, self.embedding_dim))),
111 );
112 }
113
114 fn add_to_vocab(&mut self, key: &CacheKey) -> usize {
116 if let Some(&idx) = self.key_to_idx.get(key) {
117 idx
118 } else {
119 let idx = self.vocab_size;
120 self.key_to_idx.insert(key.clone(), idx);
121 self.idx_to_key.push(key.clone());
122 self.vocab_size += 1;
123
124 if self.w_query.is_none() {
125 self.initialize_weights();
126 }
127
128 idx
129 }
130 }
131
132 fn attention(
134 &self,
135 query: &Array2<f64>,
136 key: &Array2<f64>,
137 value: &Array2<f64>,
138 ) -> Result<Array2<f64>> {
139 let w_q = self
140 .w_query
141 .as_ref()
142 .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
143 let w_k = self
144 .w_key
145 .as_ref()
146 .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
147 let w_v = self
148 .w_value
149 .as_ref()
150 .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
151 let w_o = self
152 .w_output
153 .as_ref()
154 .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
155
156 let q_proj = query.dot(w_q);
158 let k_proj = key.dot(w_k);
159 let v_proj = value.dot(w_v);
160
161 let scores = q_proj.dot(&k_proj.t()) / (self.embedding_dim as f64).sqrt();
163
164 let scores_exp = scores.mapv(|x| x.exp());
166 let scores_sum = scores_exp.sum_axis(Axis(1));
167 let attention_weights = &scores_exp / &scores_sum.insert_axis(Axis(1));
168
169 let attended = attention_weights.dot(&v_proj);
171
172 Ok(attended.dot(w_o))
174 }
175
176 pub fn record_access(&mut self, key: CacheKey) {
178 let idx = self.add_to_vocab(&key);
179
180 if self.sequence.len() >= self.seq_length {
181 self.sequence.pop_front();
182 }
183 self.sequence.push_back(idx);
184 }
185
186 pub fn predict(&self, top_n: usize) -> Result<Vec<(CacheKey, f64)>> {
188 if self.sequence.is_empty() {
189 return Ok(Vec::new());
190 }
191
192 let mut embeddings = Array2::zeros((self.sequence.len(), self.embedding_dim));
194 for (i, &idx) in self.sequence.iter().enumerate() {
195 if idx < self.embedding_dim {
197 embeddings[[i, idx]] = 1.0;
198 }
199 }
200
201 let output = self.attention(&embeddings, &embeddings, &embeddings)?;
203
204 let last_output = output.row(output.nrows() - 1);
206
207 let mut scores: Vec<(CacheKey, f64)> = self
209 .idx_to_key
210 .iter()
211 .enumerate()
212 .map(|(idx, key)| {
213 let score = if idx < last_output.len() {
214 last_output[idx]
215 } else {
216 0.0
217 };
218 (key.clone(), score)
219 })
220 .collect();
221
222 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
224 scores.truncate(top_n);
225
226 let sum: f64 = scores.iter().map(|(_, s)| s.exp()).sum();
228 if sum > 0.0 {
229 for (_, score) in &mut scores {
230 *score = score.exp() / sum;
231 }
232 }
233
234 Ok(scores)
235 }
236
237 pub fn clear(&mut self) {
239 self.sequence.clear();
240 self.key_to_idx.clear();
241 self.idx_to_key.clear();
242 self.vocab_size = 0;
243 self.w_query = None;
244 self.w_key = None;
245 self.w_value = None;
246 self.w_output = None;
247 }
248}
249
250pub struct LSTMPredictor {
252 hidden_size: usize,
254 vocab_size: usize,
256 w_forget: Option<Array2<f64>>,
258 w_input: Option<Array2<f64>>,
260 w_output: Option<Array2<f64>>,
262 w_cell: Option<Array2<f64>>,
264 hidden_state: Option<Array1<f64>>,
266 cell_state: Option<Array1<f64>>,
268 key_to_idx: HashMap<CacheKey, usize>,
270 idx_to_key: Vec<CacheKey>,
272}
273
274impl LSTMPredictor {
275 pub fn new(hidden_size: usize) -> Self {
277 Self {
278 hidden_size,
279 vocab_size: 0,
280 w_forget: None,
281 w_input: None,
282 w_output: None,
283 w_cell: None,
284 hidden_state: None,
285 cell_state: None,
286 key_to_idx: HashMap::new(),
287 idx_to_key: Vec::new(),
288 }
289 }
290
291 fn initialize_weights(&mut self) {
293 fastrand::seed(42);
295 let input_size = self.vocab_size + self.hidden_size;
296 let scale = (2.0 / input_size as f64).sqrt();
297
298 let wf_data: Vec<f64> = (0..input_size * self.hidden_size)
299 .map(|_| rand_normal(0.0, scale))
300 .collect();
301
302 let wi_data: Vec<f64> = (0..input_size * self.hidden_size)
303 .map(|_| rand_normal(0.0, scale))
304 .collect();
305
306 let wo_data: Vec<f64> = (0..input_size * self.hidden_size)
307 .map(|_| rand_normal(0.0, scale))
308 .collect();
309
310 let wc_data: Vec<f64> = (0..input_size * self.hidden_size)
311 .map(|_| rand_normal(0.0, scale))
312 .collect();
313
314 self.w_forget = Some(
315 Array2::from_shape_vec((input_size, self.hidden_size), wf_data)
316 .unwrap_or_else(|_| Array2::zeros((input_size, self.hidden_size))),
317 );
318
319 self.w_input = Some(
320 Array2::from_shape_vec((input_size, self.hidden_size), wi_data)
321 .unwrap_or_else(|_| Array2::zeros((input_size, self.hidden_size))),
322 );
323
324 self.w_output = Some(
325 Array2::from_shape_vec((input_size, self.hidden_size), wo_data)
326 .unwrap_or_else(|_| Array2::zeros((input_size, self.hidden_size))),
327 );
328
329 self.w_cell = Some(
330 Array2::from_shape_vec((input_size, self.hidden_size), wc_data)
331 .unwrap_or_else(|_| Array2::zeros((input_size, self.hidden_size))),
332 );
333
334 self.hidden_state = Some(Array1::zeros(self.hidden_size));
335 self.cell_state = Some(Array1::zeros(self.hidden_size));
336 }
337
338 fn add_to_vocab(&mut self, key: &CacheKey) -> usize {
340 if let Some(&idx) = self.key_to_idx.get(key) {
341 idx
342 } else {
343 let idx = self.vocab_size;
344 self.key_to_idx.insert(key.clone(), idx);
345 self.idx_to_key.push(key.clone());
346 self.vocab_size += 1;
347
348 self.initialize_weights();
350
351 idx
352 }
353 }
354
355 fn sigmoid(x: f64) -> f64 {
357 1.0 / (1.0 + (-x).exp())
358 }
359
360 fn forward(&mut self, input_idx: usize) -> Result<Array1<f64>> {
362 let w_f = self
363 .w_forget
364 .as_ref()
365 .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
366 let w_i = self
367 .w_input
368 .as_ref()
369 .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
370 let w_o = self
371 .w_output
372 .as_ref()
373 .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
374 let w_c = self
375 .w_cell
376 .as_ref()
377 .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
378
379 let h_prev = self
380 .hidden_state
381 .as_ref()
382 .ok_or_else(|| CacheError::Prediction("Hidden state not initialized".to_string()))?;
383 let c_prev = self
384 .cell_state
385 .as_ref()
386 .ok_or_else(|| CacheError::Prediction("Cell state not initialized".to_string()))?;
387
388 let mut input = Array1::zeros(self.vocab_size);
390 if input_idx < self.vocab_size {
391 input[input_idx] = 1.0;
392 }
393
394 let mut combined = Array1::zeros(self.vocab_size + self.hidden_size);
396 for i in 0..self.vocab_size {
397 combined[i] = input[i];
398 }
399 for i in 0..self.hidden_size {
400 combined[self.vocab_size + i] = h_prev[i];
401 }
402
403 let forget_gate = w_f.t().dot(&combined).mapv(Self::sigmoid);
405 let input_gate = w_i.t().dot(&combined).mapv(Self::sigmoid);
406 let output_gate = w_o.t().dot(&combined).mapv(Self::sigmoid);
407 let cell_candidate = w_c.t().dot(&combined).mapv(|x| x.tanh());
408
409 let new_cell = &forget_gate * c_prev + &input_gate * &cell_candidate;
411
412 let new_hidden = &output_gate * &new_cell.mapv(|x| x.tanh());
414
415 self.cell_state = Some(new_cell);
417 self.hidden_state = Some(new_hidden.clone());
418
419 Ok(new_hidden)
420 }
421
422 pub fn record_access(&mut self, key: CacheKey) -> Result<()> {
424 let idx = self.add_to_vocab(&key);
425 self.forward(idx)?;
426 Ok(())
427 }
428
429 pub fn predict(&mut self, top_n: usize) -> Result<Vec<(CacheKey, f64)>> {
431 let hidden = self
432 .hidden_state
433 .as_ref()
434 .ok_or_else(|| CacheError::Prediction("Not trained".to_string()))?;
435
436 let mut scores: Vec<(CacheKey, f64)> = self
438 .idx_to_key
439 .iter()
440 .enumerate()
441 .map(|(idx, key)| {
442 let score = if idx < hidden.len() { hidden[idx] } else { 0.0 };
443 (key.clone(), score)
444 })
445 .collect();
446
447 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
448 scores.truncate(top_n);
449
450 let sum: f64 = scores.iter().map(|(_, s)| s.exp()).sum();
452 if sum > 0.0 {
453 for (_, score) in &mut scores {
454 *score = score.exp() / sum;
455 }
456 }
457
458 Ok(scores)
459 }
460
461 pub fn reset(&mut self) {
463 self.hidden_state = Some(Array1::zeros(self.hidden_size));
464 self.cell_state = Some(Array1::zeros(self.hidden_size));
465 }
466
467 pub fn clear(&mut self) {
469 self.key_to_idx.clear();
470 self.idx_to_key.clear();
471 self.vocab_size = 0;
472 self.w_forget = None;
473 self.w_input = None;
474 self.w_output = None;
475 self.w_cell = None;
476 self.hidden_state = None;
477 self.cell_state = None;
478 }
479}
480
481pub struct HybridPredictor {
483 transformer: TransformerPredictor,
485 lstm: LSTMPredictor,
487 model_weights: HashMap<String, f64>,
489 performance_history: VecDeque<(String, f64)>,
491 history_size: usize,
493}
494
495impl HybridPredictor {
496 pub fn new(embedding_dim: usize, hidden_size: usize, seq_length: usize) -> Self {
498 let mut model_weights = HashMap::new();
499 model_weights.insert("transformer".to_string(), 0.5);
500 model_weights.insert("lstm".to_string(), 0.5);
501
502 Self {
503 transformer: TransformerPredictor::new(embedding_dim, 4, seq_length),
504 lstm: LSTMPredictor::new(hidden_size),
505 model_weights,
506 performance_history: VecDeque::with_capacity(100),
507 history_size: 100,
508 }
509 }
510
511 pub fn record_access(&mut self, key: CacheKey) -> Result<()> {
513 self.transformer.record_access(key.clone());
514 self.lstm.record_access(key)?;
515 Ok(())
516 }
517
518 fn update_weights(&mut self) {
520 if self.performance_history.len() < 10 {
521 return;
522 }
523
524 let mut model_scores: HashMap<String, f64> = HashMap::new();
525 let mut model_counts: HashMap<String, usize> = HashMap::new();
526
527 for (model, score) in &self.performance_history {
528 *model_scores.entry(model.clone()).or_insert(0.0) += score;
529 *model_counts.entry(model.clone()).or_insert(0) += 1;
530 }
531
532 let avg_scores: Vec<(String, f64)> = model_scores
534 .into_iter()
535 .map(|(model, total)| {
536 let count = model_counts.get(&model).copied().unwrap_or(1);
537 (model, total / count as f64)
538 })
539 .collect();
540
541 let sum: f64 = avg_scores.iter().map(|(_, s)| s.exp()).sum();
543 if sum > 0.0 {
544 for (model, score) in avg_scores {
545 self.model_weights.insert(model, score.exp() / sum);
546 }
547 }
548 }
549
550 pub fn predict(&mut self, top_n: usize) -> Result<Vec<(CacheKey, f64)>> {
552 let transformer_preds = self.transformer.predict(top_n)?;
554 let lstm_preds = self.lstm.predict(top_n)?;
555
556 let mut combined_scores: HashMap<CacheKey, f64> = HashMap::new();
558
559 let transformer_weight = self
560 .model_weights
561 .get("transformer")
562 .copied()
563 .unwrap_or(0.5);
564 let lstm_weight = self.model_weights.get("lstm").copied().unwrap_or(0.5);
565
566 for (key, score) in transformer_preds {
567 *combined_scores.entry(key).or_insert(0.0) += score * transformer_weight;
568 }
569
570 for (key, score) in lstm_preds {
571 *combined_scores.entry(key).or_insert(0.0) += score * lstm_weight;
572 }
573
574 let mut results: Vec<(CacheKey, f64)> = combined_scores.into_iter().collect();
576 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
577 results.truncate(top_n);
578
579 Ok(results)
580 }
581
582 pub fn report_accuracy(&mut self, model_name: &str, accuracy: f64) {
584 if self.performance_history.len() >= self.history_size {
585 self.performance_history.pop_front();
586 }
587 self.performance_history
588 .push_back((model_name.to_string(), accuracy));
589 self.update_weights();
590 }
591
592 pub fn get_weights(&self) -> &HashMap<String, f64> {
594 &self.model_weights
595 }
596
597 pub fn clear(&mut self) {
599 self.transformer.clear();
600 self.lstm.clear();
601 self.performance_history.clear();
602 }
603}
604
605#[cfg(test)]
606mod tests {
607 use super::*;
608
609 #[test]
610 fn test_transformer_predictor() {
611 let mut predictor = TransformerPredictor::new(16, 2, 5);
612
613 predictor.record_access("key1".to_string());
614 predictor.record_access("key2".to_string());
615 predictor.record_access("key3".to_string());
616
617 let result = predictor.predict(3);
618 assert!(result.is_ok());
619 }
620
621 #[test]
622 fn test_lstm_predictor() {
623 let mut predictor = LSTMPredictor::new(32);
624
625 let result = predictor.record_access("key1".to_string());
626 assert!(result.is_ok());
627
628 let result = predictor.record_access("key2".to_string());
629 assert!(result.is_ok());
630
631 let predictions = predictor.predict(3);
632 assert!(predictions.is_ok());
633 }
634
635 #[test]
636 fn test_hybrid_predictor() {
637 let mut predictor = HybridPredictor::new(16, 32, 5);
638
639 let result = predictor.record_access("key1".to_string());
640 assert!(result.is_ok());
641
642 let result = predictor.record_access("key2".to_string());
643 assert!(result.is_ok());
644
645 let predictions = predictor.predict(3);
646 assert!(predictions.is_ok());
647 }
648
649 #[test]
650 fn test_hybrid_online_learning() {
651 let mut predictor = HybridPredictor::new(16, 32, 5);
652
653 for _ in 0..10 {
656 predictor.report_accuracy("transformer", 0.8);
657 predictor.report_accuracy("lstm", 0.6);
658 }
659
660 let weights = predictor.get_weights();
661 let transformer_weight = weights.get("transformer").copied().unwrap_or(0.0);
662 let lstm_weight = weights.get("lstm").copied().unwrap_or(0.0);
663
664 assert!(transformer_weight > lstm_weight);
666 }
667}