burn_rmexp_dyntensor/
clone_box.rs

1//! # Clone Box Trait
2use std::any::Any;
3use std::fmt::Debug;
4
5/// A trait for cloning values into a boxed form.
6pub trait CloneBox: 'static + Any + Debug + Send + Sync {
7    fn clone_box(&self) -> Box<dyn CloneBox>;
8}
9
10impl dyn CloneBox {
11    /// Downcasts the boxed value to a specific type.
12    ///
13    /// See: `Any::downcast_ref`.
14    pub fn downcast_ref<T: Any>(&self) -> Option<&T> {
15        (self as &dyn Any).downcast_ref::<T>()
16    }
17}
18
19impl<T: 'static + Any + Debug + Clone + Send + Sync> CloneBox for T {
20    fn clone_box(&self) -> Box<dyn CloneBox> {
21        Box::new(self.clone())
22    }
23}
24
25impl Clone for Box<dyn CloneBox> {
26    fn clone(&self) -> Self {
27        (**self).clone_box()
28    }
29}
30
31#[cfg(test)]
32mod tests {
33    use super::*;
34    use burn::Tensor;
35    use burn::backend::Wgpu;
36    use burn::tensor::Distribution;
37
38    fn assert_send<T: Send>() {}
39
40    #[test]
41    fn test_clone_box_tensor() {
42        type B = Wgpu;
43        let device = Default::default();
44
45        let source: Tensor<B, 2> = Tensor::random([2, 3], Distribution::Default, &device);
46
47        let boxed: Box<dyn CloneBox> = Box::new(source.clone());
48
49        assert_send::<Box<dyn CloneBox>>();
50
51        let cloned_box = boxed.clone();
52
53        let clone = cloned_box.downcast_ref::<Tensor<B, 2>>().unwrap();
54
55        clone.to_data().assert_eq(&source.to_data(), true);
56    }
57}