1use crate::backend::BackendStorage;
3use crate::{Error, Layout, Result, WithDType};
4
5type C = super::CpuStorage;
6pub trait Map1 {
7 fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
8
9 fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
10 match vs {
11 C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
12 C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
13 C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
14 C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
15 C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
16 C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
17 C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
18 }
19 }
20}
21
22pub trait Map1Any {
23 fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
24
25 fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
26 match vs {
27 C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
28 C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
29 C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
30 C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
31 C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
32 C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
33 C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
34 }
35 }
36}
37
38pub trait Map2 {
39 const OP: &'static str;
40 fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
41
42 fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
43 match (v1, v2) {
44 (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
45 (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
46 (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
47 (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
48 (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
49 (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
50 (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
51 _ => Err(Error::DTypeMismatchBinaryOp {
52 lhs: v1.dtype(),
53 rhs: v2.dtype(),
54 op: Self::OP,
55 }
56 .bt()),
57 }
58 }
59}
60
61pub trait Map2InPlace {
62 const OP: &'static str;
63 fn f<T: WithDType>(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>;
64
65 fn map(&self, v1: &mut C, l1: &Layout, v2: &C, l2: &Layout) -> Result<()> {
66 match (v1, v2) {
67 (C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?,
68 (C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?,
69 (C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?,
70 (C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?,
71 (C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?,
72 (C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?,
73 (C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?,
74 (v1, v2) => Err(Error::DTypeMismatchBinaryOp {
75 lhs: v1.dtype(),
76 rhs: v2.dtype(),
77 op: Self::OP,
78 }
79 .bt())?,
80 };
81 Ok(())
82 }
83}
84
85pub trait Map2U8 {
86 const OP: &'static str;
87 fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
88
89 fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
90 match (v1, v2) {
91 (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
92 (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
93 (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
94 (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
95 (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
96 (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
97 (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
98 _ => Err(Error::DTypeMismatchBinaryOp {
99 lhs: v1.dtype(),
100 rhs: v2.dtype(),
101 op: Self::OP,
102 }
103 .bt()),
104 }
105 }
106}
107
108pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
109 lhs_l: &Layout,
110 rhs_l: &Layout,
111 lhs: &[T],
112 rhs: &[T],
113 mut f: F,
114) -> Vec<U> {
115 match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
116 (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
117 .iter()
118 .zip(rhs[o_r1..o_r2].iter())
119 .map(|(&l, &r)| f(l, r))
120 .collect(),
121 (Some((o_l1, o_l2)), None) => {
122 match rhs_l.offsets_b() {
124 Some(ob) => {
125 let mut i_in_block = 0;
126 let mut i_right_broadcast = 0;
127 lhs[o_l1..o_l2]
128 .iter()
129 .map(|&l| {
130 let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
131 i_right_broadcast += 1;
132 if i_right_broadcast >= ob.right_broadcast {
133 i_in_block += 1;
134 i_right_broadcast = 0;
135 }
136 if i_in_block >= ob.len {
137 i_in_block = 0
138 }
139 f(l, *r)
140 })
141 .collect()
142 }
143 None => lhs_l
144 .strided_index()
145 .zip(rhs_l.strided_index())
146 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
147 .collect(),
148 }
149 }
150 (None, Some((o_r1, o_r2))) => {
151 match lhs_l.offsets_b() {
153 Some(ob) => {
154 let mut i_in_block = 0;
155 let mut i_right_broadcast = 0;
156 rhs[o_r1..o_r2]
157 .iter()
158 .map(|&r| {
159 let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
160 i_right_broadcast += 1;
161 if i_right_broadcast >= ob.right_broadcast {
162 i_in_block += 1;
163 i_right_broadcast = 0;
164 }
165 if i_in_block >= ob.len {
166 i_in_block = 0
167 }
168 f(*l, r)
169 })
170 .collect()
171 }
172 None => lhs_l
173 .strided_index()
174 .zip(rhs_l.strided_index())
175 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
176 .collect(),
177 }
178 }
179 _ => lhs_l
180 .strided_index()
181 .zip(rhs_l.strided_index())
182 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
183 .collect(),
184 }
185}
186
187pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
189 lhs_l: &Layout,
190 rhs_l: &Layout,
191 lhs: &[T],
192 rhs: &[T],
193 mut f: F,
194 mut f_vec: FV,
195) -> Vec<T> {
196 let el_count = lhs_l.shape().elem_count();
197 match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
198 (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
199 let mut ys: Vec<T> = Vec::with_capacity(el_count);
200 let ys_to_set = ys.spare_capacity_mut();
201 let ys_to_set = unsafe {
202 std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
203 };
204 f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
205 unsafe { ys.set_len(el_count) };
207 ys
208 }
209 (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
210 Some(ob) if ob.right_broadcast == 1 => {
211 let rhs = &rhs[ob.start..ob.start + ob.len];
212 let mut ys: Vec<T> = Vec::with_capacity(el_count);
213 let ys_to_set = ys.spare_capacity_mut();
214 let ys_to_set = unsafe {
215 std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
216 };
217 let mut dst_i = 0;
218 for src_i in (o_l1..o_l2).step_by(ob.len) {
219 f_vec(
220 &lhs[src_i..src_i + ob.len],
221 rhs,
222 &mut ys_to_set[dst_i..dst_i + ob.len],
223 );
224 dst_i += ob.len;
225 }
226 unsafe { ys.set_len(el_count) };
228 ys
229 }
230 Some(ob) => {
231 let rhs = &rhs[ob.start..ob.start + ob.len];
232 let mut ys = lhs[o_l1..o_l2].to_vec();
233 for idx_l in 0..ob.left_broadcast {
234 let start = idx_l * ob.len * ob.right_broadcast;
235 for (i, &r) in rhs.iter().enumerate() {
236 let start = start + i * ob.right_broadcast;
237 for v in ys[start..start + ob.right_broadcast].iter_mut() {
238 *v = f(*v, r)
239 }
240 }
241 }
242 ys
243 }
244 None => lhs_l
245 .strided_index()
246 .zip(rhs_l.strided_index())
247 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
248 .collect(),
249 },
250 (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
251 Some(ob) if ob.right_broadcast == 1 => {
252 let lhs = &lhs[ob.start..ob.start + ob.len];
253 let mut ys: Vec<T> = Vec::with_capacity(el_count);
254 let ys_to_set = ys.spare_capacity_mut();
255 let ys_to_set = unsafe {
256 std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
257 };
258 let mut dst_i = 0;
259 for src_i in (o_r1..o_r2).step_by(ob.len) {
260 f_vec(
261 lhs,
262 &rhs[src_i..src_i + ob.len],
263 &mut ys_to_set[dst_i..dst_i + ob.len],
264 );
265 dst_i += ob.len;
266 }
267 unsafe { ys.set_len(el_count) };
269 ys
270 }
271 Some(ob) => {
272 let lhs = &lhs[ob.start..ob.start + ob.len];
273 let mut ys = rhs[o_r1..o_r2].to_vec();
274 for idx_l in 0..ob.left_broadcast {
275 let start = idx_l * ob.len * ob.right_broadcast;
276 for (i, &l) in lhs.iter().enumerate() {
277 let start = start + i * ob.right_broadcast;
278 for v in ys[start..start + ob.right_broadcast].iter_mut() {
279 *v = f(l, *v)
280 }
281 }
282 }
283 ys
284 }
285 None => lhs_l
286 .strided_index()
287 .zip(rhs_l.strided_index())
288 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
289 .collect(),
290 },
291 _ => lhs_l
292 .strided_index()
293 .zip(rhs_l.strided_index())
294 .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
295 .collect(),
296 }
297}
298
299pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
300 vs: &[T],
301 layout: &Layout,
302 mut f: F,
303) -> Vec<U> {
304 match layout.strided_blocks() {
305 crate::StridedBlocks::SingleBlock { start_offset, len } => vs
306 [start_offset..start_offset + len]
307 .iter()
308 .map(|&v| f(v))
309 .collect(),
310 crate::StridedBlocks::MultipleBlocks {
311 block_start_index,
312 block_len,
313 } => {
314 let mut result = Vec::with_capacity(layout.shape().elem_count());
315 if block_len == 1 {
317 for index in block_start_index {
318 let v = unsafe { vs.get_unchecked(index) };
319 result.push(f(*v))
320 }
321 } else {
322 for index in block_start_index {
323 for offset in 0..block_len {
324 let v = unsafe { vs.get_unchecked(index + offset) };
325 result.push(f(*v))
326 }
327 }
328 }
329 result
330 }
331 }
332}
333
334pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
335 vs: &[T],
336 layout: &Layout,
337 mut f: F,
338 mut f_vec: FV,
339) -> Vec<U> {
340 match layout.strided_blocks() {
341 crate::StridedBlocks::SingleBlock { start_offset, len } => {
342 let mut ys: Vec<U> = Vec::with_capacity(len);
343 let ys_to_set = ys.spare_capacity_mut();
344 let ys_to_set = unsafe {
345 std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
346 };
347 f_vec(&vs[start_offset..start_offset + len], ys_to_set);
348 unsafe { ys.set_len(len) };
350 ys
351 }
352 crate::StridedBlocks::MultipleBlocks {
353 block_start_index,
354 block_len,
355 } => {
356 let el_count = layout.shape().elem_count();
357 if block_len == 1 {
359 let mut result = Vec::with_capacity(el_count);
360 for index in block_start_index {
361 let v = unsafe { vs.get_unchecked(index) };
362 result.push(f(*v))
363 }
364 result
365 } else {
366 let mut ys: Vec<U> = Vec::with_capacity(el_count);
367 let ys_to_set = ys.spare_capacity_mut();
368 let ys_to_set = unsafe {
369 std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
370 };
371 let mut dst_index = 0;
372 for src_index in block_start_index {
373 let vs = &vs[src_index..src_index + block_len];
374 let ys = &mut ys_to_set[dst_index..dst_index + block_len];
375 f_vec(vs, ys);
376 dst_index += block_len;
377 }
378 unsafe { ys.set_len(el_count) };
380 ys
381 }
382 }
383 }
384}