use std::path::{Path, PathBuf};
use std::sync::Arc;
use oximedia_ml::{DeviceType, ModelCache, OnnxModel};
use crate::error::{MirError, MirResult};
const FALLBACK_INPUT_NAME: &str = "input";
pub const DEFAULT_TOP_K: usize = 5;
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub enum TagActivation {
#[default]
Softmax,
Sigmoid,
None,
}
#[derive(Clone, Debug, PartialEq)]
pub struct TagActivationScore {
pub label: String,
pub score: f32,
pub index: usize,
}
#[derive(Clone, Debug, Default, PartialEq)]
pub struct MusicTags {
tags: Vec<TagActivationScore>,
activation: TagActivation,
}
impl MusicTags {
#[must_use]
pub fn new(tags: Vec<TagActivationScore>, activation: TagActivation) -> Self {
Self { tags, activation }
}
#[must_use]
pub fn top(&self) -> &[TagActivationScore] {
&self.tags
}
#[must_use]
pub fn len(&self) -> usize {
self.tags.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.tags.is_empty()
}
#[must_use]
pub fn activation(&self) -> TagActivation {
self.activation
}
#[must_use]
pub fn into_inner(self) -> Vec<TagActivationScore> {
self.tags
}
#[must_use]
pub fn best(&self) -> Option<&TagActivationScore> {
self.tags.first()
}
}
pub struct MusicTagger {
model: Arc<OnnxModel>,
model_path: PathBuf,
input_name: String,
output_name: String,
labels: Vec<String>,
top_k: usize,
activation: TagActivation,
}
impl MusicTagger {
pub fn from_path(model_path: impl AsRef<Path>, device: DeviceType) -> MirResult<Self> {
let path = model_path.as_ref().to_path_buf();
let model = Arc::new(OnnxModel::load(&path, device)?);
Ok(Self::build(model, path))
}
#[must_use]
pub fn from_shared_model(model: Arc<OnnxModel>, model_path: PathBuf) -> Self {
Self::build(model, model_path)
}
pub fn from_cache(
cache: &ModelCache,
model_path: impl AsRef<Path>,
device: DeviceType,
) -> MirResult<Self> {
let path = model_path.as_ref().to_path_buf();
let model = cache.get_or_load(&path, device)?;
Ok(Self::from_shared_model(model, path))
}
fn build(model: Arc<OnnxModel>, model_path: PathBuf) -> Self {
let info = model.info();
let input_name = info
.inputs
.first()
.map(|spec| spec.name.clone())
.unwrap_or_else(|| FALLBACK_INPUT_NAME.to_string());
let output_name = info
.outputs
.first()
.map(|spec| spec.name.clone())
.unwrap_or_default();
Self {
model,
model_path,
input_name,
output_name,
labels: Vec::new(),
top_k: DEFAULT_TOP_K,
activation: TagActivation::default(),
}
}
#[must_use]
pub fn with_input_name(mut self, name: impl Into<String>) -> Self {
self.input_name = name.into();
self
}
#[must_use]
pub fn with_output_name(mut self, name: impl Into<String>) -> Self {
self.output_name = name.into();
self
}
#[must_use]
pub fn with_labels(mut self, labels: Vec<String>) -> Self {
self.labels = labels;
self
}
#[must_use]
pub fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = top_k.max(1);
self
}
#[must_use]
pub fn with_activation(mut self, activation: TagActivation) -> Self {
self.activation = activation;
self
}
#[must_use]
pub fn input_name(&self) -> &str {
&self.input_name
}
#[must_use]
pub fn output_name(&self) -> &str {
&self.output_name
}
#[must_use]
pub fn model_path(&self) -> &Path {
&self.model_path
}
#[must_use]
pub fn labels(&self) -> &[String] {
&self.labels
}
#[must_use]
pub fn top_k(&self) -> usize {
self.top_k
}
#[must_use]
pub fn activation(&self) -> TagActivation {
self.activation
}
#[must_use]
pub fn shared_model(&self) -> Arc<OnnxModel> {
self.model.clone()
}
pub fn classify(&self, tensor: &[f32], shape: &[usize]) -> MirResult<MusicTags> {
let mut outputs =
self.model
.run_single(&self.input_name, tensor.to_vec(), shape.to_vec())?;
let logits = outputs.remove(&self.output_name).ok_or_else(|| {
MirError::Ml(oximedia_ml::MlError::pipeline(
"music-tag",
format!(
"output '{}' missing from model '{}'",
self.output_name,
self.model_path.display(),
),
))
})?;
if logits.is_empty() {
return Err(MirError::AnalysisFailed(format!(
"music-tagging model '{}' returned an empty output tensor",
self.model_path.display(),
)));
}
let ranked = activate_and_rank(&logits, &self.labels, self.top_k, self.activation)?;
Ok(MusicTags::new(ranked, self.activation))
}
}
impl std::fmt::Debug for MusicTagger {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MusicTagger")
.field("input_name", &self.input_name)
.field("output_name", &self.output_name)
.field("model_path", &self.model_path)
.field("labels_len", &self.labels.len())
.field("top_k", &self.top_k)
.field("activation", &self.activation)
.finish()
}
}
pub fn activate_and_rank(
logits: &[f32],
labels: &[String],
top_k: usize,
activation: TagActivation,
) -> MirResult<Vec<TagActivationScore>> {
let scores = apply_activation(logits, activation);
let effective_k = top_k.min(scores.len()).max(1);
let ranked = oximedia_ml::postprocess::top_k(&scores, effective_k)?;
let mut out = Vec::with_capacity(ranked.len());
for (idx, score) in ranked {
let label = labels
.get(idx)
.cloned()
.unwrap_or_else(|| format!("class_{idx}"));
out.push(TagActivationScore {
label,
score,
index: idx,
});
}
Ok(out)
}
#[must_use]
pub fn apply_activation(logits: &[f32], activation: TagActivation) -> Vec<f32> {
match activation {
TagActivation::Softmax => oximedia_ml::postprocess::softmax(logits),
TagActivation::Sigmoid => oximedia_ml::postprocess::sigmoid_slice(logits),
TagActivation::None => logits.to_vec(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use oximedia_ml::MlError;
#[test]
fn softmax_activation_sums_to_one() {
let probs = apply_activation(&[1.0, 2.0, 3.0], TagActivation::Softmax);
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"softmax output must sum to 1, got {sum}",
);
}
#[test]
fn sigmoid_activation_stays_in_zero_one() {
let probs = apply_activation(&[-10.0, 0.0, 10.0], TagActivation::Sigmoid);
for p in &probs {
assert!((0.0..=1.0).contains(p), "sigmoid out of range: {p}");
}
assert!(probs[0] < 0.001);
assert!((probs[1] - 0.5).abs() < 1e-6);
assert!(probs[2] > 0.999);
}
#[test]
fn none_activation_is_identity() {
let raw = vec![1.5_f32, -2.5, 0.0];
let out = apply_activation(&raw, TagActivation::None);
assert_eq!(out, raw);
}
#[test]
fn activate_and_rank_sorts_descending() {
let labels = vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
];
let logits = vec![0.1_f32, 5.0, 0.3, 0.2];
let ranked = activate_and_rank(&logits, &labels, 4, TagActivation::Softmax).expect("ok");
assert_eq!(ranked.len(), 4);
assert_eq!(ranked[0].label, "b");
assert_eq!(ranked[0].index, 1);
for w in ranked.windows(2) {
assert!(
w[0].score >= w[1].score,
"ranking violates descending invariant: {w:?}",
);
}
}
#[test]
fn activate_and_rank_top_k_truncates() {
let labels = vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
"e".to_string(),
];
let logits = vec![0.1_f32, 0.5, 0.3, 0.7, 0.2];
let ranked = activate_and_rank(&logits, &labels, 2, TagActivation::Softmax).expect("ok");
assert_eq!(ranked.len(), 2);
assert_eq!(ranked[0].index, 3);
assert_eq!(ranked[1].index, 1);
}
#[test]
fn activate_and_rank_missing_labels_generates_synthetic() {
let labels: Vec<String> = Vec::new();
let logits = vec![0.1_f32, 0.9];
let ranked = activate_and_rank(&logits, &labels, 2, TagActivation::Softmax).expect("ok");
assert_eq!(ranked.len(), 2);
assert_eq!(ranked[0].index, 1);
assert_eq!(ranked[0].label, "class_1");
assert_eq!(ranked[1].index, 0);
assert_eq!(ranked[1].label, "class_0");
}
#[test]
fn activate_and_rank_shorter_labels_falls_back_to_synthetic() {
let labels = vec!["a".to_string(), "b".to_string()];
let logits = vec![0.1_f32, 0.2, 0.9];
let ranked = activate_and_rank(&logits, &labels, 3, TagActivation::Softmax).expect("ok");
assert_eq!(ranked.len(), 3);
assert_eq!(ranked[0].index, 2);
assert_eq!(ranked[0].label, "class_2");
}
#[test]
fn activate_and_rank_empty_logits_errors() {
let err = activate_and_rank(&[], &[], 3, TagActivation::Softmax).expect_err("must fail");
assert!(
matches!(err, MirError::Ml(MlError::Postprocess(_))),
"expected MlError::Postprocess, got {err:?}",
);
}
#[test]
fn activate_and_rank_top_k_zero_is_clamped_to_one() {
let labels = vec!["a".to_string(), "b".to_string()];
let logits = vec![0.2_f32, 0.8];
let ranked = activate_and_rank(&logits, &labels, 0, TagActivation::Softmax).expect("ok");
assert_eq!(ranked.len(), 1);
assert_eq!(ranked[0].index, 1);
}
#[test]
fn music_tags_getters_work() {
let tags = MusicTags::new(
vec![
TagActivationScore {
label: "rock".into(),
score: 0.6,
index: 0,
},
TagActivationScore {
label: "jazz".into(),
score: 0.3,
index: 1,
},
],
TagActivation::Sigmoid,
);
assert_eq!(tags.len(), 2);
assert!(!tags.is_empty());
assert_eq!(tags.activation(), TagActivation::Sigmoid);
assert_eq!(tags.best().map(|t| t.label.as_str()), Some("rock"));
assert_eq!(tags.top().len(), 2);
let owned = tags.into_inner();
assert_eq!(owned.len(), 2);
}
#[test]
fn music_tags_empty_reports_empty() {
let tags = MusicTags::default();
assert!(tags.is_empty());
assert_eq!(tags.len(), 0);
assert!(tags.best().is_none());
assert_eq!(tags.top().len(), 0);
}
#[test]
fn tag_activation_default_is_softmax() {
assert_eq!(TagActivation::default(), TagActivation::Softmax);
}
#[test]
fn mir_error_from_ml_error_is_wired() {
let ml_err = MlError::FeatureDisabled("onnx");
let mir_err: MirError = ml_err.into();
assert!(
matches!(mir_err, MirError::Ml(MlError::FeatureDisabled("onnx"))),
"expected MirError::Ml(FeatureDisabled), got {mir_err:?}",
);
}
#[test]
fn sigmoid_multi_label_preserves_independence() {
let logits = vec![4.0_f32, 4.0];
let probs = apply_activation(&logits, TagActivation::Sigmoid);
assert!(probs[0] > 0.9);
assert!(probs[1] > 0.9);
}
#[test]
fn from_path_missing_file_returns_ml_error() {
let path = std::path::PathBuf::from("/does-not-exist-oximedia-mir-music-tagger.onnx");
let err = MusicTagger::from_path(&path, DeviceType::Cpu)
.expect_err("loading a nonexistent model must fail");
assert!(
matches!(err, MirError::Ml(_)),
"expected MirError::Ml, got {err:?}",
);
}
}