#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <numeric>
#include <string_view>
#include "mlx/fft.h"
#include "mlx/ops.h"
#include "python/src/small_vector.h"
#include "python/src/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
namespace {
mx::fft::FFTNorm parse_norm(std::string_view norm, std::string_view op) {
if (norm == "backward") {
return mx::fft::FFTNorm::Backward;
}
if (norm == "ortho") {
return mx::fft::FFTNorm::Ortho;
}
if (norm == "forward") {
return mx::fft::FFTNorm::Forward;
}
throw std::invalid_argument(
std::string("[") + std::string(op) +
"] Invalid norm. Expected one of {'backward', 'ortho', 'forward'}.");
}
}
void init_fft(nb::module_& parent_module) {
auto m = parent_module.def_submodule(
"fft", "mlx.core.fft: Fast Fourier Transforms.");
m.def(
"fft",
[](const mx::array& a,
const std::optional<int>& n,
int axis,
const std::string& norm,
mx::StreamOrDevice s) {
auto fft_norm = parse_norm(norm, "fft");
if (n.has_value()) {
return mx::fft::fft(a, n.value(), axis, fft_norm, s);
} else {
return mx::fft::fft(a, axis, fft_norm, s);
}
},
"a"_a,
"n"_a = nb::none(),
"axis"_a = -1,
"norm"_a = "backward",
"stream"_a = nb::none(),
R"pbdoc(
One dimensional discrete Fourier Transform.
Args:
a (array): The input array.
n (int, optional): Size of the transformed axis. The
corresponding axis in the input is truncated or padded with
zeros to match ``n``. The default value is ``a.shape[axis]``.
axis (int, optional): Axis along which to perform the FFT. The
default is ``-1``.
norm (str, optional): One of ``"backward"``, ``"ortho"``, or
``"forward"``. Default is ``"backward"``.
Returns:
array: The DFT of the input along the given axis.
)pbdoc");
m.def(
"ifft",
[](const mx::array& a,
const std::optional<int>& n,
int axis,
const std::string& norm,
mx::StreamOrDevice s) {
auto fft_norm = parse_norm(norm, "ifft");
if (n.has_value()) {
return mx::fft::ifft(a, n.value(), axis, fft_norm, s);
} else {
return mx::fft::ifft(a, axis, fft_norm, s);
}
},
"a"_a,
"n"_a = nb::none(),
"axis"_a = -1,
"norm"_a = "backward",
"stream"_a = nb::none(),
R"pbdoc(
One dimensional inverse discrete Fourier Transform.
Args:
a (array): The input array.
n (int, optional): Size of the transformed axis. The
corresponding axis in the input is truncated or padded with
zeros to match ``n``. The default value is ``a.shape[axis]``.
axis (int, optional): Axis along which to perform the FFT. The
default is ``-1``.
norm (str, optional): One of ``"backward"``, ``"ortho"``, or
``"forward"``. Default is ``"backward"``.
Returns:
array: The inverse DFT of the input along the given axis.
)pbdoc");
m.def(
"fft2",
[](const mx::array& a,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
const std::string& norm,
mx::StreamOrDevice s) {
auto fft_norm = parse_norm(norm, "fft2");
if (axes.has_value() && n.has_value()) {
return mx::fft::fftn(a, n.value(), axes.value(), fft_norm, s);
} else if (axes.has_value()) {
return mx::fft::fftn(a, axes.value(), fft_norm, s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[fft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return mx::fft::fftn(a, fft_norm, s);
}
},
"a"_a,
"s"_a = nb::none(),
"axes"_a.none() = std::vector<int>{-2, -1},
"norm"_a = "backward",
"stream"_a = nb::none(),
R"pbdoc(
Two dimensional discrete Fourier Transform.
Args:
a (array): The input array.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``[-2, -1]``.
norm (str, optional): One of ``"backward"``, ``"ortho"``, or
``"forward"``. Default is ``"backward"``.
Returns:
array: The DFT of the input along the given axes.
)pbdoc");
m.def(
"ifft2",
[](const mx::array& a,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
const std::string& norm,
mx::StreamOrDevice s) {
auto fft_norm = parse_norm(norm, "ifft2");
if (axes.has_value() && n.has_value()) {
return mx::fft::ifftn(a, n.value(), axes.value(), fft_norm, s);
} else if (axes.has_value()) {
return mx::fft::ifftn(a, axes.value(), fft_norm, s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[ifft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return mx::fft::ifftn(a, fft_norm, s);
}
},
"a"_a,
"s"_a = nb::none(),
"axes"_a.none() = std::vector<int>{-2, -1},
"norm"_a = "backward",
"stream"_a = nb::none(),
R"pbdoc(
Two dimensional inverse discrete Fourier Transform.
Args:
a (array): The input array.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``[-2, -1]``.
norm (str, optional): One of ``"backward"``, ``"ortho"``, or
``"forward"``. Default is ``"backward"``.
Returns:
array: The inverse DFT of the input along the given axes.
)pbdoc");
m.def(
"fftn",
[](const mx::array& a,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
const std::string& norm,
mx::StreamOrDevice s) {
auto fft_norm = parse_norm(norm, "fftn");
if (axes.has_value() && n.has_value()) {
return mx::fft::fftn(a, n.value(), axes.value(), fft_norm, s);
} else if (axes.has_value()) {
return mx::fft::fftn(a, axes.value(), fft_norm, s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[fftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return mx::fft::fftn(a, fft_norm, s);
}
},
"a"_a,
"s"_a = nb::none(),
"axes"_a = nb::none(),
"norm"_a = "backward",
"stream"_a = nb::none(),
R"pbdoc(
n-dimensional discrete Fourier Transform.
Args:
a (array): The input array.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``None`` in which case the FFT is over the last
``len(s)`` axes are or all axes if ``s`` is also ``None``.
norm (str, optional): One of ``"backward"``, ``"ortho"``, or
``"forward"``. Default is ``"backward"``.
Returns:
array: The DFT of the input along the given axes.
)pbdoc");
m.def(
"ifftn",
[](const mx::array& a,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
const std::string& norm,
mx::StreamOrDevice s) {
auto fft_norm = parse_norm(norm, "ifftn");
if (axes.has_value() && n.has_value()) {
return mx::fft::ifftn(a, n.value(), axes.value(), fft_norm, s);
} else if (axes.has_value()) {
return mx::fft::ifftn(a, axes.value(), fft_norm, s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[ifftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return mx::fft::ifftn(a, fft_norm, s);
}
},
"a"_a,
"s"_a = nb::none(),
"axes"_a = nb::none(),
"norm"_a = "backward",
"stream"_a = nb::none(),
R"pbdoc(
n-dimensional inverse discrete Fourier Transform.
Args:
a (array): The input array.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``None`` in which case the FFT is over the last
``len(s)`` axes or all axes if ``s`` is also ``None``.
norm (str, optional): One of ``"backward"``, ``"ortho"``, or
``"forward"``. Default is ``"backward"``.
Returns:
array: The inverse DFT of the input along the given axes.
)pbdoc");
m.def(
"rfft",
[](const mx::array& a,
const std::optional<int>& n,
int axis,
const std::string& norm,
mx::StreamOrDevice s) {
auto fft_norm = parse_norm(norm, "rfft");
if (n.has_value()) {
return mx::fft::rfft(a, n.value(), axis, fft_norm, s);
} else {
return mx::fft::rfft(a, axis, fft_norm, s);
}
},
"a"_a,
"n"_a = nb::none(),
"axis"_a = -1,
"norm"_a = "backward",
"stream"_a = nb::none(),
R"pbdoc(
One dimensional discrete Fourier Transform on a real input.
The output has the same shape as the input except along ``axis`` in
which case it has size ``n // 2 + 1``.
Args:
a (array): The input array. If the array is complex it will be silently
cast to a real type.
n (int, optional): Size of the transformed axis. The
corresponding axis in the input is truncated or padded with
zeros to match ``n``. The default value is ``a.shape[axis]``.
axis (int, optional): Axis along which to perform the FFT. The
default is ``-1``.
norm (str, optional): One of ``"backward"``, ``"ortho"``, or
``"forward"``. Default is ``"backward"``.
Returns:
array: The DFT of the input along the given axis. The output
data type will be complex.
)pbdoc");
m.def(
"irfft",
[](const mx::array& a,
const std::optional<int>& n,
int axis,
const std::string& norm,
mx::StreamOrDevice s) {
auto fft_norm = parse_norm(norm, "irfft");
if (n.has_value()) {
return mx::fft::irfft(a, n.value(), axis, fft_norm, s);
} else {
return mx::fft::irfft(a, axis, fft_norm, s);
}
},
"a"_a,
"n"_a = nb::none(),
"axis"_a = -1,
"norm"_a = "backward",
"stream"_a = nb::none(),
R"pbdoc(
The inverse of :func:`rfft`.
The output has the same shape as the input except along ``axis`` in
which case it has size ``n``.
Args:
a (array): The input array.
n (int, optional): Size of the transformed axis. The
corresponding axis in the input is truncated or padded with
zeros to match ``n // 2 + 1``. The default value is
``a.shape[axis] // 2 + 1``.
axis (int, optional): Axis along which to perform the FFT. The
default is ``-1``.
norm (str, optional): One of ``"backward"``, ``"ortho"``, or
``"forward"``. Default is ``"backward"``.
Returns:
array: The real array containing the inverse of :func:`rfft`.
)pbdoc");
m.def(
"rfft2",
[](const mx::array& a,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
const std::string& norm,
mx::StreamOrDevice s) {
auto fft_norm = parse_norm(norm, "rfft2");
if (axes.has_value() && n.has_value()) {
return mx::fft::rfftn(a, n.value(), axes.value(), fft_norm, s);
} else if (axes.has_value()) {
return mx::fft::rfftn(a, axes.value(), fft_norm, s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[rfft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return mx::fft::rfftn(a, fft_norm, s);
}
},
"a"_a,
"s"_a = nb::none(),
"axes"_a.none() = std::vector<int>{-2, -1},
"norm"_a = "backward",
"stream"_a = nb::none(),
R"pbdoc(
Two dimensional real discrete Fourier Transform.
The output has the same shape as the input except along the dimensions in
``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is
treated as the real axis and will have size ``s[-1] // 2 + 1``.
Args:
a (array): The input array. If the array is complex it will be silently
cast to a real type.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``[-2, -1]``.
norm (str, optional): One of ``"backward"``, ``"ortho"``, or
``"forward"``. Default is ``"backward"``.
Returns:
array: The real DFT of the input along the given axes. The output
data type will be complex.
)pbdoc");
m.def(
"irfft2",
[](const mx::array& a,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
const std::string& norm,
mx::StreamOrDevice s) {
auto fft_norm = parse_norm(norm, "irfft2");
if (axes.has_value() && n.has_value()) {
return mx::fft::irfftn(a, n.value(), axes.value(), fft_norm, s);
} else if (axes.has_value()) {
return mx::fft::irfftn(a, axes.value(), fft_norm, s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[irfft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return mx::fft::irfftn(a, fft_norm, s);
}
},
"a"_a,
"s"_a = nb::none(),
"axes"_a.none() = std::vector<int>{-2, -1},
"norm"_a = "backward",
"stream"_a = nb::none(),
R"pbdoc(
The inverse of :func:`rfft2`.
Note the input is generally complex. The dimensions of the input
specified in ``axes`` are padded or truncated to match the sizes
from ``s``. The last axis in ``axes`` is treated as the real axis
and will have size ``s[-1] // 2 + 1``.
Args:
a (array): The input array.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s`` except for the last axis
which has size ``s[-1] // 2 + 1``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``[-2, -1]``.
norm (str, optional): One of ``"backward"``, ``"ortho"``, or
``"forward"``. Default is ``"backward"``.
Returns:
array: The real array containing the inverse of :func:`rfft2`.
)pbdoc");
m.def(
"rfftn",
[](const mx::array& a,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
const std::string& norm,
mx::StreamOrDevice s) {
auto fft_norm = parse_norm(norm, "rfftn");
if (axes.has_value() && n.has_value()) {
return mx::fft::rfftn(a, n.value(), axes.value(), fft_norm, s);
} else if (axes.has_value()) {
return mx::fft::rfftn(a, axes.value(), fft_norm, s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[rfftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return mx::fft::rfftn(a, fft_norm, s);
}
},
"a"_a,
"s"_a = nb::none(),
"axes"_a = nb::none(),
"norm"_a = "backward",
"stream"_a = nb::none(),
R"pbdoc(
n-dimensional real discrete Fourier Transform.
The output has the same shape as the input except along the dimensions in
``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is
treated as the real axis and will have size ``s[-1] // 2 + 1``.
Args:
a (array): The input array. If the array is complex it will be silently
cast to a real type.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``None`` in which case the FFT is over the last
``len(s)`` axes or all axes if ``s`` is also ``None``.
norm (str, optional): One of ``"backward"``, ``"ortho"``, or
``"forward"``. Default is ``"backward"``.
Returns:
array: The real DFT of the input along the given axes. The output
)pbdoc");
m.def(
"irfftn",
[](const mx::array& a,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
const std::string& norm,
mx::StreamOrDevice s) {
auto fft_norm = parse_norm(norm, "irfftn");
if (axes.has_value() && n.has_value()) {
return mx::fft::irfftn(a, n.value(), axes.value(), fft_norm, s);
} else if (axes.has_value()) {
return mx::fft::irfftn(a, axes.value(), fft_norm, s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[irfftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return mx::fft::irfftn(a, fft_norm, s);
}
},
"a"_a,
"s"_a = nb::none(),
"axes"_a = nb::none(),
"norm"_a = "backward",
"stream"_a = nb::none(),
R"pbdoc(
The inverse of :func:`rfftn`.
Note the input is generally complex. The dimensions of the input
specified in ``axes`` are padded or truncated to match the sizes
from ``s``. The last axis in ``axes`` is treated as the real axis
and will have size ``s[-1] // 2 + 1``.
Args:
a (array): The input array.
s (list(int), optional): Sizes of the transformed axes. The
corresponding axes in the input are truncated or padded with
zeros to match the sizes in ``s``. The default value is the
sizes of ``a`` along ``axes``.
axes (list(int), optional): Axes along which to perform the FFT.
The default is ``None`` in which case the FFT is over the last
``len(s)`` axes or all axes if ``s`` is also ``None``.
norm (str, optional): One of ``"backward"``, ``"ortho"``, or
``"forward"``. Default is ``"backward"``.
Returns:
array: The real array containing the inverse of :func:`rfftn`.
)pbdoc");
m.def(
"fftfreq",
[](int n, double d, mx::StreamOrDevice s) {
return mx::fft::fftfreq(n, d, s);
},
"n"_a,
"d"_a = 1.0,
"stream"_a = nb::none(),
R"pbdoc(
Return the discrete Fourier Transform sample frequencies.
Args:
n (int): Window length.
d (float, optional): Sample spacing. The default is ``1.0``.
Returns:
array: The sample frequencies as a one-dimensional array of type ``float32``.
)pbdoc");
m.def(
"rfftfreq",
[](int n, double d, mx::StreamOrDevice s) {
return mx::fft::rfftfreq(n, d, s);
},
"n"_a,
"d"_a = 1.0,
"stream"_a = nb::none(),
R"pbdoc(
Return the discrete Fourier Transform sample frequencies
for use with :func:`rfft` and :func:`irfft`.
The returned array contains the non-negative frequency terms
in the range ``[0, floor(n/2)]``.
Args:
n (int): Window length.
d (float, optional): Sample spacing. The default is ``1.0``.
Returns:
array: The sample frequencies as a one-dimensional array of type ``float32``.
)pbdoc");
m.def(
"fftshift",
[](const mx::array& a, const IntOrVec& axes, mx::StreamOrDevice s) {
if (std::holds_alternative<std::monostate>(axes)) {
return mx::fft::fftshift(a, s);
} else if (auto pv = std::get_if<int>(&axes); pv) {
return mx::fft::fftshift(a, {*pv}, s);
} else {
return mx::fft::fftshift(a, std::get<std::vector<int>>(axes), s);
}
},
"a"_a,
"axes"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
Shift the zero-frequency component to the center of the spectrum.
Args:
a (array): The input array.
axes (int or list(int), optional): Axis or axes over which to perform the shift.
If ``None``, shift all axes.
Returns:
array: The shifted array with the same shape as the input.
)pbdoc");
m.def(
"ifftshift",
[](const mx::array& a, const IntOrVec& axes, mx::StreamOrDevice s) {
if (std::holds_alternative<std::monostate>(axes)) {
return mx::fft::ifftshift(a, s);
} else if (auto pv = std::get_if<int>(&axes); pv) {
return mx::fft::ifftshift(a, {*pv}, s);
} else {
return mx::fft::ifftshift(a, std::get<std::vector<int>>(axes), s);
}
},
"a"_a,
"axes"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
The inverse of :func:`fftshift`. While identical to :func:`fftshift` for even-length axes,
the behavior differs for odd-length axes.
Args:
a (array): The input array.
axes (int or list(int), optional): Axis or axes over which to perform the inverse shift.
If ``None``, shift all axes.
Returns:
array: The inverse-shifted array with the same shape as the input.
)pbdoc");
}