use rustfft::num_complex::Complex;
use rustfft::{Fft, FftPlanner};
use std::path::Path;
use std::sync::Arc;
use super::config::NoiseConfig;
use super::phase::PhaseGenerator;
use super::spectral::SpectralMLP;
use super::{NoiseError, NoiseResult};
const CROSSFADE_LEN: usize = 32;
pub struct NoiseGenerator {
config: NoiseConfig,
mlp: SpectralMLP,
phase_gen: PhaseGenerator,
ifft: Arc<dyn Fft<f32>>,
time: f64,
sample_counter: u64,
prev_last_sample: f32,
has_prev: bool,
}
impl std::fmt::Debug for NoiseGenerator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NoiseGenerator")
.field("config", &self.config)
.field("time", &self.time)
.field("sample_counter", &self.sample_counter)
.finish_non_exhaustive()
}
}
impl NoiseGenerator {
pub fn new(config: NoiseConfig) -> NoiseResult<Self> {
config.validate()?;
let n_freqs = config.buffer_size / 2 + 1;
let mlp = SpectralMLP::random_init(8, 64, n_freqs, 42);
let mut planner = FftPlanner::new();
let ifft = planner.plan_fft_inverse(config.buffer_size);
let phase_gen = PhaseGenerator::new(12345);
Ok(Self {
config,
mlp,
phase_gen,
ifft,
time: 0.0,
sample_counter: 0,
prev_last_sample: 0.0,
has_prev: false,
})
}
pub fn from_apr<P: AsRef<Path>>(path: P, config: NoiseConfig) -> NoiseResult<Self> {
config.validate()?;
let mlp = SpectralMLP::load_apr(path)?;
let n_freqs = config.buffer_size / 2 + 1;
if mlp.n_freqs() != n_freqs {
return Err(NoiseError::ModelError(format!(
"Model n_freqs {} doesn't match config buffer_size {} (expected {})",
mlp.n_freqs(),
config.buffer_size,
n_freqs
)));
}
let mut planner = FftPlanner::new();
let ifft = planner.plan_fft_inverse(config.buffer_size);
let phase_gen = PhaseGenerator::new(12345);
Ok(Self {
config,
mlp,
phase_gen,
ifft,
time: 0.0,
sample_counter: 0,
prev_last_sample: 0.0,
has_prev: false,
})
}
pub fn with_mlp(config: NoiseConfig, mlp: SpectralMLP) -> NoiseResult<Self> {
config.validate()?;
let n_freqs = config.buffer_size / 2 + 1;
if mlp.n_freqs() != n_freqs {
return Err(NoiseError::ModelError(format!(
"Model n_freqs {} doesn't match config buffer_size {} (expected {})",
mlp.n_freqs(),
config.buffer_size,
n_freqs
)));
}
let mut planner = FftPlanner::new();
let ifft = planner.plan_fft_inverse(config.buffer_size);
let phase_gen = PhaseGenerator::new(12345);
Ok(Self {
config,
mlp,
phase_gen,
ifft,
time: 0.0,
sample_counter: 0,
prev_last_sample: 0.0,
has_prev: false,
})
}
pub fn generate(&mut self, output: &mut [f32]) -> NoiseResult<()> {
if output.len() != self.config.buffer_size {
return Err(NoiseError::BufferSizeMismatch {
expected: self.config.buffer_size,
actual: output.len(),
});
}
let n_freqs = self.config.buffer_size / 2 + 1;
let config_vec = self.config.encode(self.time);
let magnitudes = self.mlp.forward(&config_vec);
let phases = self.phase_gen.generate(n_freqs);
let mut spectrum: Vec<Complex<f32>> = Vec::with_capacity(self.config.buffer_size);
spectrum.push(Complex::new(magnitudes[0], 0.0));
for i in 1..n_freqs - 1 {
let mag = magnitudes[i];
let phase = phases[i];
spectrum.push(Complex::new(mag * phase.cos(), mag * phase.sin()));
}
if self.config.buffer_size % 2 == 0 {
spectrum.push(Complex::new(magnitudes[n_freqs - 1], 0.0));
} else {
let mag = magnitudes[n_freqs - 1];
let phase = phases[n_freqs - 1];
spectrum.push(Complex::new(mag * phase.cos(), mag * phase.sin()));
}
for i in (1..n_freqs - 1).rev() {
let mag = magnitudes[i];
let phase = phases[i];
spectrum.push(Complex::new(mag * phase.cos(), -mag * phase.sin()));
}
self.ifft.process(&mut spectrum);
let norm = 1.0 / (self.config.buffer_size as f32).sqrt();
let mut max_abs = 0.0f32;
for (i, sample) in spectrum.iter().enumerate().take(self.config.buffer_size) {
output[i] = sample.re * norm;
max_abs = max_abs.max(output[i].abs());
}
if max_abs > 1.0 {
let scale = 0.95 / max_abs;
for sample in output.iter_mut() {
*sample *= scale;
}
}
for sample in output.iter_mut() {
*sample = sample.clamp(-1.0, 1.0);
if !sample.is_finite() {
*sample = 0.0;
}
}
if self.has_prev && output.len() >= CROSSFADE_LEN {
let start_val = self.prev_last_sample;
let end_val = output[CROSSFADE_LEN - 1];
for i in 0..CROSSFADE_LEN {
let t = (i + 1) as f32 / CROSSFADE_LEN as f32;
let interp = start_val * (1.0 - t) + output[i] * t;
let ramp_target = start_val + (end_val - start_val) * t;
output[i] = interp * 0.7 + ramp_target * 0.3;
}
}
self.prev_last_sample = *output.last().unwrap_or(&0.0);
self.has_prev = true;
let samples_per_buffer = self.config.buffer_size as f64;
let sample_rate = f64::from(self.config.sample_rate);
self.time += samples_per_buffer / sample_rate;
self.sample_counter += self.config.buffer_size as u64;
Ok(())
}
pub fn update_config(&mut self, config: NoiseConfig) -> NoiseResult<()> {
config.validate()?;
if config.buffer_size != self.config.buffer_size {
let mut planner = FftPlanner::new();
self.ifft = planner.plan_fft_inverse(config.buffer_size);
let n_freqs = config.buffer_size / 2 + 1;
if self.mlp.n_freqs() != n_freqs {
self.mlp = SpectralMLP::random_init(8, 64, n_freqs, 42);
}
}
self.config = config;
Ok(())
}
#[must_use]
pub fn config(&self) -> &NoiseConfig {
&self.config
}
#[must_use]
pub fn time(&self) -> f64 {
self.time
}
#[must_use]
pub fn sample_counter(&self) -> u64 {
self.sample_counter
}
pub fn reset(&mut self) {
self.time = 0.0;
self.sample_counter = 0;
self.phase_gen.reset(12345);
self.prev_last_sample = 0.0;
self.has_prev = false;
}
pub fn set_phase_seed(&mut self, seed: u64) {
self.phase_gen.reset(seed);
}
}
impl Iterator for NoiseGenerator {
type Item = Vec<f32>;
fn next(&mut self) -> Option<Self::Item> {
let mut buffer = vec![0.0; self.config.buffer_size];
match self.generate(&mut buffer) {
Ok(()) => Some(buffer),
Err(_) => None,
}
}
}
#[cfg(test)]
#[path = "generator_tests.rs"]
mod tests;