use crate::signal::traits::extrema::{ExtremaAlgorithms, ExtremaResult, ExtremumMode};
use numr::error::{Error, Result};
use numr::runtime::cpu::{CpuClient, CpuRuntime};
use numr::tensor::Tensor;
impl ExtremaAlgorithms<CpuRuntime> for CpuClient {
fn argrelmin(
&self,
x: &Tensor<CpuRuntime>,
order: usize,
mode: ExtremumMode,
) -> Result<ExtremaResult<CpuRuntime>> {
argrel_cpu(x, order, mode, Comparison::Less)
}
fn argrelmax(
&self,
x: &Tensor<CpuRuntime>,
order: usize,
mode: ExtremumMode,
) -> Result<ExtremaResult<CpuRuntime>> {
argrel_cpu(x, order, mode, Comparison::Greater)
}
}
#[derive(Clone, Copy)]
enum Comparison {
Less,
Greater,
}
fn argrel_cpu(
x: &Tensor<CpuRuntime>,
order: usize,
mode: ExtremumMode,
comparison: Comparison,
) -> Result<ExtremaResult<CpuRuntime>> {
if x.ndim() != 1 {
return Err(Error::InvalidArgument {
arg: "x",
reason: "Input must be 1D".to_string(),
});
}
let n = x.shape()[0];
let device = x.device();
if n == 0 {
return Err(Error::InvalidArgument {
arg: "x",
reason: "Input signal cannot be empty".to_string(),
});
}
if order == 0 {
return Err(Error::InvalidArgument {
arg: "order",
reason: "Order must be at least 1".to_string(),
});
}
let data: Vec<f64> = x.to_vec();
let mut indices = Vec::new();
let mut values = Vec::new();
let get_val = |idx: isize| -> f64 {
match mode {
ExtremumMode::Wrap => {
let wrapped = ((idx % n as isize) + n as isize) as usize % n;
data[wrapped]
}
ExtremumMode::Clip => {
let clamped = idx.clamp(0, n as isize - 1) as usize;
data[clamped]
}
}
};
for (i, &val) in data.iter().enumerate() {
let mut is_extremum = true;
for offset in 1..=order {
let left_idx = i as isize - offset as isize;
let right_idx = i as isize + offset as isize;
let check_left = match mode {
ExtremumMode::Wrap => true,
ExtremumMode::Clip => left_idx >= 0,
};
let check_right = match mode {
ExtremumMode::Wrap => true,
ExtremumMode::Clip => right_idx < n as isize,
};
if check_left {
let left_val = get_val(left_idx);
match comparison {
Comparison::Less => {
if val >= left_val {
is_extremum = false;
break;
}
}
Comparison::Greater => {
if val <= left_val {
is_extremum = false;
break;
}
}
}
}
if is_extremum && check_right {
let right_val = get_val(right_idx);
match comparison {
Comparison::Less => {
if val >= right_val {
is_extremum = false;
break;
}
}
Comparison::Greater => {
if val <= right_val {
is_extremum = false;
break;
}
}
}
}
if !is_extremum {
break;
}
}
if is_extremum {
indices.push(i);
values.push(val);
}
}
let values_tensor = Tensor::from_slice(&values, &[values.len()], device);
Ok(ExtremaResult {
indices,
values: values_tensor,
})
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::CpuDevice;
fn setup() -> (CpuClient, CpuDevice) {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
(client, device)
}
#[test]
fn test_argrelmin_simple() {
let (client, device) = setup();
let signal = vec![1.0, 0.5, 0.0, 0.5, 1.0, 0.5, 0.0, 0.5, 1.0];
let x = Tensor::from_slice(&signal, &[signal.len()], &device);
let result = client.argrelmin(&x, 1, ExtremumMode::Clip).unwrap();
assert_eq!(result.indices, vec![2, 6]);
let values: Vec<f64> = result.values.to_vec();
assert!((values[0] - 0.0).abs() < 1e-10);
assert!((values[1] - 0.0).abs() < 1e-10);
}
#[test]
fn test_argrelmax_simple() {
let (client, device) = setup();
let signal = vec![1.0, 0.5, 0.0, 0.5, 1.0, 0.5, 0.0, 0.5, 1.0];
let x = Tensor::from_slice(&signal, &[signal.len()], &device);
let result = client.argrelmax(&x, 1, ExtremumMode::Clip).unwrap();
assert!(result.indices.contains(&4));
assert!(result.indices.contains(&0));
assert!(result.indices.contains(&8));
}
#[test]
fn test_argrelmin_higher_order() {
let (client, device) = setup();
let signal = vec![5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0];
let x = Tensor::from_slice(&signal, &[signal.len()], &device);
let result = client.argrelmin(&x, 2, ExtremumMode::Clip).unwrap();
assert_eq!(result.indices, vec![4]);
}
#[test]
fn test_argrelmax_wrap_mode() {
let (client, device) = setup();
let signal = vec![0.0, 1.0, 0.0, 1.0, 0.0];
let x = Tensor::from_slice(&signal, &[signal.len()], &device);
let result = client.argrelmax(&x, 1, ExtremumMode::Wrap).unwrap();
assert_eq!(result.indices, vec![1, 3]);
}
#[test]
fn test_argrelextrema() {
let (client, device) = setup();
let signal = vec![0.0, 1.0, 0.0, -1.0, 0.0];
let x = Tensor::from_slice(&signal, &[signal.len()], &device);
let (minima, maxima) = client.argrelextrema(&x, 1, ExtremumMode::Clip).unwrap();
assert!(maxima.indices.contains(&1));
assert!(minima.indices.contains(&3));
}
#[test]
fn test_argrelmin_no_extrema() {
let (client, device) = setup();
let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let x = Tensor::from_slice(&signal, &[signal.len()], &device);
let result = client.argrelmin(&x, 1, ExtremumMode::Clip).unwrap();
assert_eq!(result.indices, vec![0]);
}
#[test]
fn test_argrelmin_plateau() {
let (client, device) = setup();
let signal = vec![1.0, 0.0, 0.0, 0.0, 1.0];
let x = Tensor::from_slice(&signal, &[signal.len()], &device);
let result = client.argrelmin(&x, 1, ExtremumMode::Clip).unwrap();
assert!(result.indices.is_empty());
}
}