use std::path::Path;
use std::sync::Arc;
use oximedia_ml::{
argmax, softmax, top_k, DeviceType, MlError, ModelCache, OnnxModel, PipelineInfo, PipelineTask,
};
use crate::{CaptionGenError, CaptionGenResult};
const FALLBACK_INPUT_NAME: &str = "input";
#[derive(Clone, Debug, PartialEq)]
pub struct EncoderOutput {
pub logits: Vec<f32>,
pub shape: Vec<usize>,
}
impl EncoderOutput {
#[must_use]
pub fn len(&self) -> usize {
self.logits.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.logits.is_empty()
}
}
pub struct CaptionEncoder {
model: Arc<OnnxModel>,
input_name: String,
output_name: String,
}
impl CaptionEncoder {
pub fn from_path(model_path: impl AsRef<Path>, device: DeviceType) -> CaptionGenResult<Self> {
let path = model_path.as_ref();
let model = Arc::new(OnnxModel::load(path, device)?);
Ok(Self::build(model))
}
pub fn from_shared_model(model: Arc<OnnxModel>) -> CaptionGenResult<Self> {
Ok(Self::build(model))
}
pub fn from_cache(
cache: &ModelCache,
model_path: impl AsRef<Path>,
device: DeviceType,
) -> CaptionGenResult<Self> {
let model = cache.get_or_load(model_path.as_ref(), device)?;
Self::from_shared_model(model)
}
fn build(model: Arc<OnnxModel>) -> Self {
let info = model.info();
let input_name = info
.inputs
.first()
.map(|spec| spec.name.clone())
.unwrap_or_else(|| FALLBACK_INPUT_NAME.to_string());
let output_name = info
.outputs
.first()
.map(|spec| spec.name.clone())
.unwrap_or_default();
Self {
model,
input_name,
output_name,
}
}
#[must_use]
pub fn with_input_name(mut self, name: impl Into<String>) -> Self {
self.input_name = name.into();
self
}
#[must_use]
pub fn with_output_name(mut self, name: impl Into<String>) -> Self {
self.output_name = name.into();
self
}
#[must_use]
pub fn input_name(&self) -> &str {
&self.input_name
}
#[must_use]
pub fn output_name(&self) -> &str {
&self.output_name
}
#[must_use]
pub fn shared_model(&self) -> Arc<OnnxModel> {
self.model.clone()
}
#[must_use]
pub fn info(&self) -> PipelineInfo {
PipelineInfo {
id: "caption-gen/custom-encoder",
name: "Caption Encoder",
task: PipelineTask::Custom,
input_size: None,
}
}
pub fn encode(&self, tensor: &[f32], shape: &[usize]) -> CaptionGenResult<EncoderOutput> {
let shape_vec = shape.to_vec();
let data = tensor.to_vec();
let mut outputs = self.model.run_single(&self.input_name, data, shape_vec)?;
let raw = outputs.remove(&self.output_name).ok_or_else(|| {
CaptionGenError::Ml(MlError::pipeline(
"caption-gen",
format!("output '{}' missing from model", self.output_name),
))
})?;
let info = self.model.info();
let declared_shape: Option<Vec<usize>> = info
.outputs
.iter()
.find(|spec| spec.name == self.output_name)
.and_then(|spec| {
let mut dims = Vec::with_capacity(spec.shape.len());
for d in &spec.shape {
match d {
Some(v) if *v > 0 => dims.push(*v as usize),
_ => return None,
}
}
if dims.iter().product::<usize>() == raw.len() {
Some(dims)
} else {
None
}
});
let out_shape = declared_shape.unwrap_or_else(|| vec![raw.len()]);
Ok(EncoderOutput {
logits: raw,
shape: out_shape,
})
}
}
impl std::fmt::Debug for CaptionEncoder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CaptionEncoder")
.field("input_name", &self.input_name)
.field("output_name", &self.output_name)
.finish()
}
}
pub fn greedy_decode(
logits: &[f32],
vocab_size: usize,
seq_len: usize,
) -> CaptionGenResult<Vec<u32>> {
if vocab_size == 0 {
return Err(CaptionGenError::InvalidParameter(
"vocab_size must be > 0".into(),
));
}
if seq_len == 0 {
return Err(CaptionGenError::InvalidParameter(
"seq_len must be > 0".into(),
));
}
let expected = seq_len.checked_mul(vocab_size).ok_or_else(|| {
CaptionGenError::InvalidParameter("seq_len * vocab_size overflows usize".into())
})?;
if logits.len() != expected {
return Err(CaptionGenError::InvalidParameter(format!(
"logits len {} does not match seq_len ({}) * vocab_size ({}) = {}",
logits.len(),
seq_len,
vocab_size,
expected,
)));
}
let mut out: Vec<u32> = Vec::with_capacity(seq_len);
for step in 0..seq_len {
let start = step * vocab_size;
let end = start + vocab_size;
let row = &logits[start..end];
let idx = argmax(row).map_err(|e| {
CaptionGenError::InvalidParameter(format!(
"greedy_decode: argmax failed on step {step}: {e:?}"
))
})?;
out.push(u32_from_usize(idx)?);
}
Ok(out)
}
pub fn top_k_sample(
logits: &[f32],
vocab_size: usize,
seq_len: usize,
k: usize,
seed: u64,
) -> CaptionGenResult<Vec<u32>> {
if vocab_size == 0 {
return Err(CaptionGenError::InvalidParameter(
"vocab_size must be > 0".into(),
));
}
if seq_len == 0 {
return Err(CaptionGenError::InvalidParameter(
"seq_len must be > 0".into(),
));
}
let expected = seq_len.checked_mul(vocab_size).ok_or_else(|| {
CaptionGenError::InvalidParameter("seq_len * vocab_size overflows usize".into())
})?;
if logits.len() != expected {
return Err(CaptionGenError::InvalidParameter(format!(
"logits len {} does not match seq_len ({}) * vocab_size ({}) = {}",
logits.len(),
seq_len,
vocab_size,
expected,
)));
}
if k == 0 {
return greedy_decode(logits, vocab_size, seq_len);
}
let effective_k = k.min(vocab_size);
let mut out: Vec<u32> = Vec::with_capacity(seq_len);
for step in 0..seq_len {
let start = step * vocab_size;
let end = start + vocab_size;
let row = &logits[start..end];
let top = top_k(row, effective_k).map_err(|e| {
CaptionGenError::InvalidParameter(format!(
"top_k_sample: top_k failed on step {step}: {e:?}"
))
})?;
if top.is_empty() {
let idx = argmax(row).map_err(|e| {
CaptionGenError::InvalidParameter(format!(
"top_k_sample: fallback argmax failed on step {step}: {e:?}"
))
})?;
out.push(u32_from_usize(idx)?);
continue;
}
let top_logits: Vec<f32> = top.iter().map(|(_, v)| *v).collect();
let pmf = softmax(&top_logits);
let r = xorshift_uniform_f32(seed.wrapping_add(step as u64));
let mut acc: f32 = 0.0;
let mut chosen = top.len() - 1; for (i, p) in pmf.iter().enumerate() {
acc += *p;
if r < acc {
chosen = i;
break;
}
}
let picked_vocab_idx = top[chosen].0;
out.push(u32_from_usize(picked_vocab_idx)?);
}
Ok(out)
}
fn u32_from_usize(v: usize) -> CaptionGenResult<u32> {
u32::try_from(v)
.map_err(|_| CaptionGenError::InvalidParameter(format!("token index {v} exceeds u32::MAX")))
}
fn xorshift64_star(seed: u64) -> u64 {
let mut s = if seed == 0 {
0x9E37_79B9_7F4A_7C15
} else {
seed
};
s ^= s >> 12;
s ^= s << 25;
s ^= s >> 27;
s.wrapping_mul(0x2545_F491_4F6C_EDD1)
}
fn xorshift_uniform_f32(seed: u64) -> f32 {
let bits = (xorshift64_star(seed) >> 40) as u32;
(bits as f32) / ((1u32 << 24) as f32)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encoder_output_len_and_is_empty_match_buffer() {
let o = EncoderOutput {
logits: vec![1.0, 2.0, 3.0],
shape: vec![1, 3],
};
assert_eq!(o.len(), 3);
assert!(!o.is_empty());
let empty = EncoderOutput {
logits: Vec::new(),
shape: vec![0],
};
assert_eq!(empty.len(), 0);
assert!(empty.is_empty());
}
#[test]
fn greedy_decode_picks_argmax_per_row() {
let logits = vec![
0.1, 0.8, 0.1, 0.4, 0.2, 0.4, ];
let out = greedy_decode(&logits, 3, 2).expect("ok");
assert_eq!(out, vec![1_u32, 0_u32]);
}
#[test]
fn greedy_decode_rejects_zero_vocab_or_seq_len() {
let e1 = greedy_decode(&[1.0, 2.0], 0, 1).expect_err("vocab=0");
assert!(matches!(e1, CaptionGenError::InvalidParameter(_)));
let e2 = greedy_decode(&[1.0, 2.0], 2, 0).expect_err("seq_len=0");
assert!(matches!(e2, CaptionGenError::InvalidParameter(_)));
}
#[test]
fn greedy_decode_rejects_mismatched_buffer_length() {
let e = greedy_decode(&[0.0, 0.0, 0.0, 0.0, 0.0], 3, 2).expect_err("mismatched len");
assert!(matches!(e, CaptionGenError::InvalidParameter(_)));
}
#[test]
fn top_k_sample_with_k_zero_matches_greedy_decode() {
let logits = vec![0.1, 0.8, 0.1, 0.4, 0.2, 0.4];
let greedy = greedy_decode(&logits, 3, 2).expect("ok");
let sampled = top_k_sample(&logits, 3, 2, 0, 42).expect("ok");
assert_eq!(greedy, sampled);
}
#[test]
fn top_k_sample_is_deterministic_for_identical_seed() {
let logits = vec![0.5, 0.2, 0.1, 0.2, 0.1, 0.4, 0.3, 0.2, 0.2, 0.2, 0.2, 0.4];
let a = top_k_sample(&logits, 4, 3, 2, 12345).expect("ok");
let b = top_k_sample(&logits, 4, 3, 2, 12345).expect("ok");
assert_eq!(a, b);
assert_eq!(a.len(), 3);
}
#[test]
fn top_k_sample_only_emits_tokens_from_top_k_set() {
let logits = vec![
0.1, 0.9, 0.3, 0.7, 0.2, 0.6, 0.4, 0.5, 0.2, 0.1, ];
for seed in 0..64_u64 {
let out = top_k_sample(&logits, 5, 2, 2, seed).expect("ok");
assert!(out[0] == 1 || out[0] == 3, "step 0 outside top-2: {out:?}");
assert!(out[1] == 0 || out[1] == 2, "step 1 outside top-2: {out:?}");
}
}
#[test]
fn top_k_sample_rejects_invalid_sizes() {
let e1 = top_k_sample(&[1.0, 2.0], 0, 1, 1, 0).expect_err("vocab=0");
assert!(matches!(e1, CaptionGenError::InvalidParameter(_)));
let e2 = top_k_sample(&[1.0, 2.0], 2, 0, 1, 0).expect_err("seq_len=0");
assert!(matches!(e2, CaptionGenError::InvalidParameter(_)));
let e3 = top_k_sample(&[1.0, 2.0, 3.0], 2, 2, 1, 0).expect_err("len mismatch");
assert!(matches!(e3, CaptionGenError::InvalidParameter(_)));
}
#[test]
fn top_k_sample_clamps_k_greater_than_vocab() {
let logits = vec![0.1, 0.8, 0.1, 0.2, 0.3, 0.5];
let out = top_k_sample(&logits, 3, 2, 100, 7).expect("ok");
assert_eq!(out.len(), 2);
for &t in &out {
assert!(t < 3, "token {t} outside vocab size 3");
}
}
#[test]
fn xorshift64_star_is_nonzero_for_zero_seed() {
assert_ne!(xorshift64_star(0), 0);
}
#[test]
fn xorshift_uniform_f32_stays_in_unit_interval() {
for seed in 0..4096_u64 {
let r = xorshift_uniform_f32(seed);
assert!((0.0..1.0).contains(&r), "seed {seed} produced {r}");
}
}
#[test]
fn ml_error_roundtrips_into_caption_gen_error() {
fn forward() -> CaptionGenResult<()> {
Err(MlError::FeatureDisabled("onnx"))?;
Ok(())
}
let err = forward().expect_err("must propagate");
assert!(matches!(
err,
CaptionGenError::Ml(MlError::FeatureDisabled("onnx"))
));
}
#[test]
fn from_path_missing_file_returns_ml_error() {
let path = std::env::temp_dir().join("oximedia-caption-gen-nonexistent-encoder.onnx");
if path.exists() {
let _ = std::fs::remove_file(&path);
}
let err = CaptionEncoder::from_path(&path, DeviceType::Cpu)
.expect_err("loading a missing model must fail");
assert!(
matches!(err, CaptionGenError::Ml(_)),
"expected CaptionGenError::Ml, got {err:?}",
);
}
}