#pragma once
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <vector>
#include <complex>
#include <algorithm>
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
static inline std::size_t next_pow2(std::size_t n) {
std::size_t p = 1;
while (p < n) p <<= 1;
return p;
}
static void fft_radix2(std::complex<float>* buf, std::size_t n, bool inverse) {
if (n <= 1) return;
for (std::size_t i = 1, j = 0; i < n; ++i) {
std::size_t bit = n >> 1;
for (; j & bit; bit >>= 1) {
j ^= bit;
}
j ^= bit;
if (i < j) std::swap(buf[i], buf[j]);
}
for (std::size_t len = 2; len <= n; len <<= 1) {
const float angle = (inverse ? 2.0f : -2.0f)
* static_cast<float>(M_PI)
/ static_cast<float>(len);
const std::complex<float> wn(std::cos(angle), std::sin(angle));
for (std::size_t i = 0; i < n; i += len) {
std::complex<float> w(1.0f, 0.0f);
for (std::size_t j = 0; j < len / 2; ++j) {
std::complex<float> u = buf[i + j];
std::complex<float> v = buf[i + j + len / 2] * w;
buf[i + j] = u + v;
buf[i + j + len / 2] = u - v;
w *= wn;
}
}
}
if (inverse) {
const float inv_n = 1.0f / static_cast<float>(n);
for (std::size_t i = 0; i < n; ++i) {
buf[i] *= inv_n;
}
}
}
struct CPUFilterBank {
std::size_t fft_len; std::size_t n_filters; std::vector<std::vector<std::complex<float>>> filters; std::vector<std::complex<float>> phi; float l1_norm; };
static std::vector<std::complex<float>> morlet_freq(
std::size_t fft_len,
float centre_freq,
float sigma)
{
std::vector<std::complex<float>> psi(fft_len, {0.0f, 0.0f});
const float norm = std::pow(static_cast<float>(M_PI), -0.25f);
for (std::size_t k = 0; k <= fft_len / 2; ++k) {
const float omega = 2.0f * static_cast<float>(M_PI)
* static_cast<float>(k)
/ static_cast<float>(fft_len);
const float diff = omega - centre_freq;
const float val = norm * std::exp(-0.5f * diff * diff / (sigma * sigma));
psi[k] = {val, 0.0f};
}
return psi;
}
static std::vector<std::complex<float>> lowpass_freq(
std::size_t fft_len,
float cutoff_freq,
float sigma)
{
std::vector<std::complex<float>> phi(fft_len, {0.0f, 0.0f});
for (std::size_t k = 0; k <= fft_len / 2; ++k) {
const float omega = 2.0f * static_cast<float>(M_PI)
* static_cast<float>(k)
/ static_cast<float>(fft_len);
const float val = std::exp(-0.5f * omega * omega / (cutoff_freq * cutoff_freq));
phi[k] = {val, 0.0f};
}
for (std::size_t k = fft_len / 2 + 1; k < fft_len; ++k) {
phi[k] = phi[fft_len - k];
}
return phi;
}
static CPUFilterBank build_cpu_morlet_bank(
int32_t J,
int32_t Q,
int32_t signal_len)
{
const std::size_t fft_len = next_pow2(static_cast<std::size_t>(signal_len));
const std::size_t n_filters = static_cast<std::size_t>(J)
* static_cast<std::size_t>(Q);
CPUFilterBank bank;
bank.fft_len = fft_len;
bank.n_filters = n_filters;
bank.filters.reserve(n_filters);
const float omega_max = static_cast<float>(M_PI);
const float q_float = static_cast<float>(Q);
float l1_accum = 0.0f;
for (std::size_t i = 0; i < n_filters; ++i) {
const float ratio = static_cast<float>(i) / q_float;
const float omega_0 = omega_max * std::pow(2.0f, -ratio);
const float sigma = omega_0 / q_float;
auto psi = morlet_freq(fft_len, omega_0, sigma);
float psi_l1 = 0.0f;
for (std::size_t k = 0; k < fft_len; ++k) {
psi_l1 += std::abs(psi[k]);
}
psi_l1 /= static_cast<float>(fft_len);
l1_accum += psi_l1;
bank.filters.push_back(std::move(psi));
}
bank.l1_norm = (n_filters > 0) ? (l1_accum / static_cast<float>(n_filters))
: 0.0f;
const float phi_cutoff = omega_max * std::pow(2.0f, -static_cast<float>(J));
bank.phi = lowpass_freq(fft_len, phi_cutoff, phi_cutoff * 0.5f);
return bank;
}
static void scatter_layer(
const float* input,
float* output, std::size_t signal_len,
const CPUFilterBank& bank,
std::vector<std::vector<float>>& modulus_out) {
const std::size_t N = bank.fft_len;
std::vector<std::complex<float>> X(N, {0.0f, 0.0f});
for (std::size_t i = 0; i < signal_len; ++i) {
X[i] = {input[i], 0.0f};
}
fft_radix2(X.data(), N, false);
modulus_out.resize(bank.n_filters);
std::vector<float> avg(signal_len, 0.0f);
for (std::size_t f = 0; f < bank.n_filters; ++f) {
std::vector<std::complex<float>> Y(N);
for (std::size_t k = 0; k < N; ++k) {
Y[k] = X[k] * bank.filters[f][k];
}
fft_radix2(Y.data(), N, true);
modulus_out[f].resize(signal_len);
for (std::size_t i = 0; i < signal_len; ++i) {
modulus_out[f][i] = std::abs(Y[i]);
}
std::vector<std::complex<float>> U(N, {0.0f, 0.0f});
for (std::size_t i = 0; i < signal_len; ++i) {
U[i] = {modulus_out[f][i], 0.0f};
}
fft_radix2(U.data(), N, false);
for (std::size_t k = 0; k < N; ++k) {
U[k] *= bank.phi[k];
}
fft_radix2(U.data(), N, true);
for (std::size_t i = 0; i < signal_len; ++i) {
avg[i] += U[i].real();
}
}
if (bank.n_filters > 0) {
const float inv = 1.0f / static_cast<float>(bank.n_filters);
for (std::size_t i = 0; i < signal_len; ++i) {
output[i] = avg[i] * inv;
}
} else {
std::memset(output, 0, signal_len * sizeof(float));
}
}
static void cpu_wst_forward(
const float* input,
float* output, int32_t signal_len,
int32_t batch_size,
int32_t depth,
const CPUFilterBank& bank,
float& l1_norm_out)
{
const std::size_t sig_len = static_cast<std::size_t>(signal_len);
l1_norm_out = bank.l1_norm;
for (int32_t b = 0; b < batch_size; ++b) {
const float* sig_in = input + static_cast<std::size_t>(b) * sig_len;
float* sig_out = output + static_cast<std::size_t>(b) * sig_len;
std::vector<float> current(sig_in, sig_in + sig_len);
for (int32_t d = 0; d < depth; ++d) {
std::vector<std::vector<float>> modulus;
scatter_layer(current.data(), sig_out, sig_len, bank, modulus);
if (!modulus.empty() && !modulus[0].empty()) {
current = modulus[0];
} else {
std::fill(current.begin(), current.end(), 0.0f);
}
}
}
}