use crate::{error::AutogradError, Float, Result};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct ExprFingerprint {
pub op_name: String,
pub inputs: Vec<usize>,
pub attributes: Vec<String>,
}
impl ExprFingerprint {
pub fn new(op_name: String, inputs: Vec<usize>, attributes: Vec<String>) -> Self {
Self {
op_name,
inputs,
attributes,
}
}
pub fn commutative(op_name: String, mut inputs: Vec<usize>, attributes: Vec<String>) -> Self {
inputs.sort_unstable();
Self {
op_name,
inputs,
attributes,
}
}
}
pub struct CSEOptimizer {
fingerprint_map: HashMap<ExprFingerprint, usize>,
replaced: HashSet<usize>,
num_eliminations: usize,
}
impl CSEOptimizer {
pub fn new() -> Self {
Self {
fingerprint_map: HashMap::new(),
replaced: HashSet::new(),
num_eliminations: 0,
}
}
pub fn find_duplicate(&self, fingerprint: &ExprFingerprint) -> Option<usize> {
self.fingerprint_map.get(fingerprint).copied()
}
pub fn record(&mut self, fingerprint: ExprFingerprint, id: usize) {
self.fingerprint_map.insert(fingerprint, id);
}
pub fn mark_replaced(&mut self, id: usize) {
self.replaced.insert(id);
self.num_eliminations += 1;
}
pub fn is_replaced(&self, id: usize) -> bool {
self.replaced.contains(&id)
}
pub fn num_eliminations(&self) -> usize {
self.num_eliminations
}
pub fn clear(&mut self) {
self.fingerprint_map.clear();
self.replaced.clear();
self.num_eliminations = 0;
}
}
impl Default for CSEOptimizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CSEResult {
pub eliminations: usize,
pub memory_saved: usize,
pub cost_saved: usize,
}
impl CSEResult {
pub fn new(eliminations: usize, memory_saved: usize, cost_saved: usize) -> Self {
Self {
eliminations,
memory_saved,
cost_saved,
}
}
pub fn has_changes(&self) -> bool {
self.eliminations > 0
}
}
pub fn eliminate_common_subexpressions<T: Float>(
_graph: &mut Vec<ExprFingerprint>,
) -> Result<CSEResult> {
let mut optimizer = CSEOptimizer::new();
let mut replacements = 0;
for (id, expr) in _graph.iter().enumerate() {
if let Some(existing_id) = optimizer.find_duplicate(expr) {
optimizer.mark_replaced(id);
replacements += 1;
let _ = existing_id; } else {
optimizer.record(expr.clone(), id);
}
}
Ok(CSEResult::new(
replacements,
replacements * 1024, replacements * 100, ))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expression_fingerprint() {
let fp1 = ExprFingerprint::new("add".to_string(), vec![1, 2], vec![]);
let fp2 = ExprFingerprint::new("add".to_string(), vec![1, 2], vec![]);
assert_eq!(fp1, fp2);
}
#[test]
fn test_commutative_fingerprint() {
let fp1 = ExprFingerprint::commutative("add".to_string(), vec![2, 1], vec![]);
let fp2 = ExprFingerprint::commutative("add".to_string(), vec![1, 2], vec![]);
assert_eq!(fp1, fp2);
}
#[test]
fn test_cse_optimizer() {
let mut optimizer = CSEOptimizer::new();
let fp = ExprFingerprint::new("mul".to_string(), vec![1, 2], vec![]);
optimizer.record(fp.clone(), 100);
assert_eq!(optimizer.find_duplicate(&fp), Some(100));
optimizer.mark_replaced(101);
assert!(optimizer.is_replaced(101));
assert_eq!(optimizer.num_eliminations(), 1);
}
#[test]
fn test_cse_result() {
let result = CSEResult::new(5, 5120, 500);
assert!(result.has_changes());
assert_eq!(result.eliminations, 5);
assert_eq!(result.memory_saved, 5120);
}
}