Skip to main content

agentic_vision/
embedding.rs

1//! CLIP embedding generation via ONNX Runtime.
2
3use image::DynamicImage;
4use ndarray::Array4;
5use ort::session::Session;
6use ort::value::Tensor;
7
8use crate::types::{VisionError, VisionResult};
9
10/// Default embedding dimension for CLIP ViT-B/32.
11pub const EMBEDDING_DIM: u32 = 512;
12
13/// Default model directory.
14const MODEL_DIR: &str = ".agentic-vision/models";
15
16/// Default model filename.
17const MODEL_FILENAME: &str = "clip-vit-base-patch32-visual.onnx";
18
19/// CLIP image preprocessing constants.
20const CLIP_IMAGE_SIZE: u32 = 224;
21#[allow(clippy::excessive_precision)]
22const CLIP_MEAN: [f32; 3] = [0.48145466, 0.4578275, 0.40821073];
23#[allow(clippy::excessive_precision)]
24const CLIP_STD: [f32; 3] = [0.26862954, 0.26130258, 0.27577711];
25
26/// Engine for generating CLIP image embeddings.
27pub struct EmbeddingEngine {
28    session: Option<Session>,
29}
30
31impl EmbeddingEngine {
32    /// Create a new embedding engine.
33    ///
34    /// If `model_path` is provided, loads the model from that path.
35    /// Otherwise, looks in `~/.agentic-vision/models/`.
36    /// If no model is found, the engine operates in fallback mode (zero vectors).
37    pub fn new(model_path: Option<&str>) -> VisionResult<Self> {
38        let path = if let Some(p) = model_path {
39            std::path::PathBuf::from(p)
40        } else {
41            let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
42            std::path::PathBuf::from(home)
43                .join(MODEL_DIR)
44                .join(MODEL_FILENAME)
45        };
46
47        if !path.exists() {
48            tracing::warn!(
49                "CLIP model not found at {}. Running in fallback mode (zero embeddings). \
50                 Download a CLIP ONNX model to enable semantic similarity.",
51                path.display()
52            );
53            return Ok(Self { session: None });
54        }
55
56        tracing::info!("Loading CLIP model from {}", path.display());
57
58        let session = Session::builder()
59            .and_then(|b| b.with_intra_threads(1))
60            .and_then(|b| b.commit_from_file(&path))
61            .map_err(|e| VisionError::Embedding(format!("Failed to load ONNX model: {e}")))?;
62
63        tracing::info!("CLIP model loaded successfully");
64        Ok(Self {
65            session: Some(session),
66        })
67    }
68
69    /// Check if the engine has a loaded model.
70    pub fn has_model(&self) -> bool {
71        self.session.is_some()
72    }
73
74    /// Generate an embedding for an image.
75    ///
76    /// Returns a 512-dimensional vector. If no model is loaded, returns zeros.
77    pub fn embed(&mut self, img: &DynamicImage) -> VisionResult<Vec<f32>> {
78        let session = match &mut self.session {
79            Some(s) => s,
80            None => {
81                tracing::debug!("No model loaded, returning zero embedding");
82                return Ok(vec![0.0; EMBEDDING_DIM as usize]);
83            }
84        };
85
86        // Preprocess: resize to 224x224, normalize with CLIP mean/std
87        let resized = img.resize_exact(
88            CLIP_IMAGE_SIZE,
89            CLIP_IMAGE_SIZE,
90            image::imageops::FilterType::Lanczos3,
91        );
92        let rgb = resized.to_rgb8();
93
94        // Create NCHW tensor [1, 3, 224, 224]
95        let mut tensor =
96            Array4::<f32>::zeros((1, 3, CLIP_IMAGE_SIZE as usize, CLIP_IMAGE_SIZE as usize));
97
98        for y in 0..CLIP_IMAGE_SIZE {
99            for x in 0..CLIP_IMAGE_SIZE {
100                let pixel = rgb.get_pixel(x, y);
101                for c in 0..3usize {
102                    let val = pixel[c] as f32 / 255.0;
103                    let normalized = (val - CLIP_MEAN[c]) / CLIP_STD[c];
104                    tensor[[0, c, y as usize, x as usize]] = normalized;
105                }
106            }
107        }
108
109        let input_tensor = Tensor::from_array(tensor)
110            .map_err(|e| VisionError::Embedding(format!("Failed to create input tensor: {e}")))?;
111
112        let outputs = session
113            .run(ort::inputs![input_tensor])
114            .map_err(|e| VisionError::Embedding(format!("ONNX inference failed: {e}")))?;
115
116        // Extract the embedding from the first output
117        let (_shape, data) = outputs[0]
118            .try_extract_tensor::<f32>()
119            .map_err(|e| VisionError::Embedding(format!("Failed to extract output: {e}")))?;
120
121        let embedding: Vec<f32> = data.to_vec();
122
123        // L2 normalize
124        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
125        if norm > 0.0 {
126            Ok(embedding.iter().map(|x| x / norm).collect())
127        } else {
128            Ok(embedding)
129        }
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn test_fallback_mode() {
139        let mut engine = EmbeddingEngine::new(Some("/nonexistent/model.onnx")).unwrap();
140        assert!(!engine.has_model());
141
142        let img = DynamicImage::new_rgb8(100, 100);
143        let embedding = engine.embed(&img).unwrap();
144        assert_eq!(embedding.len(), EMBEDDING_DIM as usize);
145        assert!(embedding.iter().all(|&v| v == 0.0));
146    }
147}