use anyhow::Result;
use std::fs;
use std::path::Path;
use super::{Confidence, Signal, SignalBuilder, SignalSource};
use crate::known_tools;
const TTS_SAMPLE_RATES: &[u32] = &[22050, 24000, 16000];
pub(crate) struct WavFmt {
pub channels: u16,
pub sample_rate: u32,
pub bits_per_sample: u16,
}
#[allow(dead_code)]
pub(crate) struct WavFile {
pub fmt: WavFmt,
pub info_entries: Vec<(String, String)>,
pub pcm_start: usize,
pub pcm_end: usize,
}
pub(crate) fn parse_wav_full(data: &[u8]) -> Option<WavFile> {
let (fmt, info_entries, pcm_start, pcm_end) = parse_wav_inner(data)?;
Some(WavFile {
fmt,
info_entries,
pcm_start,
pcm_end,
})
}
fn parse_wav(data: &[u8]) -> Option<(WavFmt, Vec<(String, String)>)> {
let (fmt, info, _, _) = parse_wav_inner(data)?;
Some((fmt, info))
}
#[allow(clippy::type_complexity)]
fn parse_wav_inner(data: &[u8]) -> Option<(WavFmt, Vec<(String, String)>, usize, usize)> {
if data.len() < 36 {
return None;
}
if &data[0..4] != b"RIFF" || &data[8..12] != b"WAVE" {
return None;
}
let mut fmt = None;
let mut info_entries = Vec::new();
let mut data_start = 0usize;
let mut data_end = 0usize;
let mut pos = 12;
while pos + 8 <= data.len() {
let chunk_id = &data[pos..pos + 4];
let chunk_size =
u32::from_le_bytes([data[pos + 4], data[pos + 5], data[pos + 6], data[pos + 7]])
as usize;
let chunk_data_start = pos + 8;
let chunk_data_end = (chunk_data_start + chunk_size).min(data.len());
if chunk_id == b"fmt " && chunk_size >= 16 {
let d = &data[chunk_data_start..chunk_data_end];
fmt = Some(WavFmt {
channels: u16::from_le_bytes([d[2], d[3]]),
sample_rate: u32::from_le_bytes([d[4], d[5], d[6], d[7]]),
bits_per_sample: u16::from_le_bytes([d[14], d[15]]),
});
} else if chunk_id == b"data" {
data_start = chunk_data_start;
data_end = chunk_data_end;
} else if chunk_id == b"LIST" && chunk_size >= 4 {
let list_type = &data[chunk_data_start..chunk_data_start + 4];
if list_type == b"INFO" {
let mut sub_pos = chunk_data_start + 4;
while sub_pos + 8 <= chunk_data_end {
let sub_id = std::str::from_utf8(&data[sub_pos..sub_pos + 4])
.unwrap_or("????")
.to_string();
let sub_size = u32::from_le_bytes([
data[sub_pos + 4],
data[sub_pos + 5],
data[sub_pos + 6],
data[sub_pos + 7],
]) as usize;
let sub_data_start = sub_pos + 8;
let sub_data_end = (sub_data_start + sub_size).min(chunk_data_end);
if sub_data_start < sub_data_end {
let value = String::from_utf8_lossy(&data[sub_data_start..sub_data_end])
.trim_matches('\0')
.to_string();
if !value.is_empty() {
info_entries.push((sub_id, value));
}
}
sub_pos = sub_data_start + ((sub_size + 1) & !1);
}
}
}
pos = chunk_data_start + ((chunk_size + 1) & !1);
}
fmt.map(|f| (f, info_entries, data_start, data_end))
}
pub fn detect(path: &Path) -> Result<Vec<Signal>> {
let data = fs::read(path)?;
let (fmt, info_entries) = match parse_wav(&data) {
Some(r) => r,
None => return Ok(vec![]),
};
let mut signals = Vec::new();
let tool_keys = ["ISFT", "ICMT", "IART", "IENG", "IPRD", "IGNR"];
for (key, value) in &info_entries {
if tool_keys.contains(&key.as_str()) {
if let Some(tool_name) = known_tools::match_ai_tool(value) {
signals.push(
SignalBuilder::new(
SignalSource::WavMetadata,
Confidence::Medium,
"signal_wav_info_tool",
)
.param("key", key.as_str())
.param("value", value.as_str())
.tool(tool_name)
.detail(key.as_str(), value.as_str())
.build(),
);
}
}
}
let is_tts_rate = TTS_SAMPLE_RATES.contains(&fmt.sample_rate);
let is_mono = fmt.channels == 1;
if is_mono && is_tts_rate {
signals.push(
SignalBuilder::new(
SignalSource::WavMetadata,
Confidence::Low,
"signal_wav_tts_heuristic",
)
.param("rate", fmt.sample_rate.to_string())
.param("bits", fmt.bits_per_sample.to_string())
.detail("channels", fmt.channels.to_string())
.detail("sample_rate", format!("{}Hz", fmt.sample_rate))
.detail("bits_per_sample", fmt.bits_per_sample.to_string())
.build(),
);
}
Ok(signals)
}
pub fn dump_info(path: &Path) -> Result<Vec<(String, String)>> {
let data = fs::read(path)?;
let (fmt, info_entries) = match parse_wav(&data) {
Some(r) => r,
None => return Ok(vec![]),
};
let mut props = Vec::new();
props.push(("Sample Rate".to_string(), format!("{}Hz", fmt.sample_rate)));
props.push(("Channels".to_string(), fmt.channels.to_string()));
props.push((
"Bits Per Sample".to_string(),
fmt.bits_per_sample.to_string(),
));
for (key, value) in info_entries {
let label = match key.as_str() {
"ISFT" => "Software (ISFT)",
"ICMT" => "Comment (ICMT)",
"IART" => "Artist (IART)",
"IENG" => "Engineer (IENG)",
"IPRD" => "Product (IPRD)",
"IGNR" => "Genre (IGNR)",
"INAM" => "Name (INAM)",
"ICRD" => "Date (ICRD)",
other => other,
};
props.push((label.to_string(), value));
}
Ok(props)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_wav(
channels: u16,
sample_rate: u32,
bits_per_sample: u16,
info_chunks: &[(&str, &str)],
) -> Vec<u8> {
let byte_rate = sample_rate * channels as u32 * bits_per_sample as u32 / 8;
let block_align = channels * bits_per_sample / 8;
let data_size = 100u32 * block_align as u32;
let mut buf = Vec::new();
let mut info_buf = Vec::new();
if !info_chunks.is_empty() {
info_buf.extend_from_slice(b"INFO");
for &(key, value) in info_chunks {
let val_bytes = value.as_bytes();
let padded_len = ((val_bytes.len() + 1 + 1) & !1) as u32;
info_buf.extend_from_slice(key.as_bytes());
info_buf.extend_from_slice(&padded_len.to_le_bytes());
info_buf.extend_from_slice(val_bytes);
info_buf.push(0);
if (val_bytes.len() + 1) % 2 != 0 {
info_buf.push(0);
}
}
}
let fmt_size = 16u32;
let list_chunk_size = if info_buf.is_empty() {
0
} else {
8 + info_buf.len() as u32
};
let riff_size = 4 + 8 + fmt_size + 8 + data_size + list_chunk_size;
buf.extend_from_slice(b"RIFF");
buf.extend_from_slice(&riff_size.to_le_bytes());
buf.extend_from_slice(b"WAVE");
buf.extend_from_slice(b"fmt ");
buf.extend_from_slice(&fmt_size.to_le_bytes());
buf.extend_from_slice(&1u16.to_le_bytes());
buf.extend_from_slice(&channels.to_le_bytes());
buf.extend_from_slice(&sample_rate.to_le_bytes());
buf.extend_from_slice(&byte_rate.to_le_bytes());
buf.extend_from_slice(&block_align.to_le_bytes());
buf.extend_from_slice(&bits_per_sample.to_le_bytes());
if !info_buf.is_empty() {
buf.extend_from_slice(b"LIST");
buf.extend_from_slice(&(info_buf.len() as u32).to_le_bytes());
buf.extend_from_slice(&info_buf);
}
buf.extend_from_slice(b"data");
buf.extend_from_slice(&data_size.to_le_bytes());
buf.extend_from_slice(&vec![0u8; data_size as usize]);
buf
}
#[test]
fn test_parse_wav_basic() {
let wav = make_wav(1, 24000, 16, &[]);
let (fmt, info) = parse_wav(&wav).unwrap();
assert_eq!(fmt.channels, 1);
assert_eq!(fmt.sample_rate, 24000);
assert_eq!(fmt.bits_per_sample, 16);
assert!(info.is_empty());
}
#[test]
fn test_parse_wav_with_info() {
let wav = make_wav(2, 44100, 16, &[("ISFT", "Test Software")]);
let (fmt, info) = parse_wav(&wav).unwrap();
assert_eq!(fmt.channels, 2);
assert_eq!(fmt.sample_rate, 44100);
assert_eq!(info.len(), 1);
assert_eq!(info[0].0, "ISFT");
assert_eq!(info[0].1, "Test Software");
}
#[test]
fn test_tts_heuristic_mono_24k() {
let wav = make_wav(1, 24000, 16, &[]);
let tmp = tempfile::NamedTempFile::new().unwrap();
fs::write(tmp.path(), &wav).unwrap();
let signals = detect(tmp.path()).unwrap();
assert_eq!(signals.len(), 1);
assert_eq!(signals[0].confidence, Confidence::Low);
}
#[test]
fn test_no_tts_heuristic_stereo_44k() {
let wav = make_wav(2, 44100, 16, &[]);
let tmp = tempfile::NamedTempFile::new().unwrap();
fs::write(tmp.path(), &wav).unwrap();
let signals = detect(tmp.path()).unwrap();
assert!(signals.is_empty());
}
#[test]
fn test_info_tool_match() {
let wav = make_wav(2, 44100, 16, &[("ISFT", "ElevenLabs TTS")]);
let tmp = tempfile::NamedTempFile::new().unwrap();
fs::write(tmp.path(), &wav).unwrap();
let signals = detect(tmp.path()).unwrap();
assert_eq!(signals.len(), 1);
assert_eq!(signals[0].tool, Some("elevenlabs".to_string()));
assert_eq!(signals[0].confidence, Confidence::Medium);
}
#[test]
fn test_not_wav() {
let data = b"not a wav file at all";
assert!(parse_wav(data).is_none());
}
#[test]
fn test_dump_info() {
let wav = make_wav(1, 24000, 16, &[("ISFT", "MyTool"), ("ICMT", "A comment")]);
let tmp = tempfile::NamedTempFile::new().unwrap();
fs::write(tmp.path(), &wav).unwrap();
let props = dump_info(tmp.path()).unwrap();
assert!(props
.iter()
.any(|(k, v)| k == "Sample Rate" && v == "24000Hz"));
assert!(props
.iter()
.any(|(k, v)| k.contains("ISFT") && v == "MyTool"));
assert!(props
.iter()
.any(|(k, v)| k.contains("ICMT") && v == "A comment"));
}
}