candle_core/cpu/
kernels.rs1pub trait VecOps: num_traits::NumAssign + Copy {
2 fn min(self, rhs: Self) -> Self;
3 fn max(self, rhs: Self) -> Self;
4
5 #[inline(always)]
12 unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
13 *res = Self::zero();
14 for i in 0..len {
15 *res += *lhs.add(i) * *rhs.add(i)
16 }
17 }
18
19 #[inline(always)]
26 unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
27 *res = Self::zero();
28 for i in 0..len {
29 *res += *xs.add(i)
30 }
31 }
32
33 #[inline(always)]
40 unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
41 *res = *xs;
42 for i in 1..len {
43 *res = (*res).max(*xs.add(i))
44 }
45 }
46
47 #[inline(always)]
54 unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
55 *res = *xs;
56 for i in 1..len {
57 *res = (*res).min(*xs.add(i))
58 }
59 }
60}
61
62impl VecOps for f32 {
63 #[inline(always)]
64 fn min(self, other: Self) -> Self {
65 Self::min(self, other)
66 }
67
68 #[inline(always)]
69 fn max(self, other: Self) -> Self {
70 Self::max(self, other)
71 }
72
73 #[inline(always)]
74 unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
75 super::vec_dot_f32(lhs, rhs, res, len)
76 }
77
78 #[inline(always)]
79 unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
80 super::vec_sum(xs, res, len)
81 }
82}
83
84impl VecOps for half::f16 {
85 #[inline(always)]
86 fn min(self, other: Self) -> Self {
87 Self::min(self, other)
88 }
89
90 #[inline(always)]
91 fn max(self, other: Self) -> Self {
92 Self::max(self, other)
93 }
94
95 #[inline(always)]
96 unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
97 let mut res_f32 = 0f32;
98 super::vec_dot_f16(lhs, rhs, &mut res_f32, len);
99 *res = half::f16::from_f32(res_f32);
100 }
101}
102
103impl VecOps for f64 {
104 #[inline(always)]
105 fn min(self, other: Self) -> Self {
106 Self::min(self, other)
107 }
108
109 #[inline(always)]
110 fn max(self, other: Self) -> Self {
111 Self::max(self, other)
112 }
113}
114impl VecOps for half::bf16 {
115 #[inline(always)]
116 fn min(self, other: Self) -> Self {
117 Self::min(self, other)
118 }
119
120 #[inline(always)]
121 fn max(self, other: Self) -> Self {
122 Self::max(self, other)
123 }
124}
125impl VecOps for u8 {
126 #[inline(always)]
127 fn min(self, other: Self) -> Self {
128 <Self as Ord>::min(self, other)
129 }
130
131 #[inline(always)]
132 fn max(self, other: Self) -> Self {
133 <Self as Ord>::max(self, other)
134 }
135}
136impl VecOps for u32 {
137 #[inline(always)]
138 fn min(self, other: Self) -> Self {
139 <Self as Ord>::min(self, other)
140 }
141
142 #[inline(always)]
143 fn max(self, other: Self) -> Self {
144 <Self as Ord>::max(self, other)
145 }
146}
147impl VecOps for i64 {
148 #[inline(always)]
149 fn min(self, other: Self) -> Self {
150 <Self as Ord>::min(self, other)
151 }
152
153 #[inline(always)]
154 fn max(self, other: Self) -> Self {
155 <Self as Ord>::max(self, other)
156 }
157}
158
159#[inline(always)]
160pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
161 if n_threads == 1 {
162 func(0)
163 } else {
164 rayon::scope(|s| {
165 for thread_idx in 0..n_threads {
166 let func = &func;
167 s.spawn(move |_| func(thread_idx));
168 }
169 })
170 }
171}
172
173#[inline(always)]
174pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) {
175 if n_threads == 1 {
176 for i in lo..up {
177 func(i)
178 }
179 } else {
180 rayon::scope(|s| {
181 for thread_idx in 0..n_threads {
182 let func = &func;
183 s.spawn(move |_| {
184 for i in (thread_idx..up).step_by(n_threads) {
185 func(i)
186 }
187 });
188 }
189 })
190 }
191}