1use ferray_core::Array;
7use ferray_core::dimension::{Axis, Dimension, Ix1, IxDyn};
8use ferray_core::dtype::Element;
9use ferray_core::error::{FerrayError, FerrayResult};
10use num_traits::Zero;
11
12pub fn vectorize<T, U, F, D>(f: F) -> impl Fn(&Array<T, D>) -> FerrayResult<Array<U, D>>
28where
29 T: Element + Copy,
30 U: Element,
31 D: Dimension,
32 F: Fn(T) -> U,
33{
34 move |input: &Array<T, D>| {
35 let data: Vec<U> = input.iter().map(|&x| f(x)).collect();
36 Array::from_vec(input.dim().clone(), data)
37 }
38}
39
40pub fn piecewise<T, D>(
61 x: &Array<T, D>,
62 condlist: &[Array<bool, D>],
63 funclist: &[&dyn Fn(T) -> T],
64 default: T,
65) -> FerrayResult<Array<T, D>>
66where
67 T: Element + Copy,
68 D: Dimension,
69{
70 if condlist.len() != funclist.len() {
71 return Err(FerrayError::invalid_value(format!(
72 "piecewise: condlist length ({}) must equal funclist length ({})",
73 condlist.len(),
74 funclist.len()
75 )));
76 }
77
78 for (i, cond) in condlist.iter().enumerate() {
79 if cond.shape() != x.shape() {
80 return Err(FerrayError::shape_mismatch(format!(
81 "piecewise: condlist[{i}] shape {:?} does not match x shape {:?}",
82 cond.shape(),
83 x.shape()
84 )));
85 }
86 }
87
88 let size = x.size();
89 let mut result_data = vec![default; size];
90 let x_data: Vec<T> = x.iter().copied().collect();
91
92 let cond_data: Vec<Vec<bool>> = condlist
94 .iter()
95 .map(|c| c.iter().copied().collect())
96 .collect();
97
98 for i in 0..size {
100 for (j, cond) in cond_data.iter().enumerate() {
101 if cond[i] {
102 result_data[i] = funclist[j](x_data[i]);
103 break;
104 }
105 }
106 }
107
108 Array::from_vec(x.dim().clone(), result_data)
109}
110
111pub fn apply_along_axis<T, D>(
129 func: impl Fn(&Array<T, Ix1>) -> FerrayResult<T>,
130 axis: Axis,
131 a: &Array<T, D>,
132) -> FerrayResult<Array<T, IxDyn>>
133where
134 T: Element + Copy,
135 D: Dimension,
136{
137 let ndim = a.ndim();
138 let ax = axis.index();
139 if ax >= ndim {
140 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
141 }
142
143 let lanes_iter = a.lanes(axis)?;
145 let mut results = Vec::new();
146
147 for lane in lanes_iter {
148 let owned_lane = lane.to_owned();
150 let val = func(&owned_lane)?;
151 results.push(val);
152 }
153
154 let mut result_shape: Vec<usize> = a.shape().to_vec();
156 result_shape.remove(ax);
157 if result_shape.is_empty() {
158 result_shape.push(results.len());
160 }
161
162 Array::from_vec(IxDyn::new(&result_shape), results)
163}
164
165pub fn apply_over_axes<T, F>(
180 func: F,
181 a: &Array<T, IxDyn>,
182 axes: &[usize],
183) -> FerrayResult<Array<T, IxDyn>>
184where
185 T: Element + Copy,
186 F: Fn(&Array<T, IxDyn>, Axis) -> FerrayResult<Array<T, IxDyn>>,
187{
188 let ndim = a.ndim();
189 for &ax in axes {
190 if ax >= ndim {
191 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
192 }
193 }
194
195 let mut current = a.clone();
196 for &ax in axes {
197 current = func(¤t, Axis(ax))?;
198 }
202
203 Ok(current)
204}
205
206pub fn sum_axis_keepdims<T>(a: &Array<T, IxDyn>, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
215where
216 T: Element + Copy + Zero + core::ops::Add<Output = T>,
217{
218 let ndim = a.ndim();
219 let ax = axis.index();
220 if ax >= ndim {
221 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
222 }
223
224 let reduced = a.fold_axis(axis, <T as Zero>::zero(), |acc, &x| *acc + x)?;
225
226 let mut new_shape: Vec<usize> = reduced.shape().to_vec();
228 new_shape.insert(ax, 1);
229 let data: Vec<T> = reduced.iter().copied().collect();
230 Array::from_vec(IxDyn::new(&new_shape), data)
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use ferray_core::dimension::Ix2;
237
238 fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
239 let n = data.len();
240 Array::from_vec(Ix1::new([n]), data).unwrap()
241 }
242
243 fn arr1_bool(data: Vec<bool>) -> Array<bool, Ix1> {
244 let n = data.len();
245 Array::from_vec(Ix1::new([n]), data).unwrap()
246 }
247
248 #[test]
252 fn vectorize_square_ac4() {
253 let square = vectorize(|x: f64| x.powi(2));
254 let input = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
255 let result = square(&input).unwrap();
256 assert_eq!(result.as_slice().unwrap(), &[1.0, 4.0, 9.0, 16.0, 25.0][..]);
257 }
258
259 #[test]
260 fn vectorize_matches_mapv() {
261 let f = |x: f64| x.sin();
262 let vf = vectorize(f);
263 let input = arr1(vec![0.0, 1.0, 2.0, 3.0]);
264 let via_vectorize = vf(&input).unwrap();
265 let via_mapv = input.mapv(f);
266 assert_eq!(
267 via_vectorize.as_slice().unwrap(),
268 via_mapv.as_slice().unwrap()
269 );
270 }
271
272 #[test]
273 fn vectorize_2d_generic_dimension() {
274 let square = vectorize(|x: f64| x * x);
277 let input =
278 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
279 .unwrap();
280 let result = square(&input).unwrap();
281 assert_eq!(result.shape(), &[2, 3]);
282 assert_eq!(
283 result.as_slice().unwrap(),
284 &[1.0, 4.0, 9.0, 16.0, 25.0, 36.0][..]
285 );
286 }
287
288 #[test]
289 fn vectorize_empty() {
290 let f = vectorize(|x: f64| x + 1.0);
291 let input = arr1(vec![]);
292 let result = f(&input).unwrap();
293 assert_eq!(result.shape(), &[0]);
294 }
295
296 #[test]
300 fn piecewise_basic() {
301 let x = arr1(vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
302 let cond_neg = arr1_bool(vec![true, true, false, false, false]);
303 let cond_pos = arr1_bool(vec![false, false, false, true, true]);
304
305 let neg: &dyn Fn(f64) -> f64 = &|v| -v;
308 let pos: &dyn Fn(f64) -> f64 = &|v| v * 2.0;
309 let result = piecewise(&x, &[cond_neg, cond_pos], &[neg, pos], 0.0).unwrap();
310
311 let s = result.as_slice().unwrap();
312 assert_eq!(s, &[2.0, 1.0, 0.0, 2.0, 4.0]);
313 }
314
315 #[test]
316 fn piecewise_first_match_wins() {
317 let x = arr1(vec![1.0, 2.0, 3.0]);
318 let cond1 = arr1_bool(vec![true, true, true]);
320 let cond2 = arr1_bool(vec![true, true, true]);
321
322 let f1: &dyn Fn(f64) -> f64 = &|v| v * 10.0;
323 let f2: &dyn Fn(f64) -> f64 = &|v| v * 100.0;
324 let result = piecewise(&x, &[cond1, cond2], &[f1, f2], 0.0).unwrap();
325
326 let s = result.as_slice().unwrap();
328 assert_eq!(s, &[10.0, 20.0, 30.0]);
329 }
330
331 #[test]
332 fn piecewise_no_match_uses_default() {
333 let x = arr1(vec![1.0, 2.0, 3.0]);
334 let cond = arr1_bool(vec![false, false, false]);
335
336 let f: &dyn Fn(f64) -> f64 = &|v| v * 10.0;
337 let result = piecewise(&x, &[cond], &[f], -999.0).unwrap();
338
339 let s = result.as_slice().unwrap();
340 assert_eq!(s, &[-999.0, -999.0, -999.0]);
341 }
342
343 #[test]
344 fn piecewise_length_mismatch() {
345 let x = arr1(vec![1.0, 2.0]);
346 let cond = arr1_bool(vec![true, false]);
347 let f1: &dyn Fn(f64) -> f64 = &|v| v;
348 let f2: &dyn Fn(f64) -> f64 = &|v| v;
349 assert!(piecewise(&x, &[cond], &[f1, f2], 0.0).is_err());
350 }
351
352 #[test]
353 fn piecewise_shape_mismatch() {
354 let x = arr1(vec![1.0, 2.0]);
355 let cond = arr1_bool(vec![true, false, true]); let f: &dyn Fn(f64) -> f64 = &|v| v;
357 assert!(piecewise(&x, &[cond], &[f], 0.0).is_err());
358 }
359
360 #[test]
364 fn apply_along_axis_col_sums_ac5() {
365 let m = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
366 .unwrap();
367
368 let result = apply_along_axis(
369 |col| {
370 let sum: f64 = col.iter().sum();
371 Ok(sum)
372 },
373 Axis(0),
374 &m,
375 )
376 .unwrap();
377
378 assert_eq!(result.shape(), &[3]);
381 let data: Vec<f64> = result.iter().copied().collect();
382 assert_eq!(data, vec![5.0, 7.0, 9.0]);
383 }
384
385 #[test]
386 fn apply_along_axis_row_sums() {
387 let m = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
388 .unwrap();
389
390 let result = apply_along_axis(
391 |row| {
392 let sum: f64 = row.iter().sum();
393 Ok(sum)
394 },
395 Axis(1),
396 &m,
397 )
398 .unwrap();
399
400 assert_eq!(result.shape(), &[2]);
403 let data: Vec<f64> = result.iter().copied().collect();
404 assert_eq!(data, vec![6.0, 15.0]);
405 }
406
407 #[test]
408 fn apply_along_axis_1d() {
409 let a = arr1(vec![1.0, 2.0, 3.0]);
410 let result = apply_along_axis(
411 |lane| {
412 let sum: f64 = lane.iter().sum();
413 Ok(sum)
414 },
415 Axis(0),
416 &a,
417 )
418 .unwrap();
419 let data: Vec<f64> = result.iter().copied().collect();
421 assert_eq!(data, vec![6.0]);
422 }
423
424 #[test]
425 fn apply_along_axis_out_of_bounds() {
426 let m = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
427 .unwrap();
428 assert!(apply_along_axis(|_| Ok(0.0), Axis(5), &m).is_err());
429 }
430
431 #[test]
435 fn apply_over_axes_sum() {
436 let a =
438 Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
439 .unwrap();
440
441 let result = apply_over_axes(sum_axis_keepdims, &a, &[0, 1]).unwrap();
442
443 assert_eq!(result.shape(), &[1, 1]);
446 let data: Vec<f64> = result.iter().copied().collect();
447 assert_eq!(data, vec![21.0]);
448 }
449
450 #[test]
451 fn apply_over_axes_single_axis() {
452 let a =
453 Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
454 .unwrap();
455
456 let result = apply_over_axes(sum_axis_keepdims, &a, &[0]).unwrap();
457 assert_eq!(result.shape(), &[1, 3]);
458 let data: Vec<f64> = result.iter().copied().collect();
459 assert_eq!(data, vec![5.0, 7.0, 9.0]);
460 }
461
462 #[test]
463 fn apply_over_axes_out_of_bounds() {
464 let a =
465 Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
466 .unwrap();
467 assert!(apply_over_axes(sum_axis_keepdims, &a, &[5]).is_err());
468 }
469
470 #[test]
474 fn sum_axis_keepdims_basic() {
475 let a =
476 Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
477 .unwrap();
478
479 let result = sum_axis_keepdims(&a, Axis(0)).unwrap();
480 assert_eq!(result.shape(), &[1, 3]);
481 let data: Vec<f64> = result.iter().copied().collect();
482 assert_eq!(data, vec![5.0, 7.0, 9.0]);
483 }
484
485 #[test]
486 fn sum_axis_keepdims_axis1() {
487 let a =
488 Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
489 .unwrap();
490
491 let result = sum_axis_keepdims(&a, Axis(1)).unwrap();
492 assert_eq!(result.shape(), &[2, 1]);
493 let data: Vec<f64> = result.iter().copied().collect();
494 assert_eq!(data, vec![6.0, 15.0]);
495 }
496}