rstsr_core/tensor/
device_conversion.rs

1use crate::prelude_dev::*;
2
3#[allow(clippy::type_complexity)]
4pub trait TensorDeviceChangeAPI<'l, BOut>
5where
6    BOut: DeviceRawAPI<Self::Type>,
7{
8    type Repr;
9    type ReprTo;
10    type Type;
11    type Dim: DimAPI;
12
13    fn change_device_f(self, device: &BOut) -> Result<TensorAny<Self::Repr, Self::Type, BOut, Self::Dim>>;
14    fn into_device_f(self, device: &BOut) -> Result<TensorAny<DataOwned<BOut::Raw>, Self::Type, BOut, Self::Dim>>;
15    fn to_device_f(&'l self, device: &BOut) -> Result<TensorAny<Self::ReprTo, Self::Type, BOut, Self::Dim>>;
16
17    fn change_device(self, device: &BOut) -> TensorAny<Self::Repr, Self::Type, BOut, Self::Dim>
18    where
19        Self: Sized,
20    {
21        self.change_device_f(device).rstsr_unwrap()
22    }
23
24    fn into_device(self, device: &BOut) -> TensorAny<DataOwned<BOut::Raw>, Self::Type, BOut, Self::Dim>
25    where
26        Self: Sized,
27    {
28        self.into_device_f(device).rstsr_unwrap()
29    }
30
31    fn to_device(&'l self, device: &BOut) -> TensorAny<Self::ReprTo, Self::Type, BOut, Self::Dim> {
32        self.to_device_f(device).rstsr_unwrap()
33    }
34}
35
36impl<'a, R, T, B, D, BOut> TensorDeviceChangeAPI<'a, BOut> for TensorAny<R, T, B, D>
37where
38    B: DeviceRawAPI<T> + DeviceChangeAPI<'a, BOut, R, T, D>,
39    BOut: DeviceRawAPI<T>,
40    D: DimAPI,
41    R: DataAPI<Data = B::Raw>,
42{
43    type Repr = B::Repr;
44    type ReprTo = B::ReprTo;
45    type Type = T;
46    type Dim = D;
47
48    fn change_device_f(self, device: &BOut) -> Result<TensorAny<B::Repr, T, BOut, D>> {
49        B::change_device(self, device)
50    }
51
52    fn into_device_f(self, device: &BOut) -> Result<Tensor<T, BOut, D>> {
53        B::into_device(self, device)
54    }
55
56    fn to_device_f(&'a self, device: &BOut) -> Result<TensorAny<B::ReprTo, Self::Type, BOut, Self::Dim>> {
57        B::to_device(self, device)
58    }
59}
60
61#[allow(clippy::type_complexity)]
62pub trait TensorChangeFromDevice<'l, BOut>
63where
64    BOut: DeviceRawAPI<Self::Type>,
65{
66    type Repr;
67    type ReprTo;
68    type Type;
69    type Dim: DimAPI;
70
71    fn change_device_f(self, device: &BOut) -> Result<TensorAny<Self::Repr, Self::Type, BOut, Self::Dim>>;
72    fn into_device_f(self, device: &BOut) -> Result<TensorAny<DataOwned<BOut::Raw>, Self::Type, BOut, Self::Dim>>;
73    fn to_device_f(&'l self, device: &BOut) -> Result<TensorAny<Self::ReprTo, Self::Type, BOut, Self::Dim>>;
74
75    fn change_device(self, device: &BOut) -> TensorAny<Self::Repr, Self::Type, BOut, Self::Dim>
76    where
77        Self: Sized,
78    {
79        self.change_device_f(device).rstsr_unwrap()
80    }
81
82    fn into_device(self, device: &BOut) -> TensorAny<DataOwned<BOut::Raw>, Self::Type, BOut, Self::Dim>
83    where
84        Self: Sized,
85    {
86        self.into_device_f(device).rstsr_unwrap()
87    }
88
89    fn to_device(&'l self, device: &BOut) -> TensorAny<Self::ReprTo, Self::Type, BOut, Self::Dim> {
90        self.to_device_f(device).rstsr_unwrap()
91    }
92}