burn_rmexp_dyntensor/
clone_box.rs1use std::any::Any;
3use std::fmt::Debug;
4
5pub trait CloneBox: 'static + Any + Debug + Send + Sync {
7 fn clone_box(&self) -> Box<dyn CloneBox>;
8}
9
10impl dyn CloneBox {
11 pub fn downcast_ref<T: Any>(&self) -> Option<&T> {
15 (self as &dyn Any).downcast_ref::<T>()
16 }
17}
18
19impl<T: 'static + Any + Debug + Clone + Send + Sync> CloneBox for T {
20 fn clone_box(&self) -> Box<dyn CloneBox> {
21 Box::new(self.clone())
22 }
23}
24
25impl Clone for Box<dyn CloneBox> {
26 fn clone(&self) -> Self {
27 (**self).clone_box()
28 }
29}
30
31#[cfg(test)]
32mod tests {
33 use super::*;
34 use burn::Tensor;
35 use burn::backend::Wgpu;
36 use burn::tensor::Distribution;
37
38 fn assert_send<T: Send>() {}
39
40 #[test]
41 fn test_clone_box_tensor() {
42 type B = Wgpu;
43 let device = Default::default();
44
45 let source: Tensor<B, 2> = Tensor::random([2, 3], Distribution::Default, &device);
46
47 let boxed: Box<dyn CloneBox> = Box::new(source.clone());
48
49 assert_send::<Box<dyn CloneBox>>();
50
51 let cloned_box = boxed.clone();
52
53 let clone = cloned_box.downcast_ref::<Tensor<B, 2>>().unwrap();
54
55 clone.to_data().assert_eq(&source.to_data(), true);
56 }
57}