mod error;
mod tensor;
pub mod layers;
pub mod simd;
pub mod kernels;
pub mod quantize;
pub mod int8;
#[cfg(feature = "backbone")]
pub mod backbone;
#[cfg(feature = "backbone")]
pub mod embedding;
pub mod contrastive;
pub use error::{CnnError, CnnResult};
pub use tensor::Tensor;
#[cfg(feature = "backbone")]
pub use backbone::{
Backbone, BackboneExt, BackboneType,
MobileNetV3, MobileNetV3Config,
MobileNetV3Small, MobileNetV3Large, MobileNetConfig,
ConvBNActivation, InvertedResidual, SqueezeExcitation,
create_backbone, mobilenet_v3_small, mobilenet_v3_large,
};
#[cfg(feature = "backbone")]
pub use embedding::{
MobileNetEmbedder, EmbeddingExtractorExt,
EmbeddingConfig as MobileNetEmbeddingConfig,
cosine_similarity, euclidean_distance,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
pub input_size: u32,
pub embedding_dim: usize,
pub normalize: bool,
pub quantized: bool,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
input_size: 224,
embedding_dim: 512,
normalize: true,
quantized: false,
}
}
}
#[derive(Debug, Clone)]
pub struct CnnEmbedder {
config: EmbeddingConfig,
weights: EmbedderWeights,
}
#[derive(Debug, Clone)]
struct EmbedderWeights {
conv_weights: Vec<f32>,
bn_gamma: Vec<f32>,
bn_beta: Vec<f32>,
bn_mean: Vec<f32>,
bn_var: Vec<f32>,
projection: Vec<f32>,
}
impl Default for EmbedderWeights {
fn default() -> Self {
use rand::Rng;
let mut rng = rand::thread_rng();
let conv_size = 3 * 3 * 3 * 16;
let bn_size = 16;
let proj_size = 16 * 512;
Self {
conv_weights: (0..conv_size).map(|_| rng.gen_range(-0.1..0.1)).collect(),
bn_gamma: vec![1.0; bn_size],
bn_beta: vec![0.0; bn_size],
bn_mean: vec![0.0; bn_size],
bn_var: vec![1.0; bn_size],
projection: (0..proj_size).map(|_| rng.gen_range(-0.1..0.1)).collect(),
}
}
}
impl CnnEmbedder {
pub fn new(config: EmbeddingConfig) -> CnnResult<Self> {
let weights = EmbedderWeights::default();
Ok(Self { config, weights })
}
pub fn new_v3_small() -> CnnResult<Self> {
Self::new(EmbeddingConfig {
input_size: 224,
embedding_dim: 576,
normalize: true,
quantized: false,
})
}
pub fn new_v3_large() -> CnnResult<Self> {
Self::new(EmbeddingConfig {
input_size: 224,
embedding_dim: 960,
normalize: true,
quantized: false,
})
}
pub fn extract(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>> {
let expected_size = (width * height * 4) as usize;
if image_data.len() != expected_size {
return Err(CnnError::InvalidInput(format!(
"Expected {} bytes for {}x{} RGBA image, got {}",
expected_size, width, height, image_data.len()
)));
}
let rgb_float = self.preprocess(image_data, width, height)?;
let features = self.forward(&rgb_float)?;
let pooled = self.global_avg_pool(&features)?;
let mut embedding = self.project(&pooled)?;
if self.config.normalize {
self.l2_normalize(&mut embedding);
}
Ok(embedding)
}
pub fn embedding_dim(&self) -> usize {
self.config.embedding_dim
}
pub fn input_size(&self) -> u32 {
self.config.input_size
}
fn preprocess(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>> {
let pixels = (width * height) as usize;
let mut rgb = Vec::with_capacity(pixels * 3);
let mean = [0.485, 0.456, 0.406];
let std = [0.229, 0.224, 0.225];
for i in 0..pixels {
let offset = i * 4;
rgb.push((image_data[offset] as f32 / 255.0 - mean[0]) / std[0]);
rgb.push((image_data[offset + 1] as f32 / 255.0 - mean[1]) / std[1]);
rgb.push((image_data[offset + 2] as f32 / 255.0 - mean[2]) / std[2]);
}
Ok(rgb)
}
fn forward(&self, input: &[f32]) -> CnnResult<Vec<f32>> {
let conv_out = layers::conv2d_3x3(
input,
&self.weights.conv_weights,
3,
16,
self.config.input_size as usize,
self.config.input_size as usize,
);
let bn_out = layers::batch_norm(
&conv_out,
&self.weights.bn_gamma,
&self.weights.bn_beta,
&self.weights.bn_mean,
&self.weights.bn_var,
1e-5,
);
let activated: Vec<f32> = bn_out.iter().map(|&x| x.max(0.0)).collect();
Ok(activated)
}
fn global_avg_pool(&self, features: &[f32]) -> CnnResult<Vec<f32>> {
let channels = 16;
let spatial = features.len() / channels;
let mut pooled = vec![0.0f32; channels];
for i in 0..spatial {
for c in 0..channels {
pooled[c] += features[i * channels + c];
}
}
let inv_spatial = 1.0 / spatial as f32;
for p in pooled.iter_mut() {
*p *= inv_spatial;
}
Ok(pooled)
}
fn project(&self, features: &[f32]) -> CnnResult<Vec<f32>> {
let in_dim = features.len();
let out_dim = self.config.embedding_dim;
let mut output = vec![0.0f32; out_dim];
for o in 0..out_dim {
let mut sum = 0.0f32;
for i in 0..in_dim {
sum += features[i] * self.weights.projection[i * out_dim + o];
}
output[o] = sum;
}
Ok(output)
}
fn l2_normalize(&self, vec: &mut [f32]) {
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in vec.iter_mut() {
*x /= norm;
}
}
}
}
pub trait EmbeddingExtractor {
fn extract(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>>;
fn embedding_dim(&self) -> usize;
}
impl EmbeddingExtractor for CnnEmbedder {
fn extract(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>> {
CnnEmbedder::extract(self, image_data, width, height)
}
fn embedding_dim(&self) -> usize {
CnnEmbedder::embedding_dim(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedder_creation() {
let embedder = CnnEmbedder::new(EmbeddingConfig::default()).unwrap();
assert_eq!(embedder.embedding_dim(), 512);
}
#[test]
fn test_v3_small() {
let embedder = CnnEmbedder::new_v3_small().unwrap();
assert_eq!(embedder.embedding_dim(), 576);
}
#[test]
fn test_v3_large() {
let embedder = CnnEmbedder::new_v3_large().unwrap();
assert_eq!(embedder.embedding_dim(), 960);
}
#[test]
fn test_extract_embedding() {
let embedder = CnnEmbedder::new(EmbeddingConfig {
input_size: 4,
embedding_dim: 8,
normalize: true,
quantized: false,
}).unwrap();
let image_data = vec![128u8; 4 * 4 * 4];
let embedding = embedder.extract(&image_data, 4, 4).unwrap();
assert_eq!(embedding.len(), 8);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5 || norm < 1e-10);
}
#[test]
fn test_invalid_input() {
let embedder = CnnEmbedder::new(EmbeddingConfig::default()).unwrap();
let image_data = vec![0u8; 100];
let result = embedder.extract(&image_data, 10, 10);
assert!(result.is_err());
}
}