tension/
num.rs

1use num_traits as num;
2use num_complex::Complex;
3
4/// Analog of `num_traits::Zero` but also implemented for `bool` type.
5pub trait Zero {
6    fn zero() -> Self;
7}
8/// Analog of `num_traits::One` but also implemented for `bool` type.
9pub trait One {
10    fn one() -> Self;
11}
12
13/// Wrapper for `num_traits::Num`.
14pub trait Num: num::Num {}
15
16/// Wrapper for `num_traits::Float`.
17pub trait Float: Num + num::Float {}
18
19impl Num for u8 {}
20impl Num for u16 {}
21impl Num for u32 {}
22impl Num for u64 {}
23
24impl Num for i8 {}
25impl Num for i16 {}
26impl Num for i32 {}
27impl Num for i64 {}
28
29impl Num for usize {}
30impl Num for isize {}
31
32impl Num for f32 {}
33impl Num for f64 {}
34
35impl Float for f32 {}
36impl Float for f64 {}
37
38impl<T: Float> Num for Complex<T> {}
39
40
41impl<T: Num> Zero for T {
42    fn zero() -> Self {
43        <T as num::Zero>::zero()
44    }
45}
46impl Zero for bool {
47    fn zero() -> Self {
48        false
49    }
50}
51impl<T: Num> One for T {
52    fn one() -> Self {
53        <T as num::One>::one()
54    }
55}
56impl One for bool {
57    fn one() -> Self {
58        false
59    }
60}
61
62
63/// Type that could be put in tensor.
64pub trait Prm : Sized + Copy + PartialEq + Zero + One {}
65
66impl<T: Num + Copy> Prm for T {}
67
68impl Prm for bool {}
69
70
71#[cfg(feature = "device")]
72mod interop {
73    use super::*;
74    use std::mem::transmute;
75    use ocl::{OclPrm, Buffer};
76    use num_complex_v01::{Complex as ComplexV01};
77
78
79    /// Types that can be transformed from host representation to device one and back.
80    pub trait Interop: Copy {
81        type Dev: OclPrm + Copy;
82
83        /// Transform from host to device type.
84        fn to_dev(self) -> Self::Dev;
85        /// Transform from device to host type.
86        fn from_dev(x: Self::Dev) -> Self;
87
88        /// Copy data from OpenCL buffer to host slice.
89        fn load_from_buffer(dst: &mut [Self], src: &Buffer<Self::Dev>) {
90            assert_eq!(dst.len(), src.len());
91            let mut tmp = Vec::<Self::Dev>::new();
92            src.read(&mut tmp).enq().unwrap();
93            for (d, &s) in dst.iter_mut().zip(tmp.iter()) {
94                *d = Self::from_dev(s);
95            }
96        }
97
98        /// Copy data from host slice to OpenCL buffer.
99        fn store_to_buffer(dst: &mut Buffer<Self::Dev>, src: &[Self]) {
100            assert_eq!(dst.len(), src.len());
101            let tmp = src.iter().map(|x| x.to_dev()).collect::<Vec<_>>();
102            dst.write(&tmp).enq().unwrap();
103        }
104    }
105
106    /// Type which representation remains the same for both host and device.
107    pub trait IdentInterop: Interop<Dev=Self> + OclPrm {}
108
109    impl <T: IdentInterop> Interop for T {
110        type Dev = Self;
111
112        fn to_dev(self) -> Self::Dev {
113            self
114        }
115        fn from_dev(x: Self::Dev) -> Self {
116            x
117        }
118        fn load_from_buffer(dst: &mut [Self], src: &Buffer<Self::Dev>) {
119            assert_eq!(dst.len(), src.len());
120            src.read(dst).enq().unwrap();
121        }
122        fn store_to_buffer(dst: &mut Buffer<Self::Dev>, src: &[Self]) {
123            assert_eq!(dst.len(), src.len());
124            dst.write(src).enq().unwrap();
125        }
126    }
127
128    impl Interop for bool {
129        type Dev = u8;
130        fn to_dev(self) -> Self::Dev {
131            if self {
132                0xFF
133            } else {
134                0x00
135            }
136        }
137        fn from_dev(x: Self::Dev) -> Self {
138            x != 0
139        }
140    }
141
142    impl IdentInterop for u8 {}
143    impl IdentInterop for u16 {}
144    impl IdentInterop for u32 {}
145    impl IdentInterop for u64 {}
146
147    impl IdentInterop for i8 {}
148    impl IdentInterop for i16 {}
149    impl IdentInterop for i32 {}
150    impl IdentInterop for i64 {}
151
152    impl IdentInterop for f32 {}
153    impl IdentInterop for f64 {}
154
155    impl Interop for usize {
156        type Dev = u32;
157        fn to_dev(self) -> Self::Dev {
158            self as Self::Dev
159        }
160        fn from_dev(x: Self::Dev) -> Self {
161            x as Self
162        }
163    }
164    impl Interop for isize {
165        type Dev = i32;
166        fn to_dev(self) -> Self::Dev {
167            self as Self::Dev
168        }
169        fn from_dev(x: Self::Dev) -> Self {
170            x as Self
171        }
172    }
173
174    impl<T: Float> Interop for Complex<T> where ComplexV01<T>: OclPrm {
175        type Dev = ComplexV01<T>;
176        fn to_dev(self) -> Self::Dev {
177            Self::Dev::new(self.re, self.im)
178        }
179        fn from_dev(x: Self::Dev) -> Self {
180            Self::new(x.re, x.im)
181        }
182        fn load_from_buffer(dst: &mut [Self], src: &Buffer<Self::Dev>) {
183            assert_eq!(dst.len(), src.len());
184            src.read(
185                unsafe { transmute::<_, &mut [Self::Dev]>(dst) }
186            ).enq().unwrap();
187        }
188        fn store_to_buffer(dst: &mut Buffer<Self::Dev>, src: &[Self]) {
189            assert_eq!(dst.len(), src.len());
190            dst.write(
191                unsafe { transmute::<_, &[Self::Dev]>(src) }
192            ).enq().unwrap();
193        }
194    }
195}
196#[cfg(feature = "device")]
197pub use interop::*;