1use crate::backend::{BackendDevice, BackendStorage};
3use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
4use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
5use float8::F8E4M3;
6use half::{bf16, f16};
7use rayon::prelude::*;
8
9mod utils;
10pub use utils::{
11 binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
12};
13mod conv2d;
14use conv2d::Conv2D;
15
16const USE_IM2COL_CONV1D: bool = true;
17const USE_COL2IM_CONV1D_TR: bool = true;
18
19#[derive(Debug, Clone)]
22pub enum CpuStorage {
23 U8(Vec<u8>),
24 U32(Vec<u32>),
25 I16(Vec<i16>),
26 I32(Vec<i32>),
27 I64(Vec<i64>),
28 BF16(Vec<bf16>),
29 F16(Vec<f16>),
30 F32(Vec<f32>),
31 F64(Vec<f64>),
32 F8E4M3(Vec<F8E4M3>),
33 F6E2M3(Vec<u8>),
35 F6E3M2(Vec<u8>),
36 F4(Vec<u8>),
37 F8E8M0(Vec<u8>),
38}
39
40#[derive(Debug, Clone)]
41pub enum CpuStorageRef<'a> {
42 U8(&'a [u8]),
43 U32(&'a [u32]),
44 I16(&'a [i16]),
45 I32(&'a [i32]),
46 I64(&'a [i64]),
47 BF16(&'a [bf16]),
48 F16(&'a [f16]),
49 F32(&'a [f32]),
50 F64(&'a [f64]),
51 F8E4M3(&'a [F8E4M3]),
52 F6E2M3(&'a [u8]),
54 F6E3M2(&'a [u8]),
55 F4(&'a [u8]),
56 F8E8M0(&'a [u8]),
57}
58
59#[derive(Debug, Clone)]
60pub struct CpuDevice;
61
62struct Cmp(CmpOp);
63impl Map2U8 for Cmp {
64 const OP: &'static str = "cmp";
65 #[inline(always)]
66 fn f<T: WithDType>(
67 &self,
68 lhs: &[T],
69 lhs_l: &Layout,
70 rhs: &[T],
71 rhs_l: &Layout,
72 ) -> Result<Vec<u8>> {
73 let dst = match self.0 {
74 CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)),
75 CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)),
76 CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)),
77 CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)),
78 CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)),
79 CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)),
80 };
81 Ok(dst)
82 }
83}
84
85struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
86
87impl<I: IntDType> Map2 for WCond<'_, I> {
88 const OP: &'static str = "where";
89 #[inline(always)]
90 fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
91 let vs = match (
92 self.1.contiguous_offsets(),
93 t_l.contiguous_offsets(),
94 f_l.contiguous_offsets(),
95 ) {
96 (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
97 let pred = &self.0[o1..o2];
98 let t = &t[o_t1..o_t2];
99 let f = &f[o_f1..o_f2];
100 pred.iter()
101 .zip(t.iter().zip(f.iter()))
102 .map(|(p, (&t, &f))| if p.is_true() { t } else { f })
103 .collect::<Vec<_>>()
104 }
105 _ => self
106 .1
107 .strided_index()
108 .zip(t_l.strided_index().zip(f_l.strided_index()))
109 .map(|(i_p, (i_t, i_f))| {
110 if self.0[i_p].is_true() {
111 t[i_t]
112 } else {
113 f[i_f]
114 }
115 })
116 .collect::<Vec<_>>(),
117 };
118 Ok(vs)
119 }
120}
121
122struct ReduceIndex {
123 reduce_dim_index: usize,
124 use_min: bool,
125 return_index: bool,
126}
127
128impl ReduceIndex {
129 #[inline(always)]
131 fn fold_impl<T, U, F, G>(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result<Vec<U>>
132 where
133 T: Clone + Copy,
134 U: Clone + Copy,
135 F: Fn(T, T) -> bool,
136 G: Fn(T, usize) -> U,
137 {
138 let reduce_dim_size = src_l.dims()[self.reduce_dim_index];
139 let reduce_dim_stride = src_l.stride()[self.reduce_dim_index];
140 let dst_len = src_l.shape().elem_count() / reduce_dim_size;
141 let mut dst: Vec<U> = Vec::with_capacity(dst_len);
142 let dst_to_set = dst.spare_capacity_mut();
143 let dst_to_set =
144 unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };
145 match src_l.contiguous_offsets() {
146 Some((o1, o2)) => {
147 let src = &src[o1..o2];
148 if reduce_dim_stride == 1 {
149 for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
150 let start_src_i = start_src_i * reduce_dim_size;
151 let src = &src[start_src_i..start_src_i + reduce_dim_size];
152 let mut acc = 0;
153 let mut val = src[0];
154 for (src_i, &s) in src.iter().enumerate() {
155 if f(val, s) {
156 acc = src_i;
157 val = s
158 }
159 }
160 *dst_v = g(val, acc)
161 }
162 } else {
163 for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
164 let (p, q) = (
165 start_src_i / reduce_dim_stride,
166 start_src_i % reduce_dim_stride,
167 );
168 let start_src_i = p * reduce_dim_stride * reduce_dim_size + q;
170 let src = &src[start_src_i..];
171 let mut acc = 0;
172 let mut val = src[0];
173 for src_i in 0..reduce_dim_size {
174 let s = src[src_i * reduce_dim_stride];
175 if f(val, s) {
176 acc = src_i;
177 val = s
178 }
179 }
180 *dst_v = g(val, acc)
181 }
182 }
183 }
184 None => {
185 let l = src_l.narrow(self.reduce_dim_index, 0, 1)?;
186 for (unstr_index, src_index) in l.strided_index().enumerate() {
187 let src = &src[src_index..];
188 let mut acc = 0;
189 let mut val = src[0];
190 for src_i in 0..reduce_dim_size {
191 let s = src[src_i * reduce_dim_stride];
192 if f(val, s) {
193 acc = src_i;
194 val = s
195 }
196 }
197 dst_to_set[unstr_index] = g(val, acc)
198 }
199 }
200 }
201 unsafe { dst.set_len(dst_len) };
202 Ok(dst)
203 }
204}
205
206impl Map1Any for ReduceIndex {
207 #[inline(always)]
208 fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
209 &self,
210 src: &[T],
211 src_l: &Layout,
212 wrap: W,
213 ) -> Result<CpuStorage> {
214 if src_l.shape().elem_count() == 0 {
215 Err(Error::EmptyTensor { op: "reduce" }.bt())?
216 }
217 let dst = match (self.return_index, self.use_min) {
218 (false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?),
219 (false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?),
220 (true, true) => {
221 CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?)
222 }
223 (true, false) => {
224 CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?)
225 }
226 };
227 Ok(dst)
228 }
229}
230
231struct ReduceSum<'a> {
232 dst_shape: &'a Shape,
233 reduce_dims: &'a [usize],
234 reduce_dims_and_stride: Vec<(usize, usize)>,
235}
236
237impl ReduceSum<'_> {
238 #[inline(always)]
239 fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
240 where
241 T: WithDType,
242 {
243 let mut dst = vec![start_elt; self.dst_shape.elem_count()];
244 match src_l.contiguous_offsets() {
245 Some((o1, o2)) => {
246 let src = &src[o1..o2];
247 let reduce_over_last_dims = self
251 .reduce_dims
252 .iter()
253 .rev()
254 .enumerate()
255 .all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
256 if reduce_over_last_dims {
257 let reduce_sz = self
258 .reduce_dims_and_stride
259 .iter()
260 .map(|(u, _)| u)
261 .product::<usize>();
262 for (dst_i, dst_v) in dst.iter_mut().enumerate() {
263 let src_i = dst_i * reduce_sz;
264 unsafe {
265 T::vec_reduce_sum(
266 src[src_i..src_i + reduce_sz].as_ptr(),
267 dst_v,
268 reduce_sz,
269 )
270 };
271 }
272 return Ok(dst);
273 };
274 for (unstr_index, &src) in src.iter().enumerate() {
275 let mut dst_index = unstr_index;
276 for &(dim, stride) in self.reduce_dims_and_stride.iter() {
278 let (pre, post) = (dst_index / stride, dst_index % stride);
280 dst_index = (pre / dim) * stride + post;
281 }
282 dst[dst_index] += src;
283 }
284 }
285 None => {
286 for (unstr_index, src_index) in src_l.strided_index().enumerate() {
287 let mut dst_index = unstr_index;
288 for &(dim, stride) in self.reduce_dims_and_stride.iter() {
290 let (pre, post) = (dst_index / stride, dst_index % stride);
292 dst_index = (pre / dim) * stride + post;
293 }
294 dst[dst_index] += src[src_index];
295 }
296 }
297 }
298 Ok(dst)
299 }
300}
301
302impl Map1 for ReduceSum<'_> {
303 #[inline(always)]
304 fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
305 self.fold_impl(src, src_l, T::zero())
306 }
307}
308
309struct Affine(f64, f64);
310
311impl Map1 for Affine {
312 fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
313 let mul = T::from_f64(self.0);
314 let add = T::from_f64(self.1);
315 Ok(unary_map(vs, layout, |v| v * mul + add))
316 }
317}
318
319struct AvgPool2D((usize, usize), (usize, usize));
320
321impl Map1 for AvgPool2D {
322 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
323 let (k_h, k_w) = self.0;
325 let (s_h, s_w) = self.1;
326 let (b_sz, c, h, w) = layout.shape().dims4()?;
327 let stride = layout.stride();
328 let (stride_h, stride_w) = (stride[2], stride[3]);
329 let h_out = (h - k_h) / s_h + 1;
330 let w_out = (w - k_w) / s_w + 1;
331 let src_index = layout.start_offset();
332 let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
333 let scale = 1f64 / (k_h * k_w) as f64;
334 let scale = T::from_f64(scale);
335 for b_idx in 0..b_sz {
336 let dst = &mut dst[b_idx * c * h_out * w_out..];
337 let src_index = src_index + b_idx * stride[0];
338 for c_idx in 0..c {
339 let dst = &mut dst[c_idx * h_out * w_out..];
340 let src_index = src_index + c_idx * stride[1];
341 for h_idx in 0..h_out {
342 for w_idx in 0..w_out {
343 let mut sum = T::zero();
344 for m in 0..k_h {
345 for n in 0..k_w {
346 let m = s_h * h_idx + m;
347 let n = s_w * w_idx + n;
348 sum += src[src_index + m * stride_h + n * stride_w]
349 }
350 }
351 dst[h_idx * w_out + w_idx] = sum * scale;
352 }
353 }
354 }
355 }
356 Ok(dst)
357 }
358}
359
360struct MaxPool2D((usize, usize), (usize, usize));
361
362impl Map1 for MaxPool2D {
363 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
364 let (k_h, k_w) = self.0;
366 let (s_h, s_w) = self.1;
367 let (b_sz, c, h, w) = layout.shape().dims4()?;
368 let stride = layout.stride();
369 let (stride_h, stride_w) = (stride[2], stride[3]);
370 let h_out = (h - k_h) / s_h + 1;
371 let w_out = (w - k_w) / s_w + 1;
372 let src_index = layout.start_offset();
373 let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
374 for b_idx in 0..b_sz {
375 let dst = &mut dst[b_idx * c * h_out * w_out..];
376 let src_index = src_index + b_idx * stride[0];
377 for c_idx in 0..c {
378 let dst = &mut dst[c_idx * h_out * w_out..];
379 let src_index = src_index + c_idx * stride[1];
380 for h_idx in 0..h_out {
381 for w_idx in 0..w_out {
382 let mut largest =
383 src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
384 for m in 0..k_h {
385 for n in 0..k_w {
386 let m = s_h * h_idx + m;
387 let n = s_w * w_idx + n;
388 if largest < src[src_index + m * stride_h + n * stride_w] {
389 largest = src[src_index + m * stride_h + n * stride_w]
390 }
391 }
392 }
393 dst[h_idx * w_out + w_idx] = largest;
394 }
395 }
396 }
397 }
398 Ok(dst)
399 }
400}
401
402struct UpsampleNearest1D(usize);
403
404impl Map1 for UpsampleNearest1D {
405 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
406 let dst_sz = self.0;
408 let (b_sz, c, src_sz) = layout.shape().dims3()?;
409 let stride = layout.stride();
410 let stride_sz = stride[2];
411 let src_index = layout.start_offset();
412 let scale_sz = src_sz as f64 / dst_sz as f64;
413 let mut dst = vec![T::zero(); b_sz * c * dst_sz];
414 let src_idxs = (0..dst_sz)
415 .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
416 .collect::<Vec<_>>();
417 for b_idx in 0..b_sz {
418 let dst = &mut dst[b_idx * c * dst_sz..];
419 let src_index = src_index + b_idx * stride[0];
420 for c_idx in 0..c {
421 let dst = &mut dst[c_idx * dst_sz..];
422 let src_index = src_index + c_idx * stride[1];
423 for (idx, src_idx) in src_idxs.iter().enumerate() {
424 dst[idx] = src[src_index + src_idx * stride_sz]
425 }
426 }
427 }
428 Ok(dst)
429 }
430}
431
432struct UpsampleNearest2D(usize, usize);
433
434impl Map1 for UpsampleNearest2D {
435 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
436 let (dst_h, dst_w) = (self.0, self.1);
438 let (b_sz, c, src_h, src_w) = layout.shape().dims4()?;
439 let stride = layout.stride();
440 let (stride_h, stride_w) = (stride[2], stride[3]);
441 let src_index = layout.start_offset();
442 let scale_h = src_h as f64 / dst_h as f64;
443 let scale_w = src_w as f64 / dst_w as f64;
444 let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
445 let src_h_idxs = (0..dst_h)
446 .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
447 .collect::<Vec<_>>();
448 let src_w_idxs = (0..dst_w)
449 .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
450 .collect::<Vec<_>>();
451 for b_idx in 0..b_sz {
452 let dst = &mut dst[b_idx * c * dst_h * dst_w..];
453 let src_index = src_index + b_idx * stride[0];
454 for c_idx in 0..c {
455 let dst = &mut dst[c_idx * dst_h * dst_w..];
456 let src_index = src_index + c_idx * stride[1];
457 for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() {
458 for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() {
459 let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w;
460 dst[h_idx * dst_w + w_idx] = src[src_index]
461 }
462 }
463 }
464 }
465 Ok(dst)
466 }
467}
468
469struct UpsampleBilinear2D {
470 target_h: usize,
471 target_w: usize,
472 align_corners: bool,
473 scale_h_factor: Option<f64>,
474 scale_w_factor: Option<f64>,
475}
476
477impl Map1 for UpsampleBilinear2D {
478 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
479 let (batch, channels, height_in, width_in) = layout.shape().dims4()?;
480 let height_out = self.target_h;
481 let width_out = self.target_w;
482
483 if height_in == height_out && width_in == width_out {
485 return Ok(src.to_vec());
486 }
487
488 let stride = layout.stride();
489 let src_offset = layout.start_offset();
490
491 let scale_h = if self.align_corners {
493 if height_out > 1 {
494 (height_in - 1) as f64 / (height_out - 1) as f64
495 } else {
496 0.0
497 }
498 } else {
499 if let Some(scale_factor) = self.scale_h_factor {
503 1.0 / scale_factor
504 } else {
505 height_in as f64 / height_out as f64
506 }
507 };
508
509 let scale_w = if self.align_corners {
510 if width_out > 1 {
511 (width_in - 1) as f64 / (width_out - 1) as f64
512 } else {
513 0.0
514 }
515 } else if let Some(scale_factor) = self.scale_w_factor {
516 1.0 / scale_factor
517 } else {
518 width_in as f64 / width_out as f64
519 };
520
521 let mut h_indices = Vec::with_capacity(height_out);
523 for h_out in 0..height_out {
524 let src_h = if self.align_corners {
525 scale_h * h_out as f64
526 } else {
527 scale_h * (h_out as f64 + 0.5) - 0.5
528 };
529 let src_h_clamped = src_h.max(0.0);
530 let h0 = src_h_clamped.floor() as usize;
531 let h1 = (h0 + 1).min(height_in - 1);
532 let weight_h = (src_h_clamped - h0 as f64).clamp(0.0, 1.0);
533 h_indices.push((h0, h1, weight_h));
534 }
535
536 let mut w_indices = Vec::with_capacity(width_out);
538 for w_out in 0..width_out {
539 let src_w = if self.align_corners {
540 scale_w * w_out as f64
541 } else {
542 scale_w * (w_out as f64 + 0.5) - 0.5
543 };
544 let src_w_clamped = src_w.max(0.0);
545 let w0 = src_w_clamped.floor() as usize;
546 let w1 = (w0 + 1).min(width_in - 1);
547 let weight_w = (src_w_clamped - w0 as f64).clamp(0.0, 1.0);
548 w_indices.push((w0, w1, weight_w));
549 }
550
551 let mut dst = vec![T::zero(); batch * channels * height_out * width_out];
553
554 for b in 0..batch {
556 for c in 0..channels {
557 let base_idx = src_offset + b * stride[0] + c * stride[1];
558 let dst_base = (b * channels + c) * height_out * width_out;
559
560 for (h_out, &(h0, h1, weight_h)) in h_indices.iter().enumerate() {
561 for (w_out, &(w0, w1, weight_w)) in w_indices.iter().enumerate() {
562 let idx_00 = base_idx + h0 * stride[2] + w0 * stride[3];
564 let idx_10 = base_idx + h0 * stride[2] + w1 * stride[3];
565 let idx_01 = base_idx + h1 * stride[2] + w0 * stride[3];
566 let idx_11 = base_idx + h1 * stride[2] + w1 * stride[3];
567
568 let v00 = src[idx_00].to_f64();
569 let v10 = src[idx_10].to_f64();
570 let v01 = src[idx_01].to_f64();
571 let v11 = src[idx_11].to_f64();
572
573 let v_top = v00 * (1.0 - weight_w) + v10 * weight_w;
575 let v_bottom = v01 * (1.0 - weight_w) + v11 * weight_w;
576 let value = v_top * (1.0 - weight_h) + v_bottom * weight_h;
577
578 dst[dst_base + h_out * width_out + w_out] = T::from_f64(value);
579 }
580 }
581 }
582 }
583
584 Ok(dst)
585 }
586}
587
588struct Gather<'a, I: IntDType> {
589 ids: &'a [I],
590 ids_l: &'a Layout,
591 dim: usize,
592}
593
594impl<I: IntDType> Map1 for Gather<'_, I> {
595 fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
596 let ids = match self.ids_l.contiguous_offsets() {
597 Some((a, b)) => &self.ids[a..b],
598 None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
599 };
600 let src = match src_l.contiguous_offsets() {
601 Some((a, b)) => &src[a..b],
602 None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
603 };
604 let dim = self.dim;
605 let ids_dims = self.ids_l.dims();
606 let src_dims = src_l.dims();
607 let dst_len: usize = ids_dims.iter().product();
608 let dst_left_len: usize = ids_dims[..dim].iter().product();
609 let dst_dim_len = ids_dims[dim];
610 let dst_right_len: usize = ids_dims[dim + 1..].iter().product();
611
612 let src_dim_len = src_dims[dim];
613 let src_right_len: usize = src_dims[dim + 1..].iter().product();
614
615 let mut dst = vec![T::zero(); dst_len];
616 for left_i in 0..dst_left_len {
617 let start_src_idx = left_i * src_right_len * src_dim_len;
618 let start_dst_idx = left_i * dst_right_len * dst_dim_len;
619 for i in 0..dst_dim_len {
620 let start_dst_idx = start_dst_idx + i * dst_right_len;
621 for right_i in 0..dst_right_len {
622 let dst_idx = start_dst_idx + right_i;
623 let index = ids[dst_idx];
624 if index == I::max_value() {
625 dst[dst_idx] = T::zero();
626 } else {
627 let index = index.as_usize();
628 if index >= src_dim_len {
629 Err(Error::InvalidIndex {
630 index,
631 size: src_dim_len,
632 op: "gather",
633 }
634 .bt())?
635 }
636 let src_idx = start_src_idx + index * src_right_len + right_i;
637 dst[dst_idx] = src[src_idx]
638 }
639 }
640 }
641 }
642 Ok(dst)
643 }
644}
645
646struct IndexSelect<'a, T: IntDType> {
647 ids: &'a [T],
648 ids_l: &'a Layout,
649 dim: usize,
650}
651
652impl<I: IntDType> Map1 for IndexSelect<'_, I> {
653 fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
654 let src = match layout.contiguous_offsets() {
655 Some((a, b)) => &src[a..b],
656 None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
657 };
658 let dim = self.dim;
659 let n_ids = match self.ids_l.dims() {
660 [n_ids] => *n_ids,
661 d => Err(Error::UnexpectedNumberOfDims {
662 expected: 1,
663 got: d.len(),
664 shape: self.ids_l.shape().clone(),
665 }
666 .bt())?,
667 };
668 let stride_ids = self.ids_l.stride()[0];
669 let mut dst_dims = layout.dims().to_vec();
670 let src_dim = dst_dims[dim];
671 dst_dims[dim] = n_ids;
672 let dst_len: usize = dst_dims.iter().product();
673 let left_len: usize = dst_dims[..dim].iter().product();
674 let right_len: usize = dst_dims[dim + 1..].iter().product();
675 let mut dst = vec![T::zero(); dst_len];
676 for left_i in 0..left_len {
677 let start_src_idx = left_i * right_len * src_dim;
678 let start_dst_idx = left_i * right_len * n_ids;
679 for i in 0..n_ids {
680 let start_dst_idx = start_dst_idx + i * right_len;
681 let index = self.ids[self.ids_l.start_offset() + stride_ids * i];
682 if index == I::max_value() {
683 dst[start_dst_idx..start_dst_idx + right_len].fill(T::zero());
684 } else {
685 let index = index.as_usize();
686 if index >= src_dim {
687 Err(Error::InvalidIndex {
688 index,
689 size: src_dim,
690 op: "index-select",
691 }
692 .bt())?
693 }
694 let start_src_idx = start_src_idx + index * right_len;
695 dst[start_dst_idx..start_dst_idx + right_len]
696 .copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
697 }
698 }
699 }
700 Ok(dst)
701 }
702}
703
704trait ElemUpdate {
705 fn f<T: WithDType>(dst: &mut T, src: T);
706}
707
708struct Set;
709struct Add;
710
711impl ElemUpdate for Set {
712 fn f<T: WithDType>(dst: &mut T, src: T) {
713 *dst = src
714 }
715}
716
717impl ElemUpdate for Add {
718 fn f<T: WithDType>(dst: &mut T, src: T) {
719 *dst += src
720 }
721}
722
723struct Scatter<'a, I: IntDType, M: ElemUpdate> {
724 ids: &'a [I],
725 ids_l: &'a Layout,
726 dim: usize,
727 _phantom: std::marker::PhantomData<M>,
728}
729
730impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
731 fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self {
732 Self {
733 ids,
734 ids_l,
735 dim,
736 _phantom: Default::default(),
737 }
738 }
739}
740
741impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
742 const OP: &'static str = "scatter";
743 fn f<T: WithDType>(
744 &self,
745 dst: &mut [T],
746 dst_l: &Layout,
747 src: &[T],
748 src_l: &Layout,
749 ) -> Result<()> {
750 let dst = match dst_l.contiguous_offsets() {
751 None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
752 Some((o1, o2)) => &mut dst[o1..o2],
753 };
754
755 let src = match src_l.contiguous_offsets() {
756 None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
757 Some((o1, o2)) => &src[o1..o2],
758 };
759
760 let dim = self.dim;
761 let ids_dims = self.ids_l.dims();
762 let dst_dims = dst_l.dims();
763 let dst_dim_len = dst_dims[dim];
764 let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
765
766 let ids_left_len: usize = ids_dims[..dim].iter().product();
767 let ids_dim_len = ids_dims[dim];
768 let ids_right_len: usize = ids_dims[dim + 1..].iter().product();
769
770 let ids = match self.ids_l.contiguous_offsets() {
771 Some((a, b)) => &self.ids[a..b],
772 None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
773 };
774 for left_i in 0..ids_left_len {
775 let start_ids_idx = left_i * ids_right_len * ids_dim_len;
776 let start_dst_idx = left_i * dst_right_len * dst_dim_len;
777 for i in 0..ids_dim_len {
778 let start_ids_idx = start_ids_idx + i * ids_right_len;
779 for right_i in 0..dst_right_len {
780 let ids_idx = start_ids_idx + right_i;
781 let index = ids[ids_idx];
782 if index == I::max_value() {
783 continue;
784 }
785 let index = index.as_usize();
786 if index >= dst_dim_len {
787 Err(Error::InvalidIndex {
788 index,
789 size: dst_dim_len,
790 op: "gather",
791 }
792 .bt())?
793 }
794 let dst_idx = start_dst_idx + index * dst_right_len + right_i;
795 M::f(&mut dst[dst_idx], src[ids_idx])
796 }
797 }
798 }
799
800 Ok(())
801 }
802}
803
804struct IndexAdd<'a, I: IntDType> {
805 ids: &'a [I],
806 dim: usize,
807}
808
809impl<I: IntDType> Map2 for IndexAdd<'_, I> {
810 const OP: &'static str = "index-add";
811 fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
814 let dst_len = l1.shape().elem_count();
815 let mut dst = vec![T::zero(); dst_len];
816 copy_strided_src_(v1, &mut dst, 0, l1);
817 let src = match src_l.contiguous_offsets() {
818 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
819 Some((o1, o2)) => &src[o1..o2],
820 };
821 let dim = self.dim;
822 let max_idx = l1.dims()[dim];
823 let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
824 let src_dim_sz = src_l.dims()[dim];
825 let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
826 if dim == 0 {
827 for (src_idx, dst_idx) in self.ids.iter().enumerate() {
828 if *dst_idx == I::max_value() {
829 continue;
830 }
831 let dst_idx = dst_idx.as_usize();
832 if dst_idx >= max_idx {
833 Err(Error::InvalidIndex {
834 index: dst_idx,
835 op: "index-add",
836 size: max_idx,
837 })?
838 }
839 let src_idx = src_idx * post_dim;
840 let dst_idx = dst_idx * post_dim;
841 let src = &src[src_idx..src_idx + post_dim];
842 let dst = &mut dst[dst_idx..dst_idx + post_dim];
843 for (d, &s) in dst.iter_mut().zip(src.iter()) {
844 *d += s
845 }
846 }
847 } else {
848 for (src_idx, dst_idx) in self.ids.iter().enumerate() {
849 if *dst_idx == I::max_value() {
850 continue;
851 }
852 let dst_idx = dst_idx.as_usize();
853 if dst_idx >= max_idx {
854 Err(Error::InvalidIndex {
855 index: dst_idx,
856 op: "index-add",
857 size: max_idx,
858 })?
859 }
860 for pre_i in 0..pre_dim {
861 let pre_src_i = (pre_i * src_dim_sz + src_idx) * post_dim;
862 let pre_dst_i = (pre_i * max_idx + dst_idx) * post_dim;
863 let src = &src[pre_src_i..pre_src_i + post_dim];
864 let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim];
865 for (d, &s) in dst.iter_mut().zip(src.iter()) {
866 *d += s
867 }
868 }
869 }
870 }
871 Ok(dst)
872 }
873}
874
875#[allow(clippy::too_many_arguments)]
876fn copy2d_<T: Copy>(
877 src: &[T],
878 dst: &mut [T],
879 d1: usize,
880 d2: usize,
881 src_stride1: usize,
882 dst_stride1: usize,
883 src_offset: usize,
884 dst_offset: usize,
885) {
886 if src_stride1 == d2 && dst_stride1 == d2 {
888 let src_start = src_offset;
889 let dst_start = dst_offset;
890 dst[dst_start..dst_start + d1 * d2].copy_from_slice(&src[src_start..src_start + d1 * d2]);
891 return;
892 }
893 let mut src_idx = src_offset;
894 let mut dst_idx = dst_offset;
895 for _ in 0..d1 {
896 dst[dst_idx..dst_idx + d2].copy_from_slice(&src[src_idx..src_idx + d2]);
897 src_idx += src_stride1;
898 dst_idx += dst_stride1;
899 }
900}
901
902fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
903 match src_l.strided_blocks() {
904 crate::StridedBlocks::SingleBlock { start_offset, len } => dst
905 [dst_offset..dst_offset + len]
906 .copy_from_slice(&src[start_offset..start_offset + len]),
907 crate::StridedBlocks::UniformBlocks {
908 start_offset,
909 block_len,
910 count,
911 src_stride,
912 } => copy2d_(
913 src,
914 dst,
915 count,
916 block_len,
917 src_stride,
918 block_len,
919 start_offset,
920 dst_offset,
921 ),
922 crate::StridedBlocks::MultipleBlocks {
923 block_start_index,
924 block_len: 1,
925 } => {
926 let n = block_start_index.len();
927 let dst = &mut dst[dst_offset..dst_offset + n];
928 for (dst_elem, src_index) in dst.iter_mut().zip(block_start_index) {
929 *dst_elem = src[src_index]
930 }
931 }
932 crate::StridedBlocks::MultipleBlocks {
933 block_start_index,
934 block_len,
935 } => {
936 let n_blocks = block_start_index.len();
937 let dst = &mut dst[dst_offset..dst_offset + n_blocks * block_len];
938 let mut dst_index = 0;
939 for src_index in block_start_index {
940 dst[dst_index..dst_index + block_len]
941 .copy_from_slice(&src[src_index..src_index + block_len]);
942 dst_index += block_len;
943 }
944 }
945 }
946}
947
948struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
949
950impl Map2 for Conv1D<'_> {
951 const OP: &'static str = "conv1d";
952 fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
953 let p = self.0;
954 let inp = &inp[inp_l.start_offset()..];
955 let k = &k[k_l.start_offset()..];
956 let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
957 let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
958 let l_out = p.l_out();
959 let dst_elems = p.c_out * l_out * p.b_size;
960 let dst = vec![T::zero(); dst_elems];
962
963 let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
965 for b_idx in 0..p.b_size {
966 for src_l in 0..p.l_in {
967 for src_c_idx in 0..p.c_in {
968 let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
969 inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
970 }
971 }
972 }
973
974 for offset in 0..p.k_size {
975 (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
976 let dst_idx = dst_c_idx * l_out;
977 let k_cont = (0..p.c_in)
978 .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
979 .collect::<Vec<_>>();
980 for b_idx in 0..p.b_size {
981 let dst_idx = dst_idx + b_idx * p.c_out * l_out;
982 for dst_l in 0..l_out {
983 let dst_idx = dst_idx + dst_l;
984 let src_l = p.stride * dst_l + offset * p.dilation;
985 if src_l < p.padding || src_l >= p.padding + p.l_in {
986 continue;
987 }
988 let src_l = src_l - p.padding;
989 let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
990 assert!(inp_cont.len() >= p.c_in);
991 assert!(k_cont.len() >= p.c_in);
992 let mut d = T::zero();
993 unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
994 let dst_p = dst.as_ptr();
995 unsafe {
999 let ptr = dst_p.add(dst_idx) as *mut T;
1000 *ptr += d
1001 }
1002 }
1003 }
1004 })
1005 }
1006 Ok(dst)
1007 }
1008}
1009
1010struct Im2Col1D {
1011 l_k: usize,
1012 stride: usize,
1013 dilation: usize,
1014 padding: usize,
1015}
1016
1017impl Im2Col1D {
1018 fn l_out(&self, l: usize) -> usize {
1019 (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
1020 }
1021}
1022
1023impl Map1 for Im2Col1D {
1024 fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
1025 let &Self {
1026 l_k,
1027 stride,
1028 dilation,
1029 padding,
1030 } = self;
1031 let (b, c, l) = layout.shape().dims3()?;
1032 let l_out = self.l_out(l);
1033 let src = &vs[layout.start_offset()..];
1034 let mut dst = vec![T::zero(); b * l_out * c * l_k];
1035 let (src_s0, src_s1, src_s2) = {
1036 let s = layout.stride();
1037 (s[0], s[1], s[2])
1038 };
1039 for b_idx in 0..b {
1045 let src_idx = b_idx * src_s0;
1046 let dst_idx = b_idx * l_out * c * l_k;
1047 for l_idx in 0..l_out {
1048 let dst_idx = dst_idx + l_idx * c * l_k;
1049 for c_idx in 0..c {
1050 let dst_idx = dst_idx + c_idx * l_k;
1051 let src_idx = c_idx * src_s1 + src_idx;
1052 for l_k_idx in 0..l_k {
1053 let src_l = l_idx * stride + l_k_idx * dilation;
1054 if padding != 0 && (src_l < padding || src_l >= l + padding) {
1055 continue;
1056 }
1057 let src_l = src_l - padding;
1058 let src_idx = src_idx + src_l * src_s2;
1059 let dst_idx = dst_idx + l_k_idx;
1060 dst[dst_idx] = src[src_idx]
1061 }
1062 }
1063 }
1064 }
1065 Ok(dst)
1066 }
1067}
1068
1069struct Im2Col {
1070 h_k: usize,
1071 w_k: usize,
1072 stride: usize,
1073 dilation: usize,
1074 padding: usize,
1075}
1076
1077impl Im2Col {
1078 fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
1079 let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
1080 let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
1081 (h_out, w_out)
1082 }
1083}
1084
1085impl Map1 for Im2Col {
1086 fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
1087 let &Self {
1088 h_k,
1089 w_k,
1090 stride,
1091 dilation,
1092 padding,
1093 } = self;
1094 let (b, c, h, w) = layout.shape().dims4()?;
1095 let (h_out, w_out) = self.hw_out(h, w);
1096 let src = &vs[layout.start_offset()..];
1097 let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
1098 let (src_s0, src_s1, src_s2, src_s3) = {
1099 let s = layout.stride();
1100 (s[0], s[1], s[2], s[3])
1101 };
1102 for b_idx in 0..b {
1108 let src_idx = b_idx * src_s0;
1109 let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
1110 for h_idx in 0..h_out {
1111 let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
1112 for w_idx in 0..w_out {
1113 let dst_idx = dst_idx + w_idx * c * h_k * w_k;
1114 for c_idx in 0..c {
1115 let dst_idx = dst_idx + c_idx * h_k * w_k;
1116 let src_idx = c_idx * src_s1 + src_idx;
1117 for h_k_idx in 0..h_k {
1118 let src_h = h_idx * stride + h_k_idx * dilation;
1119 if padding != 0 && (src_h < padding || src_h >= h + padding) {
1120 continue;
1121 }
1122 let src_h = src_h - padding;
1123 let src_idx = src_idx + src_h * src_s2;
1124 let dst_idx = dst_idx + h_k_idx * w_k;
1125 for w_k_idx in 0..w_k {
1126 let src_w = w_idx * stride + w_k_idx * dilation;
1127 if padding != 0 && (src_w < padding || src_w >= w + padding) {
1128 continue;
1129 }
1130 let src_w = src_w - padding;
1131 let src_idx = src_idx + src_w * src_s3;
1132 let dst_idx = dst_idx + w_k_idx;
1133 dst[dst_idx] = src[src_idx]
1134 }
1135 }
1136 }
1137 }
1138 }
1139 }
1140 Ok(dst)
1141 }
1142}
1143
1144struct Col2Im1D {
1145 stride: usize,
1146}
1147
1148impl Map1 for Col2Im1D {
1149 fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
1150 let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
1151 let stride = self.stride;
1152 let l_out = (l_in - 1) * stride + k_size;
1153 let mut im = vec![T::zero(); b_size * c_out * l_out];
1154 let (dst_s0, dst_s1) = (c_out * l_out, l_out);
1155 let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
1156 for l_in_i in 0..l_in {
1157 for k_i in 0..k_size {
1158 let l_out_i = l_in_i * stride + k_i;
1159 for b_i in 0..b_size {
1160 for c_i in 0..c_out {
1161 let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
1162 let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
1163 im[dst_idx] += col[src_idx]
1164 }
1165 }
1166 }
1167 }
1168 Ok(im)
1169 }
1170}
1171
1172struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
1173
1174impl Map2 for ConvTranspose1D<'_> {
1175 const OP: &'static str = "conv_transpose1d";
1176 fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1177 let p = self.0;
1178 let inp = &inp[inp_l.start_offset()..];
1179 let k = &k[k_l.start_offset()..];
1180 let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
1181 let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
1182 let l_out = p.l_out();
1183
1184 let dst_elems = p.c_out * l_out * p.b_size;
1186 let dst = vec![T::zero(); dst_elems];
1187 let dst_s0 = p.c_out * l_out;
1188 let dst_s1 = l_out;
1189 let dst_s2 = 1;
1190
1191 let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
1193 let cont_s0 = p.l_in * p.c_in;
1194 let cont_s1 = p.c_in;
1195 for b_idx in 0..p.b_size {
1196 for l_idx in 0..p.l_in {
1197 for c_idx in 0..p.c_in {
1198 let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
1199 let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
1200 inp_cont[dst_idx] = inp[src_idx]
1201 }
1202 }
1203 }
1204
1205 for k_idx in 0..p.k_size {
1206 (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1207 let k_cont = (0..p.c_in)
1208 .map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
1209 .collect::<Vec<_>>();
1210 for b_idx in 0..p.b_size {
1211 for l_idx in 0..p.l_in {
1212 let out_idx = l_idx * p.stride + k_idx * p.dilation;
1213 if out_idx < p.padding {
1214 continue;
1215 }
1216 let out_idx = out_idx - p.padding;
1217 if out_idx < l_out {
1218 let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
1219 let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
1220 let mut d = T::zero();
1221 unsafe {
1222 T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
1223 }
1224 let dst_p = dst.as_ptr();
1225 unsafe {
1229 let ptr = dst_p.add(dst_idx) as *mut T;
1230 *ptr += d
1231 }
1232 }
1233 }
1234 }
1235 })
1236 }
1237 Ok(dst)
1238 }
1239}
1240
1241struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
1242
1243impl Map2 for ConvTranspose2D<'_> {
1244 const OP: &'static str = "conv_transpose2d";
1245 fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1246 let p = self.0;
1247 let inp = &inp[inp_l.start_offset()..];
1248 let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
1249 let k = &k[k_l.start_offset()..];
1250 let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
1251 let (out_h, out_w) = (p.out_h(), p.out_w());
1252
1253 let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
1255 let dst_s0 = p.c_out * out_h * out_w;
1256 let dst_s1 = out_h * out_w;
1257 let dst_s2 = out_w;
1258 let dst_s3 = 1;
1259
1260 let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
1262 let cont_s0 = p.i_h * p.i_w * p.c_in;
1263 let cont_s1 = p.i_w * p.c_in;
1264 let cont_s2 = p.c_in;
1265 for b_idx in 0..p.b_size {
1266 for h_idx in 0..p.i_h {
1267 for w_idx in 0..p.i_w {
1268 for c_idx in 0..p.c_in {
1269 let src_idx =
1270 b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
1271 let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
1272 inp_cont[dst_idx] = inp[src_idx]
1273 }
1274 }
1275 }
1276 }
1277
1278 for k_y in 0..p.k_h {
1279 for k_x in 0..p.k_w {
1280 (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1281 let k_cont = (0..p.c_in)
1282 .map(|c_in_idx| {
1283 k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
1284 })
1285 .collect::<Vec<_>>();
1286 for b_idx in 0..p.b_size {
1287 for inp_y in 0..p.i_h {
1288 for inp_x in 0..p.i_w {
1289 let out_x = inp_x * p.stride + k_x * p.dilation;
1290 let out_y = inp_y * p.stride + k_y * p.dilation;
1291 if out_x < p.padding || out_y < p.padding {
1292 continue;
1293 }
1294 let out_x = out_x - p.padding;
1295 let out_y = out_y - p.padding;
1296 if out_x < out_w && out_y < out_h {
1297 let inp_cont = &inp_cont
1298 [b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..];
1299 let dst_idx = b_idx * dst_s0
1300 + out_y * dst_s2
1301 + out_x * dst_s3
1302 + dst_c_idx * dst_s1;
1303 let mut d = T::zero();
1304 unsafe {
1305 T::vec_dot(
1306 inp_cont.as_ptr(),
1307 k_cont.as_ptr(),
1308 &mut d,
1309 p.c_in,
1310 )
1311 }
1312 let dst_p = dst.as_ptr();
1313 unsafe {
1317 let ptr = dst_p.add(dst_idx) as *mut T;
1318 *ptr += d
1319 }
1320 }
1321 }
1322 }
1323 }
1324 })
1325 }
1326 }
1327 Ok(dst)
1328 }
1329}
1330
1331struct MatMul((usize, usize, usize, usize));
1332
1333impl MatMul {
1334 fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
1335 Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
1336 lhs_l: lhs_l.clone(),
1337 rhs_l: rhs_l.clone(),
1338 bmnk: self.0,
1339 msg,
1340 }))
1341 .bt()
1342 }
1343
1344 fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
1345 let lhs_stride = lhs_l.stride();
1346 let rhs_stride = rhs_l.stride();
1347 let rank = lhs_stride.len();
1348 let (_b, m, n, k) = self.0;
1349 let a_skip: usize = match lhs_stride[..rank - 2] {
1350 [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
1351 [_, stride] if lhs_l.dims()[0] == 1 => stride,
1352 [stride, _] if lhs_l.dims()[1] == 1 => stride,
1353 [stride] => stride,
1354 [] => m * k,
1355 _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
1356 };
1357 let b_skip: usize = match rhs_stride[..rank - 2] {
1358 [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
1359 [_, stride] if rhs_l.dims()[0] == 1 => stride,
1360 [stride, _] if rhs_l.dims()[1] == 1 => stride,
1361 [stride] => stride,
1362 [] => n * k,
1363 _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
1364 };
1365 Ok((a_skip, b_skip))
1366 }
1367}
1368
1369impl Map2 for MatMul {
1370 const OP: &'static str = "mat_mul";
1371
1372 #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
1373 fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1374 &self,
1375 lhs: &[T],
1376 lhs_l: &Layout,
1377 rhs: &[T],
1378 rhs_l: &Layout,
1379 ) -> Result<Vec<T>> {
1380 use gemm::{gemm, Parallelism};
1381
1382 match T::DTYPE {
1383 DType::F16 | DType::F32 | DType::F64 => {}
1384 _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
1385 }
1386
1387 let (b, m, n, k) = self.0;
1388 let lhs = &lhs[lhs_l.start_offset()..];
1389 let rhs = &rhs[rhs_l.start_offset()..];
1390
1391 let lhs_stride = lhs_l.stride();
1392 let rhs_stride = rhs_l.stride();
1393 let rank = lhs_stride.len();
1394 let lhs_cs = lhs_stride[rank - 1];
1395 let lhs_rs = lhs_stride[rank - 2];
1396
1397 let rhs_cs = rhs_stride[rank - 1];
1398 let rhs_rs = rhs_stride[rank - 2];
1399
1400 let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1401 let c_skip: usize = m * n;
1402
1403 let dst_shape: Shape = (m, n).into();
1404 let dst_strides = dst_shape.stride_contiguous();
1405 let dst_rs = dst_strides[0];
1406 let dst_cs = dst_strides[1];
1407
1408 let mut dst = vec![T::zero(); b * m * n];
1409 let num_threads = crate::utils::get_num_threads();
1410 let parallelism = if num_threads > 1 {
1411 Parallelism::Rayon(num_threads)
1412 } else {
1413 Parallelism::None
1414 };
1415 let (b, m, n, k) = if b_skip == 0 && a_skip == m * k {
1416 (1, b * m, n, k)
1419 } else if a_skip == 0 && b_skip == n * k {
1420 (1, m, b * n, k)
1421 } else {
1422 (b, m, n, k)
1423 };
1424 for step in 0..b {
1425 let lhs_p = &lhs[step * a_skip..];
1426 let rhs_p = &rhs[step * b_skip..];
1427 let dst_p = &mut dst[step * c_skip..];
1428 unsafe {
1429 gemm(
1430 m,
1431 n,
1432 k,
1433 dst_p.as_mut_ptr(),
1434 dst_cs as isize,
1435 dst_rs as isize,
1436 false,
1437 lhs_p.as_ptr(),
1438 lhs_cs as isize,
1439 lhs_rs as isize,
1440 rhs_p.as_ptr(),
1441 rhs_cs as isize,
1442 rhs_rs as isize,
1443 T::zero(),
1444 T::one(),
1445 false,
1446 false,
1447 false,
1448 parallelism,
1449 )
1450 }
1451 }
1452 Ok(dst)
1453 }
1454
1455 #[cfg(feature = "accelerate")]
1456 fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1457 &self,
1458 lhs: &[T],
1459 lhs_l: &Layout,
1460 rhs: &[T],
1461 rhs_l: &Layout,
1462 ) -> Result<Vec<T>> {
1463 let (b, m, n, k) = self.0;
1464 let lhs = &lhs[lhs_l.start_offset()..];
1465 let rhs = &rhs[rhs_l.start_offset()..];
1466
1467 let lhs_stride = lhs_l.stride();
1468 let rhs_stride = rhs_l.stride();
1469
1470 let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1471 let c_skip: usize = m * n;
1472
1473 let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1474 let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1475 let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1476 let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1477
1478 let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
1479 (n as i32, b'N')
1480 } else if rhs_m1 == k && rhs_m2 == 1 {
1481 (k as i32, b'T')
1482 } else {
1483 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1484 };
1485 let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
1487 (k as i32, b'N')
1488 } else if lhs_m1 == m && lhs_m2 == 1 {
1489 (m as i32, b'T')
1490 } else {
1491 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1492 };
1493
1494 let mut dst = vec![T::zero(); b * m * n];
1495 match T::DTYPE {
1496 DType::F16 => {
1497 crate::bail!("the accelerate backend does not support f16 matmul")
1498 }
1499 DType::F32 => {
1500 for step in 0..b {
1501 let lhs_p = &lhs[step * a_skip..];
1502 let rhs_p = &rhs[step * b_skip..];
1503 let dst_p = &mut dst[step * c_skip..];
1504 unsafe {
1505 let a = rhs_p.as_ptr() as *const f32;
1506 let b = lhs_p.as_ptr() as *const f32;
1507 let c = dst_p.as_mut_ptr() as *mut f32;
1508 let a = std::slice::from_raw_parts(a, a_skip);
1509 let b = std::slice::from_raw_parts(b, b_skip);
1510 let c = std::slice::from_raw_parts_mut(c, c_skip);
1511 crate::accelerate::sgemm(
1512 transa, transb, n as i32, m as i32,
1513 k as i32, 1., a,
1514 lda, b, ldb,
1515 0., c, n as i32,
1516 )
1517 }
1518 }
1519 }
1520 DType::F64 => {
1521 for step in 0..b {
1522 let lhs_p = &lhs[step * a_skip..];
1523 let rhs_p = &rhs[step * b_skip..];
1524 let dst_p = &mut dst[step * c_skip..];
1525 unsafe {
1526 let a = rhs_p.as_ptr() as *const f64;
1527 let b = lhs_p.as_ptr() as *const f64;
1528 let c = dst_p.as_mut_ptr() as *mut f64;
1529 let a = std::slice::from_raw_parts(a, a_skip);
1530 let b = std::slice::from_raw_parts(b, b_skip);
1531 let c = std::slice::from_raw_parts_mut(c, c_skip);
1532 crate::accelerate::dgemm(
1533 transa, transb, n as i32, m as i32,
1534 k as i32, 1., a,
1535 lda, b, ldb,
1536 0., c, n as i32,
1537 )
1538 }
1539 }
1540 }
1541 dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1542 }
1543 Ok(dst)
1544 }
1545
1546 #[cfg(feature = "mkl")]
1547 fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1548 &self,
1549 lhs: &[T],
1550 lhs_l: &Layout,
1551 rhs: &[T],
1552 rhs_l: &Layout,
1553 ) -> Result<Vec<T>> {
1554 let (b, m, n, k) = self.0;
1555 let lhs = &lhs[lhs_l.start_offset()..];
1556 let rhs = &rhs[rhs_l.start_offset()..];
1557
1558 let lhs_stride = lhs_l.stride();
1559 let rhs_stride = rhs_l.stride();
1560
1561 let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1562 let c_skip: usize = m * n;
1563
1564 let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1565 let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1566 let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1567 let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1568
1569 let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
1570 (n as i32, b'N')
1571 } else if rhs_m1 == k && rhs_m2 == 1 {
1572 (k as i32, b'T')
1573 } else {
1574 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1575 };
1576 let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
1578 (k as i32, b'N')
1579 } else if lhs_m1 == m && lhs_m2 == 1 {
1580 (m as i32, b'T')
1581 } else {
1582 Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1583 };
1584
1585 let mut dst = vec![T::zero(); b * m * n];
1586 match T::DTYPE {
1587 DType::F16 => {
1588 for step in 0..b {
1589 let lhs_p = &lhs[step * a_skip..];
1590 let rhs_p = &rhs[step * b_skip..];
1591 let dst_p = &mut dst[step * c_skip..];
1592 unsafe {
1593 let a = rhs_p.as_ptr() as *const f16;
1594 let b = lhs_p.as_ptr() as *const f16;
1595 let c = dst_p.as_mut_ptr() as *mut f16;
1596 let a = std::slice::from_raw_parts(a, a_skip);
1597 let b = std::slice::from_raw_parts(b, b_skip);
1598 let c = std::slice::from_raw_parts_mut(c, c_skip);
1599 crate::mkl::hgemm(
1600 transa,
1601 transb,
1602 n as i32,
1603 m as i32,
1604 k as i32,
1605 f16::ONE,
1606 a,
1607 lda,
1608 b,
1609 ldb,
1610 f16::ZERO,
1611 c,
1612 n as i32,
1613 )
1614 }
1615 }
1616 }
1617 DType::F32 => {
1618 for step in 0..b {
1619 let lhs_p = &lhs[step * a_skip..];
1620 let rhs_p = &rhs[step * b_skip..];
1621 let dst_p = &mut dst[step * c_skip..];
1622 unsafe {
1623 let a = rhs_p.as_ptr() as *const f32;
1624 let b = lhs_p.as_ptr() as *const f32;
1625 let c = dst_p.as_mut_ptr() as *mut f32;
1626 let a = std::slice::from_raw_parts(a, a_skip);
1627 let b = std::slice::from_raw_parts(b, b_skip);
1628 let c = std::slice::from_raw_parts_mut(c, c_skip);
1629 crate::mkl::sgemm(
1630 transa, transb, n as i32, m as i32,
1631 k as i32, 1., a,
1632 lda, b, ldb,
1633 0., c, n as i32,
1634 )
1635 }
1636 }
1637 }
1638 DType::F64 => {
1639 for step in 0..b {
1640 let lhs_p = &lhs[step * a_skip..];
1641 let rhs_p = &rhs[step * b_skip..];
1642 let dst_p = &mut dst[step * c_skip..];
1643 unsafe {
1644 let a = rhs_p.as_ptr() as *const f64;
1645 let b = lhs_p.as_ptr() as *const f64;
1646 let c = dst_p.as_mut_ptr() as *mut f64;
1647 let a = std::slice::from_raw_parts(a, a_skip);
1648 let b = std::slice::from_raw_parts(b, b_skip);
1649 let c = std::slice::from_raw_parts_mut(c, c_skip);
1650 crate::mkl::dgemm(
1651 transa, transb, n as i32, m as i32,
1652 k as i32, 1., a,
1653 lda, b, ldb,
1654 0., c, n as i32,
1655 )
1656 }
1657 }
1658 }
1659 dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1660 }
1661 Ok(dst)
1662 }
1663}
1664
1665fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
1666 if v.is_sign_positive() {
1667 v
1668 } else {
1669 (v.exp() - T::one()) * alpha
1670 }
1671}
1672
1673impl CpuStorage {
1674 pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
1675 D::cpu_storage_as_slice(self)
1676 }
1677
1678 pub fn concat(storages: &[CpuStorage]) -> Result<CpuStorage> {
1679 let storage0 = &storages[0];
1680 let s = match storage0 {
1681 Self::U8(_) => {
1682 let storages = storages
1683 .iter()
1684 .map(|s| match s {
1685 Self::U8(s) => Ok(s.as_slice()),
1686 _ => crate::bail!("dtype mismatch"),
1687 })
1688 .collect::<Result<Vec<_>>>()?
1689 .concat();
1690 Self::U8(storages)
1691 }
1692 Self::U32(_) => {
1693 let storages = storages
1694 .iter()
1695 .map(|s| match s {
1696 Self::U32(s) => Ok(s.as_slice()),
1697 _ => crate::bail!("dtype mismatch"),
1698 })
1699 .collect::<Result<Vec<_>>>()?
1700 .concat();
1701 Self::U32(storages)
1702 }
1703 Self::I16(_) => {
1704 let storages = storages
1705 .iter()
1706 .map(|s| match s {
1707 Self::I16(s) => Ok(s.as_slice()),
1708 _ => crate::bail!("dtype mismatch"),
1709 })
1710 .collect::<Result<Vec<_>>>()?
1711 .concat();
1712 Self::I16(storages)
1713 }
1714 Self::I32(_) => {
1715 let storages = storages
1716 .iter()
1717 .map(|s| match s {
1718 Self::I32(s) => Ok(s.as_slice()),
1719 _ => crate::bail!("dtype mismatch"),
1720 })
1721 .collect::<Result<Vec<_>>>()?
1722 .concat();
1723 Self::I32(storages)
1724 }
1725 Self::I64(_) => {
1726 let storages = storages
1727 .iter()
1728 .map(|s| match s {
1729 Self::I64(s) => Ok(s.as_slice()),
1730 _ => crate::bail!("dtype mismatch"),
1731 })
1732 .collect::<Result<Vec<_>>>()?
1733 .concat();
1734 Self::I64(storages)
1735 }
1736 Self::BF16(_) => {
1737 let storages = storages
1738 .iter()
1739 .map(|s| match s {
1740 Self::BF16(s) => Ok(s.as_slice()),
1741 _ => crate::bail!("dtype mismatch"),
1742 })
1743 .collect::<Result<Vec<_>>>()?
1744 .concat();
1745 Self::BF16(storages)
1746 }
1747 Self::F16(_) => {
1748 let storages = storages
1749 .iter()
1750 .map(|s| match s {
1751 Self::F16(s) => Ok(s.as_slice()),
1752 _ => crate::bail!("dtype mismatch"),
1753 })
1754 .collect::<Result<Vec<_>>>()?
1755 .concat();
1756 Self::F16(storages)
1757 }
1758 Self::F32(_) => {
1759 let storages = storages
1760 .iter()
1761 .map(|s| match s {
1762 Self::F32(s) => Ok(s.as_slice()),
1763 _ => crate::bail!("dtype mismatch"),
1764 })
1765 .collect::<Result<Vec<_>>>()?
1766 .concat();
1767 Self::F32(storages)
1768 }
1769 Self::F64(_) => {
1770 let storages = storages
1771 .iter()
1772 .map(|s| match s {
1773 Self::F64(s) => Ok(s.as_slice()),
1774 _ => crate::bail!("dtype mismatch"),
1775 })
1776 .collect::<Result<Vec<_>>>()?
1777 .concat();
1778 Self::F64(storages)
1779 }
1780 Self::F8E4M3(_) => {
1781 let storages = storages
1782 .iter()
1783 .map(|s| match s {
1784 Self::F8E4M3(s) => Ok(s.as_slice()),
1785 _ => crate::bail!("dtype mismatch"),
1786 })
1787 .collect::<Result<Vec<_>>>()?
1788 .concat();
1789 Self::F8E4M3(storages)
1790 }
1791 Self::F6E2M3(_) => {
1792 let storages = storages
1793 .iter()
1794 .map(|s| match s {
1795 Self::F6E2M3(s) => Ok(s.as_slice()),
1796 _ => crate::bail!("dtype mismatch"),
1797 })
1798 .collect::<Result<Vec<_>>>()?
1799 .concat();
1800 Self::F6E2M3(storages)
1801 }
1802 Self::F6E3M2(_) => {
1803 let storages = storages
1804 .iter()
1805 .map(|s| match s {
1806 Self::F6E3M2(s) => Ok(s.as_slice()),
1807 _ => crate::bail!("dtype mismatch"),
1808 })
1809 .collect::<Result<Vec<_>>>()?
1810 .concat();
1811 Self::F6E3M2(storages)
1812 }
1813 Self::F4(_) => {
1814 let storages = storages
1815 .iter()
1816 .map(|s| match s {
1817 Self::F4(s) => Ok(s.as_slice()),
1818 _ => crate::bail!("dtype mismatch"),
1819 })
1820 .collect::<Result<Vec<_>>>()?
1821 .concat();
1822 Self::F4(storages)
1823 }
1824 Self::F8E8M0(_) => {
1825 let storages = storages
1826 .iter()
1827 .map(|s| match s {
1828 Self::F8E8M0(s) => Ok(s.as_slice()),
1829 _ => crate::bail!("dtype mismatch"),
1830 })
1831 .collect::<Result<Vec<_>>>()?
1832 .concat();
1833 Self::F8E8M0(storages)
1834 }
1835 };
1836 Ok(s)
1837 }
1838}
1839
1840impl BackendStorage for CpuStorage {
1841 type Device = CpuDevice;
1842
1843 fn dtype(&self) -> DType {
1844 match self {
1845 Self::U8(_) => DType::U8,
1846 Self::U32(_) => DType::U32,
1847 Self::I16(_) => DType::I16,
1848 Self::I32(_) => DType::I32,
1849 Self::I64(_) => DType::I64,
1850 Self::BF16(_) => DType::BF16,
1851 Self::F16(_) => DType::F16,
1852 Self::F32(_) => DType::F32,
1853 Self::F64(_) => DType::F64,
1854 Self::F8E4M3(_) => DType::F8E4M3,
1855 Self::F6E2M3(_) => DType::F6E2M3,
1856 Self::F6E3M2(_) => DType::F6E3M2,
1857 Self::F4(_) => DType::F4,
1858 Self::F8E8M0(_) => DType::F8E8M0,
1859 }
1860 }
1861
1862 fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
1863 match (self, dtype) {
1865 (Self::U8(storage), DType::BF16) => {
1866 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1867 Ok(Self::BF16(data))
1868 }
1869 (Self::U32(storage), DType::BF16) => {
1870 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1871 Ok(Self::BF16(data))
1872 }
1873 (Self::I64(storage), DType::BF16) => {
1874 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1875 Ok(Self::BF16(data))
1876 }
1877 (Self::BF16(storage), DType::BF16) => {
1878 let data = unary_map(storage, layout, |v| v);
1879 Ok(Self::BF16(data))
1880 }
1881 (Self::F16(storage), DType::BF16) => {
1882 let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
1883 Ok(Self::BF16(data))
1884 }
1885 (Self::F32(storage), DType::BF16) => {
1886 let data = unary_map(storage, layout, bf16::from_f32);
1887 Ok(Self::BF16(data))
1888 }
1889 (Self::F64(storage), DType::BF16) => {
1890 let data = unary_map(storage, layout, bf16::from_f64);
1891 Ok(Self::BF16(data))
1892 }
1893 (Self::U8(storage), DType::F16) => {
1894 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1895 Ok(Self::F16(data))
1896 }
1897 (Self::U32(storage), DType::F16) => {
1898 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1899 Ok(Self::F16(data))
1900 }
1901 (Self::I64(storage), DType::F16) => {
1902 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1903 Ok(Self::F16(data))
1904 }
1905 (Self::BF16(storage), DType::F16) => {
1906 let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
1907 Ok(Self::F16(data))
1908 }
1909 (Self::F16(storage), DType::F16) => {
1910 let data = unary_map(storage, layout, |v| v);
1911 Ok(Self::F16(data))
1912 }
1913 (Self::F32(storage), DType::F16) => {
1914 let data = unary_map(storage, layout, f16::from_f32);
1915 Ok(Self::F16(data))
1916 }
1917 (Self::F64(storage), DType::F16) => {
1918 let data = unary_map(storage, layout, f16::from_f64);
1919 Ok(Self::F16(data))
1920 }
1921 (Self::U8(storage), DType::F32) => {
1922 let data = unary_map(storage, layout, |v| v as f32);
1923 Ok(Self::F32(data))
1924 }
1925 (Self::U32(storage), DType::F32) => {
1926 let data = unary_map(storage, layout, |v| v as f32);
1927 Ok(Self::F32(data))
1928 }
1929 (Self::I64(storage), DType::F32) => {
1930 let data = unary_map(storage, layout, |v| v as f32);
1931 Ok(Self::F32(data))
1932 }
1933 (Self::BF16(storage), DType::F32) => {
1934 let data = unary_map(storage, layout, |v| v.to_f32());
1935 Ok(Self::F32(data))
1936 }
1937 (Self::F16(storage), DType::F32) => {
1938 let data = unary_map(storage, layout, |v| v.to_f32());
1939 Ok(Self::F32(data))
1940 }
1941 (Self::F32(storage), DType::F32) => {
1942 let data = unary_map(storage, layout, |v| v);
1943 Ok(Self::F32(data))
1944 }
1945 (Self::F64(storage), DType::F32) => {
1946 let data = unary_map(storage, layout, |v| v as f32);
1947 Ok(Self::F32(data))
1948 }
1949 (Self::U8(storage), DType::U8) => {
1950 let data = unary_map(storage, layout, |v| v);
1951 Ok(Self::U8(data))
1952 }
1953 (Self::BF16(storage), DType::U8) => {
1954 let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1955 Ok(Self::U8(data))
1956 }
1957 (Self::F16(storage), DType::U8) => {
1958 let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1959 Ok(Self::U8(data))
1960 }
1961 (Self::F32(storage), DType::U8) => {
1962 let data = unary_map(storage, layout, |v| v as u8);
1963 Ok(Self::U8(data))
1964 }
1965 (Self::F64(storage), DType::U8) => {
1966 let data = unary_map(storage, layout, |v| v as u8);
1967 Ok(Self::U8(data))
1968 }
1969 (Self::U32(storage), DType::U8) => {
1970 let data = unary_map(storage, layout, |v| v as u8);
1971 Ok(Self::U8(data))
1972 }
1973 (Self::I64(storage), DType::U8) => {
1974 let data = unary_map(storage, layout, |v| v as u8);
1975 Ok(Self::U8(data))
1976 }
1977 (Self::U8(storage), DType::U32) => {
1978 let data = unary_map(storage, layout, |v| v as u32);
1979 Ok(Self::U32(data))
1980 }
1981 (Self::U32(storage), DType::U32) => {
1982 let data = unary_map(storage, layout, |v| v);
1983 Ok(Self::U32(data))
1984 }
1985 (Self::I64(storage), DType::U32) => {
1986 let data = unary_map(storage, layout, |v| v as u32);
1987 Ok(Self::U32(data))
1988 }
1989 (Self::BF16(storage), DType::U32) => {
1990 let data = unary_map(storage, layout, |v| v.to_f32() as u32);
1991 Ok(Self::U32(data))
1992 }
1993 (Self::F16(storage), DType::U32) => {
1994 let data = unary_map(storage, layout, |v| v.to_f32() as u32);
1995 Ok(Self::U32(data))
1996 }
1997 (Self::F32(storage), DType::U32) => {
1998 let data = unary_map(storage, layout, |v| v as u32);
1999 Ok(Self::U32(data))
2000 }
2001 (Self::F64(storage), DType::U32) => {
2002 let data = unary_map(storage, layout, |v| v as u32);
2003 Ok(Self::U32(data))
2004 }
2005 (Self::U8(storage), DType::I64) => {
2006 let data = unary_map(storage, layout, |v| v as i64);
2007 Ok(Self::I64(data))
2008 }
2009 (Self::U32(storage), DType::I64) => {
2010 let data = unary_map(storage, layout, |v| v as i64);
2011 Ok(Self::I64(data))
2012 }
2013 (Self::I64(storage), DType::I64) => {
2014 let data = unary_map(storage, layout, |v| v);
2015 Ok(Self::I64(data))
2016 }
2017 (Self::BF16(storage), DType::I64) => {
2018 let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2019 Ok(Self::I64(data))
2020 }
2021 (Self::F16(storage), DType::I64) => {
2022 let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2023 Ok(Self::I64(data))
2024 }
2025 (Self::F32(storage), DType::I64) => {
2026 let data = unary_map(storage, layout, |v| v as i64);
2027 Ok(Self::I64(data))
2028 }
2029 (Self::F64(storage), DType::I64) => {
2030 let data = unary_map(storage, layout, |v| v as i64);
2031 Ok(Self::I64(data))
2032 }
2033 (Self::U8(storage), DType::F64) => {
2034 let data = unary_map(storage, layout, |v| v as f64);
2035 Ok(Self::F64(data))
2036 }
2037 (Self::U32(storage), DType::F64) => {
2038 let data = unary_map(storage, layout, |v| v as f64);
2039 Ok(Self::F64(data))
2040 }
2041 (Self::I64(storage), DType::F64) => {
2042 let data = unary_map(storage, layout, |v| v as f64);
2043 Ok(Self::F64(data))
2044 }
2045 (Self::BF16(storage), DType::F64) => {
2046 let data = unary_map(storage, layout, |v| v.to_f64());
2047 Ok(Self::F64(data))
2048 }
2049 (Self::F16(storage), DType::F64) => {
2050 let data = unary_map(storage, layout, |v| v.to_f64());
2051 Ok(Self::F64(data))
2052 }
2053 (Self::F32(storage), DType::F64) => {
2054 let data = unary_map(storage, layout, |v| v as f64);
2055 Ok(Self::F64(data))
2056 }
2057 (Self::F64(storage), DType::F64) => {
2058 let data = unary_map(storage, layout, |v| v);
2059 Ok(Self::F64(data))
2060 }
2061 (Self::U8(storage), DType::F8E4M3) => {
2063 let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2064 Ok(Self::F8E4M3(data))
2065 }
2066 (Self::U32(storage), DType::F8E4M3) => {
2067 let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2068 Ok(Self::F8E4M3(data))
2069 }
2070 (Self::I64(storage), DType::F8E4M3) => {
2071 let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2072 Ok(Self::F8E4M3(data))
2073 }
2074 (Self::BF16(storage), DType::F8E4M3) => {
2075 let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
2076 Ok(Self::F8E4M3(data))
2077 }
2078 (Self::F16(storage), DType::F8E4M3) => {
2079 let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
2080 Ok(Self::F8E4M3(data))
2081 }
2082 (Self::F32(storage), DType::F8E4M3) => {
2083 let data = unary_map(storage, layout, F8E4M3::from_f32);
2084 Ok(Self::F8E4M3(data))
2085 }
2086 (Self::F64(storage), DType::F8E4M3) => {
2087 let data = unary_map(storage, layout, F8E4M3::from_f64);
2088 Ok(Self::F8E4M3(data))
2089 }
2090 (Self::F8E4M3(storage), DType::F8E4M3) => {
2091 let data = unary_map(storage, layout, |v| v);
2092 Ok(Self::F8E4M3(data))
2093 }
2094 (Self::F8E4M3(storage), DType::U8) => {
2096 let data = unary_map(storage, layout, |v| v.to_f32() as u8);
2097 Ok(Self::U8(data))
2098 }
2099 (Self::F8E4M3(storage), DType::U32) => {
2100 let data = unary_map(storage, layout, |v| v.to_f32() as u32);
2101 Ok(Self::U32(data))
2102 }
2103 (Self::F8E4M3(storage), DType::I64) => {
2104 let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2105 Ok(Self::I64(data))
2106 }
2107 (Self::F8E4M3(storage), DType::BF16) => {
2108 let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
2109 Ok(Self::BF16(data))
2110 }
2111 (Self::F8E4M3(storage), DType::F16) => {
2112 let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
2113 Ok(Self::F16(data))
2114 }
2115 (Self::F8E4M3(storage), DType::F32) => {
2116 let data = unary_map(storage, layout, |v| v.to_f32());
2117 Ok(Self::F32(data))
2118 }
2119 (Self::F8E4M3(storage), DType::F64) => {
2120 let data = unary_map(storage, layout, |v| v.to_f64());
2121 Ok(Self::F64(data))
2122 }
2123 (Self::U8(storage), DType::I16) => {
2125 let data = unary_map(storage, layout, |v| v as i16);
2126 Ok(Self::I16(data))
2127 }
2128 (Self::U32(storage), DType::I16) => {
2129 let data = unary_map(storage, layout, |v| v as i16);
2130 Ok(Self::I16(data))
2131 }
2132 (Self::I16(storage), DType::I16) => {
2133 let data = unary_map(storage, layout, |v| v);
2134 Ok(Self::I16(data))
2135 }
2136 (Self::I32(storage), DType::I16) => {
2137 let data = unary_map(storage, layout, |v| v as i16);
2138 Ok(Self::I16(data))
2139 }
2140 (Self::I64(storage), DType::I16) => {
2141 let data = unary_map(storage, layout, |v| v as i16);
2142 Ok(Self::I16(data))
2143 }
2144 (Self::BF16(storage), DType::I16) => {
2145 let data = unary_map(storage, layout, |v| v.to_f32() as i16);
2146 Ok(Self::I16(data))
2147 }
2148 (Self::F16(storage), DType::I16) => {
2149 let data = unary_map(storage, layout, |v| v.to_f32() as i16);
2150 Ok(Self::I16(data))
2151 }
2152 (Self::F32(storage), DType::I16) => {
2153 let data = unary_map(storage, layout, |v| v as i16);
2154 Ok(Self::I16(data))
2155 }
2156 (Self::F64(storage), DType::I16) => {
2157 let data = unary_map(storage, layout, |v| v as i16);
2158 Ok(Self::I16(data))
2159 }
2160 (Self::F8E4M3(storage), DType::I16) => {
2161 let data = unary_map(storage, layout, |v| v.to_f32() as i16);
2162 Ok(Self::I16(data))
2163 }
2164 (Self::U8(storage), DType::I32) => {
2166 let data = unary_map(storage, layout, |v| v as i32);
2167 Ok(Self::I32(data))
2168 }
2169 (Self::U32(storage), DType::I32) => {
2170 let data = unary_map(storage, layout, |v| v as i32);
2171 Ok(Self::I32(data))
2172 }
2173 (Self::I16(storage), DType::I32) => {
2174 let data = unary_map(storage, layout, |v| v as i32);
2175 Ok(Self::I32(data))
2176 }
2177 (Self::I32(storage), DType::I32) => {
2178 let data = unary_map(storage, layout, |v| v);
2179 Ok(Self::I32(data))
2180 }
2181 (Self::I64(storage), DType::I32) => {
2182 let data = unary_map(storage, layout, |v| v as i32);
2183 Ok(Self::I32(data))
2184 }
2185 (Self::BF16(storage), DType::I32) => {
2186 let data = unary_map(storage, layout, |v| v.to_f32() as i32);
2187 Ok(Self::I32(data))
2188 }
2189 (Self::F16(storage), DType::I32) => {
2190 let data = unary_map(storage, layout, |v| v.to_f32() as i32);
2191 Ok(Self::I32(data))
2192 }
2193 (Self::F32(storage), DType::I32) => {
2194 let data = unary_map(storage, layout, |v| v as i32);
2195 Ok(Self::I32(data))
2196 }
2197 (Self::F64(storage), DType::I32) => {
2198 let data = unary_map(storage, layout, |v| v as i32);
2199 Ok(Self::I32(data))
2200 }
2201 (Self::F8E4M3(storage), DType::I32) => {
2202 let data = unary_map(storage, layout, |v| v.to_f32() as i32);
2203 Ok(Self::I32(data))
2204 }
2205 (Self::I16(storage), DType::U8) => {
2207 let data = unary_map(storage, layout, |v| v as u8);
2208 Ok(Self::U8(data))
2209 }
2210 (Self::I16(storage), DType::U32) => {
2211 let data = unary_map(storage, layout, |v| v as u32);
2212 Ok(Self::U32(data))
2213 }
2214 (Self::I16(storage), DType::I64) => {
2215 let data = unary_map(storage, layout, |v| v as i64);
2216 Ok(Self::I64(data))
2217 }
2218 (Self::I16(storage), DType::BF16) => {
2219 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
2220 Ok(Self::BF16(data))
2221 }
2222 (Self::I16(storage), DType::F16) => {
2223 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
2224 Ok(Self::F16(data))
2225 }
2226 (Self::I16(storage), DType::F32) => {
2227 let data = unary_map(storage, layout, |v| v as f32);
2228 Ok(Self::F32(data))
2229 }
2230 (Self::I16(storage), DType::F64) => {
2231 let data = unary_map(storage, layout, |v| v as f64);
2232 Ok(Self::F64(data))
2233 }
2234 (Self::I16(storage), DType::F8E4M3) => {
2235 let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2236 Ok(Self::F8E4M3(data))
2237 }
2238 (Self::I32(storage), DType::U8) => {
2240 let data = unary_map(storage, layout, |v| v as u8);
2241 Ok(Self::U8(data))
2242 }
2243 (Self::I32(storage), DType::U32) => {
2244 let data = unary_map(storage, layout, |v| v as u32);
2245 Ok(Self::U32(data))
2246 }
2247 (Self::I32(storage), DType::I64) => {
2248 let data = unary_map(storage, layout, |v| v as i64);
2249 Ok(Self::I64(data))
2250 }
2251 (Self::I32(storage), DType::BF16) => {
2252 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
2253 Ok(Self::BF16(data))
2254 }
2255 (Self::I32(storage), DType::F16) => {
2256 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
2257 Ok(Self::F16(data))
2258 }
2259 (Self::I32(storage), DType::F32) => {
2260 let data = unary_map(storage, layout, |v| v as f32);
2261 Ok(Self::F32(data))
2262 }
2263 (Self::I32(storage), DType::F64) => {
2264 let data = unary_map(storage, layout, |v| v as f64);
2265 Ok(Self::F64(data))
2266 }
2267 (Self::I32(storage), DType::F8E4M3) => {
2268 let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2269 Ok(Self::F8E4M3(data))
2270 }
2271 (_, DType::F6E2M3) | (_, DType::F6E3M2) | (_, DType::F4) | (_, DType::F8E8M0) => {
2273 Err(Error::UnsupportedDTypeForOp(dtype, "to_dtype").bt())
2274 }
2275 (Self::F6E2M3(_), _)
2276 | (Self::F6E3M2(_), _)
2277 | (Self::F4(_), _)
2278 | (Self::F8E8M0(_), _) => {
2279 Err(Error::UnsupportedDTypeForOp(self.dtype(), "to_dtype").bt())
2280 }
2281 }
2282 }
2283
2284 fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
2285 match op {
2286 ReduceOp::Sum => {
2287 let src_dims = layout.dims();
2288 let mut dst_dims = src_dims.to_vec();
2289 for &dim in reduce_dims.iter() {
2290 dst_dims[dim] = 1;
2291 }
2292 let dst_shape = Shape::from(dst_dims);
2293 let mut reduce_dims = reduce_dims.to_vec();
2294 reduce_dims.sort();
2297 let reduce_dims_and_stride: Vec<_> = reduce_dims
2298 .iter()
2299 .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
2300 .collect();
2301 ReduceSum {
2302 dst_shape: &dst_shape,
2303 reduce_dims: &reduce_dims,
2304 reduce_dims_and_stride,
2305 }
2306 .map(self, layout)
2307 }
2308 ReduceOp::Min | ReduceOp::ArgMin | ReduceOp::Max | ReduceOp::ArgMax => {
2309 let reduce_dim_index = match reduce_dims {
2310 [reduce_dim_index] => *reduce_dim_index,
2311 _ => {
2312 let op = match op {
2313 ReduceOp::Min => "min",
2314 ReduceOp::ArgMin => "argmin",
2315 ReduceOp::Max => "max",
2316 ReduceOp::ArgMax => "argmax",
2317 _ => unreachable!(),
2318 };
2319 let dims = reduce_dims.to_vec();
2320 Err(Error::OnlySingleDimension { op, dims })?
2321 }
2322 };
2323 let (use_min, return_index) = match op {
2324 ReduceOp::Min => (true, false),
2325 ReduceOp::ArgMin => (true, true),
2326 ReduceOp::Max => (false, false),
2327 ReduceOp::ArgMax => (false, true),
2328 _ => unreachable!(),
2329 };
2330 ReduceIndex {
2331 reduce_dim_index,
2332 use_min,
2333 return_index,
2334 }
2335 .map(self, layout)
2336 }
2337 }
2338 }
2339
2340 fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
2341 Cmp(op).map(self, lhs_l, rhs, rhs_l)
2342 }
2343
2344 fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
2345 Affine(mul, add).map(self, layout)
2346 }
2347
2348 fn avg_pool2d(
2349 &self,
2350 layout: &Layout,
2351 kernel_size: (usize, usize),
2352 stride: (usize, usize),
2353 ) -> Result<Self> {
2354 AvgPool2D(kernel_size, stride).map(self, layout)
2355 }
2356
2357 fn max_pool2d(
2358 &self,
2359 layout: &Layout,
2360 kernel_size: (usize, usize),
2361 stride: (usize, usize),
2362 ) -> Result<Self> {
2363 MaxPool2D(kernel_size, stride).map(self, layout)
2364 }
2365
2366 fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
2367 UpsampleNearest1D(sz).map(self, layout)
2368 }
2369
2370 fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
2371 UpsampleNearest2D(h, w).map(self, layout)
2372 }
2373
2374 fn upsample_bilinear2d(
2375 &self,
2376 layout: &Layout,
2377 h: usize,
2378 w: usize,
2379 align_corners: bool,
2380 scale_h: Option<f64>,
2381 scale_w: Option<f64>,
2382 ) -> Result<Self> {
2383 UpsampleBilinear2D {
2384 target_h: h,
2385 target_w: w,
2386 align_corners,
2387 scale_h_factor: scale_h,
2388 scale_w_factor: scale_w,
2389 }
2390 .map(self, layout)
2391 }
2392
2393 fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
2394 use num_traits::Float;
2395 match self {
2397 Self::BF16(storage) => {
2398 let data = unary_map(storage, layout, |v| v.powf(bf16::from_f64(e)));
2399 Ok(Self::BF16(data))
2400 }
2401 Self::F16(storage) => {
2402 let data = unary_map(storage, layout, |v| v.powf(f16::from_f64(e)));
2403 Ok(Self::F16(data))
2404 }
2405 Self::F32(storage) => {
2406 let data = unary_map(storage, layout, |v| v.powf(e as f32));
2407 Ok(Self::F32(data))
2408 }
2409 Self::F64(storage) => {
2410 let data = unary_map(storage, layout, |v| v.powf(e));
2411 Ok(Self::F64(data))
2412 }
2413 Self::F8E4M3(storage) => {
2414 let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e)));
2415 Ok(Self::F8E4M3(data))
2416 }
2417 Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "powf").bt()),
2418 Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "powf").bt()),
2419 Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "powf").bt()),
2420 Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "powf").bt()),
2421 Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "powf").bt()),
2422 Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "powf").bt()),
2423 Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "powf").bt()),
2424 Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "powf").bt()),
2425 Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "powf").bt()),
2426 }
2427 }
2428
2429 fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
2430 match self {
2432 Self::BF16(storage) => {
2433 let data = unary_map(storage, layout, |v| elu(v, bf16::from_f64(alpha)));
2434 Ok(Self::BF16(data))
2435 }
2436 Self::F16(storage) => {
2437 let data = unary_map(storage, layout, |v| elu(v, f16::from_f64(alpha)));
2438 Ok(Self::F16(data))
2439 }
2440 Self::F32(storage) => {
2441 let data = unary_map(storage, layout, |v| elu(v, f32::from_f64(alpha)));
2442 Ok(Self::F32(data))
2443 }
2444 Self::F64(storage) => {
2445 let data = unary_map(storage, layout, |v| elu(v, alpha));
2446 Ok(Self::F64(data))
2447 }
2448 Self::F8E4M3(storage) => {
2449 let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha)));
2450 Ok(Self::F8E4M3(data))
2451 }
2452 Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
2453 Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
2454 Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()),
2455 Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()),
2456 Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
2457 Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "elu").bt()),
2458 Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "elu").bt()),
2459 Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "elu").bt()),
2460 Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "elu").bt()),
2461 }
2462 }
2463
2464 fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
2465 match self {
2466 Self::BF16(storage) => {
2467 let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
2468 Ok(Self::BF16(data))
2469 }
2470 Self::F16(storage) => {
2471 let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
2472 Ok(Self::F16(data))
2473 }
2474 Self::F32(storage) => {
2475 let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
2476 Ok(Self::F32(data))
2477 }
2478 Self::F64(storage) => {
2479 let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
2480 Ok(Self::F64(data))
2481 }
2482 Self::U8(storage) => {
2483 let data = unary_map(storage, layout, B::u8);
2484 Ok(Self::U8(data))
2485 }
2486 Self::U32(storage) => {
2487 let data = unary_map(storage, layout, B::u32);
2488 Ok(Self::U32(data))
2489 }
2490 Self::I16(storage) => {
2491 let data = unary_map(storage, layout, B::i16);
2492 Ok(Self::I16(data))
2493 }
2494 Self::I32(storage) => {
2495 let data = unary_map(storage, layout, B::i32);
2496 Ok(Self::I32(data))
2497 }
2498 Self::I64(storage) => {
2499 let data = unary_map(storage, layout, B::i64);
2500 Ok(Self::I64(data))
2501 }
2502 Self::F8E4M3(storage) => {
2503 let data = unary_map(storage, layout, B::f8e4m3);
2504 Ok(Self::F8E4M3(data))
2505 }
2506 Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "unary").bt()),
2507 Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "unary").bt()),
2508 Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "unary").bt()),
2509 Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "unary").bt()),
2510 }
2511 }
2512
2513 fn binary_impl<B: BinaryOpT>(
2514 &self,
2515 rhs: &Self,
2516 lhs_l: &Layout,
2517 rhs_l: &Layout,
2518 ) -> Result<Self> {
2519 match (self, rhs) {
2520 (Self::BF16(lhs), Self::BF16(rhs)) => {
2521 let data = binary_map_vec(
2522 lhs_l,
2523 rhs_l,
2524 lhs,
2525 rhs,
2526 B::bf16,
2527 B::bf16_vec,
2528 B::bf16_scalar_vec,
2529 );
2530 Ok(Self::BF16(data))
2531 }
2532 (Self::F16(lhs), Self::F16(rhs)) => {
2533 let data = binary_map_vec(
2534 lhs_l,
2535 rhs_l,
2536 lhs,
2537 rhs,
2538 B::f16,
2539 B::f16_vec,
2540 B::f16_scalar_vec,
2541 );
2542 Ok(Self::F16(data))
2543 }
2544 (Self::F32(lhs), Self::F32(rhs)) => {
2545 let data = binary_map_vec(
2546 lhs_l,
2547 rhs_l,
2548 lhs,
2549 rhs,
2550 B::f32,
2551 B::f32_vec,
2552 B::f32_scalar_vec,
2553 );
2554 Ok(Self::F32(data))
2555 }
2556 (Self::F64(lhs), Self::F64(rhs)) => {
2557 let data = binary_map_vec(
2558 lhs_l,
2559 rhs_l,
2560 lhs,
2561 rhs,
2562 B::f64,
2563 B::f64_vec,
2564 B::f64_scalar_vec,
2565 );
2566 Ok(Self::F64(data))
2567 }
2568 (Self::U32(lhs), Self::U32(rhs)) => {
2569 let data = binary_map_vec(
2570 lhs_l,
2571 rhs_l,
2572 lhs,
2573 rhs,
2574 B::u32,
2575 B::u32_vec,
2576 B::u32_scalar_vec,
2577 );
2578 Ok(Self::U32(data))
2579 }
2580 (Self::I16(lhs), Self::I16(rhs)) => {
2581 let data = binary_map_vec(
2582 lhs_l,
2583 rhs_l,
2584 lhs,
2585 rhs,
2586 B::i16,
2587 B::i16_vec,
2588 B::i16_scalar_vec,
2589 );
2590 Ok(Self::I16(data))
2591 }
2592 (Self::I32(lhs), Self::I32(rhs)) => {
2593 let data = binary_map_vec(
2594 lhs_l,
2595 rhs_l,
2596 lhs,
2597 rhs,
2598 B::i32,
2599 B::i32_vec,
2600 B::i32_scalar_vec,
2601 );
2602 Ok(Self::I32(data))
2603 }
2604 (Self::I64(lhs), Self::I64(rhs)) => {
2605 let data = binary_map_vec(
2606 lhs_l,
2607 rhs_l,
2608 lhs,
2609 rhs,
2610 B::i64,
2611 B::i64_vec,
2612 B::i64_scalar_vec,
2613 );
2614 Ok(Self::I64(data))
2615 }
2616 (Self::U8(lhs), Self::U8(rhs)) => {
2617 let data =
2618 binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec, B::u8_scalar_vec);
2619 Ok(Self::U8(data))
2620 }
2621 (Self::F8E4M3(lhs), Self::F8E4M3(rhs)) => {
2622 let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f8e4m3);
2623 Ok(Self::F8E4M3(data))
2624 }
2625 _ => {
2626 Err(Error::DTypeMismatchBinaryOp {
2628 lhs: self.dtype(),
2629 rhs: rhs.dtype(),
2630 op: B::NAME,
2631 }
2632 .bt())
2633 }
2634 }
2635 }
2636
2637 fn copy2d(
2638 &self,
2639 dst: &mut Self,
2640 d1: usize,
2641 d2: usize,
2642 src_s: usize,
2643 dst_s: usize,
2644 src_o: usize,
2645 dst_o: usize,
2646 ) -> Result<()> {
2647 match (self, dst) {
2648 (Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
2649 (Self::U32(src), Self::U32(dst)) => {
2650 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2651 }
2652 (Self::I16(src), Self::I16(dst)) => {
2653 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2654 }
2655 (Self::I32(src), Self::I32(dst)) => {
2656 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2657 }
2658 (Self::I64(src), Self::I64(dst)) => {
2659 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2660 }
2661 (Self::BF16(src), Self::BF16(dst)) => {
2662 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2663 }
2664 (Self::F16(src), Self::F16(dst)) => {
2665 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2666 }
2667 (Self::F32(src), Self::F32(dst)) => {
2668 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2669 }
2670 (Self::F64(src), Self::F64(dst)) => {
2671 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2672 }
2673 (Self::F8E4M3(src), Self::F8E4M3(dst)) => {
2674 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2675 }
2676 (Self::F6E2M3(src), Self::F6E2M3(dst)) => {
2677 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2678 }
2679 (Self::F6E3M2(src), Self::F6E3M2(dst)) => {
2680 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2681 }
2682 (Self::F4(src), Self::F4(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
2683 (Self::F8E8M0(src), Self::F8E8M0(dst)) => {
2684 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2685 }
2686 (_, dst) => {
2687 return Err(Error::DTypeMismatchBinaryOp {
2688 lhs: self.dtype(),
2689 rhs: dst.dtype(),
2690 op: "copy2d",
2691 }
2692 .bt());
2693 }
2694 }
2695 Ok(())
2696 }
2697
2698 fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
2699 match (self, dst) {
2700 (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2701 (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2702 (Self::I16(src), Self::I16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2703 (Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2704 (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2705 (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2706 (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2707 (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2708 (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2709 (Self::F8E4M3(src), Self::F8E4M3(dst)) => {
2710 copy_strided_src_(src, dst, dst_offset, src_l)
2711 }
2712 (Self::F6E2M3(src), Self::F6E2M3(dst)) => {
2713 copy_strided_src_(src, dst, dst_offset, src_l)
2714 }
2715 (Self::F6E3M2(src), Self::F6E3M2(dst)) => {
2716 copy_strided_src_(src, dst, dst_offset, src_l)
2717 }
2718 (Self::F4(src), Self::F4(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2719 (Self::F8E8M0(src), Self::F8E8M0(dst)) => {
2720 copy_strided_src_(src, dst, dst_offset, src_l)
2721 }
2722 (_, dst) => {
2723 return Err(Error::DTypeMismatchBinaryOp {
2725 lhs: self.dtype(),
2726 rhs: dst.dtype(),
2727 op: "copy_strided",
2728 }
2729 .bt());
2730 }
2731 }
2732 Ok(())
2733 }
2734
2735 fn where_cond(
2736 &self,
2737 layout: &Layout,
2738 t: &Self,
2739 t_l: &Layout,
2740 f: &Self,
2741 f_l: &Layout,
2742 ) -> Result<Self> {
2743 match self {
2744 Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2745 Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2746 Self::I16(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2747 Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2748 Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2749 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
2750 }
2751 }
2752
2753 fn conv1d(
2754 &self,
2755 l: &Layout,
2756 kernel: &Self,
2757 kernel_l: &Layout,
2758 params: &crate::conv::ParamsConv1D,
2759 ) -> Result<Self> {
2760 if !USE_IM2COL_CONV1D {
2761 return Conv1D(params).map(self, l, kernel, kernel_l);
2762 }
2763 let op = Im2Col1D {
2764 l_k: params.k_size,
2765 padding: params.padding,
2766 stride: params.stride,
2767 dilation: params.dilation,
2768 };
2769 let col = op.map(self, l)?;
2770 let b = params.b_size;
2771 let n = params.c_out;
2772 let l_out = params.l_out();
2773 let k = op.l_k * params.c_in;
2774 let m = l_out;
2775 let col_l = Layout::contiguous((b, m, k));
2776 let res = if kernel_l.is_contiguous() {
2777 let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2778 .transpose(1, 2)?
2779 .broadcast_as((b, k, n))?;
2780 col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2781 } else {
2782 let mut kernel_c = unsafe {
2784 self.device()
2785 .alloc_uninit(kernel_l.shape(), kernel.dtype())?
2786 };
2787 kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
2788 let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2789 .transpose(1, 2)?
2790 .broadcast_as((b, k, n))?;
2791 col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2792 };
2793 let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
2794 let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
2795 res.copy_strided_src(&mut res_t, 0, &res_l)?;
2796 Ok(res_t)
2797 }
2798
2799 fn conv_transpose1d(
2800 &self,
2801 l: &Layout,
2802 kernel: &Self,
2803 kernel_l: &Layout,
2804 params: &crate::conv::ParamsConvTranspose1D,
2805 ) -> Result<Self> {
2806 let can_use_col2im = kernel_l.is_contiguous()
2807 && params.dilation == 1
2808 && params.padding == 0
2809 && params.output_padding == 0;
2810 if USE_COL2IM_CONV1D_TR && can_use_col2im {
2811 let (b_size, c_in, l_in) = l.shape().dims3()?;
2812 let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
2813 if !kernel_l.is_contiguous() {
2814 crate::bail!(
2815 "convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
2816 )
2817 }
2818 if c_in != c_in2 {
2819 crate::bail!(
2820 "convtr1d: shape mismatch on c_in {:?} {:?}",
2821 l.shape(),
2822 kernel_l.shape()
2823 )
2824 }
2825 let col = {
2826 let kernel_l_mm = Layout::new(
2828 (b_size, c_in, k_size * c_out).into(),
2829 vec![0, k_size * c_out, 1],
2830 kernel_l.start_offset(),
2831 );
2832 self.matmul(
2833 kernel,
2834 (
2835 b_size,
2836 l_in,
2837 c_out * k_size,
2838 c_in,
2839 ),
2840 &l.transpose(1, 2)?,
2841 &kernel_l_mm,
2842 )?
2843 };
2844 let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
2845 Col2Im1D {
2846 stride: params.stride,
2847 }
2848 .map(&col, &col_l)
2849 } else {
2850 ConvTranspose1D(params).map(self, l, kernel, kernel_l)
2851 }
2852 }
2853
2854 fn conv2d(
2855 &self,
2856 l: &Layout,
2857 kernel: &Self,
2858 kernel_l: &Layout,
2859 params: &crate::conv::ParamsConv2D,
2860 ) -> Result<Self> {
2861 Conv2D(params).map(self, l, kernel, kernel_l)
2862 }
2863
2864 fn conv_transpose2d(
2865 &self,
2866 l: &Layout,
2867 kernel: &Self,
2868 kernel_l: &Layout,
2869 params: &crate::conv::ParamsConvTranspose2D,
2870 ) -> Result<Self> {
2871 ConvTranspose2D(params).map(self, l, kernel, kernel_l)
2872 }
2873
2874 fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
2875 match ids {
2876 Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2877 Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2878 Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2879 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
2880 }
2881 }
2882
2883 fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
2884 match ids {
2885 Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
2886 Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
2887 Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
2888 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
2889 }
2890 }
2891
2892 fn scatter_set(
2893 &mut self,
2894 l: &Layout,
2895 ids: &Self,
2896 ids_l: &Layout,
2897 src: &Self,
2898 src_l: &Layout,
2899 dim: usize,
2900 ) -> Result<()> {
2901 match ids {
2902 Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2903 Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2904 Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2905 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()),
2906 }
2907 }
2908
2909 fn scatter_add_set(
2910 &mut self,
2911 l: &Layout,
2912 ids: &Self,
2913 ids_l: &Layout,
2914 src: &Self,
2915 src_l: &Layout,
2916 dim: usize,
2917 ) -> Result<()> {
2918 match ids {
2919 Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2920 Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2921 Self::I16(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2922 Self::I32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2923 Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2924 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
2925 }
2926 }
2927
2928 fn index_add(
2929 &self,
2930 l: &Layout,
2931 ids: &Self,
2932 ids_l: &Layout,
2933 src: &Self,
2934 src_l: &Layout,
2935 dim: usize,
2936 ) -> Result<Self> {
2937 match ids {
2938 Self::U8(ids) => {
2939 let ids = match ids_l.contiguous_offsets() {
2940 Some((a, b)) => &ids[a..b],
2941 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2942 };
2943 IndexAdd { ids, dim }.map(self, l, src, src_l)
2944 }
2945 Self::U32(ids) => {
2946 let ids = match ids_l.contiguous_offsets() {
2947 Some((a, b)) => &ids[a..b],
2948 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2949 };
2950 IndexAdd { ids, dim }.map(self, l, src, src_l)
2951 }
2952 Self::I16(ids) => {
2953 let ids = match ids_l.contiguous_offsets() {
2954 Some((a, b)) => &ids[a..b],
2955 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2956 };
2957 IndexAdd { ids, dim }.map(self, l, src, src_l)
2958 }
2959 Self::I32(ids) => {
2960 let ids = match ids_l.contiguous_offsets() {
2961 Some((a, b)) => &ids[a..b],
2962 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2963 };
2964 IndexAdd { ids, dim }.map(self, l, src, src_l)
2965 }
2966 Self::I64(ids) => {
2967 let ids = match ids_l.contiguous_offsets() {
2968 Some((a, b)) => &ids[a..b],
2969 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2970 };
2971 IndexAdd { ids, dim }.map(self, l, src, src_l)
2972 }
2973 _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
2974 }
2975 }
2976
2977 fn matmul(
2978 &self,
2979 rhs: &Self,
2980 bmnk: (usize, usize, usize, usize),
2981 lhs_l: &Layout,
2982 rhs_l: &Layout,
2983 ) -> Result<Self> {
2984 MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
2985 }
2986
2987 fn device(&self) -> &Self::Device {
2988 &CpuDevice
2989 }
2990
2991 fn try_clone(&self, _: &Layout) -> Result<Self> {
2992 Ok(self.clone())
2993 }
2994
2995 fn to_cpu_storage(&self) -> Result<CpuStorage> {
2996 Ok(self.clone())
2997 }
2998
2999 fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
3000 use crate::scalar::Scalar;
3001 fn set<T: crate::WithDType>(src: &mut [T], l: &Layout, s: T) {
3002 match l.strided_blocks() {
3003 crate::StridedBlocks::SingleBlock { start_offset, len } => {
3004 src[start_offset..start_offset + len].fill(s)
3005 }
3006 crate::StridedBlocks::UniformBlocks {
3007 start_offset,
3008 block_len,
3009 count,
3010 src_stride,
3011 } => {
3012 for i in 0..count {
3013 let start = start_offset + i * src_stride;
3014 src[start..start + block_len].fill(s)
3015 }
3016 }
3017 crate::StridedBlocks::MultipleBlocks {
3018 block_start_index,
3019 block_len: 1,
3020 } => {
3021 for src_index in block_start_index {
3022 src[src_index] = s
3023 }
3024 }
3025 crate::StridedBlocks::MultipleBlocks {
3026 block_start_index,
3027 block_len,
3028 } => {
3029 for src_index in block_start_index {
3030 src[src_index..src_index + block_len].fill(s)
3031 }
3032 }
3033 }
3034 }
3035 match (self, s) {
3036 (Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v),
3037 (Self::F16(storage), Scalar::F16(v)) => set(storage, l, v),
3038 (Self::F32(storage), Scalar::F32(v)) => set(storage, l, v),
3039 (Self::F64(storage), Scalar::F64(v)) => set(storage, l, v),
3040 (Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
3041 (Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
3042 (Self::I16(storage), Scalar::I16(v)) => set(storage, l, v),
3043 (Self::I32(storage), Scalar::I32(v)) => set(storage, l, v),
3044 (Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
3045 (Self::F8E4M3(storage), Scalar::F8E4M3(v)) => set(storage, l, v),
3046 (Self::F6E2M3(_), _) => {
3048 crate::bail!("const_set not supported for dummy type F6E2M3")
3049 }
3050 (Self::F6E3M2(_), _) => {
3051 crate::bail!("const_set not supported for dummy type F6E3M2")
3052 }
3053 (Self::F4(_), _) => {
3054 crate::bail!("const_set not supported for dummy type F4")
3055 }
3056 (Self::F8E8M0(_), _) => {
3057 crate::bail!("const_set not supported for dummy type F8E8M0")
3058 }
3059 (st, s) => crate::bail!(
3060 "const_set dtype mismatch, expected {:?} but got {:?}",
3061 st.dtype(),
3062 s
3063 ),
3064 }
3065 Ok(())
3066 }
3067}
3068
3069impl BackendDevice for CpuDevice {
3070 type Storage = CpuStorage;
3071
3072 fn location(&self) -> crate::DeviceLocation {
3073 crate::DeviceLocation::Cpu
3074 }
3075
3076 fn same_device(&self, _: &Self) -> bool {
3077 true
3078 }
3079
3080 fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
3081 Ok(T::to_cpu_storage(s))
3082 }
3083
3084 fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
3085 Ok(s.clone())
3086 }
3087
3088 fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> {
3089 Ok(s)
3090 }
3091
3092 fn new(_: usize) -> Result<Self> {
3093 Ok(Self)
3094 }
3095
3096 fn set_seed(&self, _seed: u64) -> Result<()> {
3097 crate::bail!("cannot seed the CPU rng with set_seed")
3098 }
3099
3100 fn get_current_seed(&self) -> Result<u64> {
3101 crate::bail!("cannot get the CPU rng seed with get_current_seed")
3102 }
3103
3104 fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
3105 use rand::prelude::*;
3106
3107 let elem_count = shape.elem_count();
3108 let mut rng = rand::rng();
3109 match dtype {
3110 DType::U8
3111 | DType::U32
3112 | DType::I16
3113 | DType::I32
3114 | DType::I64
3115 | DType::F6E2M3
3116 | DType::F6E3M2
3117 | DType::F4
3118 | DType::F8E8M0 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()),
3119 DType::BF16 => {
3120 let mut data = Vec::with_capacity(elem_count);
3121 let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max))
3122 .map_err(Error::wrap)?;
3123 for _i in 0..elem_count {
3124 data.push(rng.sample::<bf16, _>(uniform))
3125 }
3126 Ok(CpuStorage::BF16(data))
3127 }
3128 DType::F16 => {
3129 let mut data = Vec::with_capacity(elem_count);
3130 let uniform = rand::distr::Uniform::new(f16::from_f64(min), f16::from_f64(max))
3131 .map_err(Error::wrap)?;
3132 for _i in 0..elem_count {
3133 data.push(rng.sample::<f16, _>(uniform))
3134 }
3135 Ok(CpuStorage::F16(data))
3136 }
3137 DType::F8E4M3 => {
3138 let mut data = Vec::with_capacity(elem_count);
3139 let uniform =
3140 rand::distr::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max))
3141 .map_err(Error::wrap)?;
3142 for _i in 0..elem_count {
3143 data.push(rng.sample::<F8E4M3, _>(uniform))
3144 }
3145 Ok(CpuStorage::F8E4M3(data))
3146 }
3147 DType::F32 => {
3148 let mut data = Vec::with_capacity(elem_count);
3149 let uniform =
3150 rand::distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
3151 for _i in 0..elem_count {
3152 data.push(rng.sample::<f32, _>(uniform))
3153 }
3154 Ok(CpuStorage::F32(data))
3155 }
3156 DType::F64 => {
3157 let mut data = Vec::with_capacity(elem_count);
3158 let uniform = rand::distr::Uniform::new(min, max).map_err(Error::wrap)?;
3159 for _i in 0..elem_count {
3160 data.push(rng.sample::<f64, _>(uniform))
3161 }
3162 Ok(CpuStorage::F64(data))
3163 }
3164 }
3165 }
3166
3167 fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CpuStorage> {
3168 use rand::prelude::*;
3169
3170 let elem_count = shape.elem_count();
3171 let mut rng = rand::rng();
3172 match dtype {
3173 DType::U8
3174 | DType::U32
3175 | DType::I16
3176 | DType::I32
3177 | DType::I64
3178 | DType::F6E2M3
3179 | DType::F6E3M2
3180 | DType::F4
3181 | DType::F8E8M0 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
3182 DType::BF16 => {
3183 let mut data = Vec::with_capacity(elem_count);
3184 let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
3185 .map_err(Error::wrap)?;
3186 for _i in 0..elem_count {
3187 data.push(normal.sample(&mut rng))
3188 }
3189 Ok(CpuStorage::BF16(data))
3190 }
3191 DType::F16 => {
3192 let mut data = Vec::with_capacity(elem_count);
3193 let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
3194 .map_err(Error::wrap)?;
3195 for _i in 0..elem_count {
3196 data.push(normal.sample(&mut rng))
3197 }
3198 Ok(CpuStorage::F16(data))
3199 }
3200 DType::F8E4M3 => {
3201 let mut data = Vec::with_capacity(elem_count);
3202 let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std))
3203 .map_err(Error::wrap)?;
3204 for _i in 0..elem_count {
3205 data.push(normal.sample(&mut rng))
3206 }
3207 Ok(CpuStorage::F8E4M3(data))
3208 }
3209 DType::F32 => {
3210 let mut data = Vec::with_capacity(elem_count);
3211 let normal =
3212 rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
3213 for _i in 0..elem_count {
3214 data.push(normal.sample(&mut rng))
3215 }
3216 Ok(CpuStorage::F32(data))
3217 }
3218 DType::F64 => {
3219 let mut data = Vec::with_capacity(elem_count);
3220 let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
3221 for _i in 0..elem_count {
3222 data.push(normal.sample(&mut rng))
3223 }
3224 Ok(CpuStorage::F64(data))
3225 }
3226 }
3227 }
3228
3229 #[allow(clippy::uninit_vec)]
3230 unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
3231 let elem_count = shape.elem_count();
3232 let storage = match dtype {
3237 DType::U8 => {
3238 let mut v = Vec::with_capacity(elem_count);
3239 v.set_len(elem_count);
3240 CpuStorage::U8(v)
3241 }
3242 DType::U32 => {
3243 let mut v = Vec::with_capacity(elem_count);
3244 v.set_len(elem_count);
3245 CpuStorage::U32(v)
3246 }
3247 DType::I16 => {
3248 let mut v = Vec::with_capacity(elem_count);
3249 v.set_len(elem_count);
3250 CpuStorage::I16(v)
3251 }
3252 DType::I32 => {
3253 let mut v = Vec::with_capacity(elem_count);
3254 v.set_len(elem_count);
3255 CpuStorage::I32(v)
3256 }
3257 DType::I64 => {
3258 let mut v = Vec::with_capacity(elem_count);
3259 v.set_len(elem_count);
3260 CpuStorage::I64(v)
3261 }
3262 DType::BF16 => {
3263 let mut v = Vec::with_capacity(elem_count);
3264 v.set_len(elem_count);
3265 CpuStorage::BF16(v)
3266 }
3267 DType::F16 => {
3268 let mut v = Vec::with_capacity(elem_count);
3269 v.set_len(elem_count);
3270 CpuStorage::F16(v)
3271 }
3272 DType::F32 => {
3273 let mut v = Vec::with_capacity(elem_count);
3274 v.set_len(elem_count);
3275 CpuStorage::F32(v)
3276 }
3277 DType::F64 => {
3278 let mut v = Vec::with_capacity(elem_count);
3279 v.set_len(elem_count);
3280 CpuStorage::F64(v)
3281 }
3282 DType::F8E4M3 => {
3283 let mut v = Vec::with_capacity(elem_count);
3284 v.set_len(elem_count);
3285 CpuStorage::F8E4M3(v)
3286 }
3287 DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
3288 return Err(Error::UnsupportedDTypeForOp(dtype, "alloc_uninit").bt())
3289 }
3290 };
3291 Ok(storage)
3292 }
3293
3294 fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
3295 let elem_count = shape.elem_count();
3296 let storage = match dtype {
3297 DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
3298 DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
3299 DType::I16 => CpuStorage::I16(vec![0i16; elem_count]),
3300 DType::I32 => CpuStorage::I32(vec![0i32; elem_count]),
3301 DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
3302 DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
3303 DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
3304 DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
3305 DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
3306 DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]),
3307 DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
3308 return Err(Error::UnsupportedDTypeForOp(dtype, "zeros").bt())
3309 }
3310 };
3311 Ok(storage)
3312 }
3313
3314 fn synchronize(&self) -> Result<()> {
3315 Ok(())
3316 }
3317}
3318
3319#[macro_export]
3320macro_rules! map_dtype {
3321 ($name:expr, $storage:ident, $fn:expr, ($($dtypes:ident),+)) => {
3322 match $storage {
3323 $(CpuStorage::$dtypes(__e) => CpuStorage::$dtypes($fn(__e)),)*
3324 s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())?,
3325 }
3326 };
3327}