1use crate::array::owned::Array;
11use crate::array::view::ArrayView;
12use crate::dimension::broadcast::broadcast_shapes;
13use crate::dimension::{Dimension, IxDyn};
14use crate::dtype::Element;
15use crate::error::{FerrayError, FerrayResult};
16
17pub struct NdIter;
34
35impl NdIter {
36 pub fn binary_map<T, U, D1, D2, F>(
51 a: &Array<T, D1>,
52 b: &Array<T, D2>,
53 f: F,
54 ) -> FerrayResult<Array<U, IxDyn>>
55 where
56 T: Element + Copy,
57 U: Element,
58 D1: Dimension,
59 D2: Dimension,
60 F: Fn(T, T) -> U,
61 {
62 let shape = broadcast_shapes(a.shape(), b.shape())?;
63 let a_view = a.broadcast_to(&shape)?;
64 let b_view = b.broadcast_to(&shape)?;
65
66 let total: usize = shape.iter().product();
67 let mut data = Vec::with_capacity(total);
68
69 if let (Some(a_slice), Some(b_slice)) = (a_view.as_slice(), b_view.as_slice()) {
71 for (&ai, &bi) in a_slice.iter().zip(b_slice.iter()) {
72 data.push(f(ai, bi));
73 }
74 } else {
75 for (&ai, &bi) in a_view.iter().zip(b_view.iter()) {
77 data.push(f(ai, bi));
78 }
79 }
80
81 Array::from_vec(IxDyn::from(&shape[..]), data)
82 }
83
84 pub fn binary_map_mixed<A, B, U, D1, D2, F>(
93 a: &Array<A, D1>,
94 b: &Array<B, D2>,
95 f: F,
96 ) -> FerrayResult<Array<U, IxDyn>>
97 where
98 A: Element + Copy,
99 B: Element + Copy,
100 U: Element,
101 D1: Dimension,
102 D2: Dimension,
103 F: Fn(A, B) -> U,
104 {
105 let shape = broadcast_shapes(a.shape(), b.shape())?;
106 let a_view = a.broadcast_to(&shape)?;
107 let b_view = b.broadcast_to(&shape)?;
108
109 let total: usize = shape.iter().product();
110 let mut data = Vec::with_capacity(total);
111
112 for (&ai, &bi) in a_view.iter().zip(b_view.iter()) {
113 data.push(f(ai, bi));
114 }
115
116 Array::from_vec(IxDyn::from(&shape[..]), data)
117 }
118
119 pub fn binary_map_into<T, D1, D2>(
127 a: &Array<T, D1>,
128 b: &Array<T, D2>,
129 out: &mut Array<T, IxDyn>,
130 f: impl Fn(T, T) -> T,
131 ) -> FerrayResult<()>
132 where
133 T: Element + Copy,
134 D1: Dimension,
135 D2: Dimension,
136 {
137 let shape = broadcast_shapes(a.shape(), b.shape())?;
138 if out.shape() != &shape[..] {
139 return Err(FerrayError::shape_mismatch(format!(
140 "output shape {:?} does not match broadcast shape {:?}",
141 out.shape(),
142 shape
143 )));
144 }
145
146 let a_view = a.broadcast_to(&shape)?;
147 let b_view = b.broadcast_to(&shape)?;
148
149 if let Some(out_slice) = out.as_slice_mut() {
150 for ((o, &ai), &bi) in out_slice.iter_mut().zip(a_view.iter()).zip(b_view.iter()) {
151 *o = f(ai, bi);
152 }
153 } else {
154 for ((&ai, &bi), o) in a_view.iter().zip(b_view.iter()).zip(out.iter_mut()) {
156 *o = f(ai, bi);
157 }
158 }
159
160 Ok(())
161 }
162
163 pub fn unary_map<T, U, D, F>(a: &Array<T, D>, f: F) -> FerrayResult<Array<U, IxDyn>>
168 where
169 T: Element + Copy,
170 U: Element,
171 D: Dimension,
172 F: Fn(T) -> U,
173 {
174 let shape = a.shape().to_vec();
175 let total: usize = shape.iter().product();
176 let mut data = Vec::with_capacity(total);
177
178 if let Some(slice) = a.as_slice() {
179 for &x in slice {
180 data.push(f(x));
181 }
182 } else {
183 for &x in a.iter() {
184 data.push(f(x));
185 }
186 }
187
188 Array::from_vec(IxDyn::from(&shape[..]), data)
189 }
190
191 pub fn unary_map_into<T, D>(
193 a: &Array<T, D>,
194 out: &mut Array<T, IxDyn>,
195 f: impl Fn(T) -> T,
196 ) -> FerrayResult<()>
197 where
198 T: Element + Copy,
199 D: Dimension,
200 {
201 if a.shape() != out.shape() {
202 return Err(FerrayError::shape_mismatch(format!(
203 "input shape {:?} does not match output shape {:?}",
204 a.shape(),
205 out.shape()
206 )));
207 }
208
209 if let (Some(in_slice), Some(out_slice)) = (a.as_slice(), out.as_slice_mut()) {
210 for (o, &x) in out_slice.iter_mut().zip(in_slice.iter()) {
211 *o = f(x);
212 }
213 } else {
214 for (o, &x) in out.iter_mut().zip(a.iter()) {
215 *o = f(x);
216 }
217 }
218
219 Ok(())
220 }
221
222 pub fn broadcast_shape(a_shape: &[usize], b_shape: &[usize]) -> FerrayResult<Vec<usize>> {
226 broadcast_shapes(a_shape, b_shape)
227 }
228
229 pub fn binary_iter<'a, T, D1, D2>(
234 a: &'a Array<T, D1>,
235 b: &'a Array<T, D2>,
236 ) -> FerrayResult<BinaryBroadcastIter<'a, T>>
237 where
238 T: Element + Copy,
239 D1: Dimension,
240 D2: Dimension,
241 {
242 let shape = broadcast_shapes(a.shape(), b.shape())?;
243 let a_view = a.broadcast_to(&shape)?;
244 let b_view = b.broadcast_to(&shape)?;
245
246 let a_data: Vec<T> = a_view.iter().copied().collect();
250 let b_data: Vec<T> = b_view.iter().copied().collect();
251
252 Ok(BinaryBroadcastIter {
253 a_view,
254 b_view,
255 a_data,
256 b_data,
257 index: 0,
258 })
259 }
260}
261
262pub struct BinaryBroadcastIter<'a, T: Element> {
273 a_view: ArrayView<'a, T, IxDyn>,
275 b_view: ArrayView<'a, T, IxDyn>,
277 a_data: Vec<T>,
279 b_data: Vec<T>,
281 index: usize,
283}
284
285impl<T: Element + Copy> Iterator for BinaryBroadcastIter<'_, T> {
286 type Item = (T, T);
287
288 #[inline]
289 fn next(&mut self) -> Option<Self::Item> {
290 if self.index >= self.a_data.len() {
291 return None;
292 }
293 let i = self.index;
294 self.index += 1;
295 Some((self.a_data[i], self.b_data[i]))
296 }
297
298 #[inline]
299 fn size_hint(&self) -> (usize, Option<usize>) {
300 let remaining = self.a_data.len() - self.index;
301 (remaining, Some(remaining))
302 }
303}
304
305impl<T: Element + Copy> ExactSizeIterator for BinaryBroadcastIter<'_, T> {}
306
307impl<T: Element> BinaryBroadcastIter<'_, T> {
308 pub fn map_collect<U, F>(self, f: F) -> Vec<U>
313 where
314 T: Copy,
315 F: Fn(T, T) -> U,
316 {
317 self.a_view
318 .iter()
319 .zip(self.b_view.iter())
320 .map(|(&a, &b)| f(a, b))
321 .collect()
322 }
323
324 pub fn for_each<F>(self, mut f: F)
326 where
327 T: Copy,
328 F: FnMut(T, T),
329 {
330 for (&a, &b) in self.a_view.iter().zip(self.b_view.iter()) {
331 f(a, b);
332 }
333 }
334
335 #[must_use]
337 pub fn shape(&self) -> &[usize] {
338 self.a_view.shape()
339 }
340
341 #[must_use]
343 pub fn size(&self) -> usize {
344 self.a_view.size()
345 }
346}
347
348#[cfg(test)]
353mod tests {
354 use super::*;
355 use crate::dimension::{Ix1, Ix2};
356
357 #[test]
358 fn binary_map_same_shape() {
359 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
360 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
361 let c = NdIter::binary_map(&a, &b, |x, y| x + y).unwrap();
362 assert_eq!(c.shape(), &[3]);
363 let data: Vec<f64> = c.iter().copied().collect();
364 assert_eq!(data, vec![11.0, 22.0, 33.0]);
365 }
366
367 #[test]
368 fn binary_map_broadcast_1d_to_2d() {
369 let a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
371 let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
372 let c = NdIter::binary_map(&a, &b, |x, y| x + y).unwrap();
373 assert_eq!(c.shape(), &[3, 4]);
374 let data: Vec<f64> = c.iter().copied().collect();
375 assert_eq!(
376 data,
377 vec![
378 11.0, 21.0, 31.0, 41.0, 12.0, 22.0, 32.0, 42.0, 13.0, 23.0, 33.0, 43.0
379 ]
380 );
381 }
382
383 #[test]
384 fn binary_map_broadcast_scalar() {
385 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
387 let b = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![100.0]).unwrap();
388 let c = NdIter::binary_map(&a, &b, |x, y| x * y).unwrap();
389 assert_eq!(c.shape(), &[3]);
390 let data: Vec<f64> = c.iter().copied().collect();
391 assert_eq!(data, vec![100.0, 200.0, 300.0]);
392 }
393
394 #[test]
395 fn binary_map_incompatible_shapes() {
396 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
397 let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
398 let result = NdIter::binary_map(&a, &b, |x, y| x + y);
399 assert!(result.is_err());
400 }
401
402 #[test]
403 fn binary_map_to_bool() {
404 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 5.0, 3.0, 7.0]).unwrap();
405 let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![2.0, 3.0, 3.0, 6.0]).unwrap();
406 let c = NdIter::binary_map(&a, &b, |x, y| x > y).unwrap();
407 assert_eq!(c.shape(), &[4]);
408 let data: Vec<bool> = c.iter().copied().collect();
409 assert_eq!(data, vec![false, true, false, true]);
410 }
411
412 #[test]
413 fn binary_map_into_preallocated() {
414 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
415 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
416 let mut out = Array::<f64, IxDyn>::zeros(IxDyn::new(&[3])).unwrap();
417 NdIter::binary_map_into(&a, &b, &mut out, |x, y| x + y).unwrap();
418 let data: Vec<f64> = out.iter().copied().collect();
419 assert_eq!(data, vec![11.0, 22.0, 33.0]);
420 }
421
422 #[test]
423 fn binary_map_into_wrong_shape_error() {
424 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
425 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
426 let mut out = Array::<f64, IxDyn>::zeros(IxDyn::new(&[5])).unwrap();
427 let result = NdIter::binary_map_into(&a, &b, &mut out, |x, y| x + y);
428 assert!(result.is_err());
429 }
430
431 #[test]
432 fn unary_map_basic() {
433 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 4.0, 9.0, 16.0]).unwrap();
434 let c = NdIter::unary_map(&a, f64::sqrt).unwrap();
435 assert_eq!(c.shape(), &[4]);
436 let data: Vec<f64> = c.iter().copied().collect();
437 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
438 }
439
440 #[test]
441 fn unary_map_into_preallocated() {
442 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 4.0, 9.0]).unwrap();
443 let mut out = Array::<f64, IxDyn>::zeros(IxDyn::new(&[3])).unwrap();
444 NdIter::unary_map_into(&a, &mut out, |x| x * 2.0).unwrap();
445 let data: Vec<f64> = out.iter().copied().collect();
446 assert_eq!(data, vec![2.0, 8.0, 18.0]);
447 }
448
449 #[test]
450 fn binary_iter_shape() {
451 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![1.0, 2.0]).unwrap();
452 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
453 let iter = NdIter::binary_iter(&a, &b).unwrap();
454 assert_eq!(iter.shape(), &[2, 3]);
455 }
456
457 #[test]
458 fn binary_iter_map_collect() {
459 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
460 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
461 let iter = NdIter::binary_iter(&a, &b).unwrap();
462 let result: Vec<f64> = iter.map_collect(|x, y| x + y);
463 assert_eq!(result, vec![11.0, 22.0, 33.0]);
464 }
465
466 #[test]
467 fn binary_map_3d_broadcast() {
468 use crate::dimension::Ix3;
470 let a =
471 Array::<i32, Ix3>::from_vec(Ix3::new([2, 1, 4]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
472 let b = Array::<i32, Ix2>::from_vec(Ix2::new([3, 1]), vec![10, 20, 30]).unwrap();
473 let c = NdIter::binary_map(&a, &b, |x, y| x + y).unwrap();
474 assert_eq!(c.shape(), &[2, 3, 4]);
475 assert_eq!(*c.iter().next().unwrap(), 11);
477 }
478
479 #[test]
482 fn binary_iter_next() {
483 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
484 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
485 let mut iter = NdIter::binary_iter(&a, &b).unwrap();
486 assert_eq!(iter.next(), Some((1.0, 10.0)));
487 assert_eq!(iter.next(), Some((2.0, 20.0)));
488 assert_eq!(iter.next(), Some((3.0, 30.0)));
489 assert_eq!(iter.next(), None);
490 }
491
492 #[test]
493 fn binary_iter_for_loop() {
494 let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
495 let b = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
496 let iter = NdIter::binary_iter(&a, &b).unwrap();
497 let sums: Vec<i32> = iter.map(|(x, y)| x + y).collect();
498 assert_eq!(sums, vec![11, 22, 33]);
499 }
500
501 #[test]
502 fn binary_iter_broadcast_with_next() {
503 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![1.0, 2.0]).unwrap();
505 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
506 let iter = NdIter::binary_iter(&a, &b).unwrap();
507 assert_eq!(iter.len(), 6);
508 let pairs: Vec<(f64, f64)> = iter.collect();
509 assert_eq!(
510 pairs,
511 vec![
512 (1.0, 10.0),
513 (1.0, 20.0),
514 (1.0, 30.0),
515 (2.0, 10.0),
516 (2.0, 20.0),
517 (2.0, 30.0),
518 ]
519 );
520 }
521
522 #[test]
523 fn binary_iter_exact_size() {
524 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0; 5]).unwrap();
525 let b = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![2.0; 5]).unwrap();
526 let iter = NdIter::binary_iter(&a, &b).unwrap();
527 assert_eq!(iter.len(), 5);
528 }
529
530 #[test]
531 fn binary_iter_for_each_method() {
532 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
533 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
534 let iter = NdIter::binary_iter(&a, &b).unwrap();
535 let mut sum = 0.0;
536 iter.for_each(|x, y| sum += x + y);
537 assert!((sum - 66.0).abs() < 1e-10); }
539
540 #[test]
543 fn binary_map_mixed_i32_f64() {
544 let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
545 let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![0.5, 1.5, 2.5]).unwrap();
546 let c = NdIter::binary_map_mixed(&a, &b, |x, y| x as f64 + y).unwrap();
547 assert_eq!(c.shape(), &[3]);
548 let data: Vec<f64> = c.iter().copied().collect();
549 assert_eq!(data, vec![1.5, 3.5, 5.5]);
550 }
551
552 #[test]
553 fn binary_map_mixed_broadcast() {
554 let a = Array::<i32, Ix2>::from_vec(Ix2::new([3, 1]), vec![1, 2, 3]).unwrap();
556 let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
557 let c = NdIter::binary_map_mixed(&a, &b, |x, y| x as f64 + y).unwrap();
558 assert_eq!(c.shape(), &[3, 4]);
559 let first: f64 = *c.iter().next().unwrap();
560 assert!((first - 1.1).abs() < 1e-10);
561 }
562
563 #[test]
564 fn binary_map_mixed_to_bool() {
565 let a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 5, 3, 7]).unwrap();
567 let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![2.0, 3.0, 3.0, 6.0]).unwrap();
568 let c = NdIter::binary_map_mixed(&a, &b, |x, y| (x as f64) > y).unwrap();
569 let data: Vec<bool> = c.iter().copied().collect();
570 assert_eq!(data, vec![false, true, false, true]);
571 }
572}