#![allow(dead_code)]
use cpal::traits::{DeviceTrait, HostTrait};
use hound::{self, SampleFormat};
use lazy_static::lazy_static;
use std::f32::consts::PI;
use std::io::Cursor;
use std::path::Path;
use std::sync::Mutex;
use crate::missing_device_error::MissingDeviceError;
lazy_static! {
pub static ref HOST: Mutex<Option<cpal::Host>> = Mutex::new(None);
pub static ref DEVICE_NAME: Mutex<String> = Mutex::new(String::new());
}
pub fn set_host_and_audio_device() -> Result<(), MissingDeviceError> {
#[cfg(target_os = "windows")]
{
let host =
cpal::host_from_id(cpal::HostId::Asio).map_err(|e| MissingDeviceError::from(e))?; *HOST.lock().unwrap() = Some(host);
*DEVICE_NAME.lock().unwrap() = "Focusrite USB ASIO".to_string();
}
#[cfg(target_os = "linux")]
{
let host = cpal::default_host();
*HOST.lock().unwrap() = Some(host);
*DEVICE_NAME.lock().unwrap() = "hw:CARD=USB,DEV=0".to_string();
}
let device_exists = match (*HOST)
.lock()
.unwrap()
.as_ref()
.unwrap()
.devices()
.map_err(|_| "failed to get devices".to_string())
{
Ok(mut devices) => devices.any(|d| {
d.name()
.map_or(false, |name: String| name == *DEVICE_NAME.lock().unwrap())
}),
Err(_) => false,
};
if !device_exists {
return Err(MissingDeviceError::Error("Device not found".to_string()));
}
Ok(())
}
pub fn generate_sine_wave(frequency: u32, duration: f32, fs: u32) -> Vec<i32> {
let signal: Vec<f32> = (0..(fs as f32 * duration) as usize)
.map(|i| ((i as u32 * frequency * 2) as f32 * PI / fs as f32).sin() as f32)
.collect();
let signal: Vec<i32> = signal
.iter()
.map(|&x| (x * i32::MAX as f32) as i32)
.collect(); signal
}
pub fn generate_gaussian_white_noise(
duration_seconds: f32,
fs: u32,
_scalar: Option<f32>,
) -> Vec<i32> {
let wav_file_contents = include_bytes!("../assets/full_spectrum_white_noise.wav");
let white_noise = read_wave_file_data(Cursor::new(wav_file_contents.to_vec()), fs).unwrap();
let white_noise: Vec<i32> = white_noise
.iter()
.take((duration_seconds * fs as f32) as usize)
.cloned()
.collect();
white_noise
}
pub(crate) fn print_devices() -> Result<(), Box<dyn std::error::Error>> {
let binding = HOST.lock().unwrap();
let host = binding.as_ref().ok_or("Host not initialized")?;
println!("Input devices:");
let input_devices = host.input_devices()?;
for device in input_devices {
let config = device.default_input_config()?;
println!(
"Device: {}, input channels: {}",
device.name()?,
config.channels()
);
}
println!("Output devices:");
let output_devices = host.output_devices()?;
for device in output_devices {
let config = device.default_output_config()?;
println!(
"Device: {}, output channels: {}",
device.name()?,
config.channels()
);
}
Ok(())
}
pub fn format_signal_for_multichannel(
signal: Vec<i32>,
playback_index: usize,
output_channels: usize,
) -> Vec<Vec<i32>> {
if playback_index >= output_channels {
return vec![];
}
let mut multi_channel_data = vec![vec![0; signal.len()]; output_channels];
multi_channel_data[playback_index] = signal;
multi_channel_data
}
pub fn save_to_wav(data: &Vec<i32>, filename: &str, sample_rate: u32) -> Result<(), anyhow::Error> {
let spec = hound::WavSpec {
channels: 1,
sample_rate: sample_rate,
bits_per_sample: 32,
sample_format: hound::SampleFormat::Int,
};
let mut writer = hound::WavWriter::create(filename, spec)?;
let sliced_data = data.as_slice();
for &sample in sliced_data {
writer.write_sample(sample)?;
}
writer.finalize()?;
Ok(())
}
pub fn read_wave_file_dart(byte_data: Vec<u8>, fs: u32) -> Result<Vec<i32>, hound::Error> {
let cursor = Cursor::new(byte_data);
read_wave_file_data(cursor, fs)
}
fn read_wave_file_data<R: std::io::Read + std::io::Seek>(
reader: R,
fs: u32,
) -> Result<Vec<i32>, hound::Error> {
let mut reader = hound::WavReader::new(reader)?;
let spec = reader.spec();
assert_eq!(spec.sample_rate, fs, "Sample rate of WAV file does not match the sample rate of the audio interface.\n\tWAV file sample rate: {}\n\tAudio interface sample rate: {}", spec.sample_rate, fs);
let samples: Vec<i32> = match (spec.sample_format, spec.bits_per_sample) {
(SampleFormat::Int, _) => reader.samples::<i32>().collect::<Result<Vec<_>, _>>()?,
(SampleFormat::Float, 32) => {
let float_samples: Vec<f32> = reader.samples::<f32>().collect::<Result<Vec<_>, _>>()?;
float_samples
.into_iter()
.map(|sample| (sample * std::i32::MAX as f32) as i32)
.collect()
}
_ => return Err(hound::Error::Unsupported),
};
Ok(samples)
}
pub fn read_wave_file(filepath: &Path, fs: u32) -> Result<Vec<i32>, hound::Error> {
let mut reader = hound::WavReader::open(filepath)?;
let spec = reader.spec();
assert_eq!(spec.sample_rate, fs, "Sample rate of WAV file does not match the sample rate of the audio interface.\n\tWAV file sample rate: {}\n\tAudio interface sample rate: {}", spec.sample_rate, fs);
let samples: Vec<i32> = match (spec.sample_format, spec.bits_per_sample) {
(SampleFormat::Int, _) => reader.samples::<i32>().collect::<Result<Vec<_>, _>>()?,
(SampleFormat::Float, 32) => {
let float_samples: Vec<f32> = reader.samples::<f32>().collect::<Result<Vec<_>, _>>()?;
float_samples
.into_iter()
.map(|sample| (sample * std::i32::MAX as f32) as i32)
.collect()
}
_ => return Err(hound::Error::Unsupported),
};
Ok(samples)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_set_host_and_audio_device() {
let _ = set_host_and_audio_device();
}
#[test]
fn test_generate_sine_wave() {
let fs = 48000;
let frequency = 1000;
let duration = 1.0;
let signal = generate_sine_wave(frequency, duration, fs);
assert_eq!(signal.len(), (fs as f32 * duration) as usize);
}
#[test]
fn test_zero_duration_generate_sine_wave() {
let fs = 48000;
let frequency = 1000;
let duration = 0.0;
let signal = generate_sine_wave(frequency, duration, fs);
assert_eq!(signal.len(), 0);
}
#[test]
fn test_fractional_duration_generate_sine_wave() {
let fs = 48000;
let frequency = 1000;
let duration = 0.5;
let signal = generate_sine_wave(frequency, duration, fs);
assert_eq!(signal.len(), (fs as f32 * duration) as usize);
}
#[test]
fn test_negative_duration_generate_sine_wave() {
let fs = 48000;
let frequency = 1000;
let duration = -1.0;
let signal = generate_sine_wave(frequency, duration, fs);
assert_eq!(signal.len(), 0);
}
#[test]
fn test_generate_gaussian_white_noise() {
let fs = 48000;
let duration = 1.0;
let signal = generate_gaussian_white_noise(duration, fs, None);
assert_eq!(signal.len(), (fs as f32 * duration) as usize);
}
#[test]
fn test_zero_duration_generate_gaussian_white_noise() {
let fs = 48000;
let duration = 0.0;
let signal = generate_gaussian_white_noise(duration, fs, None);
assert_eq!(signal.len(), 0);
}
#[test]
fn test_fractional_duration_generate_gaussian_white_noise() {
let fs = 48000;
let duration = 0.5;
let signal = generate_gaussian_white_noise(duration, fs, None);
assert_eq!(signal.len(), (fs as f32 * duration) as usize);
}
#[test]
fn test_negative_duration_generate_gaussian_white_noise() {
let fs = 48000;
let duration = -1.0;
let signal = generate_gaussian_white_noise(duration, fs, None);
assert_eq!(signal.len(), 0);
}
#[test]
fn test_format_signal_for_multichannel() {
let signal = vec![1, 2, 3, 4, 5];
let playback_index = 1;
let output_channels = 3;
let formatted_signal =
format_signal_for_multichannel(signal.clone(), playback_index, output_channels);
assert_eq!(formatted_signal.len(), output_channels);
assert_eq!(formatted_signal[playback_index], signal);
}
#[test]
fn test_higher_playback_index_than_number_of_channels() {
let signal = vec![1, 2, 3, 4, 5];
let playback_index = 4;
let output_channels = 3;
let formatted_signal =
format_signal_for_multichannel(signal, playback_index as usize, output_channels);
assert_eq!(formatted_signal.len(), 0);
}
}