use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::core::{Expression, Number, MathConstant, BinaryOperator};
use crate::api::CacheConfig;
use super::{ComputeEngine, ComputeError};
use super::compute::BasicComputeEngine;
use super::cache::{CacheManager, FastCacheKey, ExactCacheKey, SymbolicCacheKey};
use num_bigint::BigInt;
pub struct CachedComputeEngine {
base_engine: BasicComputeEngine,
cache_manager: Arc<Mutex<CacheManager>>,
}
impl CachedComputeEngine {
pub fn new(cache_config: CacheConfig) -> Self {
Self {
base_engine: BasicComputeEngine::new(),
cache_manager: Arc::new(Mutex::new(CacheManager::new(cache_config))),
}
}
pub fn get_cache_stats(&self) -> Result<super::cache::CacheStats, ComputeError> {
self.cache_manager.lock()
.map_err(|_| ComputeError::internal("无法获取缓存管理器锁"))?
.cache()
.get_stats()
.pipe(Ok)
}
pub fn get_cache_usage(&self) -> Result<super::cache::CacheUsageInfo, ComputeError> {
self.cache_manager.lock()
.map_err(|_| ComputeError::internal("无法获取缓存管理器锁"))?
.cache()
.get_usage_info()
.pipe(Ok)
}
pub fn cleanup_cache(&self) -> Result<(), ComputeError> {
self.cache_manager.lock()
.map_err(|_| ComputeError::internal("无法获取缓存管理器锁"))?
.force_cleanup();
Ok(())
}
pub fn clear_cache(&self) -> Result<(), ComputeError> {
self.cache_manager.lock()
.map_err(|_| ComputeError::internal("无法获取缓存管理器锁"))?
.cache()
.clear_all();
Ok(())
}
fn try_fast_binary_op(&self, left: &Number, right: &Number, op: &BinaryOperator) -> Option<Number> {
if let (Number::Integer(l), Number::Integer(r)) = (left, right) {
if let (Ok(l_i64), Ok(r_i64)) = (l.try_into(), r.try_into()) {
let l_i64: i64 = l_i64;
let r_i64: i64 = r_i64;
if l_i64.abs() < 1_000_000 && r_i64.abs() < 1_000_000 {
let key = FastCacheKey::BinaryOp(l_i64, r_i64, op.clone());
if let Ok(cache_manager) = self.cache_manager.lock() {
if let Some(result) = cache_manager.cache().get_fast(&key) {
return Some(Number::Integer(BigInt::from(result)));
}
}
}
}
}
None
}
fn cache_fast_binary_op(&self, left: &Number, right: &Number, op: &BinaryOperator, result: &Number) {
if let (Number::Integer(l), Number::Integer(r), Number::Integer(res)) = (left, right, result) {
if let (Ok(l_i64), Ok(r_i64), Ok(res_i64)) = (l.try_into(), r.try_into(), res.try_into()) {
let l_i64: i64 = l_i64;
let r_i64: i64 = r_i64;
let res_i64: i64 = res_i64;
if l_i64.abs() < 1_000_000 && r_i64.abs() < 1_000_000 && res_i64.abs() < 10_000_000 {
let key = FastCacheKey::BinaryOp(l_i64, r_i64, op.clone());
if let Ok(cache_manager) = self.cache_manager.lock() {
let cost = match op {
BinaryOperator::Add | BinaryOperator::Subtract => 1,
BinaryOperator::Multiply => 2,
BinaryOperator::Divide => 5,
BinaryOperator::Power => 10,
_ => 3,
};
cache_manager.cache().put_fast(key, res_i64, cost);
}
}
}
}
}
fn try_exact_cache(&self, operand1: &Number, operand2: Option<&Number>, operation: &str) -> Option<Number> {
let key = ExactCacheKey {
operand1: operand1.clone(),
operand2: operand2.cloned(),
operation: operation.to_string(),
};
if let Ok(cache_manager) = self.cache_manager.lock() {
cache_manager.cache().get_exact(&key)
} else {
None
}
}
fn cache_exact_result(&self, operand1: &Number, operand2: Option<&Number>, operation: &str, result: &Number, cost: u32) {
let key = ExactCacheKey {
operand1: operand1.clone(),
operand2: operand2.cloned(),
operation: operation.to_string(),
};
if let Ok(cache_manager) = self.cache_manager.lock() {
cache_manager.cache().put_exact(key, result.clone(), cost);
}
}
fn try_symbolic_cache(&self, expr: &Expression, operation: &str, variable: Option<&str>) -> Option<Expression> {
let key = SymbolicCacheKey {
expression: expr.clone(),
operation: operation.to_string(),
variable: variable.map(|s| s.to_string()),
};
if let Ok(cache_manager) = self.cache_manager.lock() {
cache_manager.cache().get_symbolic(&key)
} else {
None
}
}
fn cache_symbolic_result(&self, expr: &Expression, operation: &str, variable: Option<&str>, result: &Expression, cost: u32) {
let key = SymbolicCacheKey {
expression: expr.clone(),
operation: operation.to_string(),
variable: variable.map(|s| s.to_string()),
};
if let Ok(cache_manager) = self.cache_manager.lock() {
cache_manager.cache().put_symbolic(key, result.clone(), cost);
}
}
fn periodic_cleanup(&self) {
if let Ok(mut cache_manager) = self.cache_manager.lock() {
cache_manager.periodic_cleanup();
}
}
fn compute_complexity(&self, expr: &Expression) -> u32 {
match expr {
Expression::Number(_) | Expression::Variable(_) | Expression::Constant(_) => 1,
Expression::UnaryOp { operand, .. } => 1 + self.compute_complexity(operand),
Expression::BinaryOp { left, right, .. } => 1 + self.compute_complexity(left) + self.compute_complexity(right),
Expression::Function { args, .. } => 5 + args.iter().map(|arg| self.compute_complexity(arg)).sum::<u32>(),
Expression::Matrix(rows) => {
10 + rows.iter().flat_map(|row| row.iter()).map(|elem| self.compute_complexity(elem)).sum::<u32>()
}
Expression::Vector(elements) => {
5 + elements.iter().map(|elem| self.compute_complexity(elem)).sum::<u32>()
}
Expression::Set(elements) => {
3 + elements.iter().map(|elem| self.compute_complexity(elem)).sum::<u32>()
}
Expression::Interval { start, end, .. } => {
2 + self.compute_complexity(start) + self.compute_complexity(end)
}
}
}
}
impl ComputeEngine for CachedComputeEngine {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn simplify(&self, expr: &Expression) -> Result<Expression, ComputeError> {
self.periodic_cleanup();
if let Some(cached_result) = self.try_symbolic_cache(expr, "simplify", None) {
return Ok(cached_result);
}
let result = self.base_engine.simplify(expr)?;
let complexity = self.compute_complexity(expr);
self.cache_symbolic_result(expr, "simplify", None, &result, complexity);
Ok(result)
}
fn evaluate(&self, expr: &Expression, vars: &HashMap<String, Number>) -> Result<Number, ComputeError> {
if let Expression::BinaryOp { op, left, right } = expr {
if let (Expression::Number(l), Expression::Number(r)) = (left.as_ref(), right.as_ref()) {
if let Some(cached_result) = self.try_fast_binary_op(l, r, op) {
return Ok(cached_result);
}
let operation = format!("evaluate_{:?}", op);
if let Some(cached_result) = self.try_exact_cache(l, Some(r), &operation) {
return Ok(cached_result);
}
let result = self.base_engine.evaluate(expr, vars)?;
self.cache_fast_binary_op(l, r, op, &result);
self.cache_exact_result(l, Some(r), &operation, &result, 5);
return Ok(result);
}
}
self.base_engine.evaluate(expr, vars)
}
fn differentiate(&self, expr: &Expression, var: &str) -> Result<Expression, ComputeError> {
self.periodic_cleanup();
if let Some(cached_result) = self.try_symbolic_cache(expr, "differentiate", Some(var)) {
return Ok(cached_result);
}
let result = self.base_engine.differentiate(expr, var)?;
let complexity = self.compute_complexity(expr);
self.cache_symbolic_result(expr, "differentiate", Some(var), &result, complexity * 2);
Ok(result)
}
fn integrate(&self, expr: &Expression, var: &str) -> Result<Expression, ComputeError> {
self.periodic_cleanup();
if let Some(cached_result) = self.try_symbolic_cache(expr, "integrate", Some(var)) {
return Ok(cached_result);
}
let result = self.base_engine.integrate(expr, var)?;
let complexity = self.compute_complexity(expr);
self.cache_symbolic_result(expr, "integrate", Some(var), &result, complexity * 5);
Ok(result)
}
fn expand(&self, expr: &Expression) -> Result<Expression, ComputeError> {
self.periodic_cleanup();
if let Some(cached_result) = self.try_symbolic_cache(expr, "expand", None) {
return Ok(cached_result);
}
let result = self.base_engine.expand(expr)?;
let complexity = self.compute_complexity(expr);
self.cache_symbolic_result(expr, "expand", None, &result, complexity * 3);
Ok(result)
}
fn factor(&self, expr: &Expression) -> Result<Expression, ComputeError> {
self.periodic_cleanup();
if let Some(cached_result) = self.try_symbolic_cache(expr, "factor", None) {
return Ok(cached_result);
}
let result = self.base_engine.factor(expr)?;
let complexity = self.compute_complexity(expr);
self.cache_symbolic_result(expr, "factor", None, &result, complexity * 4);
Ok(result)
}
fn collect(&self, expr: &Expression, var: &str) -> Result<Expression, ComputeError> {
self.periodic_cleanup();
if let Some(cached_result) = self.try_symbolic_cache(expr, "collect", Some(var)) {
return Ok(cached_result);
}
let result = self.base_engine.collect(expr, var)?;
let complexity = self.compute_complexity(expr);
self.cache_symbolic_result(expr, "collect", Some(var), &result, complexity * 2);
Ok(result)
}
fn limit(&self, expr: &Expression, var: &str, point: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.limit(expr, var, point)
}
fn series(&self, expr: &Expression, var: &str, point: &Expression, order: usize) -> Result<Expression, ComputeError> {
self.base_engine.series(expr, var, point, order)
}
fn numerical_evaluate(&self, expr: &Expression, vars: &HashMap<String, f64>) -> Result<f64, ComputeError> {
self.base_engine.numerical_evaluate(expr, vars)
}
fn constant_to_number(&self, constant: &MathConstant) -> Result<Number, ComputeError> {
self.base_engine.constant_to_number(constant)
}
fn simplify_constants(&self, expr: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.simplify_constants(expr)
}
fn polynomial_divide(&self, dividend: &Expression, divisor: &Expression) -> Result<(Expression, Expression), ComputeError> {
self.base_engine.polynomial_divide(dividend, divisor)
}
fn polynomial_gcd(&self, a: &Expression, b: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.polynomial_gcd(a, b)
}
fn gcd(&self, a: &Expression, b: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.gcd(a, b)
}
fn lcm(&self, a: &Expression, b: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.lcm(a, b)
}
fn is_prime(&self, n: &Expression) -> Result<bool, ComputeError> {
self.base_engine.is_prime(n)
}
fn prime_factors(&self, n: &Expression) -> Result<Vec<Expression>, ComputeError> {
self.base_engine.prime_factors(n)
}
fn binomial(&self, n: &Expression, k: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.binomial(n, k)
}
fn permutation(&self, n: &Expression, k: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.permutation(n, k)
}
fn mean(&self, values: &[Expression]) -> Result<Expression, ComputeError> {
self.base_engine.mean(values)
}
fn variance(&self, values: &[Expression]) -> Result<Expression, ComputeError> {
self.base_engine.variance(values)
}
fn standard_deviation(&self, values: &[Expression]) -> Result<Expression, ComputeError> {
self.base_engine.standard_deviation(values)
}
fn solve(&self, equation: &Expression, var: &str) -> Result<Vec<Expression>, ComputeError> {
self.base_engine.solve(equation, var)
}
fn solve_system(&self, equations: &[Expression], vars: &[String]) -> Result<Vec<HashMap<String, Expression>>, ComputeError> {
self.base_engine.solve_system(equations, vars)
}
fn matrix_add(&self, a: &Expression, b: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.matrix_add(a, b)
}
fn matrix_multiply(&self, a: &Expression, b: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.matrix_multiply(a, b)
}
fn matrix_determinant(&self, matrix: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.matrix_determinant(matrix)
}
fn matrix_inverse(&self, matrix: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.matrix_inverse(matrix)
}
fn complex_conjugate(&self, expr: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.complex_conjugate(expr)
}
fn complex_modulus(&self, expr: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.complex_modulus(expr)
}
fn complex_argument(&self, expr: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.complex_argument(expr)
}
fn vector_dot(&self, a: &Expression, b: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.vector_dot(a, b)
}
fn vector_cross(&self, a: &Expression, b: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.vector_cross(a, b)
}
fn vector_norm(&self, v: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.vector_norm(v)
}
fn set_union(&self, a: &Expression, b: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.set_union(a, b)
}
fn set_intersection(&self, a: &Expression, b: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.set_intersection(a, b)
}
fn set_difference(&self, a: &Expression, b: &Expression) -> Result<Expression, ComputeError> {
self.base_engine.set_difference(a, b)
}
}
trait Pipe<T> {
fn pipe<U, F>(self, f: F) -> U
where
F: FnOnce(T) -> U;
}
impl<T> Pipe<T> for T {
fn pipe<U, F>(self, f: F) -> U
where
F: FnOnce(T) -> U,
{
f(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{Expression, Number, BinaryOperator};
use crate::api::CacheConfig;
use num_bigint::BigInt;
#[test]
fn test_cached_engine_creation() {
let config = CacheConfig::default();
let engine = CachedComputeEngine::new(config);
let stats = engine.get_cache_stats().unwrap();
assert_eq!(stats.total_hit_rate(), 0.0);
let usage = engine.get_cache_usage().unwrap();
assert_eq!(usage.total_usage_rate(), 0.0);
}
#[test]
fn test_fast_cache_integration() {
let config = CacheConfig::default();
let engine = CachedComputeEngine::new(config);
let expr = Expression::BinaryOp {
op: BinaryOperator::Add,
left: Box::new(Expression::Number(Number::Integer(BigInt::from(2)))),
right: Box::new(Expression::Number(Number::Integer(BigInt::from(3)))),
};
let vars = HashMap::new();
let result1 = engine.evaluate(&expr, &vars);
let result2 = engine.evaluate(&expr, &vars);
assert_eq!(result1.is_ok(), result2.is_ok());
let stats = engine.get_cache_stats().unwrap();
assert!(stats.fast_hits > 0 || stats.exact_hits > 0);
}
#[test]
fn test_symbolic_cache_integration() {
let config = CacheConfig::default();
let engine = CachedComputeEngine::new(config);
let expr = Expression::BinaryOp {
op: BinaryOperator::Add,
left: Box::new(Expression::Variable("x".to_string())),
right: Box::new(Expression::Variable("x".to_string())),
};
let result1 = engine.simplify(&expr);
let result2 = engine.simplify(&expr);
assert_eq!(result1.is_ok(), result2.is_ok());
let stats = engine.get_cache_stats().unwrap();
assert!(stats.symbolic_hits > 0);
}
#[test]
fn test_cache_cleanup() {
let config = CacheConfig::default();
let engine = CachedComputeEngine::new(config);
let expr = Expression::Variable("x".to_string());
let _ = engine.simplify(&expr);
assert!(engine.cleanup_cache().is_ok());
assert!(engine.clear_cache().is_ok());
let usage = engine.get_cache_usage().unwrap();
assert_eq!(usage.total_usage_rate(), 0.0);
}
#[test]
fn test_complexity_calculation() {
let config = CacheConfig::default();
let engine = CachedComputeEngine::new(config);
let simple = Expression::Variable("x".to_string());
assert_eq!(engine.compute_complexity(&simple), 1);
let complex = Expression::BinaryOp {
op: BinaryOperator::Add,
left: Box::new(Expression::Variable("x".to_string())),
right: Box::new(Expression::BinaryOp {
op: BinaryOperator::Multiply,
left: Box::new(Expression::Variable("y".to_string())),
right: Box::new(Expression::Number(Number::Integer(BigInt::from(2)))),
}),
};
assert!(engine.compute_complexity(&complex) > 1);
}
}