burn_tensor/tensor/
container.rsuse alloc::boxed::Box;
use core::any::Any;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
#[cfg(feature = "std")]
use std::collections::HashMap;
use crate::{backend::Backend, TensorPrimitive};
#[derive(Debug)]
pub struct TensorContainer<ID> {
tensors: HashMap<ID, Box<dyn Any + Send>>,
}
impl<ID> Default for TensorContainer<ID>
where
ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,
{
fn default() -> Self {
Self::new()
}
}
impl<ID> TensorContainer<ID>
where
ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,
{
pub fn new() -> Self {
Self {
tensors: HashMap::new(),
}
}
pub fn get<B>(&self, id: &ID) -> Option<TensorPrimitive<B>>
where
B: Backend,
{
let grad = self.tensors.get(id)?;
let tensor = grad
.downcast_ref::<TensorPrimitive<B>>()
.unwrap();
Some(tensor.clone())
}
pub fn register<B>(&mut self, id: ID, value: TensorPrimitive<B>)
where
B: Backend,
{
self.tensors.insert(id, Box::new(value));
}
pub fn remove<B>(&mut self, id: &ID) -> Option<TensorPrimitive<B>>
where
B: Backend,
{
self.tensors
.remove(id)
.map(|item| *item.downcast::<TensorPrimitive<B>>().unwrap())
}
pub fn len(&self) -> usize {
self.tensors.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}