use std::collections::HashMap;
use std::error::Error;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, info, warn};
use crate::audio::sample_source::{create_sample_source_from_file, MemorySampleSource};
use crate::config::samples::SampleDefinition;
const DEFAULT_BUFFER_SIZE: usize = 4096;
#[derive(Clone)]
pub struct LoadedSample {
data: Arc<Vec<f32>>,
channel_count: u16,
sample_rate: u32,
}
impl LoadedSample {
pub fn create_source(&self, volume: f32) -> MemorySampleSource {
MemorySampleSource::from_shared(
self.data.clone(),
self.channel_count,
self.sample_rate,
volume,
)
}
pub fn channel_count(&self) -> u16 {
self.channel_count
}
pub fn memory_size(&self) -> usize {
self.data.len() * std::mem::size_of::<f32>()
}
}
pub struct SampleLoader {
cache: HashMap<PathBuf, LoadedSample>,
target_sample_rate: u32,
}
impl SampleLoader {
pub fn new(target_sample_rate: u32) -> Self {
Self {
cache: HashMap::new(),
target_sample_rate,
}
}
pub fn load(&mut self, path: &Path) -> Result<LoadedSample, Box<dyn Error>> {
if let Some(sample) = self.cache.get(path) {
debug!(path = ?path, "Using cached sample");
return Ok(sample.clone());
}
info!(path = ?path, "Loading sample into memory");
let mut source = create_sample_source_from_file(path, None, DEFAULT_BUFFER_SIZE).map_err(
|e| -> Box<dyn std::error::Error> {
format!("Failed to load sample {}: {}", path.display(), e).into()
},
)?;
let source_sample_rate = source.sample_rate();
let channel_count = source.channel_count();
let mut samples = Vec::new();
while let Some(sample) = source.next_sample()? {
samples.push(sample);
}
let (final_samples, final_sample_rate) = if source_sample_rate != self.target_sample_rate {
info!(
source_rate = source_sample_rate,
target_rate = self.target_sample_rate,
"Transcoding sample"
);
let transcoded = self.transcode_samples(
&samples,
channel_count,
source_sample_rate,
self.target_sample_rate,
)?;
(transcoded, self.target_sample_rate)
} else {
(samples, source_sample_rate)
};
let total_samples = final_samples.len();
let samples_per_channel = total_samples as f64 / channel_count as f64;
let duration_secs = samples_per_channel / final_sample_rate as f64;
let duration = Duration::from_secs_f64(duration_secs);
let loaded = LoadedSample {
data: Arc::new(final_samples),
channel_count,
sample_rate: final_sample_rate,
};
info!(
path = ?path,
channels = channel_count,
sample_rate = final_sample_rate,
duration_ms = duration.as_millis(),
memory_kb = loaded.memory_size() / 1024,
"Sample loaded"
);
self.cache.insert(path.to_path_buf(), loaded.clone());
Ok(loaded)
}
pub fn load_definition(
&mut self,
definition: &SampleDefinition,
base_path: &Path,
) -> Result<HashMap<PathBuf, LoadedSample>, Box<dyn Error>> {
let mut loaded = HashMap::new();
for file in definition.all_files() {
let full_path = if Path::new(file).is_absolute() {
PathBuf::from(file)
} else {
base_path.join(file)
};
match self.load(&full_path) {
Ok(sample) => {
loaded.insert(full_path, sample);
}
Err(e) => {
warn!(path = ?full_path, error = ?e, "Failed to load sample");
return Err(
format!("Failed to load sample {}: {}", full_path.display(), e).into(),
);
}
}
}
Ok(loaded)
}
pub fn total_memory_usage(&self) -> usize {
self.cache.values().map(|s| s.memory_size()).sum()
}
fn transcode_samples(
&self,
samples: &[f32],
channel_count: u16,
source_rate: u32,
target_rate: u32,
) -> Result<Vec<f32>, Box<dyn Error>> {
let ratio = target_rate as f64 / source_rate as f64;
let source_frames = samples.len() / channel_count as usize;
let target_frames = (source_frames as f64 * ratio).ceil() as usize;
let channels = channel_count as usize;
let mut output = Vec::with_capacity(target_frames * channels);
for target_frame in 0..target_frames {
let source_pos = target_frame as f64 / ratio;
let source_frame = source_pos.floor() as usize;
let frac = source_pos.fract() as f32;
for channel in 0..channels {
let idx0 = source_frame * channels + channel;
let idx1 = (source_frame + 1) * channels + channel;
let s0 = samples.get(idx0).copied().unwrap_or(0.0);
let s1 = samples.get(idx1).copied().unwrap_or(s0);
let interpolated = s0 + (s1 - s0) * frac;
output.push(interpolated);
}
}
Ok(output)
}
}
impl std::fmt::Debug for SampleLoader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SampleLoader")
.field("cached_samples", &self.cache.len())
.field("target_sample_rate", &self.target_sample_rate)
.field("total_memory_kb", &(self.total_memory_usage() / 1024))
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::samples::{VelocityConfig, VelocityLayer};
#[test]
fn test_transcode_samples() {
let loader = SampleLoader::new(48000);
let source_rate = 44100;
let target_rate = 48000;
let source_samples: Vec<f32> = (0..4410)
.map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / source_rate as f32).sin())
.collect();
let result = loader
.transcode_samples(&source_samples, 1, source_rate, target_rate)
.unwrap();
let expected_len = (4410.0_f64 * 48000.0 / 44100.0).ceil() as usize;
assert_eq!(result.len(), expected_len);
}
#[test]
fn test_transcode_stereo() {
let loader = SampleLoader::new(48000);
let source_samples = vec![1.0f32, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
let result = loader
.transcode_samples(&source_samples, 2, 44100, 48000)
.unwrap();
assert!(result.len() >= 8);
assert!((result[0] - 1.0).abs() < 0.1);
assert!((result[1] - (-1.0)).abs() < 0.1);
}
#[test]
fn test_transcode_same_rate() {
let loader = SampleLoader::new(44100);
let source = vec![0.5f32, -0.5, 0.25, -0.25];
let result = loader.transcode_samples(&source, 1, 44100, 44100).unwrap();
assert_eq!(result.len(), source.len());
for (a, b) in result.iter().zip(source.iter()) {
assert!((a - b).abs() < 1e-6);
}
}
#[test]
fn test_transcode_downsample() {
let loader = SampleLoader::new(22050);
let source: Vec<f32> = (0..4800).map(|i| (i as f32) / 4800.0).collect();
let result = loader.transcode_samples(&source, 1, 48000, 22050).unwrap();
assert!(result.len() < source.len());
let expected_len = (4800.0_f64 * 22050.0 / 48000.0).ceil() as usize;
assert_eq!(result.len(), expected_len);
}
#[test]
fn test_transcode_empty_input() {
let loader = SampleLoader::new(48000);
let result = loader.transcode_samples(&[], 1, 44100, 48000).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_transcode_single_frame() {
let loader = SampleLoader::new(48000);
let source = vec![0.75f32];
let result = loader.transcode_samples(&source, 1, 44100, 48000).unwrap();
assert!(!result.is_empty());
assert!((result[0] - 0.75).abs() < 0.01);
}
#[test]
fn test_loaded_sample_memory_size() {
let data = vec![1.0f32; 1000];
let sample = LoadedSample {
data: Arc::new(data),
channel_count: 2,
sample_rate: 44100,
};
assert_eq!(sample.memory_size(), 1000 * std::mem::size_of::<f32>());
assert_eq!(sample.channel_count(), 2);
}
#[test]
fn test_loaded_sample_create_source() {
let data = vec![0.5f32; 100];
let sample = LoadedSample {
data: Arc::new(data),
channel_count: 1,
sample_rate: 44100,
};
let _source1 = sample.create_source(1.0);
let _source2 = sample.create_source(0.5);
}
#[test]
fn test_sample_loader_new() {
let loader = SampleLoader::new(48000);
assert_eq!(loader.total_memory_usage(), 0);
}
#[test]
fn test_sample_loader_debug() {
let loader = SampleLoader::new(44100);
let debug_str = format!("{:?}", loader);
assert!(debug_str.contains("SampleLoader"));
assert!(debug_str.contains("cached_samples"));
assert!(debug_str.contains("target_sample_rate"));
}
#[test]
fn test_load_wav_file() {
let mut loader = SampleLoader::new(44100);
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets/1Channel44.1k.wav");
let sample = loader.load(&path).unwrap();
assert_eq!(sample.channel_count(), 1);
assert!(sample.memory_size() > 0);
assert!(loader.total_memory_usage() > 0);
}
#[test]
fn test_load_caches_sample() {
let mut loader = SampleLoader::new(44100);
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets/1Channel44.1k.wav");
let sample1 = loader.load(&path).unwrap();
let sample2 = loader.load(&path).unwrap();
assert_eq!(sample1.memory_size(), sample2.memory_size());
assert_eq!(loader.total_memory_usage(), sample1.memory_size());
}
#[test]
fn test_load_stereo_file() {
let mut loader = SampleLoader::new(44100);
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets/2Channel44.1k.wav");
let sample = loader.load(&path).unwrap();
assert_eq!(sample.channel_count(), 2);
}
#[test]
fn test_load_with_transcoding() {
let mut loader = SampleLoader::new(44100);
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets/1Channel22.05k.wav");
let sample = loader.load(&path).unwrap();
assert_eq!(sample.channel_count(), 1);
assert!(sample.memory_size() > 0);
}
#[test]
fn test_load_nonexistent_file() {
let mut loader = SampleLoader::new(44100);
let path = PathBuf::from("/nonexistent/file.wav");
let result = loader.load(&path);
assert!(result.is_err());
assert!(result
.err()
.unwrap()
.to_string()
.contains("Failed to load sample"));
}
#[test]
fn test_load_definition_single_file() {
let mut loader = SampleLoader::new(44100);
let base_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets");
let definition = SampleDefinition::new(
Some("1Channel44.1k.wav".to_string()),
vec![1],
VelocityConfig::ignore(None),
crate::config::samples::ReleaseBehavior::PlayToCompletion,
crate::config::samples::RetriggerBehavior::Cut,
None,
50,
);
let loaded = loader.load_definition(&definition, &base_path).unwrap();
assert_eq!(loaded.len(), 1);
assert!(loaded.contains_key(&base_path.join("1Channel44.1k.wav")));
}
#[test]
fn test_load_definition_missing_file() {
let mut loader = SampleLoader::new(44100);
let base_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets");
let definition = SampleDefinition::new(
Some("nonexistent.wav".to_string()),
vec![1],
VelocityConfig::ignore(None),
crate::config::samples::ReleaseBehavior::PlayToCompletion,
crate::config::samples::RetriggerBehavior::Cut,
None,
50,
);
let result = loader.load_definition(&definition, &base_path);
assert!(result.is_err());
}
#[test]
fn test_load_definition_with_layers() {
let mut loader = SampleLoader::new(44100);
let base_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets");
let layers = vec![
VelocityLayer::new([1, 80], "1Channel44.1k.wav".to_string()),
VelocityLayer::new([81, 127], "2Channel44.1k.wav".to_string()),
];
let definition = SampleDefinition::new(
None,
vec![1],
VelocityConfig::with_layers(layers, false),
crate::config::samples::ReleaseBehavior::PlayToCompletion,
crate::config::samples::RetriggerBehavior::Cut,
None,
50,
);
let loaded = loader.load_definition(&definition, &base_path).unwrap();
assert_eq!(loaded.len(), 2);
}
#[test]
fn test_transcode_preserves_stereo_channels() {
let loader = SampleLoader::new(48000);
let frames = 100;
let mut source = Vec::with_capacity(frames * 2);
for i in 0..frames {
source.push(i as f32 / frames as f32); source.push(-(i as f32 / frames as f32)); }
let result = loader.transcode_samples(&source, 2, 44100, 48000).unwrap();
assert_eq!(result.len() % 2, 0);
for frame in result.chunks(2) {
assert!(frame[0] >= -0.01, "L channel should be non-negative");
assert!(frame[1] <= 0.01, "R channel should be non-positive");
}
}
}