use parking_lot::RwLock;
use std::sync::Arc;
pub struct LazyTensor<T: Clone> {
inner: Arc<RwLock<Option<T>>>,
shape_hint: Option<Vec<usize>>,
pub name: Option<String>,
}
impl<T: Clone> LazyTensor<T> {
pub fn pending(shape_hint: Option<Vec<usize>>) -> Self {
Self {
inner: Arc::new(RwLock::new(None)),
shape_hint,
name: None,
}
}
pub fn eager(value: T) -> Self {
Self {
inner: Arc::new(RwLock::new(Some(value))),
shape_hint: None,
name: None,
}
}
pub fn is_computed(&self) -> bool {
self.inner.read().is_some()
}
pub fn get(&self) -> Option<T> {
self.inner.read().clone()
}
pub fn set(&self, value: T) {
*self.inner.write() = Some(value);
}
pub fn take(&self) -> Option<T> {
self.inner.write().take()
}
pub fn shape_hint(&self) -> Option<&[usize]> {
self.shape_hint.as_deref()
}
pub fn memory_estimate_bytes(&self) -> usize {
match &self.shape_hint {
Some(shape) => shape.iter().product::<usize>() * 8,
None => 0,
}
}
}
impl<T: Clone> Clone for LazyTensor<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
shape_hint: self.shape_hint.clone(),
name: self.name.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lazy_tensor_pending_not_computed() {
let t: LazyTensor<i32> = LazyTensor::pending(Some(vec![3, 4]));
assert!(!t.is_computed());
assert!(t.get().is_none());
}
#[test]
fn test_lazy_tensor_set_and_get() {
let t: LazyTensor<i32> = LazyTensor::pending(None);
t.set(42);
assert!(t.is_computed());
assert_eq!(t.get(), Some(42));
}
#[test]
fn test_lazy_tensor_eager_is_computed() {
let t = LazyTensor::eager(99_i32);
assert!(t.is_computed());
assert_eq!(t.get(), Some(99));
}
#[test]
fn test_lazy_tensor_take_clears() {
let t = LazyTensor::eager(7_i32);
let val = t.take();
assert_eq!(val, Some(7));
assert!(!t.is_computed());
assert!(t.get().is_none());
}
#[test]
fn test_lazy_tensor_memory_estimate_with_hint() {
let t: LazyTensor<i32> = LazyTensor::pending(Some(vec![2, 3, 4]));
assert_eq!(t.memory_estimate_bytes(), 192);
}
#[test]
fn test_lazy_tensor_memory_estimate_no_hint() {
let t: LazyTensor<i32> = LazyTensor::pending(None);
assert_eq!(t.memory_estimate_bytes(), 0);
}
#[test]
fn test_lazy_tensor_clone_shares_inner() {
let t1: LazyTensor<i32> = LazyTensor::pending(Some(vec![2, 2]));
let t2 = t1.clone();
t1.set(123);
assert_eq!(t2.get(), Some(123));
assert_eq!(t1.shape_hint(), Some([2_usize, 2_usize].as_ref()));
}
}