candle_core/cpu/
kernels.rs1pub trait VecOps: num_traits::NumAssign + Copy {
2 fn min(self, rhs: Self) -> Self;
3 fn max(self, rhs: Self) -> Self;
4
5 #[inline(always)]
12 unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
13 *res = Self::zero();
14 for i in 0..len {
15 *res += *lhs.add(i) * *rhs.add(i)
16 }
17 }
18
19 #[inline(always)]
26 unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
27 *res = Self::zero();
28 for i in 0..len {
29 *res += *xs.add(i)
30 }
31 }
32
33 #[inline(always)]
40 unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
41 *res = *xs;
42 for i in 1..len {
43 *res = (*res).max(*xs.add(i))
44 }
45 }
46
47 #[inline(always)]
54 unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
55 *res = *xs;
56 for i in 1..len {
57 *res = (*res).min(*xs.add(i))
58 }
59 }
60}
61
62impl VecOps for f32 {
63 #[inline(always)]
64 fn min(self, other: Self) -> Self {
65 Self::min(self, other)
66 }
67
68 #[inline(always)]
69 fn max(self, other: Self) -> Self {
70 Self::max(self, other)
71 }
72
73 #[inline(always)]
74 unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
75 super::vec_dot_f32(lhs, rhs, res, len)
76 }
77
78 #[inline(always)]
79 unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
80 super::vec_sum(xs, res, len)
81 }
82}
83
84impl VecOps for half::f16 {
85 #[inline(always)]
86 fn min(self, other: Self) -> Self {
87 Self::min(self, other)
88 }
89
90 #[inline(always)]
91 fn max(self, other: Self) -> Self {
92 Self::max(self, other)
93 }
94
95 #[inline(always)]
96 unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
97 let mut res_f32 = 0f32;
98 super::vec_dot_f16(lhs, rhs, &mut res_f32, len);
99 *res = half::f16::from_f32(res_f32);
100 }
101}
102
103impl VecOps for f64 {
104 #[inline(always)]
105 fn min(self, other: Self) -> Self {
106 Self::min(self, other)
107 }
108
109 #[inline(always)]
110 fn max(self, other: Self) -> Self {
111 Self::max(self, other)
112 }
113}
114impl VecOps for half::bf16 {
115 #[inline(always)]
116 fn min(self, other: Self) -> Self {
117 Self::min(self, other)
118 }
119
120 #[inline(always)]
121 fn max(self, other: Self) -> Self {
122 Self::max(self, other)
123 }
124
125 #[inline(always)]
126 unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
127 let mut res_f32 = 0f32;
128 super::vec_dot_bf16(lhs, rhs, &mut res_f32, len);
129 *res = half::bf16::from_f32(res_f32);
130 }
131}
132impl VecOps for u8 {
133 #[inline(always)]
134 fn min(self, other: Self) -> Self {
135 <Self as Ord>::min(self, other)
136 }
137
138 #[inline(always)]
139 fn max(self, other: Self) -> Self {
140 <Self as Ord>::max(self, other)
141 }
142}
143impl VecOps for u32 {
144 #[inline(always)]
145 fn min(self, other: Self) -> Self {
146 <Self as Ord>::min(self, other)
147 }
148
149 #[inline(always)]
150 fn max(self, other: Self) -> Self {
151 <Self as Ord>::max(self, other)
152 }
153}
154impl VecOps for i16 {
155 #[inline(always)]
156 fn min(self, other: Self) -> Self {
157 <Self as Ord>::min(self, other)
158 }
159
160 #[inline(always)]
161 fn max(self, other: Self) -> Self {
162 <Self as Ord>::max(self, other)
163 }
164}
165impl VecOps for i32 {
166 #[inline(always)]
167 fn min(self, other: Self) -> Self {
168 <Self as Ord>::min(self, other)
169 }
170
171 #[inline(always)]
172 fn max(self, other: Self) -> Self {
173 <Self as Ord>::max(self, other)
174 }
175}
176impl VecOps for i64 {
177 #[inline(always)]
178 fn min(self, other: Self) -> Self {
179 <Self as Ord>::min(self, other)
180 }
181
182 #[inline(always)]
183 fn max(self, other: Self) -> Self {
184 <Self as Ord>::max(self, other)
185 }
186}
187
188impl VecOps for float8::F8E4M3 {
189 #[inline(always)]
190 fn min(self, other: Self) -> Self {
191 Self::min(self, other)
192 }
193
194 #[inline(always)]
195 fn max(self, other: Self) -> Self {
196 Self::max(self, other)
197 }
198}
199
200#[inline(always)]
201pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
202 if n_threads == 1 {
203 func(0)
204 } else {
205 rayon::scope(|s| {
206 for thread_idx in 0..n_threads {
207 let func = &func;
208 s.spawn(move |_| func(thread_idx));
209 }
210 })
211 }
212}
213
214#[inline(always)]
215pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) {
216 if n_threads == 1 {
217 for i in lo..up {
218 func(i)
219 }
220 } else {
221 rayon::scope(|s| {
222 for thread_idx in 0..n_threads {
223 let func = &func;
224 s.spawn(move |_| {
225 for i in (thread_idx..up).step_by(n_threads) {
226 func(i)
227 }
228 });
229 }
230 })
231 }
232}