1use std::path::Path;
7use std::collections::HashMap;
8use serde::{Deserialize, Serialize};
9use crate::buffer::SampleBuffer;
10use aether_midi::tuning::TuningTable;
11
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14pub enum RoundRobinMode {
15 Sequential,
17 Random,
19 RandomNoRepeat,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ZoneGroup {
26 pub zones: Vec<SampleZone>,
28 pub mode: RoundRobinMode,
30}
31
32const DEFAULT_RR_SEED: u64 = 0x123456789ABCDEF0u64;
34
35#[derive(Debug, Clone)]
37pub struct RoundRobinState {
38 sequential_index: HashMap<usize, usize>,
40 last_selected: HashMap<usize, usize>,
42 rng_state: u64,
44 seed: u64,
46}
47
48impl RoundRobinState {
49 pub fn new() -> Self {
51 Self::with_seed(DEFAULT_RR_SEED)
52 }
53}
54
55impl Default for RoundRobinState {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl RoundRobinState {
62 pub fn with_seed(seed: u64) -> Self {
63 Self {
64 sequential_index: HashMap::new(),
65 last_selected: HashMap::new(),
66 rng_state: seed,
67 seed,
68 }
69 }
70
71 pub fn reset(&mut self) {
73 self.sequential_index.clear();
74 self.last_selected.clear();
75 self.rng_state = self.seed;
76 }
77
78 pub fn xorshift64(&mut self) -> u64 {
80 let mut x = self.rng_state;
81 x ^= x << 13;
82 x ^= x >> 7;
83 x ^= x << 17;
84 self.rng_state = x;
85 x
86 }
87
88 pub fn select(
91 &mut self,
92 group_idx: usize,
93 group_len: usize,
94 mode: &RoundRobinMode,
95 ) -> usize {
96 if group_len == 0 {
97 return 0;
98 }
99
100 let n = group_len;
101 match mode {
102 RoundRobinMode::Sequential => {
103 let current = self.sequential_index.get(&group_idx).copied().unwrap_or(0);
104 let selected = current % n;
105 self.sequential_index.insert(group_idx, (current + 1) % n);
106 selected
107 }
108 RoundRobinMode::Random => (self.xorshift64() % n as u64) as usize,
109 RoundRobinMode::RandomNoRepeat => {
110 if n == 1 {
111 0
112 } else {
113 let last = self.last_selected.get(&group_idx).copied();
114 let mut selected = (self.xorshift64() % n as u64) as usize;
115 for _ in 0..n {
117 if Some(selected) != last {
118 break;
119 }
120 selected = (self.xorshift64() % n as u64) as usize;
121 }
122 if Some(selected) == last {
124 selected = (last.unwrap() + 1) % n;
125 }
126 self.last_selected.insert(group_idx, selected);
127 selected
128 }
129 }
130 }
131 }
132
133 fn next_random(&mut self) -> u64 {
135 self.xorshift64()
136 }
137
138 pub fn select_zone<'a>(
140 &mut self,
141 group_idx: usize,
142 zones: &'a [SampleZone],
143 mode: &RoundRobinMode,
144 ) -> Option<&'a SampleZone> {
145 if zones.is_empty() {
146 return None;
147 }
148
149 let n = zones.len();
150 let idx = match mode {
151 RoundRobinMode::Sequential => {
152 let current = self.sequential_index.get(&group_idx).copied().unwrap_or(0);
153 let selected = current % n;
154 self.sequential_index.insert(group_idx, (current + 1) % n);
155 selected
156 }
157 RoundRobinMode::Random => (self.next_random() % n as u64) as usize,
158 RoundRobinMode::RandomNoRepeat => {
159 if n == 1 {
160 0
161 } else {
162 let last = self.last_selected.get(&group_idx).copied();
163 let mut selected = (self.next_random() % n as u64) as usize;
164 for _ in 0..n {
166 if Some(selected) != last {
167 break;
168 }
169 selected = (self.next_random() % n as u64) as usize;
170 }
171 if Some(selected) == last {
173 selected = (last.unwrap() + 1) % n;
174 }
175 self.last_selected.insert(group_idx, selected);
176 selected
177 }
178 }
179 };
180
181 zones.get(idx)
182 }
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
187pub enum ArticulationType {
188 OneShot,
190 SustainLoop {
192 loop_start: usize,
194 loop_end: usize,
196 },
197 SustainRelease,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct SampleZone {
204 pub id: String,
206 pub file_path: String,
208 pub root_note: u8,
210 pub note_low: u8,
212 pub note_high: u8,
214 pub velocity_low: u8,
216 pub velocity_high: u8,
218 pub articulation: ArticulationType,
220 pub volume_db: f32,
222 pub tune_cents: f32,
224 pub release_file: Option<String>,
226}
227
228impl SampleZone {
229 pub fn matches(&self, note: u8, velocity: u8) -> bool {
231 note >= self.note_low && note <= self.note_high
232 && velocity >= self.velocity_low && velocity <= self.velocity_high
233 }
234
235 pub fn pitch_ratio(&self, target_note: u8, tuning: &TuningTable) -> f32 {
237 let root_freq = tuning.frequency(self.root_note);
238 let target_freq = tuning.frequency(target_note);
239 if root_freq > 0.0 {
240 let cents_offset = self.tune_cents;
241 (target_freq / root_freq) * 2.0f32.powf(cents_offset / 1200.0)
242 } else {
243 1.0
244 }
245 }
246
247 pub fn volume_linear(&self) -> f32 {
249 10.0f32.powf(self.volume_db / 20.0)
250 }
251}
252
253#[derive(Debug, Clone, Serialize, Deserialize)]
255#[serde(from = "SamplerInstrumentRaw")]
256pub struct SamplerInstrument {
257 pub name: String,
259 pub origin: String,
261 pub description: String,
263 pub author: String,
265 pub tuning: TuningTable,
267 #[serde(default)]
269 pub zones: Vec<SampleZone>,
270 #[serde(default)]
272 pub zone_groups: Vec<ZoneGroup>,
273 pub attack: f32,
275 pub decay: f32,
276 pub sustain: f32,
277 pub release: f32,
278 pub max_voices: usize,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize)]
284struct SamplerInstrumentRaw {
285 pub name: String,
286 pub origin: String,
287 pub description: String,
288 pub author: String,
289 pub tuning: TuningTable,
290 #[serde(default)]
291 pub zones: Vec<SampleZone>,
292 #[serde(default)]
293 pub zone_groups: Vec<ZoneGroup>,
294 pub attack: f32,
295 pub decay: f32,
296 pub sustain: f32,
297 pub release: f32,
298 pub max_voices: usize,
299}
300
301impl From<SamplerInstrumentRaw> for SamplerInstrument {
302 fn from(raw: SamplerInstrumentRaw) -> Self {
303 let mut instrument = SamplerInstrument {
304 name: raw.name,
305 origin: raw.origin,
306 description: raw.description,
307 author: raw.author,
308 tuning: raw.tuning,
309 zones: raw.zones,
310 zone_groups: raw.zone_groups,
311 attack: raw.attack,
312 decay: raw.decay,
313 sustain: raw.sustain,
314 release: raw.release,
315 max_voices: raw.max_voices,
316 };
317 instrument.normalize();
318 instrument
319 }
320}
321
322impl SamplerInstrument {
323 pub fn new(name: &str) -> Self {
324 Self {
325 name: name.into(),
326 origin: String::new(),
327 description: String::new(),
328 author: String::new(),
329 tuning: TuningTable::default(),
330 zones: Vec::new(),
331 zone_groups: Vec::new(),
332 attack: 0.005,
333 decay: 0.1,
334 sustain: 0.8,
335 release: 0.3,
336 max_voices: 16,
337 }
338 }
339
340 pub fn normalize(&mut self) {
343 if self.zone_groups.is_empty() && !self.zones.is_empty() {
344 self.zone_groups = self
345 .zones
346 .iter()
347 .cloned()
348 .map(|zone| ZoneGroup {
349 zones: vec![zone],
350 mode: RoundRobinMode::Sequential,
351 })
352 .collect();
353 }
354 }
355
356 pub fn find_zone(&self, note: u8, velocity: u8) -> Option<&SampleZone> {
359 let mut best: Option<&SampleZone> = None;
360 let mut best_dist = u8::MAX;
361 for zone in &self.zones {
362 if zone.matches(note, velocity) {
363 let dist = note.abs_diff(zone.root_note);
364 if dist < best_dist {
365 best_dist = dist;
366 best = Some(zone);
367 }
368 }
369 }
370 best
371 }
372
373 pub fn find_zone_rr<'a>(
376 &'a self,
377 note: u8,
378 velocity: u8,
379 rr_state: &mut RoundRobinState,
380 ) -> Option<&'a SampleZone> {
381 for (group_idx, group) in self.zone_groups.iter().enumerate() {
382 if group.zones.iter().any(|z| z.matches(note, velocity)) {
384 return rr_state.select_zone(group_idx, &group.zones, &group.mode);
385 }
386 }
387 None
388 }
389
390 pub fn add_zone(&mut self, zone: SampleZone) {
392 self.zones.push(zone);
393 }
394
395 pub fn save(&self, path: &Path) -> std::io::Result<()> {
397 let json = serde_json::to_string_pretty(self).unwrap();
398 std::fs::write(path, json)
399 }
400
401 pub fn load(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
403 let json = std::fs::read_to_string(path)?;
404 Ok(serde_json::from_str(&json)?)
405 }
406}
407
408pub struct LoadedInstrument {
410 pub instrument: SamplerInstrument,
411 pub buffers: HashMap<String, SampleBuffer>,
413 pub release_buffers: HashMap<String, SampleBuffer>,
415}
416
417impl LoadedInstrument {
418 pub fn load(instrument: SamplerInstrument, base_dir: &Path) -> Result<Self, Box<dyn std::error::Error>> {
420 let mut buffers = HashMap::new();
421 let mut release_buffers = HashMap::new();
422
423 for zone in &instrument.zones {
424 let path = base_dir.join(&zone.file_path);
425 let buf = SampleBuffer::load_wav(&path)?;
426 buffers.insert(zone.id.clone(), buf);
427
428 if let Some(ref rel_path) = zone.release_file {
429 let rpath = base_dir.join(rel_path);
430 if rpath.exists() {
431 let rbuf = SampleBuffer::load_wav(&rpath)?;
432 release_buffers.insert(zone.id.clone(), rbuf);
433 }
434 }
435 }
436
437 Ok(Self { instrument, buffers, release_buffers })
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use proptest::prelude::*;
445
446 #[test]
447 fn test_round_robin_mode_derives() {
448 let mode = RoundRobinMode::Sequential;
449 let cloned = mode.clone();
450 assert_eq!(mode, cloned);
451 }
452
453 #[test]
454 fn test_zone_group_creation() {
455 let zone = SampleZone {
456 id: "test".into(),
457 file_path: "test.wav".into(),
458 root_note: 60,
459 note_low: 60,
460 note_high: 60,
461 velocity_low: 0,
462 velocity_high: 127,
463 articulation: ArticulationType::OneShot,
464 volume_db: 0.0,
465 tune_cents: 0.0,
466 release_file: None,
467 };
468
469 let group = ZoneGroup {
470 zones: vec![zone],
471 mode: RoundRobinMode::Sequential,
472 };
473
474 assert_eq!(group.zones.len(), 1);
475 assert_eq!(group.mode, RoundRobinMode::Sequential);
476 }
477
478 #[test]
479 fn test_round_robin_state_sequential() {
480 let mut state = RoundRobinState::with_seed(12345);
481
482 let zones = vec![
483 SampleZone {
484 id: "zone1".into(),
485 file_path: "test1.wav".into(),
486 root_note: 60,
487 note_low: 60,
488 note_high: 60,
489 velocity_low: 0,
490 velocity_high: 127,
491 articulation: ArticulationType::OneShot,
492 volume_db: 0.0,
493 tune_cents: 0.0,
494 release_file: None,
495 },
496 SampleZone {
497 id: "zone2".into(),
498 file_path: "test2.wav".into(),
499 root_note: 60,
500 note_low: 60,
501 note_high: 60,
502 velocity_low: 0,
503 velocity_high: 127,
504 articulation: ArticulationType::OneShot,
505 volume_db: 0.0,
506 tune_cents: 0.0,
507 release_file: None,
508 },
509 ];
510
511 let z1 = state.select_zone(0, &zones, &RoundRobinMode::Sequential);
513 assert_eq!(z1.unwrap().id, "zone1");
514
515 let z2 = state.select_zone(0, &zones, &RoundRobinMode::Sequential);
516 assert_eq!(z2.unwrap().id, "zone2");
517
518 let z3 = state.select_zone(0, &zones, &RoundRobinMode::Sequential);
519 assert_eq!(z3.unwrap().id, "zone1");
520 }
521
522 #[test]
523 fn test_sampler_instrument_normalize() {
524 let mut instrument = SamplerInstrument::new("test");
525
526 let zone = SampleZone {
527 id: "zone1".into(),
528 file_path: "test.wav".into(),
529 root_note: 60,
530 note_low: 60,
531 note_high: 60,
532 velocity_low: 0,
533 velocity_high: 127,
534 articulation: ArticulationType::OneShot,
535 volume_db: 0.0,
536 tune_cents: 0.0,
537 release_file: None,
538 };
539
540 instrument.zones.push(zone);
541 instrument.normalize();
542
543 assert_eq!(instrument.zone_groups.len(), 1);
544 assert_eq!(instrument.zone_groups[0].zones.len(), 1);
545 assert_eq!(instrument.zone_groups[0].mode, RoundRobinMode::Sequential);
546 }
547
548 #[test]
549 fn test_find_zone_rr() {
550 let mut instrument = SamplerInstrument::new("test");
551
552 let zone1 = SampleZone {
553 id: "zone1".into(),
554 file_path: "test1.wav".into(),
555 root_note: 60,
556 note_low: 60,
557 note_high: 60,
558 velocity_low: 0,
559 velocity_high: 127,
560 articulation: ArticulationType::OneShot,
561 volume_db: 0.0,
562 tune_cents: 0.0,
563 release_file: None,
564 };
565
566 let zone2 = SampleZone {
567 id: "zone2".into(),
568 file_path: "test2.wav".into(),
569 root_note: 60,
570 note_low: 60,
571 note_high: 60,
572 velocity_low: 0,
573 velocity_high: 127,
574 articulation: ArticulationType::OneShot,
575 volume_db: 0.0,
576 tune_cents: 0.0,
577 release_file: None,
578 };
579
580 instrument.zone_groups.push(ZoneGroup {
581 zones: vec![zone1, zone2],
582 mode: RoundRobinMode::Sequential,
583 });
584
585 let mut rr_state = RoundRobinState::with_seed(12345);
586
587 let z1 = instrument.find_zone_rr(60, 100, &mut rr_state);
588 assert!(z1.is_some());
589 assert_eq!(z1.unwrap().id, "zone1");
590
591 let z2 = instrument.find_zone_rr(60, 100, &mut rr_state);
592 assert!(z2.is_some());
593 assert_eq!(z2.unwrap().id, "zone2");
594 }
595
596 proptest! {
598 #[test]
608 fn prop_sequential_round_robin_full_cycle(
609 n in 1usize..=16,
610 ) {
611 let mut zones = Vec::new();
613 for i in 0..n {
614 zones.push(SampleZone {
615 id: format!("zone_{}", i),
616 file_path: format!("sample_{}.wav", i),
617 root_note: 60,
618 note_low: 60,
619 note_high: 60,
620 velocity_low: 0,
621 velocity_high: 127,
622 articulation: ArticulationType::OneShot,
623 volume_db: 0.0,
624 tune_cents: 0.0,
625 release_file: None,
626 });
627 }
628
629 let group = ZoneGroup {
630 zones,
631 mode: RoundRobinMode::Sequential,
632 };
633
634 let mut rr_state = RoundRobinState::with_seed(12345);
636
637 let mut selected_indices = Vec::new();
639 for _ in 0..n {
640 let idx = rr_state.select(0, n, &group.mode);
641 selected_indices.push(idx);
642 }
643
644 let mut sorted_indices = selected_indices.clone();
646 sorted_indices.sort_unstable();
647
648 prop_assert_eq!(selected_indices.len(), n);
650
651 let expected: Vec<usize> = (0..n).collect();
653 prop_assert_eq!(sorted_indices, expected);
654
655 for i in 0..n {
657 let count = selected_indices.iter().filter(|&&x| x == i).count();
658 prop_assert_eq!(count, 1, "Zone index {} should appear exactly once, but appeared {} times", i, count);
659 }
660 }
661 }
662
663 proptest! {
665 #[test]
675 fn prop_random_no_repeat_no_consecutive_repeats(
676 n in 2usize..=16,
677 seed in any::<u64>(),
678 ) {
679 let mut rr_state = RoundRobinState::with_seed(seed);
681
682 let mut selected_indices = Vec::new();
684 for _ in 0..100 {
685 let idx = rr_state.select(0, n, &RoundRobinMode::RandomNoRepeat);
686 selected_indices.push(idx);
687 }
688
689 for i in 0..selected_indices.len() - 1 {
691 let current = selected_indices[i];
692 let next = selected_indices[i + 1];
693 prop_assert_ne!(
694 current,
695 next,
696 "RandomNoRepeat violated: zone {} was selected twice in a row at positions {} and {}",
697 current,
698 i,
699 i + 1
700 );
701 }
702
703 for (i, &idx) in selected_indices.iter().enumerate() {
705 prop_assert!(
706 idx < n,
707 "Invalid zone index {} at position {} (should be < {})",
708 idx,
709 i,
710 n
711 );
712 }
713 }
714 }
715
716 proptest! {
718 #[test]
729 fn prop_backward_compatible_instrument_loading(
730 zone_count in 1usize..=10,
731 note_ranges in prop::collection::vec((0u8..=127u8, 0u8..=127u8), 1..=10),
732 velocity_ranges in prop::collection::vec((0u8..=127u8, 0u8..=127u8), 1..=10),
733 ) {
734 let mut zones = Vec::new();
736 for i in 0..zone_count {
737 let (note_low, note_high_offset) = note_ranges[i % note_ranges.len()];
738 let note_high = note_low.saturating_add(note_high_offset % 12);
739 let root_note = note_low + (note_high - note_low) / 2;
740
741 let (vel_low, vel_high_offset) = velocity_ranges[i % velocity_ranges.len()];
742 let vel_high = vel_low.saturating_add(vel_high_offset);
743
744 zones.push(SampleZone {
745 id: format!("zone_{}", i),
746 file_path: format!("sample_{}.wav", i),
747 root_note,
748 note_low,
749 note_high,
750 velocity_low: vel_low,
751 velocity_high: vel_high,
752 articulation: ArticulationType::OneShot,
753 volume_db: 0.0,
754 tune_cents: 0.0,
755 release_file: None,
756 });
757 }
758
759 let frequencies: Vec<f32> = (0..128)
761 .map(|n| 440.0f32 * 2.0f32.powf((n as f32 - 69.0) / 12.0))
762 .collect();
763
764 let legacy_json = serde_json::json!({
766 "name": "Legacy Instrument",
767 "origin": "Test",
768 "description": "Test legacy instrument",
769 "author": "Test",
770 "tuning": {
771 "name": "12-TET",
772 "description": "Standard 12-TET",
773 "frequencies": frequencies
774 },
775 "zones": zones,
776 "attack": 0.005,
777 "decay": 0.1,
778 "sustain": 0.8,
779 "release": 0.3,
780 "max_voices": 16
781 });
782
783 let instrument: SamplerInstrument = serde_json::from_value(legacy_json)
785 .expect("Failed to deserialize legacy instrument");
786
787 prop_assert_eq!(instrument.zone_groups.len(), zones.len());
789 for (i, group) in instrument.zone_groups.iter().enumerate() {
790 prop_assert_eq!(group.zones.len(), 1);
791 prop_assert_eq!(&group.mode, &RoundRobinMode::Sequential);
792 prop_assert_eq!(&group.zones[0].id, &format!("zone_{}", i));
793 }
794
795 for note in 0u8..=127 {
799 for velocity in [1u8, 64, 127] {
800 let legacy_zone = instrument.find_zone(note, velocity);
801 let mut fresh_rr = RoundRobinState::with_seed(12345);
803 let rr_zone = instrument.find_zone_rr(note, velocity, &mut fresh_rr);
804
805 prop_assert_eq!(legacy_zone.is_some(), rr_zone.is_some(),
807 "note={} vel={}: legacy={:?} rr={:?}",
808 note, velocity,
809 legacy_zone.map(|z| &z.id),
810 rr_zone.map(|z| &z.id));
811
812 if let (Some(_legacy), Some(rr)) = (legacy_zone, rr_zone) {
817 let rr_exists = instrument.zones.iter().any(|z| z.id == rr.id);
818 prop_assert!(rr_exists,
819 "note={} vel={}: rr returned zone '{}' not in instrument",
820 note, velocity, rr.id);
821 }
822 }
823 }
824 }
825 }
826}