agentic_vision/
embedding.rs1use image::DynamicImage;
4use ndarray::Array4;
5use ort::session::Session;
6use ort::value::Tensor;
7
8use crate::types::{VisionError, VisionResult};
9
10pub const EMBEDDING_DIM: u32 = 512;
12
13const MODEL_DIR: &str = ".agentic-vision/models";
15
16const MODEL_FILENAME: &str = "clip-vit-base-patch32-visual.onnx";
18
19const CLIP_IMAGE_SIZE: u32 = 224;
21#[allow(clippy::excessive_precision)]
22const CLIP_MEAN: [f32; 3] = [0.48145466, 0.4578275, 0.40821073];
23#[allow(clippy::excessive_precision)]
24const CLIP_STD: [f32; 3] = [0.26862954, 0.26130258, 0.27577711];
25
26pub struct EmbeddingEngine {
28 session: Option<Session>,
29}
30
31impl EmbeddingEngine {
32 pub fn new(model_path: Option<&str>) -> VisionResult<Self> {
38 let path = if let Some(p) = model_path {
39 std::path::PathBuf::from(p)
40 } else {
41 let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
42 std::path::PathBuf::from(home)
43 .join(MODEL_DIR)
44 .join(MODEL_FILENAME)
45 };
46
47 if !path.exists() {
48 tracing::warn!(
49 "CLIP model not found at {}. Running in fallback mode (zero embeddings). \
50 Download a CLIP ONNX model to enable semantic similarity.",
51 path.display()
52 );
53 return Ok(Self { session: None });
54 }
55
56 tracing::info!("Loading CLIP model from {}", path.display());
57
58 let session = Session::builder()
59 .and_then(|b| b.with_intra_threads(1))
60 .and_then(|b| b.commit_from_file(&path))
61 .map_err(|e| VisionError::Embedding(format!("Failed to load ONNX model: {e}")))?;
62
63 tracing::info!("CLIP model loaded successfully");
64 Ok(Self {
65 session: Some(session),
66 })
67 }
68
69 pub fn has_model(&self) -> bool {
71 self.session.is_some()
72 }
73
74 pub fn embed(&mut self, img: &DynamicImage) -> VisionResult<Vec<f32>> {
78 let session = match &mut self.session {
79 Some(s) => s,
80 None => {
81 tracing::debug!("No model loaded, returning zero embedding");
82 return Ok(vec![0.0; EMBEDDING_DIM as usize]);
83 }
84 };
85
86 let resized = img.resize_exact(
88 CLIP_IMAGE_SIZE,
89 CLIP_IMAGE_SIZE,
90 image::imageops::FilterType::Lanczos3,
91 );
92 let rgb = resized.to_rgb8();
93
94 let mut tensor =
96 Array4::<f32>::zeros((1, 3, CLIP_IMAGE_SIZE as usize, CLIP_IMAGE_SIZE as usize));
97
98 for y in 0..CLIP_IMAGE_SIZE {
99 for x in 0..CLIP_IMAGE_SIZE {
100 let pixel = rgb.get_pixel(x, y);
101 for c in 0..3usize {
102 let val = pixel[c] as f32 / 255.0;
103 let normalized = (val - CLIP_MEAN[c]) / CLIP_STD[c];
104 tensor[[0, c, y as usize, x as usize]] = normalized;
105 }
106 }
107 }
108
109 let input_tensor = Tensor::from_array(tensor)
110 .map_err(|e| VisionError::Embedding(format!("Failed to create input tensor: {e}")))?;
111
112 let outputs = session
113 .run(ort::inputs![input_tensor])
114 .map_err(|e| VisionError::Embedding(format!("ONNX inference failed: {e}")))?;
115
116 let (_shape, data) = outputs[0]
118 .try_extract_tensor::<f32>()
119 .map_err(|e| VisionError::Embedding(format!("Failed to extract output: {e}")))?;
120
121 let embedding: Vec<f32> = data.to_vec();
122
123 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
125 if norm > 0.0 {
126 Ok(embedding.iter().map(|x| x / norm).collect())
127 } else {
128 Ok(embedding)
129 }
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn test_fallback_mode() {
139 let mut engine = EmbeddingEngine::new(Some("/nonexistent/model.onnx")).unwrap();
140 assert!(!engine.has_model());
141
142 let img = DynamicImage::new_rgb8(100, 100);
143 let embedding = engine.embed(&img).unwrap();
144 assert_eq!(embedding.len(), EMBEDDING_DIM as usize);
145 assert!(embedding.iter().all(|&v| v == 0.0));
146 }
147}