1use chrono::{DateTime, Utc};
36use serde::{Deserialize, Serialize};
37use sqlx::postgres::types::PgInterval;
38use std::time::Instant;
39use thiserror::Error;
40
41pub mod constants {
43    pub const E: f64 = std::f64::consts::E;
45
46    pub const MATHEMATICAL_TOLERANCE: f64 = 0.001;
48
49    pub const COLD_MIGRATION_THRESHOLD: f64 = 0.86;
51    pub const FROZEN_MIGRATION_THRESHOLD: f64 = 0.3;
52
53    pub const DEFAULT_CONSOLIDATION_STRENGTH: f64 = 1.0;
55    pub const DEFAULT_DECAY_RATE: f64 = 1.0;
56    pub const MAX_CONSOLIDATION_STRENGTH: f64 = 10.0;
57    pub const MIN_CONSOLIDATION_STRENGTH: f64 = 0.1;
58
59    pub const MAX_CALCULATION_TIME_MS: u64 = 10;
61
62    pub const MICROSECONDS_PER_HOUR: f64 = 3_600_000_000.0;
64    pub const SECONDS_PER_HOUR: f64 = 3600.0;
65}
66
67#[derive(Error, Debug, Clone, PartialEq)]
69pub enum MathEngineError {
70    #[error("Invalid parameter: {parameter} = {value}, expected {constraint}")]
71    InvalidParameter {
72        parameter: String,
73        value: f64,
74        constraint: String,
75    },
76
77    #[error("Mathematical overflow in calculation: {operation}")]
78    MathematicalOverflow { operation: String },
79
80    #[error("Calculation accuracy exceeded tolerance: expected {expected}, got {actual}, tolerance {tolerance}")]
81    AccuracyError {
82        expected: f64,
83        actual: f64,
84        tolerance: f64,
85    },
86
87    #[error("Performance target exceeded: {duration_ms}ms > {target_ms}ms")]
88    PerformanceError { duration_ms: u64, target_ms: u64 },
89
90    #[error("Batch processing error: {message}")]
91    BatchProcessingError { message: String },
92}
93
94pub type Result<T> = std::result::Result<T, MathEngineError>;
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct MathEngineConfig {
99    pub cold_threshold: f64,
101
102    pub frozen_threshold: f64,
104
105    pub max_consolidation_strength: f64,
107
108    pub min_consolidation_strength: f64,
110
111    pub tolerance: f64,
113
114    pub performance_target_ms: u64,
116
117    pub enable_batch_processing: bool,
119}
120
121impl Default for MathEngineConfig {
122    fn default() -> Self {
123        Self {
124            cold_threshold: constants::COLD_MIGRATION_THRESHOLD,
125            frozen_threshold: constants::FROZEN_MIGRATION_THRESHOLD,
126            max_consolidation_strength: constants::MAX_CONSOLIDATION_STRENGTH,
127            min_consolidation_strength: constants::MIN_CONSOLIDATION_STRENGTH,
128            tolerance: constants::MATHEMATICAL_TOLERANCE,
129            performance_target_ms: constants::MAX_CALCULATION_TIME_MS,
130            enable_batch_processing: true,
131        }
132    }
133}
134
135#[derive(Debug, Clone)]
137pub struct MemoryParameters {
138    pub consolidation_strength: f64,
139    pub decay_rate: f64,
140    pub last_accessed_at: Option<DateTime<Utc>>,
141    pub created_at: DateTime<Utc>,
142    pub access_count: i32,
143    pub importance_score: f64,
144}
145
146#[derive(Debug, Clone, PartialEq)]
148pub struct RecallCalculationResult {
149    pub recall_probability: f64,
150    pub time_since_access_hours: f64,
151    pub normalized_time: f64,
152    pub calculation_time_ms: u64,
153}
154
155#[derive(Debug, Clone, PartialEq)]
157pub struct ConsolidationUpdateResult {
158    pub new_consolidation_strength: f64,
159    pub strength_increment: f64,
160    pub recall_interval_hours: f64,
161    pub calculation_time_ms: u64,
162}
163
164#[derive(Debug, Clone)]
166pub struct BatchProcessingResult {
167    pub processed_count: usize,
168    pub total_time_ms: u64,
169    pub average_time_per_memory_ms: f64,
170    pub results: Vec<RecallCalculationResult>,
171    pub errors: Vec<(usize, MathEngineError)>,
172}
173
174#[derive(Debug, Clone)]
176pub struct MathEngine {
177    config: MathEngineConfig,
178}
179
180impl MathEngine {
181    pub fn new() -> Self {
183        Self {
184            config: MathEngineConfig::default(),
185        }
186    }
187
188    pub fn with_config(config: MathEngineConfig) -> Self {
190        Self { config }
191    }
192
193    pub fn config(&self) -> &MathEngineConfig {
195        &self.config
196    }
197
198    pub fn update_config(&mut self, config: MathEngineConfig) {
200        self.config = config;
201    }
202
203    pub fn calculate_recall_probability(
216        &self,
217        params: &MemoryParameters,
218    ) -> Result<RecallCalculationResult> {
219        let start_time = Instant::now();
220
221        self.validate_parameters(params)?;
223
224        let last_access = match params.last_accessed_at {
226            Some(access_time) => access_time,
227            None => {
228                let time_since_creation = (Utc::now() - params.created_at).num_seconds() as f64
230                    / constants::SECONDS_PER_HOUR;
231                let probability =
232                    self.calculate_new_memory_probability(time_since_creation, params)?;
233                let calculation_time = start_time.elapsed().as_millis() as u64;
234
235                return Ok(RecallCalculationResult {
236                    recall_probability: probability,
237                    time_since_access_hours: time_since_creation,
238                    normalized_time: time_since_creation
239                        / params
240                            .consolidation_strength
241                            .max(self.config.min_consolidation_strength),
242                    calculation_time_ms: calculation_time,
243                });
244            }
245        };
246
247        let time_since_access =
249            (Utc::now() - last_access).num_seconds() as f64 / constants::SECONDS_PER_HOUR;
250
251        let consolidation_strength = params
256            .consolidation_strength
257            .max(self.config.min_consolidation_strength);
258        let normalized_time = time_since_access / consolidation_strength;
259
260        let probability = self.forgetting_curve_formula(normalized_time, params.decay_rate)?;
262
263        let calculation_time = start_time.elapsed().as_millis() as u64;
264
265        if calculation_time > self.config.performance_target_ms {
267            return Err(MathEngineError::PerformanceError {
268                duration_ms: calculation_time,
269                target_ms: self.config.performance_target_ms,
270            });
271        }
272
273        Ok(RecallCalculationResult {
274            recall_probability: probability,
275            time_since_access_hours: time_since_access,
276            normalized_time,
277            calculation_time_ms: calculation_time,
278        })
279    }
280
281    pub fn update_consolidation_strength(
292        &self,
293        current_strength: f64,
294        recall_interval: PgInterval,
295    ) -> Result<ConsolidationUpdateResult> {
296        let start_time = Instant::now();
297
298        if current_strength < 0.0 || current_strength > self.config.max_consolidation_strength * 2.0
300        {
301            return Err(MathEngineError::InvalidParameter {
302                parameter: "current_strength".to_string(),
303                value: current_strength,
304                constraint: format!(
305                    "0.0 <= value <= {}",
306                    self.config.max_consolidation_strength * 2.0
307                ),
308            });
309        }
310
311        let recall_interval_hours =
313            recall_interval.microseconds as f64 / constants::MICROSECONDS_PER_HOUR;
314
315        if recall_interval_hours < 1.0 / 60.0 {
317            let calculation_time = start_time.elapsed().as_millis() as u64;
318            return Ok(ConsolidationUpdateResult {
319                new_consolidation_strength: current_strength,
320                strength_increment: 0.0,
321                recall_interval_hours,
322                calculation_time_ms: calculation_time,
323            });
324        }
325
326        let strength_increment = self.consolidation_strength_formula(recall_interval_hours)?;
328
329        let new_strength = (current_strength + strength_increment)
331            .min(self.config.max_consolidation_strength)
332            .max(self.config.min_consolidation_strength);
333
334        let calculation_time = start_time.elapsed().as_millis() as u64;
335
336        if calculation_time > self.config.performance_target_ms {
338            return Err(MathEngineError::PerformanceError {
339                duration_ms: calculation_time,
340                target_ms: self.config.performance_target_ms,
341            });
342        }
343
344        Ok(ConsolidationUpdateResult {
345            new_consolidation_strength: new_strength,
346            strength_increment,
347            recall_interval_hours,
348            calculation_time_ms: calculation_time,
349        })
350    }
351
352    pub fn calculate_decay_rate(&self, params: &MemoryParameters) -> Result<f64> {
364        if params.access_count < 0 {
366            return Err(MathEngineError::InvalidParameter {
367                parameter: "access_count".to_string(),
368                value: params.access_count as f64,
369                constraint: "access_count >= 0".to_string(),
370            });
371        }
372
373        if !(0.0..=1.0).contains(¶ms.importance_score) {
374            return Err(MathEngineError::InvalidParameter {
375                parameter: "importance_score".to_string(),
376                value: params.importance_score,
377                constraint: "0.0 <= importance_score <= 1.0".to_string(),
378            });
379        }
380
381        let mut decay_rate = constants::DEFAULT_DECAY_RATE;
383
384        let access_factor = if params.access_count > 0 {
386            1.0 / (1.0 + (params.access_count as f64).ln())
387        } else {
388            1.0
389        };
390
391        let importance_factor = 1.0 - (params.importance_score * 0.5);
393
394        let age_days = (Utc::now() - params.created_at).num_days() as f64;
396        let age_factor = if age_days > 0.0 {
397            1.0 + (age_days / 30.0).min(2.0) } else {
399            1.0
400        };
401
402        decay_rate *= access_factor * importance_factor * age_factor;
404
405        Ok(decay_rate.max(0.1).min(5.0))
407    }
408
409    pub fn batch_calculate_recall_probability(
420        &self,
421        memory_params: &[MemoryParameters],
422    ) -> Result<BatchProcessingResult> {
423        if !self.config.enable_batch_processing {
424            return Err(MathEngineError::BatchProcessingError {
425                message: "Batch processing is disabled".to_string(),
426            });
427        }
428
429        let start_time = Instant::now();
430        let mut results = Vec::with_capacity(memory_params.len());
431        let mut errors = Vec::new();
432
433        for (index, params) in memory_params.iter().enumerate() {
434            match self.calculate_recall_probability(params) {
435                Ok(result) => results.push(result),
436                Err(error) => {
437                    errors.push((index, error));
438                    results.push(RecallCalculationResult {
440                        recall_probability: 0.0,
441                        time_since_access_hours: 0.0,
442                        normalized_time: 0.0,
443                        calculation_time_ms: 0,
444                    });
445                }
446            }
447        }
448
449        let total_time = start_time.elapsed().as_millis() as u64;
450        let average_time = if !results.is_empty() {
451            total_time as f64 / results.len() as f64
452        } else {
453            0.0
454        };
455
456        Ok(BatchProcessingResult {
457            processed_count: memory_params.len(),
458            total_time_ms: total_time,
459            average_time_per_memory_ms: average_time,
460            results,
461            errors,
462        })
463    }
464
465    pub fn should_migrate(&self, recall_probability: f64, current_tier: &str) -> bool {
474        match current_tier.to_lowercase().as_str() {
475            "working" => recall_probability < 0.7,
476            "warm" => recall_probability < self.config.cold_threshold,
477            "cold" => recall_probability < self.config.frozen_threshold,
478            "frozen" => false,
479            _ => false,
480        }
481    }
482
483    pub fn validate_accuracy(&self, expected: f64, actual: f64) -> Result<()> {
492        let difference = (expected - actual).abs();
493        if difference > self.config.tolerance {
494            return Err(MathEngineError::AccuracyError {
495                expected,
496                actual,
497                tolerance: self.config.tolerance,
498            });
499        }
500        Ok(())
501    }
502
503    fn forgetting_curve_formula(&self, normalized_time: f64, decay_rate: f64) -> Result<f64> {
508        if normalized_time < 0.0 {
510            return Err(MathEngineError::InvalidParameter {
511                parameter: "normalized_time".to_string(),
512                value: normalized_time,
513                constraint: "normalized_time >= 0.0".to_string(),
514            });
515        }
516
517        if decay_rate <= 0.0 {
518            return Err(MathEngineError::InvalidParameter {
519                parameter: "decay_rate".to_string(),
520                value: decay_rate,
521                constraint: "decay_rate > 0.0".to_string(),
522            });
523        }
524
525        let exp_neg_t = (-normalized_time).exp();
527        if !exp_neg_t.is_finite() {
528            return Err(MathEngineError::MathematicalOverflow {
529                operation: "exp(-t) calculation".to_string(),
530            });
531        }
532
533        let exponent = -decay_rate * exp_neg_t;
534        if !exponent.is_finite() {
535            return Err(MathEngineError::MathematicalOverflow {
536                operation: "-r * e^(-t) calculation".to_string(),
537            });
538        }
539
540        let numerator = 1.0 - exponent.exp();
541        let denominator = 1.0 - (-1.0_f64).exp();
542
543        if !numerator.is_finite() || !denominator.is_finite() || denominator.abs() < f64::EPSILON {
544            return Err(MathEngineError::MathematicalOverflow {
545                operation: "forgetting curve probability calculation".to_string(),
546            });
547        }
548
549        let probability = numerator / denominator;
550
551        Ok(probability.max(0.0).min(1.0))
553    }
554
555    fn consolidation_strength_formula(&self, time_hours: f64) -> Result<f64> {
558        if time_hours < 0.0 {
559            return Err(MathEngineError::InvalidParameter {
560                parameter: "time_hours".to_string(),
561                value: time_hours,
562                constraint: "time_hours >= 0.0".to_string(),
563            });
564        }
565
566        let exp_neg_t = (-time_hours).exp();
567        if !exp_neg_t.is_finite() {
568            return Err(MathEngineError::MathematicalOverflow {
569                operation: "exp(-t) in consolidation formula".to_string(),
570            });
571        }
572
573        let numerator = 1.0 - exp_neg_t;
574        let denominator = 1.0 + exp_neg_t;
575
576        if denominator.abs() < f64::EPSILON {
577            return Err(MathEngineError::MathematicalOverflow {
578                operation: "division by zero in consolidation formula".to_string(),
579            });
580        }
581
582        Ok(numerator / denominator)
583    }
584
585    fn calculate_new_memory_probability(
587        &self,
588        time_since_creation: f64,
589        params: &MemoryParameters,
590    ) -> Result<f64> {
591        let adjusted_consolidation = params.consolidation_strength * params.importance_score;
594        let normalized_time = time_since_creation / adjusted_consolidation.max(0.1);
595        self.forgetting_curve_formula(normalized_time, params.decay_rate)
596    }
597
598    fn validate_parameters(&self, params: &MemoryParameters) -> Result<()> {
600        if params.consolidation_strength < 0.0 {
601            return Err(MathEngineError::InvalidParameter {
602                parameter: "consolidation_strength".to_string(),
603                value: params.consolidation_strength,
604                constraint: "consolidation_strength >= 0.0".to_string(),
605            });
606        }
607
608        if params.decay_rate <= 0.0 {
609            return Err(MathEngineError::InvalidParameter {
610                parameter: "decay_rate".to_string(),
611                value: params.decay_rate,
612                constraint: "decay_rate > 0.0".to_string(),
613            });
614        }
615
616        if !(0.0..=1.0).contains(¶ms.importance_score) {
617            return Err(MathEngineError::InvalidParameter {
618                parameter: "importance_score".to_string(),
619                value: params.importance_score,
620                constraint: "0.0 <= importance_score <= 1.0".to_string(),
621            });
622        }
623
624        Ok(())
625    }
626}
627
628impl Default for MathEngine {
629    fn default() -> Self {
630        Self::new()
631    }
632}
633
634pub mod benchmarks {
636    use super::*;
637    use std::time::Instant;
638
639    pub fn benchmark_single_calculation(
641        engine: &MathEngine,
642        params: &MemoryParameters,
643        iterations: usize,
644    ) -> (f64, f64, f64) {
645        let mut times = Vec::with_capacity(iterations);
646
647        for _ in 0..iterations {
648            let start = Instant::now();
649            let _ = engine.calculate_recall_probability(params);
650            times.push(start.elapsed().as_nanos() as f64 / 1_000_000.0); }
652
653        let sum: f64 = times.iter().sum();
654        let avg = sum / times.len() as f64;
655
656        times.sort_by(|a, b| a.partial_cmp(b).unwrap());
657        let median = if times.len() % 2 == 0 {
658            (times[times.len() / 2 - 1] + times[times.len() / 2]) / 2.0
659        } else {
660            times[times.len() / 2]
661        };
662
663        let p99_index = ((times.len() as f64) * 0.99) as usize;
664        let p99 = times[p99_index.min(times.len() - 1)];
665
666        (avg, median, p99)
667    }
668
669    pub fn benchmark_batch_processing(
671        engine: &MathEngine,
672        batch_sizes: &[usize],
673    ) -> Vec<(usize, f64, f64)> {
674        let mut results = Vec::new();
675
676        for &batch_size in batch_sizes {
677            let params = vec![
678                MemoryParameters {
679                    consolidation_strength: 1.0,
680                    decay_rate: 1.0,
681                    last_accessed_at: Some(Utc::now() - chrono::Duration::hours(1)),
682                    created_at: Utc::now() - chrono::Duration::days(1),
683                    access_count: 5,
684                    importance_score: 0.5,
685                };
686                batch_size
687            ];
688
689            let start = Instant::now();
690            let result = engine.batch_calculate_recall_probability(¶ms);
691            let total_time = start.elapsed().as_millis() as f64;
692
693            if let Ok(_batch_result) = result {
694                let throughput = batch_size as f64 / (total_time / 1000.0); results.push((batch_size, total_time, throughput));
696            }
697        }
698
699        results
700    }
701}
702
703#[cfg(test)]
704mod tests {
705    use super::*;
706    use chrono::Duration;
707    use proptest::prelude::*;
708
709    fn create_test_params() -> MemoryParameters {
710        MemoryParameters {
711            consolidation_strength: 1.0,
712            decay_rate: 1.0,
713            last_accessed_at: Some(Utc::now() - Duration::hours(1)),
714            created_at: Utc::now() - Duration::days(1),
715            access_count: 5,
716            importance_score: 0.5,
717        }
718    }
719
720    #[test]
721    fn test_recall_probability_calculation() {
722        let engine = MathEngine::new();
723        let params = create_test_params();
724
725        let result = engine.calculate_recall_probability(¶ms).unwrap();
726
727        assert!(result.recall_probability >= 0.0);
728        assert!(result.recall_probability <= 1.0);
729        assert!(result.calculation_time_ms <= constants::MAX_CALCULATION_TIME_MS);
730    }
731
732    #[test]
733    fn test_consolidation_strength_update() {
734        let engine = MathEngine::new();
735        let interval = PgInterval {
736            months: 0,
737            days: 0,
738            microseconds: (2.0 * constants::MICROSECONDS_PER_HOUR) as i64, };
740
741        let result = engine.update_consolidation_strength(1.0, interval).unwrap();
742
743        assert!(result.new_consolidation_strength > 1.0);
744        assert!(result.new_consolidation_strength <= constants::MAX_CONSOLIDATION_STRENGTH);
745        assert!(result.calculation_time_ms <= constants::MAX_CALCULATION_TIME_MS);
746    }
747
748    #[test]
749    fn test_decay_rate_calculation() {
750        let engine = MathEngine::new();
751        let params = create_test_params();
752
753        let decay_rate = engine.calculate_decay_rate(¶ms).unwrap();
754
755        assert!(decay_rate > 0.0);
756        assert!(decay_rate <= 5.0);
757    }
758
759    #[test]
760    fn test_edge_case_never_accessed() {
761        let engine = MathEngine::new();
762        let mut params = create_test_params();
763        params.last_accessed_at = None;
764
765        let result = engine.calculate_recall_probability(¶ms).unwrap();
766
767        assert!(result.recall_probability >= 0.0);
768        assert!(result.recall_probability <= 1.0);
769    }
770
771    #[test]
772    fn test_edge_case_very_recent_access() {
773        let engine = MathEngine::new();
774        let mut params = create_test_params();
775        params.last_accessed_at = Some(Utc::now() - Duration::seconds(30));
776
777        let result = engine.calculate_recall_probability(¶ms).unwrap();
778
779        assert!(
782            result.recall_probability > 0.99,
783            "Very recent access should have >99% recall probability, got {}",
784            result.recall_probability
785        );
786        assert!(result.recall_probability <= 1.0);
787    }
788
789    #[test]
790    fn test_batch_processing() {
791        let engine = MathEngine::new();
792        let params = vec![create_test_params(); 100];
793
794        let result = engine.batch_calculate_recall_probability(¶ms).unwrap();
795
796        assert_eq!(result.processed_count, 100);
797        assert_eq!(result.results.len(), 100);
798        assert!(result.average_time_per_memory_ms < constants::MAX_CALCULATION_TIME_MS as f64);
799    }
800
801    #[test]
802    fn test_accuracy_validation() {
803        let engine = MathEngine::new();
804
805        assert!(engine.validate_accuracy(0.5, 0.5001).is_ok());
807
808        assert!(engine.validate_accuracy(0.5, 0.6).is_err());
810    }
811
812    proptest! {
813        #[test]
814        fn test_recall_probability_properties(
815            consolidation_strength in 0.1f64..10.0,
816            decay_rate in 0.1f64..5.0,
817            hours_ago in 0.1f64..168.0, importance_score in 0.0f64..1.0,
819            access_count in 0i32..1000,
820        ) {
821            let engine = MathEngine::new();
822            let params = MemoryParameters {
823                consolidation_strength,
824                decay_rate,
825                last_accessed_at: Some(Utc::now() - Duration::seconds((hours_ago * 3600.0) as i64)),
826                created_at: Utc::now() - Duration::days(1),
827                access_count,
828                importance_score,
829            };
830
831            let result = engine.calculate_recall_probability(¶ms);
832
833            if let Ok(calculation) = result {
834                assert!(calculation.recall_probability >= 0.0);
836                assert!(calculation.recall_probability <= 1.0);
837
838                assert!(calculation.calculation_time_ms <= constants::MAX_CALCULATION_TIME_MS);
840            }
841        }
842
843        #[test]
844        fn test_consolidation_strength_properties(
845            initial_strength in 0.1f64..10.0,
846            recall_interval_hours in 0.1f64..168.0,
847        ) {
848            let engine = MathEngine::new();
849            let interval = PgInterval {
850                months: 0,
851                days: 0,
852                microseconds: (recall_interval_hours * constants::MICROSECONDS_PER_HOUR) as i64,
853            };
854
855            let result = engine.update_consolidation_strength(initial_strength, interval);
856
857            if let Ok(update) = result {
858                assert!(update.new_consolidation_strength >= initial_strength);
860
861                assert!(update.new_consolidation_strength <= constants::MAX_CONSOLIDATION_STRENGTH);
863
864                assert!(update.calculation_time_ms <= constants::MAX_CALCULATION_TIME_MS);
866            }
867        }
868    }
869}