burn_backend/tensor/
container.rs

1use alloc::boxed::Box;
2use core::any::Any;
3
4#[cfg(not(feature = "std"))]
5use alloc::vec::Vec;
6#[cfg(not(feature = "std"))]
7use hashbrown::HashMap;
8
9#[cfg(feature = "std")]
10use std::collections::HashMap;
11
12use crate::{TensorPrimitive, backend::Backend};
13
14/// Contains tensor of arbitrary dimension.
15#[derive(Debug)]
16pub struct TensorContainer<ID> {
17    tensors: HashMap<ID, Box<dyn Any + Send>>,
18}
19
20impl<ID> Default for TensorContainer<ID>
21where
22    ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,
23{
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl<ID> TensorContainer<ID>
30where
31    ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,
32{
33    /// Create an empty container.
34    pub fn new() -> Self {
35        Self {
36            tensors: HashMap::new(),
37        }
38    }
39
40    /// Get a tensor with the given ID.
41    pub fn get<B>(&self, id: &ID) -> Option<TensorPrimitive<B>>
42    where
43        B: Backend,
44    {
45        let grad = self.tensors.get(id)?;
46
47        let tensor = grad
48            .downcast_ref::<TensorPrimitive<B>>()
49            // .map(|primitive| Tensor::<B, D>::from_primitive(primitive.clone()))
50            .unwrap();
51
52        Some(tensor.clone())
53    }
54
55    /// Register a new tensor for the given ID.
56    ///
57    /// # Notes
58    ///
59    /// If a tensor is already registered for the given ID, it will be replaced.
60    pub fn register<B>(&mut self, id: ID, value: TensorPrimitive<B>)
61    where
62        B: Backend,
63    {
64        self.tensors.insert(id, Box::new(value));
65    }
66
67    /// Remove a tensor for the given ID and returns it.
68    pub fn remove<B>(&mut self, id: &ID) -> Option<TensorPrimitive<B>>
69    where
70        B: Backend,
71    {
72        self.tensors
73            .remove(id)
74            .map(|item| *item.downcast::<TensorPrimitive<B>>().unwrap())
75        // .map(|primitive| Tensor::from_primitive(*primitive))
76    }
77
78    /// The number of tensors registered.
79    pub fn len(&self) -> usize {
80        self.tensors.len()
81    }
82
83    /// If any tensor is contained.
84    pub fn is_empty(&self) -> bool {
85        self.len() == 0
86    }
87
88    /// Get id of every tensor in the container
89    pub fn ids(&self) -> Vec<&ID> {
90        self.tensors.keys().collect()
91    }
92}