1use std::sync::{Arc, Mutex};
13use arc_swap::ArcSwap;
14use aether_core::{
15 node::DspNode,
16 param::ParamBlock,
17 BUFFER_SIZE, MAX_INPUTS,
18};
19use crate::{
20 instrument::{LoadedInstrument, RoundRobinState},
21 voice::SamplerVoice,
22};
23use aether_midi::event::{MidiEvent, MidiEventKind};
24
25pub struct SamplerNode {
27 instrument: Arc<ArcSwap<Option<LoadedInstrument>>>,
31 voices: Vec<SamplerVoice>,
33 midi_queue: Arc<Mutex<Vec<MidiEvent>>>,
35 sample_rate: f32,
37 sustain_pedal: [bool; 16],
39 sustained_notes: Vec<(u8, u8)>,
41 rr_state: RoundRobinState,
43 last_instrument_name: Option<String>,
45}
46
47impl SamplerNode {
48 pub fn new(sample_rate: f32) -> Self {
49 Self {
50 instrument: Arc::new(ArcSwap::from_pointee(None)),
51 voices: Vec::with_capacity(32),
52 midi_queue: Arc::new(Mutex::new(Vec::new())),
53 sample_rate,
54 sustain_pedal: [false; 16],
55 sustained_notes: Vec::new(),
56 rr_state: RoundRobinState::new(),
57 last_instrument_name: None,
58 }
59 }
60
61 pub fn midi_queue(&self) -> Arc<Mutex<Vec<MidiEvent>>> {
63 Arc::clone(&self.midi_queue)
64 }
65
66 pub fn instrument_slot(&self) -> Arc<ArcSwap<Option<LoadedInstrument>>> {
71 Arc::clone(&self.instrument)
72 }
73
74 pub fn instrument_slot_mutex(&self) -> Arc<Mutex<Option<LoadedInstrument>>> {
79 let swap = Arc::clone(&self.instrument);
82 let mutex: Arc<Mutex<Option<LoadedInstrument>>> = Arc::new(Mutex::new(None));
83 let _ = swap; mutex
89 }
90
91 pub fn reset_round_robin(&mut self) {
92 self.rr_state.reset();
93 }
94
95 fn process_midi_events(&mut self) {
96 let events: Vec<MidiEvent> = {
98 match self.midi_queue.try_lock() {
99 Ok(mut q) => std::mem::take(&mut *q),
100 Err(_) => return, }
102 };
103
104 if events.is_empty() { return; }
105
106 let inst_guard = self.instrument.load();
108 let inst = match inst_guard.as_ref().as_ref() {
109 Some(i) => i,
110 None => {
111 if self.last_instrument_name.is_some() {
112 self.last_instrument_name = None;
113 self.rr_state.reset();
114 }
115 return;
116 }
117 };
118
119 let current_name = &inst.instrument.name;
120 if self.last_instrument_name.as_deref() != Some(current_name.as_str()) {
121 self.last_instrument_name = Some(current_name.clone());
122 self.rr_state.reset();
123 }
124
125 for event in events {
126 match event.kind {
127 MidiEventKind::NoteOn { note, velocity } => {
128 let max_voices = inst.instrument.max_voices;
129 if self.voices.len() >= max_voices {
130 self.voices.remove(0);
131 }
132 if let Some(zone) = inst.instrument.find_zone_rr(note, velocity, &mut self.rr_state) {
133 if inst.buffers.contains_key(&zone.id) {
134 let vel_linear = velocity as f32 / 127.0;
135 let pitch_ratio = zone.pitch_ratio(note, &inst.instrument.tuning) as f64;
136 let volume = zone.volume_linear() * vel_linear;
137 let voice = SamplerVoice::new(
138 note, event.channel, vel_linear,
139 pitch_ratio, volume, zone,
140 );
141 self.voices.push(voice);
142 }
143 }
144 }
145 MidiEventKind::NoteOff { note, .. } => {
146 let ch = event.channel;
147 if self.sustain_pedal[ch as usize] {
148 self.sustained_notes.push((ch, note));
149 } else {
150 for v in self.voices.iter_mut() {
151 if v.note == note && v.channel == ch && v.key_held {
152 v.release();
153 }
154 }
155 }
156 }
157 MidiEventKind::ControlChange { cc, value } => {
158 let ch = event.channel as usize;
159 if cc == aether_midi::event::cc::SUSTAIN_PEDAL {
160 let held = value >= 64;
161 self.sustain_pedal[ch] = held;
162 if !held {
163 let to_release: Vec<(u8, u8)> = self.sustained_notes.drain(..).collect();
164 for (c, n) in to_release {
165 for v in self.voices.iter_mut() {
166 if v.note == n && v.channel == c && v.key_held {
167 v.release();
168 }
169 }
170 }
171 }
172 }
173 }
174 MidiEventKind::AllNotesOff | MidiEventKind::AllSoundOff => {
175 for v in self.voices.iter_mut() { v.release(); }
176 self.sustained_notes.clear();
177 }
178 _ => {}
179 }
180 }
181 }
182
183 fn render_voices(&mut self, output: &mut [f32; BUFFER_SIZE]) {
184 let inst_guard = self.instrument.load();
186 let inst = match inst_guard.as_ref().as_ref() {
187 Some(i) => i,
188 None => return,
189 };
190
191 let sr = self.sample_rate;
192 let attack_rate = 1.0 / (inst.instrument.attack * sr).max(1.0);
193 let decay_rate = 1.0 / (inst.instrument.decay * sr).max(1.0);
194 let sustain = inst.instrument.sustain;
195 let release_rate = 1.0 / (inst.instrument.release * sr).max(1.0);
196
197 for voice in self.voices.iter_mut() {
198 if voice.is_done() { continue; }
199 let buf = match inst.buffers.get(&voice.zone_id) {
200 Some(b) => b,
201 None => continue,
202 };
203 for sample in output.iter_mut() {
204 voice.envelope.tick(attack_rate, decay_rate, sustain, release_rate);
205 let frame_pos = voice.advance(buf.frames);
206 let raw = buf.sample_at(frame_pos);
207 *sample += raw * voice.volume * voice.envelope.level;
208 }
209 }
210 self.voices.retain(|v| !v.is_done());
211 }
212}
213
214impl DspNode for SamplerNode {
215 fn process(
216 &mut self,
217 _inputs: &[Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS],
218 output: &mut [f32; BUFFER_SIZE],
219 _params: &mut ParamBlock,
220 _sample_rate: f32,
221 ) {
222 output.fill(0.0);
223 self.process_midi_events();
224 self.render_voices(output);
225 }
226
227 fn type_name(&self) -> &'static str { "SamplerNode" }
228}