1use ferray_core::Array;
7use ferray_core::dimension::{Axis, Dimension, Ix1, IxDyn};
8use ferray_core::dtype::Element;
9use ferray_core::error::{FerrayError, FerrayResult};
10
11pub fn vectorize<T, U, F>(f: F) -> impl Fn(&Array<T, Ix1>) -> FerrayResult<Array<U, Ix1>>
25where
26 T: Element + Copy,
27 U: Element,
28 F: Fn(T) -> U,
29{
30 move |input: &Array<T, Ix1>| {
31 let data: Vec<U> = input.iter().map(|&x| f(x)).collect();
32 Array::from_vec(Ix1::new([data.len()]), data)
33 }
34}
35
36pub fn vectorize_nd<T, U, F, D>(f: F) -> impl Fn(&Array<T, D>) -> FerrayResult<Array<U, D>>
46where
47 T: Element + Copy,
48 U: Element,
49 D: Dimension,
50 F: Fn(T) -> U,
51{
52 move |input: &Array<T, D>| {
53 let data: Vec<U> = input.iter().map(|&x| f(x)).collect();
54 Array::from_vec(input.dim().clone(), data)
55 }
56}
57
58pub fn piecewise<T, D>(
76 x: &Array<T, D>,
77 condlist: &[Array<bool, D>],
78 funclist: &[Box<dyn Fn(T) -> T>],
79 default: T,
80) -> FerrayResult<Array<T, D>>
81where
82 T: Element + Copy,
83 D: Dimension,
84{
85 if condlist.len() != funclist.len() {
86 return Err(FerrayError::invalid_value(format!(
87 "piecewise: condlist length ({}) must equal funclist length ({})",
88 condlist.len(),
89 funclist.len()
90 )));
91 }
92
93 for (i, cond) in condlist.iter().enumerate() {
94 if cond.shape() != x.shape() {
95 return Err(FerrayError::shape_mismatch(format!(
96 "piecewise: condlist[{i}] shape {:?} does not match x shape {:?}",
97 cond.shape(),
98 x.shape()
99 )));
100 }
101 }
102
103 let size = x.size();
104 let mut result_data = vec![default; size];
105 let x_data: Vec<T> = x.iter().copied().collect();
106
107 let cond_data: Vec<Vec<bool>> = condlist
109 .iter()
110 .map(|c| c.iter().copied().collect())
111 .collect();
112
113 for i in 0..size {
115 for (j, cond) in cond_data.iter().enumerate() {
116 if cond[i] {
117 result_data[i] = funclist[j](x_data[i]);
118 break;
119 }
120 }
121 }
122
123 Array::from_vec(x.dim().clone(), result_data)
124}
125
126pub fn apply_along_axis<T, D>(
144 func: impl Fn(&Array<T, Ix1>) -> FerrayResult<T>,
145 axis: Axis,
146 a: &Array<T, D>,
147) -> FerrayResult<Array<T, IxDyn>>
148where
149 T: Element + Copy,
150 D: Dimension,
151{
152 let ndim = a.ndim();
153 let ax = axis.index();
154 if ax >= ndim {
155 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
156 }
157
158 let lanes_iter = a.lanes(axis)?;
160 let mut results = Vec::new();
161
162 for lane in lanes_iter {
163 let owned_lane = lane.to_owned();
165 let val = func(&owned_lane)?;
166 results.push(val);
167 }
168
169 let mut result_shape: Vec<usize> = a.shape().to_vec();
171 result_shape.remove(ax);
172 if result_shape.is_empty() {
173 result_shape.push(results.len());
175 }
176
177 Array::from_vec(IxDyn::new(&result_shape), results)
178}
179
180pub fn apply_over_axes(
198 func: impl Fn(&Array<f64, IxDyn>, Axis) -> FerrayResult<Array<f64, IxDyn>>,
199 a: &Array<f64, IxDyn>,
200 axes: &[usize],
201) -> FerrayResult<Array<f64, IxDyn>> {
202 let ndim = a.ndim();
203 for &ax in axes {
204 if ax >= ndim {
205 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
206 }
207 }
208
209 let mut current = a.clone();
210 for &ax in axes {
211 current = func(¤t, Axis(ax))?;
212 }
216
217 Ok(current)
218}
219
220pub fn sum_axis_keepdims(a: &Array<f64, IxDyn>, axis: Axis) -> FerrayResult<Array<f64, IxDyn>> {
227 let ndim = a.ndim();
228 let ax = axis.index();
229 if ax >= ndim {
230 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
231 }
232
233 let reduced = a.fold_axis(axis, 0.0, |acc, &x| *acc + x)?;
234
235 let mut new_shape: Vec<usize> = reduced.shape().to_vec();
237 new_shape.insert(ax, 1);
238 let data: Vec<f64> = reduced.iter().copied().collect();
239 Array::from_vec(IxDyn::new(&new_shape), data)
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use ferray_core::dimension::Ix2;
246
247 fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
248 let n = data.len();
249 Array::from_vec(Ix1::new([n]), data).unwrap()
250 }
251
252 fn arr1_bool(data: Vec<bool>) -> Array<bool, Ix1> {
253 let n = data.len();
254 Array::from_vec(Ix1::new([n]), data).unwrap()
255 }
256
257 #[test]
261 fn vectorize_square_ac4() {
262 let square = vectorize(|x: f64| x.powi(2));
263 let input = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
264 let result = square(&input).unwrap();
265 let expected = vec![1.0, 4.0, 9.0, 16.0, 25.0];
266 assert_eq!(result.as_slice().unwrap(), &expected[..]);
267 }
268
269 #[test]
270 fn vectorize_matches_mapv() {
271 let f = |x: f64| x.sin();
272 let vf = vectorize(f);
273 let input = arr1(vec![0.0, 1.0, 2.0, 3.0]);
274 let via_vectorize = vf(&input).unwrap();
275 let via_mapv = input.mapv(f);
276 assert_eq!(
277 via_vectorize.as_slice().unwrap(),
278 via_mapv.as_slice().unwrap()
279 );
280 }
281
282 #[test]
283 fn vectorize_nd_2d() {
284 let square = vectorize_nd(|x: f64| x * x);
285 let input =
286 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
287 .unwrap();
288 let result = square(&input).unwrap();
289 assert_eq!(result.shape(), &[2, 3]);
290 let expected = vec![1.0, 4.0, 9.0, 16.0, 25.0, 36.0];
291 assert_eq!(result.as_slice().unwrap(), &expected[..]);
292 }
293
294 #[test]
295 fn vectorize_empty() {
296 let f = vectorize(|x: f64| x + 1.0);
297 let input = arr1(vec![]);
298 let result = f(&input).unwrap();
299 assert_eq!(result.shape(), &[0]);
300 }
301
302 #[test]
306 fn piecewise_basic() {
307 let x = arr1(vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
308 let cond_neg = arr1_bool(vec![true, true, false, false, false]);
309 let cond_pos = arr1_bool(vec![false, false, false, true, true]);
310
311 let result = piecewise(
312 &x,
313 &[cond_neg, cond_pos],
314 &[
315 Box::new(|v: f64| -v), Box::new(|v: f64| v * 2.0), ],
318 0.0, )
320 .unwrap();
321
322 let s = result.as_slice().unwrap();
323 assert_eq!(s, &[2.0, 1.0, 0.0, 2.0, 4.0]);
324 }
325
326 #[test]
327 fn piecewise_first_match_wins() {
328 let x = arr1(vec![1.0, 2.0, 3.0]);
329 let cond1 = arr1_bool(vec![true, true, true]);
331 let cond2 = arr1_bool(vec![true, true, true]);
332
333 let result = piecewise(
334 &x,
335 &[cond1, cond2],
336 &[Box::new(|v: f64| v * 10.0), Box::new(|v: f64| v * 100.0)],
337 0.0,
338 )
339 .unwrap();
340
341 let s = result.as_slice().unwrap();
343 assert_eq!(s, &[10.0, 20.0, 30.0]);
344 }
345
346 #[test]
347 fn piecewise_no_match_uses_default() {
348 let x = arr1(vec![1.0, 2.0, 3.0]);
349 let cond = arr1_bool(vec![false, false, false]);
350
351 let result = piecewise(&x, &[cond], &[Box::new(|v: f64| v * 10.0)], -999.0).unwrap();
352
353 let s = result.as_slice().unwrap();
354 assert_eq!(s, &[-999.0, -999.0, -999.0]);
355 }
356
357 #[test]
358 fn piecewise_length_mismatch() {
359 let x = arr1(vec![1.0, 2.0]);
360 let cond = arr1_bool(vec![true, false]);
361 assert!(
362 piecewise(
363 &x,
364 &[cond],
365 &[Box::new(|v: f64| v), Box::new(|v: f64| v)],
366 0.0
367 )
368 .is_err()
369 );
370 }
371
372 #[test]
373 fn piecewise_shape_mismatch() {
374 let x = arr1(vec![1.0, 2.0]);
375 let cond = arr1_bool(vec![true, false, true]); assert!(piecewise(&x, &[cond], &[Box::new(|v: f64| v)], 0.0).is_err());
377 }
378
379 #[test]
383 fn apply_along_axis_col_sums_ac5() {
384 let m = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
385 .unwrap();
386
387 let result = apply_along_axis(
388 |col| {
389 let sum: f64 = col.iter().sum();
390 Ok(sum)
391 },
392 Axis(0),
393 &m,
394 )
395 .unwrap();
396
397 assert_eq!(result.shape(), &[3]);
400 let data: Vec<f64> = result.iter().copied().collect();
401 assert_eq!(data, vec![5.0, 7.0, 9.0]);
402 }
403
404 #[test]
405 fn apply_along_axis_row_sums() {
406 let m = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
407 .unwrap();
408
409 let result = apply_along_axis(
410 |row| {
411 let sum: f64 = row.iter().sum();
412 Ok(sum)
413 },
414 Axis(1),
415 &m,
416 )
417 .unwrap();
418
419 assert_eq!(result.shape(), &[2]);
422 let data: Vec<f64> = result.iter().copied().collect();
423 assert_eq!(data, vec![6.0, 15.0]);
424 }
425
426 #[test]
427 fn apply_along_axis_1d() {
428 let a = arr1(vec![1.0, 2.0, 3.0]);
429 let result = apply_along_axis(
430 |lane| {
431 let sum: f64 = lane.iter().sum();
432 Ok(sum)
433 },
434 Axis(0),
435 &a,
436 )
437 .unwrap();
438 let data: Vec<f64> = result.iter().copied().collect();
440 assert_eq!(data, vec![6.0]);
441 }
442
443 #[test]
444 fn apply_along_axis_out_of_bounds() {
445 let m = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
446 .unwrap();
447 assert!(apply_along_axis(|_| Ok(0.0), Axis(5), &m).is_err());
448 }
449
450 #[test]
454 fn apply_over_axes_sum() {
455 let a =
457 Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
458 .unwrap();
459
460 let result = apply_over_axes(sum_axis_keepdims, &a, &[0, 1]).unwrap();
461
462 assert_eq!(result.shape(), &[1, 1]);
465 let data: Vec<f64> = result.iter().copied().collect();
466 assert_eq!(data, vec![21.0]);
467 }
468
469 #[test]
470 fn apply_over_axes_single_axis() {
471 let a =
472 Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
473 .unwrap();
474
475 let result = apply_over_axes(sum_axis_keepdims, &a, &[0]).unwrap();
476 assert_eq!(result.shape(), &[1, 3]);
477 let data: Vec<f64> = result.iter().copied().collect();
478 assert_eq!(data, vec![5.0, 7.0, 9.0]);
479 }
480
481 #[test]
482 fn apply_over_axes_out_of_bounds() {
483 let a =
484 Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
485 .unwrap();
486 assert!(apply_over_axes(sum_axis_keepdims, &a, &[5]).is_err());
487 }
488
489 #[test]
493 fn sum_axis_keepdims_basic() {
494 let a =
495 Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
496 .unwrap();
497
498 let result = sum_axis_keepdims(&a, Axis(0)).unwrap();
499 assert_eq!(result.shape(), &[1, 3]);
500 let data: Vec<f64> = result.iter().copied().collect();
501 assert_eq!(data, vec![5.0, 7.0, 9.0]);
502 }
503
504 #[test]
505 fn sum_axis_keepdims_axis1() {
506 let a =
507 Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
508 .unwrap();
509
510 let result = sum_axis_keepdims(&a, Axis(1)).unwrap();
511 assert_eq!(result.shape(), &[2, 1]);
512 let data: Vec<f64> = result.iter().copied().collect();
513 assert_eq!(data, vec![6.0, 15.0]);
514 }
515}