use std::fs::File;
use std::io::{BufReader, Cursor};
use std::path::PathBuf;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use crossbeam_channel::{Receiver, Sender};
use rodio::{Decoder, Source};
use crate::{Result, VoiceError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StemRole {
Pad,
Texture,
Harmonic,
Percussive,
}
impl StemRole {
pub const ALL: [StemRole; 4] = [
StemRole::Pad,
StemRole::Texture,
StemRole::Harmonic,
StemRole::Percussive,
];
pub const fn file_name(self) -> &'static str {
match self {
Self::Pad => "pad.mp3",
Self::Texture => "texture.mp3",
Self::Harmonic => "harmonic.mp3",
Self::Percussive => "percussive.mp3",
}
}
fn index(self) -> usize {
match self {
Self::Pad => 0,
Self::Texture => 1,
Self::Harmonic => 2,
Self::Percussive => 3,
}
}
}
#[derive(Default)]
pub struct AtomicF32(AtomicU32);
impl AtomicF32 {
pub fn new(v: f32) -> Self {
Self(AtomicU32::new(v.to_bits()))
}
pub fn load(&self) -> f32 {
f32::from_bits(self.0.load(Ordering::Relaxed))
}
pub fn store(&self, v: f32) {
self.0.store(v.to_bits(), Ordering::Relaxed);
}
}
#[derive(Clone)]
pub struct VoiceMixerHandle {
stem_gains: Arc<[AtomicF32; 4]>,
master_gain: Arc<AtomicF32>,
tts_tx: Sender<Vec<f32>>,
stem_swap_tx: Sender<[Vec<f32>; 4]>,
speaking_now: Arc<std::sync::atomic::AtomicBool>,
stop_tts_flag: Arc<std::sync::atomic::AtomicBool>,
target_rate: u32,
target_channels: u16,
}
impl VoiceMixerHandle {
pub fn set_stem_gain(&self, role: StemRole, gain: f32) {
self.stem_gains[role.index()].store(gain);
}
pub fn set_master_gain(&self, gain: f32) {
self.master_gain.store(gain);
}
pub fn queue_tts(&self, bytes: Vec<u8>) {
let cursor = Cursor::new(bytes);
let decoder = match Decoder::new(cursor) {
Ok(d) => d,
Err(e) => {
tracing::warn!("[mixer] queue_tts decode failed: {e}");
return;
}
};
let samples = decode_source_to_interleaved(decoder, self.target_rate, self.target_channels);
if samples.is_empty() {
tracing::warn!("[mixer] queue_tts produced zero samples");
return;
}
if self.tts_tx.send(samples).is_err() {
tracing::warn!("[mixer] queue_tts send failed (mixer gone)");
}
}
pub fn is_speaking(&self) -> bool {
self.speaking_now.load(Ordering::Relaxed)
}
pub fn speaking_flag(&self) -> Arc<std::sync::atomic::AtomicBool> {
Arc::clone(&self.speaking_now)
}
pub fn swap_stems(&self, stems: [Vec<f32>; 4]) {
if self.stem_swap_tx.send(stems).is_err() {
tracing::warn!("[mixer] swap_stems send failed (mixer gone)");
}
}
pub fn target_rate(&self) -> u32 {
self.target_rate
}
pub fn target_channels(&self) -> u16 {
self.target_channels
}
pub fn stop_tts(&self) {
self.stop_tts_flag
.store(true, std::sync::atomic::Ordering::Relaxed);
self.speaking_now
.store(false, std::sync::atomic::Ordering::Relaxed);
}
}
pub struct VoiceAudioMixer {
stems: [PreloadedStem; 4],
fading_out: Option<[PreloadedStem; 4]>,
crossfade_samples_remaining: usize,
crossfade_total_samples: usize,
stem_gains: Arc<[AtomicF32; 4]>,
master_gain: Arc<AtomicF32>,
tts_rx: Receiver<Vec<f32>>,
stem_swap_rx: Receiver<[Vec<f32>; 4]>,
target_rate: u32,
target_channels: u16,
active_tts: Option<TtsClip>,
speaking_now: Arc<std::sync::atomic::AtomicBool>,
stop_tts_flag: Arc<std::sync::atomic::AtomicBool>,
}
const CROSSFADE_SECONDS: f32 = 0.75;
struct PreloadedStem {
samples: Vec<f32>,
cursor: usize,
}
impl PreloadedStem {
fn next_sample(&mut self) -> f32 {
if self.samples.is_empty() {
return 0.0;
}
let s = self.samples[self.cursor];
self.cursor += 1;
if self.cursor >= self.samples.len() {
self.cursor = 0;
}
s
}
}
struct TtsClip {
samples: Vec<f32>,
cursor: usize,
}
impl TtsClip {
fn next_sample(&mut self) -> Option<f32> {
if self.cursor >= self.samples.len() {
return None;
}
let s = self.samples[self.cursor];
self.cursor += 1;
Some(s)
}
}
impl VoiceAudioMixer {
pub fn new(
asset_root: &PathBuf,
theme_hint: &str,
target_rate: u32,
target_channels: u16,
speaking_now: Arc<std::sync::atomic::AtomicBool>,
) -> Result<(Self, VoiceMixerHandle)> {
let stems = preload_stems(asset_root, theme_hint, target_rate, target_channels)?;
let stem_gains = Arc::new([
AtomicF32::new(0.0),
AtomicF32::new(0.0),
AtomicF32::new(0.0),
AtomicF32::new(0.0),
]);
let master_gain = Arc::new(AtomicF32::new(1.0));
let stop_tts_flag = Arc::new(std::sync::atomic::AtomicBool::new(false));
let (tts_tx, tts_rx) = crossbeam_channel::unbounded::<Vec<f32>>();
let (stem_swap_tx, stem_swap_rx) = crossbeam_channel::unbounded::<[Vec<f32>; 4]>();
let handle = VoiceMixerHandle {
stem_gains: Arc::clone(&stem_gains),
master_gain: Arc::clone(&master_gain),
tts_tx,
stem_swap_tx,
speaking_now: Arc::clone(&speaking_now),
stop_tts_flag: Arc::clone(&stop_tts_flag),
target_rate,
target_channels,
};
let mixer = Self {
stems,
fading_out: None,
crossfade_samples_remaining: 0,
crossfade_total_samples: 0,
stem_gains,
master_gain,
tts_rx,
stem_swap_rx,
target_rate,
target_channels,
active_tts: None,
speaking_now,
stop_tts_flag,
};
Ok((mixer, handle))
}
pub fn from_stem_samples(
stems: [Vec<f32>; 4],
target_rate: u32,
target_channels: u16,
speaking_now: Arc<std::sync::atomic::AtomicBool>,
) -> (Self, VoiceMixerHandle) {
let [pad, texture, harmonic, percussive] = stems;
let stems = [
PreloadedStem {
samples: pad,
cursor: 0,
},
PreloadedStem {
samples: texture,
cursor: 0,
},
PreloadedStem {
samples: harmonic,
cursor: 0,
},
PreloadedStem {
samples: percussive,
cursor: 0,
},
];
let stem_gains = Arc::new([
AtomicF32::new(0.0),
AtomicF32::new(0.0),
AtomicF32::new(0.0),
AtomicF32::new(0.0),
]);
let master_gain = Arc::new(AtomicF32::new(1.0));
let stop_tts_flag = Arc::new(std::sync::atomic::AtomicBool::new(false));
let (tts_tx, tts_rx) = crossbeam_channel::unbounded::<Vec<f32>>();
let (stem_swap_tx, stem_swap_rx) = crossbeam_channel::unbounded::<[Vec<f32>; 4]>();
let handle = VoiceMixerHandle {
stem_gains: Arc::clone(&stem_gains),
master_gain: Arc::clone(&master_gain),
tts_tx,
stem_swap_tx,
speaking_now: Arc::clone(&speaking_now),
stop_tts_flag: Arc::clone(&stop_tts_flag),
target_rate,
target_channels,
};
let mixer = Self {
stems,
fading_out: None,
crossfade_samples_remaining: 0,
crossfade_total_samples: 0,
stem_gains,
master_gain,
tts_rx,
stem_swap_rx,
target_rate,
target_channels,
active_tts: None,
speaking_now,
stop_tts_flag,
};
(mixer, handle)
}
pub fn pull(&mut self, out: &mut [f32]) {
if self.stop_tts_flag.swap(false, Ordering::Relaxed) {
self.active_tts = None;
self.speaking_now.store(false, Ordering::Relaxed);
while self.tts_rx.try_recv().is_ok() {}
}
let mut latest_swap: Option<[Vec<f32>; 4]> = None;
while let Ok(new_stems) = self.stem_swap_rx.try_recv() {
latest_swap = Some(new_stems);
}
if let Some([pad, tex, harm, perc]) = latest_swap {
let new_stems = [
PreloadedStem {
samples: pad,
cursor: 0,
},
PreloadedStem {
samples: tex,
cursor: 0,
},
PreloadedStem {
samples: harm,
cursor: 0,
},
PreloadedStem {
samples: perc,
cursor: 0,
},
];
let prior = std::mem::replace(&mut self.stems, new_stems);
self.fading_out = Some(prior);
let frames = (self.target_rate as f32 * CROSSFADE_SECONDS) as usize;
self.crossfade_total_samples = frames.saturating_mul(self.target_channels as usize);
self.crossfade_samples_remaining = self.crossfade_total_samples;
}
if self.active_tts.is_none() {
if let Ok(samples) = self.tts_rx.try_recv() {
self.active_tts = Some(TtsClip { samples, cursor: 0 });
self.speaking_now.store(true, Ordering::Relaxed);
}
}
let master = self.master_gain.load();
let gains = [
self.stem_gains[0].load(),
self.stem_gains[1].load(),
self.stem_gains[2].load(),
self.stem_gains[3].load(),
];
for slot in out.iter_mut() {
let s0 = self.stems[0].next_sample();
let s1 = self.stems[1].next_sample();
let s2 = self.stems[2].next_sample();
let s3 = self.stems[3].next_sample();
let new_mix = s0 * gains[0] + s1 * gains[1] + s2 * gains[2] + s3 * gains[3];
let bed_mix = if let Some(out_stems) = self.fading_out.as_mut() {
let o0 = out_stems[0].next_sample();
let o1 = out_stems[1].next_sample();
let o2 = out_stems[2].next_sample();
let o3 = out_stems[3].next_sample();
let old_mix = o0 * gains[0] + o1 * gains[1] + o2 * gains[2] + o3 * gains[3];
let t = if self.crossfade_total_samples > 0 {
1.0 - (self.crossfade_samples_remaining as f32
/ self.crossfade_total_samples as f32)
} else {
1.0
};
let t = t.clamp(0.0, 1.0);
if self.crossfade_samples_remaining > 0 {
self.crossfade_samples_remaining -= 1;
}
if self.crossfade_samples_remaining == 0 {
self.fading_out = None;
self.crossfade_total_samples = 0;
}
(1.0 - t) * old_mix + t * new_mix
} else {
new_mix
};
let mut sample = bed_mix;
if let Some(tts) = self.active_tts.as_mut() {
if let Some(t) = tts.next_sample() {
sample += t;
} else {
self.active_tts = None;
self.speaking_now.store(false, Ordering::Relaxed);
}
}
*slot = (sample * master).clamp(-1.0, 1.0);
}
}
}
fn preload_stems(
asset_root: &PathBuf,
theme_hint: &str,
target_rate: u32,
target_channels: u16,
) -> Result<[PreloadedStem; 4]> {
let stems_dir = asset_root
.join("audio")
.join("themes")
.join(theme_hint)
.join("stems");
let mut stems: Vec<PreloadedStem> = Vec::with_capacity(4);
for role in StemRole::ALL {
let path = stems_dir.join(role.file_name());
let mut samples = decode_to_memory(&path, target_rate, target_channels)?;
stitch_loop_seam(&mut samples, target_rate, target_channels);
tracing::info!(
"[mixer] preloaded stem {} ({} samples, {:.1}s @ {} Hz {} ch)",
role.file_name(),
samples.len(),
samples.len() as f32 / (target_rate as f32 * target_channels as f32),
target_rate,
target_channels,
);
stems.push(PreloadedStem { samples, cursor: 0 });
}
let [pad, texture, harmonic, percussive]: [PreloadedStem; 4] = stems
.try_into()
.map_err(|_| VoiceError::Device("expected exactly 4 stems".into()))?;
Ok([pad, texture, harmonic, percussive])
}
fn stitch_loop_seam(samples: &mut Vec<f32>, target_rate: u32, target_channels: u16) {
if samples.is_empty() {
return;
}
let channels = target_channels.max(1) as usize;
let frames = samples.len() / channels;
let fade_frames = ((target_rate as usize * 30) / 1000).min(frames / 4);
if fade_frames < 8 {
return;
}
let head: Vec<f32> = samples[..fade_frames * channels].to_vec();
let tail_start = (frames - fade_frames) * channels;
let tail: Vec<f32> = samples[tail_start..].to_vec();
for f in 0..fade_frames {
let t = f as f32 / (fade_frames - 1).max(1) as f32;
let w_tail = ((1.0 - t) * std::f32::consts::FRAC_PI_2).cos().powi(2);
let w_head = (t * std::f32::consts::FRAC_PI_2).sin().powi(2);
for c in 0..channels {
let src = tail[f * channels + c] * w_tail + head[f * channels + c] * w_head;
samples[(frames - fade_frames + f) * channels + c] = src;
}
}
for f in 0..fade_frames {
let t = f as f32 / (fade_frames - 1).max(1) as f32;
let w_head = ((1.0 - t) * std::f32::consts::FRAC_PI_2).cos().powi(2);
let w_tail = (t * std::f32::consts::FRAC_PI_2).sin().powi(2);
for c in 0..channels {
let src = head[f * channels + c] * w_head + tail[f * channels + c] * w_tail;
samples[f * channels + c] = src;
}
}
}
fn decode_to_memory(path: &PathBuf, target_rate: u32, target_channels: u16) -> Result<Vec<f32>> {
let file = File::open(path)
.map_err(|e| VoiceError::Device(format!("open stem {}: {e}", path.display())))?;
let decoder = Decoder::new(BufReader::new(file))
.map_err(|e| VoiceError::Device(format!("decode stem {}: {e}", path.display())))?;
let samples = decode_source_to_interleaved(decoder, target_rate, target_channels);
if samples.is_empty() {
return Err(VoiceError::Device(format!(
"stem {} decoded to zero samples",
path.display()
)));
}
Ok(samples)
}
fn decode_source_to_interleaved<S>(source: S, target_rate: u32, target_channels: u16) -> Vec<f32>
where
S: Source<Item = i16>,
{
let src_channels = source.channels().max(1) as usize;
let src_rate = source.sample_rate().max(1);
let target_channels = target_channels.max(1) as usize;
let mut channel_data: Vec<Vec<f32>> = vec![Vec::new(); src_channels];
let mut idx = 0usize;
for sample in source {
let f = (sample as f32) / (i16::MAX as f32);
channel_data[idx % src_channels].push(f);
idx += 1;
}
if channel_data.iter().all(|c| c.is_empty()) {
return Vec::new();
}
let resampled: Vec<Vec<f32>> = channel_data
.into_iter()
.map(|c| linear_resample(&c, src_rate, target_rate))
.collect();
let target_per_channel = resampled[0].len();
let mut target_channels_data: Vec<Vec<f32>> = (0..target_channels)
.map(|i| {
if src_channels == 1 {
resampled[0].clone()
} else if target_channels == 1 {
let mut mono = vec![0.0f32; target_per_channel];
for c in resampled.iter() {
for (j, s) in c.iter().enumerate() {
mono[j] += *s;
}
}
let inv = 1.0 / src_channels as f32;
for s in mono.iter_mut() {
*s *= inv;
}
let _ = i;
mono
} else if i < src_channels {
resampled[i].clone()
} else {
vec![0.0; target_per_channel]
}
})
.collect();
if target_channels == 1 {
target_channels_data.truncate(1);
}
let frames = target_channels_data[0].len();
let total = frames * target_channels;
let mut out = Vec::with_capacity(total);
for f in 0..frames {
for ch in 0..target_channels {
out.push(target_channels_data[ch][f]);
}
}
out
}
#[cfg(test)]
mod stitch_tests {
use super::stitch_loop_seam;
#[test]
fn empty_buffer_is_a_noop() {
let mut buf: Vec<f32> = vec![];
stitch_loop_seam(&mut buf, 44_100, 2);
assert!(buf.is_empty());
}
#[test]
fn short_buffer_skips_crossfade_safely() {
let mut buf: Vec<f32> = vec![0.5; 16];
let original = buf.clone();
stitch_loop_seam(&mut buf, 44_100, 2);
assert_eq!(buf, original);
}
#[test]
fn seam_value_matches_after_stitch() {
let frames = 44_100; let channels = 1usize;
let mut buf: Vec<f32> = (0..frames)
.map(|i| (i as f32 / frames as f32) - 0.5) .collect();
let pre_jump = (buf[frames - 1] - buf[0]).abs();
assert!(pre_jump > 0.9);
stitch_loop_seam(&mut buf, 44_100, channels as u16);
let post_jump = (buf[frames - 1] - buf[0]).abs();
assert!(
post_jump < pre_jump,
"stitch should reduce wrap discontinuity (was {pre_jump}, now {post_jump})"
);
}
}
#[cfg(test)]
mod resample_tests {
use super::linear_resample;
#[test]
fn passthrough_when_rates_match() {
let s = vec![0.1, 0.2, 0.3, 0.4];
let out = linear_resample(&s, 44100, 44100);
assert_eq!(out, s);
}
#[test]
fn upsample_doubles_length_approx() {
let s = vec![0.0, 1.0, 0.0, 1.0];
let out = linear_resample(&s, 22050, 44100);
assert_eq!(out.len(), 8);
}
#[test]
fn downsample_halves_length_approx() {
let s = vec![0.0, 0.25, 0.5, 0.75, 1.0, 0.75, 0.5, 0.25];
let out = linear_resample(&s, 44100, 22050);
assert_eq!(out.len(), 4);
}
#[test]
fn empty_input_yields_empty_output() {
assert!(linear_resample(&[], 44100, 22050).is_empty());
}
}
fn linear_resample(samples: &[f32], src_rate: u32, target_rate: u32) -> Vec<f32> {
if samples.is_empty() {
return Vec::new();
}
if src_rate == target_rate {
return samples.to_vec();
}
let ratio = src_rate as f64 / target_rate as f64;
let out_len = ((samples.len() as f64) / ratio).floor() as usize;
let mut out = Vec::with_capacity(out_len);
for i in 0..out_len {
let src_pos = (i as f64) * ratio;
let lo = src_pos.floor() as usize;
let hi = (lo + 1).min(samples.len() - 1);
let t = (src_pos - lo as f64) as f32;
let s = samples[lo] * (1.0 - t) + samples[hi] * t;
out.push(s);
}
out
}