use std::collections::HashMap;
use std::time::{Duration, Instant};
use crate::core::{Expression, Number, BinaryOperator, UnaryOperator};
use super::{ComputeError, ComputeEngine};
use num_bigint::BigInt;
use num_traits::Signed;
#[derive(Debug, Clone)]
pub struct RuntimeConfig {
pub max_complexity_threshold: usize,
pub max_compute_time_ms: u64,
pub max_exponent: i64,
pub max_integer_digits: usize,
pub enable_complexity_check: bool,
pub enable_time_limit: bool,
}
impl Default for RuntimeConfig {
fn default() -> Self {
Self {
max_complexity_threshold: 10000,
max_compute_time_ms: 5000, max_exponent: 10000,
max_integer_digits: 1000,
enable_complexity_check: true,
enable_time_limit: true,
}
}
}
pub struct ComplexityAnalyzer {
config: RuntimeConfig,
}
impl ComplexityAnalyzer {
pub fn new(config: RuntimeConfig) -> Self {
Self { config }
}
pub fn calculate_complexity(&self, expr: &Expression) -> usize {
match expr {
Expression::Number(num) => self.number_complexity(num),
Expression::Variable(_) => 1,
Expression::Constant(_) => 1,
Expression::BinaryOp { op, left, right } => {
let left_complexity = self.calculate_complexity(left);
let right_complexity = self.calculate_complexity(right);
match op {
BinaryOperator::Power => {
let base_complexity = left_complexity;
let exp_complexity = right_complexity;
if let Expression::Number(Number::Integer(exp_int)) = right.as_ref() {
if exp_int.bits() > 64 || exp_int > &BigInt::from(self.config.max_exponent) {
return usize::MAX; }
}
base_complexity * exp_complexity * 10 + 100
}
BinaryOperator::Multiply | BinaryOperator::Divide => {
left_complexity + right_complexity + 5
}
_ => left_complexity + right_complexity + 1
}
}
Expression::UnaryOp { op, operand } => {
let operand_complexity = self.calculate_complexity(operand);
match op {
UnaryOperator::Factorial => operand_complexity * 20 + 50,
UnaryOperator::Exp => operand_complexity * 10 + 20,
_ => operand_complexity + 2
}
}
Expression::Function { args, .. } => {
args.iter().map(|arg| self.calculate_complexity(arg)).sum::<usize>() + 10
}
Expression::Matrix(rows) => {
rows.iter()
.flat_map(|row| row.iter())
.map(|elem| self.calculate_complexity(elem))
.sum::<usize>() + rows.len() * rows.get(0).map_or(0, |r| r.len()) * 2
}
Expression::Vector(elements) => {
elements.iter().map(|elem| self.calculate_complexity(elem)).sum::<usize>() + elements.len()
}
Expression::Set(elements) => {
elements.iter().map(|elem| self.calculate_complexity(elem)).sum::<usize>() + elements.len()
}
Expression::Interval { start, end, .. } => {
self.calculate_complexity(start) + self.calculate_complexity(end) + 2
}
}
}
fn number_complexity(&self, num: &Number) -> usize {
match num {
Number::Integer(int) => {
let digits = int.to_string().len();
if digits > self.config.max_integer_digits {
usize::MAX } else {
(digits / 10).max(1)
}
}
Number::Rational(rat) => {
let numer_digits = rat.numer().to_string().len();
let denom_digits = rat.denom().to_string().len();
let total_digits = numer_digits + denom_digits;
if total_digits > self.config.max_integer_digits {
usize::MAX
} else {
(total_digits / 10).max(1)
}
}
Number::Real(_) => 5,
Number::Complex { real, imaginary } => {
self.number_complexity(real) + self.number_complexity(imaginary)
}
Number::Symbolic(_) => 10,
Number::Float(_) => 1,
Number::Constant(_) => 2, }
}
pub fn is_too_complex(&self, expr: &Expression) -> bool {
if !self.config.enable_complexity_check {
return false;
}
let complexity = self.calculate_complexity(expr);
complexity == usize::MAX || complexity > self.config.max_complexity_threshold
}
pub fn is_safe_power(&self, base: &Expression, exponent: &Expression) -> bool {
if let Expression::Number(Number::Integer(exp_int)) = exponent {
if exp_int > &BigInt::from(self.config.max_exponent) {
return false;
}
if let Expression::Number(Number::Integer(base_int)) = base {
let base_abs = base_int.abs();
if exp_int > &BigInt::from(100000) {
return false;
}
if base_abs >= BigInt::from(10) && exp_int > &BigInt::from(1000) {
return false;
}
if base_abs >= BigInt::from(2) && exp_int > &BigInt::from(10000) {
return false;
}
}
}
true
}
}
#[derive(Debug, Clone)]
pub struct VariableManager {
variables: HashMap<String, Expression>,
numeric_variables: HashMap<String, Number>,
}
impl VariableManager {
pub fn new() -> Self {
Self {
variables: HashMap::new(),
numeric_variables: HashMap::new(),
}
}
pub fn set_variable(&mut self, name: String, value: Expression) -> Result<(), ComputeError> {
if !self.is_valid_variable_name(&name) {
return Err(ComputeError::domain_error(
format!("无效的变量名: {}", name)
));
}
if let Expression::Number(num) = &value {
self.numeric_variables.insert(name.clone(), num.clone());
} else {
self.numeric_variables.remove(&name);
}
self.variables.insert(name, value);
Ok(())
}
pub fn get_variable(&self, name: &str) -> Option<&Expression> {
self.variables.get(name)
}
pub fn get_numeric_variable(&self, name: &str) -> Option<&Number> {
self.numeric_variables.get(name)
}
pub fn get_all_variables(&self) -> &HashMap<String, Expression> {
&self.variables
}
pub fn get_all_numeric_variables(&self) -> &HashMap<String, Number> {
&self.numeric_variables
}
pub fn clear(&mut self) {
self.variables.clear();
self.numeric_variables.clear();
}
pub fn remove_variable(&mut self, name: &str) -> bool {
let removed_expr = self.variables.remove(name).is_some();
let removed_num = self.numeric_variables.remove(name).is_some();
removed_expr || removed_num
}
pub fn substitute_variables(&self, expr: &Expression) -> Expression {
match expr {
Expression::Variable(name) => {
if let Some(value) = self.get_variable(name) {
self.substitute_variables(value)
} else {
expr.clone()
}
}
Expression::BinaryOp { op, left, right } => {
Expression::BinaryOp {
op: op.clone(),
left: Box::new(self.substitute_variables(left)),
right: Box::new(self.substitute_variables(right)),
}
}
Expression::UnaryOp { op, operand } => {
Expression::UnaryOp {
op: op.clone(),
operand: Box::new(self.substitute_variables(operand)),
}
}
Expression::Function { name, args } => {
Expression::Function {
name: name.clone(),
args: args.iter().map(|arg| self.substitute_variables(arg)).collect(),
}
}
Expression::Matrix(rows) => {
Expression::Matrix(
rows.iter()
.map(|row| row.iter().map(|elem| self.substitute_variables(elem)).collect())
.collect()
)
}
Expression::Vector(elements) => {
Expression::Vector(
elements.iter().map(|elem| self.substitute_variables(elem)).collect()
)
}
Expression::Set(elements) => {
Expression::Set(
elements.iter().map(|elem| self.substitute_variables(elem)).collect()
)
}
Expression::Interval { start, end, start_inclusive, end_inclusive } => {
Expression::Interval {
start: Box::new(self.substitute_variables(start)),
end: Box::new(self.substitute_variables(end)),
start_inclusive: *start_inclusive,
end_inclusive: *end_inclusive,
}
}
_ => expr.clone(),
}
}
fn is_valid_variable_name(&self, name: &str) -> bool {
if name.is_empty() {
return false;
}
let first_char = name.chars().next().unwrap();
if !first_char.is_alphabetic() && first_char != '_' {
return false;
}
name.chars().all(|c| c.is_alphanumeric() || c == '_')
}
pub fn has_circular_reference(&self, name: &str) -> bool {
let mut visited = std::collections::HashSet::new();
self.check_circular_reference_recursive(name, &mut visited)
}
fn check_circular_reference_recursive(&self, name: &str, visited: &mut std::collections::HashSet<String>) -> bool {
if visited.contains(name) {
return true; }
if let Some(expr) = self.get_variable(name) {
visited.insert(name.to_string());
let has_cycle = self.check_expression_for_circular_reference(expr, visited);
visited.remove(name);
has_cycle
} else {
false
}
}
fn check_expression_for_circular_reference(&self, expr: &Expression, visited: &mut std::collections::HashSet<String>) -> bool {
match expr {
Expression::Variable(var_name) => {
self.check_circular_reference_recursive(var_name, visited)
}
Expression::BinaryOp { left, right, .. } => {
self.check_expression_for_circular_reference(left, visited) ||
self.check_expression_for_circular_reference(right, visited)
}
Expression::UnaryOp { operand, .. } => {
self.check_expression_for_circular_reference(operand, visited)
}
Expression::Function { args, .. } => {
args.iter().any(|arg| self.check_expression_for_circular_reference(arg, visited))
}
Expression::Matrix(rows) => {
rows.iter().any(|row|
row.iter().any(|elem| self.check_expression_for_circular_reference(elem, visited))
)
}
Expression::Vector(elements) => {
elements.iter().any(|elem| self.check_expression_for_circular_reference(elem, visited))
}
Expression::Set(elements) => {
elements.iter().any(|elem| self.check_expression_for_circular_reference(elem, visited))
}
Expression::Interval { start, end, .. } => {
self.check_expression_for_circular_reference(start, visited) ||
self.check_expression_for_circular_reference(end, visited)
}
_ => false,
}
}
}
pub struct RuntimeEnhancer {
complexity_analyzer: ComplexityAnalyzer,
variable_manager: VariableManager,
config: RuntimeConfig,
}
impl RuntimeEnhancer {
pub fn new(config: RuntimeConfig) -> Self {
Self {
complexity_analyzer: ComplexityAnalyzer::new(config.clone()),
variable_manager: VariableManager::new(),
config,
}
}
pub fn variable_manager_mut(&mut self) -> &mut VariableManager {
&mut self.variable_manager
}
pub fn variable_manager(&self) -> &VariableManager {
&self.variable_manager
}
pub fn safe_compute<E: ComputeEngine>(&self, expr: &Expression, engine: &E) -> Result<Expression, ComputeError> {
let substituted = self.variable_manager.substitute_variables(expr);
if self.complexity_analyzer.is_too_complex(&substituted) {
return Ok(substituted); }
if let Some(safe_expr) = self.check_and_handle_special_cases(&substituted)? {
return Ok(safe_expr);
}
if self.config.enable_time_limit {
self.compute_with_timeout(&substituted, engine)
} else {
engine.simplify(&substituted)
}
}
fn check_and_handle_special_cases(&self, expr: &Expression) -> Result<Option<Expression>, ComputeError> {
match expr {
Expression::BinaryOp { op: BinaryOperator::Power, left, right } => {
if !self.complexity_analyzer.is_safe_power(left, right) {
return Ok(Some(expr.clone()));
}
}
Expression::UnaryOp { op: UnaryOperator::Factorial, operand } => {
if let Expression::Number(Number::Integer(n)) = operand.as_ref() {
if n > &BigInt::from(1000) {
return Ok(Some(expr.clone()));
}
}
}
_ => {}
}
Ok(None)
}
fn compute_with_timeout<E: ComputeEngine>(&self, expr: &Expression, engine: &E) -> Result<Expression, ComputeError> {
let start_time = Instant::now();
let timeout = Duration::from_millis(self.config.max_compute_time_ms);
let result = engine.simplify(expr)?;
if start_time.elapsed() > timeout {
Ok(expr.clone())
} else {
Ok(result)
}
}
pub fn update_config(&mut self, config: RuntimeConfig) {
self.config = config.clone();
self.complexity_analyzer = ComplexityAnalyzer::new(config);
}
pub fn get_config(&self) -> &RuntimeConfig {
&self.config
}
}
impl Default for VariableManager {
fn default() -> Self {
Self::new()
}
}
impl Default for RuntimeEnhancer {
fn default() -> Self {
Self::new(RuntimeConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::Expression;
#[test]
fn test_complexity_analyzer() {
let config = RuntimeConfig::default();
let analyzer = ComplexityAnalyzer::new(config);
let simple_expr = Expression::variable("x");
assert_eq!(analyzer.calculate_complexity(&simple_expr), 1);
let complex_expr = Expression::binary_op(
BinaryOperator::Power,
Expression::number(Number::from(10)),
Expression::number(Number::from(10000))
);
assert!(analyzer.is_too_complex(&complex_expr));
}
#[test]
fn test_variable_manager() {
let mut manager = VariableManager::new();
let x_value = Expression::number(Number::from(10));
manager.set_variable("x".to_string(), x_value.clone()).unwrap();
assert_eq!(manager.get_variable("x"), Some(&x_value));
let expr = Expression::binary_op(
BinaryOperator::Add,
Expression::variable("x"),
Expression::number(Number::from(5))
);
let substituted = manager.substitute_variables(&expr);
if let Expression::BinaryOp { left, .. } = substituted {
match left.as_ref() {
Expression::Number(_) => {
println!("变量替换成功");
}
Expression::Variable(_) => {
panic!("变量替换失败");
}
_ => {
panic!("意外的表达式类型");
}
}
} else {
panic!("Expected binary operation");
}
}
#[test]
fn test_circular_reference_detection() {
let mut manager = VariableManager::new();
manager.set_variable("x".to_string(), Expression::variable("y")).unwrap();
manager.set_variable("y".to_string(), Expression::variable("x")).unwrap();
assert!(manager.has_circular_reference("x"));
assert!(manager.has_circular_reference("y"));
}
#[test]
fn test_safe_power_check() {
let config = RuntimeConfig::default();
let analyzer = ComplexityAnalyzer::new(config);
let safe_base = Expression::number(Number::from(2));
let safe_exp = Expression::number(Number::from(10));
assert!(analyzer.is_safe_power(&safe_base, &safe_exp));
let unsafe_base = Expression::number(Number::from(10));
let unsafe_exp = Expression::number(Number::from(10000));
assert!(!analyzer.is_safe_power(&unsafe_base, &unsafe_exp));
}
}