use crate::enrollment::{Enrollment, FilterbankEmbedder, SpeakerEmbedder, SpeakerEmbedding};
use crate::error::VoiceError;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::{Path, PathBuf};
const TARGET_SAMPLE_RATE: u32 = 16_000;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnrollmentInfo {
pub label: String,
pub path: PathBuf,
pub model_id: String,
}
pub fn enrollment_dir() -> Result<PathBuf, VoiceError> {
let home = std::env::var_os("HOME").map(PathBuf::from).ok_or_else(|| {
VoiceError::Config("HOME not set; cannot resolve enrollment dir".to_string())
})?;
Ok(home.join(".car").join("voiceprints"))
}
pub fn enrollment_path(label: &str) -> Result<PathBuf, VoiceError> {
if !is_safe_label(label) {
return Err(VoiceError::Config(format!(
"invalid enrollment label '{}': only [a-zA-Z0-9._-] allowed",
label
)));
}
Ok(enrollment_dir()?.join(format!("{}.toml", label)))
}
pub fn enroll_from_pcm(
label: &str,
samples: &[i16],
sample_rate: u32,
channels: u16,
) -> Result<Enrollment, VoiceError> {
if samples.is_empty() {
return Err(VoiceError::Config("empty PCM buffer".to_string()));
}
if channels == 0 {
return Err(VoiceError::Config("channels must be ≥ 1".to_string()));
}
let mono = downmix_to_mono(samples, channels);
let resampled = if sample_rate == TARGET_SAMPLE_RATE {
mono
} else {
linear_resample(&mono, sample_rate, TARGET_SAMPLE_RATE)
};
let embedder = FilterbankEmbedder::new();
let values = embedder.embed(&resampled);
if values.is_empty() {
return Err(VoiceError::Stt(
"embedder returned empty vector — sample too short or too quiet".to_string(),
));
}
let embedding = SpeakerEmbedding::new(values, embedder.model_id().to_string());
Ok(Enrollment {
label: label.to_string(),
embedding,
})
}
pub fn enroll_from_wav(label: &str, wav_path: &Path) -> Result<Enrollment, VoiceError> {
let mut reader = hound::WavReader::open(wav_path)
.map_err(|e| VoiceError::Config(format!("open WAV {}: {}", wav_path.display(), e)))?;
let spec = reader.spec();
let samples_i16: Vec<i16> = match spec.sample_format {
hound::SampleFormat::Int => match spec.bits_per_sample {
16 => reader
.samples::<i16>()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| VoiceError::Config(format!("read WAV samples: {}", e)))?,
other => {
return Err(VoiceError::Config(format!(
"unsupported WAV bit depth {} (only 16-bit signed and 32-bit float)",
other
)));
}
},
hound::SampleFormat::Float => {
let f: Vec<f32> = reader
.samples::<f32>()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| VoiceError::Config(format!("read WAV samples: {}", e)))?;
f.into_iter()
.map(|s| (s.clamp(-1.0, 1.0) * i16::MAX as f32) as i16)
.collect()
}
};
enroll_from_pcm(label, &samples_i16, spec.sample_rate, spec.channels)
}
pub fn save_enrollment(enrollment: &Enrollment) -> Result<PathBuf, VoiceError> {
let dir = enrollment_dir()?;
fs::create_dir_all(&dir)?;
let path = enrollment_path(&enrollment.label)?;
enrollment.save_to(&path)?;
Ok(path)
}
pub fn list_enrollments() -> Result<Vec<EnrollmentInfo>, VoiceError> {
let dir = enrollment_dir()?;
if !dir.exists() {
return Ok(Vec::new());
}
let mut out = Vec::new();
for entry in fs::read_dir(&dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) != Some("toml") {
continue;
}
let Ok(enrollment) = Enrollment::load_from(&path) else {
tracing::warn!("[enrollment] skipping malformed file {}", path.display());
continue;
};
out.push(EnrollmentInfo {
label: enrollment.label,
path: path.clone(),
model_id: enrollment.embedding.model.clone(),
});
}
out.sort_by(|a, b| a.label.cmp(&b.label));
Ok(out)
}
pub fn load_enrollment(label: &str) -> Result<Enrollment, VoiceError> {
Enrollment::load_from(&enrollment_path(label)?)
}
pub fn remove_enrollment(label: &str) -> Result<(), VoiceError> {
let path = enrollment_path(label)?;
match fs::remove_file(&path) {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(VoiceError::Io(e)),
}
}
fn is_safe_label(label: &str) -> bool {
!label.is_empty()
&& label.len() <= 64
&& label
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.')
}
fn downmix_to_mono(samples: &[i16], channels: u16) -> Vec<i16> {
if channels <= 1 {
return samples.to_vec();
}
let n = channels as usize;
samples
.chunks_exact(n)
.map(|frame| {
let sum: i32 = frame.iter().map(|s| *s as i32).sum();
(sum / n as i32) as i16
})
.collect()
}
fn linear_resample(samples: &[i16], from_rate: u32, to_rate: u32) -> Vec<i16> {
if from_rate == to_rate || samples.is_empty() {
return samples.to_vec();
}
let ratio = to_rate as f64 / from_rate as f64;
let out_len = (samples.len() as f64 * ratio) as usize;
let mut out = Vec::with_capacity(out_len);
for i in 0..out_len {
let src = i as f64 / ratio;
let lo = src.floor() as usize;
let hi = (lo + 1).min(samples.len() - 1);
let frac = (src - lo as f64) as f32;
let v = samples[lo] as f32 * (1.0 - frac) + samples[hi] as f32 * frac;
out.push(v as i16);
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn label_validation_accepts_normal_names() {
assert!(is_safe_label("alice"));
assert!(is_safe_label("alice_smith"));
assert!(is_safe_label("alice-1"));
assert!(is_safe_label("alice.work"));
}
#[test]
fn label_validation_rejects_path_traversal() {
assert!(!is_safe_label(""));
assert!(!is_safe_label("../etc/passwd"));
assert!(!is_safe_label("alice/bob"));
assert!(!is_safe_label("alice bob"));
assert!(!is_safe_label("alice\""));
assert!(!is_safe_label(&"a".repeat(65)));
}
#[test]
fn downmix_stereo_averages_channels() {
let stereo = vec![1000i16, 2000, -1000, -2000];
let mono = downmix_to_mono(&stereo, 2);
assert_eq!(mono, vec![1500, -1500]);
}
#[test]
fn downmix_mono_is_passthrough() {
let mono = vec![1, 2, 3];
assert_eq!(downmix_to_mono(&mono, 1), mono);
}
#[test]
fn linear_resample_changes_length_proportionally() {
let input = vec![0i16; 1000];
let out = linear_resample(&input, 48_000, 16_000);
assert!((out.len() as i32 - 333).abs() <= 1);
}
#[test]
fn linear_resample_passthrough_at_same_rate() {
let input = vec![1i16, 2, 3, 4];
assert_eq!(linear_resample(&input, 16_000, 16_000), input);
}
#[test]
fn enrollment_round_trips_through_disk() {
let tmp = tempdir().unwrap();
std::env::set_var("HOME", tmp.path());
let mut samples = vec![0i16; 32_000];
for (i, s) in samples.iter_mut().enumerate() {
*s = ((i as f32 * 0.1).sin() * 8000.0) as i16;
}
let enrollment = enroll_from_pcm("test_speaker", &samples, 16_000, 1).unwrap();
assert_eq!(enrollment.label, "test_speaker");
assert!(!enrollment.embedding.values.is_empty());
let path = save_enrollment(&enrollment).unwrap();
assert!(path.exists());
let listed = list_enrollments().unwrap();
assert!(listed.iter().any(|e| e.label == "test_speaker"));
let loaded = load_enrollment("test_speaker").unwrap();
assert_eq!(loaded.label, "test_speaker");
remove_enrollment("test_speaker").unwrap();
let listed_after = list_enrollments().unwrap();
assert!(listed_after.iter().all(|e| e.label != "test_speaker"));
}
#[test]
fn remove_nonexistent_is_ok() {
let tmp = tempdir().unwrap();
std::env::set_var("HOME", tmp.path());
remove_enrollment("nobody").unwrap();
}
#[test]
fn enroll_from_pcm_rejects_invalid_inputs() {
assert!(enroll_from_pcm("a", &[], 16_000, 1).is_err());
assert!(enroll_from_pcm("a", &[1, 2, 3], 16_000, 0).is_err());
}
}