1use std::sync::{Arc, Mutex};
9use aether_core::{
10 node::DspNode,
11 param::ParamBlock,
12 BUFFER_SIZE, MAX_INPUTS,
13};
14use crate::{
15 instrument::{LoadedInstrument, RoundRobinState},
16 voice::SamplerVoice,
17};
18use aether_midi::event::{MidiEvent, MidiEventKind};
19
20pub struct SamplerNode {
22 instrument: Arc<Mutex<Option<LoadedInstrument>>>,
24 voices: Vec<SamplerVoice>,
26 midi_queue: Arc<Mutex<Vec<MidiEvent>>>,
28 sample_rate: f32,
30 sustain_pedal: [bool; 16],
32 sustained_notes: Vec<(u8, u8)>, rr_state: RoundRobinState,
36 last_instrument_name: Option<String>,
38}
39
40impl SamplerNode {
41 pub fn new(sample_rate: f32) -> Self {
42 Self {
43 instrument: Arc::new(Mutex::new(None)),
44 voices: Vec::with_capacity(32),
45 midi_queue: Arc::new(Mutex::new(Vec::new())),
46 sample_rate,
47 sustain_pedal: [false; 16],
48 sustained_notes: Vec::new(),
49 rr_state: RoundRobinState::new(),
50 last_instrument_name: None,
51 }
52 }
53
54 pub fn midi_queue(&self) -> Arc<Mutex<Vec<MidiEvent>>> {
56 Arc::clone(&self.midi_queue)
57 }
58
59 pub fn instrument_slot(&self) -> Arc<Mutex<Option<LoadedInstrument>>> {
61 Arc::clone(&self.instrument)
62 }
63
64 pub fn reset_round_robin(&mut self) {
66 self.rr_state.reset();
67 }
68
69 fn process_midi_events(&mut self) {
70 let events: Vec<MidiEvent> = {
71 let mut q = self.midi_queue.lock().unwrap();
72 std::mem::take(&mut *q)
73 };
74
75 let inst_guard = self.instrument.lock().unwrap();
76 let inst = match inst_guard.as_ref() {
77 Some(i) => i,
78 None => {
79 if self.last_instrument_name.is_some() {
81 self.last_instrument_name = None;
82 self.rr_state.reset();
83 }
84 return;
85 }
86 };
87
88 let current_name = &inst.instrument.name;
90 if self.last_instrument_name.as_ref() != Some(current_name) {
91 self.last_instrument_name = Some(current_name.clone());
92 self.rr_state.reset();
93 }
94
95 for event in events {
96 match event.kind {
97 MidiEventKind::NoteOn { note, velocity } => {
98 let max_voices = inst.instrument.max_voices;
100 if self.voices.len() >= max_voices {
101 self.voices.remove(0);
102 }
103
104 if let Some(zone) = inst.instrument.find_zone_rr(note, velocity, &mut self.rr_state) {
105 if inst.buffers.contains_key(&zone.id) {
106 let vel_linear = velocity as f32 / 127.0;
107 let pitch_ratio = zone.pitch_ratio(note, &inst.instrument.tuning) as f64;
108 let volume = zone.volume_linear() * vel_linear;
109 let voice = SamplerVoice::new(
110 note, event.channel, vel_linear,
111 pitch_ratio, volume, zone,
112 );
113 self.voices.push(voice);
114 }
115 }
116 }
117
118 MidiEventKind::NoteOff { note, .. } => {
119 let ch = event.channel;
120 if self.sustain_pedal[ch as usize] {
121 self.sustained_notes.push((ch, note));
122 } else {
123 for v in self.voices.iter_mut() {
124 if v.note == note && v.channel == ch && v.key_held {
125 v.release();
126 }
127 }
128 }
129 }
130
131 MidiEventKind::ControlChange { cc, value } => {
132 let ch = event.channel as usize;
133 if cc == aether_midi::event::cc::SUSTAIN_PEDAL {
134 let held = value >= 64;
135 self.sustain_pedal[ch] = held;
136 if !held {
137 let to_release: Vec<(u8, u8)> = self.sustained_notes.drain(..).collect();
139 for (c, n) in to_release {
140 for v in self.voices.iter_mut() {
141 if v.note == n && v.channel == c && v.key_held {
142 v.release();
143 }
144 }
145 }
146 }
147 }
148 }
149
150 MidiEventKind::AllNotesOff | MidiEventKind::AllSoundOff => {
151 for v in self.voices.iter_mut() {
152 v.release();
153 }
154 self.sustained_notes.clear();
155 }
156
157 _ => {}
158 }
159 }
160 }
161
162 fn render_voices(&mut self, output: &mut [f32; BUFFER_SIZE]) {
163 let inst_guard = self.instrument.lock().unwrap();
164 let inst = match inst_guard.as_ref() {
165 Some(i) => i,
166 None => return,
167 };
168
169 let sr = self.sample_rate;
170 let attack_rate = 1.0 / (inst.instrument.attack * sr).max(1.0);
171 let decay_rate = 1.0 / (inst.instrument.decay * sr).max(1.0);
172 let sustain = inst.instrument.sustain;
173 let release_rate = 1.0 / (inst.instrument.release * sr).max(1.0);
174
175 for voice in self.voices.iter_mut() {
176 if voice.is_done() { continue; }
177
178 let buf = match inst.buffers.get(&voice.zone_id) {
179 Some(b) => b,
180 None => continue,
181 };
182
183 for sample in output.iter_mut() {
184 voice.envelope.tick(attack_rate, decay_rate, sustain, release_rate);
185 let frame_pos = voice.advance(buf.frames);
186 let raw = buf.sample_at(frame_pos);
187 *sample += raw * voice.volume * voice.envelope.level;
188 }
189 }
190
191 self.voices.retain(|v| !v.is_done());
193 }
194}
195
196impl DspNode for SamplerNode {
197 fn process(
198 &mut self,
199 _inputs: &[Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS],
200 output: &mut [f32; BUFFER_SIZE],
201 _params: &mut ParamBlock,
202 _sample_rate: f32,
203 ) {
204 output.fill(0.0);
205 self.process_midi_events();
206 self.render_voices(output);
207 }
208
209 fn type_name(&self) -> &'static str { "SamplerNode" }
210}