1use crate::backend::{BackendDevice, BackendStorage};
2use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
3use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
4use half::{bf16, f16};
5use rayon::prelude::*;
6
7const USE_IM2COL_CONV1D: bool = true;
8const USE_IM2COL_CONV2D: bool = true;
9
10#[derive(Debug, Clone)]
13pub enum CpuStorage {
14 U8(Vec<u8>),
15 U32(Vec<u32>),
16 I64(Vec<i64>),
17 BF16(Vec<bf16>),
18 F16(Vec<f16>),
19 F32(Vec<f32>),
20 F64(Vec<f64>),
21}
22
23#[derive(Debug, Clone)]
24pub struct CpuDevice;
25
26pub trait Map1 {
27 fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
28
29 fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
30 match vs {
31 CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
32 CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
33 CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
34 CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
35 CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
36 CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
37 CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
38 }
39 }
40}
41
42pub trait Map1Any {
43 fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
44 &self,
45 vs: &[T],
46 layout: &Layout,
47 wrap: W,
48 ) -> Result<CpuStorage>;
49
50 fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
51 match vs {
52 CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
53 CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
54 CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
55 CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
56 CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
57 CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
58 CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
59 }
60 }
61}
62
63type C = CpuStorage;
64pub trait Map2 {
65 const OP: &'static str;
66 fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
67
68 fn map(
69 &self,
70 v1: &CpuStorage,
71 l1: &Layout,
72 v2: &CpuStorage,
73 l2: &Layout,
74 ) -> Result<CpuStorage> {
75 match (v1, v2) {
76 (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
77 (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
78 (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
79 (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
80 (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
81 (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
82 (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
83 _ => Err(Error::DTypeMismatchBinaryOp {
84 lhs: v1.dtype(),
85 rhs: v2.dtype(),
86 op: Self::OP,
87 }
88 .bt()),
89 }
90 }
91}
92
93pub trait Map2U8 {
94 const OP: &'static str;
95 fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
96
97 fn map(
98 &self,
99 v1: &CpuStorage,
100 l1: &Layout,
101 v2: &CpuStorage,
102 l2: &Layout,
103 ) -> Result<CpuStorage> {
104 match (v1, v2) {
105 (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
106 (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
107 (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
108 (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
109 (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
110 (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
111 (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
112 _ => Err(Error::DTypeMismatchBinaryOp {
113 lhs: v1.dtype(),
114 rhs: v2.dtype(),
115 op: Self::OP,
116 }
117 .bt()),
118 }
119 }
120}
121
122struct Cmp(CmpOp);
123impl Map2U8 for Cmp {
124 const OP: &'static str = "cmp";
125 #[inline(always)]
126 fn f<T: WithDType>(
127 &self,
128 lhs: &[T],
129 lhs_l: &Layout,
130 rhs: &[T],
131 rhs_l: &Layout,
132 ) -> Result<Vec<u8>> {
133 let dst = match self.0 {
134 CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)),
135 CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)),
136 CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)),
137 CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)),
138 CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)),
139 CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)),
140 };
141 Ok(dst)
142 }
143}
144
145struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
146
147impl<'a, I: IntDType> Map2 for WCond<'a, I> {
148 const OP: &'static str = "where";
149 #[inline(always)]
150 fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
151 let vs = match (
152 self.1.contiguous_offsets(),
153 t_l.contiguous_offsets(),
154 f_l.contiguous_offsets(),
155 ) {
156 (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
157 let pred = &self.0[o1..o2];
158 let t = &t[o_t1..o_t2];
159 let f = &f[o_f1..o_f2];
160 pred.iter()
161 .zip(t.iter().zip(f.iter()))
162 .map(|(p, (&t, &f))| if p.is_true() { t } else { f })
163 .collect::<Vec<_>>()
164 }
165 _ => self
166 .1
167 .strided_index()
168 .zip(t_l.strided_index().zip(f_l.strided_index()))
169 .map(|(i_p, (i_t, i_f))| {
170 if self.0[i_p].is_true() {
171 t[i_t]
172 } else {
173 f[i_f]
174 }
175 })
176 .collect::<Vec<_>>(),
177 };
178 Ok(vs)
179 }
180}
181
182struct ReduceIndex {
183 reduce_dim_index: usize,
184 use_min: bool,
185 return_index: bool,
186}
187
188impl ReduceIndex {
189 #[inline(always)]
191 fn fold_impl<T, U, F, G>(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result<Vec<U>>
192 where
193 T: Clone + Copy,
194 U: Clone + Copy,
195 F: Fn(T, T) -> bool,
196 G: Fn(T, usize) -> U,
197 {
198 let reduce_dim_size = src_l.dims()[self.reduce_dim_index];
199 let reduce_dim_stride = src_l.stride()[self.reduce_dim_index];
200 let dst_len = src_l.shape().elem_count() / reduce_dim_size;
201 let mut dst: Vec<U> = Vec::with_capacity(dst_len);
202 let dst_to_set = dst.spare_capacity_mut();
203 let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) };
204 match src_l.contiguous_offsets() {
205 Some((o1, o2)) => {
206 let src = &src[o1..o2];
207 if reduce_dim_stride == 1 {
208 for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
209 let start_src_i = start_src_i * reduce_dim_size;
210 let src = &src[start_src_i..start_src_i + reduce_dim_size];
211 let mut acc = 0;
212 let mut val = src[0];
213 for (src_i, &s) in src.iter().enumerate() {
214 if f(val, s) {
215 acc = src_i;
216 val = s
217 }
218 }
219 *dst_v = g(val, acc)
220 }
221 } else {
222 for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
223 let (p, q) = (
224 start_src_i / reduce_dim_stride,
225 start_src_i % reduce_dim_stride,
226 );
227 let start_src_i = p * reduce_dim_stride * reduce_dim_size + q;
229 let src = &src[start_src_i..];
230 let mut acc = 0;
231 let mut val = src[0];
232 for src_i in 0..reduce_dim_size {
233 let s = src[src_i * reduce_dim_stride];
234 if f(val, s) {
235 acc = src_i;
236 val = s
237 }
238 }
239 *dst_v = g(val, acc)
240 }
241 }
242 }
243 None => {
244 let l = src_l.narrow(self.reduce_dim_index, 0, 1)?;
245 for (unstr_index, src_index) in l.strided_index().enumerate() {
246 let src = &src[src_index..];
247 let mut acc = 0;
248 let mut val = src[0];
249 for src_i in 0..reduce_dim_size {
250 let s = src[src_i * reduce_dim_stride];
251 if f(val, s) {
252 acc = src_i;
253 val = s
254 }
255 }
256 dst_to_set[unstr_index] = g(val, acc)
257 }
258 }
259 }
260 unsafe { dst.set_len(dst_len) };
261 Ok(dst)
262 }
263}
264
265impl Map1Any for ReduceIndex {
266 #[inline(always)]
267 fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
268 &self,
269 src: &[T],
270 src_l: &Layout,
271 wrap: W,
272 ) -> Result<CpuStorage> {
273 if src_l.shape().elem_count() == 0 {
274 Err(Error::EmptyTensor { op: "reduce" }.bt())?
275 }
276 let dst = match (self.return_index, self.use_min) {
277 (false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?),
278 (false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?),
279 (true, true) => {
280 CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?)
281 }
282 (true, false) => {
283 CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?)
284 }
285 };
286 Ok(dst)
287 }
288}
289
290struct ReduceSum<'a> {
291 dst_shape: &'a Shape,
292 reduce_dims: &'a [usize],
293 reduce_dims_and_stride: Vec<(usize, usize)>,
294}
295
296impl<'a> ReduceSum<'a> {
297 #[inline(always)]
298 fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
299 where
300 T: WithDType,
301 {
302 let mut dst = vec![start_elt; self.dst_shape.elem_count()];
303 match src_l.contiguous_offsets() {
304 Some((o1, o2)) => {
305 let src = &src[o1..o2];
306 let reduce_over_last_dims = self
310 .reduce_dims
311 .iter()
312 .rev()
313 .enumerate()
314 .all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
315 if reduce_over_last_dims {
316 let reduce_sz = self
317 .reduce_dims_and_stride
318 .iter()
319 .map(|(u, _)| u)
320 .product::<usize>();
321 for (dst_i, dst_v) in dst.iter_mut().enumerate() {
322 let src_i = dst_i * reduce_sz;
323 unsafe {
324 T::vec_reduce_sum(
325 src[src_i..src_i + reduce_sz].as_ptr(),
326 dst_v,
327 reduce_sz,
328 )
329 };
330 }
331 return Ok(dst);
332 };
333 for (unstr_index, &src) in src.iter().enumerate() {
334 let mut dst_index = unstr_index;
335 for &(dim, stride) in self.reduce_dims_and_stride.iter() {
337 let (pre, post) = (dst_index / stride, dst_index % stride);
339 dst_index = (pre / dim) * stride + post;
340 }
341 dst[dst_index] += src;
342 }
343 }
344 None => {
345 for (unstr_index, src_index) in src_l.strided_index().enumerate() {
346 let mut dst_index = unstr_index;
347 for &(dim, stride) in self.reduce_dims_and_stride.iter() {
349 let (pre, post) = (dst_index / stride, dst_index % stride);
351 dst_index = (pre / dim) * stride + post;
352 }
353 dst[dst_index] += src[src_index];
354 }
355 }
356 }
357 Ok(dst)
358 }
359}
360
361impl<'a> Map1 for ReduceSum<'a> {
362 #[inline(always)]
363 fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
364 self.fold_impl(src, src_l, T::zero())
365 }
366}
367
368pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
369 vs: &[T],
370 layout: &Layout,
371 mut f: F,
372) -> Vec<U> {
373 match layout.strided_blocks() {
374 crate::StridedBlocks::SingleBlock { start_offset, len } => vs
375 [start_offset..start_offset + len]
376 .iter()
377 .map(|&v| f(v))
378 .collect(),
379 crate::StridedBlocks::MultipleBlocks {
380 block_start_index,
381 block_len,
382 } => {
383 let mut result = Vec::with_capacity(layout.shape().elem_count());
384 if block_len == 1 {
386 for index in block_start_index {
387 let v = unsafe { vs.get_unchecked(index) };
388 result.push(f(*v))
389 }
390 } else {
391 for index in block_start_index {
392 for offset in 0..block_len {
393 let v = unsafe { vs.get_unchecked(index + offset) };
394 result.push(f(*v))
395 }
396 }
397 }
398 result
399 }
400 }
401}
402
403pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
404 vs: &[T],
405 layout: &Layout,
406 mut f: F,
407 mut f_vec: FV,
408) -> Vec<U> {
409 match layout.strided_blocks() {
410 crate::StridedBlocks::SingleBlock { start_offset, len } => {
411 let mut ys: Vec<U> = Vec::with_capacity(len);
412 let ys_to_set = ys.spare_capacity_mut();
413 let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
414 f_vec(&vs[start_offset..start_offset + len], ys_to_set);
415 unsafe { ys.set_len(len) };
417 ys
418 }
419 crate::StridedBlocks::MultipleBlocks {
420 block_start_index,
421 block_len,
422 } => {
423 let el_count = layout.shape().elem_count();
424 if block_len == 1 {
426 let mut result = Vec::with_capacity(el_count);
427 for index in block_start_index {
428 let v = unsafe { vs.get_unchecked(index) };
429 result.push(f(*v))
430 }
431 result
432 } else {
433 let mut ys: Vec<U> = Vec::with_capacity(el_count);
434 let ys_to_set = ys.spare_capacity_mut();
435 let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
436 let mut dst_index = 0;
437 for src_index in block_start_index {
438 let vs = &vs[src_index..src_index + block_len];
439 let ys = &mut ys_to_set[dst_index..dst_index + block_len];
440 f_vec(vs, ys);
441 dst_index += block_len;
442 }
443 unsafe { ys.set_len(el_count) };
445 ys
446 }
447 }
448 }
449}
450
451pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
453 lhs_l: &Layout,
454 rhs_l: &Layout,
455 lhs: &[T],
456 rhs: &[T],
457 mut f: F,
458) -> Vec<U> {
459 match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
460 (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
461 .iter()
462 .zip(rhs[o_r1..o_r2].iter())
463 .map(|(&l, &r)| f(l, r))
464 .collect(),
465 (Some((o_l1, o_l2)), None) => {
466 match rhs_l.offsets_b() {
468 Some(ob) => {
469 let mut i_in_block = 0;
470 let mut i_right_broadcast = 0;
471 lhs[o_l1..o_l2]
472 .iter()
473 .map(|&l| {
474 let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
475 i_right_broadcast += 1;
476 if i_right_broadcast >= ob.right_broadcast {
477 i_in_block += 1;
478 i_right_broadcast = 0;
479 }
480 if i_in_block >= ob.len {
481 i_in_block = 0
482 }
483 f(l, *r)
484 })
485 .collect()
486 }
487 None => lhs_l
488 .strided_index()
489 .zip(rhs_l.strided_index())
490 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
491 .collect(),
492 }
493 }
494 (None, Some((o_r1, o_r2))) => {
495 match lhs_l.offsets_b() {
497 Some(ob) => {
498 let mut i_in_block = 0;
499 let mut i_right_broadcast = 0;
500 rhs[o_r1..o_r2]
501 .iter()
502 .map(|&r| {
503 let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
504 i_right_broadcast += 1;
505 if i_right_broadcast >= ob.right_broadcast {
506 i_in_block += 1;
507 i_right_broadcast = 0;
508 }
509 if i_in_block >= ob.len {
510 i_in_block = 0
511 }
512 f(*l, r)
513 })
514 .collect()
515 }
516 None => lhs_l
517 .strided_index()
518 .zip(rhs_l.strided_index())
519 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
520 .collect(),
521 }
522 }
523 _ => lhs_l
524 .strided_index()
525 .zip(rhs_l.strided_index())
526 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
527 .collect(),
528 }
529}
530
531pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
533 lhs_l: &Layout,
534 rhs_l: &Layout,
535 lhs: &[T],
536 rhs: &[T],
537 mut f: F,
538 mut f_vec: FV,
539) -> Vec<T> {
540 let el_count = lhs_l.shape().elem_count();
541 match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
542 (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
543 let mut ys: Vec<T> = Vec::with_capacity(el_count);
544 let ys_to_set = ys.spare_capacity_mut();
545 let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
546 f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
547 unsafe { ys.set_len(el_count) };
549 ys
550 }
551 (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
552 Some(ob) if ob.right_broadcast == 1 => {
553 let rhs = &rhs[ob.start..ob.start + ob.len];
554 let mut ys: Vec<T> = Vec::with_capacity(el_count);
555 let ys_to_set = ys.spare_capacity_mut();
556 let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
557 let mut dst_i = 0;
558 for src_i in (o_l1..o_l2).step_by(ob.len) {
559 f_vec(
560 &lhs[src_i..src_i + ob.len],
561 rhs,
562 &mut ys_to_set[dst_i..dst_i + ob.len],
563 );
564 dst_i += ob.len;
565 }
566 unsafe { ys.set_len(el_count) };
568 ys
569 }
570 Some(ob) => {
571 let rhs = &rhs[ob.start..ob.start + ob.len];
572 let mut ys = lhs[o_l1..o_l2].to_vec();
573 for idx_l in 0..ob.left_broadcast {
574 let start = idx_l * ob.len * ob.right_broadcast;
575 for (i, &r) in rhs.iter().enumerate() {
576 let start = start + i * ob.right_broadcast;
577 for v in ys[start..start + ob.right_broadcast].iter_mut() {
578 *v = f(*v, r)
579 }
580 }
581 }
582 ys
583 }
584 None => lhs_l
585 .strided_index()
586 .zip(rhs_l.strided_index())
587 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
588 .collect(),
589 },
590 (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
591 Some(ob) if ob.right_broadcast == 1 => {
592 let lhs = &lhs[ob.start..ob.start + ob.len];
593 let mut ys: Vec<T> = Vec::with_capacity(el_count);
594 let ys_to_set = ys.spare_capacity_mut();
595 let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
596 let mut dst_i = 0;
597 for src_i in (o_r1..o_r2).step_by(ob.len) {
598 f_vec(
599 lhs,
600 &rhs[src_i..src_i + ob.len],
601 &mut ys_to_set[dst_i..dst_i + ob.len],
602 );
603 dst_i += ob.len;
604 }
605 unsafe { ys.set_len(el_count) };
607 ys
608 }
609 Some(ob) => {
610 let lhs = &lhs[ob.start..ob.start + ob.len];
611 let mut ys = rhs[o_r1..o_r2].to_vec();
612 for idx_l in 0..ob.left_broadcast {
613 let start = idx_l * ob.len * ob.right_broadcast;
614 for (i, &l) in lhs.iter().enumerate() {
615 let start = start + i * ob.right_broadcast;
616 for v in ys[start..start + ob.right_broadcast].iter_mut() {
617 *v = f(l, *v)
618 }
619 }
620 }
621 ys
622 }
623 None => lhs_l
624 .strided_index()
625 .zip(rhs_l.strided_index())
626 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
627 .collect(),
628 },
629 _ => lhs_l
630 .strided_index()
631 .zip(rhs_l.strided_index())
632 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
633 .collect(),
634 }
635}
636
637struct Affine(f64, f64);
638
639impl Map1 for Affine {
640 fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
641 let mul = T::from_f64(self.0);
642 let add = T::from_f64(self.1);
643 Ok(unary_map(vs, layout, |v| v * mul + add))
644 }
645}
646
647struct AvgPool2D((usize, usize), (usize, usize));
648
649impl Map1 for AvgPool2D {
650 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
651 let (k_h, k_w) = self.0;
653 let (s_h, s_w) = self.1;
654 let (b_sz, c, h, w) = layout.shape().dims4()?;
655 let stride = layout.stride();
656 let (stride_h, stride_w) = (stride[2], stride[3]);
657 let h_out = (h - k_h) / s_h + 1;
658 let w_out = (w - k_w) / s_w + 1;
659 let src_index = layout.start_offset();
660 let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
661 let scale = 1f64 / (k_h * k_w) as f64;
662 let scale = T::from_f64(scale);
663 for b_idx in 0..b_sz {
664 let dst = &mut dst[b_idx * c * h_out * w_out..];
665 let src_index = src_index + b_idx * stride[0];
666 for c_idx in 0..c {
667 let dst = &mut dst[c_idx * h_out * w_out..];
668 let src_index = src_index + c_idx * stride[1];
669 for h_idx in 0..h_out {
670 for w_idx in 0..w_out {
671 let mut sum = T::zero();
672 for m in 0..k_h {
673 for n in 0..k_w {
674 let m = s_h * h_idx + m;
675 let n = s_w * w_idx + n;
676 sum += src[src_index + m * stride_h + n * stride_w]
677 }
678 }
679 dst[h_idx * w_out + w_idx] = sum * scale;
680 }
681 }
682 }
683 }
684 Ok(dst)
685 }
686}
687
688struct MaxPool2D((usize, usize), (usize, usize));
689
690impl Map1 for MaxPool2D {
691 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
692 let (k_h, k_w) = self.0;
694 let (s_h, s_w) = self.1;
695 let (b_sz, c, h, w) = layout.shape().dims4()?;
696 let stride = layout.stride();
697 let (stride_h, stride_w) = (stride[2], stride[3]);
698 let h_out = (h - k_h) / s_h + 1;
699 let w_out = (w - k_w) / s_w + 1;
700 let src_index = layout.start_offset();
701 let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
702 for b_idx in 0..b_sz {
703 let dst = &mut dst[b_idx * c * h_out * w_out..];
704 let src_index = src_index + b_idx * stride[0];
705 for c_idx in 0..c {
706 let dst = &mut dst[c_idx * h_out * w_out..];
707 let src_index = src_index + c_idx * stride[1];
708 for h_idx in 0..h_out {
709 for w_idx in 0..w_out {
710 let mut largest =
711 src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
712 for m in 0..k_h {
713 for n in 0..k_w {
714 let m = s_h * h_idx + m;
715 let n = s_w * w_idx + n;
716 if largest < src[src_index + m * stride_h + n * stride_w] {
717 largest = src[src_index + m * stride_h + n * stride_w]
718 }
719 }
720 }
721 dst[h_idx * w_out + w_idx] = largest;
722 }
723 }
724 }
725 }
726 Ok(dst)
727 }
728}
729
730struct UpsampleNearest1D(usize);
731
732impl Map1 for UpsampleNearest1D {
733 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
734 let dst_sz = self.0;
736 let (b_sz, c, src_sz) = layout.shape().dims3()?;
737 let stride = layout.stride();
738 let stride_sz = stride[2];
739 let src_index = layout.start_offset();
740 let scale_sz = src_sz as f64 / dst_sz as f64;
741 let mut dst = vec![T::zero(); b_sz * c * dst_sz];
742 let src_idxs = (0..dst_sz)
743 .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
744 .collect::<Vec<_>>();
745 for b_idx in 0..b_sz {
746 let dst = &mut dst[b_idx * c * dst_sz..];
747 let src_index = src_index + b_idx * stride[0];
748 for c_idx in 0..c {
749 let dst = &mut dst[c_idx * dst_sz..];
750 let src_index = src_index + c_idx * stride[1];
751 for (idx, src_idx) in src_idxs.iter().enumerate() {
752 dst[idx] = src[src_index + src_idx * stride_sz]
753 }
754 }
755 }
756 Ok(dst)
757 }
758}
759
760struct UpsampleNearest2D(usize, usize);
761
762impl Map1 for UpsampleNearest2D {
763 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
764 let (dst_h, dst_w) = (self.0, self.1);
766 let (b_sz, c, src_h, src_w) = layout.shape().dims4()?;
767 let stride = layout.stride();
768 let (stride_h, stride_w) = (stride[2], stride[3]);
769 let src_index = layout.start_offset();
770 let scale_h = src_h as f64 / dst_h as f64;
771 let scale_w = src_w as f64 / dst_w as f64;
772 let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
773 let src_h_idxs = (0..dst_h)
774 .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
775 .collect::<Vec<_>>();
776 let src_w_idxs = (0..dst_w)
777 .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
778 .collect::<Vec<_>>();
779 for b_idx in 0..b_sz {
780 let dst = &mut dst[b_idx * c * dst_h * dst_w..];
781 let src_index = src_index + b_idx * stride[0];
782 for c_idx in 0..c {
783 let dst = &mut dst[c_idx * dst_h * dst_w..];
784 let src_index = src_index + c_idx * stride[1];
785 for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() {
786 for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() {
787 let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w;
788 dst[h_idx * dst_w + w_idx] = src[src_index]
789 }
790 }
791 }
792 }
793 Ok(dst)
794 }
795}
796
797struct Gather<'a, I: IntDType> {
798 ids: &'a [I],
799 ids_l: &'a Layout,
800 dim: usize,
801}
802
803impl<'a, I: IntDType> Map1 for Gather<'a, I> {
804 fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
805 let ids = match self.ids_l.contiguous_offsets() {
806 Some((a, b)) => &self.ids[a..b],
807 None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
808 };
809 let src = match src_l.contiguous_offsets() {
810 Some((a, b)) => &src[a..b],
811 None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
812 };
813 let dim = self.dim;
814 let ids_dims = self.ids_l.dims();
815 let src_dims = src_l.dims();
816 let dst_len: usize = ids_dims.iter().product();
817 let dst_left_len: usize = ids_dims[..dim].iter().product();
818 let dst_dim_len = ids_dims[dim];
819 let dst_right_len: usize = ids_dims[dim + 1..].iter().product();
820
821 let src_dim_len = src_dims[dim];
822 let src_right_len: usize = src_dims[dim + 1..].iter().product();
823
824 let mut dst = vec![T::zero(); dst_len];
825 for left_i in 0..dst_left_len {
826 let start_src_idx = left_i * src_right_len * src_dim_len;
827 let start_dst_idx = left_i * dst_right_len * dst_dim_len;
828 for i in 0..dst_dim_len {
829 let start_dst_idx = start_dst_idx + i * dst_right_len;
830 for right_i in 0..dst_right_len {
831 let dst_idx = start_dst_idx + right_i;
832 let index = ids[dst_idx].as_usize();
833 if index >= src_dim_len {
834 Err(Error::InvalidIndex {
835 index,
836 size: src_dim_len,
837 op: "gather",
838 }
839 .bt())?
840 }
841 let src_idx = start_src_idx + index * src_right_len + right_i;
842 dst[dst_idx] = src[src_idx]
843 }
844 }
845 }
846 Ok(dst)
847 }
848}
849
850struct IndexSelect<'a, T: IntDType> {
851 ids: &'a [T],
852 ids_l: &'a Layout,
853 dim: usize,
854}
855
856impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
857 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
858 let src = match layout.contiguous_offsets() {
859 Some((a, b)) => &src[a..b],
860 None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
861 };
862 let dim = self.dim;
863 let n_ids = match self.ids_l.dims() {
864 [n_ids] => *n_ids,
865 d => Err(Error::UnexpectedNumberOfDims {
866 expected: 1,
867 got: d.len(),
868 shape: self.ids_l.shape().clone(),
869 }
870 .bt())?,
871 };
872 let stride_ids = self.ids_l.stride()[0];
873 let mut dst_dims = layout.dims().to_vec();
874 let src_dim = dst_dims[dim];
875 dst_dims[dim] = n_ids;
876 let dst_len: usize = dst_dims.iter().product();
877 let left_len: usize = dst_dims[..dim].iter().product();
878 let right_len: usize = dst_dims[dim + 1..].iter().product();
879 let mut dst = vec![T::zero(); dst_len];
880 for left_i in 0..left_len {
881 let start_src_idx = left_i * right_len * src_dim;
882 let start_dst_idx = left_i * right_len * n_ids;
883 for i in 0..n_ids {
884 let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
885 if index >= src_dim {
886 Err(Error::InvalidIndex {
887 index,
888 size: src_dim,
889 op: "index-select",
890 }
891 .bt())?
892 }
893 let start_src_idx = start_src_idx + index * right_len;
894 let start_dst_idx = start_dst_idx + i * right_len;
895 dst[start_dst_idx..start_dst_idx + right_len]
896 .copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
897 }
898 }
899 Ok(dst)
900 }
901}
902
903struct ScatterAdd<'a, I: IntDType> {
904 ids: &'a [I],
905 ids_l: &'a Layout,
906 dim: usize,
907}
908
909impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
910 const OP: &'static str = "scatter-add";
911 fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
912 let dst_len = l1.shape().elem_count();
913 let mut dst = vec![T::zero(); dst_len];
914 copy_strided_src_(v1, &mut dst, 0, l1);
915 let src = match src_l.contiguous_offsets() {
916 None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
917 Some((o1, o2)) => &src[o1..o2],
918 };
919
920 let dim = self.dim;
921 let ids_dims = self.ids_l.dims();
922 let dst_dims = l1.dims();
923 let dst_dim_len = dst_dims[dim];
924 let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
925
926 let ids_left_len: usize = ids_dims[..dim].iter().product();
927 let ids_dim_len = ids_dims[dim];
928 let ids_right_len: usize = ids_dims[dim + 1..].iter().product();
929
930 let ids = match self.ids_l.contiguous_offsets() {
931 Some((a, b)) => &self.ids[a..b],
932 None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
933 };
934 for left_i in 0..ids_left_len {
935 let start_ids_idx = left_i * ids_right_len * ids_dim_len;
936 let start_dst_idx = left_i * dst_right_len * dst_dim_len;
937 for i in 0..ids_dim_len {
938 let start_ids_idx = start_ids_idx + i * ids_right_len;
939 for right_i in 0..dst_right_len {
940 let ids_idx = start_ids_idx + right_i;
941 let index = ids[ids_idx].as_usize();
942 if index >= dst_dim_len {
943 Err(Error::InvalidIndex {
944 index,
945 size: dst_dim_len,
946 op: "gather",
947 }
948 .bt())?
949 }
950 let dst_idx = start_dst_idx + index * dst_right_len + right_i;
951 dst[dst_idx] += src[ids_idx]
952 }
953 }
954 }
955
956 Ok(dst)
957 }
958}
959
960struct IndexAdd<'a, I: IntDType> {
961 ids: &'a [I],
962 dim: usize,
963}
964
965impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
966 const OP: &'static str = "index-add";
967 fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
970 let dst_len = l1.shape().elem_count();
971 let mut dst = vec![T::zero(); dst_len];
972 copy_strided_src_(v1, &mut dst, 0, l1);
973 let src = match src_l.contiguous_offsets() {
974 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
975 Some((o1, o2)) => &src[o1..o2],
976 };
977 let dim = self.dim;
978 let max_idx = l1.dims()[dim];
979 let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
980 let src_dim_sz = src_l.dims()[dim];
981 let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
982 if dim == 0 {
983 for (src_idx, dst_idx) in self.ids.iter().enumerate() {
984 let dst_idx = dst_idx.as_usize();
985 if dst_idx >= max_idx {
986 Err(Error::InvalidIndex {
987 index: dst_idx,
988 op: "index-add",
989 size: max_idx,
990 })?
991 }
992 let src_idx = src_idx * post_dim;
993 let dst_idx = dst_idx * post_dim;
994 let src = &src[src_idx..src_idx + post_dim];
995 let dst = &mut dst[dst_idx..dst_idx + post_dim];
996 for (d, &s) in dst.iter_mut().zip(src.iter()) {
997 *d += s
998 }
999 }
1000 } else {
1001 for (src_idx, dst_idx) in self.ids.iter().enumerate() {
1002 let dst_idx = dst_idx.as_usize();
1003 if dst_idx >= max_idx {
1004 Err(Error::InvalidIndex {
1005 index: dst_idx,
1006 op: "index-add",
1007 size: max_idx,
1008 })?
1009 }
1010 for pre_i in 0..pre_dim {
1011 let pre_src_i = (pre_i * src_dim_sz + src_idx) * post_dim;
1012 let pre_dst_i = (pre_i * max_idx + dst_idx) * post_dim;
1013 let src = &src[pre_src_i..pre_src_i + post_dim];
1014 let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim];
1015 for (d, &s) in dst.iter_mut().zip(src.iter()) {
1016 *d += s
1017 }
1018 }
1019 }
1020 }
1021 Ok(dst)
1022 }
1023}
1024
1025fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
1026 match src_l.strided_blocks() {
1027 crate::StridedBlocks::SingleBlock { start_offset, len } => {
1028 let to_copy = (dst.len() - dst_offset).min(len);
1029 dst[dst_offset..dst_offset + to_copy]
1030 .copy_from_slice(&src[start_offset..start_offset + to_copy])
1031 }
1032 crate::StridedBlocks::MultipleBlocks {
1033 block_start_index,
1034 block_len: 1,
1035 } => {
1036 for (dst_index, src_index) in block_start_index.enumerate() {
1037 let dst_index = dst_index + dst_offset;
1038 if dst_index >= dst.len() {
1039 break;
1040 }
1041 dst[dst_index] = src[src_index]
1042 }
1043 }
1044 crate::StridedBlocks::MultipleBlocks {
1045 block_start_index,
1046 block_len,
1047 } => {
1048 let mut dst_index = dst_offset;
1049 for src_index in block_start_index {
1050 let next_dst_index = dst_index + block_len;
1051 if dst_index >= dst.len() {
1052 break;
1053 }
1054 let to_copy = usize::min(block_len, dst.len() - dst_index);
1055 dst[dst_index..dst_index + to_copy]
1056 .copy_from_slice(&src[src_index..src_index + to_copy]);
1057 dst_index = next_dst_index
1058 }
1059 }
1060 }
1061}
1062
1063struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
1064
1065impl<'a> Map2 for Conv1D<'a> {
1066 const OP: &'static str = "conv1d";
1067 fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1068 let p = self.0;
1069 let inp = &inp[inp_l.start_offset()..];
1070 let k = &k[k_l.start_offset()..];
1071 let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
1072 let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
1073 let l_out = p.l_out();
1074 let dst_elems = p.c_out * l_out * p.b_size;
1075 let dst = vec![T::zero(); dst_elems];
1077
1078 let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
1080 for b_idx in 0..p.b_size {
1081 for src_l in 0..p.l_in {
1082 for src_c_idx in 0..p.c_in {
1083 let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
1084 inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
1085 }
1086 }
1087 }
1088
1089 for offset in 0..p.k_size {
1090 (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1091 let dst_idx = dst_c_idx * l_out;
1092 let k_cont = (0..p.c_in)
1093 .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
1094 .collect::<Vec<_>>();
1095 for b_idx in 0..p.b_size {
1096 let dst_idx = dst_idx + b_idx * p.c_out * l_out;
1097 for dst_l in 0..l_out {
1098 let dst_idx = dst_idx + dst_l;
1099 let src_l = p.stride * dst_l + offset * p.dilation;
1100 if src_l < p.padding || src_l >= p.padding + p.l_in {
1101 continue;
1102 }
1103 let src_l = src_l - p.padding;
1104 let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
1105 assert!(inp_cont.len() >= p.c_in);
1106 assert!(k_cont.len() >= p.c_in);
1107 let mut d = T::zero();
1108 unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
1109 let dst_p = dst.as_ptr();
1110 unsafe {
1114 let ptr = dst_p.add(dst_idx) as *mut T;
1115 *ptr += d
1116 }
1117 }
1118 }
1119 })
1120 }
1121 Ok(dst)
1122 }
1123}
1124
1125struct Im2Col1D {
1126 l_k: usize,
1127 stride: usize,
1128 dilation: usize,
1129 padding: usize,
1130}
1131
1132impl Im2Col1D {
1133 fn l_out(&self, l: usize) -> usize {
1134 (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
1135 }
1136}
1137
1138impl Map1 for Im2Col1D {
1139 fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
1140 let &Self {
1141 l_k,
1142 stride,
1143 dilation,
1144 padding,
1145 } = self;
1146 let (b, c, l) = layout.shape().dims3()?;
1147 let l_out = self.l_out(l);
1148 let src = &vs[layout.start_offset()..];
1149 let mut dst = vec![T::zero(); b * l_out * c * l_k];
1150 let (src_s0, src_s1, src_s2) = {
1151 let s = layout.stride();
1152 (s[0], s[1], s[2])
1153 };
1154 for b_idx in 0..b {
1160 let src_idx = b_idx * src_s0;
1161 let dst_idx = b_idx * l_out * c * l_k;
1162 for l_idx in 0..l_out {
1163 let dst_idx = dst_idx + l_idx * c * l_k;
1164 for c_idx in 0..c {
1165 let dst_idx = dst_idx + c_idx * l_k;
1166 let src_idx = c_idx * src_s1 + src_idx;
1167 for l_k_idx in 0..l_k {
1168 let src_l = l_idx * stride + l_k_idx * dilation;
1169 if padding != 0 && (src_l < padding || src_l >= l + padding) {
1170 continue;
1171 }
1172 let src_l = src_l - padding;
1173 let src_idx = src_idx + src_l * src_s2;
1174 let dst_idx = dst_idx + l_k_idx;
1175 dst[dst_idx] = src[src_idx]
1176 }
1177 }
1178 }
1179 }
1180 Ok(dst)
1181 }
1182}
1183
1184struct Im2Col {
1185 h_k: usize,
1186 w_k: usize,
1187 stride: usize,
1188 dilation: usize,
1189 padding: usize,
1190}
1191
1192impl Im2Col {
1193 fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
1194 let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
1195 let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
1196 (h_out, w_out)
1197 }
1198}
1199
1200impl Map1 for Im2Col {
1201 fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
1202 let &Self {
1203 h_k,
1204 w_k,
1205 stride,
1206 dilation,
1207 padding,
1208 } = self;
1209 let (b, c, h, w) = layout.shape().dims4()?;
1210 let (h_out, w_out) = self.hw_out(h, w);
1211 let src = &vs[layout.start_offset()..];
1212 let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
1213 let (src_s0, src_s1, src_s2, src_s3) = {
1214 let s = layout.stride();
1215 (s[0], s[1], s[2], s[3])
1216 };
1217 for b_idx in 0..b {
1223 let src_idx = b_idx * src_s0;
1224 let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
1225 for h_idx in 0..h_out {
1226 let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
1227 for w_idx in 0..w_out {
1228 let dst_idx = dst_idx + w_idx * c * h_k * w_k;
1229 for c_idx in 0..c {
1230 let dst_idx = dst_idx + c_idx * h_k * w_k;
1231 let src_idx = c_idx * src_s1 + src_idx;
1232 for h_k_idx in 0..h_k {
1233 let src_h = h_idx * stride + h_k_idx * dilation;
1234 if padding != 0 && (src_h < padding || src_h >= h + padding) {
1235 continue;
1236 }
1237 let src_h = src_h - padding;
1238 let src_idx = src_idx + src_h * src_s2;
1239 let dst_idx = dst_idx + h_k_idx * w_k;
1240 for w_k_idx in 0..w_k {
1241 let src_w = w_idx * stride + w_k_idx * dilation;
1242 if padding != 0 && (src_w < padding || src_w >= w + padding) {
1243 continue;
1244 }
1245 let src_w = src_w - padding;
1246 let src_idx = src_idx + src_w * src_s3;
1247 let dst_idx = dst_idx + w_k_idx;
1248 dst[dst_idx] = src[src_idx]
1249 }
1250 }
1251 }
1252 }
1253 }
1254 }
1255 Ok(dst)
1256 }
1257}
1258
1259struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
1260
1261impl<'a> Map2 for Conv2D<'a> {
1262 const OP: &'static str = "conv2d";
1263 fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1264 let p = self.0;
1265 let inp = &inp[inp_l.start_offset()..];
1266 let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
1267 let k = &k[k_l.start_offset()..];
1268 let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
1269 let (out_h, out_w) = (p.out_h(), p.out_w());
1270
1271 let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
1273
1274 let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
1276 let cont_s0 = p.i_h * p.i_w * p.c_in;
1277 let cont_s1 = p.i_w * p.c_in;
1278 let cont_s2 = p.c_in;
1279 for b_idx in 0..p.b_size {
1280 for h_idx in 0..p.i_h {
1281 for w_idx in 0..p.i_w {
1282 for c_idx in 0..p.c_in {
1283 let src_idx =
1284 b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
1285 let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
1286 inp_cont[dst_idx] = inp[src_idx]
1287 }
1288 }
1289 }
1290 }
1291
1292 for offset_h in 0..p.k_h {
1293 for offset_w in 0..p.k_w {
1294 (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1295 let dst_idx = dst_c_idx * out_w * out_h;
1296 let k_cont = (0..p.c_in)
1297 .map(|c_in_idx| {
1298 k[dst_c_idx * k_s0
1299 + c_in_idx * k_s1
1300 + offset_h * k_s2
1301 + offset_w * k_s3]
1302 })
1303 .collect::<Vec<_>>();
1304 for b_idx in 0..p.b_size {
1305 let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
1306 for dst_h in 0..out_h {
1307 let dst_idx = dst_idx + dst_h * out_w;
1308 let src_h = p.stride * dst_h + offset_h * p.dilation;
1309 if src_h < p.padding || src_h >= p.i_h + p.padding {
1310 continue;
1311 }
1312 let src_h = src_h - p.padding;
1313 for dst_w in 0..out_w {
1314 let dst_idx = dst_idx + dst_w;
1315 let src_w = p.stride * dst_w + offset_w * p.dilation;
1316 if src_w < p.padding || src_w >= p.i_w + p.padding {
1317 continue;
1318 }
1319 let src_w = src_w - p.padding;
1320 let inp_cont = &inp_cont
1321 [b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..];
1322 assert!(inp_cont.len() >= p.c_in);
1323 assert!(k_cont.len() >= p.c_in);
1324 let mut d = T::zero();
1325 unsafe {
1326 T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
1327 }
1328 let dst_p = dst.as_ptr();
1329 unsafe {
1333 let ptr = dst_p.add(dst_idx) as *mut T;
1334 *ptr += d
1335 }
1336 }
1337 }
1338 }
1339 });
1340 }
1341 }
1342
1343 Ok(dst)
1344 }
1345}
1346
1347struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
1348
1349impl<'a> Map2 for ConvTranspose2D<'a> {
1350 const OP: &'static str = "conv_transpose2d";
1351 fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1352 let p = self.0;
1353 let inp = &inp[inp_l.start_offset()..];
1354 let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
1355 let k = &k[k_l.start_offset()..];
1356 let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
1357 let (out_h, out_w) = (p.out_h(), p.out_w());
1358
1359 let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
1361 let dst_s0 = p.c_out * out_h * out_w;
1362 let dst_s1 = out_h * out_w;
1363 let dst_s2 = out_w;
1364 let dst_s3 = 1;
1365
1366 let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
1368 let cont_s0 = p.i_h * p.i_w * p.c_in;
1369 let cont_s1 = p.i_w * p.c_in;
1370 let cont_s2 = p.c_in;
1371 for b_idx in 0..p.b_size {
1372 for h_idx in 0..p.i_h {
1373 for w_idx in 0..p.i_w {
1374 for c_idx in 0..p.c_in {
1375 let src_idx =
1376 b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
1377 let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
1378 inp_cont[dst_idx] = inp[src_idx]
1379 }
1380 }
1381 }
1382 }
1383
1384 for k_y in 0..p.k_h {
1385 for k_x in 0..p.k_w {
1386 (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1387 let k_cont = (0..p.c_in)
1388 .map(|c_in_idx| {
1389 k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
1390 })
1391 .collect::<Vec<_>>();
1392 for b_idx in 0..p.b_size {
1393 for inp_y in 0..p.i_h {
1394 for inp_x in 0..p.i_w {
1395 let out_x = inp_x * p.stride + k_x * p.dilation;
1396 let out_y = inp_y * p.stride + k_y * p.dilation;
1397 if out_x < p.padding || out_y < p.padding {
1398 continue;
1399 }
1400 let out_x = out_x - p.padding;
1401 let out_y = out_y - p.padding;
1402 if out_x < out_w && out_y < out_h {
1403 let inp_cont = &inp_cont
1404 [b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..];
1405 let dst_idx = b_idx * dst_s0
1406 + out_y * dst_s2
1407 + out_x * dst_s3
1408 + dst_c_idx * dst_s1;
1409 let mut d = T::zero();
1410 unsafe {
1411 T::vec_dot(
1412 inp_cont.as_ptr(),
1413 k_cont.as_ptr(),
1414 &mut d,
1415 p.c_in,
1416 )
1417 }
1418 let dst_p = dst.as_ptr();
1419 unsafe {
1423 let ptr = dst_p.add(dst_idx) as *mut T;
1424 *ptr += d
1425 }
1426 }
1427 }
1428 }
1429 }
1430 })
1431 }
1432 }
1433 Ok(dst)
1434 }
1435}
1436
1437struct MatMul((usize, usize, usize, usize));
1438
1439impl MatMul {
1440 fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
1441 Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
1442 lhs_l: lhs_l.clone(),
1443 rhs_l: rhs_l.clone(),
1444 bmnk: self.0,
1445 msg,
1446 }))
1447 .bt()
1448 }
1449}
1450
1451impl Map2 for MatMul {
1452 const OP: &'static str = "mat_mul";
1453
1454 #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
1455 fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1456 &self,
1457 lhs: &[T],
1458 lhs_l: &Layout,
1459 rhs: &[T],
1460 rhs_l: &Layout,
1461 ) -> Result<Vec<T>> {
1462 use gemm::{gemm, Parallelism};
1463
1464 match T::DTYPE {
1465 DType::F16 | DType::F32 | DType::F64 => {}
1466 _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
1467 }
1468
1469 let (b, m, n, k) = self.0;
1470 let lhs = &lhs[lhs_l.start_offset()..];
1471 let rhs = &rhs[rhs_l.start_offset()..];
1472
1473 let lhs_stride = lhs_l.stride();
1474 let rhs_stride = rhs_l.stride();
1475 let rank = lhs_stride.len();
1476 let lhs_cs = lhs_stride[rank - 1];
1477 let lhs_rs = lhs_stride[rank - 2];
1478
1479 let rhs_cs = rhs_stride[rank - 1];
1480 let rhs_rs = rhs_stride[rank - 2];
1481
1482 let a_skip: usize = match lhs_stride[..rank - 2] {
1483 [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
1484 [stride] => stride,
1485 [] => m * k,
1486 _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
1487 };
1488 let b_skip: usize = match rhs_stride[..rank - 2] {
1489 [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
1490 [stride] => stride,
1491 [] => n * k,
1492 _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
1493 };
1494 let c_skip: usize = m * n;
1495
1496 let dst_shape: Shape = (m, n).into();
1497 let dst_strides = dst_shape.stride_contiguous();
1498 let dst_rs = dst_strides[0];
1499 let dst_cs = dst_strides[1];
1500
1501 let mut dst = vec![T::zero(); b * m * n];
1502 let num_threads = crate::utils::get_num_threads();
1503 let parallelism = if num_threads > 1 {
1504 Parallelism::Rayon(num_threads)
1505 } else {
1506 Parallelism::None
1507 };
1508 for step in 0..b {
1509 let lhs_p = &lhs[step * a_skip..];
1510 let rhs_p = &rhs[step * b_skip..];
1511 let dst_p = &mut dst[step * c_skip..];
1512 unsafe {
1513 gemm(
1514 m,
1515 n,
1516 k,
1517 dst_p.as_mut_ptr(),
1518 dst_cs as isize,
1519 dst_rs as isize,
1520 false,
1521 lhs_p.as_ptr(),
1522 lhs_cs as isize,
1523 lhs_rs as isize,
1524 rhs_p.as_ptr(),
1525 rhs_cs as isize,
1526 rhs_rs as isize,
1527 T::zero(),
1528 T::one(),
1529 false,
1530 false,
1531 false,
1532 parallelism,
1533 )
1534 }
1535 }
1536 Ok(dst)
1537 }
1538
1539 #[cfg(feature = "accelerate")]
1540 fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1541 &self,
1542 lhs: &[T],
1543 lhs_l: &Layout,
1544 rhs: &[T],
1545 rhs_l: &Layout,
1546 ) -> Result<Vec<T>> {
1547 let (b, m, n, k) = self.0;
1548 let lhs = &lhs[lhs_l.start_offset()..];
1549 let rhs = &rhs[rhs_l.start_offset()..];
1550
1551 let lhs_stride = lhs_l.stride();
1552 let rhs_stride = rhs_l.stride();
1553 let rank = lhs_stride.len();
1554
1555 let a_skip: usize = match lhs_stride[..rank - 2] {
1556 [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
1557 [stride] => stride,
1558 [] => m * k,
1559 _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
1560 };
1561 let b_skip: usize = match rhs_stride[..rank - 2] {
1562 [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
1563 [stride] => stride,
1564 [] => n * k,
1565 _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
1566 };
1567 let c_skip: usize = m * n;
1568
1569 let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1570 let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1571 let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1572 let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1573
1574 let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
1575 (n as i32, b'N')
1576 } else if rhs_m1 == k && rhs_m2 == 1 {
1577 (k as i32, b'T')
1578 } else {
1579 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1580 };
1581 let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
1583 (k as i32, b'N')
1584 } else if lhs_m1 == m && lhs_m2 == 1 {
1585 (m as i32, b'T')
1586 } else {
1587 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1588 };
1589
1590 let mut dst = vec![T::zero(); b * m * n];
1591 match T::DTYPE {
1592 DType::F16 => {
1593 crate::bail!("the accelerate backend does not support f16 matmul")
1594 }
1595 DType::F32 => {
1596 for step in 0..b {
1597 let lhs_p = &lhs[step * a_skip..];
1598 let rhs_p = &rhs[step * b_skip..];
1599 let dst_p = &mut dst[step * c_skip..];
1600 unsafe {
1601 let a = rhs_p.as_ptr() as *const f32;
1602 let b = lhs_p.as_ptr() as *const f32;
1603 let c = dst_p.as_mut_ptr() as *mut f32;
1604 let a = std::slice::from_raw_parts(a, a_skip);
1605 let b = std::slice::from_raw_parts(b, b_skip);
1606 let c = std::slice::from_raw_parts_mut(c, c_skip);
1607 crate::accelerate::sgemm(
1608 transa, transb, n as i32, m as i32,
1609 k as i32, 1., a,
1610 lda, b, ldb,
1611 0., c, n as i32,
1612 )
1613 }
1614 }
1615 }
1616 DType::F64 => {
1617 for step in 0..b {
1618 let lhs_p = &lhs[step * a_skip..];
1619 let rhs_p = &rhs[step * b_skip..];
1620 let dst_p = &mut dst[step * c_skip..];
1621 unsafe {
1622 let a = rhs_p.as_ptr() as *const f64;
1623 let b = lhs_p.as_ptr() as *const f64;
1624 let c = dst_p.as_mut_ptr() as *mut f64;
1625 let a = std::slice::from_raw_parts(a, a_skip);
1626 let b = std::slice::from_raw_parts(b, b_skip);
1627 let c = std::slice::from_raw_parts_mut(c, c_skip);
1628 crate::accelerate::dgemm(
1629 transa, transb, n as i32, m as i32,
1630 k as i32, 1., a,
1631 lda, b, ldb,
1632 0., c, n as i32,
1633 )
1634 }
1635 }
1636 }
1637 dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1638 }
1639 Ok(dst)
1640 }
1641
1642 #[cfg(feature = "mkl")]
1643 fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1644 &self,
1645 lhs: &[T],
1646 lhs_l: &Layout,
1647 rhs: &[T],
1648 rhs_l: &Layout,
1649 ) -> Result<Vec<T>> {
1650 let (b, m, n, k) = self.0;
1651 let lhs = &lhs[lhs_l.start_offset()..];
1652 let rhs = &rhs[rhs_l.start_offset()..];
1653
1654 let lhs_stride = lhs_l.stride();
1655 let rhs_stride = rhs_l.stride();
1656 let rank = lhs_stride.len();
1657
1658 let a_skip: usize = match lhs_stride[..rank - 2] {
1659 [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
1660 [stride] => stride,
1661 [] => m * k,
1662 _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
1663 };
1664 let b_skip: usize = match rhs_stride[..rank - 2] {
1665 [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
1666 [stride] => stride,
1667 [] => n * k,
1668 _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
1669 };
1670 let c_skip: usize = m * n;
1671
1672 let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1673 let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1674 let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1675 let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1676
1677 let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
1678 (n as i32, b'N')
1679 } else if rhs_m1 == k && rhs_m2 == 1 {
1680 (k as i32, b'T')
1681 } else {
1682 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1683 };
1684 let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
1686 (k as i32, b'N')
1687 } else if lhs_m1 == m && lhs_m2 == 1 {
1688 (m as i32, b'T')
1689 } else {
1690 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1691 };
1692
1693 let mut dst = vec![T::zero(); b * m * n];
1694 match T::DTYPE {
1695 DType::F16 => {
1696 for step in 0..b {
1697 let lhs_p = &lhs[step * a_skip..];
1698 let rhs_p = &rhs[step * b_skip..];
1699 let dst_p = &mut dst[step * c_skip..];
1700 unsafe {
1701 let a = rhs_p.as_ptr() as *const f16;
1702 let b = lhs_p.as_ptr() as *const f16;
1703 let c = dst_p.as_mut_ptr() as *mut f16;
1704 let a = std::slice::from_raw_parts(a, a_skip);
1705 let b = std::slice::from_raw_parts(b, b_skip);
1706 let c = std::slice::from_raw_parts_mut(c, c_skip);
1707 crate::mkl::hgemm(
1708 transa,
1709 transb,
1710 n as i32,
1711 m as i32,
1712 k as i32,
1713 f16::ONE,
1714 a,
1715 lda,
1716 b,
1717 ldb,
1718 f16::ZERO,
1719 c,
1720 n as i32,
1721 )
1722 }
1723 }
1724 }
1725 DType::F32 => {
1726 for step in 0..b {
1727 let lhs_p = &lhs[step * a_skip..];
1728 let rhs_p = &rhs[step * b_skip..];
1729 let dst_p = &mut dst[step * c_skip..];
1730 unsafe {
1731 let a = rhs_p.as_ptr() as *const f32;
1732 let b = lhs_p.as_ptr() as *const f32;
1733 let c = dst_p.as_mut_ptr() as *mut f32;
1734 let a = std::slice::from_raw_parts(a, a_skip);
1735 let b = std::slice::from_raw_parts(b, b_skip);
1736 let c = std::slice::from_raw_parts_mut(c, c_skip);
1737 crate::mkl::sgemm(
1738 transa, transb, n as i32, m as i32,
1739 k as i32, 1., a,
1740 lda, b, ldb,
1741 0., c, n as i32,
1742 )
1743 }
1744 }
1745 }
1746 DType::F64 => {
1747 for step in 0..b {
1748 let lhs_p = &lhs[step * a_skip..];
1749 let rhs_p = &rhs[step * b_skip..];
1750 let dst_p = &mut dst[step * c_skip..];
1751 unsafe {
1752 let a = rhs_p.as_ptr() as *const f64;
1753 let b = lhs_p.as_ptr() as *const f64;
1754 let c = dst_p.as_mut_ptr() as *mut f64;
1755 let a = std::slice::from_raw_parts(a, a_skip);
1756 let b = std::slice::from_raw_parts(b, b_skip);
1757 let c = std::slice::from_raw_parts_mut(c, c_skip);
1758 crate::mkl::dgemm(
1759 transa, transb, n as i32, m as i32,
1760 k as i32, 1., a,
1761 lda, b, ldb,
1762 0., c, n as i32,
1763 )
1764 }
1765 }
1766 }
1767 dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1768 }
1769 Ok(dst)
1770 }
1771}
1772
1773fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
1774 if v.is_sign_positive() {
1775 v
1776 } else {
1777 (v.exp() - T::one()) * alpha
1778 }
1779}
1780
1781impl CpuStorage {
1782 pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
1783 D::cpu_storage_as_slice(self)
1784 }
1785
1786 pub fn concat(storages: &[CpuStorage]) -> Result<CpuStorage> {
1787 let storage0 = &storages[0];
1788 let s = match storage0 {
1789 Self::U8(_) => {
1790 let storages = storages
1791 .iter()
1792 .map(|s| match s {
1793 Self::U8(s) => Ok(s.as_slice()),
1794 _ => crate::bail!("dtype mismatch"),
1795 })
1796 .collect::<Result<Vec<_>>>()?
1797 .concat();
1798 Self::U8(storages)
1799 }
1800 Self::U32(_) => {
1801 let storages = storages
1802 .iter()
1803 .map(|s| match s {
1804 Self::U32(s) => Ok(s.as_slice()),
1805 _ => crate::bail!("dtype mismatch"),
1806 })
1807 .collect::<Result<Vec<_>>>()?
1808 .concat();
1809 Self::U32(storages)
1810 }
1811 Self::I64(_) => {
1812 let storages = storages
1813 .iter()
1814 .map(|s| match s {
1815 Self::I64(s) => Ok(s.as_slice()),
1816 _ => crate::bail!("dtype mismatch"),
1817 })
1818 .collect::<Result<Vec<_>>>()?
1819 .concat();
1820 Self::I64(storages)
1821 }
1822 Self::BF16(_) => {
1823 let storages = storages
1824 .iter()
1825 .map(|s| match s {
1826 Self::BF16(s) => Ok(s.as_slice()),
1827 _ => crate::bail!("dtype mismatch"),
1828 })
1829 .collect::<Result<Vec<_>>>()?
1830 .concat();
1831 Self::BF16(storages)
1832 }
1833 Self::F16(_) => {
1834 let storages = storages
1835 .iter()
1836 .map(|s| match s {
1837 Self::F16(s) => Ok(s.as_slice()),
1838 _ => crate::bail!("dtype mismatch"),
1839 })
1840 .collect::<Result<Vec<_>>>()?
1841 .concat();
1842 Self::F16(storages)
1843 }
1844 Self::F32(_) => {
1845 let storages = storages
1846 .iter()
1847 .map(|s| match s {
1848 Self::F32(s) => Ok(s.as_slice()),
1849 _ => crate::bail!("dtype mismatch"),
1850 })
1851 .collect::<Result<Vec<_>>>()?
1852 .concat();
1853 Self::F32(storages)
1854 }
1855 Self::F64(_) => {
1856 let storages = storages
1857 .iter()
1858 .map(|s| match s {
1859 Self::F64(s) => Ok(s.as_slice()),
1860 _ => crate::bail!("dtype mismatch"),
1861 })
1862 .collect::<Result<Vec<_>>>()?
1863 .concat();
1864 Self::F64(storages)
1865 }
1866 };
1867 Ok(s)
1868 }
1869}
1870
1871impl BackendStorage for CpuStorage {
1872 type Device = CpuDevice;
1873
1874 fn dtype(&self) -> DType {
1875 match self {
1876 Self::U8(_) => DType::U8,
1877 Self::U32(_) => DType::U32,
1878 Self::I64(_) => DType::I64,
1879 Self::BF16(_) => DType::BF16,
1880 Self::F16(_) => DType::F16,
1881 Self::F32(_) => DType::F32,
1882 Self::F64(_) => DType::F64,
1883 }
1884 }
1885
1886 fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
1887 match (self, dtype) {
1889 (Self::U8(storage), DType::BF16) => {
1890 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1891 Ok(Self::BF16(data))
1892 }
1893 (Self::U32(storage), DType::BF16) => {
1894 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1895 Ok(Self::BF16(data))
1896 }
1897 (Self::I64(storage), DType::BF16) => {
1898 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1899 Ok(Self::BF16(data))
1900 }
1901 (Self::BF16(storage), DType::BF16) => {
1902 let data = unary_map(storage, layout, |v| v);
1903 Ok(Self::BF16(data))
1904 }
1905 (Self::F16(storage), DType::BF16) => {
1906 let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
1907 Ok(Self::BF16(data))
1908 }
1909 (Self::F32(storage), DType::BF16) => {
1910 let data = unary_map(storage, layout, bf16::from_f32);
1911 Ok(Self::BF16(data))
1912 }
1913 (Self::F64(storage), DType::BF16) => {
1914 let data = unary_map(storage, layout, bf16::from_f64);
1915 Ok(Self::BF16(data))
1916 }
1917 (Self::U8(storage), DType::F16) => {
1918 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1919 Ok(Self::F16(data))
1920 }
1921 (Self::U32(storage), DType::F16) => {
1922 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1923 Ok(Self::F16(data))
1924 }
1925 (Self::I64(storage), DType::F16) => {
1926 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1927 Ok(Self::F16(data))
1928 }
1929 (Self::BF16(storage), DType::F16) => {
1930 let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
1931 Ok(Self::F16(data))
1932 }
1933 (Self::F16(storage), DType::F16) => {
1934 let data = unary_map(storage, layout, |v| v);
1935 Ok(Self::F16(data))
1936 }
1937 (Self::F32(storage), DType::F16) => {
1938 let data = unary_map(storage, layout, f16::from_f32);
1939 Ok(Self::F16(data))
1940 }
1941 (Self::F64(storage), DType::F16) => {
1942 let data = unary_map(storage, layout, f16::from_f64);
1943 Ok(Self::F16(data))
1944 }
1945 (Self::U8(storage), DType::F32) => {
1946 let data = unary_map(storage, layout, |v| v as f32);
1947 Ok(Self::F32(data))
1948 }
1949 (Self::U32(storage), DType::F32) => {
1950 let data = unary_map(storage, layout, |v| v as f32);
1951 Ok(Self::F32(data))
1952 }
1953 (Self::I64(storage), DType::F32) => {
1954 let data = unary_map(storage, layout, |v| v as f32);
1955 Ok(Self::F32(data))
1956 }
1957 (Self::BF16(storage), DType::F32) => {
1958 let data = unary_map(storage, layout, |v| v.to_f32());
1959 Ok(Self::F32(data))
1960 }
1961 (Self::F16(storage), DType::F32) => {
1962 let data = unary_map(storage, layout, |v| v.to_f32());
1963 Ok(Self::F32(data))
1964 }
1965 (Self::F32(storage), DType::F32) => {
1966 let data = unary_map(storage, layout, |v| v);
1967 Ok(Self::F32(data))
1968 }
1969 (Self::F64(storage), DType::F32) => {
1970 let data = unary_map(storage, layout, |v| v as f32);
1971 Ok(Self::F32(data))
1972 }
1973 (Self::U8(storage), DType::U8) => {
1974 let data = unary_map(storage, layout, |v| v);
1975 Ok(Self::U8(data))
1976 }
1977 (Self::BF16(storage), DType::U8) => {
1978 let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1979 Ok(Self::U8(data))
1980 }
1981 (Self::F16(storage), DType::U8) => {
1982 let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1983 Ok(Self::U8(data))
1984 }
1985 (Self::F32(storage), DType::U8) => {
1986 let data = unary_map(storage, layout, |v| v as u8);
1987 Ok(Self::U8(data))
1988 }
1989 (Self::F64(storage), DType::U8) => {
1990 let data = unary_map(storage, layout, |v| v as u8);
1991 Ok(Self::U8(data))
1992 }
1993 (Self::U32(storage), DType::U8) => {
1994 let data = unary_map(storage, layout, |v| v as u8);
1995 Ok(Self::U8(data))
1996 }
1997 (Self::I64(storage), DType::U8) => {
1998 let data = unary_map(storage, layout, |v| v as u8);
1999 Ok(Self::U8(data))
2000 }
2001 (Self::U8(storage), DType::U32) => {
2002 let data = unary_map(storage, layout, |v| v as u32);
2003 Ok(Self::U32(data))
2004 }
2005 (Self::U32(storage), DType::U32) => {
2006 let data = unary_map(storage, layout, |v| v);
2007 Ok(Self::U32(data))
2008 }
2009 (Self::I64(storage), DType::U32) => {
2010 let data = unary_map(storage, layout, |v| v as u32);
2011 Ok(Self::U32(data))
2012 }
2013 (Self::BF16(storage), DType::U32) => {
2014 let data = unary_map(storage, layout, |v| v.to_f32() as u32);
2015 Ok(Self::U32(data))
2016 }
2017 (Self::F16(storage), DType::U32) => {
2018 let data = unary_map(storage, layout, |v| v.to_f32() as u32);
2019 Ok(Self::U32(data))
2020 }
2021 (Self::F32(storage), DType::U32) => {
2022 let data = unary_map(storage, layout, |v| v as u32);
2023 Ok(Self::U32(data))
2024 }
2025 (Self::F64(storage), DType::U32) => {
2026 let data = unary_map(storage, layout, |v| v as u32);
2027 Ok(Self::U32(data))
2028 }
2029 (Self::U8(storage), DType::I64) => {
2030 let data = unary_map(storage, layout, |v| v as i64);
2031 Ok(Self::I64(data))
2032 }
2033 (Self::U32(storage), DType::I64) => {
2034 let data = unary_map(storage, layout, |v| v as i64);
2035 Ok(Self::I64(data))
2036 }
2037 (Self::I64(storage), DType::I64) => {
2038 let data = unary_map(storage, layout, |v| v);
2039 Ok(Self::I64(data))
2040 }
2041 (Self::BF16(storage), DType::I64) => {
2042 let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2043 Ok(Self::I64(data))
2044 }
2045 (Self::F16(storage), DType::I64) => {
2046 let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2047 Ok(Self::I64(data))
2048 }
2049 (Self::F32(storage), DType::I64) => {
2050 let data = unary_map(storage, layout, |v| v as i64);
2051 Ok(Self::I64(data))
2052 }
2053 (Self::F64(storage), DType::I64) => {
2054 let data = unary_map(storage, layout, |v| v as i64);
2055 Ok(Self::I64(data))
2056 }
2057 (Self::U8(storage), DType::F64) => {
2058 let data = unary_map(storage, layout, |v| v as f64);
2059 Ok(Self::F64(data))
2060 }
2061 (Self::U32(storage), DType::F64) => {
2062 let data = unary_map(storage, layout, |v| v as f64);
2063 Ok(Self::F64(data))
2064 }
2065 (Self::I64(storage), DType::F64) => {
2066 let data = unary_map(storage, layout, |v| v as f64);
2067 Ok(Self::F64(data))
2068 }
2069 (Self::BF16(storage), DType::F64) => {
2070 let data = unary_map(storage, layout, |v| v.to_f64());
2071 Ok(Self::F64(data))
2072 }
2073 (Self::F16(storage), DType::F64) => {
2074 let data = unary_map(storage, layout, |v| v.to_f64());
2075 Ok(Self::F64(data))
2076 }
2077 (Self::F32(storage), DType::F64) => {
2078 let data = unary_map(storage, layout, |v| v as f64);
2079 Ok(Self::F64(data))
2080 }
2081 (Self::F64(storage), DType::F64) => {
2082 let data = unary_map(storage, layout, |v| v);
2083 Ok(Self::F64(data))
2084 }
2085 }
2086 }
2087
2088 fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
2089 match op {
2090 ReduceOp::Sum => {
2091 let src_dims = layout.dims();
2092 let mut dst_dims = src_dims.to_vec();
2093 for &dim in reduce_dims.iter() {
2094 dst_dims[dim] = 1;
2095 }
2096 let dst_shape = Shape::from(dst_dims);
2097 let mut reduce_dims = reduce_dims.to_vec();
2098 reduce_dims.sort();
2101 let reduce_dims_and_stride: Vec<_> = reduce_dims
2102 .iter()
2103 .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
2104 .collect();
2105 ReduceSum {
2106 dst_shape: &dst_shape,
2107 reduce_dims: &reduce_dims,
2108 reduce_dims_and_stride,
2109 }
2110 .map(self, layout)
2111 }
2112 ReduceOp::Min | ReduceOp::ArgMin | ReduceOp::Max | ReduceOp::ArgMax => {
2113 let reduce_dim_index = match reduce_dims {
2114 [reduce_dim_index] => *reduce_dim_index,
2115 _ => {
2116 let op = match op {
2117 ReduceOp::Min => "min",
2118 ReduceOp::ArgMin => "argmin",
2119 ReduceOp::Max => "max",
2120 ReduceOp::ArgMax => "argmax",
2121 _ => unreachable!(),
2122 };
2123 let dims = reduce_dims.to_vec();
2124 Err(Error::OnlySingleDimension { op, dims })?
2125 }
2126 };
2127 let (use_min, return_index) = match op {
2128 ReduceOp::Min => (true, false),
2129 ReduceOp::ArgMin => (true, true),
2130 ReduceOp::Max => (false, false),
2131 ReduceOp::ArgMax => (false, true),
2132 _ => unreachable!(),
2133 };
2134 ReduceIndex {
2135 reduce_dim_index,
2136 use_min,
2137 return_index,
2138 }
2139 .map(self, layout)
2140 }
2141 }
2142 }
2143
2144 fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
2145 Cmp(op).map(self, lhs_l, rhs, rhs_l)
2146 }
2147
2148 fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
2149 Affine(mul, add).map(self, layout)
2150 }
2151
2152 fn avg_pool2d(
2153 &self,
2154 layout: &Layout,
2155 kernel_size: (usize, usize),
2156 stride: (usize, usize),
2157 ) -> Result<Self> {
2158 AvgPool2D(kernel_size, stride).map(self, layout)
2159 }
2160
2161 fn max_pool2d(
2162 &self,
2163 layout: &Layout,
2164 kernel_size: (usize, usize),
2165 stride: (usize, usize),
2166 ) -> Result<Self> {
2167 MaxPool2D(kernel_size, stride).map(self, layout)
2168 }
2169
2170 fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
2171 UpsampleNearest1D(sz).map(self, layout)
2172 }
2173
2174 fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
2175 UpsampleNearest2D(h, w).map(self, layout)
2176 }
2177
2178 fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
2179 use num_traits::Float;
2180 match self {
2182 Self::BF16(storage) => {
2183 let data = unary_map(storage, layout, |v| v.powf(bf16::from_f64(e)));
2184 Ok(Self::BF16(data))
2185 }
2186 Self::F16(storage) => {
2187 let data = unary_map(storage, layout, |v| v.powf(f16::from_f64(e)));
2188 Ok(Self::F16(data))
2189 }
2190 Self::F32(storage) => {
2191 let data = unary_map(storage, layout, |v| v.powf(e as f32));
2192 Ok(Self::F32(data))
2193 }
2194 Self::F64(storage) => {
2195 let data = unary_map(storage, layout, |v| v.powf(e));
2196 Ok(Self::F64(data))
2197 }
2198 Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
2199 Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
2200 Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
2201 }
2202 }
2203
2204 fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
2205 match self {
2207 Self::BF16(storage) => {
2208 let data = unary_map(storage, layout, |v| elu(v, bf16::from_f64(alpha)));
2209 Ok(Self::BF16(data))
2210 }
2211 Self::F16(storage) => {
2212 let data = unary_map(storage, layout, |v| elu(v, f16::from_f64(alpha)));
2213 Ok(Self::F16(data))
2214 }
2215 Self::F32(storage) => {
2216 let data = unary_map(storage, layout, |v| elu(v, f32::from_f64(alpha)));
2217 Ok(Self::F32(data))
2218 }
2219 Self::F64(storage) => {
2220 let data = unary_map(storage, layout, |v| elu(v, alpha));
2221 Ok(Self::F64(data))
2222 }
2223 Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
2224 Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
2225 Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
2226 }
2227 }
2228
2229 fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
2230 match self {
2231 Self::BF16(storage) => {
2232 if B::BF16_VEC {
2233 let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
2234 Ok(Self::BF16(data))
2235 } else {
2236 let data = unary_map(storage, layout, B::bf16);
2237 Ok(Self::BF16(data))
2238 }
2239 }
2240 Self::F16(storage) => {
2241 if B::F16_VEC {
2242 let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
2243 Ok(Self::F16(data))
2244 } else {
2245 let data = unary_map(storage, layout, B::f16);
2246 Ok(Self::F16(data))
2247 }
2248 }
2249 Self::F32(storage) => {
2250 if B::F32_VEC {
2251 let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
2252 Ok(Self::F32(data))
2253 } else {
2254 let data = unary_map(storage, layout, B::f32);
2255 Ok(Self::F32(data))
2256 }
2257 }
2258 Self::F64(storage) => {
2259 if B::F64_VEC {
2260 let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
2261 Ok(Self::F64(data))
2262 } else {
2263 let data = unary_map(storage, layout, B::f64);
2264 Ok(Self::F64(data))
2265 }
2266 }
2267 Self::U8(storage) => {
2268 let data = unary_map(storage, layout, B::u8);
2269 Ok(Self::U8(data))
2270 }
2271 Self::U32(storage) => {
2272 let data = unary_map(storage, layout, B::u32);
2273 Ok(Self::U32(data))
2274 }
2275 Self::I64(storage) => {
2276 let data = unary_map(storage, layout, B::i64);
2277 Ok(Self::I64(data))
2278 }
2279 }
2280 }
2281
2282 fn binary_impl<B: BinaryOpT>(
2283 &self,
2284 rhs: &Self,
2285 lhs_l: &Layout,
2286 rhs_l: &Layout,
2287 ) -> Result<Self> {
2288 match (self, rhs) {
2289 (Self::BF16(lhs), Self::BF16(rhs)) => {
2290 let data = if B::BF16_VEC {
2291 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::bf16, B::bf16_vec)
2292 } else {
2293 binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16)
2294 };
2295 Ok(Self::BF16(data))
2296 }
2297 (Self::F16(lhs), Self::F16(rhs)) => {
2298 let data = if B::F16_VEC {
2299 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f16, B::f16_vec)
2300 } else {
2301 binary_map(lhs_l, rhs_l, lhs, rhs, B::f16)
2302 };
2303 Ok(Self::F16(data))
2304 }
2305 (Self::F32(lhs), Self::F32(rhs)) => {
2306 let data = if B::F32_VEC {
2307 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f32, B::f32_vec)
2308 } else {
2309 binary_map(lhs_l, rhs_l, lhs, rhs, B::f32)
2310 };
2311 Ok(Self::F32(data))
2312 }
2313 (Self::F64(lhs), Self::F64(rhs)) => {
2314 let data = if B::F64_VEC {
2315 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f64, B::f64_vec)
2316 } else {
2317 binary_map(lhs_l, rhs_l, lhs, rhs, B::f64)
2318 };
2319 Ok(Self::F64(data))
2320 }
2321 (Self::U32(lhs), Self::U32(rhs)) => {
2322 let data = if B::U32_VEC {
2323 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u32, B::u32_vec)
2324 } else {
2325 binary_map(lhs_l, rhs_l, lhs, rhs, B::u32)
2326 };
2327 Ok(Self::U32(data))
2328 }
2329 (Self::I64(lhs), Self::I64(rhs)) => {
2330 let data = if B::I64_VEC {
2331 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
2332 } else {
2333 binary_map(lhs_l, rhs_l, lhs, rhs, B::i64)
2334 };
2335 Ok(Self::I64(data))
2336 }
2337 (Self::U8(lhs), Self::U8(rhs)) => {
2338 let data = if B::U8_VEC {
2339 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec)
2340 } else {
2341 binary_map(lhs_l, rhs_l, lhs, rhs, B::u8)
2342 };
2343 Ok(Self::U8(data))
2344 }
2345 _ => {
2346 Err(Error::DTypeMismatchBinaryOp {
2348 lhs: self.dtype(),
2349 rhs: rhs.dtype(),
2350 op: B::NAME,
2351 }
2352 .bt())
2353 }
2354 }
2355 }
2356
2357 fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
2358 match (self, dst) {
2359 (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2360 (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2361 (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2362 (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2363 (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2364 (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2365 (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2366 (_, dst) => {
2367 return Err(Error::DTypeMismatchBinaryOp {
2369 lhs: self.dtype(),
2370 rhs: dst.dtype(),
2371 op: "copy_strided",
2372 }
2373 .bt());
2374 }
2375 }
2376 Ok(())
2377 }
2378
2379 fn where_cond(
2380 &self,
2381 layout: &Layout,
2382 t: &Self,
2383 t_l: &Layout,
2384 f: &Self,
2385 f_l: &Layout,
2386 ) -> Result<Self> {
2387 match self {
2388 Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2389 Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2390 Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2391 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
2392 }
2393 }
2394
2395 fn conv1d(
2396 &self,
2397 l: &Layout,
2398 kernel: &Self,
2399 kernel_l: &Layout,
2400 params: &crate::conv::ParamsConv1D,
2401 ) -> Result<Self> {
2402 if !USE_IM2COL_CONV1D {
2403 return Conv1D(params).map(self, l, kernel, kernel_l);
2404 }
2405 let op = Im2Col1D {
2406 l_k: params.k_size,
2407 padding: params.padding,
2408 stride: params.stride,
2409 dilation: params.dilation,
2410 };
2411 let col = op.map(self, l)?;
2412 let b = params.b_size;
2413 let n = params.c_out;
2414 let l_out = params.l_out();
2415 let k = op.l_k * params.c_in;
2416 let m = l_out;
2417 let col_l = Layout::contiguous((b, m, k));
2418 let res = if kernel_l.is_contiguous() {
2419 let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2420 .transpose(1, 2)?
2421 .broadcast_as((b, k, n))?;
2422 col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2423 } else {
2424 let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
2426 kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
2427 let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2428 .transpose(1, 2)?
2429 .broadcast_as((b, k, n))?;
2430 col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2431 };
2432 let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
2433 let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
2434 res.copy_strided_src(&mut res_t, 0, &res_l)?;
2435 Ok(res_t)
2436 }
2437
2438 fn conv2d(
2439 &self,
2440 l: &Layout,
2441 kernel: &Self,
2442 kernel_l: &Layout,
2443 params: &crate::conv::ParamsConv2D,
2444 ) -> Result<Self> {
2445 if !USE_IM2COL_CONV2D {
2446 return Conv2D(params).map(self, l, kernel, kernel_l);
2447 }
2448 let op = Im2Col {
2449 h_k: params.k_h,
2450 w_k: params.k_w,
2451 padding: params.padding,
2452 stride: params.stride,
2453 dilation: params.dilation,
2454 };
2455 let col = op.map(self, l)?;
2456 let b = params.b_size;
2457 let n = params.c_out;
2458 let (h_out, w_out) = (params.out_h(), params.out_w());
2459 let k = op.h_k * op.w_k * params.c_in;
2460 let m = h_out * w_out;
2461 let col_l = Layout::contiguous((b, m, k));
2462 let res = if kernel_l.is_contiguous() {
2463 let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2464 .transpose(1, 2)?
2465 .broadcast_as((b, k, n))?;
2466 col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2467 } else {
2468 let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
2470 kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
2471 let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2472 .transpose(1, 2)?
2473 .broadcast_as((b, k, n))?;
2474 col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2475 };
2476 let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
2477 .transpose(1, 2)?
2478 .transpose(1, 3)?;
2479 let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
2480 res.copy_strided_src(&mut res_t, 0, &res_l)?;
2481 Ok(res_t)
2482 }
2483
2484 fn conv_transpose2d(
2485 &self,
2486 l: &Layout,
2487 kernel: &Self,
2488 kernel_l: &Layout,
2489 params: &crate::conv::ParamsConvTranspose2D,
2490 ) -> Result<Self> {
2491 ConvTranspose2D(params).map(self, l, kernel, kernel_l)
2492 }
2493
2494 fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
2495 match ids {
2496 Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2497 Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2498 Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2499 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
2500 }
2501 }
2502
2503 fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
2504 match ids {
2505 Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
2506 Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
2507 Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
2508 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
2509 }
2510 }
2511
2512 fn scatter_add(
2513 &self,
2514 l: &Layout,
2515 ids: &Self,
2516 ids_l: &Layout,
2517 src: &Self,
2518 src_l: &Layout,
2519 dim: usize,
2520 ) -> Result<Self> {
2521 match ids {
2522 Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
2523 Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
2524 Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
2525 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
2526 }
2527 }
2528
2529 fn index_add(
2530 &self,
2531 l: &Layout,
2532 ids: &Self,
2533 ids_l: &Layout,
2534 src: &Self,
2535 src_l: &Layout,
2536 dim: usize,
2537 ) -> Result<Self> {
2538 match ids {
2539 Self::U8(ids) => {
2540 let ids = match ids_l.contiguous_offsets() {
2541 Some((a, b)) => &ids[a..b],
2542 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2543 };
2544 IndexAdd { ids, dim }.map(self, l, src, src_l)
2545 }
2546 Self::U32(ids) => {
2547 let ids = match ids_l.contiguous_offsets() {
2548 Some((a, b)) => &ids[a..b],
2549 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2550 };
2551 IndexAdd { ids, dim }.map(self, l, src, src_l)
2552 }
2553 Self::I64(ids) => {
2554 let ids = match ids_l.contiguous_offsets() {
2555 Some((a, b)) => &ids[a..b],
2556 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2557 };
2558 IndexAdd { ids, dim }.map(self, l, src, src_l)
2559 }
2560 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
2561 }
2562 }
2563
2564 fn matmul(
2565 &self,
2566 rhs: &Self,
2567 bmnk: (usize, usize, usize, usize),
2568 lhs_l: &Layout,
2569 rhs_l: &Layout,
2570 ) -> Result<Self> {
2571 MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
2572 }
2573
2574 fn device(&self) -> &Self::Device {
2575 &CpuDevice
2576 }
2577
2578 fn try_clone(&self, _: &Layout) -> Result<Self> {
2579 Ok(self.clone())
2580 }
2581
2582 fn to_cpu_storage(&self) -> Result<CpuStorage> {
2583 Ok(self.clone())
2584 }
2585}
2586
2587impl BackendDevice for CpuDevice {
2588 type Storage = CpuStorage;
2589
2590 fn location(&self) -> crate::DeviceLocation {
2591 crate::DeviceLocation::Cpu
2592 }
2593
2594 fn same_device(&self, _: &Self) -> bool {
2595 true
2596 }
2597
2598 fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
2599 Ok(s.clone())
2600 }
2601
2602 fn new(_: usize) -> Result<Self> {
2603 Ok(Self)
2604 }
2605
2606 fn set_seed(&self, _seed: u64) -> Result<()> {
2607 crate::bail!("cannot seed the CPU rng with set_seed")
2608 }
2609
2610 fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
2611 use rand::prelude::*;
2612
2613 let elem_count = shape.elem_count();
2614 let mut rng = rand::thread_rng();
2615 match dtype {
2616 DType::U8 | DType::U32 | DType::I64 => {
2617 Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
2618 }
2619 DType::BF16 => {
2620 let mut data = Vec::with_capacity(elem_count);
2621 let uniform =
2622 rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max));
2623 for _i in 0..elem_count {
2624 data.push(rng.sample::<bf16, _>(uniform))
2625 }
2626 Ok(CpuStorage::BF16(data))
2627 }
2628 DType::F16 => {
2629 let mut data = Vec::with_capacity(elem_count);
2630 let uniform =
2631 rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max));
2632 for _i in 0..elem_count {
2633 data.push(rng.sample::<f16, _>(uniform))
2634 }
2635 Ok(CpuStorage::F16(data))
2636 }
2637 DType::F32 => {
2638 let mut data = Vec::with_capacity(elem_count);
2639 let uniform = rand::distributions::Uniform::new(min as f32, max as f32);
2640 for _i in 0..elem_count {
2641 data.push(rng.sample::<f32, _>(uniform))
2642 }
2643 Ok(CpuStorage::F32(data))
2644 }
2645 DType::F64 => {
2646 let mut data = Vec::with_capacity(elem_count);
2647 let uniform = rand::distributions::Uniform::new(min, max);
2648 for _i in 0..elem_count {
2649 data.push(rng.sample::<f64, _>(uniform))
2650 }
2651 Ok(CpuStorage::F64(data))
2652 }
2653 }
2654 }
2655
2656 fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CpuStorage> {
2657 use rand::prelude::*;
2658
2659 let elem_count = shape.elem_count();
2660 let mut rng = rand::thread_rng();
2661 match dtype {
2662 DType::U8 | DType::U32 | DType::I64 => {
2663 Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
2664 }
2665 DType::BF16 => {
2666 let mut data = Vec::with_capacity(elem_count);
2667 let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
2668 .map_err(Error::wrap)?;
2669 for _i in 0..elem_count {
2670 data.push(normal.sample(&mut rng))
2671 }
2672 Ok(CpuStorage::BF16(data))
2673 }
2674 DType::F16 => {
2675 let mut data = Vec::with_capacity(elem_count);
2676 let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
2677 .map_err(Error::wrap)?;
2678 for _i in 0..elem_count {
2679 data.push(normal.sample(&mut rng))
2680 }
2681 Ok(CpuStorage::F16(data))
2682 }
2683 DType::F32 => {
2684 let mut data = Vec::with_capacity(elem_count);
2685 let normal =
2686 rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
2687 for _i in 0..elem_count {
2688 data.push(normal.sample(&mut rng))
2689 }
2690 Ok(CpuStorage::F32(data))
2691 }
2692 DType::F64 => {
2693 let mut data = Vec::with_capacity(elem_count);
2694 let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
2695 for _i in 0..elem_count {
2696 data.push(normal.sample(&mut rng))
2697 }
2698 Ok(CpuStorage::F64(data))
2699 }
2700 }
2701 }
2702
2703 fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
2704 let elem_count = shape.elem_count();
2705 let storage = match dtype {
2706 DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
2707 DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
2708 DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
2709 DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
2710 DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
2711 DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
2712 DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
2713 };
2714 Ok(storage)
2715 }
2716
2717 fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
2718 let elem_count = shape.elem_count();
2719 let storage = match dtype {
2720 DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
2721 DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
2722 DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
2723 DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
2724 DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
2725 DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
2726 DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
2727 };
2728 Ok(storage)
2729 }
2730}
2731
2732#[macro_export]
2733macro_rules! map_dtype {
2734 ($name:expr, $storage:ident, $fn:expr, ($($dtypes:ident),+)) => {
2735 match $storage {
2736 $(CpuStorage::$dtypes(__e) => CpuStorage::$dtypes($fn(__e)),)*
2737 s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())?,
2738 }
2739 };
2740}