1use chrono::{DateTime, Utc};
36use serde::{Deserialize, Serialize};
37use sqlx::postgres::types::PgInterval;
38use std::time::Instant;
39use thiserror::Error;
40
41pub mod constants {
43 pub const E: f64 = std::f64::consts::E;
45
46 pub const MATHEMATICAL_TOLERANCE: f64 = 0.001;
48
49 pub const COLD_MIGRATION_THRESHOLD: f64 = 0.5;
51 pub const FROZEN_MIGRATION_THRESHOLD: f64 = 0.2;
52
53 pub const DEFAULT_CONSOLIDATION_STRENGTH: f64 = 1.0;
55 pub const DEFAULT_DECAY_RATE: f64 = 1.0;
56 pub const MAX_CONSOLIDATION_STRENGTH: f64 = 10.0;
57 pub const MIN_CONSOLIDATION_STRENGTH: f64 = 0.1;
58
59 pub const MAX_CALCULATION_TIME_MS: u64 = 10;
61
62 pub const MICROSECONDS_PER_HOUR: f64 = 3_600_000_000.0;
64 pub const SECONDS_PER_HOUR: f64 = 3600.0;
65}
66
67#[derive(Error, Debug, Clone, PartialEq)]
69pub enum MathEngineError {
70 #[error("Invalid parameter: {parameter} = {value}, expected {constraint}")]
71 InvalidParameter {
72 parameter: String,
73 value: f64,
74 constraint: String,
75 },
76
77 #[error("Mathematical overflow in calculation: {operation}")]
78 MathematicalOverflow { operation: String },
79
80 #[error("Calculation accuracy exceeded tolerance: expected {expected}, got {actual}, tolerance {tolerance}")]
81 AccuracyError {
82 expected: f64,
83 actual: f64,
84 tolerance: f64,
85 },
86
87 #[error("Performance target exceeded: {duration_ms}ms > {target_ms}ms")]
88 PerformanceError { duration_ms: u64, target_ms: u64 },
89
90 #[error("Batch processing error: {message}")]
91 BatchProcessingError { message: String },
92}
93
94pub type Result<T> = std::result::Result<T, MathEngineError>;
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct MathEngineConfig {
99 pub cold_threshold: f64,
101
102 pub frozen_threshold: f64,
104
105 pub max_consolidation_strength: f64,
107
108 pub min_consolidation_strength: f64,
110
111 pub tolerance: f64,
113
114 pub performance_target_ms: u64,
116
117 pub enable_batch_processing: bool,
119}
120
121impl Default for MathEngineConfig {
122 fn default() -> Self {
123 Self {
124 cold_threshold: constants::COLD_MIGRATION_THRESHOLD,
125 frozen_threshold: constants::FROZEN_MIGRATION_THRESHOLD,
126 max_consolidation_strength: constants::MAX_CONSOLIDATION_STRENGTH,
127 min_consolidation_strength: constants::MIN_CONSOLIDATION_STRENGTH,
128 tolerance: constants::MATHEMATICAL_TOLERANCE,
129 performance_target_ms: constants::MAX_CALCULATION_TIME_MS,
130 enable_batch_processing: true,
131 }
132 }
133}
134
135#[derive(Debug, Clone)]
137pub struct MemoryParameters {
138 pub consolidation_strength: f64,
139 pub decay_rate: f64,
140 pub last_accessed_at: Option<DateTime<Utc>>,
141 pub created_at: DateTime<Utc>,
142 pub access_count: i32,
143 pub importance_score: f64,
144}
145
146#[derive(Debug, Clone, PartialEq)]
148pub struct RecallCalculationResult {
149 pub recall_probability: f64,
150 pub time_since_access_hours: f64,
151 pub normalized_time: f64,
152 pub calculation_time_ms: u64,
153}
154
155#[derive(Debug, Clone, PartialEq)]
157pub struct ConsolidationUpdateResult {
158 pub new_consolidation_strength: f64,
159 pub strength_increment: f64,
160 pub recall_interval_hours: f64,
161 pub calculation_time_ms: u64,
162}
163
164#[derive(Debug, Clone)]
166pub struct BatchProcessingResult {
167 pub processed_count: usize,
168 pub total_time_ms: u64,
169 pub average_time_per_memory_ms: f64,
170 pub results: Vec<RecallCalculationResult>,
171 pub errors: Vec<(usize, MathEngineError)>,
172}
173
174#[derive(Debug, Clone)]
176pub struct MathEngine {
177 config: MathEngineConfig,
178}
179
180impl MathEngine {
181 pub fn new() -> Self {
183 Self {
184 config: MathEngineConfig::default(),
185 }
186 }
187
188 pub fn with_config(config: MathEngineConfig) -> Self {
190 Self { config }
191 }
192
193 pub fn config(&self) -> &MathEngineConfig {
195 &self.config
196 }
197
198 pub fn update_config(&mut self, config: MathEngineConfig) {
200 self.config = config;
201 }
202
203 pub fn calculate_recall_probability(
216 &self,
217 params: &MemoryParameters,
218 ) -> Result<RecallCalculationResult> {
219 let start_time = Instant::now();
220
221 self.validate_parameters(params)?;
223
224 let last_access = match params.last_accessed_at {
226 Some(access_time) => access_time,
227 None => {
228 let time_since_creation = (Utc::now() - params.created_at).num_seconds() as f64
230 / constants::SECONDS_PER_HOUR;
231 let probability =
232 self.calculate_new_memory_probability(time_since_creation, params)?;
233 let calculation_time = start_time.elapsed().as_millis() as u64;
234
235 return Ok(RecallCalculationResult {
236 recall_probability: probability,
237 time_since_access_hours: time_since_creation,
238 normalized_time: time_since_creation
239 / params
240 .consolidation_strength
241 .max(self.config.min_consolidation_strength),
242 calculation_time_ms: calculation_time,
243 });
244 }
245 };
246
247 let time_since_access =
249 (Utc::now() - last_access).num_seconds() as f64 / constants::SECONDS_PER_HOUR;
250
251 let consolidation_strength = params
256 .consolidation_strength
257 .max(self.config.min_consolidation_strength);
258 let normalized_time = time_since_access / consolidation_strength;
259
260 let probability = self.forgetting_curve_formula(normalized_time, params.decay_rate)?;
262
263 let calculation_time = start_time.elapsed().as_millis() as u64;
264
265 if calculation_time > self.config.performance_target_ms {
267 return Err(MathEngineError::PerformanceError {
268 duration_ms: calculation_time,
269 target_ms: self.config.performance_target_ms,
270 });
271 }
272
273 Ok(RecallCalculationResult {
274 recall_probability: probability,
275 time_since_access_hours: time_since_access,
276 normalized_time,
277 calculation_time_ms: calculation_time,
278 })
279 }
280
281 pub fn update_consolidation_strength(
292 &self,
293 current_strength: f64,
294 recall_interval: PgInterval,
295 ) -> Result<ConsolidationUpdateResult> {
296 let start_time = Instant::now();
297
298 if current_strength < 0.0 || current_strength > self.config.max_consolidation_strength * 2.0
300 {
301 return Err(MathEngineError::InvalidParameter {
302 parameter: "current_strength".to_string(),
303 value: current_strength,
304 constraint: format!(
305 "0.0 <= value <= {}",
306 self.config.max_consolidation_strength * 2.0
307 ),
308 });
309 }
310
311 let recall_interval_hours =
313 recall_interval.microseconds as f64 / constants::MICROSECONDS_PER_HOUR;
314
315 if recall_interval_hours < 1.0 / 60.0 {
317 let calculation_time = start_time.elapsed().as_millis() as u64;
318 return Ok(ConsolidationUpdateResult {
319 new_consolidation_strength: current_strength,
320 strength_increment: 0.0,
321 recall_interval_hours,
322 calculation_time_ms: calculation_time,
323 });
324 }
325
326 let strength_increment = self.consolidation_strength_formula(recall_interval_hours)?;
328
329 let new_strength = (current_strength + strength_increment)
331 .min(self.config.max_consolidation_strength)
332 .max(self.config.min_consolidation_strength);
333
334 let calculation_time = start_time.elapsed().as_millis() as u64;
335
336 if calculation_time > self.config.performance_target_ms {
338 return Err(MathEngineError::PerformanceError {
339 duration_ms: calculation_time,
340 target_ms: self.config.performance_target_ms,
341 });
342 }
343
344 Ok(ConsolidationUpdateResult {
345 new_consolidation_strength: new_strength,
346 strength_increment,
347 recall_interval_hours,
348 calculation_time_ms: calculation_time,
349 })
350 }
351
352 pub fn calculate_decay_rate(&self, params: &MemoryParameters) -> Result<f64> {
364 if params.access_count < 0 {
366 return Err(MathEngineError::InvalidParameter {
367 parameter: "access_count".to_string(),
368 value: params.access_count as f64,
369 constraint: "access_count >= 0".to_string(),
370 });
371 }
372
373 if !(0.0..=1.0).contains(¶ms.importance_score) {
374 return Err(MathEngineError::InvalidParameter {
375 parameter: "importance_score".to_string(),
376 value: params.importance_score,
377 constraint: "0.0 <= importance_score <= 1.0".to_string(),
378 });
379 }
380
381 let mut decay_rate = constants::DEFAULT_DECAY_RATE;
383
384 let access_factor = if params.access_count > 0 {
386 1.0 / (1.0 + (params.access_count as f64).ln())
387 } else {
388 1.0
389 };
390
391 let importance_factor = 1.0 - (params.importance_score * 0.5);
393
394 let age_days = (Utc::now() - params.created_at).num_days() as f64;
396 let age_factor = if age_days > 0.0 {
397 1.0 + (age_days / 30.0).min(2.0) } else {
399 1.0
400 };
401
402 decay_rate *= access_factor * importance_factor * age_factor;
404
405 Ok(decay_rate.max(0.1).min(5.0))
407 }
408
409 pub fn batch_calculate_recall_probability(
420 &self,
421 memory_params: &[MemoryParameters],
422 ) -> Result<BatchProcessingResult> {
423 if !self.config.enable_batch_processing {
424 return Err(MathEngineError::BatchProcessingError {
425 message: "Batch processing is disabled".to_string(),
426 });
427 }
428
429 let start_time = Instant::now();
430 let mut results = Vec::with_capacity(memory_params.len());
431 let mut errors = Vec::new();
432
433 for (index, params) in memory_params.iter().enumerate() {
434 match self.calculate_recall_probability(params) {
435 Ok(result) => results.push(result),
436 Err(error) => {
437 errors.push((index, error));
438 results.push(RecallCalculationResult {
440 recall_probability: 0.0,
441 time_since_access_hours: 0.0,
442 normalized_time: 0.0,
443 calculation_time_ms: 0,
444 });
445 }
446 }
447 }
448
449 let total_time = start_time.elapsed().as_millis() as u64;
450 let average_time = if !results.is_empty() {
451 total_time as f64 / results.len() as f64
452 } else {
453 0.0
454 };
455
456 Ok(BatchProcessingResult {
457 processed_count: memory_params.len(),
458 total_time_ms: total_time,
459 average_time_per_memory_ms: average_time,
460 results,
461 errors,
462 })
463 }
464
465 pub fn should_migrate(&self, recall_probability: f64, current_tier: &str) -> bool {
474 match current_tier.to_lowercase().as_str() {
475 "working" => recall_probability < 0.7,
476 "warm" => recall_probability < self.config.cold_threshold,
477 "cold" => recall_probability < self.config.frozen_threshold,
478 "frozen" => false,
479 _ => false,
480 }
481 }
482
483 pub fn validate_accuracy(&self, expected: f64, actual: f64) -> Result<()> {
492 let difference = (expected - actual).abs();
493 if difference > self.config.tolerance {
494 return Err(MathEngineError::AccuracyError {
495 expected,
496 actual,
497 tolerance: self.config.tolerance,
498 });
499 }
500 Ok(())
501 }
502
503 fn forgetting_curve_formula(&self, normalized_time: f64, decay_rate: f64) -> Result<f64> {
508 if normalized_time < 0.0 {
510 return Err(MathEngineError::InvalidParameter {
511 parameter: "normalized_time".to_string(),
512 value: normalized_time,
513 constraint: "normalized_time >= 0.0".to_string(),
514 });
515 }
516
517 if decay_rate <= 0.0 {
518 return Err(MathEngineError::InvalidParameter {
519 parameter: "decay_rate".to_string(),
520 value: decay_rate,
521 constraint: "decay_rate > 0.0".to_string(),
522 });
523 }
524
525 let exp_neg_t = (-normalized_time).exp();
527 if !exp_neg_t.is_finite() {
528 return Err(MathEngineError::MathematicalOverflow {
529 operation: "exp(-t) calculation".to_string(),
530 });
531 }
532
533 let exponent = -decay_rate * exp_neg_t;
534 if !exponent.is_finite() {
535 return Err(MathEngineError::MathematicalOverflow {
536 operation: "-r * e^(-t) calculation".to_string(),
537 });
538 }
539
540 let numerator = 1.0 - exponent.exp();
541 let denominator = 1.0 - (-1.0_f64).exp();
542
543 if !numerator.is_finite() || !denominator.is_finite() || denominator.abs() < f64::EPSILON {
544 return Err(MathEngineError::MathematicalOverflow {
545 operation: "forgetting curve probability calculation".to_string(),
546 });
547 }
548
549 let probability = numerator / denominator;
550
551 Ok(probability.max(0.0).min(1.0))
553 }
554
555 fn consolidation_strength_formula(&self, time_hours: f64) -> Result<f64> {
558 if time_hours < 0.0 {
559 return Err(MathEngineError::InvalidParameter {
560 parameter: "time_hours".to_string(),
561 value: time_hours,
562 constraint: "time_hours >= 0.0".to_string(),
563 });
564 }
565
566 let exp_neg_t = (-time_hours).exp();
567 if !exp_neg_t.is_finite() {
568 return Err(MathEngineError::MathematicalOverflow {
569 operation: "exp(-t) in consolidation formula".to_string(),
570 });
571 }
572
573 let numerator = 1.0 - exp_neg_t;
574 let denominator = 1.0 + exp_neg_t;
575
576 if denominator.abs() < f64::EPSILON {
577 return Err(MathEngineError::MathematicalOverflow {
578 operation: "division by zero in consolidation formula".to_string(),
579 });
580 }
581
582 Ok(numerator / denominator)
583 }
584
585 fn calculate_new_memory_probability(
587 &self,
588 time_since_creation: f64,
589 params: &MemoryParameters,
590 ) -> Result<f64> {
591 let adjusted_consolidation = params.consolidation_strength * params.importance_score;
594 let normalized_time = time_since_creation / adjusted_consolidation.max(0.1);
595 self.forgetting_curve_formula(normalized_time, params.decay_rate)
596 }
597
598 fn validate_parameters(&self, params: &MemoryParameters) -> Result<()> {
600 if params.consolidation_strength < 0.0 {
601 return Err(MathEngineError::InvalidParameter {
602 parameter: "consolidation_strength".to_string(),
603 value: params.consolidation_strength,
604 constraint: "consolidation_strength >= 0.0".to_string(),
605 });
606 }
607
608 if params.decay_rate <= 0.0 {
609 return Err(MathEngineError::InvalidParameter {
610 parameter: "decay_rate".to_string(),
611 value: params.decay_rate,
612 constraint: "decay_rate > 0.0".to_string(),
613 });
614 }
615
616 if !(0.0..=1.0).contains(¶ms.importance_score) {
617 return Err(MathEngineError::InvalidParameter {
618 parameter: "importance_score".to_string(),
619 value: params.importance_score,
620 constraint: "0.0 <= importance_score <= 1.0".to_string(),
621 });
622 }
623
624 Ok(())
625 }
626}
627
628impl Default for MathEngine {
629 fn default() -> Self {
630 Self::new()
631 }
632}
633
634pub mod benchmarks {
636 use super::*;
637 use std::time::Instant;
638
639 pub fn benchmark_single_calculation(
641 engine: &MathEngine,
642 params: &MemoryParameters,
643 iterations: usize,
644 ) -> (f64, f64, f64) {
645 let mut times = Vec::with_capacity(iterations);
646
647 for _ in 0..iterations {
648 let start = Instant::now();
649 let _ = engine.calculate_recall_probability(params);
650 times.push(start.elapsed().as_nanos() as f64 / 1_000_000.0); }
652
653 let sum: f64 = times.iter().sum();
654 let avg = sum / times.len() as f64;
655
656 times.sort_by(|a, b| a.partial_cmp(b).unwrap());
657 let median = if times.len() % 2 == 0 {
658 (times[times.len() / 2 - 1] + times[times.len() / 2]) / 2.0
659 } else {
660 times[times.len() / 2]
661 };
662
663 let p99_index = ((times.len() as f64) * 0.99) as usize;
664 let p99 = times[p99_index.min(times.len() - 1)];
665
666 (avg, median, p99)
667 }
668
669 pub fn benchmark_batch_processing(
671 engine: &MathEngine,
672 batch_sizes: &[usize],
673 ) -> Vec<(usize, f64, f64)> {
674 let mut results = Vec::new();
675
676 for &batch_size in batch_sizes {
677 let params = vec![
678 MemoryParameters {
679 consolidation_strength: 1.0,
680 decay_rate: 1.0,
681 last_accessed_at: Some(Utc::now() - chrono::Duration::hours(1)),
682 created_at: Utc::now() - chrono::Duration::days(1),
683 access_count: 5,
684 importance_score: 0.5,
685 };
686 batch_size
687 ];
688
689 let start = Instant::now();
690 let result = engine.batch_calculate_recall_probability(¶ms);
691 let total_time = start.elapsed().as_millis() as f64;
692
693 if let Ok(_batch_result) = result {
694 let throughput = batch_size as f64 / (total_time / 1000.0); results.push((batch_size, total_time, throughput));
696 }
697 }
698
699 results
700 }
701}
702
703#[cfg(test)]
704mod tests {
705 use super::*;
706 use chrono::Duration;
707 use proptest::prelude::*;
708
709 fn create_test_params() -> MemoryParameters {
710 MemoryParameters {
711 consolidation_strength: 1.0,
712 decay_rate: 1.0,
713 last_accessed_at: Some(Utc::now() - Duration::hours(1)),
714 created_at: Utc::now() - Duration::days(1),
715 access_count: 5,
716 importance_score: 0.5,
717 }
718 }
719
720 #[test]
721 fn test_recall_probability_calculation() {
722 let engine = MathEngine::new();
723 let params = create_test_params();
724
725 let result = engine.calculate_recall_probability(¶ms).unwrap();
726
727 assert!(result.recall_probability >= 0.0);
728 assert!(result.recall_probability <= 1.0);
729 assert!(result.calculation_time_ms <= constants::MAX_CALCULATION_TIME_MS);
730 }
731
732 #[test]
733 fn test_consolidation_strength_update() {
734 let engine = MathEngine::new();
735 let interval = PgInterval {
736 months: 0,
737 days: 0,
738 microseconds: (2.0 * constants::MICROSECONDS_PER_HOUR) as i64, };
740
741 let result = engine.update_consolidation_strength(1.0, interval).unwrap();
742
743 assert!(result.new_consolidation_strength > 1.0);
744 assert!(result.new_consolidation_strength <= constants::MAX_CONSOLIDATION_STRENGTH);
745 assert!(result.calculation_time_ms <= constants::MAX_CALCULATION_TIME_MS);
746 }
747
748 #[test]
749 fn test_decay_rate_calculation() {
750 let engine = MathEngine::new();
751 let params = create_test_params();
752
753 let decay_rate = engine.calculate_decay_rate(¶ms).unwrap();
754
755 assert!(decay_rate > 0.0);
756 assert!(decay_rate <= 5.0);
757 }
758
759 #[test]
760 fn test_edge_case_never_accessed() {
761 let engine = MathEngine::new();
762 let mut params = create_test_params();
763 params.last_accessed_at = None;
764
765 let result = engine.calculate_recall_probability(¶ms).unwrap();
766
767 assert!(result.recall_probability >= 0.0);
768 assert!(result.recall_probability <= 1.0);
769 }
770
771 #[test]
772 fn test_edge_case_very_recent_access() {
773 let engine = MathEngine::new();
774 let mut params = create_test_params();
775 params.last_accessed_at = Some(Utc::now() - Duration::seconds(30));
776
777 let result = engine.calculate_recall_probability(¶ms).unwrap();
778
779 assert!(
782 result.recall_probability > 0.99,
783 "Very recent access should have >99% recall probability, got {}",
784 result.recall_probability
785 );
786 assert!(result.recall_probability <= 1.0);
787 }
788
789 #[test]
790 fn test_batch_processing() {
791 let engine = MathEngine::new();
792 let params = vec![create_test_params(); 100];
793
794 let result = engine.batch_calculate_recall_probability(¶ms).unwrap();
795
796 assert_eq!(result.processed_count, 100);
797 assert_eq!(result.results.len(), 100);
798 assert!(result.average_time_per_memory_ms < constants::MAX_CALCULATION_TIME_MS as f64);
799 }
800
801 #[test]
802 fn test_accuracy_validation() {
803 let engine = MathEngine::new();
804
805 assert!(engine.validate_accuracy(0.5, 0.5001).is_ok());
807
808 assert!(engine.validate_accuracy(0.5, 0.6).is_err());
810 }
811
812 proptest! {
813 #[test]
814 fn test_recall_probability_properties(
815 consolidation_strength in 0.1f64..10.0,
816 decay_rate in 0.1f64..5.0,
817 hours_ago in 0.1f64..168.0, importance_score in 0.0f64..1.0,
819 access_count in 0i32..1000,
820 ) {
821 let engine = MathEngine::new();
822 let params = MemoryParameters {
823 consolidation_strength,
824 decay_rate,
825 last_accessed_at: Some(Utc::now() - Duration::seconds((hours_ago * 3600.0) as i64)),
826 created_at: Utc::now() - Duration::days(1),
827 access_count,
828 importance_score,
829 };
830
831 let result = engine.calculate_recall_probability(¶ms);
832
833 if let Ok(calculation) = result {
834 assert!(calculation.recall_probability >= 0.0);
836 assert!(calculation.recall_probability <= 1.0);
837
838 assert!(calculation.calculation_time_ms <= constants::MAX_CALCULATION_TIME_MS);
840 }
841 }
842
843 #[test]
844 fn test_consolidation_strength_properties(
845 initial_strength in 0.1f64..10.0,
846 recall_interval_hours in 0.1f64..168.0,
847 ) {
848 let engine = MathEngine::new();
849 let interval = PgInterval {
850 months: 0,
851 days: 0,
852 microseconds: (recall_interval_hours * constants::MICROSECONDS_PER_HOUR) as i64,
853 };
854
855 let result = engine.update_consolidation_strength(initial_strength, interval);
856
857 if let Ok(update) = result {
858 assert!(update.new_consolidation_strength >= initial_strength);
860
861 assert!(update.new_consolidation_strength <= constants::MAX_CONSOLIDATION_STRENGTH);
863
864 assert!(update.calculation_time_ms <= constants::MAX_CALCULATION_TIME_MS);
866 }
867 }
868 }
869}