1use crate::backend::BackendStorage;
3use crate::{Error, Layout, Result, WithDType};
4
5type C = super::CpuStorage;
6pub trait Map1 {
7 fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
8
9 fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
10 match vs {
11 C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
12 C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
13 C::I16(vs) => Ok(C::I16(self.f(vs, layout)?)),
14 C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)),
15 C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
16 C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
17 C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
18 C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
19 C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
20 C::F8E4M3(vs) => Ok(C::F8E4M3(self.f(vs, layout)?)),
21 C::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()),
23 C::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()),
24 C::F4(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()),
25 C::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()),
26 }
27 }
28}
29
30pub trait Map1Any {
31 fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
32
33 fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
34 match vs {
35 C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
36 C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
37 C::I16(vs) => Ok(self.f(vs, layout, C::I16)?),
38 C::I32(vs) => Ok(self.f(vs, layout, C::I32)?),
39 C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
40 C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
41 C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
42 C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
43 C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
44 C::F8E4M3(vs) => Ok(self.f(vs, layout, C::F8E4M3)?),
45 C::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()),
47 C::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()),
48 C::F4(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()),
49 C::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()),
50 }
51 }
52}
53
54pub trait Map2 {
55 const OP: &'static str;
56 fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
57
58 fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
59 match (v1, v2) {
60 (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
61 (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
62 (C::I16(v1), C::I16(v2)) => Ok(C::I16(self.f(v1, l1, v2, l2)?)),
63 (C::I32(v1), C::I32(v2)) => Ok(C::I32(self.f(v1, l1, v2, l2)?)),
64 (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
65 (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
66 (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
67 (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
68 (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
69 (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2)?)),
70 _ => Err(Error::DTypeMismatchBinaryOp {
71 lhs: v1.dtype(),
72 rhs: v2.dtype(),
73 op: Self::OP,
74 }
75 .bt()),
76 }
77 }
78}
79
80pub trait Map2InPlace {
81 const OP: &'static str;
82 fn f<T: WithDType>(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>;
83
84 fn map(&self, v1: &mut C, l1: &Layout, v2: &C, l2: &Layout) -> Result<()> {
85 match (v1, v2) {
86 (C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?,
87 (C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?,
88 (C::I16(v1), C::I16(v2)) => self.f(v1, l1, v2, l2)?,
89 (C::I32(v1), C::I32(v2)) => self.f(v1, l1, v2, l2)?,
90 (C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?,
91 (C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?,
92 (C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?,
93 (C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?,
94 (C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?,
95 (C::F8E4M3(v1), C::F8E4M3(v2)) => self.f(v1, l1, v2, l2)?,
96 (v1, v2) => Err(Error::DTypeMismatchBinaryOp {
97 lhs: v1.dtype(),
98 rhs: v2.dtype(),
99 op: Self::OP,
100 }
101 .bt())?,
102 };
103 Ok(())
104 }
105}
106
107pub trait Map2U8 {
108 const OP: &'static str;
109 fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
110
111 fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
112 match (v1, v2) {
113 (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
114 (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
115 (C::I16(v1), C::I16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
116 (C::I32(v1), C::I32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
117 (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
118 (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
119 (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
120 (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
121 (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
122 (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
123 _ => Err(Error::DTypeMismatchBinaryOp {
124 lhs: v1.dtype(),
125 rhs: v2.dtype(),
126 op: Self::OP,
127 }
128 .bt()),
129 }
130 }
131}
132
133pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
134 lhs_l: &Layout,
135 rhs_l: &Layout,
136 lhs: &[T],
137 rhs: &[T],
138 mut f: F,
139) -> Vec<U> {
140 match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
141 (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
142 .iter()
143 .zip(rhs[o_r1..o_r2].iter())
144 .map(|(&l, &r)| f(l, r))
145 .collect(),
146 (Some((o_l1, o_l2)), None) => {
147 match rhs_l.offsets_b() {
149 Some(ob) => {
150 let mut i_in_block = 0;
151 let mut i_right_broadcast = 0;
152 lhs[o_l1..o_l2]
153 .iter()
154 .map(|&l| {
155 let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
156 i_right_broadcast += 1;
157 if i_right_broadcast >= ob.right_broadcast {
158 i_in_block += 1;
159 i_right_broadcast = 0;
160 }
161 if i_in_block >= ob.len {
162 i_in_block = 0
163 }
164 f(l, *r)
165 })
166 .collect()
167 }
168 None => lhs_l
169 .strided_index()
170 .zip(rhs_l.strided_index())
171 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
172 .collect(),
173 }
174 }
175 (None, Some((o_r1, o_r2))) => {
176 match lhs_l.offsets_b() {
178 Some(ob) => {
179 let mut i_in_block = 0;
180 let mut i_right_broadcast = 0;
181 rhs[o_r1..o_r2]
182 .iter()
183 .map(|&r| {
184 let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
185 i_right_broadcast += 1;
186 if i_right_broadcast >= ob.right_broadcast {
187 i_in_block += 1;
188 i_right_broadcast = 0;
189 }
190 if i_in_block >= ob.len {
191 i_in_block = 0
192 }
193 f(*l, r)
194 })
195 .collect()
196 }
197 None => lhs_l
198 .strided_index()
199 .zip(rhs_l.strided_index())
200 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
201 .collect(),
202 }
203 }
204 _ => lhs_l
205 .strided_index()
206 .zip(rhs_l.strided_index())
207 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
208 .collect(),
209 }
210}
211
212pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
214 lhs_l: &Layout,
215 rhs_l: &Layout,
216 lhs: &[T],
217 rhs: &[T],
218 mut f: F,
219 mut f_vec: FV,
220) -> Vec<T> {
221 let el_count = lhs_l.shape().elem_count();
222 match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
223 (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
224 let mut ys: Vec<T> = Vec::with_capacity(el_count);
225 let ys_to_set = ys.spare_capacity_mut();
226 let ys_to_set = unsafe {
227 std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
228 };
229 f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
230 unsafe { ys.set_len(el_count) };
232 ys
233 }
234 (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
235 Some(ob) if ob.right_broadcast == 1 => {
236 let rhs = &rhs[ob.start..ob.start + ob.len];
237 let mut ys: Vec<T> = Vec::with_capacity(el_count);
238 let ys_to_set = ys.spare_capacity_mut();
239 let ys_to_set = unsafe {
240 std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
241 };
242 let mut dst_i = 0;
243 for src_i in (o_l1..o_l2).step_by(ob.len) {
244 f_vec(
245 &lhs[src_i..src_i + ob.len],
246 rhs,
247 &mut ys_to_set[dst_i..dst_i + ob.len],
248 );
249 dst_i += ob.len;
250 }
251 unsafe { ys.set_len(el_count) };
253 ys
254 }
255 Some(ob) => {
256 let rhs = &rhs[ob.start..ob.start + ob.len];
257 let mut ys = lhs[o_l1..o_l2].to_vec();
258 for idx_l in 0..ob.left_broadcast {
259 let start = idx_l * ob.len * ob.right_broadcast;
260 for (i, &r) in rhs.iter().enumerate() {
261 let start = start + i * ob.right_broadcast;
262 for v in ys[start..start + ob.right_broadcast].iter_mut() {
263 *v = f(*v, r)
264 }
265 }
266 }
267 ys
268 }
269 None => lhs_l
270 .strided_index()
271 .zip(rhs_l.strided_index())
272 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
273 .collect(),
274 },
275 (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
276 Some(ob) if ob.right_broadcast == 1 => {
277 let lhs = &lhs[ob.start..ob.start + ob.len];
278 let mut ys: Vec<T> = Vec::with_capacity(el_count);
279 let ys_to_set = ys.spare_capacity_mut();
280 let ys_to_set = unsafe {
281 std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
282 };
283 let mut dst_i = 0;
284 for src_i in (o_r1..o_r2).step_by(ob.len) {
285 f_vec(
286 lhs,
287 &rhs[src_i..src_i + ob.len],
288 &mut ys_to_set[dst_i..dst_i + ob.len],
289 );
290 dst_i += ob.len;
291 }
292 unsafe { ys.set_len(el_count) };
294 ys
295 }
296 Some(ob) => {
297 let lhs = &lhs[ob.start..ob.start + ob.len];
298 let mut ys = rhs[o_r1..o_r2].to_vec();
299 for idx_l in 0..ob.left_broadcast {
300 let start = idx_l * ob.len * ob.right_broadcast;
301 for (i, &l) in lhs.iter().enumerate() {
302 let start = start + i * ob.right_broadcast;
303 for v in ys[start..start + ob.right_broadcast].iter_mut() {
304 *v = f(l, *v)
305 }
306 }
307 }
308 ys
309 }
310 None => lhs_l
311 .strided_index()
312 .zip(rhs_l.strided_index())
313 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
314 .collect(),
315 },
316 _ => lhs_l
317 .strided_index()
318 .zip(rhs_l.strided_index())
319 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
320 .collect(),
321 }
322}
323
324pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
325 vs: &[T],
326 layout: &Layout,
327 mut f: F,
328) -> Vec<U> {
329 match layout.strided_blocks() {
330 crate::StridedBlocks::SingleBlock { start_offset, len } => vs
331 [start_offset..start_offset + len]
332 .iter()
333 .map(|&v| f(v))
334 .collect(),
335 crate::StridedBlocks::MultipleBlocks {
336 block_start_index,
337 block_len,
338 } => {
339 let mut result = Vec::with_capacity(layout.shape().elem_count());
340 if block_len == 1 {
342 for index in block_start_index {
343 let v = unsafe { vs.get_unchecked(index) };
344 result.push(f(*v))
345 }
346 } else {
347 for index in block_start_index {
348 for offset in 0..block_len {
349 let v = unsafe { vs.get_unchecked(index + offset) };
350 result.push(f(*v))
351 }
352 }
353 }
354 result
355 }
356 }
357}
358
359pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
360 vs: &[T],
361 layout: &Layout,
362 mut f: F,
363 mut f_vec: FV,
364) -> Vec<U> {
365 match layout.strided_blocks() {
366 crate::StridedBlocks::SingleBlock { start_offset, len } => {
367 let mut ys: Vec<U> = Vec::with_capacity(len);
368 let ys_to_set = ys.spare_capacity_mut();
369 let ys_to_set = unsafe {
370 std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
371 };
372 f_vec(&vs[start_offset..start_offset + len], ys_to_set);
373 unsafe { ys.set_len(len) };
375 ys
376 }
377 crate::StridedBlocks::MultipleBlocks {
378 block_start_index,
379 block_len,
380 } => {
381 let el_count = layout.shape().elem_count();
382 if block_len == 1 {
384 let mut result = Vec::with_capacity(el_count);
385 for index in block_start_index {
386 let v = unsafe { vs.get_unchecked(index) };
387 result.push(f(*v))
388 }
389 result
390 } else {
391 let mut ys: Vec<U> = Vec::with_capacity(el_count);
392 let ys_to_set = ys.spare_capacity_mut();
393 let ys_to_set = unsafe {
394 std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
395 };
396 let mut dst_index = 0;
397 for src_index in block_start_index {
398 let vs = &vs[src_index..src_index + block_len];
399 let ys = &mut ys_to_set[dst_index..dst_index + block_len];
400 f_vec(vs, ys);
401 dst_index += block_len;
402 }
403 unsafe { ys.set_len(el_count) };
405 ys
406 }
407 }
408 }
409}