1use crate::backend::{BackendDevice, BackendStorage};
3use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
4use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
5use half::{bf16, f16};
6use rayon::prelude::*;
7
8mod utils;
9pub use utils::{
10 binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
11};
12
13const USE_IM2COL_CONV1D: bool = true;
14const USE_COL2IM_CONV1D_TR: bool = true;
15const USE_IM2COL_CONV2D: bool = true;
16
17#[derive(Debug, Clone)]
20pub enum CpuStorage {
21 U8(Vec<u8>),
22 U32(Vec<u32>),
23 I64(Vec<i64>),
24 BF16(Vec<bf16>),
25 F16(Vec<f16>),
26 F32(Vec<f32>),
27 F64(Vec<f64>),
28}
29
30#[derive(Debug, Clone)]
31pub enum CpuStorageRef<'a> {
32 U8(&'a [u8]),
33 U32(&'a [u32]),
34 I64(&'a [i64]),
35 BF16(&'a [bf16]),
36 F16(&'a [f16]),
37 F32(&'a [f32]),
38 F64(&'a [f64]),
39}
40
41#[derive(Debug, Clone)]
42pub struct CpuDevice;
43
44struct Cmp(CmpOp);
45impl Map2U8 for Cmp {
46 const OP: &'static str = "cmp";
47 #[inline(always)]
48 fn f<T: WithDType>(
49 &self,
50 lhs: &[T],
51 lhs_l: &Layout,
52 rhs: &[T],
53 rhs_l: &Layout,
54 ) -> Result<Vec<u8>> {
55 let dst = match self.0 {
56 CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)),
57 CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)),
58 CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)),
59 CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)),
60 CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)),
61 CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)),
62 };
63 Ok(dst)
64 }
65}
66
67struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
68
69impl<I: IntDType> Map2 for WCond<'_, I> {
70 const OP: &'static str = "where";
71 #[inline(always)]
72 fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
73 let vs = match (
74 self.1.contiguous_offsets(),
75 t_l.contiguous_offsets(),
76 f_l.contiguous_offsets(),
77 ) {
78 (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
79 let pred = &self.0[o1..o2];
80 let t = &t[o_t1..o_t2];
81 let f = &f[o_f1..o_f2];
82 pred.iter()
83 .zip(t.iter().zip(f.iter()))
84 .map(|(p, (&t, &f))| if p.is_true() { t } else { f })
85 .collect::<Vec<_>>()
86 }
87 _ => self
88 .1
89 .strided_index()
90 .zip(t_l.strided_index().zip(f_l.strided_index()))
91 .map(|(i_p, (i_t, i_f))| {
92 if self.0[i_p].is_true() {
93 t[i_t]
94 } else {
95 f[i_f]
96 }
97 })
98 .collect::<Vec<_>>(),
99 };
100 Ok(vs)
101 }
102}
103
104struct ReduceIndex {
105 reduce_dim_index: usize,
106 use_min: bool,
107 return_index: bool,
108}
109
110impl ReduceIndex {
111 #[inline(always)]
113 fn fold_impl<T, U, F, G>(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result<Vec<U>>
114 where
115 T: Clone + Copy,
116 U: Clone + Copy,
117 F: Fn(T, T) -> bool,
118 G: Fn(T, usize) -> U,
119 {
120 let reduce_dim_size = src_l.dims()[self.reduce_dim_index];
121 let reduce_dim_stride = src_l.stride()[self.reduce_dim_index];
122 let dst_len = src_l.shape().elem_count() / reduce_dim_size;
123 let mut dst: Vec<U> = Vec::with_capacity(dst_len);
124 let dst_to_set = dst.spare_capacity_mut();
125 let dst_to_set =
126 unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };
127 match src_l.contiguous_offsets() {
128 Some((o1, o2)) => {
129 let src = &src[o1..o2];
130 if reduce_dim_stride == 1 {
131 for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
132 let start_src_i = start_src_i * reduce_dim_size;
133 let src = &src[start_src_i..start_src_i + reduce_dim_size];
134 let mut acc = 0;
135 let mut val = src[0];
136 for (src_i, &s) in src.iter().enumerate() {
137 if f(val, s) {
138 acc = src_i;
139 val = s
140 }
141 }
142 *dst_v = g(val, acc)
143 }
144 } else {
145 for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
146 let (p, q) = (
147 start_src_i / reduce_dim_stride,
148 start_src_i % reduce_dim_stride,
149 );
150 let start_src_i = p * reduce_dim_stride * reduce_dim_size + q;
152 let src = &src[start_src_i..];
153 let mut acc = 0;
154 let mut val = src[0];
155 for src_i in 0..reduce_dim_size {
156 let s = src[src_i * reduce_dim_stride];
157 if f(val, s) {
158 acc = src_i;
159 val = s
160 }
161 }
162 *dst_v = g(val, acc)
163 }
164 }
165 }
166 None => {
167 let l = src_l.narrow(self.reduce_dim_index, 0, 1)?;
168 for (unstr_index, src_index) in l.strided_index().enumerate() {
169 let src = &src[src_index..];
170 let mut acc = 0;
171 let mut val = src[0];
172 for src_i in 0..reduce_dim_size {
173 let s = src[src_i * reduce_dim_stride];
174 if f(val, s) {
175 acc = src_i;
176 val = s
177 }
178 }
179 dst_to_set[unstr_index] = g(val, acc)
180 }
181 }
182 }
183 unsafe { dst.set_len(dst_len) };
184 Ok(dst)
185 }
186}
187
188impl Map1Any for ReduceIndex {
189 #[inline(always)]
190 fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
191 &self,
192 src: &[T],
193 src_l: &Layout,
194 wrap: W,
195 ) -> Result<CpuStorage> {
196 if src_l.shape().elem_count() == 0 {
197 Err(Error::EmptyTensor { op: "reduce" }.bt())?
198 }
199 let dst = match (self.return_index, self.use_min) {
200 (false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?),
201 (false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?),
202 (true, true) => {
203 CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?)
204 }
205 (true, false) => {
206 CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?)
207 }
208 };
209 Ok(dst)
210 }
211}
212
213struct ReduceSum<'a> {
214 dst_shape: &'a Shape,
215 reduce_dims: &'a [usize],
216 reduce_dims_and_stride: Vec<(usize, usize)>,
217}
218
219impl ReduceSum<'_> {
220 #[inline(always)]
221 fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
222 where
223 T: WithDType,
224 {
225 let mut dst = vec![start_elt; self.dst_shape.elem_count()];
226 match src_l.contiguous_offsets() {
227 Some((o1, o2)) => {
228 let src = &src[o1..o2];
229 let reduce_over_last_dims = self
233 .reduce_dims
234 .iter()
235 .rev()
236 .enumerate()
237 .all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
238 if reduce_over_last_dims {
239 let reduce_sz = self
240 .reduce_dims_and_stride
241 .iter()
242 .map(|(u, _)| u)
243 .product::<usize>();
244 for (dst_i, dst_v) in dst.iter_mut().enumerate() {
245 let src_i = dst_i * reduce_sz;
246 unsafe {
247 T::vec_reduce_sum(
248 src[src_i..src_i + reduce_sz].as_ptr(),
249 dst_v,
250 reduce_sz,
251 )
252 };
253 }
254 return Ok(dst);
255 };
256 for (unstr_index, &src) in src.iter().enumerate() {
257 let mut dst_index = unstr_index;
258 for &(dim, stride) in self.reduce_dims_and_stride.iter() {
260 let (pre, post) = (dst_index / stride, dst_index % stride);
262 dst_index = (pre / dim) * stride + post;
263 }
264 dst[dst_index] += src;
265 }
266 }
267 None => {
268 for (unstr_index, src_index) in src_l.strided_index().enumerate() {
269 let mut dst_index = unstr_index;
270 for &(dim, stride) in self.reduce_dims_and_stride.iter() {
272 let (pre, post) = (dst_index / stride, dst_index % stride);
274 dst_index = (pre / dim) * stride + post;
275 }
276 dst[dst_index] += src[src_index];
277 }
278 }
279 }
280 Ok(dst)
281 }
282}
283
284impl Map1 for ReduceSum<'_> {
285 #[inline(always)]
286 fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
287 self.fold_impl(src, src_l, T::zero())
288 }
289}
290
291struct Affine(f64, f64);
292
293impl Map1 for Affine {
294 fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
295 let mul = T::from_f64(self.0);
296 let add = T::from_f64(self.1);
297 Ok(unary_map(vs, layout, |v| v * mul + add))
298 }
299}
300
301struct AvgPool2D((usize, usize), (usize, usize));
302
303impl Map1 for AvgPool2D {
304 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
305 let (k_h, k_w) = self.0;
307 let (s_h, s_w) = self.1;
308 let (b_sz, c, h, w) = layout.shape().dims4()?;
309 let stride = layout.stride();
310 let (stride_h, stride_w) = (stride[2], stride[3]);
311 let h_out = (h - k_h) / s_h + 1;
312 let w_out = (w - k_w) / s_w + 1;
313 let src_index = layout.start_offset();
314 let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
315 let scale = 1f64 / (k_h * k_w) as f64;
316 let scale = T::from_f64(scale);
317 for b_idx in 0..b_sz {
318 let dst = &mut dst[b_idx * c * h_out * w_out..];
319 let src_index = src_index + b_idx * stride[0];
320 for c_idx in 0..c {
321 let dst = &mut dst[c_idx * h_out * w_out..];
322 let src_index = src_index + c_idx * stride[1];
323 for h_idx in 0..h_out {
324 for w_idx in 0..w_out {
325 let mut sum = T::zero();
326 for m in 0..k_h {
327 for n in 0..k_w {
328 let m = s_h * h_idx + m;
329 let n = s_w * w_idx + n;
330 sum += src[src_index + m * stride_h + n * stride_w]
331 }
332 }
333 dst[h_idx * w_out + w_idx] = sum * scale;
334 }
335 }
336 }
337 }
338 Ok(dst)
339 }
340}
341
342struct MaxPool2D((usize, usize), (usize, usize));
343
344impl Map1 for MaxPool2D {
345 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
346 let (k_h, k_w) = self.0;
348 let (s_h, s_w) = self.1;
349 let (b_sz, c, h, w) = layout.shape().dims4()?;
350 let stride = layout.stride();
351 let (stride_h, stride_w) = (stride[2], stride[3]);
352 let h_out = (h - k_h) / s_h + 1;
353 let w_out = (w - k_w) / s_w + 1;
354 let src_index = layout.start_offset();
355 let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
356 for b_idx in 0..b_sz {
357 let dst = &mut dst[b_idx * c * h_out * w_out..];
358 let src_index = src_index + b_idx * stride[0];
359 for c_idx in 0..c {
360 let dst = &mut dst[c_idx * h_out * w_out..];
361 let src_index = src_index + c_idx * stride[1];
362 for h_idx in 0..h_out {
363 for w_idx in 0..w_out {
364 let mut largest =
365 src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
366 for m in 0..k_h {
367 for n in 0..k_w {
368 let m = s_h * h_idx + m;
369 let n = s_w * w_idx + n;
370 if largest < src[src_index + m * stride_h + n * stride_w] {
371 largest = src[src_index + m * stride_h + n * stride_w]
372 }
373 }
374 }
375 dst[h_idx * w_out + w_idx] = largest;
376 }
377 }
378 }
379 }
380 Ok(dst)
381 }
382}
383
384struct UpsampleNearest1D(usize);
385
386impl Map1 for UpsampleNearest1D {
387 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
388 let dst_sz = self.0;
390 let (b_sz, c, src_sz) = layout.shape().dims3()?;
391 let stride = layout.stride();
392 let stride_sz = stride[2];
393 let src_index = layout.start_offset();
394 let scale_sz = src_sz as f64 / dst_sz as f64;
395 let mut dst = vec![T::zero(); b_sz * c * dst_sz];
396 let src_idxs = (0..dst_sz)
397 .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
398 .collect::<Vec<_>>();
399 for b_idx in 0..b_sz {
400 let dst = &mut dst[b_idx * c * dst_sz..];
401 let src_index = src_index + b_idx * stride[0];
402 for c_idx in 0..c {
403 let dst = &mut dst[c_idx * dst_sz..];
404 let src_index = src_index + c_idx * stride[1];
405 for (idx, src_idx) in src_idxs.iter().enumerate() {
406 dst[idx] = src[src_index + src_idx * stride_sz]
407 }
408 }
409 }
410 Ok(dst)
411 }
412}
413
414struct UpsampleNearest2D(usize, usize);
415
416impl Map1 for UpsampleNearest2D {
417 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
418 let (dst_h, dst_w) = (self.0, self.1);
420 let (b_sz, c, src_h, src_w) = layout.shape().dims4()?;
421 let stride = layout.stride();
422 let (stride_h, stride_w) = (stride[2], stride[3]);
423 let src_index = layout.start_offset();
424 let scale_h = src_h as f64 / dst_h as f64;
425 let scale_w = src_w as f64 / dst_w as f64;
426 let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
427 let src_h_idxs = (0..dst_h)
428 .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
429 .collect::<Vec<_>>();
430 let src_w_idxs = (0..dst_w)
431 .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
432 .collect::<Vec<_>>();
433 for b_idx in 0..b_sz {
434 let dst = &mut dst[b_idx * c * dst_h * dst_w..];
435 let src_index = src_index + b_idx * stride[0];
436 for c_idx in 0..c {
437 let dst = &mut dst[c_idx * dst_h * dst_w..];
438 let src_index = src_index + c_idx * stride[1];
439 for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() {
440 for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() {
441 let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w;
442 dst[h_idx * dst_w + w_idx] = src[src_index]
443 }
444 }
445 }
446 }
447 Ok(dst)
448 }
449}
450
451struct Gather<'a, I: IntDType> {
452 ids: &'a [I],
453 ids_l: &'a Layout,
454 dim: usize,
455}
456
457impl<I: IntDType> Map1 for Gather<'_, I> {
458 fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
459 let ids = match self.ids_l.contiguous_offsets() {
460 Some((a, b)) => &self.ids[a..b],
461 None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
462 };
463 let src = match src_l.contiguous_offsets() {
464 Some((a, b)) => &src[a..b],
465 None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
466 };
467 let dim = self.dim;
468 let ids_dims = self.ids_l.dims();
469 let src_dims = src_l.dims();
470 let dst_len: usize = ids_dims.iter().product();
471 let dst_left_len: usize = ids_dims[..dim].iter().product();
472 let dst_dim_len = ids_dims[dim];
473 let dst_right_len: usize = ids_dims[dim + 1..].iter().product();
474
475 let src_dim_len = src_dims[dim];
476 let src_right_len: usize = src_dims[dim + 1..].iter().product();
477
478 let mut dst = vec![T::zero(); dst_len];
479 for left_i in 0..dst_left_len {
480 let start_src_idx = left_i * src_right_len * src_dim_len;
481 let start_dst_idx = left_i * dst_right_len * dst_dim_len;
482 for i in 0..dst_dim_len {
483 let start_dst_idx = start_dst_idx + i * dst_right_len;
484 for right_i in 0..dst_right_len {
485 let dst_idx = start_dst_idx + right_i;
486 let index = ids[dst_idx].as_usize();
487 if index >= src_dim_len {
488 Err(Error::InvalidIndex {
489 index,
490 size: src_dim_len,
491 op: "gather",
492 }
493 .bt())?
494 }
495 let src_idx = start_src_idx + index * src_right_len + right_i;
496 dst[dst_idx] = src[src_idx]
497 }
498 }
499 }
500 Ok(dst)
501 }
502}
503
504struct IndexSelect<'a, T: IntDType> {
505 ids: &'a [T],
506 ids_l: &'a Layout,
507 dim: usize,
508}
509
510impl<I: IntDType> Map1 for IndexSelect<'_, I> {
511 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
512 let src = match layout.contiguous_offsets() {
513 Some((a, b)) => &src[a..b],
514 None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
515 };
516 let dim = self.dim;
517 let n_ids = match self.ids_l.dims() {
518 [n_ids] => *n_ids,
519 d => Err(Error::UnexpectedNumberOfDims {
520 expected: 1,
521 got: d.len(),
522 shape: self.ids_l.shape().clone(),
523 }
524 .bt())?,
525 };
526 let stride_ids = self.ids_l.stride()[0];
527 let mut dst_dims = layout.dims().to_vec();
528 let src_dim = dst_dims[dim];
529 dst_dims[dim] = n_ids;
530 let dst_len: usize = dst_dims.iter().product();
531 let left_len: usize = dst_dims[..dim].iter().product();
532 let right_len: usize = dst_dims[dim + 1..].iter().product();
533 let mut dst = vec![T::zero(); dst_len];
534 for left_i in 0..left_len {
535 let start_src_idx = left_i * right_len * src_dim;
536 let start_dst_idx = left_i * right_len * n_ids;
537 for i in 0..n_ids {
538 let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
539 if index >= src_dim {
540 Err(Error::InvalidIndex {
541 index,
542 size: src_dim,
543 op: "index-select",
544 }
545 .bt())?
546 }
547 let start_src_idx = start_src_idx + index * right_len;
548 let start_dst_idx = start_dst_idx + i * right_len;
549 dst[start_dst_idx..start_dst_idx + right_len]
550 .copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
551 }
552 }
553 Ok(dst)
554 }
555}
556
557trait ElemUpdate {
558 fn f<T: WithDType>(dst: &mut T, src: T);
559}
560
561struct Set;
562struct Add;
563
564impl ElemUpdate for Set {
565 fn f<T: WithDType>(dst: &mut T, src: T) {
566 *dst = src
567 }
568}
569
570impl ElemUpdate for Add {
571 fn f<T: WithDType>(dst: &mut T, src: T) {
572 *dst += src
573 }
574}
575
576struct Scatter<'a, I: IntDType, M: ElemUpdate> {
577 ids: &'a [I],
578 ids_l: &'a Layout,
579 dim: usize,
580 _phantom: std::marker::PhantomData<M>,
581}
582
583impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
584 fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self {
585 Self {
586 ids,
587 ids_l,
588 dim,
589 _phantom: Default::default(),
590 }
591 }
592}
593
594impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
595 const OP: &'static str = "scatter";
596 fn f<T: WithDType>(
597 &self,
598 dst: &mut [T],
599 dst_l: &Layout,
600 src: &[T],
601 src_l: &Layout,
602 ) -> Result<()> {
603 let dst = match dst_l.contiguous_offsets() {
604 None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
605 Some((o1, o2)) => &mut dst[o1..o2],
606 };
607
608 let src = match src_l.contiguous_offsets() {
609 None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
610 Some((o1, o2)) => &src[o1..o2],
611 };
612
613 let dim = self.dim;
614 let ids_dims = self.ids_l.dims();
615 let dst_dims = dst_l.dims();
616 let dst_dim_len = dst_dims[dim];
617 let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
618
619 let ids_left_len: usize = ids_dims[..dim].iter().product();
620 let ids_dim_len = ids_dims[dim];
621 let ids_right_len: usize = ids_dims[dim + 1..].iter().product();
622
623 let ids = match self.ids_l.contiguous_offsets() {
624 Some((a, b)) => &self.ids[a..b],
625 None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
626 };
627 for left_i in 0..ids_left_len {
628 let start_ids_idx = left_i * ids_right_len * ids_dim_len;
629 let start_dst_idx = left_i * dst_right_len * dst_dim_len;
630 for i in 0..ids_dim_len {
631 let start_ids_idx = start_ids_idx + i * ids_right_len;
632 for right_i in 0..dst_right_len {
633 let ids_idx = start_ids_idx + right_i;
634 let index = ids[ids_idx].as_usize();
635 if index >= dst_dim_len {
636 Err(Error::InvalidIndex {
637 index,
638 size: dst_dim_len,
639 op: "gather",
640 }
641 .bt())?
642 }
643 let dst_idx = start_dst_idx + index * dst_right_len + right_i;
644 M::f(&mut dst[dst_idx], src[ids_idx])
645 }
646 }
647 }
648
649 Ok(())
650 }
651}
652
653struct IndexAdd<'a, I: IntDType> {
654 ids: &'a [I],
655 dim: usize,
656}
657
658impl<I: IntDType> Map2 for IndexAdd<'_, I> {
659 const OP: &'static str = "index-add";
660 fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
663 let dst_len = l1.shape().elem_count();
664 let mut dst = vec![T::zero(); dst_len];
665 copy_strided_src_(v1, &mut dst, 0, l1);
666 let src = match src_l.contiguous_offsets() {
667 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
668 Some((o1, o2)) => &src[o1..o2],
669 };
670 let dim = self.dim;
671 let max_idx = l1.dims()[dim];
672 let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
673 let src_dim_sz = src_l.dims()[dim];
674 let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
675 if dim == 0 {
676 for (src_idx, dst_idx) in self.ids.iter().enumerate() {
677 let dst_idx = dst_idx.as_usize();
678 if dst_idx >= max_idx {
679 Err(Error::InvalidIndex {
680 index: dst_idx,
681 op: "index-add",
682 size: max_idx,
683 })?
684 }
685 let src_idx = src_idx * post_dim;
686 let dst_idx = dst_idx * post_dim;
687 let src = &src[src_idx..src_idx + post_dim];
688 let dst = &mut dst[dst_idx..dst_idx + post_dim];
689 for (d, &s) in dst.iter_mut().zip(src.iter()) {
690 *d += s
691 }
692 }
693 } else {
694 for (src_idx, dst_idx) in self.ids.iter().enumerate() {
695 let dst_idx = dst_idx.as_usize();
696 if dst_idx >= max_idx {
697 Err(Error::InvalidIndex {
698 index: dst_idx,
699 op: "index-add",
700 size: max_idx,
701 })?
702 }
703 for pre_i in 0..pre_dim {
704 let pre_src_i = (pre_i * src_dim_sz + src_idx) * post_dim;
705 let pre_dst_i = (pre_i * max_idx + dst_idx) * post_dim;
706 let src = &src[pre_src_i..pre_src_i + post_dim];
707 let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim];
708 for (d, &s) in dst.iter_mut().zip(src.iter()) {
709 *d += s
710 }
711 }
712 }
713 }
714 Ok(dst)
715 }
716}
717
718#[allow(clippy::too_many_arguments)]
719fn copy2d_<T: Copy>(
720 src: &[T],
721 dst: &mut [T],
722 d1: usize,
723 d2: usize,
724 src_stride1: usize,
725 dst_stride1: usize,
726 src_offset: usize,
727 dst_offset: usize,
728) {
729 for i1 in 0..d1 {
730 let dst_idx = i1 * dst_stride1 + dst_offset;
731 let src_idx = i1 * src_stride1 + src_offset;
732 let dst = &mut dst[dst_idx..dst_idx + d2];
733 let src = &src[src_idx..src_idx + d2];
734 dst.copy_from_slice(src)
735 }
736}
737
738fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
739 match src_l.strided_blocks() {
740 crate::StridedBlocks::SingleBlock { start_offset, len } => {
741 let to_copy = (dst.len() - dst_offset).min(len);
742 dst[dst_offset..dst_offset + to_copy]
743 .copy_from_slice(&src[start_offset..start_offset + to_copy])
744 }
745 crate::StridedBlocks::MultipleBlocks {
746 block_start_index,
747 block_len: 1,
748 } => {
749 for (dst_index, src_index) in block_start_index.enumerate() {
750 let dst_index = dst_index + dst_offset;
751 if dst_index >= dst.len() {
752 break;
753 }
754 dst[dst_index] = src[src_index]
755 }
756 }
757 crate::StridedBlocks::MultipleBlocks {
758 block_start_index,
759 block_len,
760 } => {
761 let mut dst_index = dst_offset;
762 for src_index in block_start_index {
763 let next_dst_index = dst_index + block_len;
764 if dst_index >= dst.len() {
765 break;
766 }
767 let to_copy = usize::min(block_len, dst.len() - dst_index);
768 dst[dst_index..dst_index + to_copy]
769 .copy_from_slice(&src[src_index..src_index + to_copy]);
770 dst_index = next_dst_index
771 }
772 }
773 }
774}
775
776struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
777
778impl Map2 for Conv1D<'_> {
779 const OP: &'static str = "conv1d";
780 fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
781 let p = self.0;
782 let inp = &inp[inp_l.start_offset()..];
783 let k = &k[k_l.start_offset()..];
784 let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
785 let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
786 let l_out = p.l_out();
787 let dst_elems = p.c_out * l_out * p.b_size;
788 let dst = vec![T::zero(); dst_elems];
790
791 let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
793 for b_idx in 0..p.b_size {
794 for src_l in 0..p.l_in {
795 for src_c_idx in 0..p.c_in {
796 let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
797 inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
798 }
799 }
800 }
801
802 for offset in 0..p.k_size {
803 (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
804 let dst_idx = dst_c_idx * l_out;
805 let k_cont = (0..p.c_in)
806 .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
807 .collect::<Vec<_>>();
808 for b_idx in 0..p.b_size {
809 let dst_idx = dst_idx + b_idx * p.c_out * l_out;
810 for dst_l in 0..l_out {
811 let dst_idx = dst_idx + dst_l;
812 let src_l = p.stride * dst_l + offset * p.dilation;
813 if src_l < p.padding || src_l >= p.padding + p.l_in {
814 continue;
815 }
816 let src_l = src_l - p.padding;
817 let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
818 assert!(inp_cont.len() >= p.c_in);
819 assert!(k_cont.len() >= p.c_in);
820 let mut d = T::zero();
821 unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
822 let dst_p = dst.as_ptr();
823 unsafe {
827 let ptr = dst_p.add(dst_idx) as *mut T;
828 *ptr += d
829 }
830 }
831 }
832 })
833 }
834 Ok(dst)
835 }
836}
837
838struct Im2Col1D {
839 l_k: usize,
840 stride: usize,
841 dilation: usize,
842 padding: usize,
843}
844
845impl Im2Col1D {
846 fn l_out(&self, l: usize) -> usize {
847 (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
848 }
849}
850
851impl Map1 for Im2Col1D {
852 fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
853 let &Self {
854 l_k,
855 stride,
856 dilation,
857 padding,
858 } = self;
859 let (b, c, l) = layout.shape().dims3()?;
860 let l_out = self.l_out(l);
861 let src = &vs[layout.start_offset()..];
862 let mut dst = vec![T::zero(); b * l_out * c * l_k];
863 let (src_s0, src_s1, src_s2) = {
864 let s = layout.stride();
865 (s[0], s[1], s[2])
866 };
867 for b_idx in 0..b {
873 let src_idx = b_idx * src_s0;
874 let dst_idx = b_idx * l_out * c * l_k;
875 for l_idx in 0..l_out {
876 let dst_idx = dst_idx + l_idx * c * l_k;
877 for c_idx in 0..c {
878 let dst_idx = dst_idx + c_idx * l_k;
879 let src_idx = c_idx * src_s1 + src_idx;
880 for l_k_idx in 0..l_k {
881 let src_l = l_idx * stride + l_k_idx * dilation;
882 if padding != 0 && (src_l < padding || src_l >= l + padding) {
883 continue;
884 }
885 let src_l = src_l - padding;
886 let src_idx = src_idx + src_l * src_s2;
887 let dst_idx = dst_idx + l_k_idx;
888 dst[dst_idx] = src[src_idx]
889 }
890 }
891 }
892 }
893 Ok(dst)
894 }
895}
896
897struct Im2Col {
898 h_k: usize,
899 w_k: usize,
900 stride: usize,
901 dilation: usize,
902 padding: usize,
903}
904
905impl Im2Col {
906 fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
907 let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
908 let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
909 (h_out, w_out)
910 }
911}
912
913impl Map1 for Im2Col {
914 fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
915 let &Self {
916 h_k,
917 w_k,
918 stride,
919 dilation,
920 padding,
921 } = self;
922 let (b, c, h, w) = layout.shape().dims4()?;
923 let (h_out, w_out) = self.hw_out(h, w);
924 let src = &vs[layout.start_offset()..];
925 let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
926 let (src_s0, src_s1, src_s2, src_s3) = {
927 let s = layout.stride();
928 (s[0], s[1], s[2], s[3])
929 };
930 for b_idx in 0..b {
936 let src_idx = b_idx * src_s0;
937 let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
938 for h_idx in 0..h_out {
939 let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
940 for w_idx in 0..w_out {
941 let dst_idx = dst_idx + w_idx * c * h_k * w_k;
942 for c_idx in 0..c {
943 let dst_idx = dst_idx + c_idx * h_k * w_k;
944 let src_idx = c_idx * src_s1 + src_idx;
945 for h_k_idx in 0..h_k {
946 let src_h = h_idx * stride + h_k_idx * dilation;
947 if padding != 0 && (src_h < padding || src_h >= h + padding) {
948 continue;
949 }
950 let src_h = src_h - padding;
951 let src_idx = src_idx + src_h * src_s2;
952 let dst_idx = dst_idx + h_k_idx * w_k;
953 for w_k_idx in 0..w_k {
954 let src_w = w_idx * stride + w_k_idx * dilation;
955 if padding != 0 && (src_w < padding || src_w >= w + padding) {
956 continue;
957 }
958 let src_w = src_w - padding;
959 let src_idx = src_idx + src_w * src_s3;
960 let dst_idx = dst_idx + w_k_idx;
961 dst[dst_idx] = src[src_idx]
962 }
963 }
964 }
965 }
966 }
967 }
968 Ok(dst)
969 }
970}
971
972struct Col2Im1D {
973 stride: usize,
974}
975
976impl Map1 for Col2Im1D {
977 fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
978 let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
979 let stride = self.stride;
980 let l_out = (l_in - 1) * stride + k_size;
981 let mut im = vec![T::zero(); b_size * c_out * l_out];
982 let (dst_s0, dst_s1) = (c_out * l_out, l_out);
983 let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
984 for l_in_i in 0..l_in {
985 for k_i in 0..k_size {
986 let l_out_i = l_in_i * stride + k_i;
987 for b_i in 0..b_size {
988 for c_i in 0..c_out {
989 let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
990 let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
991 im[dst_idx] += col[src_idx]
992 }
993 }
994 }
995 }
996 Ok(im)
997 }
998}
999
1000struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
1001
1002impl Map2 for ConvTranspose1D<'_> {
1003 const OP: &'static str = "conv_transpose1d";
1004 fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1005 let p = self.0;
1006 let inp = &inp[inp_l.start_offset()..];
1007 let k = &k[k_l.start_offset()..];
1008 let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
1009 let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
1010 let l_out = p.l_out();
1011
1012 let dst_elems = p.c_out * l_out * p.b_size;
1014 let dst = vec![T::zero(); dst_elems];
1015 let dst_s0 = p.c_out * l_out;
1016 let dst_s1 = l_out;
1017 let dst_s2 = 1;
1018
1019 let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
1021 let cont_s0 = p.l_in * p.c_in;
1022 let cont_s1 = p.c_in;
1023 for b_idx in 0..p.b_size {
1024 for l_idx in 0..p.l_in {
1025 for c_idx in 0..p.c_in {
1026 let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
1027 let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
1028 inp_cont[dst_idx] = inp[src_idx]
1029 }
1030 }
1031 }
1032
1033 for k_idx in 0..p.k_size {
1034 (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1035 let k_cont = (0..p.c_in)
1036 .map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
1037 .collect::<Vec<_>>();
1038 for b_idx in 0..p.b_size {
1039 for l_idx in 0..p.l_in {
1040 let out_idx = l_idx * p.stride + k_idx * p.dilation;
1041 if out_idx < p.padding {
1042 continue;
1043 }
1044 let out_idx = out_idx - p.padding;
1045 if out_idx < l_out {
1046 let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
1047 let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
1048 let mut d = T::zero();
1049 unsafe {
1050 T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
1051 }
1052 let dst_p = dst.as_ptr();
1053 unsafe {
1057 let ptr = dst_p.add(dst_idx) as *mut T;
1058 *ptr += d
1059 }
1060 }
1061 }
1062 }
1063 })
1064 }
1065 Ok(dst)
1066 }
1067}
1068
1069struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
1070
1071impl Map2 for Conv2D<'_> {
1072 const OP: &'static str = "conv2d";
1073 fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1074 let p = self.0;
1075 let inp = &inp[inp_l.start_offset()..];
1076 let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
1077 let k = &k[k_l.start_offset()..];
1078 let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
1079 let (out_h, out_w) = (p.out_h(), p.out_w());
1080
1081 let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
1083
1084 let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
1086 let cont_s0 = p.i_h * p.i_w * p.c_in;
1087 let cont_s1 = p.i_w * p.c_in;
1088 let cont_s2 = p.c_in;
1089 for b_idx in 0..p.b_size {
1090 for h_idx in 0..p.i_h {
1091 for w_idx in 0..p.i_w {
1092 for c_idx in 0..p.c_in {
1093 let src_idx =
1094 b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
1095 let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
1096 inp_cont[dst_idx] = inp[src_idx]
1097 }
1098 }
1099 }
1100 }
1101
1102 for offset_h in 0..p.k_h {
1103 for offset_w in 0..p.k_w {
1104 (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1105 let dst_idx = dst_c_idx * out_w * out_h;
1106 let k_cont = (0..p.c_in)
1107 .map(|c_in_idx| {
1108 k[dst_c_idx * k_s0
1109 + c_in_idx * k_s1
1110 + offset_h * k_s2
1111 + offset_w * k_s3]
1112 })
1113 .collect::<Vec<_>>();
1114 for b_idx in 0..p.b_size {
1115 let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
1116 for dst_h in 0..out_h {
1117 let dst_idx = dst_idx + dst_h * out_w;
1118 let src_h = p.stride * dst_h + offset_h * p.dilation;
1119 if src_h < p.padding || src_h >= p.i_h + p.padding {
1120 continue;
1121 }
1122 let src_h = src_h - p.padding;
1123 for dst_w in 0..out_w {
1124 let dst_idx = dst_idx + dst_w;
1125 let src_w = p.stride * dst_w + offset_w * p.dilation;
1126 if src_w < p.padding || src_w >= p.i_w + p.padding {
1127 continue;
1128 }
1129 let src_w = src_w - p.padding;
1130 let inp_cont = &inp_cont
1131 [b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..];
1132 assert!(inp_cont.len() >= p.c_in);
1133 assert!(k_cont.len() >= p.c_in);
1134 let mut d = T::zero();
1135 unsafe {
1136 T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
1137 }
1138 let dst_p = dst.as_ptr();
1139 unsafe {
1143 let ptr = dst_p.add(dst_idx) as *mut T;
1144 *ptr += d
1145 }
1146 }
1147 }
1148 }
1149 });
1150 }
1151 }
1152
1153 Ok(dst)
1154 }
1155}
1156
1157struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
1158
1159impl Map2 for ConvTranspose2D<'_> {
1160 const OP: &'static str = "conv_transpose2d";
1161 fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1162 let p = self.0;
1163 let inp = &inp[inp_l.start_offset()..];
1164 let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
1165 let k = &k[k_l.start_offset()..];
1166 let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
1167 let (out_h, out_w) = (p.out_h(), p.out_w());
1168
1169 let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
1171 let dst_s0 = p.c_out * out_h * out_w;
1172 let dst_s1 = out_h * out_w;
1173 let dst_s2 = out_w;
1174 let dst_s3 = 1;
1175
1176 let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
1178 let cont_s0 = p.i_h * p.i_w * p.c_in;
1179 let cont_s1 = p.i_w * p.c_in;
1180 let cont_s2 = p.c_in;
1181 for b_idx in 0..p.b_size {
1182 for h_idx in 0..p.i_h {
1183 for w_idx in 0..p.i_w {
1184 for c_idx in 0..p.c_in {
1185 let src_idx =
1186 b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
1187 let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
1188 inp_cont[dst_idx] = inp[src_idx]
1189 }
1190 }
1191 }
1192 }
1193
1194 for k_y in 0..p.k_h {
1195 for k_x in 0..p.k_w {
1196 (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1197 let k_cont = (0..p.c_in)
1198 .map(|c_in_idx| {
1199 k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
1200 })
1201 .collect::<Vec<_>>();
1202 for b_idx in 0..p.b_size {
1203 for inp_y in 0..p.i_h {
1204 for inp_x in 0..p.i_w {
1205 let out_x = inp_x * p.stride + k_x * p.dilation;
1206 let out_y = inp_y * p.stride + k_y * p.dilation;
1207 if out_x < p.padding || out_y < p.padding {
1208 continue;
1209 }
1210 let out_x = out_x - p.padding;
1211 let out_y = out_y - p.padding;
1212 if out_x < out_w && out_y < out_h {
1213 let inp_cont = &inp_cont
1214 [b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..];
1215 let dst_idx = b_idx * dst_s0
1216 + out_y * dst_s2
1217 + out_x * dst_s3
1218 + dst_c_idx * dst_s1;
1219 let mut d = T::zero();
1220 unsafe {
1221 T::vec_dot(
1222 inp_cont.as_ptr(),
1223 k_cont.as_ptr(),
1224 &mut d,
1225 p.c_in,
1226 )
1227 }
1228 let dst_p = dst.as_ptr();
1229 unsafe {
1233 let ptr = dst_p.add(dst_idx) as *mut T;
1234 *ptr += d
1235 }
1236 }
1237 }
1238 }
1239 }
1240 })
1241 }
1242 }
1243 Ok(dst)
1244 }
1245}
1246
1247struct MatMul((usize, usize, usize, usize));
1248
1249impl MatMul {
1250 fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
1251 Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
1252 lhs_l: lhs_l.clone(),
1253 rhs_l: rhs_l.clone(),
1254 bmnk: self.0,
1255 msg,
1256 }))
1257 .bt()
1258 }
1259
1260 fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
1261 let lhs_stride = lhs_l.stride();
1262 let rhs_stride = rhs_l.stride();
1263 let rank = lhs_stride.len();
1264 let (_b, m, n, k) = self.0;
1265 let a_skip: usize = match lhs_stride[..rank - 2] {
1266 [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
1267 [_, stride] if lhs_l.dims()[0] == 1 => stride,
1268 [stride, _] if lhs_l.dims()[1] == 1 => stride,
1269 [stride] => stride,
1270 [] => m * k,
1271 _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
1272 };
1273 let b_skip: usize = match rhs_stride[..rank - 2] {
1274 [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
1275 [_, stride] if rhs_l.dims()[0] == 1 => stride,
1276 [stride, _] if rhs_l.dims()[1] == 1 => stride,
1277 [stride] => stride,
1278 [] => n * k,
1279 _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
1280 };
1281 Ok((a_skip, b_skip))
1282 }
1283}
1284
1285impl Map2 for MatMul {
1286 const OP: &'static str = "mat_mul";
1287
1288 #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
1289 fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1290 &self,
1291 lhs: &[T],
1292 lhs_l: &Layout,
1293 rhs: &[T],
1294 rhs_l: &Layout,
1295 ) -> Result<Vec<T>> {
1296 use gemm::{gemm, Parallelism};
1297
1298 match T::DTYPE {
1299 DType::F16 | DType::F32 | DType::F64 => {}
1300 _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
1301 }
1302
1303 let (b, m, n, k) = self.0;
1304 let lhs = &lhs[lhs_l.start_offset()..];
1305 let rhs = &rhs[rhs_l.start_offset()..];
1306
1307 let lhs_stride = lhs_l.stride();
1308 let rhs_stride = rhs_l.stride();
1309 let rank = lhs_stride.len();
1310 let lhs_cs = lhs_stride[rank - 1];
1311 let lhs_rs = lhs_stride[rank - 2];
1312
1313 let rhs_cs = rhs_stride[rank - 1];
1314 let rhs_rs = rhs_stride[rank - 2];
1315
1316 let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1317 let c_skip: usize = m * n;
1318
1319 let dst_shape: Shape = (m, n).into();
1320 let dst_strides = dst_shape.stride_contiguous();
1321 let dst_rs = dst_strides[0];
1322 let dst_cs = dst_strides[1];
1323
1324 let mut dst = vec![T::zero(); b * m * n];
1325 let num_threads = crate::utils::get_num_threads();
1326 let parallelism = if num_threads > 1 {
1327 Parallelism::Rayon(num_threads)
1328 } else {
1329 Parallelism::None
1330 };
1331 let (b, m, n, k) = if b_skip == 0 && a_skip == m * k {
1332 (1, b * m, n, k)
1335 } else if a_skip == 0 && b_skip == n * k {
1336 (1, m, b * n, k)
1337 } else {
1338 (b, m, n, k)
1339 };
1340 for step in 0..b {
1341 let lhs_p = &lhs[step * a_skip..];
1342 let rhs_p = &rhs[step * b_skip..];
1343 let dst_p = &mut dst[step * c_skip..];
1344 unsafe {
1345 gemm(
1346 m,
1347 n,
1348 k,
1349 dst_p.as_mut_ptr(),
1350 dst_cs as isize,
1351 dst_rs as isize,
1352 false,
1353 lhs_p.as_ptr(),
1354 lhs_cs as isize,
1355 lhs_rs as isize,
1356 rhs_p.as_ptr(),
1357 rhs_cs as isize,
1358 rhs_rs as isize,
1359 T::zero(),
1360 T::one(),
1361 false,
1362 false,
1363 false,
1364 parallelism,
1365 )
1366 }
1367 }
1368 Ok(dst)
1369 }
1370
1371 #[cfg(feature = "accelerate")]
1372 fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1373 &self,
1374 lhs: &[T],
1375 lhs_l: &Layout,
1376 rhs: &[T],
1377 rhs_l: &Layout,
1378 ) -> Result<Vec<T>> {
1379 let (b, m, n, k) = self.0;
1380 let lhs = &lhs[lhs_l.start_offset()..];
1381 let rhs = &rhs[rhs_l.start_offset()..];
1382
1383 let lhs_stride = lhs_l.stride();
1384 let rhs_stride = rhs_l.stride();
1385
1386 let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1387 let c_skip: usize = m * n;
1388
1389 let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1390 let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1391 let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1392 let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1393
1394 let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
1395 (n as i32, b'N')
1396 } else if rhs_m1 == k && rhs_m2 == 1 {
1397 (k as i32, b'T')
1398 } else {
1399 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1400 };
1401 let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
1403 (k as i32, b'N')
1404 } else if lhs_m1 == m && lhs_m2 == 1 {
1405 (m as i32, b'T')
1406 } else {
1407 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1408 };
1409
1410 let mut dst = vec![T::zero(); b * m * n];
1411 match T::DTYPE {
1412 DType::F16 => {
1413 crate::bail!("the accelerate backend does not support f16 matmul")
1414 }
1415 DType::F32 => {
1416 for step in 0..b {
1417 let lhs_p = &lhs[step * a_skip..];
1418 let rhs_p = &rhs[step * b_skip..];
1419 let dst_p = &mut dst[step * c_skip..];
1420 unsafe {
1421 let a = rhs_p.as_ptr() as *const f32;
1422 let b = lhs_p.as_ptr() as *const f32;
1423 let c = dst_p.as_mut_ptr() as *mut f32;
1424 let a = std::slice::from_raw_parts(a, a_skip);
1425 let b = std::slice::from_raw_parts(b, b_skip);
1426 let c = std::slice::from_raw_parts_mut(c, c_skip);
1427 crate::accelerate::sgemm(
1428 transa, transb, n as i32, m as i32,
1429 k as i32, 1., a,
1430 lda, b, ldb,
1431 0., c, n as i32,
1432 )
1433 }
1434 }
1435 }
1436 DType::F64 => {
1437 for step in 0..b {
1438 let lhs_p = &lhs[step * a_skip..];
1439 let rhs_p = &rhs[step * b_skip..];
1440 let dst_p = &mut dst[step * c_skip..];
1441 unsafe {
1442 let a = rhs_p.as_ptr() as *const f64;
1443 let b = lhs_p.as_ptr() as *const f64;
1444 let c = dst_p.as_mut_ptr() as *mut f64;
1445 let a = std::slice::from_raw_parts(a, a_skip);
1446 let b = std::slice::from_raw_parts(b, b_skip);
1447 let c = std::slice::from_raw_parts_mut(c, c_skip);
1448 crate::accelerate::dgemm(
1449 transa, transb, n as i32, m as i32,
1450 k as i32, 1., a,
1451 lda, b, ldb,
1452 0., c, n as i32,
1453 )
1454 }
1455 }
1456 }
1457 dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1458 }
1459 Ok(dst)
1460 }
1461
1462 #[cfg(feature = "mkl")]
1463 fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1464 &self,
1465 lhs: &[T],
1466 lhs_l: &Layout,
1467 rhs: &[T],
1468 rhs_l: &Layout,
1469 ) -> Result<Vec<T>> {
1470 let (b, m, n, k) = self.0;
1471 let lhs = &lhs[lhs_l.start_offset()..];
1472 let rhs = &rhs[rhs_l.start_offset()..];
1473
1474 let lhs_stride = lhs_l.stride();
1475 let rhs_stride = rhs_l.stride();
1476
1477 let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1478 let c_skip: usize = m * n;
1479
1480 let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1481 let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1482 let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1483 let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1484
1485 let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
1486 (n as i32, b'N')
1487 } else if rhs_m1 == k && rhs_m2 == 1 {
1488 (k as i32, b'T')
1489 } else {
1490 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1491 };
1492 let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
1494 (k as i32, b'N')
1495 } else if lhs_m1 == m && lhs_m2 == 1 {
1496 (m as i32, b'T')
1497 } else {
1498 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1499 };
1500
1501 let mut dst = vec![T::zero(); b * m * n];
1502 match T::DTYPE {
1503 DType::F16 => {
1504 for step in 0..b {
1505 let lhs_p = &lhs[step * a_skip..];
1506 let rhs_p = &rhs[step * b_skip..];
1507 let dst_p = &mut dst[step * c_skip..];
1508 unsafe {
1509 let a = rhs_p.as_ptr() as *const f16;
1510 let b = lhs_p.as_ptr() as *const f16;
1511 let c = dst_p.as_mut_ptr() as *mut f16;
1512 let a = std::slice::from_raw_parts(a, a_skip);
1513 let b = std::slice::from_raw_parts(b, b_skip);
1514 let c = std::slice::from_raw_parts_mut(c, c_skip);
1515 crate::mkl::hgemm(
1516 transa,
1517 transb,
1518 n as i32,
1519 m as i32,
1520 k as i32,
1521 f16::ONE,
1522 a,
1523 lda,
1524 b,
1525 ldb,
1526 f16::ZERO,
1527 c,
1528 n as i32,
1529 )
1530 }
1531 }
1532 }
1533 DType::F32 => {
1534 for step in 0..b {
1535 let lhs_p = &lhs[step * a_skip..];
1536 let rhs_p = &rhs[step * b_skip..];
1537 let dst_p = &mut dst[step * c_skip..];
1538 unsafe {
1539 let a = rhs_p.as_ptr() as *const f32;
1540 let b = lhs_p.as_ptr() as *const f32;
1541 let c = dst_p.as_mut_ptr() as *mut f32;
1542 let a = std::slice::from_raw_parts(a, a_skip);
1543 let b = std::slice::from_raw_parts(b, b_skip);
1544 let c = std::slice::from_raw_parts_mut(c, c_skip);
1545 crate::mkl::sgemm(
1546 transa, transb, n as i32, m as i32,
1547 k as i32, 1., a,
1548 lda, b, ldb,
1549 0., c, n as i32,
1550 )
1551 }
1552 }
1553 }
1554 DType::F64 => {
1555 for step in 0..b {
1556 let lhs_p = &lhs[step * a_skip..];
1557 let rhs_p = &rhs[step * b_skip..];
1558 let dst_p = &mut dst[step * c_skip..];
1559 unsafe {
1560 let a = rhs_p.as_ptr() as *const f64;
1561 let b = lhs_p.as_ptr() as *const f64;
1562 let c = dst_p.as_mut_ptr() as *mut f64;
1563 let a = std::slice::from_raw_parts(a, a_skip);
1564 let b = std::slice::from_raw_parts(b, b_skip);
1565 let c = std::slice::from_raw_parts_mut(c, c_skip);
1566 crate::mkl::dgemm(
1567 transa, transb, n as i32, m as i32,
1568 k as i32, 1., a,
1569 lda, b, ldb,
1570 0., c, n as i32,
1571 )
1572 }
1573 }
1574 }
1575 dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1576 }
1577 Ok(dst)
1578 }
1579}
1580
1581fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
1582 if v.is_sign_positive() {
1583 v
1584 } else {
1585 (v.exp() - T::one()) * alpha
1586 }
1587}
1588
1589impl CpuStorage {
1590 pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
1591 D::cpu_storage_as_slice(self)
1592 }
1593
1594 pub fn concat(storages: &[CpuStorage]) -> Result<CpuStorage> {
1595 let storage0 = &storages[0];
1596 let s = match storage0 {
1597 Self::U8(_) => {
1598 let storages = storages
1599 .iter()
1600 .map(|s| match s {
1601 Self::U8(s) => Ok(s.as_slice()),
1602 _ => crate::bail!("dtype mismatch"),
1603 })
1604 .collect::<Result<Vec<_>>>()?
1605 .concat();
1606 Self::U8(storages)
1607 }
1608 Self::U32(_) => {
1609 let storages = storages
1610 .iter()
1611 .map(|s| match s {
1612 Self::U32(s) => Ok(s.as_slice()),
1613 _ => crate::bail!("dtype mismatch"),
1614 })
1615 .collect::<Result<Vec<_>>>()?
1616 .concat();
1617 Self::U32(storages)
1618 }
1619 Self::I64(_) => {
1620 let storages = storages
1621 .iter()
1622 .map(|s| match s {
1623 Self::I64(s) => Ok(s.as_slice()),
1624 _ => crate::bail!("dtype mismatch"),
1625 })
1626 .collect::<Result<Vec<_>>>()?
1627 .concat();
1628 Self::I64(storages)
1629 }
1630 Self::BF16(_) => {
1631 let storages = storages
1632 .iter()
1633 .map(|s| match s {
1634 Self::BF16(s) => Ok(s.as_slice()),
1635 _ => crate::bail!("dtype mismatch"),
1636 })
1637 .collect::<Result<Vec<_>>>()?
1638 .concat();
1639 Self::BF16(storages)
1640 }
1641 Self::F16(_) => {
1642 let storages = storages
1643 .iter()
1644 .map(|s| match s {
1645 Self::F16(s) => Ok(s.as_slice()),
1646 _ => crate::bail!("dtype mismatch"),
1647 })
1648 .collect::<Result<Vec<_>>>()?
1649 .concat();
1650 Self::F16(storages)
1651 }
1652 Self::F32(_) => {
1653 let storages = storages
1654 .iter()
1655 .map(|s| match s {
1656 Self::F32(s) => Ok(s.as_slice()),
1657 _ => crate::bail!("dtype mismatch"),
1658 })
1659 .collect::<Result<Vec<_>>>()?
1660 .concat();
1661 Self::F32(storages)
1662 }
1663 Self::F64(_) => {
1664 let storages = storages
1665 .iter()
1666 .map(|s| match s {
1667 Self::F64(s) => Ok(s.as_slice()),
1668 _ => crate::bail!("dtype mismatch"),
1669 })
1670 .collect::<Result<Vec<_>>>()?
1671 .concat();
1672 Self::F64(storages)
1673 }
1674 };
1675 Ok(s)
1676 }
1677}
1678
1679impl BackendStorage for CpuStorage {
1680 type Device = CpuDevice;
1681
1682 fn dtype(&self) -> DType {
1683 match self {
1684 Self::U8(_) => DType::U8,
1685 Self::U32(_) => DType::U32,
1686 Self::I64(_) => DType::I64,
1687 Self::BF16(_) => DType::BF16,
1688 Self::F16(_) => DType::F16,
1689 Self::F32(_) => DType::F32,
1690 Self::F64(_) => DType::F64,
1691 }
1692 }
1693
1694 fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
1695 match (self, dtype) {
1697 (Self::U8(storage), DType::BF16) => {
1698 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1699 Ok(Self::BF16(data))
1700 }
1701 (Self::U32(storage), DType::BF16) => {
1702 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1703 Ok(Self::BF16(data))
1704 }
1705 (Self::I64(storage), DType::BF16) => {
1706 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1707 Ok(Self::BF16(data))
1708 }
1709 (Self::BF16(storage), DType::BF16) => {
1710 let data = unary_map(storage, layout, |v| v);
1711 Ok(Self::BF16(data))
1712 }
1713 (Self::F16(storage), DType::BF16) => {
1714 let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
1715 Ok(Self::BF16(data))
1716 }
1717 (Self::F32(storage), DType::BF16) => {
1718 let data = unary_map(storage, layout, bf16::from_f32);
1719 Ok(Self::BF16(data))
1720 }
1721 (Self::F64(storage), DType::BF16) => {
1722 let data = unary_map(storage, layout, bf16::from_f64);
1723 Ok(Self::BF16(data))
1724 }
1725 (Self::U8(storage), DType::F16) => {
1726 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1727 Ok(Self::F16(data))
1728 }
1729 (Self::U32(storage), DType::F16) => {
1730 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1731 Ok(Self::F16(data))
1732 }
1733 (Self::I64(storage), DType::F16) => {
1734 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1735 Ok(Self::F16(data))
1736 }
1737 (Self::BF16(storage), DType::F16) => {
1738 let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
1739 Ok(Self::F16(data))
1740 }
1741 (Self::F16(storage), DType::F16) => {
1742 let data = unary_map(storage, layout, |v| v);
1743 Ok(Self::F16(data))
1744 }
1745 (Self::F32(storage), DType::F16) => {
1746 let data = unary_map(storage, layout, f16::from_f32);
1747 Ok(Self::F16(data))
1748 }
1749 (Self::F64(storage), DType::F16) => {
1750 let data = unary_map(storage, layout, f16::from_f64);
1751 Ok(Self::F16(data))
1752 }
1753 (Self::U8(storage), DType::F32) => {
1754 let data = unary_map(storage, layout, |v| v as f32);
1755 Ok(Self::F32(data))
1756 }
1757 (Self::U32(storage), DType::F32) => {
1758 let data = unary_map(storage, layout, |v| v as f32);
1759 Ok(Self::F32(data))
1760 }
1761 (Self::I64(storage), DType::F32) => {
1762 let data = unary_map(storage, layout, |v| v as f32);
1763 Ok(Self::F32(data))
1764 }
1765 (Self::BF16(storage), DType::F32) => {
1766 let data = unary_map(storage, layout, |v| v.to_f32());
1767 Ok(Self::F32(data))
1768 }
1769 (Self::F16(storage), DType::F32) => {
1770 let data = unary_map(storage, layout, |v| v.to_f32());
1771 Ok(Self::F32(data))
1772 }
1773 (Self::F32(storage), DType::F32) => {
1774 let data = unary_map(storage, layout, |v| v);
1775 Ok(Self::F32(data))
1776 }
1777 (Self::F64(storage), DType::F32) => {
1778 let data = unary_map(storage, layout, |v| v as f32);
1779 Ok(Self::F32(data))
1780 }
1781 (Self::U8(storage), DType::U8) => {
1782 let data = unary_map(storage, layout, |v| v);
1783 Ok(Self::U8(data))
1784 }
1785 (Self::BF16(storage), DType::U8) => {
1786 let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1787 Ok(Self::U8(data))
1788 }
1789 (Self::F16(storage), DType::U8) => {
1790 let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1791 Ok(Self::U8(data))
1792 }
1793 (Self::F32(storage), DType::U8) => {
1794 let data = unary_map(storage, layout, |v| v as u8);
1795 Ok(Self::U8(data))
1796 }
1797 (Self::F64(storage), DType::U8) => {
1798 let data = unary_map(storage, layout, |v| v as u8);
1799 Ok(Self::U8(data))
1800 }
1801 (Self::U32(storage), DType::U8) => {
1802 let data = unary_map(storage, layout, |v| v as u8);
1803 Ok(Self::U8(data))
1804 }
1805 (Self::I64(storage), DType::U8) => {
1806 let data = unary_map(storage, layout, |v| v as u8);
1807 Ok(Self::U8(data))
1808 }
1809 (Self::U8(storage), DType::U32) => {
1810 let data = unary_map(storage, layout, |v| v as u32);
1811 Ok(Self::U32(data))
1812 }
1813 (Self::U32(storage), DType::U32) => {
1814 let data = unary_map(storage, layout, |v| v);
1815 Ok(Self::U32(data))
1816 }
1817 (Self::I64(storage), DType::U32) => {
1818 let data = unary_map(storage, layout, |v| v as u32);
1819 Ok(Self::U32(data))
1820 }
1821 (Self::BF16(storage), DType::U32) => {
1822 let data = unary_map(storage, layout, |v| v.to_f32() as u32);
1823 Ok(Self::U32(data))
1824 }
1825 (Self::F16(storage), DType::U32) => {
1826 let data = unary_map(storage, layout, |v| v.to_f32() as u32);
1827 Ok(Self::U32(data))
1828 }
1829 (Self::F32(storage), DType::U32) => {
1830 let data = unary_map(storage, layout, |v| v as u32);
1831 Ok(Self::U32(data))
1832 }
1833 (Self::F64(storage), DType::U32) => {
1834 let data = unary_map(storage, layout, |v| v as u32);
1835 Ok(Self::U32(data))
1836 }
1837 (Self::U8(storage), DType::I64) => {
1838 let data = unary_map(storage, layout, |v| v as i64);
1839 Ok(Self::I64(data))
1840 }
1841 (Self::U32(storage), DType::I64) => {
1842 let data = unary_map(storage, layout, |v| v as i64);
1843 Ok(Self::I64(data))
1844 }
1845 (Self::I64(storage), DType::I64) => {
1846 let data = unary_map(storage, layout, |v| v);
1847 Ok(Self::I64(data))
1848 }
1849 (Self::BF16(storage), DType::I64) => {
1850 let data = unary_map(storage, layout, |v| v.to_f32() as i64);
1851 Ok(Self::I64(data))
1852 }
1853 (Self::F16(storage), DType::I64) => {
1854 let data = unary_map(storage, layout, |v| v.to_f32() as i64);
1855 Ok(Self::I64(data))
1856 }
1857 (Self::F32(storage), DType::I64) => {
1858 let data = unary_map(storage, layout, |v| v as i64);
1859 Ok(Self::I64(data))
1860 }
1861 (Self::F64(storage), DType::I64) => {
1862 let data = unary_map(storage, layout, |v| v as i64);
1863 Ok(Self::I64(data))
1864 }
1865 (Self::U8(storage), DType::F64) => {
1866 let data = unary_map(storage, layout, |v| v as f64);
1867 Ok(Self::F64(data))
1868 }
1869 (Self::U32(storage), DType::F64) => {
1870 let data = unary_map(storage, layout, |v| v as f64);
1871 Ok(Self::F64(data))
1872 }
1873 (Self::I64(storage), DType::F64) => {
1874 let data = unary_map(storage, layout, |v| v as f64);
1875 Ok(Self::F64(data))
1876 }
1877 (Self::BF16(storage), DType::F64) => {
1878 let data = unary_map(storage, layout, |v| v.to_f64());
1879 Ok(Self::F64(data))
1880 }
1881 (Self::F16(storage), DType::F64) => {
1882 let data = unary_map(storage, layout, |v| v.to_f64());
1883 Ok(Self::F64(data))
1884 }
1885 (Self::F32(storage), DType::F64) => {
1886 let data = unary_map(storage, layout, |v| v as f64);
1887 Ok(Self::F64(data))
1888 }
1889 (Self::F64(storage), DType::F64) => {
1890 let data = unary_map(storage, layout, |v| v);
1891 Ok(Self::F64(data))
1892 }
1893 }
1894 }
1895
1896 fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
1897 match op {
1898 ReduceOp::Sum => {
1899 let src_dims = layout.dims();
1900 let mut dst_dims = src_dims.to_vec();
1901 for &dim in reduce_dims.iter() {
1902 dst_dims[dim] = 1;
1903 }
1904 let dst_shape = Shape::from(dst_dims);
1905 let mut reduce_dims = reduce_dims.to_vec();
1906 reduce_dims.sort();
1909 let reduce_dims_and_stride: Vec<_> = reduce_dims
1910 .iter()
1911 .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
1912 .collect();
1913 ReduceSum {
1914 dst_shape: &dst_shape,
1915 reduce_dims: &reduce_dims,
1916 reduce_dims_and_stride,
1917 }
1918 .map(self, layout)
1919 }
1920 ReduceOp::Min | ReduceOp::ArgMin | ReduceOp::Max | ReduceOp::ArgMax => {
1921 let reduce_dim_index = match reduce_dims {
1922 [reduce_dim_index] => *reduce_dim_index,
1923 _ => {
1924 let op = match op {
1925 ReduceOp::Min => "min",
1926 ReduceOp::ArgMin => "argmin",
1927 ReduceOp::Max => "max",
1928 ReduceOp::ArgMax => "argmax",
1929 _ => unreachable!(),
1930 };
1931 let dims = reduce_dims.to_vec();
1932 Err(Error::OnlySingleDimension { op, dims })?
1933 }
1934 };
1935 let (use_min, return_index) = match op {
1936 ReduceOp::Min => (true, false),
1937 ReduceOp::ArgMin => (true, true),
1938 ReduceOp::Max => (false, false),
1939 ReduceOp::ArgMax => (false, true),
1940 _ => unreachable!(),
1941 };
1942 ReduceIndex {
1943 reduce_dim_index,
1944 use_min,
1945 return_index,
1946 }
1947 .map(self, layout)
1948 }
1949 }
1950 }
1951
1952 fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
1953 Cmp(op).map(self, lhs_l, rhs, rhs_l)
1954 }
1955
1956 fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
1957 Affine(mul, add).map(self, layout)
1958 }
1959
1960 fn avg_pool2d(
1961 &self,
1962 layout: &Layout,
1963 kernel_size: (usize, usize),
1964 stride: (usize, usize),
1965 ) -> Result<Self> {
1966 AvgPool2D(kernel_size, stride).map(self, layout)
1967 }
1968
1969 fn max_pool2d(
1970 &self,
1971 layout: &Layout,
1972 kernel_size: (usize, usize),
1973 stride: (usize, usize),
1974 ) -> Result<Self> {
1975 MaxPool2D(kernel_size, stride).map(self, layout)
1976 }
1977
1978 fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
1979 UpsampleNearest1D(sz).map(self, layout)
1980 }
1981
1982 fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
1983 UpsampleNearest2D(h, w).map(self, layout)
1984 }
1985
1986 fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
1987 use num_traits::Float;
1988 match self {
1990 Self::BF16(storage) => {
1991 let data = unary_map(storage, layout, |v| v.powf(bf16::from_f64(e)));
1992 Ok(Self::BF16(data))
1993 }
1994 Self::F16(storage) => {
1995 let data = unary_map(storage, layout, |v| v.powf(f16::from_f64(e)));
1996 Ok(Self::F16(data))
1997 }
1998 Self::F32(storage) => {
1999 let data = unary_map(storage, layout, |v| v.powf(e as f32));
2000 Ok(Self::F32(data))
2001 }
2002 Self::F64(storage) => {
2003 let data = unary_map(storage, layout, |v| v.powf(e));
2004 Ok(Self::F64(data))
2005 }
2006 Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
2007 Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
2008 Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
2009 }
2010 }
2011
2012 fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
2013 match self {
2015 Self::BF16(storage) => {
2016 let data = unary_map(storage, layout, |v| elu(v, bf16::from_f64(alpha)));
2017 Ok(Self::BF16(data))
2018 }
2019 Self::F16(storage) => {
2020 let data = unary_map(storage, layout, |v| elu(v, f16::from_f64(alpha)));
2021 Ok(Self::F16(data))
2022 }
2023 Self::F32(storage) => {
2024 let data = unary_map(storage, layout, |v| elu(v, f32::from_f64(alpha)));
2025 Ok(Self::F32(data))
2026 }
2027 Self::F64(storage) => {
2028 let data = unary_map(storage, layout, |v| elu(v, alpha));
2029 Ok(Self::F64(data))
2030 }
2031 Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
2032 Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
2033 Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
2034 }
2035 }
2036
2037 fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
2038 match self {
2039 Self::BF16(storage) => {
2040 if B::BF16_VEC {
2041 let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
2042 Ok(Self::BF16(data))
2043 } else {
2044 let data = unary_map(storage, layout, B::bf16);
2045 Ok(Self::BF16(data))
2046 }
2047 }
2048 Self::F16(storage) => {
2049 if B::F16_VEC {
2050 let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
2051 Ok(Self::F16(data))
2052 } else {
2053 let data = unary_map(storage, layout, B::f16);
2054 Ok(Self::F16(data))
2055 }
2056 }
2057 Self::F32(storage) => {
2058 if B::F32_VEC {
2059 let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
2060 Ok(Self::F32(data))
2061 } else {
2062 let data = unary_map(storage, layout, B::f32);
2063 Ok(Self::F32(data))
2064 }
2065 }
2066 Self::F64(storage) => {
2067 if B::F64_VEC {
2068 let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
2069 Ok(Self::F64(data))
2070 } else {
2071 let data = unary_map(storage, layout, B::f64);
2072 Ok(Self::F64(data))
2073 }
2074 }
2075 Self::U8(storage) => {
2076 let data = unary_map(storage, layout, B::u8);
2077 Ok(Self::U8(data))
2078 }
2079 Self::U32(storage) => {
2080 let data = unary_map(storage, layout, B::u32);
2081 Ok(Self::U32(data))
2082 }
2083 Self::I64(storage) => {
2084 let data = unary_map(storage, layout, B::i64);
2085 Ok(Self::I64(data))
2086 }
2087 }
2088 }
2089
2090 fn binary_impl<B: BinaryOpT>(
2091 &self,
2092 rhs: &Self,
2093 lhs_l: &Layout,
2094 rhs_l: &Layout,
2095 ) -> Result<Self> {
2096 match (self, rhs) {
2097 (Self::BF16(lhs), Self::BF16(rhs)) => {
2098 let data = if B::BF16_VEC {
2099 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::bf16, B::bf16_vec)
2100 } else {
2101 binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16)
2102 };
2103 Ok(Self::BF16(data))
2104 }
2105 (Self::F16(lhs), Self::F16(rhs)) => {
2106 let data = if B::F16_VEC {
2107 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f16, B::f16_vec)
2108 } else {
2109 binary_map(lhs_l, rhs_l, lhs, rhs, B::f16)
2110 };
2111 Ok(Self::F16(data))
2112 }
2113 (Self::F32(lhs), Self::F32(rhs)) => {
2114 let data = if B::F32_VEC {
2115 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f32, B::f32_vec)
2116 } else {
2117 binary_map(lhs_l, rhs_l, lhs, rhs, B::f32)
2118 };
2119 Ok(Self::F32(data))
2120 }
2121 (Self::F64(lhs), Self::F64(rhs)) => {
2122 let data = if B::F64_VEC {
2123 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f64, B::f64_vec)
2124 } else {
2125 binary_map(lhs_l, rhs_l, lhs, rhs, B::f64)
2126 };
2127 Ok(Self::F64(data))
2128 }
2129 (Self::U32(lhs), Self::U32(rhs)) => {
2130 let data = if B::U32_VEC {
2131 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u32, B::u32_vec)
2132 } else {
2133 binary_map(lhs_l, rhs_l, lhs, rhs, B::u32)
2134 };
2135 Ok(Self::U32(data))
2136 }
2137 (Self::I64(lhs), Self::I64(rhs)) => {
2138 let data = if B::I64_VEC {
2139 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
2140 } else {
2141 binary_map(lhs_l, rhs_l, lhs, rhs, B::i64)
2142 };
2143 Ok(Self::I64(data))
2144 }
2145 (Self::U8(lhs), Self::U8(rhs)) => {
2146 let data = if B::U8_VEC {
2147 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec)
2148 } else {
2149 binary_map(lhs_l, rhs_l, lhs, rhs, B::u8)
2150 };
2151 Ok(Self::U8(data))
2152 }
2153 _ => {
2154 Err(Error::DTypeMismatchBinaryOp {
2156 lhs: self.dtype(),
2157 rhs: rhs.dtype(),
2158 op: B::NAME,
2159 }
2160 .bt())
2161 }
2162 }
2163 }
2164
2165 fn copy2d(
2166 &self,
2167 dst: &mut Self,
2168 d1: usize,
2169 d2: usize,
2170 src_s: usize,
2171 dst_s: usize,
2172 src_o: usize,
2173 dst_o: usize,
2174 ) -> Result<()> {
2175 match (self, dst) {
2176 (Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
2177 (Self::U32(src), Self::U32(dst)) => {
2178 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2179 }
2180 (Self::I64(src), Self::I64(dst)) => {
2181 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2182 }
2183 (Self::BF16(src), Self::BF16(dst)) => {
2184 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2185 }
2186 (Self::F16(src), Self::F16(dst)) => {
2187 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2188 }
2189 (Self::F32(src), Self::F32(dst)) => {
2190 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2191 }
2192 (Self::F64(src), Self::F64(dst)) => {
2193 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2194 }
2195 (_, dst) => {
2196 return Err(Error::DTypeMismatchBinaryOp {
2197 lhs: self.dtype(),
2198 rhs: dst.dtype(),
2199 op: "copy2d",
2200 }
2201 .bt());
2202 }
2203 }
2204 Ok(())
2205 }
2206
2207 fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
2208 match (self, dst) {
2209 (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2210 (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2211 (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2212 (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2213 (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2214 (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2215 (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2216 (_, dst) => {
2217 return Err(Error::DTypeMismatchBinaryOp {
2219 lhs: self.dtype(),
2220 rhs: dst.dtype(),
2221 op: "copy_strided",
2222 }
2223 .bt());
2224 }
2225 }
2226 Ok(())
2227 }
2228
2229 fn where_cond(
2230 &self,
2231 layout: &Layout,
2232 t: &Self,
2233 t_l: &Layout,
2234 f: &Self,
2235 f_l: &Layout,
2236 ) -> Result<Self> {
2237 match self {
2238 Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2239 Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2240 Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2241 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
2242 }
2243 }
2244
2245 fn conv1d(
2246 &self,
2247 l: &Layout,
2248 kernel: &Self,
2249 kernel_l: &Layout,
2250 params: &crate::conv::ParamsConv1D,
2251 ) -> Result<Self> {
2252 if !USE_IM2COL_CONV1D {
2253 return Conv1D(params).map(self, l, kernel, kernel_l);
2254 }
2255 let op = Im2Col1D {
2256 l_k: params.k_size,
2257 padding: params.padding,
2258 stride: params.stride,
2259 dilation: params.dilation,
2260 };
2261 let col = op.map(self, l)?;
2262 let b = params.b_size;
2263 let n = params.c_out;
2264 let l_out = params.l_out();
2265 let k = op.l_k * params.c_in;
2266 let m = l_out;
2267 let col_l = Layout::contiguous((b, m, k));
2268 let res = if kernel_l.is_contiguous() {
2269 let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2270 .transpose(1, 2)?
2271 .broadcast_as((b, k, n))?;
2272 col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2273 } else {
2274 let mut kernel_c = unsafe {
2276 self.device()
2277 .alloc_uninit(kernel_l.shape(), kernel.dtype())?
2278 };
2279 kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
2280 let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2281 .transpose(1, 2)?
2282 .broadcast_as((b, k, n))?;
2283 col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2284 };
2285 let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
2286 let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
2287 res.copy_strided_src(&mut res_t, 0, &res_l)?;
2288 Ok(res_t)
2289 }
2290
2291 fn conv_transpose1d(
2292 &self,
2293 l: &Layout,
2294 kernel: &Self,
2295 kernel_l: &Layout,
2296 params: &crate::conv::ParamsConvTranspose1D,
2297 ) -> Result<Self> {
2298 let can_use_col2im = kernel_l.is_contiguous()
2299 && params.dilation == 1
2300 && params.padding == 0
2301 && params.output_padding == 0;
2302 if USE_COL2IM_CONV1D_TR && can_use_col2im {
2303 let (b_size, c_in, l_in) = l.shape().dims3()?;
2304 let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
2305 if !kernel_l.is_contiguous() {
2306 crate::bail!(
2307 "convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
2308 )
2309 }
2310 if c_in != c_in2 {
2311 crate::bail!(
2312 "convtr1d: shape mismatch on c_in {:?} {:?}",
2313 l.shape(),
2314 kernel_l.shape()
2315 )
2316 }
2317 let col = {
2318 let kernel_l_mm = Layout::new(
2320 (b_size, c_in, k_size * c_out).into(),
2321 vec![0, k_size * c_out, 1],
2322 kernel_l.start_offset(),
2323 );
2324 self.matmul(
2325 kernel,
2326 (
2327 b_size,
2328 l_in,
2329 c_out * k_size,
2330 c_in,
2331 ),
2332 &l.transpose(1, 2)?,
2333 &kernel_l_mm,
2334 )?
2335 };
2336 let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
2337 Col2Im1D {
2338 stride: params.stride,
2339 }
2340 .map(&col, &col_l)
2341 } else {
2342 ConvTranspose1D(params).map(self, l, kernel, kernel_l)
2343 }
2344 }
2345
2346 fn conv2d(
2347 &self,
2348 l: &Layout,
2349 kernel: &Self,
2350 kernel_l: &Layout,
2351 params: &crate::conv::ParamsConv2D,
2352 ) -> Result<Self> {
2353 if !USE_IM2COL_CONV2D {
2354 return Conv2D(params).map(self, l, kernel, kernel_l);
2355 }
2356 let op = Im2Col {
2357 h_k: params.k_h,
2358 w_k: params.k_w,
2359 padding: params.padding,
2360 stride: params.stride,
2361 dilation: params.dilation,
2362 };
2363 let col = op.map(self, l)?;
2364 let b = params.b_size;
2365 let n = params.c_out;
2366 let (h_out, w_out) = (params.out_h(), params.out_w());
2367 let k = op.h_k * op.w_k * params.c_in;
2368 let m = h_out * w_out;
2369 let col_l = Layout::contiguous((b, m, k));
2370 let res = if kernel_l.is_contiguous() {
2371 let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2372 .transpose(1, 2)?
2373 .broadcast_as((b, k, n))?;
2374 col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2375 } else {
2376 let mut kernel_c = unsafe {
2378 self.device()
2379 .alloc_uninit(kernel_l.shape(), kernel.dtype())?
2380 };
2381 kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
2382 let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2383 .transpose(1, 2)?
2384 .broadcast_as((b, k, n))?;
2385 col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2386 };
2387 let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
2388 .transpose(1, 2)?
2389 .transpose(1, 3)?;
2390 let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
2391 res.copy_strided_src(&mut res_t, 0, &res_l)?;
2392 Ok(res_t)
2393 }
2394
2395 fn conv_transpose2d(
2396 &self,
2397 l: &Layout,
2398 kernel: &Self,
2399 kernel_l: &Layout,
2400 params: &crate::conv::ParamsConvTranspose2D,
2401 ) -> Result<Self> {
2402 ConvTranspose2D(params).map(self, l, kernel, kernel_l)
2403 }
2404
2405 fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
2406 match ids {
2407 Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2408 Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2409 Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2410 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
2411 }
2412 }
2413
2414 fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
2415 match ids {
2416 Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
2417 Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
2418 Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
2419 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
2420 }
2421 }
2422
2423 fn scatter_set(
2424 &mut self,
2425 l: &Layout,
2426 ids: &Self,
2427 ids_l: &Layout,
2428 src: &Self,
2429 src_l: &Layout,
2430 dim: usize,
2431 ) -> Result<()> {
2432 match ids {
2433 Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2434 Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2435 Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2436 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()),
2437 }
2438 }
2439
2440 fn scatter_add_set(
2441 &mut self,
2442 l: &Layout,
2443 ids: &Self,
2444 ids_l: &Layout,
2445 src: &Self,
2446 src_l: &Layout,
2447 dim: usize,
2448 ) -> Result<()> {
2449 match ids {
2450 Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2451 Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2452 Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2453 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
2454 }
2455 }
2456
2457 fn index_add(
2458 &self,
2459 l: &Layout,
2460 ids: &Self,
2461 ids_l: &Layout,
2462 src: &Self,
2463 src_l: &Layout,
2464 dim: usize,
2465 ) -> Result<Self> {
2466 match ids {
2467 Self::U8(ids) => {
2468 let ids = match ids_l.contiguous_offsets() {
2469 Some((a, b)) => &ids[a..b],
2470 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2471 };
2472 IndexAdd { ids, dim }.map(self, l, src, src_l)
2473 }
2474 Self::U32(ids) => {
2475 let ids = match ids_l.contiguous_offsets() {
2476 Some((a, b)) => &ids[a..b],
2477 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2478 };
2479 IndexAdd { ids, dim }.map(self, l, src, src_l)
2480 }
2481 Self::I64(ids) => {
2482 let ids = match ids_l.contiguous_offsets() {
2483 Some((a, b)) => &ids[a..b],
2484 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2485 };
2486 IndexAdd { ids, dim }.map(self, l, src, src_l)
2487 }
2488 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
2489 }
2490 }
2491
2492 fn matmul(
2493 &self,
2494 rhs: &Self,
2495 bmnk: (usize, usize, usize, usize),
2496 lhs_l: &Layout,
2497 rhs_l: &Layout,
2498 ) -> Result<Self> {
2499 MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
2500 }
2501
2502 fn device(&self) -> &Self::Device {
2503 &CpuDevice
2504 }
2505
2506 fn try_clone(&self, _: &Layout) -> Result<Self> {
2507 Ok(self.clone())
2508 }
2509
2510 fn to_cpu_storage(&self) -> Result<CpuStorage> {
2511 Ok(self.clone())
2512 }
2513
2514 fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
2515 use crate::scalar::Scalar;
2516 fn set<T: crate::WithDType>(src: &mut [T], l: &Layout, s: T) {
2517 match l.strided_blocks() {
2518 crate::StridedBlocks::SingleBlock { start_offset, len } => {
2519 src[start_offset..start_offset + len].fill(s)
2520 }
2521 crate::StridedBlocks::MultipleBlocks {
2522 block_start_index,
2523 block_len: 1,
2524 } => {
2525 for src_index in block_start_index {
2526 src[src_index] = s
2527 }
2528 }
2529 crate::StridedBlocks::MultipleBlocks {
2530 block_start_index,
2531 block_len,
2532 } => {
2533 for src_index in block_start_index {
2534 src[src_index..src_index + block_len].fill(s)
2535 }
2536 }
2537 }
2538 }
2539 match (self, s) {
2540 (Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v),
2541 (Self::F16(storage), Scalar::F16(v)) => set(storage, l, v),
2542 (Self::F32(storage), Scalar::F32(v)) => set(storage, l, v),
2543 (Self::F64(storage), Scalar::F64(v)) => set(storage, l, v),
2544 (Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
2545 (Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
2546 (Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
2547 (st, s) => crate::bail!(
2548 "const_set dtype mismatch, expected {:?} but got {:?}",
2549 st.dtype(),
2550 s
2551 ),
2552 }
2553 Ok(())
2554 }
2555}
2556
2557impl BackendDevice for CpuDevice {
2558 type Storage = CpuStorage;
2559
2560 fn location(&self) -> crate::DeviceLocation {
2561 crate::DeviceLocation::Cpu
2562 }
2563
2564 fn same_device(&self, _: &Self) -> bool {
2565 true
2566 }
2567
2568 fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
2569 Ok(T::to_cpu_storage(s))
2570 }
2571
2572 fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
2573 Ok(s.clone())
2574 }
2575
2576 fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> {
2577 Ok(s)
2578 }
2579
2580 fn new(_: usize) -> Result<Self> {
2581 Ok(Self)
2582 }
2583
2584 fn set_seed(&self, _seed: u64) -> Result<()> {
2585 crate::bail!("cannot seed the CPU rng with set_seed")
2586 }
2587
2588 fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
2589 use rand::prelude::*;
2590
2591 let elem_count = shape.elem_count();
2592 let mut rng = rand::rng();
2593 match dtype {
2594 DType::U8 | DType::U32 | DType::I64 => {
2595 Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
2596 }
2597 DType::BF16 => {
2598 let mut data = Vec::with_capacity(elem_count);
2599 let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max))
2600 .map_err(Error::wrap)?;
2601 for _i in 0..elem_count {
2602 data.push(rng.sample::<bf16, _>(uniform))
2603 }
2604 Ok(CpuStorage::BF16(data))
2605 }
2606 DType::F16 => {
2607 let mut data = Vec::with_capacity(elem_count);
2608 let uniform = rand::distr::Uniform::new(f16::from_f64(min), f16::from_f64(max))
2609 .map_err(Error::wrap)?;
2610 for _i in 0..elem_count {
2611 data.push(rng.sample::<f16, _>(uniform))
2612 }
2613 Ok(CpuStorage::F16(data))
2614 }
2615 DType::F32 => {
2616 let mut data = Vec::with_capacity(elem_count);
2617 let uniform =
2618 rand::distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
2619 for _i in 0..elem_count {
2620 data.push(rng.sample::<f32, _>(uniform))
2621 }
2622 Ok(CpuStorage::F32(data))
2623 }
2624 DType::F64 => {
2625 let mut data = Vec::with_capacity(elem_count);
2626 let uniform = rand::distr::Uniform::new(min, max).map_err(Error::wrap)?;
2627 for _i in 0..elem_count {
2628 data.push(rng.sample::<f64, _>(uniform))
2629 }
2630 Ok(CpuStorage::F64(data))
2631 }
2632 }
2633 }
2634
2635 fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CpuStorage> {
2636 use rand::prelude::*;
2637
2638 let elem_count = shape.elem_count();
2639 let mut rng = rand::rng();
2640 match dtype {
2641 DType::U8 | DType::U32 | DType::I64 => {
2642 Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
2643 }
2644 DType::BF16 => {
2645 let mut data = Vec::with_capacity(elem_count);
2646 let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
2647 .map_err(Error::wrap)?;
2648 for _i in 0..elem_count {
2649 data.push(normal.sample(&mut rng))
2650 }
2651 Ok(CpuStorage::BF16(data))
2652 }
2653 DType::F16 => {
2654 let mut data = Vec::with_capacity(elem_count);
2655 let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
2656 .map_err(Error::wrap)?;
2657 for _i in 0..elem_count {
2658 data.push(normal.sample(&mut rng))
2659 }
2660 Ok(CpuStorage::F16(data))
2661 }
2662 DType::F32 => {
2663 let mut data = Vec::with_capacity(elem_count);
2664 let normal =
2665 rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
2666 for _i in 0..elem_count {
2667 data.push(normal.sample(&mut rng))
2668 }
2669 Ok(CpuStorage::F32(data))
2670 }
2671 DType::F64 => {
2672 let mut data = Vec::with_capacity(elem_count);
2673 let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
2674 for _i in 0..elem_count {
2675 data.push(normal.sample(&mut rng))
2676 }
2677 Ok(CpuStorage::F64(data))
2678 }
2679 }
2680 }
2681
2682 #[allow(clippy::uninit_vec)]
2683 unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
2684 let elem_count = shape.elem_count();
2685 let storage = match dtype {
2690 DType::U8 => {
2691 let mut v = Vec::with_capacity(elem_count);
2692 v.set_len(elem_count);
2693 CpuStorage::U8(v)
2694 }
2695 DType::U32 => {
2696 let mut v = Vec::with_capacity(elem_count);
2697 v.set_len(elem_count);
2698 CpuStorage::U32(v)
2699 }
2700 DType::I64 => {
2701 let mut v = Vec::with_capacity(elem_count);
2702 v.set_len(elem_count);
2703 CpuStorage::I64(v)
2704 }
2705 DType::BF16 => {
2706 let mut v = Vec::with_capacity(elem_count);
2707 v.set_len(elem_count);
2708 CpuStorage::BF16(v)
2709 }
2710 DType::F16 => {
2711 let mut v = Vec::with_capacity(elem_count);
2712 v.set_len(elem_count);
2713 CpuStorage::F16(v)
2714 }
2715 DType::F32 => {
2716 let mut v = Vec::with_capacity(elem_count);
2717 v.set_len(elem_count);
2718 CpuStorage::F32(v)
2719 }
2720 DType::F64 => {
2721 let mut v = Vec::with_capacity(elem_count);
2722 v.set_len(elem_count);
2723 CpuStorage::F64(v)
2724 }
2725 };
2726 Ok(storage)
2727 }
2728
2729 fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
2730 let elem_count = shape.elem_count();
2731 let storage = match dtype {
2732 DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
2733 DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
2734 DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
2735 DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
2736 DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
2737 DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
2738 DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
2739 };
2740 Ok(storage)
2741 }
2742
2743 fn synchronize(&self) -> Result<()> {
2744 Ok(())
2745 }
2746}
2747
2748#[macro_export]
2749macro_rules! map_dtype {
2750 ($name:expr, $storage:ident, $fn:expr, ($($dtypes:ident),+)) => {
2751 match $storage {
2752 $(CpuStorage::$dtypes(__e) => CpuStorage::$dtypes($fn(__e)),)*
2753 s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())?,
2754 }
2755 };
2756}