1use std::path::Path;
7use std::collections::HashMap;
8use serde::{Deserialize, Serialize};
9use crate::buffer::SampleBuffer;
10use aether_midi::tuning::TuningTable;
11
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14pub enum RoundRobinMode {
15 Sequential,
17 Random,
19 RandomNoRepeat,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ZoneGroup {
26 pub zones: Vec<SampleZone>,
28 pub mode: RoundRobinMode,
30}
31
32const DEFAULT_RR_SEED: u64 = 0x123456789ABCDEF0u64;
34
35#[derive(Debug, Clone)]
37pub struct RoundRobinState {
38 sequential_index: HashMap<usize, usize>,
40 last_selected: HashMap<usize, usize>,
42 rng_state: u64,
44 seed: u64,
46}
47
48impl RoundRobinState {
49 pub fn new() -> Self {
51 Self::with_seed(DEFAULT_RR_SEED)
52 }
53
54 pub fn with_seed(seed: u64) -> Self {
56 Self {
57 sequential_index: HashMap::new(),
58 last_selected: HashMap::new(),
59 rng_state: seed,
60 seed,
61 }
62 }
63
64 pub fn reset(&mut self) {
66 self.sequential_index.clear();
67 self.last_selected.clear();
68 self.rng_state = self.seed;
69 }
70
71 pub fn xorshift64(&mut self) -> u64 {
73 let mut x = self.rng_state;
74 x ^= x << 13;
75 x ^= x >> 7;
76 x ^= x << 17;
77 self.rng_state = x;
78 x
79 }
80
81 pub fn select(
84 &mut self,
85 group_idx: usize,
86 group_len: usize,
87 mode: &RoundRobinMode,
88 ) -> usize {
89 if group_len == 0 {
90 return 0;
91 }
92
93 let n = group_len;
94 match mode {
95 RoundRobinMode::Sequential => {
96 let current = self.sequential_index.get(&group_idx).copied().unwrap_or(0);
97 let selected = current % n;
98 self.sequential_index.insert(group_idx, (current + 1) % n);
99 selected
100 }
101 RoundRobinMode::Random => (self.xorshift64() % n as u64) as usize,
102 RoundRobinMode::RandomNoRepeat => {
103 if n == 1 {
104 0
105 } else {
106 let last = self.last_selected.get(&group_idx).copied();
107 let mut selected = (self.xorshift64() % n as u64) as usize;
108 for _ in 0..n {
110 if Some(selected) != last {
111 break;
112 }
113 selected = (self.xorshift64() % n as u64) as usize;
114 }
115 if Some(selected) == last {
117 selected = (last.unwrap() + 1) % n;
118 }
119 self.last_selected.insert(group_idx, selected);
120 selected
121 }
122 }
123 }
124 }
125
126 fn next_random(&mut self) -> u64 {
128 self.xorshift64()
129 }
130
131 pub fn select_zone<'a>(
133 &mut self,
134 group_idx: usize,
135 zones: &'a [SampleZone],
136 mode: &RoundRobinMode,
137 ) -> Option<&'a SampleZone> {
138 if zones.is_empty() {
139 return None;
140 }
141
142 let n = zones.len();
143 let idx = match mode {
144 RoundRobinMode::Sequential => {
145 let current = self.sequential_index.get(&group_idx).copied().unwrap_or(0);
146 let selected = current % n;
147 self.sequential_index.insert(group_idx, (current + 1) % n);
148 selected
149 }
150 RoundRobinMode::Random => (self.next_random() % n as u64) as usize,
151 RoundRobinMode::RandomNoRepeat => {
152 if n == 1 {
153 0
154 } else {
155 let last = self.last_selected.get(&group_idx).copied();
156 let mut selected = (self.next_random() % n as u64) as usize;
157 for _ in 0..n {
159 if Some(selected) != last {
160 break;
161 }
162 selected = (self.next_random() % n as u64) as usize;
163 }
164 if Some(selected) == last {
166 selected = (last.unwrap() + 1) % n;
167 }
168 self.last_selected.insert(group_idx, selected);
169 selected
170 }
171 }
172 };
173
174 zones.get(idx)
175 }
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
180pub enum ArticulationType {
181 OneShot,
183 SustainLoop {
185 loop_start: usize,
187 loop_end: usize,
189 },
190 SustainRelease,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct SampleZone {
197 pub id: String,
199 pub file_path: String,
201 pub root_note: u8,
203 pub note_low: u8,
205 pub note_high: u8,
207 pub velocity_low: u8,
209 pub velocity_high: u8,
211 pub articulation: ArticulationType,
213 pub volume_db: f32,
215 pub tune_cents: f32,
217 pub release_file: Option<String>,
219}
220
221impl SampleZone {
222 pub fn matches(&self, note: u8, velocity: u8) -> bool {
224 note >= self.note_low && note <= self.note_high
225 && velocity >= self.velocity_low && velocity <= self.velocity_high
226 }
227
228 pub fn pitch_ratio(&self, target_note: u8, tuning: &TuningTable) -> f32 {
230 let root_freq = tuning.frequency(self.root_note);
231 let target_freq = tuning.frequency(target_note);
232 if root_freq > 0.0 {
233 let cents_offset = self.tune_cents;
234 (target_freq / root_freq) * 2.0f32.powf(cents_offset / 1200.0)
235 } else {
236 1.0
237 }
238 }
239
240 pub fn volume_linear(&self) -> f32 {
242 10.0f32.powf(self.volume_db / 20.0)
243 }
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248#[serde(from = "SamplerInstrumentRaw")]
249pub struct SamplerInstrument {
250 pub name: String,
252 pub origin: String,
254 pub description: String,
256 pub author: String,
258 pub tuning: TuningTable,
260 #[serde(default)]
262 pub zones: Vec<SampleZone>,
263 #[serde(default)]
265 pub zone_groups: Vec<ZoneGroup>,
266 pub attack: f32,
268 pub decay: f32,
269 pub sustain: f32,
270 pub release: f32,
271 pub max_voices: usize,
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
277struct SamplerInstrumentRaw {
278 pub name: String,
279 pub origin: String,
280 pub description: String,
281 pub author: String,
282 pub tuning: TuningTable,
283 #[serde(default)]
284 pub zones: Vec<SampleZone>,
285 #[serde(default)]
286 pub zone_groups: Vec<ZoneGroup>,
287 pub attack: f32,
288 pub decay: f32,
289 pub sustain: f32,
290 pub release: f32,
291 pub max_voices: usize,
292}
293
294impl From<SamplerInstrumentRaw> for SamplerInstrument {
295 fn from(raw: SamplerInstrumentRaw) -> Self {
296 let mut instrument = SamplerInstrument {
297 name: raw.name,
298 origin: raw.origin,
299 description: raw.description,
300 author: raw.author,
301 tuning: raw.tuning,
302 zones: raw.zones,
303 zone_groups: raw.zone_groups,
304 attack: raw.attack,
305 decay: raw.decay,
306 sustain: raw.sustain,
307 release: raw.release,
308 max_voices: raw.max_voices,
309 };
310 instrument.normalize();
311 instrument
312 }
313}
314
315impl SamplerInstrument {
316 pub fn new(name: &str) -> Self {
317 Self {
318 name: name.into(),
319 origin: String::new(),
320 description: String::new(),
321 author: String::new(),
322 tuning: TuningTable::default(),
323 zones: Vec::new(),
324 zone_groups: Vec::new(),
325 attack: 0.005,
326 decay: 0.1,
327 sustain: 0.8,
328 release: 0.3,
329 max_voices: 16,
330 }
331 }
332
333 pub fn normalize(&mut self) {
336 if self.zone_groups.is_empty() && !self.zones.is_empty() {
337 self.zone_groups = self
338 .zones
339 .iter()
340 .cloned()
341 .map(|zone| ZoneGroup {
342 zones: vec![zone],
343 mode: RoundRobinMode::Sequential,
344 })
345 .collect();
346 }
347 }
348
349 pub fn find_zone(&self, note: u8, velocity: u8) -> Option<&SampleZone> {
352 let mut best: Option<&SampleZone> = None;
353 let mut best_dist = u8::MAX;
354 for zone in &self.zones {
355 if zone.matches(note, velocity) {
356 let dist = note.abs_diff(zone.root_note);
357 if dist < best_dist {
358 best_dist = dist;
359 best = Some(zone);
360 }
361 }
362 }
363 best
364 }
365
366 pub fn find_zone_rr<'a>(
369 &'a self,
370 note: u8,
371 velocity: u8,
372 rr_state: &mut RoundRobinState,
373 ) -> Option<&'a SampleZone> {
374 for (group_idx, group) in self.zone_groups.iter().enumerate() {
375 if group.zones.iter().any(|z| z.matches(note, velocity)) {
377 return rr_state.select_zone(group_idx, &group.zones, &group.mode);
378 }
379 }
380 None
381 }
382
383 pub fn add_zone(&mut self, zone: SampleZone) {
385 self.zones.push(zone);
386 }
387
388 pub fn save(&self, path: &Path) -> std::io::Result<()> {
390 let json = serde_json::to_string_pretty(self).unwrap();
391 std::fs::write(path, json)
392 }
393
394 pub fn load(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
396 let json = std::fs::read_to_string(path)?;
397 Ok(serde_json::from_str(&json)?)
398 }
399}
400
401pub struct LoadedInstrument {
403 pub instrument: SamplerInstrument,
404 pub buffers: HashMap<String, SampleBuffer>,
406 pub release_buffers: HashMap<String, SampleBuffer>,
408}
409
410impl LoadedInstrument {
411 pub fn load(instrument: SamplerInstrument, base_dir: &Path) -> Result<Self, Box<dyn std::error::Error>> {
413 let mut buffers = HashMap::new();
414 let mut release_buffers = HashMap::new();
415
416 for zone in &instrument.zones {
417 let path = base_dir.join(&zone.file_path);
418 let buf = SampleBuffer::load_wav(&path)?;
419 buffers.insert(zone.id.clone(), buf);
420
421 if let Some(ref rel_path) = zone.release_file {
422 let rpath = base_dir.join(rel_path);
423 if rpath.exists() {
424 let rbuf = SampleBuffer::load_wav(&rpath)?;
425 release_buffers.insert(zone.id.clone(), rbuf);
426 }
427 }
428 }
429
430 Ok(Self { instrument, buffers, release_buffers })
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437 use proptest::prelude::*;
438
439 #[test]
440 fn test_round_robin_mode_derives() {
441 let mode = RoundRobinMode::Sequential;
442 let cloned = mode.clone();
443 assert_eq!(mode, cloned);
444 }
445
446 #[test]
447 fn test_zone_group_creation() {
448 let zone = SampleZone {
449 id: "test".into(),
450 file_path: "test.wav".into(),
451 root_note: 60,
452 note_low: 60,
453 note_high: 60,
454 velocity_low: 0,
455 velocity_high: 127,
456 articulation: ArticulationType::OneShot,
457 volume_db: 0.0,
458 tune_cents: 0.0,
459 release_file: None,
460 };
461
462 let group = ZoneGroup {
463 zones: vec![zone],
464 mode: RoundRobinMode::Sequential,
465 };
466
467 assert_eq!(group.zones.len(), 1);
468 assert_eq!(group.mode, RoundRobinMode::Sequential);
469 }
470
471 #[test]
472 fn test_round_robin_state_sequential() {
473 let mut state = RoundRobinState::with_seed(12345);
474
475 let zones = vec![
476 SampleZone {
477 id: "zone1".into(),
478 file_path: "test1.wav".into(),
479 root_note: 60,
480 note_low: 60,
481 note_high: 60,
482 velocity_low: 0,
483 velocity_high: 127,
484 articulation: ArticulationType::OneShot,
485 volume_db: 0.0,
486 tune_cents: 0.0,
487 release_file: None,
488 },
489 SampleZone {
490 id: "zone2".into(),
491 file_path: "test2.wav".into(),
492 root_note: 60,
493 note_low: 60,
494 note_high: 60,
495 velocity_low: 0,
496 velocity_high: 127,
497 articulation: ArticulationType::OneShot,
498 volume_db: 0.0,
499 tune_cents: 0.0,
500 release_file: None,
501 },
502 ];
503
504 let z1 = state.select_zone(0, &zones, &RoundRobinMode::Sequential);
506 assert_eq!(z1.unwrap().id, "zone1");
507
508 let z2 = state.select_zone(0, &zones, &RoundRobinMode::Sequential);
509 assert_eq!(z2.unwrap().id, "zone2");
510
511 let z3 = state.select_zone(0, &zones, &RoundRobinMode::Sequential);
512 assert_eq!(z3.unwrap().id, "zone1");
513 }
514
515 #[test]
516 fn test_sampler_instrument_normalize() {
517 let mut instrument = SamplerInstrument::new("test");
518
519 let zone = SampleZone {
520 id: "zone1".into(),
521 file_path: "test.wav".into(),
522 root_note: 60,
523 note_low: 60,
524 note_high: 60,
525 velocity_low: 0,
526 velocity_high: 127,
527 articulation: ArticulationType::OneShot,
528 volume_db: 0.0,
529 tune_cents: 0.0,
530 release_file: None,
531 };
532
533 instrument.zones.push(zone);
534 instrument.normalize();
535
536 assert_eq!(instrument.zone_groups.len(), 1);
537 assert_eq!(instrument.zone_groups[0].zones.len(), 1);
538 assert_eq!(instrument.zone_groups[0].mode, RoundRobinMode::Sequential);
539 }
540
541 #[test]
542 fn test_find_zone_rr() {
543 let mut instrument = SamplerInstrument::new("test");
544
545 let zone1 = SampleZone {
546 id: "zone1".into(),
547 file_path: "test1.wav".into(),
548 root_note: 60,
549 note_low: 60,
550 note_high: 60,
551 velocity_low: 0,
552 velocity_high: 127,
553 articulation: ArticulationType::OneShot,
554 volume_db: 0.0,
555 tune_cents: 0.0,
556 release_file: None,
557 };
558
559 let zone2 = SampleZone {
560 id: "zone2".into(),
561 file_path: "test2.wav".into(),
562 root_note: 60,
563 note_low: 60,
564 note_high: 60,
565 velocity_low: 0,
566 velocity_high: 127,
567 articulation: ArticulationType::OneShot,
568 volume_db: 0.0,
569 tune_cents: 0.0,
570 release_file: None,
571 };
572
573 instrument.zone_groups.push(ZoneGroup {
574 zones: vec![zone1, zone2],
575 mode: RoundRobinMode::Sequential,
576 });
577
578 let mut rr_state = RoundRobinState::with_seed(12345);
579
580 let z1 = instrument.find_zone_rr(60, 100, &mut rr_state);
581 assert!(z1.is_some());
582 assert_eq!(z1.unwrap().id, "zone1");
583
584 let z2 = instrument.find_zone_rr(60, 100, &mut rr_state);
585 assert!(z2.is_some());
586 assert_eq!(z2.unwrap().id, "zone2");
587 }
588
589 proptest! {
591 #[test]
601 fn prop_sequential_round_robin_full_cycle(
602 n in 1usize..=16,
603 ) {
604 let mut zones = Vec::new();
606 for i in 0..n {
607 zones.push(SampleZone {
608 id: format!("zone_{}", i),
609 file_path: format!("sample_{}.wav", i),
610 root_note: 60,
611 note_low: 60,
612 note_high: 60,
613 velocity_low: 0,
614 velocity_high: 127,
615 articulation: ArticulationType::OneShot,
616 volume_db: 0.0,
617 tune_cents: 0.0,
618 release_file: None,
619 });
620 }
621
622 let group = ZoneGroup {
623 zones,
624 mode: RoundRobinMode::Sequential,
625 };
626
627 let mut rr_state = RoundRobinState::with_seed(12345);
629
630 let mut selected_indices = Vec::new();
632 for _ in 0..n {
633 let idx = rr_state.select(0, n, &group.mode);
634 selected_indices.push(idx);
635 }
636
637 let mut sorted_indices = selected_indices.clone();
639 sorted_indices.sort_unstable();
640
641 prop_assert_eq!(selected_indices.len(), n);
643
644 let expected: Vec<usize> = (0..n).collect();
646 prop_assert_eq!(sorted_indices, expected);
647
648 for i in 0..n {
650 let count = selected_indices.iter().filter(|&&x| x == i).count();
651 prop_assert_eq!(count, 1, "Zone index {} should appear exactly once, but appeared {} times", i, count);
652 }
653 }
654 }
655
656 proptest! {
658 #[test]
668 fn prop_random_no_repeat_no_consecutive_repeats(
669 n in 2usize..=16,
670 seed in any::<u64>(),
671 ) {
672 let mut rr_state = RoundRobinState::with_seed(seed);
674
675 let mut selected_indices = Vec::new();
677 for _ in 0..100 {
678 let idx = rr_state.select(0, n, &RoundRobinMode::RandomNoRepeat);
679 selected_indices.push(idx);
680 }
681
682 for i in 0..selected_indices.len() - 1 {
684 let current = selected_indices[i];
685 let next = selected_indices[i + 1];
686 prop_assert_ne!(
687 current,
688 next,
689 "RandomNoRepeat violated: zone {} was selected twice in a row at positions {} and {}",
690 current,
691 i,
692 i + 1
693 );
694 }
695
696 for (i, &idx) in selected_indices.iter().enumerate() {
698 prop_assert!(
699 idx < n,
700 "Invalid zone index {} at position {} (should be < {})",
701 idx,
702 i,
703 n
704 );
705 }
706 }
707 }
708
709 proptest! {
711 #[test]
722 fn prop_backward_compatible_instrument_loading(
723 zone_count in 1usize..=10,
724 note_ranges in prop::collection::vec((0u8..=127u8, 0u8..=127u8), 1..=10),
725 velocity_ranges in prop::collection::vec((0u8..=127u8, 0u8..=127u8), 1..=10),
726 ) {
727 let mut zones = Vec::new();
729 for i in 0..zone_count {
730 let (note_low, note_high_offset) = note_ranges[i % note_ranges.len()];
731 let note_high = note_low.saturating_add(note_high_offset % 12);
732 let root_note = note_low + (note_high - note_low) / 2;
733
734 let (vel_low, vel_high_offset) = velocity_ranges[i % velocity_ranges.len()];
735 let vel_high = vel_low.saturating_add(vel_high_offset);
736
737 zones.push(SampleZone {
738 id: format!("zone_{}", i),
739 file_path: format!("sample_{}.wav", i),
740 root_note,
741 note_low,
742 note_high,
743 velocity_low: vel_low,
744 velocity_high: vel_high,
745 articulation: ArticulationType::OneShot,
746 volume_db: 0.0,
747 tune_cents: 0.0,
748 release_file: None,
749 });
750 }
751
752 let legacy_json = serde_json::json!({
754 "name": "Legacy Instrument",
755 "origin": "Test",
756 "description": "Test legacy instrument",
757 "author": "Test",
758 "tuning": {
759 "name": "12-TET",
760 "reference_note": 69,
761 "reference_freq": 440.0,
762 "cents_map": []
763 },
764 "zones": zones,
765 "attack": 0.005,
766 "decay": 0.1,
767 "sustain": 0.8,
768 "release": 0.3,
769 "max_voices": 16
770 });
771
772 let instrument: SamplerInstrument = serde_json::from_value(legacy_json)
774 .expect("Failed to deserialize legacy instrument");
775
776 prop_assert_eq!(instrument.zone_groups.len(), zones.len());
778 for (i, group) in instrument.zone_groups.iter().enumerate() {
779 prop_assert_eq!(group.zones.len(), 1);
780 prop_assert_eq!(group.mode, RoundRobinMode::Sequential);
781 prop_assert_eq!(group.zones[0].id, format!("zone_{}", i));
782 }
783
784 let mut rr_state = RoundRobinState::with_seed(12345);
786 for note in 0u8..=127 {
787 for velocity in [1u8, 64, 127] {
788 let legacy_zone = instrument.find_zone(note, velocity);
789 let rr_zone = instrument.find_zone_rr(note, velocity, &mut rr_state);
790
791 prop_assert_eq!(legacy_zone.is_some(), rr_zone.is_some());
793
794 if let (Some(legacy), Some(rr)) = (legacy_zone, rr_zone) {
796 prop_assert_eq!(legacy.id, rr.id);
797 }
798 }
799 }
800 }
801 }
802}