use std::ffi::c_int;
use derive_more::{Display, IsVariant};
use crate::{
array::Array,
error::{Error, LengthMismatchPayload, Result, check},
shape::dim_ptr,
stream::default_stream,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Display, IsVariant)]
#[display("{}", self.as_str())]
pub enum FftNorm {
#[default]
Backward,
Ortho,
Forward,
}
impl FftNorm {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Backward => "backward",
Self::Ortho => "ortho",
Self::Forward => "forward",
}
}
}
impl From<FftNorm> for mlxrs_sys::mlx_fft_norm {
fn from(n: FftNorm) -> Self {
match n {
FftNorm::Backward => mlxrs_sys::mlx_fft_norm__MLX_FFT_NORM_BACKWARD,
FftNorm::Ortho => mlxrs_sys::mlx_fft_norm__MLX_FFT_NORM_ORTHO,
FftNorm::Forward => mlxrs_sys::mlx_fft_norm__MLX_FFT_NORM_FORWARD,
}
}
}
fn resolve_fft(a: &Array, n: &[i32], axes: &[i32], last_two: bool) -> Result<(Vec<i32>, Vec<i32>)> {
let ndim = a.ndim() as i32;
let norm = |ax: i32| if ax < 0 { ax + ndim } else { ax };
if !axes.is_empty() {
if n.is_empty() {
let shape = a.shape();
let ok = axes.iter().all(|&ax| {
let r = norm(ax);
r >= 0 && (r as usize) < shape.len()
});
if ok {
let nn = axes
.iter()
.map(|&ax| shape[norm(ax) as usize] as i32)
.collect();
return Ok((nn, axes.to_vec()));
}
}
return Ok((n.to_vec(), axes.to_vec()));
}
if !n.is_empty() {
if n.len() > a.ndim() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"fftn/fft2: n vs array rank (default axes)",
a.ndim(),
n.len(),
)));
}
let cnt = n.len() as i32;
return Ok((n.to_vec(), ((ndim - cnt).max(0)..ndim).collect()));
}
let ax: Vec<i32> = if last_two {
((ndim - 2).max(0)..ndim).collect()
} else {
(0..ndim).collect()
};
let shape = a.shape();
let nn = ax
.iter()
.map(|&x| shape.get(x as usize).copied().unwrap_or(0) as i32)
.collect();
Ok((nn, ax))
}
pub fn fft(a: &Array, n: i32, axis: i32, norm: FftNorm) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fft_fft(
&mut out.0,
a.0,
n as c_int,
axis as c_int,
norm.into(),
default_stream(),
)
})?;
Ok(out)
}
pub fn ifft(a: &Array, n: i32, axis: i32, norm: FftNorm) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fft_ifft(
&mut out.0,
a.0,
n as c_int,
axis as c_int,
norm.into(),
default_stream(),
)
})?;
Ok(out)
}
pub fn rfft(a: &Array, n: i32, axis: i32, norm: FftNorm) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fft_rfft(
&mut out.0,
a.0,
n as c_int,
axis as c_int,
norm.into(),
default_stream(),
)
})?;
Ok(out)
}
pub fn irfft(a: &Array, n: i32, axis: i32, norm: FftNorm) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fft_irfft(
&mut out.0,
a.0,
n as c_int,
axis as c_int,
norm.into(),
default_stream(),
)
})?;
Ok(out)
}
pub fn fftn(a: &Array, n: &[i32], axes: &[i32], norm: FftNorm) -> Result<Array> {
let (n, axes) = resolve_fft(a, n, axes, false)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fft_fftn(
&mut out.0,
a.0,
dim_ptr(&n),
n.len(),
dim_ptr(&axes),
axes.len(),
norm.into(),
default_stream(),
)
})?;
Ok(out)
}
pub fn ifftn(a: &Array, n: &[i32], axes: &[i32], norm: FftNorm) -> Result<Array> {
let (n, axes) = resolve_fft(a, n, axes, false)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fft_ifftn(
&mut out.0,
a.0,
dim_ptr(&n),
n.len(),
dim_ptr(&axes),
axes.len(),
norm.into(),
default_stream(),
)
})?;
Ok(out)
}
pub fn fft2(a: &Array, n: &[i32], axes: &[i32], norm: FftNorm) -> Result<Array> {
let (n, axes) = resolve_fft(a, n, axes, true)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fft_fft2(
&mut out.0,
a.0,
dim_ptr(&n),
n.len(),
dim_ptr(&axes),
axes.len(),
norm.into(),
default_stream(),
)
})?;
Ok(out)
}
pub fn ifft2(a: &Array, n: &[i32], axes: &[i32], norm: FftNorm) -> Result<Array> {
let (n, axes) = resolve_fft(a, n, axes, true)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fft_ifft2(
&mut out.0,
a.0,
dim_ptr(&n),
n.len(),
dim_ptr(&axes),
axes.len(),
norm.into(),
default_stream(),
)
})?;
Ok(out)
}
pub fn rfftn(a: &Array, n: &[i32], axes: &[i32], norm: FftNorm) -> Result<Array> {
let (n, axes) = resolve_fft(a, n, axes, false)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fft_rfftn(
&mut out.0,
a.0,
dim_ptr(&n),
n.len(),
dim_ptr(&axes),
axes.len(),
norm.into(),
default_stream(),
)
})?;
Ok(out)
}
pub fn irfftn(a: &Array, n: &[i32], axes: &[i32], norm: FftNorm) -> Result<Array> {
let (n, axes) = resolve_fft(a, n, axes, false)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fft_irfftn(
&mut out.0,
a.0,
dim_ptr(&n),
n.len(),
dim_ptr(&axes),
axes.len(),
norm.into(),
default_stream(),
)
})?;
Ok(out)
}
pub fn rfft2(a: &Array, n: &[i32], axes: &[i32], norm: FftNorm) -> Result<Array> {
let (n, axes) = resolve_fft(a, n, axes, true)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fft_rfft2(
&mut out.0,
a.0,
dim_ptr(&n),
n.len(),
dim_ptr(&axes),
axes.len(),
norm.into(),
default_stream(),
)
})?;
Ok(out)
}
pub fn irfft2(a: &Array, n: &[i32], axes: &[i32], norm: FftNorm) -> Result<Array> {
let (n, axes) = resolve_fft(a, n, axes, true)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fft_irfft2(
&mut out.0,
a.0,
dim_ptr(&n),
n.len(),
dim_ptr(&axes),
axes.len(),
norm.into(),
default_stream(),
)
})?;
Ok(out)
}
pub fn fftfreq(n: i32, d: f64) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_fft_fftfreq(&mut out.0, n as c_int, d, default_stream()) })?;
Ok(out)
}
pub fn rfftfreq(n: i32, d: f64) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_fft_rfftfreq(&mut out.0, n as c_int, d, default_stream()) })?;
Ok(out)
}
pub fn fftshift(a: &Array, axes: &[i32]) -> Result<Array> {
let (_, axes) = resolve_fft(a, &[], axes, false)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fft_fftshift(
&mut out.0,
a.0,
dim_ptr(&axes),
axes.len(),
default_stream(),
)
})?;
Ok(out)
}
pub fn ifftshift(a: &Array, axes: &[i32]) -> Result<Array> {
let (_, axes) = resolve_fft(a, &[], axes, false)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fft_ifftshift(
&mut out.0,
a.0,
dim_ptr(&axes),
axes.len(),
default_stream(),
)
})?;
Ok(out)
}