use neutts::npy::{write_npy_i32, load_npy_i32};
use neutts::tokens::{ids_to_token_str, extract_ids};
use neutts::preprocess::TextPreprocessor;
use neutts::cache::RefCodeCache;
fn tmp_dir(tag: &str) -> std::path::PathBuf {
let d = std::env::temp_dir().join(format!(
"neutts_e2e_{}_{}",
tag,
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.subsec_nanos()
));
std::fs::create_dir_all(&d).unwrap();
d
}
fn make_silence_wav(n_samples: usize, sample_rate: u32) -> Vec<u8> {
let data_size = (n_samples * 2) as u32;
let byte_rate = sample_rate * 2;
let mut buf = Vec::with_capacity(44 + n_samples * 2);
buf.extend_from_slice(b"RIFF");
buf.extend_from_slice(&(36 + data_size).to_le_bytes());
buf.extend_from_slice(b"WAVE");
buf.extend_from_slice(b"fmt ");
buf.extend_from_slice(&16u32.to_le_bytes());
buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&sample_rate.to_le_bytes());
buf.extend_from_slice(&byte_rate.to_le_bytes());
buf.extend_from_slice(&2u16.to_le_bytes()); buf.extend_from_slice(&16u16.to_le_bytes()); buf.extend_from_slice(b"data");
buf.extend_from_slice(&data_size.to_le_bytes());
buf.extend_from_slice(&vec![0u8; n_samples * 2]);
buf
}
#[test]
fn e2e_token_pipeline_roundtrip() {
let original_ids: Vec<i32> = vec![0, 128, 512, 1023, 65535];
let backbone_output = format!(
"{}<|SPEECH_GENERATION_END|>",
ids_to_token_str(&original_ids)
);
let extracted = extract_ids(&backbone_output);
assert_eq!(extracted, original_ids, "token round-trip mismatch");
}
#[test]
fn e2e_token_pipeline_all_fsq_range() {
let ids: Vec<i32> = (0..=65535).step_by(66).collect();
let s = ids_to_token_str(&ids);
let back = extract_ids(&s);
assert_eq!(back, ids);
}
fn to_wav_bytes_standalone(audio: &[f32], sample_rate: u32) -> Vec<u8> {
let peak = audio.iter().map(|&s| s.abs()).fold(0.0f32, f32::max);
let scale = if peak > 1.0 { 1.0 / peak } else { 1.0 };
let data_size = (audio.len() * 2) as u32;
let byte_rate = sample_rate * 2;
let mut buf = Vec::with_capacity(44 + audio.len() * 2);
buf.extend_from_slice(b"RIFF");
buf.extend_from_slice(&(36 + data_size).to_le_bytes());
buf.extend_from_slice(b"WAVE");
buf.extend_from_slice(b"fmt ");
buf.extend_from_slice(&16u32.to_le_bytes());
buf.extend_from_slice(&1u16.to_le_bytes());
buf.extend_from_slice(&1u16.to_le_bytes());
buf.extend_from_slice(&sample_rate.to_le_bytes());
buf.extend_from_slice(&byte_rate.to_le_bytes());
buf.extend_from_slice(&2u16.to_le_bytes());
buf.extend_from_slice(&16u16.to_le_bytes());
buf.extend_from_slice(b"data");
buf.extend_from_slice(&data_size.to_le_bytes());
for &s in audio {
let s16 = (s * scale * i16::MAX as f32).clamp(i16::MIN as f32, i16::MAX as f32) as i16;
buf.extend_from_slice(&s16.to_le_bytes());
}
buf
}
#[test]
fn e2e_wav_bytes_structure_silent() {
let audio = vec![0.0f32; 2400];
let bytes = to_wav_bytes_standalone(&audio, 24_000);
assert_eq!(&bytes[0..4], b"RIFF", "RIFF header");
assert_eq!(&bytes[8..12], b"WAVE", "WAVE chunk");
assert_eq!(&bytes[12..16], b"fmt ", "fmt sub-chunk");
assert_eq!(&bytes[36..40], b"data", "data sub-chunk");
let data_size = u32::from_le_bytes(bytes[40..44].try_into().unwrap());
assert_eq!(data_size, 4800);
assert_eq!(bytes.len(), 44 + 4800);
}
#[test]
fn e2e_wav_bytes_peak_normalization() {
let audio = vec![2.0f32, -3.0, 0.5];
let bytes = to_wav_bytes_standalone(&audio, 24_000);
let s0 = i16::from_le_bytes(bytes[44..46].try_into().unwrap());
let s1 = i16::from_le_bytes(bytes[46..48].try_into().unwrap());
assert!(s0 > 0, "first sample (scaled 2/3) should be positive, got {s0}");
assert!(s1 < 0, "second sample (scaled -1) should be negative, got {s1}");
assert!(s1 >= i16::MIN, "must not underflow i16");
}
#[test]
fn e2e_wav_bytes_unit_signal_not_amplified() {
let audio = vec![1.0f32, -1.0, 0.0];
let bytes = to_wav_bytes_standalone(&audio, 24_000);
let s0 = i16::from_le_bytes(bytes[44..46].try_into().unwrap());
assert_eq!(s0, i16::MAX, "peak 1.0 sample should map to i16::MAX");
let s1 = i16::from_le_bytes(bytes[46..48].try_into().unwrap());
assert_eq!(s1, i16::MIN + 1, "peak -1.0 sample should map close to i16::MIN");
}
#[test]
fn e2e_npy_save_and_reload() {
let dir = tmp_dir("npy_save");
let path = dir.join("ref.npy");
let codes: Vec<i32> = (0..200).map(|i| i * 307 % 65536).collect();
write_npy_i32(&path, &codes).unwrap();
let loaded = load_npy_i32(&path).unwrap();
assert_eq!(loaded, codes, "codes should survive NPY round-trip exactly");
}
#[test]
fn e2e_npy_empty_codes() {
let dir = tmp_dir("npy_empty");
let path = dir.join("empty.npy");
write_npy_i32(&path, &[]).unwrap();
assert!(load_npy_i32(&path).unwrap().is_empty());
}
#[test]
fn e2e_cache_store_reload_evict_clear() {
let dir = tmp_dir("cache_e2e");
let cache = RefCodeCache::with_dir(&dir).unwrap();
let wav_bytes = make_silence_wav(800, 16_000);
let wav_path = dir.join("reference.wav");
std::fs::write(&wav_path, &wav_bytes).unwrap();
assert!(cache.try_load(&wav_path).unwrap().is_none());
let codes: Vec<i32> = vec![10, 20, 30, 40, 1023];
let miss = cache.store(&wav_path, &codes).unwrap();
assert!(!miss.is_hit());
assert!(miss.path().exists());
let (loaded, hit) = cache.try_load(&wav_path).unwrap().unwrap();
assert!(hit.is_hit());
assert_eq!(loaded, codes);
assert!(cache.evict(&wav_path).unwrap());
assert!(cache.try_load(&wav_path).unwrap().is_none());
assert!(!cache.evict(&wav_path).unwrap());
cache.store(&wav_path, &codes).unwrap();
let n = cache.clear().unwrap();
assert_eq!(n, 1);
assert!(cache.try_load(&wav_path).unwrap().is_none());
}
#[test]
fn e2e_preprocessor_realistic_tts_input() {
let input = "On January 1st, 2025, the product sold 1.5K units at $49.99 each — \
that's $74,985 total, or roughly 74K dollars.";
let output = TextPreprocessor::new().process(input);
assert!(output.chars().all(|c| !c.is_uppercase()),
"output should be lowercase: {output}");
assert!(output.contains("first"), "1st → first: {output}");
assert!(output.contains("thousand"), "1.5K → thousand: {output}");
assert!(output.contains("dollar"), "$49.99 → dollars: {output}");
assert!(!output.contains('$'), "currency symbols should be removed: {output}");
assert!(!output.contains('%'), "percent signs should be removed: {output}");
assert!(!output.contains('—'), "em-dash should be removed: {output}");
assert!(!output.contains(','), "commas should be removed: {output}");
println!("Preprocessed: {output}");
}
#[test]
fn e2e_preprocessor_code_snippet() {
let input = "Training lr=1e-4, batch_size=32, max_steps=10K.";
let output = TextPreprocessor::new().process(input);
assert!(output.chars().all(|c| !c.is_uppercase()), "should be lowercase: {output}");
assert!(output.contains("times ten to the"), "1e-4 → …: {output}");
println!("Code snippet: {output}");
}