use std::path::{Path, PathBuf};
use std::sync::Arc;
use oximedia_ml::pipelines::{
SceneClassification as MlSceneClassification, SceneClassifier as MlSceneClassifier,
SceneClassifierConfig, SceneImage,
};
use oximedia_ml::{DeviceType, ModelCache, OnnxModel, PipelineInfo, TypedPipeline};
use crate::error::{SceneError, SceneResult};
pub struct MlSceneEnricher {
classifier: MlSceneClassifier,
top_k: usize,
}
impl MlSceneEnricher {
pub fn from_path(
model_path: impl AsRef<Path>,
labels: Vec<String>,
device: DeviceType,
) -> SceneResult<Self> {
let config = SceneClassifierConfig {
labels: Some(labels),
..SceneClassifierConfig::default()
};
let top_k = config.top_k;
let classifier = MlSceneClassifier::load_with_config(model_path, device, config)?;
Ok(Self { classifier, top_k })
}
#[must_use]
pub fn from_shared_model(
model: Arc<OnnxModel>,
labels: Vec<String>,
model_path: PathBuf,
) -> Self {
let config = SceneClassifierConfig {
labels: Some(labels),
..SceneClassifierConfig::default()
};
let top_k = config.top_k;
let classifier = MlSceneClassifier::from_shared(model, config, model_path);
Self { classifier, top_k }
}
pub fn from_cache(
cache: &ModelCache,
model_path: impl AsRef<Path>,
labels: Vec<String>,
device: DeviceType,
) -> SceneResult<Self> {
let path = model_path.as_ref().to_path_buf();
let model = cache.get_or_load(&path, device)?;
Ok(Self::from_shared_model(model, labels, path))
}
#[must_use]
pub fn with_top_k(mut self, k: usize) -> Self {
let shared = self.classifier.shared_model();
let path = self.classifier.model_path().to_path_buf();
let mut config = self.classifier.config().clone();
config.top_k = k;
self.top_k = k;
self.classifier = MlSceneClassifier::from_shared(shared, config, path);
self
}
#[must_use]
pub fn top_k(&self) -> usize {
self.top_k
}
#[must_use]
pub fn info(&self) -> PipelineInfo {
self.classifier.info()
}
pub fn classify_frame(
&self,
rgb: &[u8],
width: usize,
height: usize,
) -> SceneResult<Vec<(String, f32)>> {
let expected = width
.checked_mul(height)
.and_then(|wh| wh.checked_mul(3))
.ok_or_else(|| {
SceneError::InvalidDimensions(format!(
"width*height*3 overflows usize: width={width} height={height}"
))
})?;
if rgb.len() != expected {
return Err(SceneError::InvalidDimensions(format!(
"expected {expected} bytes, got {}",
rgb.len()
)));
}
let w32 = u32::try_from(width).map_err(|_| {
SceneError::InvalidDimensions(format!("width {width} does not fit in u32"))
})?;
let h32 = u32::try_from(height).map_err(|_| {
SceneError::InvalidDimensions(format!("height {height} does not fit in u32"))
})?;
let image = SceneImage::new(rgb.to_vec(), w32, h32)?;
let raw: Vec<MlSceneClassification> = self.classifier.run(image)?;
Ok(raw.into_iter().map(label_score_pair).collect())
}
}
impl std::fmt::Debug for MlSceneEnricher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MlSceneEnricher")
.field("top_k", &self.top_k)
.field("model_path", &self.classifier.model_path())
.finish()
}
}
fn label_score_pair(pred: MlSceneClassification) -> (String, f32) {
let MlSceneClassification {
class_index,
label,
score,
} = pred;
let label = label.unwrap_or_else(|| format!("class_{class_index}"));
(label, score)
}
#[cfg(test)]
mod tests {
use super::*;
use oximedia_ml::MlError;
use std::path::PathBuf;
#[test]
fn label_score_pair_uses_label_when_present() {
let pred = MlSceneClassification {
class_index: 7,
label: Some("nature".to_string()),
score: 0.42,
};
let (label, score) = label_score_pair(pred);
assert_eq!(label, "nature");
assert!((score - 0.42).abs() < f32::EPSILON);
}
#[test]
fn label_score_pair_falls_back_to_class_index() {
let pred = MlSceneClassification {
class_index: 13,
label: None,
score: 0.77,
};
let (label, score) = label_score_pair(pred);
assert_eq!(label, "class_13");
assert!((score - 0.77).abs() < f32::EPSILON);
}
#[test]
fn from_path_missing_file_returns_ml_error() {
let labels = vec!["indoor".to_string(), "outdoor".to_string()];
let path = PathBuf::from("/does-not-exist-oximedia-scene.onnx");
let err = MlSceneEnricher::from_path(&path, labels, DeviceType::Cpu)
.expect_err("loading a nonexistent model must fail");
assert!(
matches!(err, SceneError::MlError(_)),
"expected SceneError::MlError, got {err:?}"
);
}
#[test]
fn ml_error_from_conversion_is_wired() {
let ml_err = MlError::FeatureDisabled("onnx");
let scene_err: SceneError = ml_err.into();
match scene_err {
SceneError::MlError(inner) => {
assert!(matches!(inner, MlError::FeatureDisabled("onnx")));
}
other => panic!("unexpected conversion result: {other:?}"),
}
}
}