use super::{Expression, Number};
use std::rc::Rc;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct MemoryStats {
pub active_expressions: usize,
pub shared_expressions: usize,
pub cache_hits: usize,
pub cache_misses: usize,
pub cow_triggers: usize,
pub estimated_memory_usage: usize,
pub last_updated: Instant,
}
impl Default for MemoryStats {
fn default() -> Self {
Self {
active_expressions: 0,
shared_expressions: 0,
cache_hits: 0,
cache_misses: 0,
cow_triggers: 0,
estimated_memory_usage: 0,
last_updated: Instant::now(),
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryConfig {
pub enable_sharing: bool,
pub enable_cow: bool,
pub max_hash_cache_size: usize,
pub max_expression_cache_size: usize,
pub cleanup_threshold: usize,
pub cleanup_interval: Duration,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
enable_sharing: true,
enable_cow: true,
max_hash_cache_size: 10000,
max_expression_cache_size: 5000,
cleanup_threshold: 100 * 1024 * 1024, cleanup_interval: Duration::from_secs(60), }
}
}
#[derive(Debug, Clone)]
pub struct SharedExpression {
inner: Rc<Expression>,
hash: Option<u64>,
}
impl SharedExpression {
pub fn new(expr: Expression) -> Self {
Self {
inner: Rc::new(expr),
hash: None,
}
}
pub fn from_rc(rc: Rc<Expression>) -> Self {
Self {
inner: rc,
hash: None,
}
}
pub fn as_ref(&self) -> &Expression {
&self.inner
}
pub fn ref_count(&self) -> usize {
Rc::strong_count(&self.inner)
}
pub fn is_unique(&self) -> bool {
Rc::strong_count(&self.inner) == 1
}
pub fn make_mut(&mut self) -> &mut Expression {
if !self.is_unique() {
self.inner = Rc::new((*self.inner).clone());
self.hash = None; }
Rc::get_mut(&mut self.inner).unwrap()
}
pub fn get_mut(&mut self) -> Option<&mut Expression> {
if self.is_unique() {
self.hash = None; Rc::get_mut(&mut self.inner)
} else {
None
}
}
pub fn get_hash(&mut self) -> u64 {
if let Some(hash) = self.hash {
hash
} else {
let hash = calculate_expression_hash(&self.inner);
self.hash = Some(hash);
hash
}
}
pub fn clone_shared(&self) -> Self {
Self {
inner: self.inner.clone(),
hash: self.hash,
}
}
pub fn into_owned(self) -> Expression {
match Rc::try_unwrap(self.inner) {
Ok(expr) => expr,
Err(rc) => (*rc).clone(),
}
}
}
impl PartialEq for SharedExpression {
fn eq(&self, other: &Self) -> bool {
if Rc::ptr_eq(&self.inner, &other.inner) {
return true;
}
self.inner.as_ref() == other.inner.as_ref()
}
}
impl Hash for SharedExpression {
fn hash<H: Hasher>(&self, state: &mut H) {
let hash = if let Some(hash) = self.hash {
hash
} else {
calculate_expression_hash(&self.inner)
};
hash.hash(state);
}
}
pub struct MemoryManager {
config: MemoryConfig,
stats: MemoryStats,
hash_cache: HashMap<*const Expression, u64>,
expression_pool: HashMap<u64, Rc<Expression>>,
expression_counter: AtomicUsize,
last_cleanup: Instant,
}
impl MemoryManager {
pub fn new() -> Self {
Self::with_config(MemoryConfig::default())
}
pub fn with_config(config: MemoryConfig) -> Self {
Self {
config,
stats: MemoryStats::default(),
hash_cache: HashMap::new(),
expression_pool: HashMap::new(),
expression_counter: AtomicUsize::new(0),
last_cleanup: Instant::now(),
}
}
pub fn create_shared(&mut self, expr: Expression) -> SharedExpression {
self.expression_counter.fetch_add(1, Ordering::Relaxed);
if !self.config.enable_sharing {
return SharedExpression::new(expr);
}
let hash = calculate_expression_hash(&expr);
if let Some(existing) = self.expression_pool.get(&hash) {
if **existing == expr {
self.stats.cache_hits += 1;
return SharedExpression::from_rc(existing.clone());
}
}
self.stats.cache_misses += 1;
let rc = Rc::new(expr);
if self.expression_pool.len() < self.config.max_expression_cache_size {
self.expression_pool.insert(hash, rc.clone());
}
SharedExpression::from_rc(rc)
}
pub fn get_hash(&mut self, expr: &Expression) -> u64 {
let ptr = expr as *const Expression;
if let Some(&hash) = self.hash_cache.get(&ptr) {
self.stats.cache_hits += 1;
return hash;
}
self.stats.cache_misses += 1;
let hash = calculate_expression_hash(expr);
if self.hash_cache.len() < self.config.max_hash_cache_size {
self.hash_cache.insert(ptr, hash);
}
hash
}
pub fn update_stats(&mut self) {
self.stats.active_expressions = self.expression_counter.load(Ordering::Relaxed);
self.stats.shared_expressions = self.expression_pool.len();
self.stats.estimated_memory_usage = self.estimate_memory_usage();
self.stats.last_updated = Instant::now();
if self.should_cleanup() {
self.cleanup();
}
}
pub fn get_stats(&mut self) -> &MemoryStats {
self.update_stats();
&self.stats
}
fn estimate_memory_usage(&self) -> usize {
let hash_cache_size = self.hash_cache.len() * (std::mem::size_of::<*const Expression>() + std::mem::size_of::<u64>());
let pool_size = self.expression_pool.len() * std::mem::size_of::<Rc<Expression>>();
let estimated_expr_size = self.expression_pool.len() * 100;
hash_cache_size + pool_size + estimated_expr_size
}
fn should_cleanup(&self) -> bool {
let now = Instant::now();
let time_elapsed = now.duration_since(self.last_cleanup) >= self.config.cleanup_interval;
let memory_threshold = self.estimate_memory_usage() >= self.config.cleanup_threshold;
time_elapsed || memory_threshold
}
pub fn cleanup(&mut self) {
let initial_hash_cache_size = self.hash_cache.len();
let initial_pool_size = self.expression_pool.len();
self.hash_cache.retain(|&ptr, _| {
false
});
self.expression_pool.retain(|_, rc| {
Rc::strong_count(rc) > 1
});
self.last_cleanup = Instant::now();
println!("内存清理完成: 哈希缓存 {} -> {}, 表达式池 {} -> {}",
initial_hash_cache_size, self.hash_cache.len(),
initial_pool_size, self.expression_pool.len());
}
pub fn clear_all(&mut self) {
self.hash_cache.clear();
self.expression_pool.clear();
self.stats = MemoryStats::default();
self.last_cleanup = Instant::now();
}
pub fn config(&self) -> &MemoryConfig {
&self.config
}
pub fn set_config(&mut self, config: MemoryConfig) {
self.config = config;
}
}
impl Default for MemoryManager {
fn default() -> Self {
Self::new()
}
}
pub fn calculate_expression_hash(expr: &Expression) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
hash_expression(expr, &mut hasher);
hasher.finish()
}
fn hash_expression<H: Hasher>(expr: &Expression, hasher: &mut H) {
match expr {
Expression::Number(n) => {
0u8.hash(hasher);
hash_number(n, hasher);
}
Expression::Variable(name) => {
1u8.hash(hasher);
name.hash(hasher);
}
Expression::Constant(c) => {
2u8.hash(hasher);
std::mem::discriminant(c).hash(hasher);
}
Expression::BinaryOp { op, left, right } => {
3u8.hash(hasher);
std::mem::discriminant(op).hash(hasher);
hash_expression(left, hasher);
hash_expression(right, hasher);
}
Expression::UnaryOp { op, operand } => {
4u8.hash(hasher);
std::mem::discriminant(op).hash(hasher);
hash_expression(operand, hasher);
}
Expression::Function { name, args } => {
5u8.hash(hasher);
name.hash(hasher);
args.len().hash(hasher);
for arg in args {
hash_expression(arg, hasher);
}
}
Expression::Matrix(rows) => {
6u8.hash(hasher);
rows.len().hash(hasher);
if !rows.is_empty() {
rows[0].len().hash(hasher);
}
for row in rows {
for elem in row {
hash_expression(elem, hasher);
}
}
}
Expression::Vector(elements) => {
7u8.hash(hasher);
elements.len().hash(hasher);
for elem in elements {
hash_expression(elem, hasher);
}
}
Expression::Set(elements) => {
8u8.hash(hasher);
elements.len().hash(hasher);
for elem in elements {
hash_expression(elem, hasher);
}
}
Expression::Interval { start, end, start_inclusive, end_inclusive } => {
9u8.hash(hasher);
hash_expression(start, hasher);
hash_expression(end, hasher);
start_inclusive.hash(hasher);
end_inclusive.hash(hasher);
}
}
}
fn hash_number<H: Hasher>(number: &Number, hasher: &mut H) {
match number {
Number::Integer(i) => {
0u8.hash(hasher);
i.to_string().hash(hasher);
}
Number::Rational(r) => {
1u8.hash(hasher);
r.numer().to_string().hash(hasher);
r.denom().to_string().hash(hasher);
}
Number::Real(r) => {
2u8.hash(hasher);
r.to_string().hash(hasher);
}
Number::Complex { real, imaginary } => {
3u8.hash(hasher);
hash_number(real, hasher);
hash_number(imaginary, hasher);
}
Number::Symbolic(expr) => {
4u8.hash(hasher);
hash_expression(expr, hasher);
}
Number::Float(f) => {
5u8.hash(hasher);
f.to_bits().hash(hasher);
}
Number::Constant(c) => {
6u8.hash(hasher);
std::mem::discriminant(c).hash(hasher);
}
}
}
pub struct ExpressionComparator {
hash_cache: HashMap<*const Expression, u64>,
}
impl ExpressionComparator {
pub fn new() -> Self {
Self {
hash_cache: HashMap::new(),
}
}
pub fn fast_eq(&mut self, left: &Expression, right: &Expression) -> bool {
if std::ptr::eq(left, right) {
return true;
}
let left_hash = self.get_cached_hash(left);
let right_hash = self.get_cached_hash(right);
if left_hash != right_hash {
return false;
}
left == right
}
fn get_cached_hash(&mut self, expr: &Expression) -> u64 {
let ptr = expr as *const Expression;
if let Some(&hash) = self.hash_cache.get(&ptr) {
return hash;
}
let hash = calculate_expression_hash(expr);
self.hash_cache.insert(ptr, hash);
hash
}
pub fn clear_cache(&mut self) {
self.hash_cache.clear();
}
}
impl Default for ExpressionComparator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CowExpression {
inner: SharedExpression,
modified: bool,
}
impl CowExpression {
pub fn new(expr: Expression) -> Self {
Self {
inner: SharedExpression::new(expr),
modified: false,
}
}
pub fn from_shared(shared: SharedExpression) -> Self {
Self {
inner: shared,
modified: false,
}
}
pub fn as_ref(&self) -> &Expression {
self.inner.as_ref()
}
pub fn as_mut(&mut self) -> &mut Expression {
self.modified = true;
self.inner.make_mut()
}
pub fn is_modified(&self) -> bool {
self.modified
}
pub fn ref_count(&self) -> usize {
self.inner.ref_count()
}
pub fn into_owned(self) -> Expression {
self.inner.into_owned()
}
pub fn into_shared(self) -> SharedExpression {
self.inner
}
}
impl PartialEq for CowExpression {
fn eq(&self, other: &Self) -> bool {
self.inner == other.inner
}
}
pub struct MemoryMonitor {
manager: MemoryManager,
enabled: bool,
interval: Duration,
last_check: Instant,
}
impl MemoryMonitor {
pub fn new() -> Self {
Self {
manager: MemoryManager::new(),
enabled: true,
interval: Duration::from_secs(30),
last_check: Instant::now(),
}
}
pub fn enable(&mut self) {
self.enabled = true;
}
pub fn disable(&mut self) {
self.enabled = false;
}
pub fn set_interval(&mut self, interval: Duration) {
self.interval = interval;
}
pub fn check(&mut self) -> Option<&MemoryStats> {
if !self.enabled {
return None;
}
let now = Instant::now();
if now.duration_since(self.last_check) >= self.interval {
self.last_check = now;
Some(self.manager.get_stats())
} else {
None
}
}
pub fn manager(&mut self) -> &mut MemoryManager {
&mut self.manager
}
pub fn stats(&mut self) -> &MemoryStats {
self.manager.get_stats()
}
pub fn cleanup(&mut self) {
self.manager.cleanup();
}
}
impl Default for MemoryMonitor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{Expression, Number};
use num_bigint::BigInt;
#[test]
fn test_shared_expression_creation() {
let expr = Expression::Number(Number::Integer(BigInt::from(42)));
let shared = SharedExpression::new(expr.clone());
assert_eq!(shared.as_ref(), &expr);
assert_eq!(shared.ref_count(), 1);
assert!(shared.is_unique());
}
#[test]
fn test_shared_expression_cloning() {
let expr = Expression::Number(Number::Integer(BigInt::from(42)));
let shared1 = SharedExpression::new(expr.clone());
let shared2 = shared1.clone_shared();
assert_eq!(shared1.ref_count(), 2);
assert_eq!(shared2.ref_count(), 2);
assert!(!shared1.is_unique());
assert!(!shared2.is_unique());
assert_eq!(shared1, shared2);
}
#[test]
fn test_cow_functionality() {
let expr = Expression::Number(Number::Integer(BigInt::from(42)));
let mut shared = SharedExpression::new(expr.clone());
let shared_clone = shared.clone_shared();
assert_eq!(shared.ref_count(), 2);
let _mutable_ref = shared.make_mut();
assert_eq!(shared.ref_count(), 1);
assert_eq!(shared_clone.ref_count(), 1);
}
#[test]
fn test_expression_hashing() {
let expr1 = Expression::Number(Number::Integer(BigInt::from(42)));
let expr2 = Expression::Number(Number::Integer(BigInt::from(42)));
let expr3 = Expression::Number(Number::Integer(BigInt::from(43)));
let hash1 = calculate_expression_hash(&expr1);
let hash2 = calculate_expression_hash(&expr2);
let hash3 = calculate_expression_hash(&expr3);
assert_eq!(hash1, hash2);
assert_ne!(hash1, hash3);
}
#[test]
fn test_memory_manager() {
let mut manager = MemoryManager::new();
let expr1 = Expression::Number(Number::Integer(BigInt::from(42)));
let expr2 = Expression::Number(Number::Integer(BigInt::from(42)));
let shared1 = manager.create_shared(expr1);
let shared2 = manager.create_shared(expr2);
assert_eq!(shared1, shared2);
let stats = manager.get_stats();
assert!(stats.cache_hits > 0 || stats.cache_misses > 0);
}
#[test]
fn test_expression_comparator() {
let mut comparator = ExpressionComparator::new();
let expr1 = Expression::Number(Number::Integer(BigInt::from(42)));
let expr2 = Expression::Number(Number::Integer(BigInt::from(42)));
let expr3 = Expression::Number(Number::Integer(BigInt::from(43)));
assert!(comparator.fast_eq(&expr1, &expr2));
assert!(!comparator.fast_eq(&expr1, &expr3));
}
#[test]
fn test_cow_expression() {
let expr = Expression::Number(Number::Integer(BigInt::from(42)));
let mut cow = CowExpression::new(expr.clone());
assert!(!cow.is_modified());
assert_eq!(cow.ref_count(), 1);
let _mutable_ref = cow.as_mut();
assert!(cow.is_modified());
}
#[test]
fn test_memory_cleanup() {
let mut manager = MemoryManager::new();
for i in 0..100 {
let expr = Expression::Number(Number::Integer(BigInt::from(i)));
let _shared = manager.create_shared(expr);
}
let stats_before = manager.get_stats().clone();
manager.cleanup();
let stats_after = manager.get_stats().clone();
println!("清理前: {:?}", stats_before);
println!("清理后: {:?}", stats_after);
}
}