1use crate::array::owned::Array;
11use crate::array::view::ArrayView;
12use crate::dimension::broadcast::broadcast_shapes;
13use crate::dimension::{Dimension, IxDyn};
14use crate::dtype::Element;
15use crate::error::{FerrayError, FerrayResult};
16
17pub struct NdIter;
34
35impl NdIter {
36 pub fn binary_map<T, U, D1, D2, F>(
51 a: &Array<T, D1>,
52 b: &Array<T, D2>,
53 f: F,
54 ) -> FerrayResult<Array<U, IxDyn>>
55 where
56 T: Element + Copy,
57 U: Element,
58 D1: Dimension,
59 D2: Dimension,
60 F: Fn(T, T) -> U,
61 {
62 let shape = broadcast_shapes(a.shape(), b.shape())?;
63 let a_view = a.broadcast_to(&shape)?;
64 let b_view = b.broadcast_to(&shape)?;
65
66 let total: usize = shape.iter().product();
67 let mut data = Vec::with_capacity(total);
68
69 if let (Some(a_slice), Some(b_slice)) = (a_view.as_slice(), b_view.as_slice()) {
71 for (&ai, &bi) in a_slice.iter().zip(b_slice.iter()) {
72 data.push(f(ai, bi));
73 }
74 } else {
75 for (&ai, &bi) in a_view.iter().zip(b_view.iter()) {
77 data.push(f(ai, bi));
78 }
79 }
80
81 Array::from_vec(IxDyn::from(&shape[..]), data)
82 }
83
84 pub fn binary_map_mixed<A, B, U, D1, D2, F>(
93 a: &Array<A, D1>,
94 b: &Array<B, D2>,
95 f: F,
96 ) -> FerrayResult<Array<U, IxDyn>>
97 where
98 A: Element + Copy,
99 B: Element + Copy,
100 U: Element,
101 D1: Dimension,
102 D2: Dimension,
103 F: Fn(A, B) -> U,
104 {
105 let shape = broadcast_shapes(a.shape(), b.shape())?;
106 let a_view = a.broadcast_to(&shape)?;
107 let b_view = b.broadcast_to(&shape)?;
108
109 let total: usize = shape.iter().product();
110 let mut data = Vec::with_capacity(total);
111
112 for (&ai, &bi) in a_view.iter().zip(b_view.iter()) {
113 data.push(f(ai, bi));
114 }
115
116 Array::from_vec(IxDyn::from(&shape[..]), data)
117 }
118
119 pub fn binary_map_into<T, D1, D2>(
127 a: &Array<T, D1>,
128 b: &Array<T, D2>,
129 out: &mut Array<T, IxDyn>,
130 f: impl Fn(T, T) -> T,
131 ) -> FerrayResult<()>
132 where
133 T: Element + Copy,
134 D1: Dimension,
135 D2: Dimension,
136 {
137 let shape = broadcast_shapes(a.shape(), b.shape())?;
138 if out.shape() != &shape[..] {
139 return Err(FerrayError::shape_mismatch(format!(
140 "output shape {:?} does not match broadcast shape {:?}",
141 out.shape(),
142 shape
143 )));
144 }
145
146 let a_view = a.broadcast_to(&shape)?;
147 let b_view = b.broadcast_to(&shape)?;
148
149 if let Some(out_slice) = out.as_slice_mut() {
150 for ((o, &ai), &bi) in out_slice.iter_mut().zip(a_view.iter()).zip(b_view.iter()) {
151 *o = f(ai, bi);
152 }
153 } else {
154 for ((&ai, &bi), o) in a_view.iter().zip(b_view.iter()).zip(out.iter_mut()) {
156 *o = f(ai, bi);
157 }
158 }
159
160 Ok(())
161 }
162
163 pub fn unary_map<T, U, D, F>(a: &Array<T, D>, f: F) -> FerrayResult<Array<U, IxDyn>>
168 where
169 T: Element + Copy,
170 U: Element,
171 D: Dimension,
172 F: Fn(T) -> U,
173 {
174 let shape = a.shape().to_vec();
175 let total: usize = shape.iter().product();
176 let mut data = Vec::with_capacity(total);
177
178 if let Some(slice) = a.as_slice() {
179 for &x in slice {
180 data.push(f(x));
181 }
182 } else {
183 for &x in a.iter() {
184 data.push(f(x));
185 }
186 }
187
188 Array::from_vec(IxDyn::from(&shape[..]), data)
189 }
190
191 pub fn unary_map_into<T, D>(
193 a: &Array<T, D>,
194 out: &mut Array<T, IxDyn>,
195 f: impl Fn(T) -> T,
196 ) -> FerrayResult<()>
197 where
198 T: Element + Copy,
199 D: Dimension,
200 {
201 if a.shape() != out.shape() {
202 return Err(FerrayError::shape_mismatch(format!(
203 "input shape {:?} does not match output shape {:?}",
204 a.shape(),
205 out.shape()
206 )));
207 }
208
209 if let (Some(in_slice), Some(out_slice)) = (a.as_slice(), out.as_slice_mut()) {
210 for (o, &x) in out_slice.iter_mut().zip(in_slice.iter()) {
211 *o = f(x);
212 }
213 } else {
214 for (o, &x) in out.iter_mut().zip(a.iter()) {
215 *o = f(x);
216 }
217 }
218
219 Ok(())
220 }
221
222 pub fn broadcast_shape(a_shape: &[usize], b_shape: &[usize]) -> FerrayResult<Vec<usize>> {
226 broadcast_shapes(a_shape, b_shape)
227 }
228
229 pub fn binary_iter<'a, T, D1, D2>(
234 a: &'a Array<T, D1>,
235 b: &'a Array<T, D2>,
236 ) -> FerrayResult<BinaryBroadcastIter<'a, T>>
237 where
238 T: Element + Copy,
239 D1: Dimension,
240 D2: Dimension,
241 {
242 let shape = broadcast_shapes(a.shape(), b.shape())?;
243 let a_view = a.broadcast_to(&shape)?;
244 let b_view = b.broadcast_to(&shape)?;
245
246 let a_data: Vec<T> = a_view.iter().copied().collect();
250 let b_data: Vec<T> = b_view.iter().copied().collect();
251
252 Ok(BinaryBroadcastIter {
253 a_view,
254 b_view,
255 a_data,
256 b_data,
257 index: 0,
258 })
259 }
260}
261
262pub struct BinaryBroadcastIter<'a, T: Element> {
273 a_view: ArrayView<'a, T, IxDyn>,
275 b_view: ArrayView<'a, T, IxDyn>,
277 a_data: Vec<T>,
279 b_data: Vec<T>,
281 index: usize,
283}
284
285impl<T: Element + Copy> Iterator for BinaryBroadcastIter<'_, T> {
286 type Item = (T, T);
287
288 #[inline]
289 fn next(&mut self) -> Option<Self::Item> {
290 if self.index >= self.a_data.len() {
291 return None;
292 }
293 let i = self.index;
294 self.index += 1;
295 Some((self.a_data[i], self.b_data[i]))
296 }
297
298 #[inline]
299 fn size_hint(&self) -> (usize, Option<usize>) {
300 let remaining = self.a_data.len() - self.index;
301 (remaining, Some(remaining))
302 }
303}
304
305impl<T: Element + Copy> ExactSizeIterator for BinaryBroadcastIter<'_, T> {}
306
307impl<'a, T: Element> BinaryBroadcastIter<'a, T> {
308 pub fn map_collect<U, F>(self, f: F) -> Vec<U>
313 where
314 T: Copy,
315 F: Fn(T, T) -> U,
316 {
317 self.a_view
318 .iter()
319 .zip(self.b_view.iter())
320 .map(|(&a, &b)| f(a, b))
321 .collect()
322 }
323
324 pub fn for_each<F>(self, mut f: F)
326 where
327 T: Copy,
328 F: FnMut(T, T),
329 {
330 for (&a, &b) in self.a_view.iter().zip(self.b_view.iter()) {
331 f(a, b);
332 }
333 }
334
335 pub fn shape(&self) -> &[usize] {
337 self.a_view.shape()
338 }
339
340 pub fn size(&self) -> usize {
342 self.a_view.size()
343 }
344}
345
346#[cfg(test)]
351mod tests {
352 use super::*;
353 use crate::dimension::{Ix1, Ix2};
354
355 #[test]
356 fn binary_map_same_shape() {
357 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
358 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
359 let c = NdIter::binary_map(&a, &b, |x, y| x + y).unwrap();
360 assert_eq!(c.shape(), &[3]);
361 let data: Vec<f64> = c.iter().copied().collect();
362 assert_eq!(data, vec![11.0, 22.0, 33.0]);
363 }
364
365 #[test]
366 fn binary_map_broadcast_1d_to_2d() {
367 let a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
369 let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
370 let c = NdIter::binary_map(&a, &b, |x, y| x + y).unwrap();
371 assert_eq!(c.shape(), &[3, 4]);
372 let data: Vec<f64> = c.iter().copied().collect();
373 assert_eq!(
374 data,
375 vec![
376 11.0, 21.0, 31.0, 41.0, 12.0, 22.0, 32.0, 42.0, 13.0, 23.0, 33.0, 43.0
377 ]
378 );
379 }
380
381 #[test]
382 fn binary_map_broadcast_scalar() {
383 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
385 let b = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![100.0]).unwrap();
386 let c = NdIter::binary_map(&a, &b, |x, y| x * y).unwrap();
387 assert_eq!(c.shape(), &[3]);
388 let data: Vec<f64> = c.iter().copied().collect();
389 assert_eq!(data, vec![100.0, 200.0, 300.0]);
390 }
391
392 #[test]
393 fn binary_map_incompatible_shapes() {
394 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
395 let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
396 let result = NdIter::binary_map(&a, &b, |x, y| x + y);
397 assert!(result.is_err());
398 }
399
400 #[test]
401 fn binary_map_to_bool() {
402 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 5.0, 3.0, 7.0]).unwrap();
403 let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![2.0, 3.0, 3.0, 6.0]).unwrap();
404 let c = NdIter::binary_map(&a, &b, |x, y| x > y).unwrap();
405 assert_eq!(c.shape(), &[4]);
406 let data: Vec<bool> = c.iter().copied().collect();
407 assert_eq!(data, vec![false, true, false, true]);
408 }
409
410 #[test]
411 fn binary_map_into_preallocated() {
412 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
413 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
414 let mut out = Array::<f64, IxDyn>::zeros(IxDyn::new(&[3])).unwrap();
415 NdIter::binary_map_into(&a, &b, &mut out, |x, y| x + y).unwrap();
416 let data: Vec<f64> = out.iter().copied().collect();
417 assert_eq!(data, vec![11.0, 22.0, 33.0]);
418 }
419
420 #[test]
421 fn binary_map_into_wrong_shape_error() {
422 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
423 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
424 let mut out = Array::<f64, IxDyn>::zeros(IxDyn::new(&[5])).unwrap();
425 let result = NdIter::binary_map_into(&a, &b, &mut out, |x, y| x + y);
426 assert!(result.is_err());
427 }
428
429 #[test]
430 fn unary_map_basic() {
431 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 4.0, 9.0, 16.0]).unwrap();
432 let c = NdIter::unary_map(&a, |x| x.sqrt()).unwrap();
433 assert_eq!(c.shape(), &[4]);
434 let data: Vec<f64> = c.iter().copied().collect();
435 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
436 }
437
438 #[test]
439 fn unary_map_into_preallocated() {
440 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 4.0, 9.0]).unwrap();
441 let mut out = Array::<f64, IxDyn>::zeros(IxDyn::new(&[3])).unwrap();
442 NdIter::unary_map_into(&a, &mut out, |x| x * 2.0).unwrap();
443 let data: Vec<f64> = out.iter().copied().collect();
444 assert_eq!(data, vec![2.0, 8.0, 18.0]);
445 }
446
447 #[test]
448 fn binary_iter_shape() {
449 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![1.0, 2.0]).unwrap();
450 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
451 let iter = NdIter::binary_iter(&a, &b).unwrap();
452 assert_eq!(iter.shape(), &[2, 3]);
453 }
454
455 #[test]
456 fn binary_iter_map_collect() {
457 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
458 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
459 let iter = NdIter::binary_iter(&a, &b).unwrap();
460 let result: Vec<f64> = iter.map_collect(|x, y| x + y);
461 assert_eq!(result, vec![11.0, 22.0, 33.0]);
462 }
463
464 #[test]
465 fn binary_map_3d_broadcast() {
466 use crate::dimension::Ix3;
468 let a =
469 Array::<i32, Ix3>::from_vec(Ix3::new([2, 1, 4]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
470 let b = Array::<i32, Ix2>::from_vec(Ix2::new([3, 1]), vec![10, 20, 30]).unwrap();
471 let c = NdIter::binary_map(&a, &b, |x, y| x + y).unwrap();
472 assert_eq!(c.shape(), &[2, 3, 4]);
473 assert_eq!(*c.iter().next().unwrap(), 11);
475 }
476
477 #[test]
480 fn binary_iter_next() {
481 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
482 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
483 let mut iter = NdIter::binary_iter(&a, &b).unwrap();
484 assert_eq!(iter.next(), Some((1.0, 10.0)));
485 assert_eq!(iter.next(), Some((2.0, 20.0)));
486 assert_eq!(iter.next(), Some((3.0, 30.0)));
487 assert_eq!(iter.next(), None);
488 }
489
490 #[test]
491 fn binary_iter_for_loop() {
492 let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
493 let b = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
494 let iter = NdIter::binary_iter(&a, &b).unwrap();
495 let sums: Vec<i32> = iter.map(|(x, y)| x + y).collect();
496 assert_eq!(sums, vec![11, 22, 33]);
497 }
498
499 #[test]
500 fn binary_iter_broadcast_with_next() {
501 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![1.0, 2.0]).unwrap();
503 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
504 let iter = NdIter::binary_iter(&a, &b).unwrap();
505 assert_eq!(iter.len(), 6);
506 let pairs: Vec<(f64, f64)> = iter.collect();
507 assert_eq!(
508 pairs,
509 vec![
510 (1.0, 10.0),
511 (1.0, 20.0),
512 (1.0, 30.0),
513 (2.0, 10.0),
514 (2.0, 20.0),
515 (2.0, 30.0),
516 ]
517 );
518 }
519
520 #[test]
521 fn binary_iter_exact_size() {
522 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0; 5]).unwrap();
523 let b = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![2.0; 5]).unwrap();
524 let iter = NdIter::binary_iter(&a, &b).unwrap();
525 assert_eq!(iter.len(), 5);
526 }
527
528 #[test]
529 fn binary_iter_for_each_method() {
530 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
531 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
532 let iter = NdIter::binary_iter(&a, &b).unwrap();
533 let mut sum = 0.0;
534 iter.for_each(|x, y| sum += x + y);
535 assert!((sum - 66.0).abs() < 1e-10); }
537
538 #[test]
541 fn binary_map_mixed_i32_f64() {
542 let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
543 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![0.5, 1.5, 2.5]).unwrap();
544 let c = NdIter::binary_map_mixed(&a, &b, |x, y| x as f64 + y).unwrap();
545 assert_eq!(c.shape(), &[3]);
546 let data: Vec<f64> = c.iter().copied().collect();
547 assert_eq!(data, vec![1.5, 3.5, 5.5]);
548 }
549
550 #[test]
551 fn binary_map_mixed_broadcast() {
552 let a = Array::<i32, Ix2>::from_vec(Ix2::new([3, 1]), vec![1, 2, 3]).unwrap();
554 let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
555 let c = NdIter::binary_map_mixed(&a, &b, |x, y| x as f64 + y).unwrap();
556 assert_eq!(c.shape(), &[3, 4]);
557 let first: f64 = *c.iter().next().unwrap();
558 assert!((first - 1.1).abs() < 1e-10);
559 }
560
561 #[test]
562 fn binary_map_mixed_to_bool() {
563 let a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 5, 3, 7]).unwrap();
565 let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![2.0, 3.0, 3.0, 6.0]).unwrap();
566 let c = NdIter::binary_map_mixed(&a, &b, |x, y| (x as f64) > y).unwrap();
567 let data: Vec<bool> = c.iter().copied().collect();
568 assert_eq!(data, vec![false, true, false, true]);
569 }
570}