use std::collections::HashMap;
use tensorlogic_infer::{ElemOp, ExecutorError, ReduceOp, TlAutodiff, TlExecutor};
use tensorlogic_ir::EinsumGraph;
use crate::autodiff::ForwardTape;
use crate::{Scirs2Exec, Scirs2Tensor};
#[derive(Debug, Default, Clone)]
pub struct LazyStats {
pub cache_hits: usize,
pub cache_misses: usize,
pub tensors_recomputed: usize,
pub peak_memory_estimate_bytes: usize,
}
pub struct LazyExecutor {
inner: Scirs2Exec,
cache: HashMap<usize, Scirs2Tensor>,
stats: LazyStats,
}
impl LazyExecutor {
pub fn new() -> Self {
Self {
inner: Scirs2Exec::new(),
cache: HashMap::new(),
stats: LazyStats::default(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
inner: Scirs2Exec::new(),
cache: HashMap::with_capacity(capacity),
stats: LazyStats::default(),
}
}
pub fn invalidate_cache(&mut self) {
self.cache.clear();
}
pub fn invalidate_node(&mut self, node_id: usize) {
if self.cache.remove(&node_id).is_some() {
self.stats.tensors_recomputed += 1;
}
}
pub fn stats(&self) -> &LazyStats {
&self.stats
}
pub fn memory_estimate_for(&self, graph: &EinsumGraph) -> usize {
if graph.nodes.is_empty() {
return 0;
}
if self.cache.is_empty() {
return 0;
}
let total_cached_bytes: usize = self
.cache
.values()
.map(|t| t.len() * std::mem::size_of::<f64>())
.sum();
let avg_bytes = total_cached_bytes / self.cache.len();
avg_bytes * graph.nodes.len()
}
pub fn cached_count(&self) -> usize {
self.cache.len()
}
fn cache_get(&mut self, node_id: usize) -> Option<Scirs2Tensor> {
if let Some(t) = self.cache.get(&node_id) {
self.stats.cache_hits += 1;
Some(t.clone())
} else {
self.stats.cache_misses += 1;
None
}
}
fn cache_insert(&mut self, node_id: usize, tensor: Scirs2Tensor) {
let size = tensor.len() * std::mem::size_of::<f64>();
self.cache.insert(node_id, tensor);
let current_bytes: usize = self
.cache
.values()
.map(|t| t.len() * std::mem::size_of::<f64>())
.sum();
if current_bytes > self.stats.peak_memory_estimate_bytes {
self.stats.peak_memory_estimate_bytes = current_bytes;
}
let _ = size;
}
}
impl Default for LazyExecutor {
fn default() -> Self {
Self::new()
}
}
impl TlExecutor for LazyExecutor {
type Tensor = Scirs2Tensor;
type Error = ExecutorError;
fn einsum(&mut self, spec: &str, inputs: &[Self::Tensor]) -> Result<Self::Tensor, Self::Error> {
self.inner.einsum(spec, inputs)
}
fn elem_op(&mut self, op: ElemOp, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error> {
self.inner.elem_op(op, x)
}
fn elem_op_binary(
&mut self,
op: ElemOp,
x: &Self::Tensor,
y: &Self::Tensor,
) -> Result<Self::Tensor, Self::Error> {
self.inner.elem_op_binary(op, x, y)
}
fn reduce(
&mut self,
op: ReduceOp,
x: &Self::Tensor,
axes: &[usize],
) -> Result<Self::Tensor, Self::Error> {
self.inner.reduce(op, x, axes)
}
}
impl TlAutodiff for LazyExecutor {
type Tape = ForwardTape;
fn forward(&mut self, graph: &EinsumGraph) -> Result<Self::Tensor, Self::Error> {
let result = self.inner.forward(graph)?;
let node_tensors: Vec<(usize, Scirs2Tensor)> = if let Some(tape) = &self.inner.tape {
graph
.nodes
.iter()
.enumerate()
.filter_map(|(node_idx, node)| {
node.outputs.first().and_then(|&tensor_idx| {
tape.tensors
.get(tensor_idx)
.and_then(|opt| opt.as_ref())
.map(|t| (node_idx, t.clone()))
})
})
.collect()
} else {
Vec::new()
};
for (node_idx, tensor) in node_tensors {
if !self.cache.contains_key(&node_idx) {
self.cache_insert(node_idx, tensor);
} else {
self.stats.cache_hits += 1;
}
}
Ok(result)
}
fn backward(
&mut self,
graph: &EinsumGraph,
loss: &Self::Tensor,
) -> Result<Self::Tape, Self::Error> {
self.inner.backward(graph, loss)
}
}
impl LazyExecutor {
pub fn get_cached(&mut self, node_id: usize) -> Option<Scirs2Tensor> {
self.cache_get(node_id)
}
pub fn put_cached(&mut self, node_id: usize, tensor: Scirs2Tensor) {
self.cache_insert(node_id, tensor);
}
pub fn inner_mut(&mut self) -> &mut Scirs2Exec {
&mut self.inner
}
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::EinsumGraph;
#[test]
fn test_lazy_executor_default() {
let exec = LazyExecutor::default();
assert_eq!(exec.cached_count(), 0);
}
#[test]
fn test_lazy_executor_cached_count_starts_zero() {
let exec = LazyExecutor::new();
assert_eq!(exec.cached_count(), 0);
}
#[test]
fn test_lazy_executor_invalidate_cache() {
let mut exec = LazyExecutor::with_capacity(4);
use scirs2_core::ndarray::ArrayD;
let t: Scirs2Tensor = ArrayD::zeros(scirs2_core::ndarray::IxDyn(&[2, 2]));
exec.put_cached(0, t);
assert_eq!(exec.cached_count(), 1);
exec.invalidate_cache();
assert_eq!(exec.cached_count(), 0);
}
#[test]
fn test_lazy_stats_default() {
let stats = LazyStats::default();
assert_eq!(stats.cache_hits, 0);
assert_eq!(stats.cache_misses, 0);
assert_eq!(stats.tensors_recomputed, 0);
assert_eq!(stats.peak_memory_estimate_bytes, 0);
}
#[test]
fn test_lazy_executor_memory_estimate_for_empty_graph() {
let exec = LazyExecutor::new();
let g = EinsumGraph::new();
assert_eq!(exec.memory_estimate_for(&g), 0);
}
}