Skip to main content

aether_sampler/
node.rs

1//! SamplerNode — integrates the sampler into the AetherDSP graph.
2//!
3//! This is a DspNode that:
4//!   1. Receives MIDI events via a lock-free queue
5//!   2. Manages polyphonic voices
6//!   3. Renders audio into the output buffer
7
8use std::sync::{Arc, Mutex};
9use aether_core::{
10    node::DspNode,
11    param::ParamBlock,
12    BUFFER_SIZE, MAX_INPUTS,
13};
14use crate::{
15    instrument::{LoadedInstrument, RoundRobinState},
16    voice::SamplerVoice,
17};
18use aether_midi::event::{MidiEvent, MidiEventKind};
19
20/// A polyphonic sampler node.
21pub struct SamplerNode {
22    /// The loaded instrument (shared with the instrument maker UI).
23    instrument: Arc<Mutex<Option<LoadedInstrument>>>,
24    /// Active voices.
25    voices: Vec<SamplerVoice>,
26    /// Pending MIDI events (written by MIDI thread, read by audio thread).
27    midi_queue: Arc<Mutex<Vec<MidiEvent>>>,
28    /// Sample rate.
29    sample_rate: f32,
30    /// Sustain pedal state per channel.
31    sustain_pedal: [bool; 16],
32    /// Notes held by sustain pedal (released by key but sustained by pedal).
33    sustained_notes: Vec<(u8, u8)>, // (channel, note)
34    /// Round-robin state for zone selection.
35    rr_state: RoundRobinState,
36    /// Last loaded instrument name (to detect instrument changes).
37    last_instrument_name: Option<String>,
38}
39
40impl SamplerNode {
41    pub fn new(sample_rate: f32) -> Self {
42        Self {
43            instrument: Arc::new(Mutex::new(None)),
44            voices: Vec::with_capacity(32),
45            midi_queue: Arc::new(Mutex::new(Vec::new())),
46            sample_rate,
47            sustain_pedal: [false; 16],
48            sustained_notes: Vec::new(),
49            rr_state: RoundRobinState::new(),
50            last_instrument_name: None,
51        }
52    }
53
54    /// Get the MIDI queue for pushing events from the MIDI thread.
55    pub fn midi_queue(&self) -> Arc<Mutex<Vec<MidiEvent>>> {
56        Arc::clone(&self.midi_queue)
57    }
58
59    /// Get the instrument slot for loading/replacing instruments.
60    pub fn instrument_slot(&self) -> Arc<Mutex<Option<LoadedInstrument>>> {
61        Arc::clone(&self.instrument)
62    }
63
64    /// Reset the round-robin state (call when loading a new instrument).
65    pub fn reset_round_robin(&mut self) {
66        self.rr_state.reset();
67    }
68
69    fn process_midi_events(&mut self) {
70        let events: Vec<MidiEvent> = {
71            let mut q = self.midi_queue.lock().unwrap();
72            std::mem::take(&mut *q)
73        };
74
75        let inst_guard = self.instrument.lock().unwrap();
76        let inst = match inst_guard.as_ref() {
77            Some(i) => i,
78            None => {
79                // No instrument loaded, reset tracking
80                if self.last_instrument_name.is_some() {
81                    self.last_instrument_name = None;
82                    self.rr_state.reset();
83                }
84                return;
85            }
86        };
87
88        // Check if instrument changed and reset round-robin state if so
89        let current_name = &inst.instrument.name;
90        if self.last_instrument_name.as_ref() != Some(current_name) {
91            self.last_instrument_name = Some(current_name.clone());
92            self.rr_state.reset();
93        }
94
95        for event in events {
96            match event.kind {
97                MidiEventKind::NoteOn { note, velocity } => {
98                    // Steal oldest voice if at max polyphony
99                    let max_voices = inst.instrument.max_voices;
100                    if self.voices.len() >= max_voices {
101                        self.voices.remove(0);
102                    }
103
104                    if let Some(zone) = inst.instrument.find_zone_rr(note, velocity, &mut self.rr_state) {
105                        if inst.buffers.contains_key(&zone.id) {
106                            let vel_linear = velocity as f32 / 127.0;
107                            let pitch_ratio = zone.pitch_ratio(note, &inst.instrument.tuning) as f64;
108                            let volume = zone.volume_linear() * vel_linear;
109                            let voice = SamplerVoice::new(
110                                note, event.channel, vel_linear,
111                                pitch_ratio, volume, zone,
112                            );
113                            self.voices.push(voice);
114                        }
115                    }
116                }
117
118                MidiEventKind::NoteOff { note, .. } => {
119                    let ch = event.channel;
120                    if self.sustain_pedal[ch as usize] {
121                        self.sustained_notes.push((ch, note));
122                    } else {
123                        for v in self.voices.iter_mut() {
124                            if v.note == note && v.channel == ch && v.key_held {
125                                v.release();
126                            }
127                        }
128                    }
129                }
130
131                MidiEventKind::ControlChange { cc, value } => {
132                    let ch = event.channel as usize;
133                    if cc == aether_midi::event::cc::SUSTAIN_PEDAL {
134                        let held = value >= 64;
135                        self.sustain_pedal[ch] = held;
136                        if !held {
137                            // Release all sustained notes
138                            let to_release: Vec<(u8, u8)> = self.sustained_notes.drain(..).collect();
139                            for (c, n) in to_release {
140                                for v in self.voices.iter_mut() {
141                                    if v.note == n && v.channel == c && v.key_held {
142                                        v.release();
143                                    }
144                                }
145                            }
146                        }
147                    }
148                }
149
150                MidiEventKind::AllNotesOff | MidiEventKind::AllSoundOff => {
151                    for v in self.voices.iter_mut() {
152                        v.release();
153                    }
154                    self.sustained_notes.clear();
155                }
156
157                _ => {}
158            }
159        }
160    }
161
162    fn render_voices(&mut self, output: &mut [f32; BUFFER_SIZE]) {
163        let inst_guard = self.instrument.lock().unwrap();
164        let inst = match inst_guard.as_ref() {
165            Some(i) => i,
166            None => return,
167        };
168
169        let sr = self.sample_rate;
170        let attack_rate = 1.0 / (inst.instrument.attack * sr).max(1.0);
171        let decay_rate = 1.0 / (inst.instrument.decay * sr).max(1.0);
172        let sustain = inst.instrument.sustain;
173        let release_rate = 1.0 / (inst.instrument.release * sr).max(1.0);
174
175        for voice in self.voices.iter_mut() {
176            if voice.is_done() { continue; }
177
178            let buf = match inst.buffers.get(&voice.zone_id) {
179                Some(b) => b,
180                None => continue,
181            };
182
183            for sample in output.iter_mut() {
184                voice.envelope.tick(attack_rate, decay_rate, sustain, release_rate);
185                let frame_pos = voice.advance(buf.frames);
186                let raw = buf.sample_at(frame_pos);
187                *sample += raw * voice.volume * voice.envelope.level;
188            }
189        }
190
191        // Remove finished voices
192        self.voices.retain(|v| !v.is_done());
193    }
194}
195
196impl DspNode for SamplerNode {
197    fn process(
198        &mut self,
199        _inputs: &[Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS],
200        output: &mut [f32; BUFFER_SIZE],
201        _params: &mut ParamBlock,
202        _sample_rate: f32,
203    ) {
204        output.fill(0.0);
205        self.process_midi_events();
206        self.render_voices(output);
207    }
208
209    fn type_name(&self) -> &'static str { "SamplerNode" }
210}