use std::path::{Path, PathBuf};
use std::sync::Arc;
use oximedia_ml::{DeviceType, ModelCache, OnnxModel};
use crate::error::{RecommendError, RecommendResult};
const FALLBACK_INPUT_NAME: &str = "input";
#[derive(Clone, Debug, PartialEq)]
pub struct ContentEmbedding {
vector: Vec<f32>,
}
impl ContentEmbedding {
pub fn new(mut vector: Vec<f32>) -> RecommendResult<Self> {
if vector.is_empty() {
return Err(RecommendError::InvalidSimilarity(0.0));
}
let norm_sq: f32 = vector.iter().map(|x| x * x).sum();
if !norm_sq.is_finite() || norm_sq <= 0.0 {
return Err(RecommendError::InvalidSimilarity(norm_sq));
}
oximedia_ml::l2_normalize(&mut vector);
Ok(Self { vector })
}
#[must_use]
pub fn from_normalized(vector: Vec<f32>) -> Self {
Self { vector }
}
#[must_use]
pub fn dim(&self) -> usize {
self.vector.len()
}
#[must_use]
pub fn as_slice(&self) -> &[f32] {
&self.vector
}
#[must_use]
pub fn into_inner(self) -> Vec<f32> {
self.vector
}
#[must_use]
pub fn cosine_similarity(&self, other: &Self) -> f32 {
oximedia_ml::cosine_similarity(&self.vector, &other.vector)
}
#[must_use]
pub fn euclidean_distance(&self, other: &Self) -> f32 {
if self.vector.len() != other.vector.len() {
return f32::INFINITY;
}
let sum_sq: f32 = self
.vector
.iter()
.zip(other.vector.iter())
.map(|(&a, &b)| {
let d = a - b;
d * d
})
.sum();
sum_sq.max(0.0).sqrt()
}
}
pub struct EmbeddingExtractor {
model: Arc<OnnxModel>,
model_path: PathBuf,
input_name: String,
output_name: String,
}
impl EmbeddingExtractor {
pub fn from_path(model_path: impl AsRef<Path>, device: DeviceType) -> RecommendResult<Self> {
let path = model_path.as_ref().to_path_buf();
let model = Arc::new(OnnxModel::load(&path, device)?);
Ok(Self::build(model, path))
}
#[must_use]
pub fn from_shared_model(model: Arc<OnnxModel>, model_path: PathBuf) -> Self {
Self::build(model, model_path)
}
pub fn from_cache(
cache: &ModelCache,
model_path: impl AsRef<Path>,
device: DeviceType,
) -> RecommendResult<Self> {
let path = model_path.as_ref().to_path_buf();
let model = cache.get_or_load(&path, device)?;
Ok(Self::from_shared_model(model, path))
}
fn build(model: Arc<OnnxModel>, model_path: PathBuf) -> Self {
let info = model.info();
let input_name = info
.inputs
.first()
.map(|spec| spec.name.clone())
.unwrap_or_else(|| FALLBACK_INPUT_NAME.to_string());
let output_name = info
.outputs
.first()
.map(|spec| spec.name.clone())
.unwrap_or_default();
Self {
model,
model_path,
input_name,
output_name,
}
}
#[must_use]
pub fn with_input_name(mut self, name: impl Into<String>) -> Self {
self.input_name = name.into();
self
}
#[must_use]
pub fn with_output_name(mut self, name: impl Into<String>) -> Self {
self.output_name = name.into();
self
}
#[must_use]
pub fn input_name(&self) -> &str {
&self.input_name
}
#[must_use]
pub fn output_name(&self) -> &str {
&self.output_name
}
#[must_use]
pub fn model_path(&self) -> &Path {
&self.model_path
}
#[must_use]
pub fn shared_model(&self) -> Arc<OnnxModel> {
self.model.clone()
}
pub fn extract(&self, data: Vec<f32>, shape: Vec<usize>) -> RecommendResult<ContentEmbedding> {
let mut outputs = self.model.run_single(&self.input_name, data, shape)?;
let raw = outputs.remove(&self.output_name).ok_or_else(|| {
RecommendError::Ml(oximedia_ml::MlError::pipeline(
"embed",
format!(
"output '{}' missing from model '{}'",
self.output_name,
self.model_path.display(),
),
))
})?;
ContentEmbedding::new(raw)
}
}
impl std::fmt::Debug for EmbeddingExtractor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EmbeddingExtractor")
.field("input_name", &self.input_name)
.field("output_name", &self.output_name)
.field("model_path", &self.model_path)
.finish()
}
}
#[must_use]
pub fn rank_by_similarity(
query: &ContentEmbedding,
candidates: &[ContentEmbedding],
top_k: usize,
) -> Vec<(usize, f32)> {
if candidates.is_empty() || top_k == 0 {
return Vec::new();
}
let mut scored: Vec<(usize, f32)> = candidates
.iter()
.enumerate()
.map(|(idx, c)| (idx, query.cosine_similarity(c)))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
scored
}
#[cfg(feature = "face-embedder-sim")]
pub mod face {
use super::{ContentEmbedding, RecommendResult};
use oximedia_ml::pipelines::{FaceEmbedder, FaceImage};
use oximedia_ml::{DeviceType, TypedPipeline};
use std::path::Path;
pub struct FaceContentExtractor {
embedder: FaceEmbedder,
}
impl FaceContentExtractor {
pub fn from_path(path: impl AsRef<Path>, device: DeviceType) -> RecommendResult<Self> {
let embedder = FaceEmbedder::load(path, device)?;
Ok(Self { embedder })
}
pub fn extract(
&self,
pixels: Vec<u8>,
width: u32,
height: u32,
) -> RecommendResult<ContentEmbedding> {
let image = FaceImage::new(pixels, width, height)?;
let face = self.embedder.run(image)?;
Ok(ContentEmbedding::from_normalized(face.into_inner()))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use oximedia_ml::MlError;
#[test]
fn content_embedding_new_normalises_vector() {
let e = ContentEmbedding::new(vec![3.0, 4.0]).expect("ok");
let s = e.as_slice();
let norm: f32 = s.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
assert_eq!(e.dim(), 2);
}
#[test]
fn content_embedding_new_rejects_empty() {
let err = ContentEmbedding::new(Vec::<f32>::new()).expect_err("must fail");
assert!(matches!(err, RecommendError::InvalidSimilarity(_)));
}
#[test]
fn content_embedding_new_rejects_zero_vector() {
let err = ContentEmbedding::new(vec![0.0_f32; 4]).expect_err("must fail");
assert!(matches!(err, RecommendError::InvalidSimilarity(_)));
}
#[test]
fn from_normalized_bypasses_normalisation() {
let raw = vec![1.0_f32, 2.0, 3.0];
let e = ContentEmbedding::from_normalized(raw.clone());
assert_eq!(e.as_slice(), raw.as_slice());
}
#[test]
fn cosine_similarity_identical_is_one() {
let a = ContentEmbedding::new(vec![1.0_f32, 0.0, 0.0]).expect("ok");
let b = ContentEmbedding::new(vec![1.0_f32, 0.0, 0.0]).expect("ok");
assert!((a.cosine_similarity(&b) - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_similarity_orthogonal_is_zero() {
let a = ContentEmbedding::new(vec![1.0_f32, 0.0]).expect("ok");
let b = ContentEmbedding::new(vec![0.0_f32, 1.0]).expect("ok");
assert!(a.cosine_similarity(&b).abs() < 1e-6);
}
#[test]
fn cosine_similarity_antiparallel_is_minus_one() {
let a = ContentEmbedding::new(vec![1.0_f32, 0.0]).expect("ok");
let b = ContentEmbedding::new(vec![-1.0_f32, 0.0]).expect("ok");
assert!((a.cosine_similarity(&b) - (-1.0)).abs() < 1e-6);
}
#[test]
fn euclidean_distance_on_unit_norm_obeys_identity() {
let a = ContentEmbedding::new(vec![1.0_f32, 0.0]).expect("ok");
let b = ContentEmbedding::new(vec![0.0_f32, 1.0]).expect("ok");
let cos = a.cosine_similarity(&b);
let expected = (2.0 - 2.0 * cos).max(0.0).sqrt();
assert!((a.euclidean_distance(&b) - expected).abs() < 1e-5);
}
#[test]
fn euclidean_distance_dim_mismatch_is_infinity() {
let a = ContentEmbedding::new(vec![1.0_f32, 0.0]).expect("ok");
let b = ContentEmbedding::new(vec![1.0_f32, 0.0, 0.0]).expect("ok");
assert!(a.euclidean_distance(&b).is_infinite());
}
#[test]
fn rank_by_similarity_returns_descending_top_k() {
let query = ContentEmbedding::new(vec![1.0_f32, 0.0, 0.0]).expect("ok");
let candidates = vec![
ContentEmbedding::new(vec![0.0_f32, 1.0, 0.0]).expect("ok"),
ContentEmbedding::new(vec![1.0_f32, 0.0, 0.0]).expect("ok"),
ContentEmbedding::new(vec![0.5_f32, 0.5, 0.0]).expect("ok"),
];
let ranked = rank_by_similarity(&query, &candidates, 2);
assert_eq!(ranked.len(), 2);
assert_eq!(ranked[0].0, 1);
assert!((ranked[0].1 - 1.0).abs() < 1e-5);
assert_eq!(ranked[1].0, 2);
}
#[test]
fn rank_by_similarity_empty_candidates_returns_empty() {
let query = ContentEmbedding::new(vec![1.0_f32, 0.0]).expect("ok");
assert!(rank_by_similarity(&query, &[], 5).is_empty());
}
#[test]
fn rank_by_similarity_top_k_zero_returns_empty() {
let query = ContentEmbedding::new(vec![1.0_f32, 0.0]).expect("ok");
let c = vec![ContentEmbedding::new(vec![1.0_f32, 0.0]).expect("ok")];
assert!(rank_by_similarity(&query, &c, 0).is_empty());
}
#[test]
fn rank_by_similarity_larger_top_k_than_candidates_is_capped() {
let query = ContentEmbedding::new(vec![1.0_f32, 0.0]).expect("ok");
let c = vec![
ContentEmbedding::new(vec![1.0_f32, 0.0]).expect("ok"),
ContentEmbedding::new(vec![0.0_f32, 1.0]).expect("ok"),
];
let ranked = rank_by_similarity(&query, &c, 100);
assert_eq!(ranked.len(), 2);
}
#[test]
fn ml_error_from_conversion_is_wired() {
let ml_err = MlError::FeatureDisabled("onnx");
let rec_err: RecommendError = ml_err.into();
match rec_err {
RecommendError::Ml(inner) => {
assert!(matches!(inner, MlError::FeatureDisabled("onnx")));
}
other => panic!("unexpected conversion result: {other:?}"),
}
}
#[test]
fn from_path_missing_file_returns_ml_error() {
let path = std::path::PathBuf::from("/does-not-exist-oximedia-recommend-embedding.onnx");
let err = EmbeddingExtractor::from_path(&path, DeviceType::Cpu)
.expect_err("loading a nonexistent model must fail");
assert!(
matches!(err, RecommendError::Ml(_)),
"expected RecommendError::Ml, got {err:?}"
);
}
}