Skip to main content

candle_core/cpu/
kernels.rs

1pub trait VecOps: num_traits::NumAssign + Copy {
2    fn min(self, rhs: Self) -> Self;
3    fn max(self, rhs: Self) -> Self;
4
5    /// Dot-product of two vectors.
6    ///
7    /// # Safety
8    ///
9    /// The length of `lhs` and `rhs` have to be at least `len`. `res` has to point to a valid
10    /// element.
11    #[inline(always)]
12    unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
13        *res = Self::zero();
14        for i in 0..len {
15            *res += *lhs.add(i) * *rhs.add(i)
16        }
17    }
18
19    /// Sum of all elements in a vector.
20    ///
21    /// # Safety
22    ///
23    /// The length of `xs` must be at least `len`. `res` has to point to a valid
24    /// element.
25    #[inline(always)]
26    unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
27        *res = Self::zero();
28        for i in 0..len {
29            *res += *xs.add(i)
30        }
31    }
32
33    /// Maximum element in a non-empty vector.
34    ///
35    /// # Safety
36    ///
37    /// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
38    /// element.
39    #[inline(always)]
40    unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
41        *res = *xs;
42        for i in 1..len {
43            *res = (*res).max(*xs.add(i))
44        }
45    }
46
47    /// Minimum element in a non-empty vector.
48    ///
49    /// # Safety
50    ///
51    /// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
52    /// element.
53    #[inline(always)]
54    unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
55        *res = *xs;
56        for i in 1..len {
57            *res = (*res).min(*xs.add(i))
58        }
59    }
60}
61
62impl VecOps for f32 {
63    #[inline(always)]
64    fn min(self, other: Self) -> Self {
65        Self::min(self, other)
66    }
67
68    #[inline(always)]
69    fn max(self, other: Self) -> Self {
70        Self::max(self, other)
71    }
72
73    #[inline(always)]
74    unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
75        super::vec_dot_f32(lhs, rhs, res, len)
76    }
77
78    #[inline(always)]
79    unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
80        super::vec_sum(xs, res, len)
81    }
82}
83
84impl VecOps for half::f16 {
85    #[inline(always)]
86    fn min(self, other: Self) -> Self {
87        Self::min(self, other)
88    }
89
90    #[inline(always)]
91    fn max(self, other: Self) -> Self {
92        Self::max(self, other)
93    }
94
95    #[inline(always)]
96    unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
97        let mut res_f32 = 0f32;
98        super::vec_dot_f16(lhs, rhs, &mut res_f32, len);
99        *res = half::f16::from_f32(res_f32);
100    }
101}
102
103impl VecOps for f64 {
104    #[inline(always)]
105    fn min(self, other: Self) -> Self {
106        Self::min(self, other)
107    }
108
109    #[inline(always)]
110    fn max(self, other: Self) -> Self {
111        Self::max(self, other)
112    }
113}
114impl VecOps for half::bf16 {
115    #[inline(always)]
116    fn min(self, other: Self) -> Self {
117        Self::min(self, other)
118    }
119
120    #[inline(always)]
121    fn max(self, other: Self) -> Self {
122        Self::max(self, other)
123    }
124
125    #[inline(always)]
126    unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
127        let mut res_f32 = 0f32;
128        super::vec_dot_bf16(lhs, rhs, &mut res_f32, len);
129        *res = half::bf16::from_f32(res_f32);
130    }
131}
132impl VecOps for u8 {
133    #[inline(always)]
134    fn min(self, other: Self) -> Self {
135        <Self as Ord>::min(self, other)
136    }
137
138    #[inline(always)]
139    fn max(self, other: Self) -> Self {
140        <Self as Ord>::max(self, other)
141    }
142}
143impl VecOps for u32 {
144    #[inline(always)]
145    fn min(self, other: Self) -> Self {
146        <Self as Ord>::min(self, other)
147    }
148
149    #[inline(always)]
150    fn max(self, other: Self) -> Self {
151        <Self as Ord>::max(self, other)
152    }
153}
154impl VecOps for i16 {
155    #[inline(always)]
156    fn min(self, other: Self) -> Self {
157        <Self as Ord>::min(self, other)
158    }
159
160    #[inline(always)]
161    fn max(self, other: Self) -> Self {
162        <Self as Ord>::max(self, other)
163    }
164}
165impl VecOps for i32 {
166    #[inline(always)]
167    fn min(self, other: Self) -> Self {
168        <Self as Ord>::min(self, other)
169    }
170
171    #[inline(always)]
172    fn max(self, other: Self) -> Self {
173        <Self as Ord>::max(self, other)
174    }
175}
176impl VecOps for i64 {
177    #[inline(always)]
178    fn min(self, other: Self) -> Self {
179        <Self as Ord>::min(self, other)
180    }
181
182    #[inline(always)]
183    fn max(self, other: Self) -> Self {
184        <Self as Ord>::max(self, other)
185    }
186}
187
188impl VecOps for float8::F8E4M3 {
189    #[inline(always)]
190    fn min(self, other: Self) -> Self {
191        Self::min(self, other)
192    }
193
194    #[inline(always)]
195    fn max(self, other: Self) -> Self {
196        Self::max(self, other)
197    }
198}
199
200#[inline(always)]
201pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
202    if n_threads == 1 {
203        func(0)
204    } else {
205        rayon::scope(|s| {
206            for thread_idx in 0..n_threads {
207                let func = &func;
208                s.spawn(move |_| func(thread_idx));
209            }
210        })
211    }
212}
213
214#[inline(always)]
215pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) {
216    if n_threads == 1 {
217        for i in lo..up {
218            func(i)
219        }
220    } else {
221        rayon::scope(|s| {
222            for thread_idx in 0..n_threads {
223                let func = &func;
224                s.spawn(move |_| {
225                    for i in (thread_idx..up).step_by(n_threads) {
226                        func(i)
227                    }
228                });
229            }
230        })
231    }
232}