use crate::backbone::{Backbone, BackboneExt, BackboneType, MobileNetV3, MobileNetV3Config};
use crate::error::{CnnError, CnnResult};
use crate::layers::TensorShape;
pub trait EmbeddingExtractorExt: Send + Sync {
fn extract(&self, image: &[f32], height: usize, width: usize) -> CnnResult<Vec<f32>>;
fn extract_with_shape(&self, image: &[f32], shape: &TensorShape) -> CnnResult<Vec<f32>>;
fn extract_batch(
&self,
images: &[f32],
batch_size: usize,
height: usize,
width: usize,
) -> CnnResult<Vec<Vec<f32>>>;
fn embedding_dim(&self) -> usize;
fn is_normalized(&self) -> bool;
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct EmbeddingConfig {
pub backbone_type: BackboneType,
pub normalize: bool,
pub input_size: usize,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
backbone_type: BackboneType::MobileNetV3Small,
normalize: true,
input_size: 224,
}
}
}
impl EmbeddingConfig {
pub fn mobilenet_v3_small() -> Self {
Self {
backbone_type: BackboneType::MobileNetV3Small,
normalize: true,
input_size: 224,
}
}
pub fn mobilenet_v3_large() -> Self {
Self {
backbone_type: BackboneType::MobileNetV3Large,
normalize: true,
input_size: 224,
}
}
pub fn normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
}
#[derive(Clone, Debug)]
pub struct MobileNetEmbedder {
backbone: MobileNetV3,
normalize: bool,
input_size: usize,
}
impl MobileNetEmbedder {
pub fn new(config: EmbeddingConfig) -> CnnResult<Self> {
let backbone_config = match config.backbone_type {
BackboneType::MobileNetV3Small => MobileNetV3Config::small(0), BackboneType::MobileNetV3Large => MobileNetV3Config::large(0),
};
let backbone = MobileNetV3::new(backbone_config)?;
Ok(Self {
backbone,
normalize: config.normalize,
input_size: config.input_size,
})
}
pub fn v3_small() -> CnnResult<Self> {
Self::new(EmbeddingConfig::mobilenet_v3_small())
}
pub fn v3_large() -> CnnResult<Self> {
Self::new(EmbeddingConfig::mobilenet_v3_large())
}
pub fn without_normalization(mut self) -> Self {
self.normalize = false;
self
}
pub fn backbone(&self) -> &MobileNetV3 {
&self.backbone
}
pub fn input_size(&self) -> usize {
self.input_size
}
fn l2_normalize_inplace(&self, vec: &mut [f32]) {
let norm: f32 = vec.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in vec.iter_mut() {
*x /= norm;
}
}
}
fn l2_normalize(&self, vec: &[f32]) -> Vec<f32> {
let norm: f32 = vec.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
vec.iter().map(|&x| x / norm).collect()
} else {
vec.to_vec()
}
}
}
impl EmbeddingExtractorExt for MobileNetEmbedder {
fn extract(&self, image: &[f32], height: usize, width: usize) -> CnnResult<Vec<f32>> {
let shape = TensorShape::new(1, 3, height, width);
self.extract_with_shape(image, &shape)
}
fn extract_with_shape(&self, image: &[f32], shape: &TensorShape) -> CnnResult<Vec<f32>> {
if image.len() != shape.numel() {
return Err(CnnError::DimensionMismatch(format!(
"Image has {} elements, expected {} for shape {}",
image.len(),
shape.numel(),
shape
)));
}
let mut embedding = self.backbone.forward_features(image, shape)?;
if self.normalize {
self.l2_normalize_inplace(&mut embedding);
}
Ok(embedding)
}
fn extract_batch(
&self,
images: &[f32],
batch_size: usize,
height: usize,
width: usize,
) -> CnnResult<Vec<Vec<f32>>> {
let image_size = 3 * height * width;
if images.len() != batch_size * image_size {
return Err(CnnError::DimensionMismatch(format!(
"Images have {} elements, expected {} for batch of {} images",
images.len(),
batch_size * image_size,
batch_size
)));
}
let embeddings: CnnResult<Vec<Vec<f32>>> = (0..batch_size)
.map(|i| {
let start = i * image_size;
let end = start + image_size;
let image = &images[start..end];
self.extract(image, height, width)
})
.collect();
embeddings
}
fn embedding_dim(&self) -> usize {
self.backbone.output_dim()
}
fn is_normalized(&self) -> bool {
self.normalize
}
}
#[cfg(feature = "parallel")]
pub mod parallel {
use super::*;
use rayon::prelude::*;
pub trait ParallelEmbedding: EmbeddingExtractorExt {
fn extract_batch_parallel(
&self,
images: &[f32],
batch_size: usize,
height: usize,
width: usize,
) -> CnnResult<Vec<Vec<f32>>>;
fn extract_many_parallel(
&self,
images: &[&[f32]],
height: usize,
width: usize,
) -> CnnResult<Vec<Vec<f32>>>;
}
impl<T: EmbeddingExtractorExt + Sync> ParallelEmbedding for T {
fn extract_batch_parallel(
&self,
images: &[f32],
batch_size: usize,
height: usize,
width: usize,
) -> CnnResult<Vec<Vec<f32>>> {
let image_size = 3 * height * width;
if images.len() != batch_size * image_size {
return Err(CnnError::DimensionMismatch(format!(
"Images have {} elements, expected {} for batch of {} images",
images.len(),
batch_size * image_size,
batch_size
)));
}
let results: Vec<CnnResult<Vec<f32>>> = (0..batch_size)
.into_par_iter()
.map(|i| {
let start = i * image_size;
let end = start + image_size;
let image = &images[start..end];
self.extract(image, height, width)
})
.collect();
results.into_iter().collect()
}
fn extract_many_parallel(
&self,
images: &[&[f32]],
height: usize,
width: usize,
) -> CnnResult<Vec<Vec<f32>>> {
let results: Vec<CnnResult<Vec<f32>>> = images
.par_iter()
.map(|image| self.extract(image, height, width))
.collect();
results.into_iter().collect()
}
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::MAX;
}
let sum_sq: f32 = a
.iter()
.zip(b.iter())
.map(|(&x, &y)| {
let diff = x - y;
diff * diff
})
.sum();
sum_sq.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedder_creation() {
let embedder = MobileNetEmbedder::v3_small().unwrap();
assert_eq!(embedder.embedding_dim(), 576);
assert!(embedder.is_normalized());
}
#[test]
fn test_embedder_v3_large() {
let embedder = MobileNetEmbedder::v3_large().unwrap();
assert_eq!(embedder.embedding_dim(), 960);
}
#[test]
fn test_embedder_config() {
let config = EmbeddingConfig::mobilenet_v3_small().normalize(false);
let embedder = MobileNetEmbedder::new(config).unwrap();
assert!(!embedder.is_normalized());
}
#[test]
fn test_extract_embedding() {
let embedder = MobileNetEmbedder::v3_small().unwrap();
let image = vec![0.5f32; 3 * 224 * 224];
let embedding = embedder.extract(&image, 224, 224).unwrap();
assert_eq!(embedding.len(), 576);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5 || norm < 1e-10);
}
#[test]
fn test_extract_batch() {
let embedder = MobileNetEmbedder::v3_small().unwrap();
let batch_size = 2;
let images = vec![0.5f32; batch_size * 3 * 224 * 224];
let embeddings = embedder.extract_batch(&images, batch_size, 224, 224).unwrap();
assert_eq!(embeddings.len(), batch_size);
for embedding in &embeddings {
assert_eq!(embedding.len(), 576);
}
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![0.0, 1.0, 0.0];
assert!(cosine_similarity(&a, &c).abs() < 1e-6);
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![3.0, 4.0, 0.0];
assert!((euclidean_distance(&a, &b) - 5.0).abs() < 1e-6);
}
#[test]
fn test_without_normalization() {
let embedder = MobileNetEmbedder::v3_small().unwrap().without_normalization();
assert!(!embedder.is_normalized());
}
}