use alloc::format;
use alloc::string::String;
use alloc::string::ToString;
use alloc::vec::Vec;
use std::path::Path;
use tract_onnx::prelude::*;
use crate::{AfpError, AudioBuffer, Result};
#[derive(Clone, Debug)]
pub struct WatermarkConfig {
pub model_path: String,
pub message_bits: u8,
pub threshold: f32,
pub sample_rate: u32,
}
impl WatermarkConfig {
#[must_use]
pub fn new(model_path: impl Into<String>) -> Self {
Self {
model_path: model_path.into(),
message_bits: 16,
threshold: 0.5,
sample_rate: 16_000,
}
}
}
#[derive(Clone, Debug)]
pub struct WatermarkResult {
pub detected: bool,
pub confidence: f32,
pub message: u32,
pub localization: Vec<f32>,
}
pub struct WatermarkDetector {
cfg: WatermarkConfig,
model: InferenceModel,
}
impl WatermarkDetector {
pub fn new(cfg: WatermarkConfig) -> Result<Self> {
if cfg.message_bits > 32 {
return Err(AfpError::Config(format!(
"message_bits ({}) > 32",
cfg.message_bits,
)));
}
if !(0.0..=1.0).contains(&cfg.threshold) {
return Err(AfpError::Config(format!(
"threshold {} not in [0, 1]",
cfg.threshold,
)));
}
if cfg.sample_rate == 0 {
return Err(AfpError::Config("sample_rate must be > 0".to_string()));
}
if cfg.model_path.is_empty() {
return Err(AfpError::ModelNotFound(String::new()));
}
let path = Path::new(&cfg.model_path);
if !path.exists() {
return Err(AfpError::ModelNotFound(cfg.model_path.clone()));
}
let model = tract_onnx::onnx()
.model_for_path(path)
.map_err(|e| AfpError::ModelLoad(format!("load: {e}")))?;
Ok(Self { cfg, model })
}
#[must_use]
pub fn config(&self) -> &WatermarkConfig {
&self.cfg
}
pub fn detect(&mut self, audio: AudioBuffer<'_>) -> Result<WatermarkResult> {
if audio.rate.hz() != self.cfg.sample_rate {
return Err(AfpError::UnsupportedSampleRate(audio.rate.hz()));
}
let n = audio.samples.len();
if n == 0 {
return Err(AfpError::AudioTooShort { needed: 1, got: 0 });
}
let input_tensor = Tensor::from_shape(&[1, 1, n], audio.samples)
.map_err(|e| AfpError::Inference(format!("input shape: {e}")))?;
let runnable = self
.model
.clone()
.with_input_fact(
0,
InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 1, n)),
)
.map_err(|e| AfpError::Inference(format!("input fact: {e}")))?
.into_typed()
.map_err(|e| AfpError::Inference(format!("type: {e}")))?
.into_runnable()
.map_err(|e| AfpError::Inference(format!("runnable: {e}")))?;
let outputs = runnable
.run(tvec!(input_tensor.into()))
.map_err(|e| AfpError::Inference(format!("run: {e}")))?;
if outputs.len() < 2 {
return Err(AfpError::Inference(format!(
"expected ≥ 2 outputs (detection, message), got {}",
outputs.len(),
)));
}
let detection = outputs[0]
.to_array_view::<f32>()
.map_err(|e| AfpError::Inference(format!("detection view: {e}")))?;
let localization: Vec<f32> = detection.iter().copied().collect();
let confidence = if localization.is_empty() {
0.0
} else {
localization.iter().sum::<f32>() / localization.len() as f32
};
let detected = confidence > self.cfg.threshold;
let message_view = outputs[1]
.to_array_view::<f32>()
.map_err(|e| AfpError::Inference(format!("message view: {e}")))?;
let bits = self.cfg.message_bits.min(32) as usize;
let mut message: u32 = 0;
if message_view.len() >= bits {
for (i, &logit) in message_view.iter().take(bits).enumerate() {
if logit >= 0.0 {
message |= 1u32 << i;
}
}
}
Ok(WatermarkResult {
detected,
confidence,
message,
localization,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
fn unique_path(stem: &str) -> std::path::PathBuf {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
std::env::temp_dir().join(format!(
"audiofp-watermark-test-{}-{}-{n}.bin",
std::process::id(),
stem,
))
}
#[test]
fn empty_model_path_returns_model_not_found() {
let res = WatermarkDetector::new(WatermarkConfig::new(""));
match res {
Err(AfpError::ModelNotFound(_)) => {}
Ok(_) => panic!("expected ModelNotFound, got Ok"),
Err(e) => panic!("expected ModelNotFound, got {e:?}"),
}
}
#[test]
fn missing_model_returns_model_not_found() {
let res =
WatermarkDetector::new(WatermarkConfig::new("/nonexistent/path/to/audioseal.onnx"));
match res {
Err(AfpError::ModelNotFound(_)) => {}
Ok(_) => panic!("expected ModelNotFound, got Ok"),
Err(e) => panic!("expected ModelNotFound, got {e:?}"),
}
}
#[test]
fn message_bits_above_32_is_rejected() {
let mut cfg = WatermarkConfig::new("/tmp/dummy.onnx");
cfg.message_bits = 33;
match WatermarkDetector::new(cfg) {
Err(AfpError::Config(_)) => {}
Ok(_) => panic!("expected Config error, got Ok"),
Err(e) => panic!("expected Config error, got {e:?}"),
}
}
#[test]
fn threshold_outside_unit_interval_is_rejected() {
for bad in [-0.5_f32, 1.1, -1.0] {
let mut cfg = WatermarkConfig::new("/tmp/dummy.onnx");
cfg.threshold = bad;
match WatermarkDetector::new(cfg) {
Err(AfpError::Config(_)) => {}
Ok(_) => panic!("expected Config for threshold={bad}, got Ok"),
Err(e) => panic!("expected Config for threshold={bad}, got {e:?}"),
}
}
}
#[test]
fn zero_sample_rate_is_rejected() {
let mut cfg = WatermarkConfig::new("/tmp/dummy.onnx");
cfg.sample_rate = 0;
match WatermarkDetector::new(cfg) {
Err(AfpError::Config(_)) => {}
Ok(_) => panic!("expected Config error, got Ok"),
Err(e) => panic!("expected Config error, got {e:?}"),
}
}
#[test]
fn corrupt_onnx_returns_model_load_error() {
let path = unique_path("corrupt");
{
let mut f = std::fs::File::create(&path).unwrap();
let garbage = [0xAA_u8; 64];
f.write_all(&garbage).unwrap();
}
let res = WatermarkDetector::new(WatermarkConfig::new(path.to_string_lossy().into_owned()));
std::fs::remove_file(&path).ok();
match res {
Err(AfpError::ModelLoad(_)) => {}
Ok(_) => panic!("expected ModelLoad, got Ok"),
Err(e) => panic!("expected ModelLoad, got {e:?}"),
}
}
#[test]
fn config_constructor_uses_audioseal_defaults() {
let cfg = WatermarkConfig::new("model.onnx");
assert_eq!(cfg.message_bits, 16);
assert_eq!(cfg.threshold, 0.5);
assert_eq!(cfg.sample_rate, 16_000);
}
}