1use blake3::hash;
16#[cfg(feature = "clip")]
17use image::DynamicImage;
18#[cfg(all(feature = "clip", not(feature = "pdfium")))]
19use image::{ImageBuffer, Luma, Rgb};
20#[cfg(all(feature = "clip", not(feature = "pdfium")))]
21use lopdf::{Dictionary, Document, Object, ObjectId};
22use serde::{Deserialize, Serialize};
23use std::borrow::Cow;
24#[cfg(all(feature = "clip", not(feature = "pdfium")))]
25use std::collections::HashSet;
26use std::path::{Path, PathBuf};
27use std::time::Duration;
28
29use crate::{MemvidError, Result, types::FrameId};
30
31const CLIP_DECODE_LIMIT: usize = crate::MAX_INDEX_BYTES as usize;
37
38pub const MOBILECLIP_DIMS: u32 = 512;
40
41pub const SIGLIP_DIMS: u32 = 768;
43
44pub const MOBILECLIP_INPUT_SIZE: u32 = 256;
46
47pub const SIGLIP_INPUT_SIZE: u32 = 224;
49
50pub const MIN_IMAGE_DIM: u32 = 64;
52
53pub const MAX_ASPECT_RATIO: f32 = 10.0;
55
56pub const MIN_COLOR_VARIANCE: f32 = 0.01;
58
59pub const MODEL_UNLOAD_TIMEOUT: Duration = Duration::from_secs(300);
61
62fn clip_config() -> impl bincode::config::Config {
67 bincode::config::standard()
68 .with_fixed_int_encoding()
69 .with_little_endian()
70}
71
72#[derive(Debug, Clone)]
78pub struct ClipModelInfo {
79 pub name: &'static str,
81 pub vision_url: &'static str,
83 pub text_url: &'static str,
85 pub tokenizer_url: &'static str,
87 pub vision_size_mb: f32,
89 pub text_size_mb: f32,
91 pub dims: u32,
93 pub input_resolution: u32,
95 pub is_default: bool,
97}
98
99pub static CLIP_MODELS: &[ClipModelInfo] = &[
101 ClipModelInfo {
104 name: "mobileclip-s2-int8",
105 vision_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/onnx/vision_model_int8.onnx",
106 text_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/onnx/text_model_int8.onnx",
107 tokenizer_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/tokenizer.json",
108 vision_size_mb: 36.7,
109 text_size_mb: 64.1,
110 dims: MOBILECLIP_DIMS,
111 input_resolution: MOBILECLIP_INPUT_SIZE,
112 is_default: false,
113 },
114 ClipModelInfo {
116 name: "siglip-base",
117 vision_url: "https://huggingface.co/Xenova/siglip-base-patch16-224/resolve/main/onnx/vision_model_quantized.onnx",
118 text_url: "https://huggingface.co/Xenova/siglip-base-patch16-224/resolve/main/onnx/text_model_quantized.onnx",
119 tokenizer_url: "https://huggingface.co/Xenova/siglip-base-patch16-224/resolve/main/tokenizer.json",
120 vision_size_mb: 99.5,
121 text_size_mb: 111.0,
122 dims: SIGLIP_DIMS,
123 input_resolution: SIGLIP_INPUT_SIZE,
124 is_default: false,
125 },
126 ClipModelInfo {
128 name: "mobileclip-s2",
129 vision_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/onnx/vision_model_fp16.onnx",
130 text_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/onnx/text_model_fp16.onnx",
131 tokenizer_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/tokenizer.json",
132 vision_size_mb: 71.7,
133 text_size_mb: 127.0,
134 dims: MOBILECLIP_DIMS,
135 input_resolution: MOBILECLIP_INPUT_SIZE,
136 is_default: true,
137 },
138];
139
140pub fn get_model_info(name: &str) -> &'static ClipModelInfo {
142 CLIP_MODELS
143 .iter()
144 .find(|m| m.name == name)
145 .unwrap_or_else(|| {
146 CLIP_MODELS
147 .iter()
148 .find(|m| m.is_default)
149 .expect("default model")
150 })
151}
152
153pub fn default_model_info() -> &'static ClipModelInfo {
155 CLIP_MODELS
156 .iter()
157 .find(|m| m.is_default)
158 .expect("default model exists")
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct ClipDocument {
168 pub frame_id: FrameId,
170 pub embedding: Vec<f32>,
172 #[serde(default)]
174 pub page: Option<u32>,
175}
176
177#[derive(Default)]
179pub struct ClipIndexBuilder {
180 documents: Vec<ClipDocument>,
181}
182
183impl ClipIndexBuilder {
184 pub fn new() -> Self {
185 Self::default()
186 }
187
188 pub fn add_document<I>(&mut self, frame_id: FrameId, page: Option<u32>, embedding: I)
190 where
191 I: Into<Vec<f32>>,
192 {
193 self.documents.push(ClipDocument {
194 frame_id,
195 embedding: embedding.into(),
196 page,
197 });
198 }
199
200 pub fn finish(self) -> Result<ClipIndexArtifact> {
202 let bytes = bincode::serde::encode_to_vec(&self.documents, clip_config())?;
203
204 let checksum = *hash(&bytes).as_bytes();
205 let dimension = self
206 .documents
207 .first()
208 .map(|doc| doc.embedding.len() as u32)
209 .unwrap_or(0);
210
211 Ok(ClipIndexArtifact {
212 bytes,
213 vector_count: self.documents.len() as u64,
214 dimension,
215 checksum,
216 })
217 }
218}
219
220#[derive(Debug, Clone)]
222pub struct ClipIndexArtifact {
223 pub bytes: Vec<u8>,
225 pub vector_count: u64,
227 pub dimension: u32,
229 pub checksum: [u8; 32],
231}
232
233#[derive(Debug, Clone)]
235pub struct ClipIndex {
236 documents: Vec<ClipDocument>,
237}
238
239impl ClipIndex {
240 pub fn new() -> Self {
242 Self {
243 documents: Vec::new(),
244 }
245 }
246
247 pub fn add_document<I>(&mut self, frame_id: FrameId, page: Option<u32>, embedding: I)
249 where
250 I: Into<Vec<f32>>,
251 {
252 self.documents.push(ClipDocument {
253 frame_id,
254 embedding: embedding.into(),
255 page,
256 });
257 }
258
259 pub fn decode(bytes: &[u8]) -> Result<Self> {
261 let (documents, read) = bincode::serde::decode_from_slice::<Vec<ClipDocument>, _>(
262 bytes,
263 bincode::config::standard()
264 .with_fixed_int_encoding()
265 .with_little_endian()
266 .with_limit::<CLIP_DECODE_LIMIT>(),
267 )?;
268
269 if read != bytes.len() {
270 return Err(MemvidError::InvalidToc {
271 reason: Cow::Owned(format!(
272 "CLIP index decode: expected {} bytes, read {}",
273 bytes.len(),
274 read
275 )),
276 });
277 }
278
279 tracing::debug!(
280 bytes_len = bytes.len(),
281 docs_count = documents.len(),
282 "decoded CLIP index"
283 );
284
285 Ok(Self { documents })
286 }
287
288 pub fn search(&self, query: &[f32], limit: usize) -> Vec<ClipSearchHit> {
290 if query.is_empty() {
291 return Vec::new();
292 }
293
294 let mut hits: Vec<ClipSearchHit> = self
295 .documents
296 .iter()
297 .map(|doc| {
298 let distance = l2_distance(query, &doc.embedding);
299 ClipSearchHit {
300 frame_id: doc.frame_id,
301 page: doc.page,
302 distance,
303 }
304 })
305 .collect();
306
307 hits.sort_by(|a, b| {
308 a.distance
309 .partial_cmp(&b.distance)
310 .unwrap_or(std::cmp::Ordering::Equal)
311 });
312 hits.truncate(limit);
313 hits
314 }
315
316 pub fn entries(&self) -> impl Iterator<Item = (FrameId, Option<u32>, &[f32])> + '_ {
318 self.documents
319 .iter()
320 .map(|doc| (doc.frame_id, doc.page, doc.embedding.as_slice()))
321 }
322
323 pub fn embedding_for(&self, frame_id: FrameId) -> Option<&[f32]> {
325 self.documents
326 .iter()
327 .find(|doc| doc.frame_id == frame_id)
328 .map(|doc| doc.embedding.as_slice())
329 }
330
331 pub fn remove(&mut self, frame_id: FrameId) {
333 self.documents.retain(|doc| doc.frame_id != frame_id);
334 }
335
336 pub fn len(&self) -> usize {
338 self.documents.len()
339 }
340
341 pub fn is_empty(&self) -> bool {
343 self.documents.is_empty()
344 }
345
346 pub fn encode(&self) -> Result<ClipIndexArtifact> {
348 let bytes = bincode::serde::encode_to_vec(&self.documents, clip_config())?;
349
350 let checksum = *hash(&bytes).as_bytes();
351 let dimension = self
352 .documents
353 .first()
354 .map(|doc| doc.embedding.len() as u32)
355 .unwrap_or(0);
356
357 Ok(ClipIndexArtifact {
358 bytes,
359 vector_count: self.documents.len() as u64,
360 dimension,
361 checksum,
362 })
363 }
364}
365
366#[derive(Debug, Clone, PartialEq)]
368pub struct ClipSearchHit {
369 pub frame_id: FrameId,
371 pub page: Option<u32>,
373 pub distance: f32,
375}
376
377fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
379 a.iter()
380 .zip(b.iter())
381 .map(|(x, y)| (x - y).powi(2))
382 .sum::<f32>()
383 .sqrt()
384}
385
386#[derive(Debug, Clone)]
392pub struct ImageInfo {
393 pub width: u32,
394 pub height: u32,
395 pub color_variance: f32,
396}
397
398impl ImageInfo {
399 pub fn should_embed(&self) -> bool {
401 if self.width < MIN_IMAGE_DIM || self.height < MIN_IMAGE_DIM {
403 return false;
404 }
405
406 let aspect = self.width as f32 / self.height as f32;
408 if aspect > MAX_ASPECT_RATIO || aspect < (1.0 / MAX_ASPECT_RATIO) {
409 return false;
410 }
411
412 if self.color_variance < MIN_COLOR_VARIANCE {
414 return false;
415 }
416
417 true
418 }
419}
420
421pub fn filter_junk_images<T, F>(images: Vec<T>, get_info: F) -> Vec<T>
423where
424 F: Fn(&T) -> ImageInfo,
425{
426 images
427 .into_iter()
428 .filter(|img| get_info(img).should_embed())
429 .collect()
430}
431
432#[derive(Debug, Clone)]
438pub struct ClipConfig {
439 pub model_name: String,
441 pub models_dir: PathBuf,
443 pub offline: bool,
445}
446
447impl Default for ClipConfig {
448 fn default() -> Self {
449 let models_dir = std::env::var("MEMVID_MODELS_DIR")
451 .ok()
452 .map(PathBuf::from)
453 .or_else(|| dirs_next::home_dir().map(|d| d.join(".memvid/models")))
454 .unwrap_or_else(|| PathBuf::from(".memvid/models"));
455
456 let model_name =
457 std::env::var("MEMVID_CLIP_MODEL").unwrap_or_else(|_| "mobileclip-s2".to_string());
458
459 let offline = std::env::var("MEMVID_OFFLINE").is_ok();
460
461 Self {
462 model_name,
463 models_dir,
464 offline,
465 }
466 }
467}
468
469#[derive(Debug, thiserror::Error)]
475pub enum ClipError {
476 #[error("CLIP model '{model}' not found. {hint}")]
478 ModelNotFound { model: String, hint: String },
479
480 #[error("Failed to decode image at {path:?}: {cause}")]
482 ImageDecodeError { path: PathBuf, cause: String },
483
484 #[error("Failed to decode image bytes: {cause}")]
486 ImageBytesDecodeError { cause: String },
487
488 #[error("CLIP inference error: {cause}")]
490 InferenceError { cause: String },
491
492 #[error("Failed to download CLIP model: {cause}")]
494 DownloadError { cause: String },
495
496 #[error("CLIP model file is corrupted: {cause}")]
498 ModelCorrupted { cause: String },
499}
500
501impl From<ClipError> for MemvidError {
502 fn from(err: ClipError) -> Self {
503 MemvidError::EmbeddingFailed {
504 reason: err.to_string().into_boxed_str(),
505 }
506 }
507}
508
509#[cfg(feature = "clip")]
514mod model {
515 use super::*;
516 use image::{DynamicImage, GenericImageView, imageops::FilterType};
517 use ndarray::{Array, Array4};
518 use ort::session::{Session, builder::GraphOptimizationLevel};
519 use ort::value::Tensor;
520 use std::sync::Mutex;
521 use std::time::Instant;
522 use tokenizers::{
523 PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationDirection,
524 TruncationParams, TruncationStrategy,
525 };
526
527 pub struct ClipModel {
529 config: ClipConfig,
530 model_info: &'static ClipModelInfo,
531 vision_session: Mutex<Option<Session>>,
533 text_session: Mutex<Option<Session>>,
535 tokenizer: Mutex<Option<Tokenizer>>,
537 last_used: Mutex<Instant>,
539 }
540
541 impl ClipModel {
542 pub fn new(config: ClipConfig) -> Result<Self> {
544 let model_info = get_model_info(&config.model_name);
545
546 Ok(Self {
547 config,
548 model_info,
549 vision_session: Mutex::new(None),
550 text_session: Mutex::new(None),
551 tokenizer: Mutex::new(None),
552 last_used: Mutex::new(Instant::now()),
553 })
554 }
555
556 pub fn default_model() -> Result<Self> {
558 Self::new(ClipConfig::default())
559 }
560
561 pub fn model_info(&self) -> &'static ClipModelInfo {
563 self.model_info
564 }
565
566 pub fn dims(&self) -> u32 {
568 self.model_info.dims
569 }
570
571 fn ensure_model_file(&self, kind: &str) -> Result<PathBuf> {
573 let filename = format!("{}_{}.onnx", self.model_info.name, kind);
574 let path = self.config.models_dir.join(&filename);
575
576 if path.exists() {
577 return Ok(path);
578 }
579
580 if self.config.offline {
581 return Err(ClipError::ModelNotFound {
582 model: self.model_info.name.to_string(),
583 hint: format!(
584 "Run: memvid model download {} (or disable MEMVID_OFFLINE)",
585 self.model_info.name
586 ),
587 }
588 .into());
589 }
590
591 std::fs::create_dir_all(&self.config.models_dir).map_err(|e| {
593 ClipError::DownloadError {
594 cause: format!("Failed to create models directory: {}", e),
595 }
596 })?;
597
598 Err(ClipError::DownloadError {
600 cause: format!(
601 "Automatic download not yet implemented. Please download manually:\n\
602 curl -L '{}' -o '{}'",
603 if kind == "vision" {
604 self.model_info.vision_url
605 } else {
606 self.model_info.text_url
607 },
608 path.display()
609 ),
610 }
611 .into())
612 }
613
614 fn ensure_tokenizer_file(&self) -> Result<PathBuf> {
616 let filename = format!("{}_tokenizer.json", self.model_info.name);
617 let path = self.config.models_dir.join(&filename);
618
619 if path.exists() {
620 return Ok(path);
621 }
622
623 if self.config.offline {
624 return Err(ClipError::ModelNotFound {
625 model: self.model_info.name.to_string(),
626 hint: format!(
627 "Tokenizer missing at {}. Copy tokenizer.json from {}",
628 path.display(),
629 self.model_info.tokenizer_url
630 ),
631 }
632 .into());
633 }
634
635 std::fs::create_dir_all(&self.config.models_dir).map_err(|e| {
636 ClipError::DownloadError {
637 cause: format!("Failed to create models directory: {}", e),
638 }
639 })?;
640
641 Err(ClipError::DownloadError {
642 cause: format!(
643 "Automatic download not yet implemented. Please download manually:\n\
644 curl -L '{}' -o '{}'",
645 self.model_info.tokenizer_url,
646 path.display()
647 ),
648 }
649 .into())
650 }
651
652 fn load_vision_session(&self) -> Result<()> {
654 let mut session_guard = self
655 .vision_session
656 .lock()
657 .map_err(|_| MemvidError::Lock("Failed to lock vision session".into()))?;
658
659 if session_guard.is_some() {
660 return Ok(());
661 }
662
663 let vision_path = self.ensure_model_file("vision")?;
664
665 tracing::debug!(path = %vision_path.display(), "Loading CLIP vision model");
666
667 let session = Session::builder()
668 .map_err(|e| ClipError::InferenceError {
669 cause: e.to_string(),
670 })?
671 .with_optimization_level(GraphOptimizationLevel::Level3)
672 .map_err(|e| ClipError::InferenceError {
673 cause: e.to_string(),
674 })?
675 .with_intra_threads(4)
676 .map_err(|e| ClipError::InferenceError {
677 cause: e.to_string(),
678 })?
679 .commit_from_file(&vision_path)
680 .map_err(|e| ClipError::InferenceError {
681 cause: format!("Failed to load vision model: {}", e),
682 })?;
683
684 *session_guard = Some(session);
685 tracing::info!(model = %self.model_info.name, "CLIP vision model loaded");
686
687 Ok(())
688 }
689
690 fn load_text_session(&self) -> Result<()> {
692 let mut session_guard = self
693 .text_session
694 .lock()
695 .map_err(|_| MemvidError::Lock("Failed to lock text session".into()))?;
696
697 if session_guard.is_some() {
698 return Ok(());
699 }
700
701 let text_path = self.ensure_model_file("text")?;
702
703 tracing::debug!(path = %text_path.display(), "Loading CLIP text model");
704
705 let session = Session::builder()
706 .map_err(|e| ClipError::InferenceError {
707 cause: e.to_string(),
708 })?
709 .with_optimization_level(GraphOptimizationLevel::Level3)
710 .map_err(|e| ClipError::InferenceError {
711 cause: e.to_string(),
712 })?
713 .with_intra_threads(4)
714 .map_err(|e| ClipError::InferenceError {
715 cause: e.to_string(),
716 })?
717 .commit_from_file(&text_path)
718 .map_err(|e| ClipError::InferenceError {
719 cause: format!("Failed to load text model: {}", e),
720 })?;
721
722 *session_guard = Some(session);
723 tracing::info!(model = %self.model_info.name, "CLIP text model loaded");
724
725 Ok(())
726 }
727
728 fn load_tokenizer(&self) -> Result<()> {
730 let mut tokenizer_guard = self
731 .tokenizer
732 .lock()
733 .map_err(|_| MemvidError::Lock("Failed to lock CLIP tokenizer".into()))?;
734
735 if tokenizer_guard.is_some() {
736 return Ok(());
737 }
738
739 let tokenizer_path = self.ensure_tokenizer_file()?;
740
741 tracing::debug!(path = %tokenizer_path.display(), "Loading CLIP tokenizer");
742
743 let mut tokenizer =
744 Tokenizer::from_file(&tokenizer_path).map_err(|e| ClipError::InferenceError {
745 cause: format!("Failed to load tokenizer: {}", e),
746 })?;
747
748 tokenizer.with_padding(Some(PaddingParams {
749 strategy: PaddingStrategy::Fixed(77),
750 direction: PaddingDirection::Right,
751 pad_to_multiple_of: None,
752 pad_id: 0,
753 pad_type_id: 0,
754 pad_token: "[PAD]".to_string(),
755 }));
756
757 tokenizer
758 .with_truncation(Some(TruncationParams {
759 max_length: 77,
760 strategy: TruncationStrategy::LongestFirst,
761 stride: 0,
762 direction: TruncationDirection::Right,
763 }))
764 .map_err(|e| ClipError::InferenceError {
765 cause: format!("Failed to apply truncation config: {}", e),
766 })?;
767
768 *tokenizer_guard = Some(tokenizer);
769 tracing::info!(model = %self.model_info.name, "CLIP tokenizer loaded");
770
771 Ok(())
772 }
773
774 fn preprocess_image(&self, image: &DynamicImage) -> Array4<f32> {
782 let size = self.model_info.input_resolution;
783 let rgb_input = image.to_rgb8();
784 let (w, h) = rgb_input.dimensions();
785
786 let scale = size as f32 / w.min(h) as f32;
788 let new_w = ((w as f32) * scale).round().max(1.0) as u32;
789 let new_h = ((h as f32) * scale).round().max(1.0) as u32;
790 let resized = image.resize_exact(new_w, new_h, FilterType::Triangle);
791
792 let start_x = (resized.width().saturating_sub(size)) / 2;
794 let start_y = (resized.height().saturating_sub(size)) / 2;
795
796 let mut array = Array4::<f32>::zeros((1, 3, size as usize, size as usize));
798
799 for y in 0..size as usize {
800 for x in 0..size as usize {
801 let pixel = resized.get_pixel(start_x + x as u32, start_y + y as u32);
802 array[[0, 0, y, x]] = pixel[0] as f32 / 255.0;
803 array[[0, 1, y, x]] = pixel[1] as f32 / 255.0;
804 array[[0, 2, y, x]] = pixel[2] as f32 / 255.0;
805 }
806 }
807
808 array
809 }
810
811 pub fn encode_image(&self, image: &DynamicImage) -> Result<Vec<f32>> {
813 self.load_vision_session()?;
815
816 let pixel_values = self.preprocess_image(image);
818
819 if let Ok(mut last) = self.last_used.lock() {
821 *last = Instant::now();
822 }
823
824 let mut session_guard = self
826 .vision_session
827 .lock()
828 .map_err(|_| MemvidError::Lock("Failed to lock vision session".into()))?;
829
830 let session = session_guard
831 .as_mut()
832 .ok_or_else(|| ClipError::InferenceError {
833 cause: "Vision session not loaded".to_string(),
834 })?;
835
836 let input_name = session
838 .inputs
839 .first()
840 .map(|i| i.name.clone())
841 .unwrap_or_else(|| "pixel_values".into());
842 let output_name = session
843 .outputs
844 .first()
845 .map(|o| o.name.clone())
846 .unwrap_or_else(|| "image_embeds".into());
847
848 let input_tensor =
850 Tensor::from_array(pixel_values).map_err(|e| ClipError::InferenceError {
851 cause: format!("Failed to create input tensor: {}", e),
852 })?;
853
854 let outputs = session
856 .run(ort::inputs![input_name => input_tensor])
857 .map_err(|e| ClipError::InferenceError {
858 cause: format!("Vision inference failed: {}", e),
859 })?;
860
861 let output = outputs
863 .get(&output_name)
864 .ok_or_else(|| ClipError::InferenceError {
865 cause: format!("No output '{}' from vision model", output_name),
866 })?;
867
868 let (_shape, data) =
869 output
870 .try_extract_tensor::<f32>()
871 .map_err(|e| ClipError::InferenceError {
872 cause: format!("Failed to extract embeddings: {}", e),
873 })?;
874
875 let embedding: Vec<f32> = data.to_vec();
877 if embedding.iter().any(|v| !v.is_finite()) {
878 return Err(ClipError::InferenceError {
879 cause: "Vision embedding contains non-finite values".to_string(),
880 }
881 .into());
882 }
883 let normalized = l2_normalize(&embedding);
884
885 tracing::debug!(dims = normalized.len(), "Generated CLIP image embedding");
886
887 Ok(normalized)
888 }
889
890 pub fn encode_image_bytes(&self, bytes: &[u8]) -> Result<Vec<f32>> {
892 let image =
893 image::load_from_memory(bytes).map_err(|e| ClipError::ImageBytesDecodeError {
894 cause: e.to_string(),
895 })?;
896 self.encode_image(&image)
897 }
898
899 pub fn encode_image_file(&self, path: &Path) -> Result<Vec<f32>> {
901 let image = image::open(path).map_err(|e| ClipError::ImageDecodeError {
902 path: path.to_path_buf(),
903 cause: e.to_string(),
904 })?;
905 self.encode_image(&image)
906 }
907
908 pub fn encode_text(&self, text: &str) -> Result<Vec<f32>> {
910 self.load_text_session()?;
912 self.load_tokenizer()?;
913
914 let encoding = {
916 let tokenizer_guard = self
917 .tokenizer
918 .lock()
919 .map_err(|_| MemvidError::Lock("Failed to lock CLIP tokenizer".into()))?;
920 let tokenizer =
921 tokenizer_guard
922 .as_ref()
923 .ok_or_else(|| ClipError::InferenceError {
924 cause: "Tokenizer not loaded".to_string(),
925 })?;
926
927 tokenizer
928 .encode(text, true)
929 .map_err(|e| ClipError::InferenceError {
930 cause: format!("Text tokenization failed: {}", e),
931 })?
932 };
933
934 let input_ids: Vec<i64> = encoding.get_ids().iter().map(|id| *id as i64).collect();
935 let attention_mask: Vec<i64> = encoding
936 .get_attention_mask()
937 .iter()
938 .map(|id| *id as i64)
939 .collect();
940 let max_length = input_ids.len();
941
942 let input_ids_array =
944 Array::from_shape_vec((1, max_length), input_ids).map_err(|e| {
945 ClipError::InferenceError {
946 cause: e.to_string(),
947 }
948 })?;
949 let attention_mask_array = Array::from_shape_vec((1, max_length), attention_mask)
950 .map_err(|e| ClipError::InferenceError {
951 cause: e.to_string(),
952 })?;
953
954 if let Ok(mut last) = self.last_used.lock() {
956 *last = Instant::now();
957 }
958
959 let mut session_guard = self
961 .text_session
962 .lock()
963 .map_err(|_| MemvidError::Lock("Failed to lock text session".into()))?;
964
965 let session = session_guard
966 .as_mut()
967 .ok_or_else(|| ClipError::InferenceError {
968 cause: "Text session not loaded".to_string(),
969 })?;
970
971 let input_names: Vec<String> = session.inputs.iter().map(|i| i.name.clone()).collect();
973 let output_name = session
974 .outputs
975 .first()
976 .map(|o| o.name.clone())
977 .unwrap_or_else(|| "text_embeds".into());
978
979 let input_ids_tensor =
981 Tensor::from_array(input_ids_array).map_err(|e| ClipError::InferenceError {
982 cause: format!("Failed to create input_ids tensor: {}", e),
983 })?;
984 let attention_mask_tensor = Tensor::from_array(attention_mask_array).map_err(|e| {
985 ClipError::InferenceError {
986 cause: format!("Failed to create attention_mask tensor: {}", e),
987 }
988 })?;
989
990 let outputs = if input_names.len() >= 2 {
992 session
993 .run(ort::inputs![
994 input_names[0].clone() => input_ids_tensor,
995 input_names[1].clone() => attention_mask_tensor
996 ])
997 .map_err(|e| ClipError::InferenceError {
998 cause: format!("Text inference failed: {}", e),
999 })?
1000 } else {
1001 let name = input_names
1003 .first()
1004 .cloned()
1005 .unwrap_or_else(|| "input_ids".to_string());
1006 session
1007 .run(ort::inputs![name => input_ids_tensor])
1008 .map_err(|e| ClipError::InferenceError {
1009 cause: format!("Text inference failed: {}", e),
1010 })?
1011 };
1012
1013 let output = outputs
1015 .get(&output_name)
1016 .ok_or_else(|| ClipError::InferenceError {
1017 cause: format!("No output '{}' from text model", output_name),
1018 })?;
1019
1020 let (_shape, data) =
1021 output
1022 .try_extract_tensor::<f32>()
1023 .map_err(|e| ClipError::InferenceError {
1024 cause: format!("Failed to extract text embeddings: {}", e),
1025 })?;
1026
1027 let embedding: Vec<f32> = data.to_vec();
1029 if embedding.iter().any(|v| !v.is_finite()) {
1030 return Err(ClipError::InferenceError {
1031 cause: "Text embedding contains non-finite values".to_string(),
1032 }
1033 .into());
1034 }
1035 let normalized = l2_normalize(&embedding);
1036
1037 tracing::debug!(
1038 text_len = text.len(),
1039 dims = normalized.len(),
1040 "Generated CLIP text embedding"
1041 );
1042
1043 Ok(normalized)
1044 }
1045
1046 pub fn maybe_unload(&self) -> Result<()> {
1048 let last_used = self
1049 .last_used
1050 .lock()
1051 .map_err(|_| MemvidError::Lock("Failed to check last_used".into()))?;
1052
1053 if last_used.elapsed() > MODEL_UNLOAD_TIMEOUT {
1054 tracing::debug!(model = %self.model_info.name, "Model idle, unloading sessions");
1055
1056 if let Ok(mut guard) = self.vision_session.lock() {
1058 *guard = None;
1059 }
1060
1061 if let Ok(mut guard) = self.text_session.lock() {
1063 *guard = None;
1064 }
1065
1066 if let Ok(mut guard) = self.tokenizer.lock() {
1068 *guard = None;
1069 }
1070 }
1071
1072 Ok(())
1073 }
1074
1075 pub fn unload(&self) -> Result<()> {
1077 if let Ok(mut guard) = self.vision_session.lock() {
1078 *guard = None;
1079 }
1080 if let Ok(mut guard) = self.text_session.lock() {
1081 *guard = None;
1082 }
1083 if let Ok(mut guard) = self.tokenizer.lock() {
1084 *guard = None;
1085 }
1086 tracing::debug!(model = %self.model_info.name, "CLIP sessions unloaded");
1087 Ok(())
1088 }
1089
1090 pub fn is_vision_loaded(&self) -> bool {
1092 self.vision_session
1093 .lock()
1094 .map(|g| g.is_some())
1095 .unwrap_or(false)
1096 }
1097
1098 pub fn is_text_loaded(&self) -> bool {
1100 self.text_session
1101 .lock()
1102 .map(|g| g.is_some())
1103 .unwrap_or(false)
1104 }
1105 }
1106
1107 fn l2_normalize(v: &[f32]) -> Vec<f32> {
1109 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1110 if norm.is_finite() && norm > 1e-10 {
1111 v.iter().map(|x| x / norm).collect()
1112 } else {
1113 vec![0.0; v.len()]
1115 }
1116 }
1117
1118 pub fn calculate_color_variance(image: &DynamicImage) -> f32 {
1120 let rgb = image.to_rgb8();
1121 let (width, height) = rgb.dimensions();
1122 let total_pixels = (width * height) as f32;
1123
1124 if total_pixels == 0.0 {
1125 return 0.0;
1126 }
1127
1128 let mut sum_r = 0.0f32;
1130 let mut sum_g = 0.0f32;
1131 let mut sum_b = 0.0f32;
1132
1133 for pixel in rgb.pixels() {
1134 sum_r += pixel[0] as f32;
1135 sum_g += pixel[1] as f32;
1136 sum_b += pixel[2] as f32;
1137 }
1138
1139 let mean_r = sum_r / total_pixels;
1140 let mean_g = sum_g / total_pixels;
1141 let mean_b = sum_b / total_pixels;
1142
1143 let mut var_r = 0.0f32;
1145 let mut var_g = 0.0f32;
1146 let mut var_b = 0.0f32;
1147
1148 for pixel in rgb.pixels() {
1149 var_r += (pixel[0] as f32 - mean_r).powi(2);
1150 var_g += (pixel[1] as f32 - mean_g).powi(2);
1151 var_b += (pixel[2] as f32 - mean_b).powi(2);
1152 }
1153
1154 ((var_r + var_g + var_b) / (3.0 * total_pixels)) / (255.0 * 255.0)
1156 }
1157
1158 pub fn get_image_info(image: &DynamicImage) -> ImageInfo {
1160 let (width, height) = image.dimensions();
1161 let color_variance = calculate_color_variance(image);
1162
1163 ImageInfo {
1164 width,
1165 height,
1166 color_variance,
1167 }
1168 }
1169}
1170
1171#[cfg(feature = "clip")]
1172pub use model::*;
1173
1174#[cfg(all(feature = "clip", feature = "pdfium"))]
1175use pdfium_render::prelude::{PdfPageRenderRotation, PdfRenderConfig, Pdfium};
1176
1177#[cfg(all(feature = "clip", feature = "pdfium"))]
1179pub fn render_pdf_pages_for_clip(
1180 path: &Path,
1181 max_pages: usize,
1182 target_px: u32,
1183) -> Result<Vec<(u32, DynamicImage)>> {
1184 let bindings = Pdfium::bind_to_system_library().map_err(|e| ClipError::InferenceError {
1185 cause: format!("Failed to bind pdfium: {}", e),
1186 })?;
1187 let pdfium = Pdfium::new(bindings);
1188 let document =
1189 pdfium
1190 .load_pdf_from_file(path, None)
1191 .map_err(|e| ClipError::InferenceError {
1192 cause: format!("Failed to load PDF for CLIP rendering: {}", e),
1193 })?;
1194
1195 let render_config = PdfRenderConfig::new()
1196 .set_target_width(target_px as i32)
1197 .set_maximum_height(target_px as i32)
1198 .set_maximum_width(target_px as i32)
1199 .rotate_if_landscape(PdfPageRenderRotation::None, false);
1200
1201 let mut pages = Vec::new();
1202 for (index, page) in document.pages().iter().enumerate() {
1203 if index >= max_pages {
1204 break;
1205 }
1206 let rendered = page
1207 .render_with_config(&render_config)
1208 .map_err(|e| ClipError::InferenceError {
1209 cause: format!("Failed to render PDF page {}: {}", index + 1, e),
1210 })?
1211 .as_image();
1212 pages.push(((index + 1) as u32, rendered));
1213 }
1214
1215 Ok(pages)
1216}
1217
1218#[cfg(all(feature = "clip", not(feature = "pdfium")))]
1219pub fn render_pdf_pages_for_clip(
1220 path: &Path,
1221 max_pages: usize,
1222 _target_px: u32,
1223) -> Result<Vec<(u32, DynamicImage)>> {
1224 fn extract_images_from_page(
1225 doc: &Document,
1226 page_id: ObjectId,
1227 remaining: &mut usize,
1228 out: &mut Vec<(u32, DynamicImage)>,
1229 ) -> Result<()> {
1230 if *remaining == 0 {
1231 return Ok(());
1232 }
1233
1234 let (resources_opt, resource_ids) = doc.get_page_resources(page_id).map_err(|e| {
1235 ClipError::InferenceError {
1236 cause: format!("Failed to read PDF resources: {}", e),
1237 }
1238 })?;
1239
1240 let mut seen = HashSet::new();
1241 let mut resource_dicts: Vec<Dictionary> = Vec::new();
1242
1243 if let Some(dict) = resources_opt {
1244 resource_dicts.push(dict.clone());
1245 }
1246 for res_id in resource_ids {
1247 if seen.insert(res_id) {
1248 if let Ok(dict) = doc.get_dictionary(res_id) {
1249 resource_dicts.push(dict.clone());
1250 }
1251 }
1252 }
1253
1254 for dict in resource_dicts {
1255 if let Ok(xobjects) = dict.get(b"XObject") {
1256 let xobj_dict = match xobjects {
1257 Object::Dictionary(d) => Some(d),
1258 Object::Reference(id) => doc.get_dictionary(*id).ok(),
1259 _ => None,
1260 };
1261 if let Some(xobj_dict) = xobj_dict {
1262 for (_, obj) in xobj_dict.iter() {
1263 let id = match obj {
1264 Object::Reference(id) => *id,
1265 _ => continue,
1266 };
1267 let stream = match doc.get_object(id).and_then(Object::as_stream) {
1268 Ok(s) => s,
1269 Err(_) => continue,
1270 };
1271 let subtype = stream.dict.get(b"Subtype").ok();
1272 let is_image = matches!(subtype, Some(Object::Name(n)) if n == b"Image");
1273 if !is_image {
1274 continue;
1275 }
1276
1277 let width = stream
1278 .dict
1279 .get(b"Width")
1280 .ok()
1281 .and_then(|o| o.as_i64().ok())
1282 .unwrap_or(0);
1283 let height = stream
1284 .dict
1285 .get(b"Height")
1286 .ok()
1287 .and_then(|o| o.as_i64().ok())
1288 .unwrap_or(0);
1289 if width <= 0 || height <= 0 {
1290 continue;
1291 }
1292
1293 let filters = stream
1294 .dict
1295 .get(b"Filter")
1296 .ok()
1297 .and_then(|f| match f {
1298 Object::Name(n) => Some(vec![n.clone()]),
1299 Object::Array(arr) => {
1300 Some(
1301 arr.iter()
1302 .filter_map(|o| o.as_name().ok().map(|n| n.to_vec()))
1303 .collect(),
1304 )
1305 }
1306 _ => None,
1307 })
1308 .unwrap_or_default();
1309
1310 let data = stream
1311 .decompressed_content()
1312 .unwrap_or_else(|_| stream.content.clone());
1313
1314 if filters
1316 .iter()
1317 .any(|f| f == b"DCTDecode" || f == b"JPXDecode")
1318 {
1319 if let Ok(img) = image::load_from_memory(&data) {
1320 out.push((1, img));
1321 if out.len() >= *remaining {
1322 *remaining = 0;
1323 return Ok(());
1324 }
1325 *remaining -= 1;
1326 continue;
1327 }
1328 }
1329
1330 let color_space = stream
1331 .dict
1332 .get(b"ColorSpace")
1333 .ok()
1334 .and_then(|o| o.as_name().ok())
1335 .unwrap_or(b"DeviceRGB");
1336 let channels = if color_space == b"DeviceGray" { 1 } else { 3 };
1337
1338 let expected = width as usize * height as usize * channels;
1339 if data.len() >= expected && channels == 3 {
1340 if let Some(buf) = ImageBuffer::<Rgb<u8>, _>::from_raw(
1341 width as u32,
1342 height as u32,
1343 data.clone(),
1344 ) {
1345 out.push((1, DynamicImage::ImageRgb8(buf)));
1346 if out.len() >= *remaining {
1347 *remaining = 0;
1348 return Ok(());
1349 }
1350 *remaining -= 1;
1351 continue;
1352 }
1353 } else if data.len() >= expected && channels == 1 {
1354 if let Some(buf) = ImageBuffer::<Luma<u8>, _>::from_raw(
1355 width as u32,
1356 height as u32,
1357 data.clone(),
1358 ) {
1359 out.push((1, DynamicImage::ImageLuma8(buf)));
1360 if out.len() >= *remaining {
1361 *remaining = 0;
1362 return Ok(());
1363 }
1364 *remaining -= 1;
1365 continue;
1366 }
1367 }
1368 }
1369 }
1370 }
1371 }
1372
1373 Ok(())
1374 }
1375
1376 let doc = Document::load(path).map_err(|e| ClipError::InferenceError {
1377 cause: format!("Failed to load PDF for image extraction: {}", e),
1378 })?;
1379
1380 let mut remaining = max_pages;
1381 let mut pages: Vec<(u32, DynamicImage)> = Vec::new();
1382
1383 for (page_num, page_id) in doc.get_pages() {
1384 if remaining == 0 {
1385 break;
1386 }
1387 let start_len = pages.len();
1388 extract_images_from_page(&doc, page_id, &mut remaining, &mut pages)?;
1389 if pages.len() > start_len {
1390 for entry in pages.iter_mut().skip(start_len) {
1391 entry.0 = page_num as u32;
1392 }
1393 }
1394 }
1395
1396 Ok(pages)
1397}
1398
1399pub trait ClipEmbeddingProvider: Send + Sync {
1428 fn kind(&self) -> &str;
1430
1431 fn model(&self) -> &str;
1433
1434 fn dimension(&self) -> usize;
1436
1437 fn embed_image_file(&self, path: &Path) -> Result<Vec<f32>>;
1439
1440 fn embed_image_bytes(&self, bytes: &[u8]) -> Result<Vec<f32>>;
1442
1443 fn embed_query(&self, text: &str) -> Result<Vec<f32>>;
1445
1446 fn embed_image_batch(&self, paths: &[&Path]) -> Result<Vec<Vec<f32>>> {
1451 let mut embeddings = Vec::with_capacity(paths.len());
1452 for path in paths {
1453 embeddings.push(self.embed_image_file(path)?);
1454 }
1455 Ok(embeddings)
1456 }
1457
1458 fn is_ready(&self) -> bool {
1460 true
1461 }
1462
1463 fn init(&mut self) -> Result<()> {
1465 Ok(())
1466 }
1467
1468 fn unload(&self) -> Result<()> {
1470 Ok(())
1471 }
1472}
1473
1474pub type ClipEmbeddingResult = Result<Vec<f32>>;
1476pub type ClipBatchEmbeddingResult = Result<Vec<Vec<f32>>>;
1477
1478#[cfg(feature = "clip")]
1483impl ClipEmbeddingProvider for ClipModel {
1484 fn kind(&self) -> &str {
1485 "clip"
1486 }
1487
1488 fn model(&self) -> &str {
1489 self.model_info().name
1490 }
1491
1492 fn dimension(&self) -> usize {
1493 self.model_info().dims as usize
1494 }
1495
1496 fn embed_image_file(&self, path: &Path) -> Result<Vec<f32>> {
1497 self.encode_image_file(path)
1498 }
1499
1500 fn embed_image_bytes(&self, bytes: &[u8]) -> Result<Vec<f32>> {
1501 self.encode_image_bytes(bytes)
1502 }
1503
1504 fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
1505 self.encode_text(text)
1506 }
1507
1508 fn embed_image_batch(&self, paths: &[&Path]) -> Result<Vec<Vec<f32>>> {
1509 let mut embeddings = Vec::with_capacity(paths.len());
1511 for path in paths {
1512 embeddings.push(self.encode_image_file(path)?);
1513 }
1514 Ok(embeddings)
1515 }
1516
1517 fn is_ready(&self) -> bool {
1518 true
1520 }
1521
1522 fn unload(&self) -> Result<()> {
1523 ClipModel::unload(self)
1524 }
1525}
1526
1527#[derive(Debug, Clone, Serialize, Deserialize)]
1533pub struct ClipIndexManifest {
1534 pub bytes_offset: u64,
1536 pub bytes_length: u64,
1538 pub vector_count: u64,
1540 pub dimension: u32,
1542 pub checksum: [u8; 32],
1544 pub model_name: String,
1546}
1547
1548#[cfg(test)]
1553mod tests {
1554 use super::*;
1555
1556 #[test]
1557 fn clip_index_builder_roundtrip() {
1558 let mut builder = ClipIndexBuilder::new();
1559 builder.add_document(1, None, vec![0.1, 0.2, 0.3, 0.4]);
1560 builder.add_document(2, None, vec![0.5, 0.6, 0.7, 0.8]);
1561
1562 let artifact = builder.finish().expect("finish");
1563 assert_eq!(artifact.vector_count, 2);
1564 assert_eq!(artifact.dimension, 4);
1565
1566 let index = ClipIndex::decode(&artifact.bytes).expect("decode");
1567 assert_eq!(index.len(), 2);
1568
1569 let hits = index.search(&[0.1, 0.2, 0.3, 0.4], 10);
1570 assert_eq!(hits[0].frame_id, 1);
1571 assert!(hits[0].distance < 0.001); }
1573
1574 #[test]
1575 fn clip_index_search() {
1576 let mut builder = ClipIndexBuilder::new();
1577 builder.add_document(1, None, vec![1.0, 0.0, 0.0]);
1578 builder.add_document(2, None, vec![0.0, 1.0, 0.0]);
1579 builder.add_document(3, None, vec![0.0, 0.0, 1.0]);
1580
1581 let artifact = builder.finish().expect("finish");
1582 let index = ClipIndex::decode(&artifact.bytes).expect("decode");
1583
1584 let hits = index.search(&[1.0, 0.0, 0.0], 3);
1586 assert_eq!(hits[0].frame_id, 1);
1587
1588 let hits = index.search(&[0.0, 1.0, 0.0], 3);
1590 assert_eq!(hits[0].frame_id, 2);
1591 }
1592
1593 #[test]
1594 fn l2_distance_calculation() {
1595 let d = l2_distance(&[0.0, 0.0], &[3.0, 4.0]);
1596 assert!((d - 5.0).abs() < 1e-6);
1597
1598 let d = l2_distance(&[1.0, 1.0, 1.0], &[1.0, 1.0, 1.0]);
1599 assert!(d.abs() < 1e-6);
1600 }
1601
1602 #[test]
1603 fn image_info_filtering() {
1604 let tiny = ImageInfo {
1606 width: 32,
1607 height: 32,
1608 color_variance: 0.5,
1609 };
1610 assert!(!tiny.should_embed());
1611
1612 let good = ImageInfo {
1614 width: 256,
1615 height: 256,
1616 color_variance: 0.5,
1617 };
1618 assert!(good.should_embed());
1619
1620 let wide = ImageInfo {
1622 width: 1000,
1623 height: 10,
1624 color_variance: 0.5,
1625 };
1626 assert!(!wide.should_embed());
1627
1628 let solid = ImageInfo {
1630 width: 256,
1631 height: 256,
1632 color_variance: 0.001,
1633 };
1634 assert!(!solid.should_embed());
1635 }
1636
1637 #[test]
1638 fn model_registry() {
1639 let default = default_model_info();
1640 assert_eq!(default.name, "mobileclip-s2");
1641 assert_eq!(default.dims, 512);
1642 assert!(default.is_default);
1643
1644 let siglip = get_model_info("siglip-base");
1645 assert_eq!(siglip.dims, 768);
1646
1647 let unknown = get_model_info("nonexistent");
1649 assert_eq!(unknown.name, "mobileclip-s2");
1650 }
1651
1652 #[test]
1653 fn clip_config_defaults() {
1654 unsafe {
1657 std::env::remove_var("MEMVID_CLIP_MODEL");
1658 std::env::remove_var("MEMVID_OFFLINE");
1659 }
1660
1661 let config = ClipConfig::default();
1662 assert_eq!(config.model_name, "mobileclip-s2");
1663 assert!(!config.offline);
1664 }
1665
1666 #[test]
1667 fn clip_embedding_provider_trait() {
1668 fn assert_send_sync<T: Send + Sync>() {}
1670
1671 assert_send_sync::<Box<dyn super::ClipEmbeddingProvider>>();
1673 }
1674}