car-voice 0.15.1

Voice I/O capability for CAR — mic capture, VAD, listener/speaker traits
Documentation
//! On-disk store for speaker enrollments.
//!
//! Layout:
//! ```text
//! ~/.car/voiceprints/
//!   <label>.toml          # one Enrollment record per file
//! ```
//!
//! Provides the file-IO + WAV-decode helpers consumers need to build
//! and manage enrollments without driving the listener-arming
//! protocol manually. The runtime arming path
//! (`SpeakerPipeline::arm_enrollment` → `capture_enrollment`) still
//! exists for callers who want enrollment to come from a live mic
//! turn; this module is for the file-based path JS / Python / WS
//! callers actually want.

use crate::enrollment::{Enrollment, FilterbankEmbedder, SpeakerEmbedder, SpeakerEmbedding};
use crate::error::VoiceError;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::{Path, PathBuf};

/// Embedder uses 16 kHz mono — see `enrollment::EMBED_SAMPLE_RATE`.
const TARGET_SAMPLE_RATE: u32 = 16_000;

/// Summary record returned by [`list_enrollments`].
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnrollmentInfo {
    pub label: String,
    pub path: PathBuf,
    pub model_id: String,
}

/// Resolve the default enrollment directory: `~/.car/voiceprints/`.
/// Created on first use by the save helpers.
pub fn enrollment_dir() -> Result<PathBuf, VoiceError> {
    let home = std::env::var_os("HOME").map(PathBuf::from).ok_or_else(|| {
        VoiceError::Config("HOME not set; cannot resolve enrollment dir".to_string())
    })?;
    Ok(home.join(".car").join("voiceprints"))
}

/// Path to the enrollment file for a label. Doesn't touch disk.
pub fn enrollment_path(label: &str) -> Result<PathBuf, VoiceError> {
    if !is_safe_label(label) {
        return Err(VoiceError::Config(format!(
            "invalid enrollment label '{}': only [a-zA-Z0-9._-] allowed",
            label
        )));
    }
    Ok(enrollment_dir()?.join(format!("{}.toml", label)))
}

/// Enroll a speaker from a 16-bit signed PCM buffer. Caller supplies
/// the sample rate and channel count; we downmix to mono and resample
/// (linear interpolation) to 16 kHz before embedding.
///
/// Returns the resulting [`Enrollment`] without persisting it. Use
/// [`save_enrollment`] to write to `~/.car/voiceprints/<label>.toml`.
pub fn enroll_from_pcm(
    label: &str,
    samples: &[i16],
    sample_rate: u32,
    channels: u16,
) -> Result<Enrollment, VoiceError> {
    if samples.is_empty() {
        return Err(VoiceError::Config("empty PCM buffer".to_string()));
    }
    if channels == 0 {
        return Err(VoiceError::Config("channels must be ≥ 1".to_string()));
    }
    let mono = downmix_to_mono(samples, channels);
    let resampled = if sample_rate == TARGET_SAMPLE_RATE {
        mono
    } else {
        linear_resample(&mono, sample_rate, TARGET_SAMPLE_RATE)
    };
    let embedder = FilterbankEmbedder::new();
    let values = embedder.embed(&resampled);
    if values.is_empty() {
        return Err(VoiceError::Stt(
            "embedder returned empty vector — sample too short or too quiet".to_string(),
        ));
    }
    let embedding = SpeakerEmbedding::new(values, embedder.model_id().to_string());
    Ok(Enrollment {
        label: label.to_string(),
        embedding,
    })
}

/// Enroll a speaker from a WAV file on disk. Accepts mono or stereo,
/// any sample rate, 16-bit signed integer or 32-bit float PCM.
pub fn enroll_from_wav(label: &str, wav_path: &Path) -> Result<Enrollment, VoiceError> {
    let mut reader = hound::WavReader::open(wav_path)
        .map_err(|e| VoiceError::Config(format!("open WAV {}: {}", wav_path.display(), e)))?;
    let spec = reader.spec();
    let samples_i16: Vec<i16> = match spec.sample_format {
        hound::SampleFormat::Int => match spec.bits_per_sample {
            16 => reader
                .samples::<i16>()
                .collect::<std::result::Result<Vec<_>, _>>()
                .map_err(|e| VoiceError::Config(format!("read WAV samples: {}", e)))?,
            // Other bit depths aren't worth a dep tree; consumers can
            // pre-convert to 16-bit signed.
            other => {
                return Err(VoiceError::Config(format!(
                    "unsupported WAV bit depth {} (only 16-bit signed and 32-bit float)",
                    other
                )));
            }
        },
        hound::SampleFormat::Float => {
            let f: Vec<f32> = reader
                .samples::<f32>()
                .collect::<std::result::Result<Vec<_>, _>>()
                .map_err(|e| VoiceError::Config(format!("read WAV samples: {}", e)))?;
            f.into_iter()
                .map(|s| (s.clamp(-1.0, 1.0) * i16::MAX as f32) as i16)
                .collect()
        }
    };
    enroll_from_pcm(label, &samples_i16, spec.sample_rate, spec.channels)
}

/// Persist an [`Enrollment`] to `~/.car/voiceprints/<label>.toml`.
/// Returns the absolute path on success.
pub fn save_enrollment(enrollment: &Enrollment) -> Result<PathBuf, VoiceError> {
    let dir = enrollment_dir()?;
    fs::create_dir_all(&dir)?;
    let path = enrollment_path(&enrollment.label)?;
    enrollment.save_to(&path)?;
    Ok(path)
}

/// List every enrollment under the default directory.
pub fn list_enrollments() -> Result<Vec<EnrollmentInfo>, VoiceError> {
    let dir = enrollment_dir()?;
    if !dir.exists() {
        return Ok(Vec::new());
    }
    let mut out = Vec::new();
    for entry in fs::read_dir(&dir)? {
        let entry = entry?;
        let path = entry.path();
        if path.extension().and_then(|s| s.to_str()) != Some("toml") {
            continue;
        }
        let Ok(enrollment) = Enrollment::load_from(&path) else {
            tracing::warn!("[enrollment] skipping malformed file {}", path.display());
            continue;
        };
        out.push(EnrollmentInfo {
            label: enrollment.label,
            path: path.clone(),
            model_id: enrollment.embedding.model.clone(),
        });
    }
    out.sort_by(|a, b| a.label.cmp(&b.label));
    Ok(out)
}

/// Load a single enrollment by label.
pub fn load_enrollment(label: &str) -> Result<Enrollment, VoiceError> {
    Enrollment::load_from(&enrollment_path(label)?)
}

/// Delete an enrollment by label. Errors only on filesystem failure;
/// removing a non-existent label is a no-op.
pub fn remove_enrollment(label: &str) -> Result<(), VoiceError> {
    let path = enrollment_path(label)?;
    match fs::remove_file(&path) {
        Ok(()) => Ok(()),
        Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
        Err(e) => Err(VoiceError::Io(e)),
    }
}

// ─── Helpers ──────────────────────────────────────────────────────────────

/// Restrict labels to safe filename characters so we never produce a
/// path with `..` or shell-special chars. Conservative on purpose.
fn is_safe_label(label: &str) -> bool {
    !label.is_empty()
        && label.len() <= 64
        && label
            .chars()
            .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.')
}

fn downmix_to_mono(samples: &[i16], channels: u16) -> Vec<i16> {
    if channels <= 1 {
        return samples.to_vec();
    }
    let n = channels as usize;
    samples
        .chunks_exact(n)
        .map(|frame| {
            let sum: i32 = frame.iter().map(|s| *s as i32).sum();
            (sum / n as i32) as i16
        })
        .collect()
}

fn linear_resample(samples: &[i16], from_rate: u32, to_rate: u32) -> Vec<i16> {
    if from_rate == to_rate || samples.is_empty() {
        return samples.to_vec();
    }
    let ratio = to_rate as f64 / from_rate as f64;
    let out_len = (samples.len() as f64 * ratio) as usize;
    let mut out = Vec::with_capacity(out_len);
    for i in 0..out_len {
        let src = i as f64 / ratio;
        let lo = src.floor() as usize;
        let hi = (lo + 1).min(samples.len() - 1);
        let frac = (src - lo as f64) as f32;
        let v = samples[lo] as f32 * (1.0 - frac) + samples[hi] as f32 * frac;
        out.push(v as i16);
    }
    out
}

#[cfg(test)]
mod tests {
    use super::*;
    use tempfile::tempdir;

    #[test]
    fn label_validation_accepts_normal_names() {
        assert!(is_safe_label("alice"));
        assert!(is_safe_label("alice_smith"));
        assert!(is_safe_label("alice-1"));
        assert!(is_safe_label("alice.work"));
    }

    #[test]
    fn label_validation_rejects_path_traversal() {
        assert!(!is_safe_label(""));
        assert!(!is_safe_label("../etc/passwd"));
        assert!(!is_safe_label("alice/bob"));
        assert!(!is_safe_label("alice bob"));
        assert!(!is_safe_label("alice\""));
        assert!(!is_safe_label(&"a".repeat(65)));
    }

    #[test]
    fn downmix_stereo_averages_channels() {
        let stereo = vec![1000i16, 2000, -1000, -2000];
        let mono = downmix_to_mono(&stereo, 2);
        assert_eq!(mono, vec![1500, -1500]);
    }

    #[test]
    fn downmix_mono_is_passthrough() {
        let mono = vec![1, 2, 3];
        assert_eq!(downmix_to_mono(&mono, 1), mono);
    }

    #[test]
    fn linear_resample_changes_length_proportionally() {
        let input = vec![0i16; 1000];
        let out = linear_resample(&input, 48_000, 16_000);
        // ±1 for integer rounding.
        assert!((out.len() as i32 - 333).abs() <= 1);
    }

    #[test]
    fn linear_resample_passthrough_at_same_rate() {
        let input = vec![1i16, 2, 3, 4];
        assert_eq!(linear_resample(&input, 16_000, 16_000), input);
    }

    /// Generate enough non-zero PCM to satisfy the embedder, then
    /// round-trip through enroll → save → list → load → remove.
    #[test]
    fn enrollment_round_trips_through_disk() {
        // Override HOME so we don't pollute the real ~/.car
        let tmp = tempdir().unwrap();
        std::env::set_var("HOME", tmp.path());

        // 2 seconds of 16 kHz sine-ish signal.
        let mut samples = vec![0i16; 32_000];
        for (i, s) in samples.iter_mut().enumerate() {
            *s = ((i as f32 * 0.1).sin() * 8000.0) as i16;
        }

        let enrollment = enroll_from_pcm("test_speaker", &samples, 16_000, 1).unwrap();
        assert_eq!(enrollment.label, "test_speaker");
        assert!(!enrollment.embedding.values.is_empty());

        let path = save_enrollment(&enrollment).unwrap();
        assert!(path.exists());

        let listed = list_enrollments().unwrap();
        assert!(listed.iter().any(|e| e.label == "test_speaker"));

        let loaded = load_enrollment("test_speaker").unwrap();
        assert_eq!(loaded.label, "test_speaker");

        remove_enrollment("test_speaker").unwrap();
        let listed_after = list_enrollments().unwrap();
        assert!(listed_after.iter().all(|e| e.label != "test_speaker"));
    }

    #[test]
    fn remove_nonexistent_is_ok() {
        let tmp = tempdir().unwrap();
        std::env::set_var("HOME", tmp.path());
        // No enrollment exists; remove must succeed silently.
        remove_enrollment("nobody").unwrap();
    }

    #[test]
    fn enroll_from_pcm_rejects_invalid_inputs() {
        assert!(enroll_from_pcm("a", &[], 16_000, 1).is_err());
        assert!(enroll_from_pcm("a", &[1, 2, 3], 16_000, 0).is_err());
    }
}