1use crate::prelude_dev::*;
2
3#[allow(clippy::type_complexity)]
4pub trait TensorDeviceChangeAPI<'l, BOut>
5where
6 BOut: DeviceRawAPI<Self::Type>,
7{
8 type Repr;
9 type ReprTo;
10 type Type;
11 type Dim: DimAPI;
12
13 fn change_device_f(self, device: &BOut) -> Result<TensorAny<Self::Repr, Self::Type, BOut, Self::Dim>>;
14 fn into_device_f(self, device: &BOut) -> Result<TensorAny<DataOwned<BOut::Raw>, Self::Type, BOut, Self::Dim>>;
15 fn to_device_f(&'l self, device: &BOut) -> Result<TensorAny<Self::ReprTo, Self::Type, BOut, Self::Dim>>;
16
17 fn change_device(self, device: &BOut) -> TensorAny<Self::Repr, Self::Type, BOut, Self::Dim>
18 where
19 Self: Sized,
20 {
21 self.change_device_f(device).rstsr_unwrap()
22 }
23
24 fn into_device(self, device: &BOut) -> TensorAny<DataOwned<BOut::Raw>, Self::Type, BOut, Self::Dim>
25 where
26 Self: Sized,
27 {
28 self.into_device_f(device).rstsr_unwrap()
29 }
30
31 fn to_device(&'l self, device: &BOut) -> TensorAny<Self::ReprTo, Self::Type, BOut, Self::Dim> {
32 self.to_device_f(device).rstsr_unwrap()
33 }
34}
35
36impl<'a, R, T, B, D, BOut> TensorDeviceChangeAPI<'a, BOut> for TensorAny<R, T, B, D>
37where
38 B: DeviceRawAPI<T> + DeviceChangeAPI<'a, BOut, R, T, D>,
39 BOut: DeviceRawAPI<T>,
40 D: DimAPI,
41 R: DataAPI<Data = B::Raw>,
42{
43 type Repr = B::Repr;
44 type ReprTo = B::ReprTo;
45 type Type = T;
46 type Dim = D;
47
48 fn change_device_f(self, device: &BOut) -> Result<TensorAny<B::Repr, T, BOut, D>> {
49 B::change_device(self, device)
50 }
51
52 fn into_device_f(self, device: &BOut) -> Result<Tensor<T, BOut, D>> {
53 B::into_device(self, device)
54 }
55
56 fn to_device_f(&'a self, device: &BOut) -> Result<TensorAny<B::ReprTo, Self::Type, BOut, Self::Dim>> {
57 B::to_device(self, device)
58 }
59}
60
61#[allow(clippy::type_complexity)]
62pub trait TensorChangeFromDevice<'l, BOut>
63where
64 BOut: DeviceRawAPI<Self::Type>,
65{
66 type Repr;
67 type ReprTo;
68 type Type;
69 type Dim: DimAPI;
70
71 fn change_device_f(self, device: &BOut) -> Result<TensorAny<Self::Repr, Self::Type, BOut, Self::Dim>>;
72 fn into_device_f(self, device: &BOut) -> Result<TensorAny<DataOwned<BOut::Raw>, Self::Type, BOut, Self::Dim>>;
73 fn to_device_f(&'l self, device: &BOut) -> Result<TensorAny<Self::ReprTo, Self::Type, BOut, Self::Dim>>;
74
75 fn change_device(self, device: &BOut) -> TensorAny<Self::Repr, Self::Type, BOut, Self::Dim>
76 where
77 Self: Sized,
78 {
79 self.change_device_f(device).rstsr_unwrap()
80 }
81
82 fn into_device(self, device: &BOut) -> TensorAny<DataOwned<BOut::Raw>, Self::Type, BOut, Self::Dim>
83 where
84 Self: Sized,
85 {
86 self.into_device_f(device).rstsr_unwrap()
87 }
88
89 fn to_device(&'l self, device: &BOut) -> TensorAny<Self::ReprTo, Self::Type, BOut, Self::Dim> {
90 self.to_device_f(device).rstsr_unwrap()
91 }
92}