1use crate::backend::{BackendDevice, BackendStorage};
3use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
4use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
5use float8::F8E4M3;
6use half::{bf16, f16};
7use rayon::prelude::*;
8
9mod utils;
10pub use utils::{
11 binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
12};
13mod conv2d;
14use conv2d::Conv2D;
15
16const USE_IM2COL_CONV1D: bool = true;
17const USE_COL2IM_CONV1D_TR: bool = true;
18
19#[derive(Debug, Clone)]
22pub enum CpuStorage {
23 U8(Vec<u8>),
24 U32(Vec<u32>),
25 I16(Vec<i16>),
26 I32(Vec<i32>),
27 I64(Vec<i64>),
28 BF16(Vec<bf16>),
29 F16(Vec<f16>),
30 F32(Vec<f32>),
31 F64(Vec<f64>),
32 F8E4M3(Vec<F8E4M3>),
33 F6E2M3(Vec<u8>),
35 F6E3M2(Vec<u8>),
36 F4(Vec<u8>),
37 F8E8M0(Vec<u8>),
38}
39
40#[derive(Debug, Clone)]
41pub enum CpuStorageRef<'a> {
42 U8(&'a [u8]),
43 U32(&'a [u32]),
44 I16(&'a [i16]),
45 I32(&'a [i32]),
46 I64(&'a [i64]),
47 BF16(&'a [bf16]),
48 F16(&'a [f16]),
49 F32(&'a [f32]),
50 F64(&'a [f64]),
51 F8E4M3(&'a [F8E4M3]),
52 F6E2M3(&'a [u8]),
54 F6E3M2(&'a [u8]),
55 F4(&'a [u8]),
56 F8E8M0(&'a [u8]),
57}
58
59#[derive(Debug, Clone)]
60pub struct CpuDevice;
61
62struct Cmp(CmpOp);
63impl Map2U8 for Cmp {
64 const OP: &'static str = "cmp";
65 #[inline(always)]
66 fn f<T: WithDType>(
67 &self,
68 lhs: &[T],
69 lhs_l: &Layout,
70 rhs: &[T],
71 rhs_l: &Layout,
72 ) -> Result<Vec<u8>> {
73 let dst = match self.0 {
74 CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)),
75 CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)),
76 CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)),
77 CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)),
78 CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)),
79 CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)),
80 };
81 Ok(dst)
82 }
83}
84
85struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
86
87impl<I: IntDType> Map2 for WCond<'_, I> {
88 const OP: &'static str = "where";
89 #[inline(always)]
90 fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
91 let vs = match (
92 self.1.contiguous_offsets(),
93 t_l.contiguous_offsets(),
94 f_l.contiguous_offsets(),
95 ) {
96 (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
97 let pred = &self.0[o1..o2];
98 let t = &t[o_t1..o_t2];
99 let f = &f[o_f1..o_f2];
100 pred.iter()
101 .zip(t.iter().zip(f.iter()))
102 .map(|(p, (&t, &f))| if p.is_true() { t } else { f })
103 .collect::<Vec<_>>()
104 }
105 _ => self
106 .1
107 .strided_index()
108 .zip(t_l.strided_index().zip(f_l.strided_index()))
109 .map(|(i_p, (i_t, i_f))| {
110 if self.0[i_p].is_true() {
111 t[i_t]
112 } else {
113 f[i_f]
114 }
115 })
116 .collect::<Vec<_>>(),
117 };
118 Ok(vs)
119 }
120}
121
122struct ReduceIndex {
123 reduce_dim_index: usize,
124 use_min: bool,
125 return_index: bool,
126}
127
128impl ReduceIndex {
129 #[inline(always)]
131 fn fold_impl<T, U, F, G>(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result<Vec<U>>
132 where
133 T: Clone + Copy,
134 U: Clone + Copy,
135 F: Fn(T, T) -> bool,
136 G: Fn(T, usize) -> U,
137 {
138 let reduce_dim_size = src_l.dims()[self.reduce_dim_index];
139 let reduce_dim_stride = src_l.stride()[self.reduce_dim_index];
140 let dst_len = src_l.shape().elem_count() / reduce_dim_size;
141 let mut dst: Vec<U> = Vec::with_capacity(dst_len);
142 let dst_to_set = dst.spare_capacity_mut();
143 let dst_to_set =
144 unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };
145 match src_l.contiguous_offsets() {
146 Some((o1, o2)) => {
147 let src = &src[o1..o2];
148 if reduce_dim_stride == 1 {
149 for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
150 let start_src_i = start_src_i * reduce_dim_size;
151 let src = &src[start_src_i..start_src_i + reduce_dim_size];
152 let mut acc = 0;
153 let mut val = src[0];
154 for (src_i, &s) in src.iter().enumerate() {
155 if f(val, s) {
156 acc = src_i;
157 val = s
158 }
159 }
160 *dst_v = g(val, acc)
161 }
162 } else {
163 for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
164 let (p, q) = (
165 start_src_i / reduce_dim_stride,
166 start_src_i % reduce_dim_stride,
167 );
168 let start_src_i = p * reduce_dim_stride * reduce_dim_size + q;
170 let src = &src[start_src_i..];
171 let mut acc = 0;
172 let mut val = src[0];
173 for src_i in 0..reduce_dim_size {
174 let s = src[src_i * reduce_dim_stride];
175 if f(val, s) {
176 acc = src_i;
177 val = s
178 }
179 }
180 *dst_v = g(val, acc)
181 }
182 }
183 }
184 None => {
185 let l = src_l.narrow(self.reduce_dim_index, 0, 1)?;
186 for (unstr_index, src_index) in l.strided_index().enumerate() {
187 let src = &src[src_index..];
188 let mut acc = 0;
189 let mut val = src[0];
190 for src_i in 0..reduce_dim_size {
191 let s = src[src_i * reduce_dim_stride];
192 if f(val, s) {
193 acc = src_i;
194 val = s
195 }
196 }
197 dst_to_set[unstr_index] = g(val, acc)
198 }
199 }
200 }
201 unsafe { dst.set_len(dst_len) };
202 Ok(dst)
203 }
204}
205
206impl Map1Any for ReduceIndex {
207 #[inline(always)]
208 fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
209 &self,
210 src: &[T],
211 src_l: &Layout,
212 wrap: W,
213 ) -> Result<CpuStorage> {
214 if src_l.shape().elem_count() == 0 {
215 Err(Error::EmptyTensor { op: "reduce" }.bt())?
216 }
217 let dst = match (self.return_index, self.use_min) {
218 (false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?),
219 (false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?),
220 (true, true) => {
221 CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?)
222 }
223 (true, false) => {
224 CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?)
225 }
226 };
227 Ok(dst)
228 }
229}
230
231struct ReduceSum<'a> {
232 dst_shape: &'a Shape,
233 reduce_dims: &'a [usize],
234 reduce_dims_and_stride: Vec<(usize, usize)>,
235}
236
237impl ReduceSum<'_> {
238 #[inline(always)]
239 fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
240 where
241 T: WithDType,
242 {
243 let mut dst = vec![start_elt; self.dst_shape.elem_count()];
244 match src_l.contiguous_offsets() {
245 Some((o1, o2)) => {
246 let src = &src[o1..o2];
247 let reduce_over_last_dims = self
251 .reduce_dims
252 .iter()
253 .rev()
254 .enumerate()
255 .all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
256 if reduce_over_last_dims {
257 let reduce_sz = self
258 .reduce_dims_and_stride
259 .iter()
260 .map(|(u, _)| u)
261 .product::<usize>();
262 for (dst_i, dst_v) in dst.iter_mut().enumerate() {
263 let src_i = dst_i * reduce_sz;
264 unsafe {
265 T::vec_reduce_sum(
266 src[src_i..src_i + reduce_sz].as_ptr(),
267 dst_v,
268 reduce_sz,
269 )
270 };
271 }
272 return Ok(dst);
273 };
274 for (unstr_index, &src) in src.iter().enumerate() {
275 let mut dst_index = unstr_index;
276 for &(dim, stride) in self.reduce_dims_and_stride.iter() {
278 let (pre, post) = (dst_index / stride, dst_index % stride);
280 dst_index = (pre / dim) * stride + post;
281 }
282 dst[dst_index] += src;
283 }
284 }
285 None => {
286 for (unstr_index, src_index) in src_l.strided_index().enumerate() {
287 let mut dst_index = unstr_index;
288 for &(dim, stride) in self.reduce_dims_and_stride.iter() {
290 let (pre, post) = (dst_index / stride, dst_index % stride);
292 dst_index = (pre / dim) * stride + post;
293 }
294 dst[dst_index] += src[src_index];
295 }
296 }
297 }
298 Ok(dst)
299 }
300}
301
302impl Map1 for ReduceSum<'_> {
303 #[inline(always)]
304 fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
305 self.fold_impl(src, src_l, T::zero())
306 }
307}
308
309struct Affine(f64, f64);
310
311impl Map1 for Affine {
312 fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
313 let mul = T::from_f64(self.0);
314 let add = T::from_f64(self.1);
315 Ok(unary_map(vs, layout, |v| v * mul + add))
316 }
317}
318
319struct AvgPool2D((usize, usize), (usize, usize));
320
321impl Map1 for AvgPool2D {
322 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
323 let (k_h, k_w) = self.0;
325 let (s_h, s_w) = self.1;
326 let (b_sz, c, h, w) = layout.shape().dims4()?;
327 let stride = layout.stride();
328 let (stride_h, stride_w) = (stride[2], stride[3]);
329 let h_out = (h - k_h) / s_h + 1;
330 let w_out = (w - k_w) / s_w + 1;
331 let src_index = layout.start_offset();
332 let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
333 let scale = 1f64 / (k_h * k_w) as f64;
334 let scale = T::from_f64(scale);
335 for b_idx in 0..b_sz {
336 let dst = &mut dst[b_idx * c * h_out * w_out..];
337 let src_index = src_index + b_idx * stride[0];
338 for c_idx in 0..c {
339 let dst = &mut dst[c_idx * h_out * w_out..];
340 let src_index = src_index + c_idx * stride[1];
341 for h_idx in 0..h_out {
342 for w_idx in 0..w_out {
343 let mut sum = T::zero();
344 for m in 0..k_h {
345 for n in 0..k_w {
346 let m = s_h * h_idx + m;
347 let n = s_w * w_idx + n;
348 sum += src[src_index + m * stride_h + n * stride_w]
349 }
350 }
351 dst[h_idx * w_out + w_idx] = sum * scale;
352 }
353 }
354 }
355 }
356 Ok(dst)
357 }
358}
359
360struct MaxPool2D((usize, usize), (usize, usize));
361
362impl Map1 for MaxPool2D {
363 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
364 let (k_h, k_w) = self.0;
366 let (s_h, s_w) = self.1;
367 let (b_sz, c, h, w) = layout.shape().dims4()?;
368 let stride = layout.stride();
369 let (stride_h, stride_w) = (stride[2], stride[3]);
370 let h_out = (h - k_h) / s_h + 1;
371 let w_out = (w - k_w) / s_w + 1;
372 let src_index = layout.start_offset();
373 let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
374 for b_idx in 0..b_sz {
375 let dst = &mut dst[b_idx * c * h_out * w_out..];
376 let src_index = src_index + b_idx * stride[0];
377 for c_idx in 0..c {
378 let dst = &mut dst[c_idx * h_out * w_out..];
379 let src_index = src_index + c_idx * stride[1];
380 for h_idx in 0..h_out {
381 for w_idx in 0..w_out {
382 let mut largest =
383 src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
384 for m in 0..k_h {
385 for n in 0..k_w {
386 let m = s_h * h_idx + m;
387 let n = s_w * w_idx + n;
388 if largest < src[src_index + m * stride_h + n * stride_w] {
389 largest = src[src_index + m * stride_h + n * stride_w]
390 }
391 }
392 }
393 dst[h_idx * w_out + w_idx] = largest;
394 }
395 }
396 }
397 }
398 Ok(dst)
399 }
400}
401
402struct UpsampleNearest1D(usize);
403
404impl Map1 for UpsampleNearest1D {
405 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
406 let dst_sz = self.0;
408 let (b_sz, c, src_sz) = layout.shape().dims3()?;
409 let stride = layout.stride();
410 let stride_sz = stride[2];
411 let src_index = layout.start_offset();
412 let scale_sz = src_sz as f64 / dst_sz as f64;
413 let mut dst = vec![T::zero(); b_sz * c * dst_sz];
414 let src_idxs = (0..dst_sz)
415 .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
416 .collect::<Vec<_>>();
417 for b_idx in 0..b_sz {
418 let dst = &mut dst[b_idx * c * dst_sz..];
419 let src_index = src_index + b_idx * stride[0];
420 for c_idx in 0..c {
421 let dst = &mut dst[c_idx * dst_sz..];
422 let src_index = src_index + c_idx * stride[1];
423 for (idx, src_idx) in src_idxs.iter().enumerate() {
424 dst[idx] = src[src_index + src_idx * stride_sz]
425 }
426 }
427 }
428 Ok(dst)
429 }
430}
431
432struct UpsampleNearest2D(usize, usize);
433
434impl Map1 for UpsampleNearest2D {
435 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
436 let (dst_h, dst_w) = (self.0, self.1);
438 let (b_sz, c, src_h, src_w) = layout.shape().dims4()?;
439 let stride = layout.stride();
440 let (stride_h, stride_w) = (stride[2], stride[3]);
441 let src_index = layout.start_offset();
442 let scale_h = src_h as f64 / dst_h as f64;
443 let scale_w = src_w as f64 / dst_w as f64;
444 let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
445 let src_h_idxs = (0..dst_h)
446 .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
447 .collect::<Vec<_>>();
448 let src_w_idxs = (0..dst_w)
449 .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
450 .collect::<Vec<_>>();
451 for b_idx in 0..b_sz {
452 let dst = &mut dst[b_idx * c * dst_h * dst_w..];
453 let src_index = src_index + b_idx * stride[0];
454 for c_idx in 0..c {
455 let dst = &mut dst[c_idx * dst_h * dst_w..];
456 let src_index = src_index + c_idx * stride[1];
457 for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() {
458 for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() {
459 let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w;
460 dst[h_idx * dst_w + w_idx] = src[src_index]
461 }
462 }
463 }
464 }
465 Ok(dst)
466 }
467}
468
469struct UpsampleBilinear2D {
470 target_h: usize,
471 target_w: usize,
472 align_corners: bool,
473 scale_h_factor: Option<f64>,
474 scale_w_factor: Option<f64>,
475}
476
477impl Map1 for UpsampleBilinear2D {
478 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
479 let (batch, channels, height_in, width_in) = layout.shape().dims4()?;
480 let height_out = self.target_h;
481 let width_out = self.target_w;
482
483 if height_in == height_out && width_in == width_out {
485 return Ok(src.to_vec());
486 }
487
488 let stride = layout.stride();
489 let src_offset = layout.start_offset();
490
491 let scale_h = if self.align_corners {
493 if height_out > 1 {
494 (height_in - 1) as f64 / (height_out - 1) as f64
495 } else {
496 0.0
497 }
498 } else {
499 if let Some(scale_factor) = self.scale_h_factor {
503 1.0 / scale_factor
504 } else {
505 height_in as f64 / height_out as f64
506 }
507 };
508
509 let scale_w = if self.align_corners {
510 if width_out > 1 {
511 (width_in - 1) as f64 / (width_out - 1) as f64
512 } else {
513 0.0
514 }
515 } else if let Some(scale_factor) = self.scale_w_factor {
516 1.0 / scale_factor
517 } else {
518 width_in as f64 / width_out as f64
519 };
520
521 let mut h_indices = Vec::with_capacity(height_out);
523 for h_out in 0..height_out {
524 let src_h = if self.align_corners {
525 scale_h * h_out as f64
526 } else {
527 scale_h * (h_out as f64 + 0.5) - 0.5
528 };
529 let src_h_clamped = src_h.max(0.0);
530 let h0 = src_h_clamped.floor() as usize;
531 let h1 = (h0 + 1).min(height_in - 1);
532 let weight_h = (src_h_clamped - h0 as f64).clamp(0.0, 1.0);
533 h_indices.push((h0, h1, weight_h));
534 }
535
536 let mut w_indices = Vec::with_capacity(width_out);
538 for w_out in 0..width_out {
539 let src_w = if self.align_corners {
540 scale_w * w_out as f64
541 } else {
542 scale_w * (w_out as f64 + 0.5) - 0.5
543 };
544 let src_w_clamped = src_w.max(0.0);
545 let w0 = src_w_clamped.floor() as usize;
546 let w1 = (w0 + 1).min(width_in - 1);
547 let weight_w = (src_w_clamped - w0 as f64).clamp(0.0, 1.0);
548 w_indices.push((w0, w1, weight_w));
549 }
550
551 let mut dst = vec![T::zero(); batch * channels * height_out * width_out];
553
554 for b in 0..batch {
556 for c in 0..channels {
557 let base_idx = src_offset + b * stride[0] + c * stride[1];
558 let dst_base = (b * channels + c) * height_out * width_out;
559
560 for (h_out, &(h0, h1, weight_h)) in h_indices.iter().enumerate() {
561 for (w_out, &(w0, w1, weight_w)) in w_indices.iter().enumerate() {
562 let idx_00 = base_idx + h0 * stride[2] + w0 * stride[3];
564 let idx_10 = base_idx + h0 * stride[2] + w1 * stride[3];
565 let idx_01 = base_idx + h1 * stride[2] + w0 * stride[3];
566 let idx_11 = base_idx + h1 * stride[2] + w1 * stride[3];
567
568 let v00 = src[idx_00].to_f64();
569 let v10 = src[idx_10].to_f64();
570 let v01 = src[idx_01].to_f64();
571 let v11 = src[idx_11].to_f64();
572
573 let v_top = v00 * (1.0 - weight_w) + v10 * weight_w;
575 let v_bottom = v01 * (1.0 - weight_w) + v11 * weight_w;
576 let value = v_top * (1.0 - weight_h) + v_bottom * weight_h;
577
578 dst[dst_base + h_out * width_out + w_out] = T::from_f64(value);
579 }
580 }
581 }
582 }
583
584 Ok(dst)
585 }
586}
587
588struct Gather<'a, I: IntDType> {
589 ids: &'a [I],
590 ids_l: &'a Layout,
591 dim: usize,
592}
593
594impl<I: IntDType> Map1 for Gather<'_, I> {
595 fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
596 let ids = match self.ids_l.contiguous_offsets() {
597 Some((a, b)) => &self.ids[a..b],
598 None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
599 };
600 let src = match src_l.contiguous_offsets() {
601 Some((a, b)) => &src[a..b],
602 None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
603 };
604 let dim = self.dim;
605 let ids_dims = self.ids_l.dims();
606 let src_dims = src_l.dims();
607 let dst_len: usize = ids_dims.iter().product();
608 let dst_left_len: usize = ids_dims[..dim].iter().product();
609 let dst_dim_len = ids_dims[dim];
610 let dst_right_len: usize = ids_dims[dim + 1..].iter().product();
611
612 let src_dim_len = src_dims[dim];
613 let src_right_len: usize = src_dims[dim + 1..].iter().product();
614
615 let mut dst = vec![T::zero(); dst_len];
616 for left_i in 0..dst_left_len {
617 let start_src_idx = left_i * src_right_len * src_dim_len;
618 let start_dst_idx = left_i * dst_right_len * dst_dim_len;
619 for i in 0..dst_dim_len {
620 let start_dst_idx = start_dst_idx + i * dst_right_len;
621 for right_i in 0..dst_right_len {
622 let dst_idx = start_dst_idx + right_i;
623 let index = ids[dst_idx];
624 if index == I::max_value() {
625 dst[dst_idx] = T::zero();
626 } else {
627 let index = index.as_usize();
628 if index >= src_dim_len {
629 Err(Error::InvalidIndex {
630 index,
631 size: src_dim_len,
632 op: "gather",
633 }
634 .bt())?
635 }
636 let src_idx = start_src_idx + index * src_right_len + right_i;
637 dst[dst_idx] = src[src_idx]
638 }
639 }
640 }
641 }
642 Ok(dst)
643 }
644}
645
646struct IndexSelect<'a, T: IntDType> {
647 ids: &'a [T],
648 ids_l: &'a Layout,
649 dim: usize,
650}
651
652impl<I: IntDType> Map1 for IndexSelect<'_, I> {
653 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
654 let src = match layout.contiguous_offsets() {
655 Some((a, b)) => &src[a..b],
656 None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
657 };
658 let dim = self.dim;
659 let n_ids = match self.ids_l.dims() {
660 [n_ids] => *n_ids,
661 d => Err(Error::UnexpectedNumberOfDims {
662 expected: 1,
663 got: d.len(),
664 shape: self.ids_l.shape().clone(),
665 }
666 .bt())?,
667 };
668 let stride_ids = self.ids_l.stride()[0];
669 let mut dst_dims = layout.dims().to_vec();
670 let src_dim = dst_dims[dim];
671 dst_dims[dim] = n_ids;
672 let dst_len: usize = dst_dims.iter().product();
673 let left_len: usize = dst_dims[..dim].iter().product();
674 let right_len: usize = dst_dims[dim + 1..].iter().product();
675 let mut dst = vec![T::zero(); dst_len];
676 for left_i in 0..left_len {
677 let start_src_idx = left_i * right_len * src_dim;
678 let start_dst_idx = left_i * right_len * n_ids;
679 for i in 0..n_ids {
680 let start_dst_idx = start_dst_idx + i * right_len;
681 let index = self.ids[self.ids_l.start_offset() + stride_ids * i];
682 if index == I::max_value() {
683 dst[start_dst_idx..start_dst_idx + right_len].fill(T::zero());
684 } else {
685 let index = index.as_usize();
686 if index >= src_dim {
687 Err(Error::InvalidIndex {
688 index,
689 size: src_dim,
690 op: "index-select",
691 }
692 .bt())?
693 }
694 let start_src_idx = start_src_idx + index * right_len;
695 dst[start_dst_idx..start_dst_idx + right_len]
696 .copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
697 }
698 }
699 }
700 Ok(dst)
701 }
702}
703
704trait ElemUpdate {
705 fn f<T: WithDType>(dst: &mut T, src: T);
706}
707
708struct Set;
709struct Add;
710
711impl ElemUpdate for Set {
712 fn f<T: WithDType>(dst: &mut T, src: T) {
713 *dst = src
714 }
715}
716
717impl ElemUpdate for Add {
718 fn f<T: WithDType>(dst: &mut T, src: T) {
719 *dst += src
720 }
721}
722
723struct Scatter<'a, I: IntDType, M: ElemUpdate> {
724 ids: &'a [I],
725 ids_l: &'a Layout,
726 dim: usize,
727 _phantom: std::marker::PhantomData<M>,
728}
729
730impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
731 fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self {
732 Self {
733 ids,
734 ids_l,
735 dim,
736 _phantom: Default::default(),
737 }
738 }
739}
740
741impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
742 const OP: &'static str = "scatter";
743 fn f<T: WithDType>(
744 &self,
745 dst: &mut [T],
746 dst_l: &Layout,
747 src: &[T],
748 src_l: &Layout,
749 ) -> Result<()> {
750 let dst = match dst_l.contiguous_offsets() {
751 None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
752 Some((o1, o2)) => &mut dst[o1..o2],
753 };
754
755 let src = match src_l.contiguous_offsets() {
756 None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
757 Some((o1, o2)) => &src[o1..o2],
758 };
759
760 let dim = self.dim;
761 let ids_dims = self.ids_l.dims();
762 let dst_dims = dst_l.dims();
763 let dst_dim_len = dst_dims[dim];
764 let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
765
766 let ids_left_len: usize = ids_dims[..dim].iter().product();
767 let ids_dim_len = ids_dims[dim];
768 let ids_right_len: usize = ids_dims[dim + 1..].iter().product();
769
770 let ids = match self.ids_l.contiguous_offsets() {
771 Some((a, b)) => &self.ids[a..b],
772 None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
773 };
774 for left_i in 0..ids_left_len {
775 let start_ids_idx = left_i * ids_right_len * ids_dim_len;
776 let start_dst_idx = left_i * dst_right_len * dst_dim_len;
777 for i in 0..ids_dim_len {
778 let start_ids_idx = start_ids_idx + i * ids_right_len;
779 for right_i in 0..dst_right_len {
780 let ids_idx = start_ids_idx + right_i;
781 let index = ids[ids_idx];
782 if index == I::max_value() {
783 continue;
784 }
785 let index = index.as_usize();
786 if index >= dst_dim_len {
787 Err(Error::InvalidIndex {
788 index,
789 size: dst_dim_len,
790 op: "gather",
791 }
792 .bt())?
793 }
794 let dst_idx = start_dst_idx + index * dst_right_len + right_i;
795 M::f(&mut dst[dst_idx], src[ids_idx])
796 }
797 }
798 }
799
800 Ok(())
801 }
802}
803
804struct IndexAdd<'a, I: IntDType> {
805 ids: &'a [I],
806 dim: usize,
807}
808
809impl<I: IntDType> Map2 for IndexAdd<'_, I> {
810 const OP: &'static str = "index-add";
811 fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
814 let dst_len = l1.shape().elem_count();
815 let mut dst = vec![T::zero(); dst_len];
816 copy_strided_src_(v1, &mut dst, 0, l1);
817 let src = match src_l.contiguous_offsets() {
818 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
819 Some((o1, o2)) => &src[o1..o2],
820 };
821 let dim = self.dim;
822 let max_idx = l1.dims()[dim];
823 let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
824 let src_dim_sz = src_l.dims()[dim];
825 let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
826 if dim == 0 {
827 for (src_idx, dst_idx) in self.ids.iter().enumerate() {
828 if *dst_idx == I::max_value() {
829 continue;
830 }
831 let dst_idx = dst_idx.as_usize();
832 if dst_idx >= max_idx {
833 Err(Error::InvalidIndex {
834 index: dst_idx,
835 op: "index-add",
836 size: max_idx,
837 })?
838 }
839 let src_idx = src_idx * post_dim;
840 let dst_idx = dst_idx * post_dim;
841 let src = &src[src_idx..src_idx + post_dim];
842 let dst = &mut dst[dst_idx..dst_idx + post_dim];
843 for (d, &s) in dst.iter_mut().zip(src.iter()) {
844 *d += s
845 }
846 }
847 } else {
848 for (src_idx, dst_idx) in self.ids.iter().enumerate() {
849 if *dst_idx == I::max_value() {
850 continue;
851 }
852 let dst_idx = dst_idx.as_usize();
853 if dst_idx >= max_idx {
854 Err(Error::InvalidIndex {
855 index: dst_idx,
856 op: "index-add",
857 size: max_idx,
858 })?
859 }
860 for pre_i in 0..pre_dim {
861 let pre_src_i = (pre_i * src_dim_sz + src_idx) * post_dim;
862 let pre_dst_i = (pre_i * max_idx + dst_idx) * post_dim;
863 let src = &src[pre_src_i..pre_src_i + post_dim];
864 let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim];
865 for (d, &s) in dst.iter_mut().zip(src.iter()) {
866 *d += s
867 }
868 }
869 }
870 }
871 Ok(dst)
872 }
873}
874
875#[allow(clippy::too_many_arguments)]
876fn copy2d_<T: Copy>(
877 src: &[T],
878 dst: &mut [T],
879 d1: usize,
880 d2: usize,
881 src_stride1: usize,
882 dst_stride1: usize,
883 src_offset: usize,
884 dst_offset: usize,
885) {
886 for i1 in 0..d1 {
887 let dst_idx = i1 * dst_stride1 + dst_offset;
888 let src_idx = i1 * src_stride1 + src_offset;
889 let dst = &mut dst[dst_idx..dst_idx + d2];
890 let src = &src[src_idx..src_idx + d2];
891 dst.copy_from_slice(src)
892 }
893}
894
895fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
896 match src_l.strided_blocks() {
897 crate::StridedBlocks::SingleBlock { start_offset, len } => {
898 let to_copy = (dst.len() - dst_offset).min(len);
899 dst[dst_offset..dst_offset + to_copy]
900 .copy_from_slice(&src[start_offset..start_offset + to_copy])
901 }
902 crate::StridedBlocks::MultipleBlocks {
903 block_start_index,
904 block_len: 1,
905 } => {
906 for (dst_index, src_index) in block_start_index.enumerate() {
907 let dst_index = dst_index + dst_offset;
908 if dst_index >= dst.len() {
909 break;
910 }
911 dst[dst_index] = src[src_index]
912 }
913 }
914 crate::StridedBlocks::MultipleBlocks {
915 block_start_index,
916 block_len,
917 } => {
918 let mut dst_index = dst_offset;
919 for src_index in block_start_index {
920 let next_dst_index = dst_index + block_len;
921 if dst_index >= dst.len() {
922 break;
923 }
924 let to_copy = usize::min(block_len, dst.len() - dst_index);
925 dst[dst_index..dst_index + to_copy]
926 .copy_from_slice(&src[src_index..src_index + to_copy]);
927 dst_index = next_dst_index
928 }
929 }
930 }
931}
932
933struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
934
935impl Map2 for Conv1D<'_> {
936 const OP: &'static str = "conv1d";
937 fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
938 let p = self.0;
939 let inp = &inp[inp_l.start_offset()..];
940 let k = &k[k_l.start_offset()..];
941 let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
942 let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
943 let l_out = p.l_out();
944 let dst_elems = p.c_out * l_out * p.b_size;
945 let dst = vec![T::zero(); dst_elems];
947
948 let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
950 for b_idx in 0..p.b_size {
951 for src_l in 0..p.l_in {
952 for src_c_idx in 0..p.c_in {
953 let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
954 inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
955 }
956 }
957 }
958
959 for offset in 0..p.k_size {
960 (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
961 let dst_idx = dst_c_idx * l_out;
962 let k_cont = (0..p.c_in)
963 .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
964 .collect::<Vec<_>>();
965 for b_idx in 0..p.b_size {
966 let dst_idx = dst_idx + b_idx * p.c_out * l_out;
967 for dst_l in 0..l_out {
968 let dst_idx = dst_idx + dst_l;
969 let src_l = p.stride * dst_l + offset * p.dilation;
970 if src_l < p.padding || src_l >= p.padding + p.l_in {
971 continue;
972 }
973 let src_l = src_l - p.padding;
974 let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
975 assert!(inp_cont.len() >= p.c_in);
976 assert!(k_cont.len() >= p.c_in);
977 let mut d = T::zero();
978 unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
979 let dst_p = dst.as_ptr();
980 unsafe {
984 let ptr = dst_p.add(dst_idx) as *mut T;
985 *ptr += d
986 }
987 }
988 }
989 })
990 }
991 Ok(dst)
992 }
993}
994
995struct Im2Col1D {
996 l_k: usize,
997 stride: usize,
998 dilation: usize,
999 padding: usize,
1000}
1001
1002impl Im2Col1D {
1003 fn l_out(&self, l: usize) -> usize {
1004 (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
1005 }
1006}
1007
1008impl Map1 for Im2Col1D {
1009 fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
1010 let &Self {
1011 l_k,
1012 stride,
1013 dilation,
1014 padding,
1015 } = self;
1016 let (b, c, l) = layout.shape().dims3()?;
1017 let l_out = self.l_out(l);
1018 let src = &vs[layout.start_offset()..];
1019 let mut dst = vec![T::zero(); b * l_out * c * l_k];
1020 let (src_s0, src_s1, src_s2) = {
1021 let s = layout.stride();
1022 (s[0], s[1], s[2])
1023 };
1024 for b_idx in 0..b {
1030 let src_idx = b_idx * src_s0;
1031 let dst_idx = b_idx * l_out * c * l_k;
1032 for l_idx in 0..l_out {
1033 let dst_idx = dst_idx + l_idx * c * l_k;
1034 for c_idx in 0..c {
1035 let dst_idx = dst_idx + c_idx * l_k;
1036 let src_idx = c_idx * src_s1 + src_idx;
1037 for l_k_idx in 0..l_k {
1038 let src_l = l_idx * stride + l_k_idx * dilation;
1039 if padding != 0 && (src_l < padding || src_l >= l + padding) {
1040 continue;
1041 }
1042 let src_l = src_l - padding;
1043 let src_idx = src_idx + src_l * src_s2;
1044 let dst_idx = dst_idx + l_k_idx;
1045 dst[dst_idx] = src[src_idx]
1046 }
1047 }
1048 }
1049 }
1050 Ok(dst)
1051 }
1052}
1053
1054struct Im2Col {
1055 h_k: usize,
1056 w_k: usize,
1057 stride: usize,
1058 dilation: usize,
1059 padding: usize,
1060}
1061
1062impl Im2Col {
1063 fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
1064 let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
1065 let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
1066 (h_out, w_out)
1067 }
1068}
1069
1070impl Map1 for Im2Col {
1071 fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
1072 let &Self {
1073 h_k,
1074 w_k,
1075 stride,
1076 dilation,
1077 padding,
1078 } = self;
1079 let (b, c, h, w) = layout.shape().dims4()?;
1080 let (h_out, w_out) = self.hw_out(h, w);
1081 let src = &vs[layout.start_offset()..];
1082 let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
1083 let (src_s0, src_s1, src_s2, src_s3) = {
1084 let s = layout.stride();
1085 (s[0], s[1], s[2], s[3])
1086 };
1087 for b_idx in 0..b {
1093 let src_idx = b_idx * src_s0;
1094 let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
1095 for h_idx in 0..h_out {
1096 let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
1097 for w_idx in 0..w_out {
1098 let dst_idx = dst_idx + w_idx * c * h_k * w_k;
1099 for c_idx in 0..c {
1100 let dst_idx = dst_idx + c_idx * h_k * w_k;
1101 let src_idx = c_idx * src_s1 + src_idx;
1102 for h_k_idx in 0..h_k {
1103 let src_h = h_idx * stride + h_k_idx * dilation;
1104 if padding != 0 && (src_h < padding || src_h >= h + padding) {
1105 continue;
1106 }
1107 let src_h = src_h - padding;
1108 let src_idx = src_idx + src_h * src_s2;
1109 let dst_idx = dst_idx + h_k_idx * w_k;
1110 for w_k_idx in 0..w_k {
1111 let src_w = w_idx * stride + w_k_idx * dilation;
1112 if padding != 0 && (src_w < padding || src_w >= w + padding) {
1113 continue;
1114 }
1115 let src_w = src_w - padding;
1116 let src_idx = src_idx + src_w * src_s3;
1117 let dst_idx = dst_idx + w_k_idx;
1118 dst[dst_idx] = src[src_idx]
1119 }
1120 }
1121 }
1122 }
1123 }
1124 }
1125 Ok(dst)
1126 }
1127}
1128
1129struct Col2Im1D {
1130 stride: usize,
1131}
1132
1133impl Map1 for Col2Im1D {
1134 fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
1135 let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
1136 let stride = self.stride;
1137 let l_out = (l_in - 1) * stride + k_size;
1138 let mut im = vec![T::zero(); b_size * c_out * l_out];
1139 let (dst_s0, dst_s1) = (c_out * l_out, l_out);
1140 let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
1141 for l_in_i in 0..l_in {
1142 for k_i in 0..k_size {
1143 let l_out_i = l_in_i * stride + k_i;
1144 for b_i in 0..b_size {
1145 for c_i in 0..c_out {
1146 let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
1147 let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
1148 im[dst_idx] += col[src_idx]
1149 }
1150 }
1151 }
1152 }
1153 Ok(im)
1154 }
1155}
1156
1157struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
1158
1159impl Map2 for ConvTranspose1D<'_> {
1160 const OP: &'static str = "conv_transpose1d";
1161 fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1162 let p = self.0;
1163 let inp = &inp[inp_l.start_offset()..];
1164 let k = &k[k_l.start_offset()..];
1165 let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
1166 let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
1167 let l_out = p.l_out();
1168
1169 let dst_elems = p.c_out * l_out * p.b_size;
1171 let dst = vec![T::zero(); dst_elems];
1172 let dst_s0 = p.c_out * l_out;
1173 let dst_s1 = l_out;
1174 let dst_s2 = 1;
1175
1176 let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
1178 let cont_s0 = p.l_in * p.c_in;
1179 let cont_s1 = p.c_in;
1180 for b_idx in 0..p.b_size {
1181 for l_idx in 0..p.l_in {
1182 for c_idx in 0..p.c_in {
1183 let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
1184 let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
1185 inp_cont[dst_idx] = inp[src_idx]
1186 }
1187 }
1188 }
1189
1190 for k_idx in 0..p.k_size {
1191 (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1192 let k_cont = (0..p.c_in)
1193 .map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
1194 .collect::<Vec<_>>();
1195 for b_idx in 0..p.b_size {
1196 for l_idx in 0..p.l_in {
1197 let out_idx = l_idx * p.stride + k_idx * p.dilation;
1198 if out_idx < p.padding {
1199 continue;
1200 }
1201 let out_idx = out_idx - p.padding;
1202 if out_idx < l_out {
1203 let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
1204 let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
1205 let mut d = T::zero();
1206 unsafe {
1207 T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
1208 }
1209 let dst_p = dst.as_ptr();
1210 unsafe {
1214 let ptr = dst_p.add(dst_idx) as *mut T;
1215 *ptr += d
1216 }
1217 }
1218 }
1219 }
1220 })
1221 }
1222 Ok(dst)
1223 }
1224}
1225
1226struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
1227
1228impl Map2 for ConvTranspose2D<'_> {
1229 const OP: &'static str = "conv_transpose2d";
1230 fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1231 let p = self.0;
1232 let inp = &inp[inp_l.start_offset()..];
1233 let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
1234 let k = &k[k_l.start_offset()..];
1235 let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
1236 let (out_h, out_w) = (p.out_h(), p.out_w());
1237
1238 let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
1240 let dst_s0 = p.c_out * out_h * out_w;
1241 let dst_s1 = out_h * out_w;
1242 let dst_s2 = out_w;
1243 let dst_s3 = 1;
1244
1245 let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
1247 let cont_s0 = p.i_h * p.i_w * p.c_in;
1248 let cont_s1 = p.i_w * p.c_in;
1249 let cont_s2 = p.c_in;
1250 for b_idx in 0..p.b_size {
1251 for h_idx in 0..p.i_h {
1252 for w_idx in 0..p.i_w {
1253 for c_idx in 0..p.c_in {
1254 let src_idx =
1255 b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
1256 let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
1257 inp_cont[dst_idx] = inp[src_idx]
1258 }
1259 }
1260 }
1261 }
1262
1263 for k_y in 0..p.k_h {
1264 for k_x in 0..p.k_w {
1265 (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1266 let k_cont = (0..p.c_in)
1267 .map(|c_in_idx| {
1268 k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
1269 })
1270 .collect::<Vec<_>>();
1271 for b_idx in 0..p.b_size {
1272 for inp_y in 0..p.i_h {
1273 for inp_x in 0..p.i_w {
1274 let out_x = inp_x * p.stride + k_x * p.dilation;
1275 let out_y = inp_y * p.stride + k_y * p.dilation;
1276 if out_x < p.padding || out_y < p.padding {
1277 continue;
1278 }
1279 let out_x = out_x - p.padding;
1280 let out_y = out_y - p.padding;
1281 if out_x < out_w && out_y < out_h {
1282 let inp_cont = &inp_cont
1283 [b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..];
1284 let dst_idx = b_idx * dst_s0
1285 + out_y * dst_s2
1286 + out_x * dst_s3
1287 + dst_c_idx * dst_s1;
1288 let mut d = T::zero();
1289 unsafe {
1290 T::vec_dot(
1291 inp_cont.as_ptr(),
1292 k_cont.as_ptr(),
1293 &mut d,
1294 p.c_in,
1295 )
1296 }
1297 let dst_p = dst.as_ptr();
1298 unsafe {
1302 let ptr = dst_p.add(dst_idx) as *mut T;
1303 *ptr += d
1304 }
1305 }
1306 }
1307 }
1308 }
1309 })
1310 }
1311 }
1312 Ok(dst)
1313 }
1314}
1315
1316struct MatMul((usize, usize, usize, usize));
1317
1318impl MatMul {
1319 fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
1320 Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
1321 lhs_l: lhs_l.clone(),
1322 rhs_l: rhs_l.clone(),
1323 bmnk: self.0,
1324 msg,
1325 }))
1326 .bt()
1327 }
1328
1329 fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
1330 let lhs_stride = lhs_l.stride();
1331 let rhs_stride = rhs_l.stride();
1332 let rank = lhs_stride.len();
1333 let (_b, m, n, k) = self.0;
1334 let a_skip: usize = match lhs_stride[..rank - 2] {
1335 [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
1336 [_, stride] if lhs_l.dims()[0] == 1 => stride,
1337 [stride, _] if lhs_l.dims()[1] == 1 => stride,
1338 [stride] => stride,
1339 [] => m * k,
1340 _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
1341 };
1342 let b_skip: usize = match rhs_stride[..rank - 2] {
1343 [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
1344 [_, stride] if rhs_l.dims()[0] == 1 => stride,
1345 [stride, _] if rhs_l.dims()[1] == 1 => stride,
1346 [stride] => stride,
1347 [] => n * k,
1348 _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
1349 };
1350 Ok((a_skip, b_skip))
1351 }
1352}
1353
1354impl Map2 for MatMul {
1355 const OP: &'static str = "mat_mul";
1356
1357 #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
1358 fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1359 &self,
1360 lhs: &[T],
1361 lhs_l: &Layout,
1362 rhs: &[T],
1363 rhs_l: &Layout,
1364 ) -> Result<Vec<T>> {
1365 use gemm::{gemm, Parallelism};
1366
1367 match T::DTYPE {
1368 DType::F16 | DType::F32 | DType::F64 => {}
1369 _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
1370 }
1371
1372 let (b, m, n, k) = self.0;
1373 let lhs = &lhs[lhs_l.start_offset()..];
1374 let rhs = &rhs[rhs_l.start_offset()..];
1375
1376 let lhs_stride = lhs_l.stride();
1377 let rhs_stride = rhs_l.stride();
1378 let rank = lhs_stride.len();
1379 let lhs_cs = lhs_stride[rank - 1];
1380 let lhs_rs = lhs_stride[rank - 2];
1381
1382 let rhs_cs = rhs_stride[rank - 1];
1383 let rhs_rs = rhs_stride[rank - 2];
1384
1385 let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1386 let c_skip: usize = m * n;
1387
1388 let dst_shape: Shape = (m, n).into();
1389 let dst_strides = dst_shape.stride_contiguous();
1390 let dst_rs = dst_strides[0];
1391 let dst_cs = dst_strides[1];
1392
1393 let mut dst = vec![T::zero(); b * m * n];
1394 let num_threads = crate::utils::get_num_threads();
1395 let parallelism = if num_threads > 1 {
1396 Parallelism::Rayon(num_threads)
1397 } else {
1398 Parallelism::None
1399 };
1400 let (b, m, n, k) = if b_skip == 0 && a_skip == m * k {
1401 (1, b * m, n, k)
1404 } else if a_skip == 0 && b_skip == n * k {
1405 (1, m, b * n, k)
1406 } else {
1407 (b, m, n, k)
1408 };
1409 for step in 0..b {
1410 let lhs_p = &lhs[step * a_skip..];
1411 let rhs_p = &rhs[step * b_skip..];
1412 let dst_p = &mut dst[step * c_skip..];
1413 unsafe {
1414 gemm(
1415 m,
1416 n,
1417 k,
1418 dst_p.as_mut_ptr(),
1419 dst_cs as isize,
1420 dst_rs as isize,
1421 false,
1422 lhs_p.as_ptr(),
1423 lhs_cs as isize,
1424 lhs_rs as isize,
1425 rhs_p.as_ptr(),
1426 rhs_cs as isize,
1427 rhs_rs as isize,
1428 T::zero(),
1429 T::one(),
1430 false,
1431 false,
1432 false,
1433 parallelism,
1434 )
1435 }
1436 }
1437 Ok(dst)
1438 }
1439
1440 #[cfg(feature = "accelerate")]
1441 fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1442 &self,
1443 lhs: &[T],
1444 lhs_l: &Layout,
1445 rhs: &[T],
1446 rhs_l: &Layout,
1447 ) -> Result<Vec<T>> {
1448 let (b, m, n, k) = self.0;
1449 let lhs = &lhs[lhs_l.start_offset()..];
1450 let rhs = &rhs[rhs_l.start_offset()..];
1451
1452 let lhs_stride = lhs_l.stride();
1453 let rhs_stride = rhs_l.stride();
1454
1455 let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1456 let c_skip: usize = m * n;
1457
1458 let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1459 let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1460 let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1461 let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1462
1463 let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
1464 (n as i32, b'N')
1465 } else if rhs_m1 == k && rhs_m2 == 1 {
1466 (k as i32, b'T')
1467 } else {
1468 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1469 };
1470 let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
1472 (k as i32, b'N')
1473 } else if lhs_m1 == m && lhs_m2 == 1 {
1474 (m as i32, b'T')
1475 } else {
1476 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1477 };
1478
1479 let mut dst = vec![T::zero(); b * m * n];
1480 match T::DTYPE {
1481 DType::F16 => {
1482 crate::bail!("the accelerate backend does not support f16 matmul")
1483 }
1484 DType::F32 => {
1485 for step in 0..b {
1486 let lhs_p = &lhs[step * a_skip..];
1487 let rhs_p = &rhs[step * b_skip..];
1488 let dst_p = &mut dst[step * c_skip..];
1489 unsafe {
1490 let a = rhs_p.as_ptr() as *const f32;
1491 let b = lhs_p.as_ptr() as *const f32;
1492 let c = dst_p.as_mut_ptr() as *mut f32;
1493 let a = std::slice::from_raw_parts(a, a_skip);
1494 let b = std::slice::from_raw_parts(b, b_skip);
1495 let c = std::slice::from_raw_parts_mut(c, c_skip);
1496 crate::accelerate::sgemm(
1497 transa, transb, n as i32, m as i32,
1498 k as i32, 1., a,
1499 lda, b, ldb,
1500 0., c, n as i32,
1501 )
1502 }
1503 }
1504 }
1505 DType::F64 => {
1506 for step in 0..b {
1507 let lhs_p = &lhs[step * a_skip..];
1508 let rhs_p = &rhs[step * b_skip..];
1509 let dst_p = &mut dst[step * c_skip..];
1510 unsafe {
1511 let a = rhs_p.as_ptr() as *const f64;
1512 let b = lhs_p.as_ptr() as *const f64;
1513 let c = dst_p.as_mut_ptr() as *mut f64;
1514 let a = std::slice::from_raw_parts(a, a_skip);
1515 let b = std::slice::from_raw_parts(b, b_skip);
1516 let c = std::slice::from_raw_parts_mut(c, c_skip);
1517 crate::accelerate::dgemm(
1518 transa, transb, n as i32, m as i32,
1519 k as i32, 1., a,
1520 lda, b, ldb,
1521 0., c, n as i32,
1522 )
1523 }
1524 }
1525 }
1526 dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1527 }
1528 Ok(dst)
1529 }
1530
1531 #[cfg(feature = "mkl")]
1532 fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1533 &self,
1534 lhs: &[T],
1535 lhs_l: &Layout,
1536 rhs: &[T],
1537 rhs_l: &Layout,
1538 ) -> Result<Vec<T>> {
1539 let (b, m, n, k) = self.0;
1540 let lhs = &lhs[lhs_l.start_offset()..];
1541 let rhs = &rhs[rhs_l.start_offset()..];
1542
1543 let lhs_stride = lhs_l.stride();
1544 let rhs_stride = rhs_l.stride();
1545
1546 let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1547 let c_skip: usize = m * n;
1548
1549 let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1550 let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1551 let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1552 let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1553
1554 let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
1555 (n as i32, b'N')
1556 } else if rhs_m1 == k && rhs_m2 == 1 {
1557 (k as i32, b'T')
1558 } else {
1559 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1560 };
1561 let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
1563 (k as i32, b'N')
1564 } else if lhs_m1 == m && lhs_m2 == 1 {
1565 (m as i32, b'T')
1566 } else {
1567 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1568 };
1569
1570 let mut dst = vec![T::zero(); b * m * n];
1571 match T::DTYPE {
1572 DType::F16 => {
1573 for step in 0..b {
1574 let lhs_p = &lhs[step * a_skip..];
1575 let rhs_p = &rhs[step * b_skip..];
1576 let dst_p = &mut dst[step * c_skip..];
1577 unsafe {
1578 let a = rhs_p.as_ptr() as *const f16;
1579 let b = lhs_p.as_ptr() as *const f16;
1580 let c = dst_p.as_mut_ptr() as *mut f16;
1581 let a = std::slice::from_raw_parts(a, a_skip);
1582 let b = std::slice::from_raw_parts(b, b_skip);
1583 let c = std::slice::from_raw_parts_mut(c, c_skip);
1584 crate::mkl::hgemm(
1585 transa,
1586 transb,
1587 n as i32,
1588 m as i32,
1589 k as i32,
1590 f16::ONE,
1591 a,
1592 lda,
1593 b,
1594 ldb,
1595 f16::ZERO,
1596 c,
1597 n as i32,
1598 )
1599 }
1600 }
1601 }
1602 DType::F32 => {
1603 for step in 0..b {
1604 let lhs_p = &lhs[step * a_skip..];
1605 let rhs_p = &rhs[step * b_skip..];
1606 let dst_p = &mut dst[step * c_skip..];
1607 unsafe {
1608 let a = rhs_p.as_ptr() as *const f32;
1609 let b = lhs_p.as_ptr() as *const f32;
1610 let c = dst_p.as_mut_ptr() as *mut f32;
1611 let a = std::slice::from_raw_parts(a, a_skip);
1612 let b = std::slice::from_raw_parts(b, b_skip);
1613 let c = std::slice::from_raw_parts_mut(c, c_skip);
1614 crate::mkl::sgemm(
1615 transa, transb, n as i32, m as i32,
1616 k as i32, 1., a,
1617 lda, b, ldb,
1618 0., c, n as i32,
1619 )
1620 }
1621 }
1622 }
1623 DType::F64 => {
1624 for step in 0..b {
1625 let lhs_p = &lhs[step * a_skip..];
1626 let rhs_p = &rhs[step * b_skip..];
1627 let dst_p = &mut dst[step * c_skip..];
1628 unsafe {
1629 let a = rhs_p.as_ptr() as *const f64;
1630 let b = lhs_p.as_ptr() as *const f64;
1631 let c = dst_p.as_mut_ptr() as *mut f64;
1632 let a = std::slice::from_raw_parts(a, a_skip);
1633 let b = std::slice::from_raw_parts(b, b_skip);
1634 let c = std::slice::from_raw_parts_mut(c, c_skip);
1635 crate::mkl::dgemm(
1636 transa, transb, n as i32, m as i32,
1637 k as i32, 1., a,
1638 lda, b, ldb,
1639 0., c, n as i32,
1640 )
1641 }
1642 }
1643 }
1644 dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1645 }
1646 Ok(dst)
1647 }
1648}
1649
1650fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
1651 if v.is_sign_positive() {
1652 v
1653 } else {
1654 (v.exp() - T::one()) * alpha
1655 }
1656}
1657
1658impl CpuStorage {
1659 pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
1660 D::cpu_storage_as_slice(self)
1661 }
1662
1663 pub fn concat(storages: &[CpuStorage]) -> Result<CpuStorage> {
1664 let storage0 = &storages[0];
1665 let s = match storage0 {
1666 Self::U8(_) => {
1667 let storages = storages
1668 .iter()
1669 .map(|s| match s {
1670 Self::U8(s) => Ok(s.as_slice()),
1671 _ => crate::bail!("dtype mismatch"),
1672 })
1673 .collect::<Result<Vec<_>>>()?
1674 .concat();
1675 Self::U8(storages)
1676 }
1677 Self::U32(_) => {
1678 let storages = storages
1679 .iter()
1680 .map(|s| match s {
1681 Self::U32(s) => Ok(s.as_slice()),
1682 _ => crate::bail!("dtype mismatch"),
1683 })
1684 .collect::<Result<Vec<_>>>()?
1685 .concat();
1686 Self::U32(storages)
1687 }
1688 Self::I16(_) => {
1689 let storages = storages
1690 .iter()
1691 .map(|s| match s {
1692 Self::I16(s) => Ok(s.as_slice()),
1693 _ => crate::bail!("dtype mismatch"),
1694 })
1695 .collect::<Result<Vec<_>>>()?
1696 .concat();
1697 Self::I16(storages)
1698 }
1699 Self::I32(_) => {
1700 let storages = storages
1701 .iter()
1702 .map(|s| match s {
1703 Self::I32(s) => Ok(s.as_slice()),
1704 _ => crate::bail!("dtype mismatch"),
1705 })
1706 .collect::<Result<Vec<_>>>()?
1707 .concat();
1708 Self::I32(storages)
1709 }
1710 Self::I64(_) => {
1711 let storages = storages
1712 .iter()
1713 .map(|s| match s {
1714 Self::I64(s) => Ok(s.as_slice()),
1715 _ => crate::bail!("dtype mismatch"),
1716 })
1717 .collect::<Result<Vec<_>>>()?
1718 .concat();
1719 Self::I64(storages)
1720 }
1721 Self::BF16(_) => {
1722 let storages = storages
1723 .iter()
1724 .map(|s| match s {
1725 Self::BF16(s) => Ok(s.as_slice()),
1726 _ => crate::bail!("dtype mismatch"),
1727 })
1728 .collect::<Result<Vec<_>>>()?
1729 .concat();
1730 Self::BF16(storages)
1731 }
1732 Self::F16(_) => {
1733 let storages = storages
1734 .iter()
1735 .map(|s| match s {
1736 Self::F16(s) => Ok(s.as_slice()),
1737 _ => crate::bail!("dtype mismatch"),
1738 })
1739 .collect::<Result<Vec<_>>>()?
1740 .concat();
1741 Self::F16(storages)
1742 }
1743 Self::F32(_) => {
1744 let storages = storages
1745 .iter()
1746 .map(|s| match s {
1747 Self::F32(s) => Ok(s.as_slice()),
1748 _ => crate::bail!("dtype mismatch"),
1749 })
1750 .collect::<Result<Vec<_>>>()?
1751 .concat();
1752 Self::F32(storages)
1753 }
1754 Self::F64(_) => {
1755 let storages = storages
1756 .iter()
1757 .map(|s| match s {
1758 Self::F64(s) => Ok(s.as_slice()),
1759 _ => crate::bail!("dtype mismatch"),
1760 })
1761 .collect::<Result<Vec<_>>>()?
1762 .concat();
1763 Self::F64(storages)
1764 }
1765 Self::F8E4M3(_) => {
1766 let storages = storages
1767 .iter()
1768 .map(|s| match s {
1769 Self::F8E4M3(s) => Ok(s.as_slice()),
1770 _ => crate::bail!("dtype mismatch"),
1771 })
1772 .collect::<Result<Vec<_>>>()?
1773 .concat();
1774 Self::F8E4M3(storages)
1775 }
1776 Self::F6E2M3(_) => {
1777 let storages = storages
1778 .iter()
1779 .map(|s| match s {
1780 Self::F6E2M3(s) => Ok(s.as_slice()),
1781 _ => crate::bail!("dtype mismatch"),
1782 })
1783 .collect::<Result<Vec<_>>>()?
1784 .concat();
1785 Self::F6E2M3(storages)
1786 }
1787 Self::F6E3M2(_) => {
1788 let storages = storages
1789 .iter()
1790 .map(|s| match s {
1791 Self::F6E3M2(s) => Ok(s.as_slice()),
1792 _ => crate::bail!("dtype mismatch"),
1793 })
1794 .collect::<Result<Vec<_>>>()?
1795 .concat();
1796 Self::F6E3M2(storages)
1797 }
1798 Self::F4(_) => {
1799 let storages = storages
1800 .iter()
1801 .map(|s| match s {
1802 Self::F4(s) => Ok(s.as_slice()),
1803 _ => crate::bail!("dtype mismatch"),
1804 })
1805 .collect::<Result<Vec<_>>>()?
1806 .concat();
1807 Self::F4(storages)
1808 }
1809 Self::F8E8M0(_) => {
1810 let storages = storages
1811 .iter()
1812 .map(|s| match s {
1813 Self::F8E8M0(s) => Ok(s.as_slice()),
1814 _ => crate::bail!("dtype mismatch"),
1815 })
1816 .collect::<Result<Vec<_>>>()?
1817 .concat();
1818 Self::F8E8M0(storages)
1819 }
1820 };
1821 Ok(s)
1822 }
1823}
1824
1825impl BackendStorage for CpuStorage {
1826 type Device = CpuDevice;
1827
1828 fn dtype(&self) -> DType {
1829 match self {
1830 Self::U8(_) => DType::U8,
1831 Self::U32(_) => DType::U32,
1832 Self::I16(_) => DType::I16,
1833 Self::I32(_) => DType::I32,
1834 Self::I64(_) => DType::I64,
1835 Self::BF16(_) => DType::BF16,
1836 Self::F16(_) => DType::F16,
1837 Self::F32(_) => DType::F32,
1838 Self::F64(_) => DType::F64,
1839 Self::F8E4M3(_) => DType::F8E4M3,
1840 Self::F6E2M3(_) => DType::F6E2M3,
1841 Self::F6E3M2(_) => DType::F6E3M2,
1842 Self::F4(_) => DType::F4,
1843 Self::F8E8M0(_) => DType::F8E8M0,
1844 }
1845 }
1846
1847 fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
1848 match (self, dtype) {
1850 (Self::U8(storage), DType::BF16) => {
1851 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1852 Ok(Self::BF16(data))
1853 }
1854 (Self::U32(storage), DType::BF16) => {
1855 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1856 Ok(Self::BF16(data))
1857 }
1858 (Self::I64(storage), DType::BF16) => {
1859 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1860 Ok(Self::BF16(data))
1861 }
1862 (Self::BF16(storage), DType::BF16) => {
1863 let data = unary_map(storage, layout, |v| v);
1864 Ok(Self::BF16(data))
1865 }
1866 (Self::F16(storage), DType::BF16) => {
1867 let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
1868 Ok(Self::BF16(data))
1869 }
1870 (Self::F32(storage), DType::BF16) => {
1871 let data = unary_map(storage, layout, bf16::from_f32);
1872 Ok(Self::BF16(data))
1873 }
1874 (Self::F64(storage), DType::BF16) => {
1875 let data = unary_map(storage, layout, bf16::from_f64);
1876 Ok(Self::BF16(data))
1877 }
1878 (Self::U8(storage), DType::F16) => {
1879 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1880 Ok(Self::F16(data))
1881 }
1882 (Self::U32(storage), DType::F16) => {
1883 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1884 Ok(Self::F16(data))
1885 }
1886 (Self::I64(storage), DType::F16) => {
1887 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1888 Ok(Self::F16(data))
1889 }
1890 (Self::BF16(storage), DType::F16) => {
1891 let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
1892 Ok(Self::F16(data))
1893 }
1894 (Self::F16(storage), DType::F16) => {
1895 let data = unary_map(storage, layout, |v| v);
1896 Ok(Self::F16(data))
1897 }
1898 (Self::F32(storage), DType::F16) => {
1899 let data = unary_map(storage, layout, f16::from_f32);
1900 Ok(Self::F16(data))
1901 }
1902 (Self::F64(storage), DType::F16) => {
1903 let data = unary_map(storage, layout, f16::from_f64);
1904 Ok(Self::F16(data))
1905 }
1906 (Self::U8(storage), DType::F32) => {
1907 let data = unary_map(storage, layout, |v| v as f32);
1908 Ok(Self::F32(data))
1909 }
1910 (Self::U32(storage), DType::F32) => {
1911 let data = unary_map(storage, layout, |v| v as f32);
1912 Ok(Self::F32(data))
1913 }
1914 (Self::I64(storage), DType::F32) => {
1915 let data = unary_map(storage, layout, |v| v as f32);
1916 Ok(Self::F32(data))
1917 }
1918 (Self::BF16(storage), DType::F32) => {
1919 let data = unary_map(storage, layout, |v| v.to_f32());
1920 Ok(Self::F32(data))
1921 }
1922 (Self::F16(storage), DType::F32) => {
1923 let data = unary_map(storage, layout, |v| v.to_f32());
1924 Ok(Self::F32(data))
1925 }
1926 (Self::F32(storage), DType::F32) => {
1927 let data = unary_map(storage, layout, |v| v);
1928 Ok(Self::F32(data))
1929 }
1930 (Self::F64(storage), DType::F32) => {
1931 let data = unary_map(storage, layout, |v| v as f32);
1932 Ok(Self::F32(data))
1933 }
1934 (Self::U8(storage), DType::U8) => {
1935 let data = unary_map(storage, layout, |v| v);
1936 Ok(Self::U8(data))
1937 }
1938 (Self::BF16(storage), DType::U8) => {
1939 let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1940 Ok(Self::U8(data))
1941 }
1942 (Self::F16(storage), DType::U8) => {
1943 let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1944 Ok(Self::U8(data))
1945 }
1946 (Self::F32(storage), DType::U8) => {
1947 let data = unary_map(storage, layout, |v| v as u8);
1948 Ok(Self::U8(data))
1949 }
1950 (Self::F64(storage), DType::U8) => {
1951 let data = unary_map(storage, layout, |v| v as u8);
1952 Ok(Self::U8(data))
1953 }
1954 (Self::U32(storage), DType::U8) => {
1955 let data = unary_map(storage, layout, |v| v as u8);
1956 Ok(Self::U8(data))
1957 }
1958 (Self::I64(storage), DType::U8) => {
1959 let data = unary_map(storage, layout, |v| v as u8);
1960 Ok(Self::U8(data))
1961 }
1962 (Self::U8(storage), DType::U32) => {
1963 let data = unary_map(storage, layout, |v| v as u32);
1964 Ok(Self::U32(data))
1965 }
1966 (Self::U32(storage), DType::U32) => {
1967 let data = unary_map(storage, layout, |v| v);
1968 Ok(Self::U32(data))
1969 }
1970 (Self::I64(storage), DType::U32) => {
1971 let data = unary_map(storage, layout, |v| v as u32);
1972 Ok(Self::U32(data))
1973 }
1974 (Self::BF16(storage), DType::U32) => {
1975 let data = unary_map(storage, layout, |v| v.to_f32() as u32);
1976 Ok(Self::U32(data))
1977 }
1978 (Self::F16(storage), DType::U32) => {
1979 let data = unary_map(storage, layout, |v| v.to_f32() as u32);
1980 Ok(Self::U32(data))
1981 }
1982 (Self::F32(storage), DType::U32) => {
1983 let data = unary_map(storage, layout, |v| v as u32);
1984 Ok(Self::U32(data))
1985 }
1986 (Self::F64(storage), DType::U32) => {
1987 let data = unary_map(storage, layout, |v| v as u32);
1988 Ok(Self::U32(data))
1989 }
1990 (Self::U8(storage), DType::I64) => {
1991 let data = unary_map(storage, layout, |v| v as i64);
1992 Ok(Self::I64(data))
1993 }
1994 (Self::U32(storage), DType::I64) => {
1995 let data = unary_map(storage, layout, |v| v as i64);
1996 Ok(Self::I64(data))
1997 }
1998 (Self::I64(storage), DType::I64) => {
1999 let data = unary_map(storage, layout, |v| v);
2000 Ok(Self::I64(data))
2001 }
2002 (Self::BF16(storage), DType::I64) => {
2003 let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2004 Ok(Self::I64(data))
2005 }
2006 (Self::F16(storage), DType::I64) => {
2007 let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2008 Ok(Self::I64(data))
2009 }
2010 (Self::F32(storage), DType::I64) => {
2011 let data = unary_map(storage, layout, |v| v as i64);
2012 Ok(Self::I64(data))
2013 }
2014 (Self::F64(storage), DType::I64) => {
2015 let data = unary_map(storage, layout, |v| v as i64);
2016 Ok(Self::I64(data))
2017 }
2018 (Self::U8(storage), DType::F64) => {
2019 let data = unary_map(storage, layout, |v| v as f64);
2020 Ok(Self::F64(data))
2021 }
2022 (Self::U32(storage), DType::F64) => {
2023 let data = unary_map(storage, layout, |v| v as f64);
2024 Ok(Self::F64(data))
2025 }
2026 (Self::I64(storage), DType::F64) => {
2027 let data = unary_map(storage, layout, |v| v as f64);
2028 Ok(Self::F64(data))
2029 }
2030 (Self::BF16(storage), DType::F64) => {
2031 let data = unary_map(storage, layout, |v| v.to_f64());
2032 Ok(Self::F64(data))
2033 }
2034 (Self::F16(storage), DType::F64) => {
2035 let data = unary_map(storage, layout, |v| v.to_f64());
2036 Ok(Self::F64(data))
2037 }
2038 (Self::F32(storage), DType::F64) => {
2039 let data = unary_map(storage, layout, |v| v as f64);
2040 Ok(Self::F64(data))
2041 }
2042 (Self::F64(storage), DType::F64) => {
2043 let data = unary_map(storage, layout, |v| v);
2044 Ok(Self::F64(data))
2045 }
2046 (Self::U8(storage), DType::F8E4M3) => {
2048 let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2049 Ok(Self::F8E4M3(data))
2050 }
2051 (Self::U32(storage), DType::F8E4M3) => {
2052 let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2053 Ok(Self::F8E4M3(data))
2054 }
2055 (Self::I64(storage), DType::F8E4M3) => {
2056 let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2057 Ok(Self::F8E4M3(data))
2058 }
2059 (Self::BF16(storage), DType::F8E4M3) => {
2060 let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
2061 Ok(Self::F8E4M3(data))
2062 }
2063 (Self::F16(storage), DType::F8E4M3) => {
2064 let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
2065 Ok(Self::F8E4M3(data))
2066 }
2067 (Self::F32(storage), DType::F8E4M3) => {
2068 let data = unary_map(storage, layout, F8E4M3::from_f32);
2069 Ok(Self::F8E4M3(data))
2070 }
2071 (Self::F64(storage), DType::F8E4M3) => {
2072 let data = unary_map(storage, layout, F8E4M3::from_f64);
2073 Ok(Self::F8E4M3(data))
2074 }
2075 (Self::F8E4M3(storage), DType::F8E4M3) => {
2076 let data = unary_map(storage, layout, |v| v);
2077 Ok(Self::F8E4M3(data))
2078 }
2079 (Self::F8E4M3(storage), DType::U8) => {
2081 let data = unary_map(storage, layout, |v| v.to_f32() as u8);
2082 Ok(Self::U8(data))
2083 }
2084 (Self::F8E4M3(storage), DType::U32) => {
2085 let data = unary_map(storage, layout, |v| v.to_f32() as u32);
2086 Ok(Self::U32(data))
2087 }
2088 (Self::F8E4M3(storage), DType::I64) => {
2089 let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2090 Ok(Self::I64(data))
2091 }
2092 (Self::F8E4M3(storage), DType::BF16) => {
2093 let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
2094 Ok(Self::BF16(data))
2095 }
2096 (Self::F8E4M3(storage), DType::F16) => {
2097 let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
2098 Ok(Self::F16(data))
2099 }
2100 (Self::F8E4M3(storage), DType::F32) => {
2101 let data = unary_map(storage, layout, |v| v.to_f32());
2102 Ok(Self::F32(data))
2103 }
2104 (Self::F8E4M3(storage), DType::F64) => {
2105 let data = unary_map(storage, layout, |v| v.to_f64());
2106 Ok(Self::F64(data))
2107 }
2108 (Self::U8(storage), DType::I16) => {
2110 let data = unary_map(storage, layout, |v| v as i16);
2111 Ok(Self::I16(data))
2112 }
2113 (Self::U32(storage), DType::I16) => {
2114 let data = unary_map(storage, layout, |v| v as i16);
2115 Ok(Self::I16(data))
2116 }
2117 (Self::I16(storage), DType::I16) => {
2118 let data = unary_map(storage, layout, |v| v);
2119 Ok(Self::I16(data))
2120 }
2121 (Self::I32(storage), DType::I16) => {
2122 let data = unary_map(storage, layout, |v| v as i16);
2123 Ok(Self::I16(data))
2124 }
2125 (Self::I64(storage), DType::I16) => {
2126 let data = unary_map(storage, layout, |v| v as i16);
2127 Ok(Self::I16(data))
2128 }
2129 (Self::BF16(storage), DType::I16) => {
2130 let data = unary_map(storage, layout, |v| v.to_f32() as i16);
2131 Ok(Self::I16(data))
2132 }
2133 (Self::F16(storage), DType::I16) => {
2134 let data = unary_map(storage, layout, |v| v.to_f32() as i16);
2135 Ok(Self::I16(data))
2136 }
2137 (Self::F32(storage), DType::I16) => {
2138 let data = unary_map(storage, layout, |v| v as i16);
2139 Ok(Self::I16(data))
2140 }
2141 (Self::F64(storage), DType::I16) => {
2142 let data = unary_map(storage, layout, |v| v as i16);
2143 Ok(Self::I16(data))
2144 }
2145 (Self::F8E4M3(storage), DType::I16) => {
2146 let data = unary_map(storage, layout, |v| v.to_f32() as i16);
2147 Ok(Self::I16(data))
2148 }
2149 (Self::U8(storage), DType::I32) => {
2151 let data = unary_map(storage, layout, |v| v as i32);
2152 Ok(Self::I32(data))
2153 }
2154 (Self::U32(storage), DType::I32) => {
2155 let data = unary_map(storage, layout, |v| v as i32);
2156 Ok(Self::I32(data))
2157 }
2158 (Self::I16(storage), DType::I32) => {
2159 let data = unary_map(storage, layout, |v| v as i32);
2160 Ok(Self::I32(data))
2161 }
2162 (Self::I32(storage), DType::I32) => {
2163 let data = unary_map(storage, layout, |v| v);
2164 Ok(Self::I32(data))
2165 }
2166 (Self::I64(storage), DType::I32) => {
2167 let data = unary_map(storage, layout, |v| v as i32);
2168 Ok(Self::I32(data))
2169 }
2170 (Self::BF16(storage), DType::I32) => {
2171 let data = unary_map(storage, layout, |v| v.to_f32() as i32);
2172 Ok(Self::I32(data))
2173 }
2174 (Self::F16(storage), DType::I32) => {
2175 let data = unary_map(storage, layout, |v| v.to_f32() as i32);
2176 Ok(Self::I32(data))
2177 }
2178 (Self::F32(storage), DType::I32) => {
2179 let data = unary_map(storage, layout, |v| v as i32);
2180 Ok(Self::I32(data))
2181 }
2182 (Self::F64(storage), DType::I32) => {
2183 let data = unary_map(storage, layout, |v| v as i32);
2184 Ok(Self::I32(data))
2185 }
2186 (Self::F8E4M3(storage), DType::I32) => {
2187 let data = unary_map(storage, layout, |v| v.to_f32() as i32);
2188 Ok(Self::I32(data))
2189 }
2190 (Self::I16(storage), DType::U8) => {
2192 let data = unary_map(storage, layout, |v| v as u8);
2193 Ok(Self::U8(data))
2194 }
2195 (Self::I16(storage), DType::U32) => {
2196 let data = unary_map(storage, layout, |v| v as u32);
2197 Ok(Self::U32(data))
2198 }
2199 (Self::I16(storage), DType::I64) => {
2200 let data = unary_map(storage, layout, |v| v as i64);
2201 Ok(Self::I64(data))
2202 }
2203 (Self::I16(storage), DType::BF16) => {
2204 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
2205 Ok(Self::BF16(data))
2206 }
2207 (Self::I16(storage), DType::F16) => {
2208 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
2209 Ok(Self::F16(data))
2210 }
2211 (Self::I16(storage), DType::F32) => {
2212 let data = unary_map(storage, layout, |v| v as f32);
2213 Ok(Self::F32(data))
2214 }
2215 (Self::I16(storage), DType::F64) => {
2216 let data = unary_map(storage, layout, |v| v as f64);
2217 Ok(Self::F64(data))
2218 }
2219 (Self::I16(storage), DType::F8E4M3) => {
2220 let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2221 Ok(Self::F8E4M3(data))
2222 }
2223 (Self::I32(storage), DType::U8) => {
2225 let data = unary_map(storage, layout, |v| v as u8);
2226 Ok(Self::U8(data))
2227 }
2228 (Self::I32(storage), DType::U32) => {
2229 let data = unary_map(storage, layout, |v| v as u32);
2230 Ok(Self::U32(data))
2231 }
2232 (Self::I32(storage), DType::I64) => {
2233 let data = unary_map(storage, layout, |v| v as i64);
2234 Ok(Self::I64(data))
2235 }
2236 (Self::I32(storage), DType::BF16) => {
2237 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
2238 Ok(Self::BF16(data))
2239 }
2240 (Self::I32(storage), DType::F16) => {
2241 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
2242 Ok(Self::F16(data))
2243 }
2244 (Self::I32(storage), DType::F32) => {
2245 let data = unary_map(storage, layout, |v| v as f32);
2246 Ok(Self::F32(data))
2247 }
2248 (Self::I32(storage), DType::F64) => {
2249 let data = unary_map(storage, layout, |v| v as f64);
2250 Ok(Self::F64(data))
2251 }
2252 (Self::I32(storage), DType::F8E4M3) => {
2253 let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2254 Ok(Self::F8E4M3(data))
2255 }
2256 (_, DType::F6E2M3) | (_, DType::F6E3M2) | (_, DType::F4) | (_, DType::F8E8M0) => {
2258 Err(Error::UnsupportedDTypeForOp(dtype, "to_dtype").bt())
2259 }
2260 (Self::F6E2M3(_), _)
2261 | (Self::F6E3M2(_), _)
2262 | (Self::F4(_), _)
2263 | (Self::F8E8M0(_), _) => {
2264 Err(Error::UnsupportedDTypeForOp(self.dtype(), "to_dtype").bt())
2265 }
2266 }
2267 }
2268
2269 fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
2270 match op {
2271 ReduceOp::Sum => {
2272 let src_dims = layout.dims();
2273 let mut dst_dims = src_dims.to_vec();
2274 for &dim in reduce_dims.iter() {
2275 dst_dims[dim] = 1;
2276 }
2277 let dst_shape = Shape::from(dst_dims);
2278 let mut reduce_dims = reduce_dims.to_vec();
2279 reduce_dims.sort();
2282 let reduce_dims_and_stride: Vec<_> = reduce_dims
2283 .iter()
2284 .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
2285 .collect();
2286 ReduceSum {
2287 dst_shape: &dst_shape,
2288 reduce_dims: &reduce_dims,
2289 reduce_dims_and_stride,
2290 }
2291 .map(self, layout)
2292 }
2293 ReduceOp::Min | ReduceOp::ArgMin | ReduceOp::Max | ReduceOp::ArgMax => {
2294 let reduce_dim_index = match reduce_dims {
2295 [reduce_dim_index] => *reduce_dim_index,
2296 _ => {
2297 let op = match op {
2298 ReduceOp::Min => "min",
2299 ReduceOp::ArgMin => "argmin",
2300 ReduceOp::Max => "max",
2301 ReduceOp::ArgMax => "argmax",
2302 _ => unreachable!(),
2303 };
2304 let dims = reduce_dims.to_vec();
2305 Err(Error::OnlySingleDimension { op, dims })?
2306 }
2307 };
2308 let (use_min, return_index) = match op {
2309 ReduceOp::Min => (true, false),
2310 ReduceOp::ArgMin => (true, true),
2311 ReduceOp::Max => (false, false),
2312 ReduceOp::ArgMax => (false, true),
2313 _ => unreachable!(),
2314 };
2315 ReduceIndex {
2316 reduce_dim_index,
2317 use_min,
2318 return_index,
2319 }
2320 .map(self, layout)
2321 }
2322 }
2323 }
2324
2325 fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
2326 Cmp(op).map(self, lhs_l, rhs, rhs_l)
2327 }
2328
2329 fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
2330 Affine(mul, add).map(self, layout)
2331 }
2332
2333 fn avg_pool2d(
2334 &self,
2335 layout: &Layout,
2336 kernel_size: (usize, usize),
2337 stride: (usize, usize),
2338 ) -> Result<Self> {
2339 AvgPool2D(kernel_size, stride).map(self, layout)
2340 }
2341
2342 fn max_pool2d(
2343 &self,
2344 layout: &Layout,
2345 kernel_size: (usize, usize),
2346 stride: (usize, usize),
2347 ) -> Result<Self> {
2348 MaxPool2D(kernel_size, stride).map(self, layout)
2349 }
2350
2351 fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
2352 UpsampleNearest1D(sz).map(self, layout)
2353 }
2354
2355 fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
2356 UpsampleNearest2D(h, w).map(self, layout)
2357 }
2358
2359 fn upsample_bilinear2d(
2360 &self,
2361 layout: &Layout,
2362 h: usize,
2363 w: usize,
2364 align_corners: bool,
2365 scale_h: Option<f64>,
2366 scale_w: Option<f64>,
2367 ) -> Result<Self> {
2368 UpsampleBilinear2D {
2369 target_h: h,
2370 target_w: w,
2371 align_corners,
2372 scale_h_factor: scale_h,
2373 scale_w_factor: scale_w,
2374 }
2375 .map(self, layout)
2376 }
2377
2378 fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
2379 use num_traits::Float;
2380 match self {
2382 Self::BF16(storage) => {
2383 let data = unary_map(storage, layout, |v| v.powf(bf16::from_f64(e)));
2384 Ok(Self::BF16(data))
2385 }
2386 Self::F16(storage) => {
2387 let data = unary_map(storage, layout, |v| v.powf(f16::from_f64(e)));
2388 Ok(Self::F16(data))
2389 }
2390 Self::F32(storage) => {
2391 let data = unary_map(storage, layout, |v| v.powf(e as f32));
2392 Ok(Self::F32(data))
2393 }
2394 Self::F64(storage) => {
2395 let data = unary_map(storage, layout, |v| v.powf(e));
2396 Ok(Self::F64(data))
2397 }
2398 Self::F8E4M3(storage) => {
2399 let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e)));
2400 Ok(Self::F8E4M3(data))
2401 }
2402 Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "powf").bt()),
2403 Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "powf").bt()),
2404 Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "powf").bt()),
2405 Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "powf").bt()),
2406 Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "powf").bt()),
2407 Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "powf").bt()),
2408 Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "powf").bt()),
2409 Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "powf").bt()),
2410 Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "powf").bt()),
2411 }
2412 }
2413
2414 fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
2415 match self {
2417 Self::BF16(storage) => {
2418 let data = unary_map(storage, layout, |v| elu(v, bf16::from_f64(alpha)));
2419 Ok(Self::BF16(data))
2420 }
2421 Self::F16(storage) => {
2422 let data = unary_map(storage, layout, |v| elu(v, f16::from_f64(alpha)));
2423 Ok(Self::F16(data))
2424 }
2425 Self::F32(storage) => {
2426 let data = unary_map(storage, layout, |v| elu(v, f32::from_f64(alpha)));
2427 Ok(Self::F32(data))
2428 }
2429 Self::F64(storage) => {
2430 let data = unary_map(storage, layout, |v| elu(v, alpha));
2431 Ok(Self::F64(data))
2432 }
2433 Self::F8E4M3(storage) => {
2434 let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha)));
2435 Ok(Self::F8E4M3(data))
2436 }
2437 Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
2438 Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
2439 Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()),
2440 Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()),
2441 Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
2442 Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "elu").bt()),
2443 Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "elu").bt()),
2444 Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "elu").bt()),
2445 Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "elu").bt()),
2446 }
2447 }
2448
2449 fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
2450 match self {
2451 Self::BF16(storage) => {
2452 if B::BF16_VEC {
2453 let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
2454 Ok(Self::BF16(data))
2455 } else {
2456 let data = unary_map(storage, layout, B::bf16);
2457 Ok(Self::BF16(data))
2458 }
2459 }
2460 Self::F16(storage) => {
2461 if B::F16_VEC {
2462 let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
2463 Ok(Self::F16(data))
2464 } else {
2465 let data = unary_map(storage, layout, B::f16);
2466 Ok(Self::F16(data))
2467 }
2468 }
2469 Self::F32(storage) => {
2470 if B::F32_VEC {
2471 let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
2472 Ok(Self::F32(data))
2473 } else {
2474 let data = unary_map(storage, layout, B::f32);
2475 Ok(Self::F32(data))
2476 }
2477 }
2478 Self::F64(storage) => {
2479 if B::F64_VEC {
2480 let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
2481 Ok(Self::F64(data))
2482 } else {
2483 let data = unary_map(storage, layout, B::f64);
2484 Ok(Self::F64(data))
2485 }
2486 }
2487 Self::U8(storage) => {
2488 let data = unary_map(storage, layout, B::u8);
2489 Ok(Self::U8(data))
2490 }
2491 Self::U32(storage) => {
2492 let data = unary_map(storage, layout, B::u32);
2493 Ok(Self::U32(data))
2494 }
2495 Self::I16(storage) => {
2496 let data = unary_map(storage, layout, B::i16);
2497 Ok(Self::I16(data))
2498 }
2499 Self::I32(storage) => {
2500 let data = unary_map(storage, layout, B::i32);
2501 Ok(Self::I32(data))
2502 }
2503 Self::I64(storage) => {
2504 let data = unary_map(storage, layout, B::i64);
2505 Ok(Self::I64(data))
2506 }
2507 Self::F8E4M3(storage) => {
2508 let data = unary_map(storage, layout, B::f8e4m3);
2509 Ok(Self::F8E4M3(data))
2510 }
2511 Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "unary").bt()),
2512 Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "unary").bt()),
2513 Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "unary").bt()),
2514 Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "unary").bt()),
2515 }
2516 }
2517
2518 fn binary_impl<B: BinaryOpT>(
2519 &self,
2520 rhs: &Self,
2521 lhs_l: &Layout,
2522 rhs_l: &Layout,
2523 ) -> Result<Self> {
2524 match (self, rhs) {
2525 (Self::BF16(lhs), Self::BF16(rhs)) => {
2526 let data = if B::BF16_VEC {
2527 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::bf16, B::bf16_vec)
2528 } else {
2529 binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16)
2530 };
2531 Ok(Self::BF16(data))
2532 }
2533 (Self::F16(lhs), Self::F16(rhs)) => {
2534 let data = if B::F16_VEC {
2535 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f16, B::f16_vec)
2536 } else {
2537 binary_map(lhs_l, rhs_l, lhs, rhs, B::f16)
2538 };
2539 Ok(Self::F16(data))
2540 }
2541 (Self::F32(lhs), Self::F32(rhs)) => {
2542 let data = if B::F32_VEC {
2543 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f32, B::f32_vec)
2544 } else {
2545 binary_map(lhs_l, rhs_l, lhs, rhs, B::f32)
2546 };
2547 Ok(Self::F32(data))
2548 }
2549 (Self::F64(lhs), Self::F64(rhs)) => {
2550 let data = if B::F64_VEC {
2551 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f64, B::f64_vec)
2552 } else {
2553 binary_map(lhs_l, rhs_l, lhs, rhs, B::f64)
2554 };
2555 Ok(Self::F64(data))
2556 }
2557 (Self::U32(lhs), Self::U32(rhs)) => {
2558 let data = if B::U32_VEC {
2559 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u32, B::u32_vec)
2560 } else {
2561 binary_map(lhs_l, rhs_l, lhs, rhs, B::u32)
2562 };
2563 Ok(Self::U32(data))
2564 }
2565 (Self::I16(lhs), Self::I16(rhs)) => {
2566 let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::i16);
2567 Ok(Self::I16(data))
2568 }
2569 (Self::I32(lhs), Self::I32(rhs)) => {
2570 let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::i32);
2571 Ok(Self::I32(data))
2572 }
2573 (Self::I64(lhs), Self::I64(rhs)) => {
2574 let data = if B::I64_VEC {
2575 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
2576 } else {
2577 binary_map(lhs_l, rhs_l, lhs, rhs, B::i64)
2578 };
2579 Ok(Self::I64(data))
2580 }
2581 (Self::U8(lhs), Self::U8(rhs)) => {
2582 let data = if B::U8_VEC {
2583 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec)
2584 } else {
2585 binary_map(lhs_l, rhs_l, lhs, rhs, B::u8)
2586 };
2587 Ok(Self::U8(data))
2588 }
2589 (Self::F8E4M3(lhs), Self::F8E4M3(rhs)) => {
2590 let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f8e4m3);
2591 Ok(Self::F8E4M3(data))
2592 }
2593 _ => {
2594 Err(Error::DTypeMismatchBinaryOp {
2596 lhs: self.dtype(),
2597 rhs: rhs.dtype(),
2598 op: B::NAME,
2599 }
2600 .bt())
2601 }
2602 }
2603 }
2604
2605 fn copy2d(
2606 &self,
2607 dst: &mut Self,
2608 d1: usize,
2609 d2: usize,
2610 src_s: usize,
2611 dst_s: usize,
2612 src_o: usize,
2613 dst_o: usize,
2614 ) -> Result<()> {
2615 match (self, dst) {
2616 (Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
2617 (Self::U32(src), Self::U32(dst)) => {
2618 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2619 }
2620 (Self::I16(src), Self::I16(dst)) => {
2621 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2622 }
2623 (Self::I32(src), Self::I32(dst)) => {
2624 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2625 }
2626 (Self::I64(src), Self::I64(dst)) => {
2627 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2628 }
2629 (Self::BF16(src), Self::BF16(dst)) => {
2630 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2631 }
2632 (Self::F16(src), Self::F16(dst)) => {
2633 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2634 }
2635 (Self::F32(src), Self::F32(dst)) => {
2636 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2637 }
2638 (Self::F64(src), Self::F64(dst)) => {
2639 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2640 }
2641 (Self::F8E4M3(src), Self::F8E4M3(dst)) => {
2642 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2643 }
2644 (Self::F6E2M3(src), Self::F6E2M3(dst)) => {
2645 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2646 }
2647 (Self::F6E3M2(src), Self::F6E3M2(dst)) => {
2648 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2649 }
2650 (Self::F4(src), Self::F4(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
2651 (Self::F8E8M0(src), Self::F8E8M0(dst)) => {
2652 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2653 }
2654 (_, dst) => {
2655 return Err(Error::DTypeMismatchBinaryOp {
2656 lhs: self.dtype(),
2657 rhs: dst.dtype(),
2658 op: "copy2d",
2659 }
2660 .bt());
2661 }
2662 }
2663 Ok(())
2664 }
2665
2666 fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
2667 match (self, dst) {
2668 (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2669 (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2670 (Self::I16(src), Self::I16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2671 (Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2672 (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2673 (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2674 (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2675 (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2676 (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2677 (Self::F8E4M3(src), Self::F8E4M3(dst)) => {
2678 copy_strided_src_(src, dst, dst_offset, src_l)
2679 }
2680 (Self::F6E2M3(src), Self::F6E2M3(dst)) => {
2681 copy_strided_src_(src, dst, dst_offset, src_l)
2682 }
2683 (Self::F6E3M2(src), Self::F6E3M2(dst)) => {
2684 copy_strided_src_(src, dst, dst_offset, src_l)
2685 }
2686 (Self::F4(src), Self::F4(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2687 (Self::F8E8M0(src), Self::F8E8M0(dst)) => {
2688 copy_strided_src_(src, dst, dst_offset, src_l)
2689 }
2690 (_, dst) => {
2691 return Err(Error::DTypeMismatchBinaryOp {
2693 lhs: self.dtype(),
2694 rhs: dst.dtype(),
2695 op: "copy_strided",
2696 }
2697 .bt());
2698 }
2699 }
2700 Ok(())
2701 }
2702
2703 fn where_cond(
2704 &self,
2705 layout: &Layout,
2706 t: &Self,
2707 t_l: &Layout,
2708 f: &Self,
2709 f_l: &Layout,
2710 ) -> Result<Self> {
2711 match self {
2712 Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2713 Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2714 Self::I16(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2715 Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2716 Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2717 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
2718 }
2719 }
2720
2721 fn conv1d(
2722 &self,
2723 l: &Layout,
2724 kernel: &Self,
2725 kernel_l: &Layout,
2726 params: &crate::conv::ParamsConv1D,
2727 ) -> Result<Self> {
2728 if !USE_IM2COL_CONV1D {
2729 return Conv1D(params).map(self, l, kernel, kernel_l);
2730 }
2731 let op = Im2Col1D {
2732 l_k: params.k_size,
2733 padding: params.padding,
2734 stride: params.stride,
2735 dilation: params.dilation,
2736 };
2737 let col = op.map(self, l)?;
2738 let b = params.b_size;
2739 let n = params.c_out;
2740 let l_out = params.l_out();
2741 let k = op.l_k * params.c_in;
2742 let m = l_out;
2743 let col_l = Layout::contiguous((b, m, k));
2744 let res = if kernel_l.is_contiguous() {
2745 let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2746 .transpose(1, 2)?
2747 .broadcast_as((b, k, n))?;
2748 col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2749 } else {
2750 let mut kernel_c = unsafe {
2752 self.device()
2753 .alloc_uninit(kernel_l.shape(), kernel.dtype())?
2754 };
2755 kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
2756 let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2757 .transpose(1, 2)?
2758 .broadcast_as((b, k, n))?;
2759 col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2760 };
2761 let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
2762 let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
2763 res.copy_strided_src(&mut res_t, 0, &res_l)?;
2764 Ok(res_t)
2765 }
2766
2767 fn conv_transpose1d(
2768 &self,
2769 l: &Layout,
2770 kernel: &Self,
2771 kernel_l: &Layout,
2772 params: &crate::conv::ParamsConvTranspose1D,
2773 ) -> Result<Self> {
2774 let can_use_col2im = kernel_l.is_contiguous()
2775 && params.dilation == 1
2776 && params.padding == 0
2777 && params.output_padding == 0;
2778 if USE_COL2IM_CONV1D_TR && can_use_col2im {
2779 let (b_size, c_in, l_in) = l.shape().dims3()?;
2780 let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
2781 if !kernel_l.is_contiguous() {
2782 crate::bail!(
2783 "convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
2784 )
2785 }
2786 if c_in != c_in2 {
2787 crate::bail!(
2788 "convtr1d: shape mismatch on c_in {:?} {:?}",
2789 l.shape(),
2790 kernel_l.shape()
2791 )
2792 }
2793 let col = {
2794 let kernel_l_mm = Layout::new(
2796 (b_size, c_in, k_size * c_out).into(),
2797 vec![0, k_size * c_out, 1],
2798 kernel_l.start_offset(),
2799 );
2800 self.matmul(
2801 kernel,
2802 (
2803 b_size,
2804 l_in,
2805 c_out * k_size,
2806 c_in,
2807 ),
2808 &l.transpose(1, 2)?,
2809 &kernel_l_mm,
2810 )?
2811 };
2812 let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
2813 Col2Im1D {
2814 stride: params.stride,
2815 }
2816 .map(&col, &col_l)
2817 } else {
2818 ConvTranspose1D(params).map(self, l, kernel, kernel_l)
2819 }
2820 }
2821
2822 fn conv2d(
2823 &self,
2824 l: &Layout,
2825 kernel: &Self,
2826 kernel_l: &Layout,
2827 params: &crate::conv::ParamsConv2D,
2828 ) -> Result<Self> {
2829 Conv2D(params).map(self, l, kernel, kernel_l)
2830 }
2831
2832 fn conv_transpose2d(
2833 &self,
2834 l: &Layout,
2835 kernel: &Self,
2836 kernel_l: &Layout,
2837 params: &crate::conv::ParamsConvTranspose2D,
2838 ) -> Result<Self> {
2839 ConvTranspose2D(params).map(self, l, kernel, kernel_l)
2840 }
2841
2842 fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
2843 match ids {
2844 Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2845 Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2846 Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2847 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
2848 }
2849 }
2850
2851 fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
2852 match ids {
2853 Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
2854 Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
2855 Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
2856 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
2857 }
2858 }
2859
2860 fn scatter_set(
2861 &mut self,
2862 l: &Layout,
2863 ids: &Self,
2864 ids_l: &Layout,
2865 src: &Self,
2866 src_l: &Layout,
2867 dim: usize,
2868 ) -> Result<()> {
2869 match ids {
2870 Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2871 Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2872 Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2873 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()),
2874 }
2875 }
2876
2877 fn scatter_add_set(
2878 &mut self,
2879 l: &Layout,
2880 ids: &Self,
2881 ids_l: &Layout,
2882 src: &Self,
2883 src_l: &Layout,
2884 dim: usize,
2885 ) -> Result<()> {
2886 match ids {
2887 Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2888 Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2889 Self::I16(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2890 Self::I32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2891 Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2892 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
2893 }
2894 }
2895
2896 fn index_add(
2897 &self,
2898 l: &Layout,
2899 ids: &Self,
2900 ids_l: &Layout,
2901 src: &Self,
2902 src_l: &Layout,
2903 dim: usize,
2904 ) -> Result<Self> {
2905 match ids {
2906 Self::U8(ids) => {
2907 let ids = match ids_l.contiguous_offsets() {
2908 Some((a, b)) => &ids[a..b],
2909 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2910 };
2911 IndexAdd { ids, dim }.map(self, l, src, src_l)
2912 }
2913 Self::U32(ids) => {
2914 let ids = match ids_l.contiguous_offsets() {
2915 Some((a, b)) => &ids[a..b],
2916 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2917 };
2918 IndexAdd { ids, dim }.map(self, l, src, src_l)
2919 }
2920 Self::I16(ids) => {
2921 let ids = match ids_l.contiguous_offsets() {
2922 Some((a, b)) => &ids[a..b],
2923 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2924 };
2925 IndexAdd { ids, dim }.map(self, l, src, src_l)
2926 }
2927 Self::I32(ids) => {
2928 let ids = match ids_l.contiguous_offsets() {
2929 Some((a, b)) => &ids[a..b],
2930 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2931 };
2932 IndexAdd { ids, dim }.map(self, l, src, src_l)
2933 }
2934 Self::I64(ids) => {
2935 let ids = match ids_l.contiguous_offsets() {
2936 Some((a, b)) => &ids[a..b],
2937 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2938 };
2939 IndexAdd { ids, dim }.map(self, l, src, src_l)
2940 }
2941 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
2942 }
2943 }
2944
2945 fn matmul(
2946 &self,
2947 rhs: &Self,
2948 bmnk: (usize, usize, usize, usize),
2949 lhs_l: &Layout,
2950 rhs_l: &Layout,
2951 ) -> Result<Self> {
2952 MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
2953 }
2954
2955 fn device(&self) -> &Self::Device {
2956 &CpuDevice
2957 }
2958
2959 fn try_clone(&self, _: &Layout) -> Result<Self> {
2960 Ok(self.clone())
2961 }
2962
2963 fn to_cpu_storage(&self) -> Result<CpuStorage> {
2964 Ok(self.clone())
2965 }
2966
2967 fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
2968 use crate::scalar::Scalar;
2969 fn set<T: crate::WithDType>(src: &mut [T], l: &Layout, s: T) {
2970 match l.strided_blocks() {
2971 crate::StridedBlocks::SingleBlock { start_offset, len } => {
2972 src[start_offset..start_offset + len].fill(s)
2973 }
2974 crate::StridedBlocks::MultipleBlocks {
2975 block_start_index,
2976 block_len: 1,
2977 } => {
2978 for src_index in block_start_index {
2979 src[src_index] = s
2980 }
2981 }
2982 crate::StridedBlocks::MultipleBlocks {
2983 block_start_index,
2984 block_len,
2985 } => {
2986 for src_index in block_start_index {
2987 src[src_index..src_index + block_len].fill(s)
2988 }
2989 }
2990 }
2991 }
2992 match (self, s) {
2993 (Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v),
2994 (Self::F16(storage), Scalar::F16(v)) => set(storage, l, v),
2995 (Self::F32(storage), Scalar::F32(v)) => set(storage, l, v),
2996 (Self::F64(storage), Scalar::F64(v)) => set(storage, l, v),
2997 (Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
2998 (Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
2999 (Self::I16(storage), Scalar::I16(v)) => set(storage, l, v),
3000 (Self::I32(storage), Scalar::I32(v)) => set(storage, l, v),
3001 (Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
3002 (Self::F8E4M3(storage), Scalar::F8E4M3(v)) => set(storage, l, v),
3003 (Self::F6E2M3(_), _) => {
3005 crate::bail!("const_set not supported for dummy type F6E2M3")
3006 }
3007 (Self::F6E3M2(_), _) => {
3008 crate::bail!("const_set not supported for dummy type F6E3M2")
3009 }
3010 (Self::F4(_), _) => {
3011 crate::bail!("const_set not supported for dummy type F4")
3012 }
3013 (Self::F8E8M0(_), _) => {
3014 crate::bail!("const_set not supported for dummy type F8E8M0")
3015 }
3016 (st, s) => crate::bail!(
3017 "const_set dtype mismatch, expected {:?} but got {:?}",
3018 st.dtype(),
3019 s
3020 ),
3021 }
3022 Ok(())
3023 }
3024}
3025
3026impl BackendDevice for CpuDevice {
3027 type Storage = CpuStorage;
3028
3029 fn location(&self) -> crate::DeviceLocation {
3030 crate::DeviceLocation::Cpu
3031 }
3032
3033 fn same_device(&self, _: &Self) -> bool {
3034 true
3035 }
3036
3037 fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
3038 Ok(T::to_cpu_storage(s))
3039 }
3040
3041 fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
3042 Ok(s.clone())
3043 }
3044
3045 fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> {
3046 Ok(s)
3047 }
3048
3049 fn new(_: usize) -> Result<Self> {
3050 Ok(Self)
3051 }
3052
3053 fn set_seed(&self, _seed: u64) -> Result<()> {
3054 crate::bail!("cannot seed the CPU rng with set_seed")
3055 }
3056
3057 fn get_current_seed(&self) -> Result<u64> {
3058 crate::bail!("cannot get the CPU rng seed with get_current_seed")
3059 }
3060
3061 fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
3062 use rand::prelude::*;
3063
3064 let elem_count = shape.elem_count();
3065 let mut rng = rand::rng();
3066 match dtype {
3067 DType::U8
3068 | DType::U32
3069 | DType::I16
3070 | DType::I32
3071 | DType::I64
3072 | DType::F6E2M3
3073 | DType::F6E3M2
3074 | DType::F4
3075 | DType::F8E8M0 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()),
3076 DType::BF16 => {
3077 let mut data = Vec::with_capacity(elem_count);
3078 let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max))
3079 .map_err(Error::wrap)?;
3080 for _i in 0..elem_count {
3081 data.push(rng.sample::<bf16, _>(uniform))
3082 }
3083 Ok(CpuStorage::BF16(data))
3084 }
3085 DType::F16 => {
3086 let mut data = Vec::with_capacity(elem_count);
3087 let uniform = rand::distr::Uniform::new(f16::from_f64(min), f16::from_f64(max))
3088 .map_err(Error::wrap)?;
3089 for _i in 0..elem_count {
3090 data.push(rng.sample::<f16, _>(uniform))
3091 }
3092 Ok(CpuStorage::F16(data))
3093 }
3094 DType::F8E4M3 => {
3095 let mut data = Vec::with_capacity(elem_count);
3096 let uniform =
3097 rand::distr::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max))
3098 .map_err(Error::wrap)?;
3099 for _i in 0..elem_count {
3100 data.push(rng.sample::<F8E4M3, _>(uniform))
3101 }
3102 Ok(CpuStorage::F8E4M3(data))
3103 }
3104 DType::F32 => {
3105 let mut data = Vec::with_capacity(elem_count);
3106 let uniform =
3107 rand::distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
3108 for _i in 0..elem_count {
3109 data.push(rng.sample::<f32, _>(uniform))
3110 }
3111 Ok(CpuStorage::F32(data))
3112 }
3113 DType::F64 => {
3114 let mut data = Vec::with_capacity(elem_count);
3115 let uniform = rand::distr::Uniform::new(min, max).map_err(Error::wrap)?;
3116 for _i in 0..elem_count {
3117 data.push(rng.sample::<f64, _>(uniform))
3118 }
3119 Ok(CpuStorage::F64(data))
3120 }
3121 }
3122 }
3123
3124 fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CpuStorage> {
3125 use rand::prelude::*;
3126
3127 let elem_count = shape.elem_count();
3128 let mut rng = rand::rng();
3129 match dtype {
3130 DType::U8
3131 | DType::U32
3132 | DType::I16
3133 | DType::I32
3134 | DType::I64
3135 | DType::F6E2M3
3136 | DType::F6E3M2
3137 | DType::F4
3138 | DType::F8E8M0 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
3139 DType::BF16 => {
3140 let mut data = Vec::with_capacity(elem_count);
3141 let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
3142 .map_err(Error::wrap)?;
3143 for _i in 0..elem_count {
3144 data.push(normal.sample(&mut rng))
3145 }
3146 Ok(CpuStorage::BF16(data))
3147 }
3148 DType::F16 => {
3149 let mut data = Vec::with_capacity(elem_count);
3150 let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
3151 .map_err(Error::wrap)?;
3152 for _i in 0..elem_count {
3153 data.push(normal.sample(&mut rng))
3154 }
3155 Ok(CpuStorage::F16(data))
3156 }
3157 DType::F8E4M3 => {
3158 let mut data = Vec::with_capacity(elem_count);
3159 let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std))
3160 .map_err(Error::wrap)?;
3161 for _i in 0..elem_count {
3162 data.push(normal.sample(&mut rng))
3163 }
3164 Ok(CpuStorage::F8E4M3(data))
3165 }
3166 DType::F32 => {
3167 let mut data = Vec::with_capacity(elem_count);
3168 let normal =
3169 rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
3170 for _i in 0..elem_count {
3171 data.push(normal.sample(&mut rng))
3172 }
3173 Ok(CpuStorage::F32(data))
3174 }
3175 DType::F64 => {
3176 let mut data = Vec::with_capacity(elem_count);
3177 let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
3178 for _i in 0..elem_count {
3179 data.push(normal.sample(&mut rng))
3180 }
3181 Ok(CpuStorage::F64(data))
3182 }
3183 }
3184 }
3185
3186 #[allow(clippy::uninit_vec)]
3187 unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
3188 let elem_count = shape.elem_count();
3189 let storage = match dtype {
3194 DType::U8 => {
3195 let mut v = Vec::with_capacity(elem_count);
3196 v.set_len(elem_count);
3197 CpuStorage::U8(v)
3198 }
3199 DType::U32 => {
3200 let mut v = Vec::with_capacity(elem_count);
3201 v.set_len(elem_count);
3202 CpuStorage::U32(v)
3203 }
3204 DType::I16 => {
3205 let mut v = Vec::with_capacity(elem_count);
3206 v.set_len(elem_count);
3207 CpuStorage::I16(v)
3208 }
3209 DType::I32 => {
3210 let mut v = Vec::with_capacity(elem_count);
3211 v.set_len(elem_count);
3212 CpuStorage::I32(v)
3213 }
3214 DType::I64 => {
3215 let mut v = Vec::with_capacity(elem_count);
3216 v.set_len(elem_count);
3217 CpuStorage::I64(v)
3218 }
3219 DType::BF16 => {
3220 let mut v = Vec::with_capacity(elem_count);
3221 v.set_len(elem_count);
3222 CpuStorage::BF16(v)
3223 }
3224 DType::F16 => {
3225 let mut v = Vec::with_capacity(elem_count);
3226 v.set_len(elem_count);
3227 CpuStorage::F16(v)
3228 }
3229 DType::F32 => {
3230 let mut v = Vec::with_capacity(elem_count);
3231 v.set_len(elem_count);
3232 CpuStorage::F32(v)
3233 }
3234 DType::F64 => {
3235 let mut v = Vec::with_capacity(elem_count);
3236 v.set_len(elem_count);
3237 CpuStorage::F64(v)
3238 }
3239 DType::F8E4M3 => {
3240 let mut v = Vec::with_capacity(elem_count);
3241 v.set_len(elem_count);
3242 CpuStorage::F8E4M3(v)
3243 }
3244 DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
3245 return Err(Error::UnsupportedDTypeForOp(dtype, "alloc_uninit").bt())
3246 }
3247 };
3248 Ok(storage)
3249 }
3250
3251 fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
3252 let elem_count = shape.elem_count();
3253 let storage = match dtype {
3254 DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
3255 DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
3256 DType::I16 => CpuStorage::I16(vec![0i16; elem_count]),
3257 DType::I32 => CpuStorage::I32(vec![0i32; elem_count]),
3258 DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
3259 DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
3260 DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
3261 DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
3262 DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
3263 DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]),
3264 DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
3265 return Err(Error::UnsupportedDTypeForOp(dtype, "zeros").bt())
3266 }
3267 };
3268 Ok(storage)
3269 }
3270
3271 fn synchronize(&self) -> Result<()> {
3272 Ok(())
3273 }
3274}
3275
3276#[macro_export]
3277macro_rules! map_dtype {
3278 ($name:expr, $storage:ident, $fn:expr, ($($dtypes:ident),+)) => {
3279 match $storage {
3280 $(CpuStorage::$dtypes(__e) => CpuStorage::$dtypes($fn(__e)),)*
3281 s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())?,
3282 }
3283 };
3284}