burn_backend/tensor/
container.rs1use alloc::boxed::Box;
2use core::any::Any;
3
4#[cfg(not(feature = "std"))]
5use alloc::vec::Vec;
6#[cfg(not(feature = "std"))]
7use hashbrown::HashMap;
8
9#[cfg(feature = "std")]
10use std::collections::HashMap;
11
12use crate::{TensorPrimitive, backend::Backend};
13
14#[derive(Debug)]
16pub struct TensorContainer<ID> {
17 tensors: HashMap<ID, Box<dyn Any + Send>>,
18}
19
20impl<ID> Default for TensorContainer<ID>
21where
22 ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,
23{
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl<ID> TensorContainer<ID>
30where
31 ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,
32{
33 pub fn new() -> Self {
35 Self {
36 tensors: HashMap::new(),
37 }
38 }
39
40 pub fn get<B>(&self, id: &ID) -> Option<TensorPrimitive<B>>
42 where
43 B: Backend,
44 {
45 let grad = self.tensors.get(id)?;
46
47 let tensor = grad
48 .downcast_ref::<TensorPrimitive<B>>()
49 .unwrap();
51
52 Some(tensor.clone())
53 }
54
55 pub fn get_mut_ref<B>(&mut self, id: &ID) -> Option<&mut TensorPrimitive<B>>
57 where
58 B: Backend,
59 {
60 let grad = self.tensors.get_mut(id)?;
61
62 let tensor = grad.downcast_mut::<TensorPrimitive<B>>().unwrap();
63
64 Some(tensor)
65 }
66
67 pub fn register<B>(&mut self, id: ID, value: TensorPrimitive<B>)
73 where
74 B: Backend,
75 {
76 self.tensors.insert(id, Box::new(value));
77 }
78
79 pub fn remove<B>(&mut self, id: &ID) -> Option<TensorPrimitive<B>>
81 where
82 B: Backend,
83 {
84 self.tensors
85 .remove(id)
86 .map(|item| *item.downcast::<TensorPrimitive<B>>().unwrap())
87 }
89
90 pub fn len(&self) -> usize {
92 self.tensors.len()
93 }
94
95 pub fn is_empty(&self) -> bool {
97 self.len() == 0
98 }
99
100 pub fn ids(&self) -> Vec<&ID> {
102 self.tensors.keys().collect()
103 }
104}