use aether_core::{node::DspNode, param::ParamBlock, state::StateBlob, BUFFER_SIZE, MAX_INPUTS};
use std::f32::consts::TAU;
#[inline(always)]
fn fast_sin_norm(phase: f32) -> f32 {
let x = (phase - 0.5) * TAU;
let x2 = x * x;
x * (0.999_999_4 + x2 * (-0.166_666_58 + x2 * (0.008_333_331 - x2 * 0.000_198_409)))
}
#[derive(Clone, Copy)]
struct OscState {
phase: f32,
}
pub struct Oscillator {
phase: f32,
tuning: Option<Box<[f32; 128]>>,
}
impl Oscillator {
pub fn new() -> Self {
Self { phase: 0.0, tuning: None }
}
pub fn set_tuning(&mut self, frequencies: [f32; 128]) {
self.tuning = Some(Box::new(frequencies));
}
pub fn clear_tuning(&mut self) {
self.tuning = None;
}
}
#[inline(always)]
fn blep(t: f32, dt: f32) -> f32 {
if t < dt {
let t = t / dt;
2.0 * t - t * t - 1.0
} else if t > 1.0 - dt {
let t = (t - 1.0) / dt;
t * t + 2.0 * t + 1.0
} else {
0.0
}
}
impl Default for Oscillator {
fn default() -> Self {
Self::new()
}
}
impl DspNode for Oscillator {
fn process(
&mut self,
_inputs: &[Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS],
output: &mut [f32; BUFFER_SIZE],
params: &mut ParamBlock,
sample_rate: f32,
) {
let amp = params.get(1).current.clamp(0.0, 1.0);
let wave = params.get(2).current as u32;
let freq = if let Some(ref tuning) = self.tuning {
let midi = params.get(3).current;
if midi >= 0.0 {
let note = (midi as usize).min(127);
tuning[note].max(0.01)
} else {
params.get(0).current.max(0.01)
}
} else {
params.get(0).current.max(0.01)
};
let phase_inc = freq / sample_rate;
match wave {
0 => {
let mut phase = self.phase;
for out in output.iter_mut() {
*out = fast_sin_norm(phase) * amp;
phase = (phase + phase_inc).fract();
}
self.phase = phase;
}
1 => {
let mut phase = self.phase;
for out in output.iter_mut() {
let mut saw = 2.0 * phase - 1.0;
saw -= blep(phase, phase_inc);
*out = saw * amp;
phase = (phase + phase_inc).fract();
}
self.phase = phase;
}
2 => {
let mut phase = self.phase;
for out in output.iter_mut() {
let mut sq = if phase < 0.5 { 1.0f32 } else { -1.0f32 };
sq += blep(phase, phase_inc);
sq -= blep((phase + 0.5).fract(), phase_inc);
*out = sq * amp;
phase = (phase + phase_inc).fract();
}
self.phase = phase;
}
_ => {
let mut phase = self.phase;
for out in output.iter_mut() {
let tri = if phase < 0.5 {
4.0 * phase - 1.0
} else {
3.0 - 4.0 * phase
};
*out = tri * amp;
phase = (phase + phase_inc).fract();
}
self.phase = phase;
}
}
if params.params[..params.count].iter().any(|p| p.step != 0.0) {
for _ in 0..BUFFER_SIZE {
params.tick_all();
}
}
}
fn capture_state(&self) -> StateBlob {
StateBlob::from_value(&OscState { phase: self.phase })
}
fn restore_state(&mut self, state: StateBlob) {
let s: OscState = state.to_value();
self.phase = s.phase;
}
fn type_name(&self) -> &'static str {
"Oscillator"
}
}