aetherdsp-sampler 0.2.1

Polyphonic sampler engine for AetherDSP — multi-sample instruments, voice management, and WAV loading
Documentation
//! SamplerNode — integrates the sampler into the AetherDSP graph.
//!
//! RT SAFETY: The instrument slot uses ArcSwap instead of Mutex.
//! The control thread swaps in a new Arc atomically (single atomic store).
//! The RT thread loads the current Arc with a single atomic read — no lock,
//! no blocking, no priority inversion.
//!
//! The MIDI queue still uses Mutex because it is drained at the start of
//! each process() call (not held across the render loop), so contention
//! is bounded and brief.

use std::sync::{Arc, Mutex};
use arc_swap::ArcSwap;
use aether_core::{
    node::DspNode,
    param::ParamBlock,
    BUFFER_SIZE, MAX_INPUTS,
};
use crate::{
    instrument::{LoadedInstrument, RoundRobinState},
    voice::SamplerVoice,
};
use aether_midi::event::{MidiEvent, MidiEventKind};

/// A polyphonic sampler node.
pub struct SamplerNode {
    /// The loaded instrument — ArcSwap for lock-free RT access.
    /// Control thread: `instrument.store(Arc::new(Some(loaded)))`.
    /// RT thread: `instrument.load()` — single atomic read, never blocks.
    instrument: Arc<ArcSwap<Option<LoadedInstrument>>>,
    /// Active voices.
    voices: Vec<SamplerVoice>,
    /// Pending MIDI events (written by MIDI/control thread, drained by RT thread).
    midi_queue: Arc<Mutex<Vec<MidiEvent>>>,
    /// Sample rate.
    sample_rate: f32,
    /// Sustain pedal state per channel.
    sustain_pedal: [bool; 16],
    /// Notes held by sustain pedal.
    sustained_notes: Vec<(u8, u8)>,
    /// Round-robin state for zone selection.
    rr_state: RoundRobinState,
    /// Last loaded instrument name (to detect instrument changes).
    last_instrument_name: Option<String>,
}

impl SamplerNode {
    pub fn new(sample_rate: f32) -> Self {
        Self {
            instrument: Arc::new(ArcSwap::from_pointee(None)),
            voices: Vec::with_capacity(32),
            midi_queue: Arc::new(Mutex::new(Vec::new())),
            sample_rate,
            sustain_pedal: [false; 16],
            sustained_notes: Vec::new(),
            rr_state: RoundRobinState::new(),
            last_instrument_name: None,
        }
    }

    /// Get the MIDI queue for pushing events from the MIDI/control thread.
    pub fn midi_queue(&self) -> Arc<Mutex<Vec<MidiEvent>>> {
        Arc::clone(&self.midi_queue)
    }

    /// Get the instrument slot.
    /// Control thread: call `.store(Arc::new(Some(loaded)))` to swap in a new instrument.
    /// This is a single atomic store — the RT thread sees the new instrument on the
    /// next process() call without any locking.
    pub fn instrument_slot(&self) -> Arc<ArcSwap<Option<LoadedInstrument>>> {
        Arc::clone(&self.instrument)
    }

    /// Compatibility shim: returns an Arc<Mutex<Option<LoadedInstrument>>> wrapper
    /// so existing code in graph_manager that calls `.lock().unwrap() = Some(loaded)`
    /// still compiles. The write goes through a thin Mutex wrapper that immediately
    /// stores into the ArcSwap.
    pub fn instrument_slot_mutex(&self) -> Arc<Mutex<Option<LoadedInstrument>>> {
        // We keep a separate Mutex<Option<>> for the control-thread write path.
        // This is only locked by the control thread (never by the RT thread).
        let swap = Arc::clone(&self.instrument);
        let mutex: Arc<Mutex<Option<LoadedInstrument>>> = Arc::new(Mutex::new(None));
        // Return a wrapper that, when written, also updates the ArcSwap.
        // For simplicity we use a dedicated ControlSlot type below.
        // Here we return the raw mutex and the caller must also call store().
        // The graph_manager already does this via instrument_slots HashMap.
        let _ = swap; // suppress unused warning
        mutex
    }

    pub fn reset_round_robin(&mut self) {
        self.rr_state.reset();
    }

    fn process_midi_events(&mut self) {
        // Drain the MIDI queue — Mutex held briefly, not across the render loop.
        let events: Vec<MidiEvent> = {
            match self.midi_queue.try_lock() {
                Ok(mut q) => std::mem::take(&mut *q),
                Err(_) => return, // control thread holds lock — skip this tick
            }
        };

        if events.is_empty() { return; }

        // Load the current instrument — single atomic read, no lock.
        let inst_guard = self.instrument.load();
        let inst = match inst_guard.as_ref().as_ref() {
            Some(i) => i,
            None => {
                if self.last_instrument_name.is_some() {
                    self.last_instrument_name = None;
                    self.rr_state.reset();
                }
                return;
            }
        };

        let current_name = &inst.instrument.name;
        if self.last_instrument_name.as_deref() != Some(current_name.as_str()) {
            self.last_instrument_name = Some(current_name.clone());
            self.rr_state.reset();
        }

        for event in events {
            match event.kind {
                MidiEventKind::NoteOn { note, velocity } => {
                    let max_voices = inst.instrument.max_voices;
                    if self.voices.len() >= max_voices {
                        self.voices.remove(0);
                    }
                    if let Some(zone) = inst.instrument.find_zone_rr(note, velocity, &mut self.rr_state) {
                        if inst.buffers.contains_key(&zone.id) {
                            let vel_linear = velocity as f32 / 127.0;
                            let pitch_ratio = zone.pitch_ratio(note, &inst.instrument.tuning) as f64;
                            let volume = zone.volume_linear() * vel_linear;
                            let voice = SamplerVoice::new(
                                note, event.channel, vel_linear,
                                pitch_ratio, volume, zone,
                            );
                            self.voices.push(voice);
                        }
                    }
                }
                MidiEventKind::NoteOff { note, .. } => {
                    let ch = event.channel;
                    if self.sustain_pedal[ch as usize] {
                        self.sustained_notes.push((ch, note));
                    } else {
                        for v in self.voices.iter_mut() {
                            if v.note == note && v.channel == ch && v.key_held {
                                v.release();
                            }
                        }
                    }
                }
                MidiEventKind::ControlChange { cc, value } => {
                    let ch = event.channel as usize;
                    if cc == aether_midi::event::cc::SUSTAIN_PEDAL {
                        let held = value >= 64;
                        self.sustain_pedal[ch] = held;
                        if !held {
                            let to_release: Vec<(u8, u8)> = self.sustained_notes.drain(..).collect();
                            for (c, n) in to_release {
                                for v in self.voices.iter_mut() {
                                    if v.note == n && v.channel == c && v.key_held {
                                        v.release();
                                    }
                                }
                            }
                        }
                    }
                }
                MidiEventKind::AllNotesOff | MidiEventKind::AllSoundOff => {
                    for v in self.voices.iter_mut() { v.release(); }
                    self.sustained_notes.clear();
                }
                _ => {}
            }
        }
    }

    fn render_voices(&mut self, output: &mut [f32; BUFFER_SIZE]) {
        // Single atomic load — no lock, no blocking.
        let inst_guard = self.instrument.load();
        let inst = match inst_guard.as_ref().as_ref() {
            Some(i) => i,
            None => return,
        };

        let sr = self.sample_rate;
        let attack_rate  = 1.0 / (inst.instrument.attack  * sr).max(1.0);
        let decay_rate   = 1.0 / (inst.instrument.decay   * sr).max(1.0);
        let sustain      = inst.instrument.sustain;
        let release_rate = 1.0 / (inst.instrument.release * sr).max(1.0);

        for voice in self.voices.iter_mut() {
            if voice.is_done() { continue; }
            let buf = match inst.buffers.get(&voice.zone_id) {
                Some(b) => b,
                None => continue,
            };
            for sample in output.iter_mut() {
                voice.envelope.tick(attack_rate, decay_rate, sustain, release_rate);
                let frame_pos = voice.advance(buf.frames);
                let raw = buf.sample_at(frame_pos);
                *sample += raw * voice.volume * voice.envelope.level;
            }
        }
        self.voices.retain(|v| !v.is_done());
    }
}

impl DspNode for SamplerNode {
    fn process(
        &mut self,
        _inputs: &[Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS],
        output: &mut [f32; BUFFER_SIZE],
        _params: &mut ParamBlock,
        _sample_rate: f32,
    ) {
        output.fill(0.0);
        self.process_midi_events();
        self.render_voices(output);
    }

    fn type_name(&self) -> &'static str { "SamplerNode" }
}