candle_core/cpu/
kernels.rs

1pub trait VecOps: num_traits::NumAssign + Copy {
2    fn min(self, rhs: Self) -> Self;
3    fn max(self, rhs: Self) -> Self;
4
5    /// Dot-product of two vectors.
6    ///
7    /// # Safety
8    ///
9    /// The length of `lhs` and `rhs` have to be at least `len`. `res` has to point to a valid
10    /// element.
11    #[inline(always)]
12    unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
13        *res = Self::zero();
14        for i in 0..len {
15            *res += *lhs.add(i) * *rhs.add(i)
16        }
17    }
18
19    /// Sum of all elements in a vector.
20    ///
21    /// # Safety
22    ///
23    /// The length of `xs` must be at least `len`. `res` has to point to a valid
24    /// element.
25    #[inline(always)]
26    unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
27        *res = Self::zero();
28        for i in 0..len {
29            *res += *xs.add(i)
30        }
31    }
32
33    /// Maximum element in a non-empty vector.
34    ///
35    /// # Safety
36    ///
37    /// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
38    /// element.
39    #[inline(always)]
40    unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
41        *res = *xs;
42        for i in 1..len {
43            *res = (*res).max(*xs.add(i))
44        }
45    }
46
47    /// Minimum element in a non-empty vector.
48    ///
49    /// # Safety
50    ///
51    /// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
52    /// element.
53    #[inline(always)]
54    unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
55        *res = *xs;
56        for i in 1..len {
57            *res = (*res).min(*xs.add(i))
58        }
59    }
60}
61
62impl VecOps for f32 {
63    #[inline(always)]
64    fn min(self, other: Self) -> Self {
65        Self::min(self, other)
66    }
67
68    #[inline(always)]
69    fn max(self, other: Self) -> Self {
70        Self::max(self, other)
71    }
72
73    #[inline(always)]
74    unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
75        super::vec_dot_f32(lhs, rhs, res, len)
76    }
77
78    #[inline(always)]
79    unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
80        super::vec_sum(xs, res, len)
81    }
82}
83
84impl VecOps for half::f16 {
85    #[inline(always)]
86    fn min(self, other: Self) -> Self {
87        Self::min(self, other)
88    }
89
90    #[inline(always)]
91    fn max(self, other: Self) -> Self {
92        Self::max(self, other)
93    }
94
95    #[inline(always)]
96    unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
97        let mut res_f32 = 0f32;
98        super::vec_dot_f16(lhs, rhs, &mut res_f32, len);
99        *res = half::f16::from_f32(res_f32);
100    }
101}
102
103impl VecOps for f64 {
104    #[inline(always)]
105    fn min(self, other: Self) -> Self {
106        Self::min(self, other)
107    }
108
109    #[inline(always)]
110    fn max(self, other: Self) -> Self {
111        Self::max(self, other)
112    }
113}
114impl VecOps for half::bf16 {
115    #[inline(always)]
116    fn min(self, other: Self) -> Self {
117        Self::min(self, other)
118    }
119
120    #[inline(always)]
121    fn max(self, other: Self) -> Self {
122        Self::max(self, other)
123    }
124}
125impl VecOps for u8 {
126    #[inline(always)]
127    fn min(self, other: Self) -> Self {
128        <Self as Ord>::min(self, other)
129    }
130
131    #[inline(always)]
132    fn max(self, other: Self) -> Self {
133        <Self as Ord>::max(self, other)
134    }
135}
136impl VecOps for u32 {
137    #[inline(always)]
138    fn min(self, other: Self) -> Self {
139        <Self as Ord>::min(self, other)
140    }
141
142    #[inline(always)]
143    fn max(self, other: Self) -> Self {
144        <Self as Ord>::max(self, other)
145    }
146}
147impl VecOps for i64 {
148    #[inline(always)]
149    fn min(self, other: Self) -> Self {
150        <Self as Ord>::min(self, other)
151    }
152
153    #[inline(always)]
154    fn max(self, other: Self) -> Self {
155        <Self as Ord>::max(self, other)
156    }
157}
158
159#[inline(always)]
160pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
161    if n_threads == 1 {
162        func(0)
163    } else {
164        rayon::scope(|s| {
165            for thread_idx in 0..n_threads {
166                let func = &func;
167                s.spawn(move |_| func(thread_idx));
168            }
169        })
170    }
171}
172
173#[inline(always)]
174pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) {
175    if n_threads == 1 {
176        for i in lo..up {
177            func(i)
178        }
179    } else {
180        rayon::scope(|s| {
181            for thread_idx in 0..n_threads {
182                let func = &func;
183                s.spawn(move |_| {
184                    for i in (thread_idx..up).step_by(n_threads) {
185                        func(i)
186                    }
187                });
188            }
189        })
190    }
191}