use super::types::AudioData;
use crate::transforms::Transform;
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;
pub struct AudioToTensor;
impl Transform<AudioData> for AudioToTensor {
type Output = Tensor<f32>;
fn transform(&self, input: AudioData) -> Result<Self::Output> {
let channels = input.channels;
let samples_per_channel = input.samples.len() / channels;
if channels == 1 {
Ok(Tensor::from_data(
input.samples,
vec![1, samples_per_channel],
torsh_core::device::DeviceType::Cpu,
)?)
} else {
let mut channel_data = vec![0.0f32; input.samples.len()];
for i in 0..samples_per_channel {
for c in 0..channels {
let src_idx = i * channels + c;
let dst_idx = c * samples_per_channel + i;
channel_data[dst_idx] = input.samples[src_idx];
}
}
Ok(Tensor::from_data(
channel_data,
vec![channels, samples_per_channel],
torsh_core::device::DeviceType::Cpu,
)?)
}
}
}
pub struct TensorToAudio {
sample_rate: u32,
}
impl TensorToAudio {
pub fn new(sample_rate: u32) -> Self {
Self { sample_rate }
}
}
impl Transform<Tensor<f32>> for TensorToAudio {
type Output = AudioData;
fn transform(&self, input: Tensor<f32>) -> Result<Self::Output> {
let shape = input.shape();
if shape.ndim() != 2 {
return Err(TorshError::InvalidShape(
"Expected 2D tensor (channels, samples)".to_string(),
));
}
let dims = shape.dims();
let (channels, samples_per_channel) = (dims[0], dims[1]);
let data = input.to_vec()?;
let audio_samples = if channels == 1 {
data
} else {
let mut interleaved = vec![0.0f32; data.len()];
for i in 0..samples_per_channel {
for c in 0..channels {
let src_idx = c * samples_per_channel + i;
let dst_idx = i * channels + c;
interleaved[dst_idx] = data[src_idx];
}
}
interleaved
};
Ok(AudioData::new(audio_samples, self.sample_rate, channels))
}
}
pub mod transforms {
use super::*;
pub struct Resample {
target_sample_rate: u32,
}
impl Resample {
pub fn new(target_sample_rate: u32) -> Self {
Self { target_sample_rate }
}
}
impl Transform<AudioData> for Resample {
type Output = AudioData;
fn transform(&self, input: AudioData) -> Result<Self::Output> {
if input.sample_rate == self.target_sample_rate {
return Ok(input);
}
let ratio = self.target_sample_rate as f32 / input.sample_rate as f32;
let new_length = (input.samples.len() as f32 * ratio) as usize;
let mut resampled = Vec::with_capacity(new_length);
for i in 0..new_length {
let src_index = i as f32 / ratio;
let src_index_floor = src_index.floor() as usize;
let src_index_ceil = (src_index_floor + 1).min(input.samples.len() - 1);
let fraction = src_index - src_index_floor as f32;
if src_index_floor < input.samples.len() {
let sample = input.samples[src_index_floor] * (1.0 - fraction)
+ input.samples[src_index_ceil] * fraction;
resampled.push(sample);
}
}
Ok(AudioData::new(
resampled,
self.target_sample_rate,
input.channels,
))
}
}
pub struct FixedLength {
length: usize,
pad_value: f32,
}
impl FixedLength {
pub fn new(length: usize) -> Self {
Self {
length,
pad_value: 0.0,
}
}
pub fn with_pad_value(mut self, pad_value: f32) -> Self {
self.pad_value = pad_value;
self
}
}
impl Transform<AudioData> for FixedLength {
type Output = AudioData;
fn transform(&self, input: AudioData) -> Result<Self::Output> {
let target_total_length = self.length * input.channels;
let mut samples = input.samples;
match samples.len().cmp(&target_total_length) {
std::cmp::Ordering::Greater => {
samples.truncate(target_total_length);
}
std::cmp::Ordering::Less => {
samples.resize(target_total_length, self.pad_value);
}
std::cmp::Ordering::Equal => {
}
}
Ok(AudioData::new(samples, input.sample_rate, input.channels))
}
}
pub struct Normalize {
target_rms: f32,
}
impl Normalize {
pub fn new(target_rms: f32) -> Self {
Self { target_rms }
}
}
impl Transform<AudioData> for Normalize {
type Output = AudioData;
fn transform(&self, input: AudioData) -> Result<Self::Output> {
let rms = (input.samples.iter().map(|&x| x * x).sum::<f32>()
/ input.samples.len() as f32)
.sqrt();
if rms == 0.0 {
return Ok(input); }
let gain = self.target_rms / rms;
let normalized_samples: Vec<f32> = input.samples.iter().map(|&x| x * gain).collect();
Ok(AudioData::new(
normalized_samples,
input.sample_rate,
input.channels,
))
}
}
pub struct AddNoise {
noise_level: f32,
}
impl AddNoise {
pub fn new(noise_level: f32) -> Self {
assert!(noise_level >= 0.0, "Noise level must be non-negative");
Self { noise_level }
}
}
impl Transform<AudioData> for AddNoise {
type Output = AudioData;
fn transform(&self, input: AudioData) -> Result<Self::Output> {
#[allow(unused_imports)]
use scirs2_core::random::{Random, Rng};
let mut rng = Random::seed(42);
let noisy_samples: Vec<f32> = input
.samples
.iter()
.map(|&sample| {
let noise = rng.gen_range(-1.0..1.0) * self.noise_level;
sample + noise
})
.collect();
Ok(AudioData::new(
noisy_samples,
input.sample_rate,
input.channels,
))
}
}
}