use std::path::Path;
use crate::device::DeviceType;
use crate::error::{MlError, MlResult};
use crate::model::OnnxModel;
#[derive(Clone, Debug)]
pub struct AutoCaptionConfig {
pub sample_rate: u32,
pub n_mels: usize,
pub n_fft: usize,
pub hop_length: usize,
pub max_decode_steps: usize,
pub vocab_size: usize,
pub bos_token: u32,
pub eos_token: u32,
pub encoder_input_name: Option<String>,
pub encoder_output_name: Option<String>,
pub decoder_token_input_name: Option<String>,
pub decoder_state_input_name: Option<String>,
pub decoder_logits_output_name: Option<String>,
}
impl Default for AutoCaptionConfig {
fn default() -> Self {
Self {
sample_rate: 16_000,
n_mels: 80,
n_fft: 400,
hop_length: 160,
max_decode_steps: 448,
vocab_size: 51_865,
bos_token: 50_258,
eos_token: 50_257,
encoder_input_name: None,
encoder_output_name: None,
decoder_token_input_name: None,
decoder_state_input_name: None,
decoder_logits_output_name: None,
}
}
}
pub struct AutoCaptionPipeline {
encoder: OnnxModel,
#[cfg_attr(not(feature = "onnx"), allow(dead_code))]
decoder: OnnxModel,
config: AutoCaptionConfig,
}
impl std::fmt::Debug for AutoCaptionPipeline {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AutoCaptionPipeline")
.field("n_mels", &self.config.n_mels)
.field("vocab_size", &self.config.vocab_size)
.field("max_decode_steps", &self.config.max_decode_steps)
.finish()
}
}
impl AutoCaptionPipeline {
pub fn new(
encoder_path: &Path,
decoder_path: &Path,
config: AutoCaptionConfig,
) -> MlResult<Self> {
Self::new_with_device(
encoder_path,
decoder_path,
config,
DeviceType::Cpu,
DeviceType::Cpu,
)
}
pub fn new_with_device(
encoder_path: &Path,
decoder_path: &Path,
config: AutoCaptionConfig,
encoder_device: DeviceType,
decoder_device: DeviceType,
) -> MlResult<Self> {
let encoder = OnnxModel::load(encoder_path, encoder_device)?;
let decoder = OnnxModel::load(decoder_path, decoder_device)?;
Ok(Self {
encoder,
decoder,
config,
})
}
pub fn encode_audio(&self, samples: &[f32]) -> MlResult<Vec<f32>> {
if samples.is_empty() {
return Err(MlError::invalid_input(
"encode_audio: samples must not be empty",
));
}
let log_mel = oximedia_audio::spectrum::fft::compute_log_mel_spectrogram(
samples,
self.config.sample_rate,
self.config.n_mels,
self.config.n_fft,
self.config.hop_length,
);
if log_mel.is_empty() {
return Err(MlError::pipeline(
"auto-caption-encode",
"compute_log_mel_spectrogram returned an empty tensor",
));
}
let n_frames = log_mel.len() / self.config.n_mels;
let input_name = self
.config
.encoder_input_name
.clone()
.or_else(|| self.encoder.info().inputs.first().map(|s| s.name.clone()))
.unwrap_or_else(|| "input".to_string());
let shape = vec![1, self.config.n_mels, n_frames];
let outputs = self
.encoder
.run_single(input_name.as_str(), log_mel, shape)?;
let output_name = self
.config
.encoder_output_name
.clone()
.or_else(|| self.encoder.info().outputs.first().map(|s| s.name.clone()))
.unwrap_or_else(|| "output".to_string());
outputs.get(&output_name).cloned().ok_or_else(|| {
MlError::postprocess(format!(
"encode_audio: encoder output '{output_name}' not found in model run results"
))
})
}
#[cfg(feature = "onnx")]
pub fn step_decode(&self, token_id: u32, encoder_output: &[f32]) -> MlResult<Vec<f32>> {
use oxionnx::Tensor;
use std::collections::HashMap;
let token_input_name = self
.config
.decoder_token_input_name
.as_deref()
.unwrap_or("token");
let state_input_name = self
.config
.decoder_state_input_name
.as_deref()
.unwrap_or("encoder_output");
let token_tensor = Tensor {
data: vec![token_id as f32],
shape: vec![1, 1],
};
let enc_len = encoder_output.len();
let state_tensor = Tensor {
data: encoder_output.to_vec(),
shape: vec![1, enc_len],
};
let mut inputs: HashMap<&str, Tensor> = HashMap::with_capacity(2);
inputs.insert(token_input_name, token_tensor);
inputs.insert(state_input_name, state_tensor);
let outputs = self.decoder.run(&inputs)?;
let logits_name = self
.config
.decoder_logits_output_name
.clone()
.or_else(|| self.decoder.info().outputs.first().map(|s| s.name.clone()))
.unwrap_or_else(|| "logits".to_string());
outputs
.get(&logits_name)
.map(|t| t.data.clone())
.ok_or_else(|| {
MlError::postprocess(format!(
"step_decode: decoder output '{logits_name}' not found in model run results"
))
})
}
#[cfg(not(feature = "onnx"))]
pub fn step_decode(&self, _token_id: u32, _encoder_output: &[f32]) -> MlResult<Vec<f32>> {
Err(MlError::FeatureDisabled("onnx"))
}
pub fn caption(&self, samples: &[f32]) -> MlResult<Vec<u32>> {
let encoder_output = self.encode_audio(samples)?;
let mut tokens: Vec<u32> = Vec::with_capacity(self.config.max_decode_steps);
let mut last_token = self.config.bos_token;
for _step in 0..self.config.max_decode_steps {
let logits = self.step_decode(last_token, &encoder_output)?;
if logits.len() < self.config.vocab_size {
return Err(MlError::postprocess(format!(
"caption: logits length {} is less than vocab_size {}",
logits.len(),
self.config.vocab_size,
)));
}
let vocab_logits = &logits[..self.config.vocab_size];
let next_token = crate::postprocess::argmax(vocab_logits)
.map_err(|e| MlError::postprocess(format!("caption: argmax failed: {e:?}")))?;
let next_token_u32 = u32::try_from(next_token).map_err(|_| {
MlError::postprocess(format!(
"caption: token index {next_token} exceeds u32::MAX"
))
})?;
tokens.push(next_token_u32);
if next_token_u32 == self.config.eos_token {
break;
}
last_token = next_token_u32;
}
Ok(tokens)
}
#[must_use]
pub fn config(&self) -> &AutoCaptionConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_has_whisper_values() {
let cfg = AutoCaptionConfig::default();
assert_eq!(cfg.sample_rate, 16_000, "sample_rate");
assert_eq!(cfg.n_mels, 80, "n_mels");
assert_eq!(cfg.n_fft, 400, "n_fft");
assert_eq!(cfg.hop_length, 160, "hop_length");
assert_eq!(cfg.max_decode_steps, 448, "max_decode_steps");
assert_eq!(cfg.vocab_size, 51_865, "vocab_size");
assert_eq!(cfg.bos_token, 50_258, "bos_token");
assert_eq!(cfg.eos_token, 50_257, "eos_token");
}
#[test]
fn new_with_missing_paths_returns_err() {
let tmp = std::env::temp_dir();
let enc = tmp.join("oximedia-ml-autocaption-nonexistent-enc.onnx");
let dec = tmp.join("oximedia-ml-autocaption-nonexistent-dec.onnx");
let _ = std::fs::remove_file(&enc);
let _ = std::fs::remove_file(&dec);
let result = AutoCaptionPipeline::new(&enc, &dec, AutoCaptionConfig::default());
assert!(
result.is_err(),
"expected Err for missing ONNX paths, got Ok"
);
}
#[test]
fn config_clone_and_debug_do_not_panic() {
let cfg = AutoCaptionConfig::default();
let cfg2 = cfg.clone();
assert_eq!(cfg.sample_rate, cfg2.sample_rate);
let _ = format!("{cfg:?}");
}
}