use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, RwLock};
use crate::shared_array::SharedArray;
use super::shared::SharedExpr;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ExprId(u64);
impl ExprId {
pub fn new() -> Self {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
Self(COUNTER.fetch_add(1, Ordering::SeqCst))
}
pub fn from_raw(id: u64) -> Self {
Self(id)
}
pub fn raw(&self) -> u64 {
self.0
}
}
impl Default for ExprId {
fn default() -> Self {
Self::new()
}
}
pub struct ExprCache<T: Clone> {
cache: Arc<RwLock<HashMap<ExprId, SharedArray<T>>>>,
}
impl<T: Clone> ExprCache<T> {
pub fn new() -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn insert(&self, id: ExprId, value: SharedArray<T>) {
if let Ok(mut guard) = self.cache.write() {
guard.insert(id, value);
}
}
pub fn get(&self, id: &ExprId) -> Option<SharedArray<T>> {
if let Ok(guard) = self.cache.read() {
guard.get(id).cloned()
} else {
None
}
}
pub fn contains(&self, id: &ExprId) -> bool {
if let Ok(guard) = self.cache.read() {
guard.contains_key(id)
} else {
false
}
}
pub fn clear(&self) {
if let Ok(mut guard) = self.cache.write() {
guard.clear();
}
}
pub fn len(&self) -> usize {
if let Ok(guard) = self.cache.read() {
guard.len()
} else {
0
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<T: Clone> Default for ExprCache<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Clone> Clone for ExprCache<T> {
fn clone(&self) -> Self {
Self {
cache: Arc::clone(&self.cache),
}
}
}
#[derive(Clone)]
pub struct CachedExpr<T: Clone, E: SharedExpr<T>> {
expr: E,
id: ExprId,
cache: ExprCache<T>,
}
impl<T: Clone, E: SharedExpr<T>> CachedExpr<T, E> {
pub fn new(expr: E, cache: ExprCache<T>) -> Self {
Self {
expr,
id: ExprId::new(),
cache,
}
}
pub fn with_id(expr: E, id: ExprId, cache: ExprCache<T>) -> Self {
Self { expr, id, cache }
}
pub fn id(&self) -> ExprId {
self.id
}
pub fn cache(&self) -> &ExprCache<T> {
&self.cache
}
pub fn invalidate(&self) {
if let Ok(mut guard) = self.cache.cache.write() {
guard.remove(&self.id);
}
}
}
impl<T: Clone, E: SharedExpr<T>> SharedExpr<T> for CachedExpr<T, E> {
fn eval_at(&self, index: usize) -> T {
let array = self.eval();
array.to_vec()[index].clone()
}
fn size(&self) -> usize {
self.expr.size()
}
fn shape(&self) -> Vec<usize> {
self.expr.shape()
}
fn eval(&self) -> SharedArray<T> {
if let Some(cached) = self.cache.get(&self.id) {
return cached;
}
let result = self.expr.eval();
self.cache.insert(self.id, result.clone());
result
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ExprKey {
Array(u64),
Binary {
op: &'static str,
left: Box<ExprKey>,
right: Box<ExprKey>,
},
Unary {
op: &'static str,
operand: Box<ExprKey>,
},
Scalar {
op: &'static str,
operand: Box<ExprKey>,
scalar_hash: u64,
},
}
impl ExprKey {
pub fn array(id: u64) -> Self {
Self::Array(id)
}
pub fn binary(op: &'static str, left: ExprKey, right: ExprKey) -> Self {
Self::Binary {
op,
left: Box::new(left),
right: Box::new(right),
}
}
pub fn unary(op: &'static str, operand: ExprKey) -> Self {
Self::Unary {
op,
operand: Box::new(operand),
}
}
pub fn scalar(op: &'static str, operand: ExprKey, scalar_hash: u64) -> Self {
Self::Scalar {
op,
operand: Box::new(operand),
scalar_hash,
}
}
}
pub fn hash_f64(value: f64) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
value.to_bits().hash(&mut hasher);
hasher.finish()
}
pub struct CSEOptimizer<T: Clone> {
key_to_id: HashMap<ExprKey, ExprId>,
cache: ExprCache<T>,
next_array_id: u64,
}
impl<T: Clone> CSEOptimizer<T> {
pub fn new() -> Self {
Self {
key_to_id: HashMap::new(),
cache: ExprCache::new(),
next_array_id: 0,
}
}
pub fn get_or_create_id(&mut self, key: &ExprKey) -> ExprId {
if let Some(&id) = self.key_to_id.get(key) {
id
} else {
let id = ExprId::new();
self.key_to_id.insert(key.clone(), id);
id
}
}
pub fn next_array_id(&mut self) -> u64 {
let id = self.next_array_id;
self.next_array_id += 1;
id
}
pub fn cache(&self) -> &ExprCache<T> {
&self.cache
}
pub fn cache_expr<E: SharedExpr<T>>(&self, expr: E, id: ExprId) -> CachedExpr<T, E> {
CachedExpr::with_id(expr, id, self.cache.clone())
}
pub fn stats(&self) -> CSEStats {
CSEStats {
unique_expressions: self.key_to_id.len(),
cached_results: self.cache.len(),
}
}
pub fn clear(&mut self) {
self.key_to_id.clear();
self.cache.clear();
self.next_array_id = 0;
}
}
impl<T: Clone> Default for CSEOptimizer<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CSEStats {
pub unique_expressions: usize,
pub cached_results: usize,
}
pub struct CSEExprBuilder<T: Clone> {
optimizer: CSEOptimizer<T>,
}
impl<T: Clone> CSEExprBuilder<T> {
pub fn new() -> Self {
Self {
optimizer: CSEOptimizer::new(),
}
}
pub fn wrap<E: SharedExpr<T>>(&mut self, expr: E, key: ExprKey) -> CachedExpr<T, E> {
let id = self.optimizer.get_or_create_id(&key);
self.optimizer.cache_expr(expr, id)
}
pub fn eval_array(&self, array: SharedArray<T>) -> SharedArray<T> {
array
}
pub fn stats(&self) -> CSEStats {
self.optimizer.stats()
}
pub fn clear(&mut self) {
self.optimizer.clear();
}
}
impl<T: Clone> Default for CSEExprBuilder<T> {
fn default() -> Self {
Self::new()
}
}
pub trait CSESupport<T: Clone>: SharedExpr<T> + Sized {
fn with_cache(self, cache: ExprCache<T>) -> CachedExpr<T, Self> {
CachedExpr::new(self, cache)
}
fn with_cache_id(self, id: ExprId, cache: ExprCache<T>) -> CachedExpr<T, Self> {
CachedExpr::with_id(self, id, cache)
}
}
impl<T: Clone, E: SharedExpr<T>> CSESupport<T> for E {}
#[derive(Debug, Clone)]
pub struct CSEAnalysisResult {
pub total_nodes: usize,
pub common_subexpressions: usize,
pub savings_ratio: f64,
pub occurrence_counts: HashMap<String, usize>,
}
impl CSEAnalysisResult {
pub fn new() -> Self {
Self {
total_nodes: 0,
common_subexpressions: 0,
savings_ratio: 0.0,
occurrence_counts: HashMap::new(),
}
}
pub fn calculate_savings(&mut self) {
if self.total_nodes > 0 {
self.savings_ratio = self.common_subexpressions as f64 / self.total_nodes as f64;
}
}
}
impl Default for CSEAnalysisResult {
fn default() -> Self {
Self::new()
}
}
pub fn analyze_cse(keys: &[ExprKey]) -> CSEAnalysisResult {
let mut result = CSEAnalysisResult::new();
result.total_nodes = keys.len();
let mut key_counts: HashMap<String, usize> = HashMap::new();
for key in keys {
let key_str = format!("{:?}", key);
*key_counts.entry(key_str).or_insert(0) += 1;
}
for (key_str, count) in &key_counts {
if *count > 1 {
result.common_subexpressions += count - 1; result.occurrence_counts.insert(key_str.clone(), *count);
}
}
result.calculate_savings();
result
}
#[derive(Clone)]
pub struct OptimizedExprNode<T: Clone> {
id: ExprId,
key: ExprKey,
cache: ExprCache<T>,
result: Option<SharedArray<T>>,
}
impl<T: Clone> OptimizedExprNode<T> {
pub fn new(id: ExprId, key: ExprKey, cache: ExprCache<T>) -> Self {
Self {
id,
key,
cache,
result: None,
}
}
pub fn id(&self) -> ExprId {
self.id
}
pub fn key(&self) -> &ExprKey {
&self.key
}
pub fn is_cached(&self) -> bool {
self.result.is_some() || self.cache.contains(&self.id)
}
pub fn get_or_compute<F>(&mut self, compute: F) -> SharedArray<T>
where
F: FnOnce() -> SharedArray<T>,
{
if let Some(ref result) = self.result {
return result.clone();
}
if let Some(cached) = self.cache.get(&self.id) {
self.result = Some(cached.clone());
return cached;
}
let result = compute();
self.cache.insert(self.id, result.clone());
self.result = Some(result.clone());
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expr::shared::{SharedArrayExpr, SharedBinaryExpr, SharedScalarExpr};
#[test]
fn test_expr_id_uniqueness() {
let id1 = ExprId::new();
let id2 = ExprId::new();
let id3 = ExprId::new();
assert_ne!(id1, id2);
assert_ne!(id2, id3);
assert_ne!(id1, id3);
}
#[test]
fn test_expr_cache_basic() {
let cache: ExprCache<f64> = ExprCache::new();
let id = ExprId::new();
let array = SharedArray::from_vec(vec![1.0, 2.0, 3.0]);
assert!(cache.is_empty());
assert!(!cache.contains(&id));
cache.insert(id, array.clone());
assert!(!cache.is_empty());
assert!(cache.contains(&id));
assert_eq!(cache.len(), 1);
let cached = cache.get(&id);
assert!(cached.is_some());
assert_eq!(
cached.expect("Cached value should exist").to_vec(),
vec![1.0, 2.0, 3.0]
);
}
#[test]
fn test_expr_cache_multiple_entries() {
let cache: ExprCache<f64> = ExprCache::new();
let id1 = ExprId::new();
let id2 = ExprId::new();
let id3 = ExprId::new();
cache.insert(id1, SharedArray::from_vec(vec![1.0]));
cache.insert(id2, SharedArray::from_vec(vec![2.0]));
cache.insert(id3, SharedArray::from_vec(vec![3.0]));
assert_eq!(cache.len(), 3);
assert_eq!(cache.get(&id1).expect("Should exist").to_vec(), vec![1.0]);
assert_eq!(cache.get(&id2).expect("Should exist").to_vec(), vec![2.0]);
assert_eq!(cache.get(&id3).expect("Should exist").to_vec(), vec![3.0]);
}
#[test]
fn test_expr_cache_clear() {
let cache: ExprCache<f64> = ExprCache::new();
let id = ExprId::new();
cache.insert(id, SharedArray::from_vec(vec![1.0, 2.0]));
assert_eq!(cache.len(), 1);
cache.clear();
assert!(cache.is_empty());
assert!(!cache.contains(&id));
}
#[test]
fn test_cached_expr_basic() {
let arr = SharedArray::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let expr = SharedArrayExpr::new(arr);
let cache: ExprCache<f64> = ExprCache::new();
let cached = CachedExpr::new(expr, cache.clone());
let result1 = cached.eval();
assert_eq!(result1.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
assert_eq!(cache.len(), 1);
let result2 = cached.eval();
assert_eq!(result2.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
assert_eq!(cache.len(), 1); }
#[test]
fn test_cached_expr_shared_cache() {
let cache: ExprCache<f64> = ExprCache::new();
let arr1 = SharedArray::from_vec(vec![1.0, 2.0]);
let arr2 = SharedArray::from_vec(vec![3.0, 4.0]);
let expr1 = SharedArrayExpr::new(arr1);
let expr2 = SharedArrayExpr::new(arr2);
let cached1 = CachedExpr::new(expr1, cache.clone());
let cached2 = CachedExpr::new(expr2, cache.clone());
cached1.eval();
cached2.eval();
assert_eq!(cache.len(), 2);
}
#[test]
fn test_expr_key_array() {
let key1 = ExprKey::array(0);
let key2 = ExprKey::array(0);
let key3 = ExprKey::array(1);
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[test]
fn test_expr_key_binary() {
let key_a = ExprKey::array(0);
let key_b = ExprKey::array(1);
let add1 = ExprKey::binary("add", key_a.clone(), key_b.clone());
let add2 = ExprKey::binary("add", key_a.clone(), key_b.clone());
let mul1 = ExprKey::binary("mul", key_a.clone(), key_b.clone());
assert_eq!(add1, add2); assert_ne!(add1, mul1); }
#[test]
fn test_expr_key_unary() {
let key_a = ExprKey::array(0);
let sqrt1 = ExprKey::unary("sqrt", key_a.clone());
let sqrt2 = ExprKey::unary("sqrt", key_a.clone());
let neg1 = ExprKey::unary("neg", key_a.clone());
assert_eq!(sqrt1, sqrt2);
assert_ne!(sqrt1, neg1);
}
#[test]
fn test_expr_key_scalar() {
let key_a = ExprKey::array(0);
let add10_1 = ExprKey::scalar("add", key_a.clone(), hash_f64(10.0));
let add10_2 = ExprKey::scalar("add", key_a.clone(), hash_f64(10.0));
let add20 = ExprKey::scalar("add", key_a.clone(), hash_f64(20.0));
assert_eq!(add10_1, add10_2); assert_ne!(add10_1, add20); }
#[test]
fn test_cse_optimizer_basic() {
let mut optimizer: CSEOptimizer<f64> = CSEOptimizer::new();
let key_a = ExprKey::array(0);
let key_b = ExprKey::array(1);
let key_sum = ExprKey::binary("add", key_a.clone(), key_b.clone());
let id1 = optimizer.get_or_create_id(&key_sum);
let id2 = optimizer.get_or_create_id(&key_sum);
assert_eq!(id1, id2);
assert_eq!(optimizer.stats().unique_expressions, 1);
}
#[test]
fn test_cse_optimizer_multiple_keys() {
let mut optimizer: CSEOptimizer<f64> = CSEOptimizer::new();
let key_a = ExprKey::array(0);
let key_b = ExprKey::array(1);
let key_sum = ExprKey::binary("add", key_a.clone(), key_b.clone());
let key_prod = ExprKey::binary("mul", key_a.clone(), key_b.clone());
let id_sum = optimizer.get_or_create_id(&key_sum);
let id_prod = optimizer.get_or_create_id(&key_prod);
assert_ne!(id_sum, id_prod);
assert_eq!(optimizer.stats().unique_expressions, 2);
}
#[test]
fn test_cse_analysis() {
let key_a = ExprKey::array(0);
let key_b = ExprKey::array(1);
let key_sum = ExprKey::binary("add", key_a.clone(), key_b.clone());
let keys = vec![
key_a.clone(),
key_b.clone(),
key_sum.clone(),
key_sum.clone(), ExprKey::binary("mul", key_sum.clone(), key_sum.clone()),
];
let analysis = analyze_cse(&keys);
assert_eq!(analysis.total_nodes, 5);
assert!(analysis.common_subexpressions > 0);
assert!(analysis.savings_ratio > 0.0);
}
#[test]
fn test_cse_support_trait() {
let arr = SharedArray::from_vec(vec![1.0, 2.0, 3.0]);
let expr = SharedArrayExpr::new(arr);
let cache: ExprCache<f64> = ExprCache::new();
let cached = expr.with_cache(cache.clone());
let result = cached.eval();
assert_eq!(result.to_vec(), vec![1.0, 2.0, 3.0]);
assert_eq!(cache.len(), 1);
}
#[test]
fn test_cse_expr_builder() {
let a = SharedArray::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let builder: CSEExprBuilder<f64> = CSEExprBuilder::new();
let result = builder.eval_array(a);
assert_eq!(result.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_optimized_expr_node() {
let cache: ExprCache<f64> = ExprCache::new();
let id = ExprId::new();
let key = ExprKey::array(0);
let mut node = OptimizedExprNode::new(id, key.clone(), cache.clone());
assert!(!node.is_cached());
let result = node.get_or_compute(|| SharedArray::from_vec(vec![1.0, 2.0, 3.0]));
assert_eq!(result.to_vec(), vec![1.0, 2.0, 3.0]);
assert!(node.is_cached());
let result2 = node.get_or_compute(|| SharedArray::from_vec(vec![9.0, 9.0, 9.0]));
assert_eq!(result2.to_vec(), vec![1.0, 2.0, 3.0]); }
#[test]
fn test_cse_shared_computation() {
let a = SharedArray::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let b = SharedArray::from_vec(vec![2.0, 3.0, 4.0, 5.0]);
let cache: ExprCache<f64> = ExprCache::new();
let expr_a = SharedArrayExpr::new(a);
let expr_b = SharedArrayExpr::new(b);
let sum = SharedBinaryExpr::new(expr_a, expr_b, |x, y| x + y)
.expect("Binary expression creation should succeed");
let cached_sum = CachedExpr::new(sum, cache.clone());
let sum_result = cached_sum.eval();
assert_eq!(sum_result.to_vec(), vec![3.0, 5.0, 7.0, 9.0]);
assert_eq!(cache.len(), 1);
let sum_squared =
SharedBinaryExpr::new(cached_sum.clone(), cached_sum, |x: f64, y: f64| x * y)
.expect("Binary expression creation should succeed");
let result = sum_squared.eval();
assert_eq!(result.to_vec(), vec![9.0, 25.0, 49.0, 81.0]);
}
#[test]
fn test_cached_expr_invalidate() {
let arr = SharedArray::from_vec(vec![1.0, 2.0, 3.0]);
let expr = SharedArrayExpr::new(arr);
let cache: ExprCache<f64> = ExprCache::new();
let cached = CachedExpr::new(expr, cache.clone());
cached.eval();
assert_eq!(cache.len(), 1);
cached.invalidate();
}
#[test]
fn test_hash_f64() {
let h1 = hash_f64(10.0);
let h2 = hash_f64(10.0);
let h3 = hash_f64(20.0);
assert_eq!(h1, h2);
assert_ne!(h1, h3);
}
}