use numr::dtype::DType;
use numr::error::{Error, Result};
use numr::ops::{ComplexOps, ReduceOps, ScalarOps, ShapeOps, TensorOps, UtilityOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn reverse_1d_impl<R, C>(_client: &C, tensor: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>,
{
if tensor.ndim() != 1 {
return Err(Error::InvalidArgument {
arg: "tensor",
reason: "reverse_1d requires 1D tensor".to_string(),
});
}
tensor.flip(0)
}
pub fn reverse_2d_impl<R, C>(_client: &C, tensor: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>,
{
if tensor.ndim() != 2 {
return Err(Error::InvalidArgument {
arg: "tensor",
reason: "reverse_2d requires 2D tensor".to_string(),
});
}
tensor.flip_dims(&[0, 1])
}
pub fn complex_mul_impl<R, C>(client: &C, a: &Tensor<R>, b: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
let dtype = a.dtype();
if a.dtype() != b.dtype() {
return Err(Error::DTypeMismatch {
lhs: a.dtype(),
rhs: b.dtype(),
});
}
if a.shape() != b.shape() {
return Err(Error::ShapeMismatch {
expected: a.shape().to_vec(),
got: b.shape().to_vec(),
});
}
if !dtype.is_complex() {
return Err(Error::UnsupportedDType {
dtype,
op: "complex_mul",
});
}
client.mul(a, b)
}
#[allow(dead_code)]
pub fn complex_divide_impl<R, C>(client: &C, a: &Tensor<R>, b: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ComplexOps<R> + RuntimeClient<R>,
{
let dtype = a.dtype();
if !dtype.is_complex() {
return Err(Error::UnsupportedDType {
dtype,
op: "complex_divide",
});
}
if a.dtype() != b.dtype() {
return Err(Error::DTypeMismatch {
lhs: a.dtype(),
rhs: b.dtype(),
});
}
let b_conj = client.conj(b)?;
let numerator = client.mul(a, &b_conj)?;
let b_re = client.real(b)?;
let b_im = client.imag(b)?;
let b_re_sq = client.mul(&b_re, &b_re)?;
let b_im_sq = client.mul(&b_im, &b_im)?;
let denom = client.add(&b_re_sq, &b_im_sq)?;
client.complex_div_real(&numerator, &denom)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DetrendMode {
#[default]
None,
Constant,
Linear,
}
#[allow(dead_code)]
pub fn detrend_tensor_impl<R, C>(
client: &C,
tensor: &Tensor<R>,
mode: DetrendMode,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + ReduceOps<R> + UtilityOps<R> + RuntimeClient<R>,
{
match mode {
DetrendMode::None => Ok(tensor.clone()),
DetrendMode::Constant => {
let ndim = tensor.ndim();
let last_dim = ndim - 1;
let mean = client.mean(tensor, &[last_dim], true)?;
client.sub(tensor, &mean)
}
DetrendMode::Linear => {
let ndim = tensor.ndim();
let last_dim = ndim - 1;
let n = tensor.shape()[last_dim];
let _device = tensor.device();
let dtype = tensor.dtype();
if n < 2 {
return Ok(tensor.clone());
}
let x = client.arange(0.0, n as f64, 1.0, dtype)?;
let x_mean = (n - 1) as f64 / 2.0;
let y_mean = client.mean(tensor, &[last_dim], true)?;
let y_centered = client.sub(tensor, &y_mean)?;
let x_centered = client.add_scalar(&x, -x_mean)?;
let x_centered_broadcast = if ndim == 1 {
x_centered.clone()
} else {
let mut shape = vec![1usize; ndim];
shape[last_dim] = n;
x_centered.reshape(&shape)?
};
let xy_product = client.mul(&x_centered_broadcast, &y_centered)?;
let numerator = client.sum(&xy_product, &[last_dim], true)?;
let denom_val = (n as f64) * ((n * n - 1) as f64) / 12.0;
let b = client.div_scalar(&numerator, denom_val)?;
let b_x_mean = client.mul_scalar(&b, x_mean)?;
let a = client.sub(&y_mean, &b_x_mean)?;
let trend_bx = client.mul(&b, &x_centered_broadcast)?;
let trend = client.add(&a, &trend_bx)?;
client.sub(tensor, &trend)
}
}
}
pub fn complex_magnitude_pow_impl<R, C>(
client: &C,
tensor: &Tensor<R>,
power: f64,
output_dtype: DType,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let dtype = tensor.dtype();
if !dtype.is_complex() {
return Err(Error::UnsupportedDType {
dtype,
op: "complex_magnitude_pow",
});
}
let re = client.real(tensor)?;
let im = client.imag(tensor)?;
let re_sq = client.mul(&re, &re)?;
let im_sq = client.mul(&im, &im)?;
let mag_sq = client.add(&re_sq, &im_sq)?;
let result = if (power - 2.0).abs() < 1e-10 {
mag_sq
} else if (power - 1.0).abs() < 1e-10 {
client.sqrt(&mag_sq)?
} else {
let half_power = power / 2.0;
client.pow_scalar(&mag_sq, half_power)?
};
if result.dtype() != output_dtype {
client.cast(&result, output_dtype)
} else {
Ok(result)
}
}
pub fn extract_segments_impl<R, C>(
client: &C,
signal: &Tensor<R>,
nperseg: usize,
noverlap: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ShapeOps<R> + RuntimeClient<R>,
{
if signal.ndim() != 1 {
return Err(Error::InvalidArgument {
arg: "signal",
reason: "Signal must be 1D".to_string(),
});
}
let n = signal.shape()[0];
let step = nperseg - noverlap;
if step == 0 {
return Err(Error::InvalidArgument {
arg: "noverlap",
reason: "noverlap must be less than nperseg".to_string(),
});
}
let num_segments = if n >= nperseg {
(n - nperseg) / step + 1
} else {
0
};
if num_segments == 0 {
return Err(Error::InvalidArgument {
arg: "signal",
reason: "Signal too short for given segment parameters".to_string(),
});
}
let mut segments: Vec<Tensor<R>> = Vec::with_capacity(num_segments);
for i in 0..num_segments {
let start = i * step;
let segment = signal.narrow(0, start, nperseg)?;
segments.push(segment);
}
let segment_refs: Vec<&Tensor<R>> = segments.iter().collect();
client.stack(&segment_refs, 0)
}
pub fn power_spectrum_impl<R, C>(client: &C, fft_result: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ComplexOps<R> + RuntimeClient<R>,
{
let conj = client.conj(fft_result)?;
let power_complex = client.mul(&conj, fft_result)?;
client.real(&power_complex)
}