burn_tensor/tensor/
container.rs

1use alloc::boxed::Box;
2use core::any::Any;
3
4#[cfg(not(feature = "std"))]
5use hashbrown::HashMap;
6
7#[cfg(feature = "std")]
8use std::collections::HashMap;
9
10use crate::{TensorPrimitive, backend::Backend};
11
12/// Contains tensor of arbitrary dimension.
13#[derive(Debug)]
14pub struct TensorContainer<ID> {
15    tensors: HashMap<ID, Box<dyn Any + Send>>,
16}
17
18impl<ID> Default for TensorContainer<ID>
19where
20    ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,
21{
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl<ID> TensorContainer<ID>
28where
29    ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,
30{
31    /// Create an empty container.
32    pub fn new() -> Self {
33        Self {
34            tensors: HashMap::new(),
35        }
36    }
37
38    /// Get a tensor with the given ID.
39    pub fn get<B>(&self, id: &ID) -> Option<TensorPrimitive<B>>
40    where
41        B: Backend,
42    {
43        let grad = self.tensors.get(id)?;
44
45        let tensor = grad
46            .downcast_ref::<TensorPrimitive<B>>()
47            // .map(|primitive| Tensor::<B, D>::from_primitive(primitive.clone()))
48            .unwrap();
49
50        Some(tensor.clone())
51    }
52
53    /// Register a new tensor for the given ID.
54    ///
55    /// # Notes
56    ///
57    /// If a tensor is already registered for the given ID, it will be replaced.
58    pub fn register<B>(&mut self, id: ID, value: TensorPrimitive<B>)
59    where
60        B: Backend,
61    {
62        self.tensors.insert(id, Box::new(value));
63    }
64
65    /// Remove a tensor for the given ID and returns it.
66    pub fn remove<B>(&mut self, id: &ID) -> Option<TensorPrimitive<B>>
67    where
68        B: Backend,
69    {
70        self.tensors
71            .remove(id)
72            .map(|item| *item.downcast::<TensorPrimitive<B>>().unwrap())
73        // .map(|primitive| Tensor::from_primitive(*primitive))
74    }
75
76    /// The number of tensors registered.
77    pub fn len(&self) -> usize {
78        self.tensors.len()
79    }
80
81    /// If any tensor is contained.
82    pub fn is_empty(&self) -> bool {
83        self.len() == 0
84    }
85}