burn_tensor/tensor/
container.rs1use alloc::boxed::Box;
2use core::any::Any;
3
4#[cfg(not(feature = "std"))]
5use hashbrown::HashMap;
6
7#[cfg(feature = "std")]
8use std::collections::HashMap;
9
10use crate::{TensorPrimitive, backend::Backend};
11
12#[derive(Debug)]
14pub struct TensorContainer<ID> {
15 tensors: HashMap<ID, Box<dyn Any + Send>>,
16}
17
18impl<ID> Default for TensorContainer<ID>
19where
20 ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,
21{
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl<ID> TensorContainer<ID>
28where
29 ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,
30{
31 pub fn new() -> Self {
33 Self {
34 tensors: HashMap::new(),
35 }
36 }
37
38 pub fn get<B>(&self, id: &ID) -> Option<TensorPrimitive<B>>
40 where
41 B: Backend,
42 {
43 let grad = self.tensors.get(id)?;
44
45 let tensor = grad
46 .downcast_ref::<TensorPrimitive<B>>()
47 .unwrap();
49
50 Some(tensor.clone())
51 }
52
53 pub fn register<B>(&mut self, id: ID, value: TensorPrimitive<B>)
59 where
60 B: Backend,
61 {
62 self.tensors.insert(id, Box::new(value));
63 }
64
65 pub fn remove<B>(&mut self, id: &ID) -> Option<TensorPrimitive<B>>
67 where
68 B: Backend,
69 {
70 self.tensors
71 .remove(id)
72 .map(|item| *item.downcast::<TensorPrimitive<B>>().unwrap())
73 }
75
76 pub fn len(&self) -> usize {
78 self.tensors.len()
79 }
80
81 pub fn is_empty(&self) -> bool {
83 self.len() == 0
84 }
85}