burn_backend/tensor/
container.rs1use alloc::boxed::Box;
2use core::any::Any;
3
4#[cfg(not(feature = "std"))]
5use alloc::vec::Vec;
6#[cfg(not(feature = "std"))]
7use hashbrown::HashMap;
8
9#[cfg(feature = "std")]
10use std::collections::HashMap;
11
12use crate::{TensorPrimitive, backend::Backend};
13
14#[derive(Debug)]
16pub struct TensorContainer<ID> {
17 tensors: HashMap<ID, Box<dyn Any + Send>>,
18}
19
20impl<ID> Default for TensorContainer<ID>
21where
22 ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,
23{
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl<ID> TensorContainer<ID>
30where
31 ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,
32{
33 pub fn new() -> Self {
35 Self {
36 tensors: HashMap::new(),
37 }
38 }
39
40 pub fn get<B>(&self, id: &ID) -> Option<TensorPrimitive<B>>
42 where
43 B: Backend,
44 {
45 let grad = self.tensors.get(id)?;
46
47 let tensor = grad
48 .downcast_ref::<TensorPrimitive<B>>()
49 .unwrap();
51
52 Some(tensor.clone())
53 }
54
55 pub fn register<B>(&mut self, id: ID, value: TensorPrimitive<B>)
61 where
62 B: Backend,
63 {
64 self.tensors.insert(id, Box::new(value));
65 }
66
67 pub fn remove<B>(&mut self, id: &ID) -> Option<TensorPrimitive<B>>
69 where
70 B: Backend,
71 {
72 self.tensors
73 .remove(id)
74 .map(|item| *item.downcast::<TensorPrimitive<B>>().unwrap())
75 }
77
78 pub fn len(&self) -> usize {
80 self.tensors.len()
81 }
82
83 pub fn is_empty(&self) -> bool {
85 self.len() == 0
86 }
87
88 pub fn ids(&self) -> Vec<&ID> {
90 self.tensors.keys().collect()
91 }
92}