memvid_core/
clip.rs

1//! CLIP (Contrastive Language-Image Pre-training) visual embeddings module.
2//!
3//! This module provides visual understanding capabilities using MobileCLIP-S2,
4//! enabling semantic search across images and PDF pages with natural language queries.
5//!
6//! # Design Philosophy
7//!
8//! - **Synchronous with Parallelism**: CLIP runs in parallel with text embedding via rayon.
9//!   Since CLIP (~25ms) is faster than text embedding (~200-500ms), it adds zero latency.
10//! - **Separate Index**: CLIP embeddings (512 dims) are stored in a separate index from
11//!   text embeddings (384/768/1536 dims) because dimensions must match within an index.
12//! - **Auto-detection**: Images and PDFs with images are automatically processed without flags.
13//! - **Graceful Degradation**: Works without CLIP, just loses visual search capability.
14
15use blake3::hash;
16#[cfg(feature = "clip")]
17use image::DynamicImage;
18#[cfg(all(feature = "clip", not(feature = "pdfium")))]
19use image::{ImageBuffer, Luma, Rgb};
20#[cfg(all(feature = "clip", not(feature = "pdfium")))]
21use lopdf::{Dictionary, Document, Object, ObjectId};
22use serde::{Deserialize, Serialize};
23use std::borrow::Cow;
24#[cfg(all(feature = "clip", not(feature = "pdfium")))]
25use std::collections::HashSet;
26use std::path::{Path, PathBuf};
27use std::time::Duration;
28
29use crate::{MemvidError, Result, types::FrameId};
30
31// ============================================================================
32// Configuration Constants
33// ============================================================================
34
35/// CLIP index decode limit (512MB max)
36const CLIP_DECODE_LIMIT: usize = crate::MAX_INDEX_BYTES as usize;
37
38/// MobileCLIP-S2 embedding dimensions
39pub const MOBILECLIP_DIMS: u32 = 512;
40
41/// SigLIP-base embedding dimensions
42pub const SIGLIP_DIMS: u32 = 768;
43
44/// Default input resolution for MobileCLIP-S2
45pub const MOBILECLIP_INPUT_SIZE: u32 = 256;
46
47/// Default input resolution for SigLIP
48pub const SIGLIP_INPUT_SIZE: u32 = 224;
49
50/// Minimum image dimension to process (skip icons, bullets)
51pub const MIN_IMAGE_DIM: u32 = 64;
52
53/// Maximum aspect ratio deviation from 1:1 (skip dividers, lines)
54pub const MAX_ASPECT_RATIO: f32 = 10.0;
55
56/// Minimum color variance threshold (skip solid backgrounds)
57pub const MIN_COLOR_VARIANCE: f32 = 0.01;
58
59/// Model unload timeout (5 minutes idle)
60pub const MODEL_UNLOAD_TIMEOUT: Duration = Duration::from_secs(300);
61
62// ============================================================================
63// Bincode Configuration
64// ============================================================================
65
66fn clip_config() -> impl bincode::config::Config {
67    bincode::config::standard()
68        .with_fixed_int_encoding()
69        .with_little_endian()
70}
71
72// ============================================================================
73// Model Registry
74// ============================================================================
75
76/// Available CLIP models with verified HuggingFace URLs
77#[derive(Debug, Clone)]
78pub struct ClipModelInfo {
79    /// Model identifier
80    pub name: &'static str,
81    /// URL for vision encoder ONNX model
82    pub vision_url: &'static str,
83    /// URL for text encoder ONNX model
84    pub text_url: &'static str,
85    /// URL for tokenizer JSON (BPE)
86    pub tokenizer_url: &'static str,
87    /// Vision model size in MB
88    pub vision_size_mb: f32,
89    /// Text model size in MB
90    pub text_size_mb: f32,
91    /// Output embedding dimensions
92    pub dims: u32,
93    /// Input image resolution
94    pub input_resolution: u32,
95    /// Whether this is the default model
96    pub is_default: bool,
97}
98
99/// Available CLIP models registry
100pub static CLIP_MODELS: &[ClipModelInfo] = &[
101    // MobileCLIP-S2 int8 quantized (smallest, but requires INT8 ONNX support)
102    // Note: INT8 quantized models don't work on all platforms (ConvInteger not supported)
103    ClipModelInfo {
104        name: "mobileclip-s2-int8",
105        vision_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/onnx/vision_model_int8.onnx",
106        text_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/onnx/text_model_int8.onnx",
107        tokenizer_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/tokenizer.json",
108        vision_size_mb: 36.7,
109        text_size_mb: 64.1,
110        dims: MOBILECLIP_DIMS,
111        input_resolution: MOBILECLIP_INPUT_SIZE,
112        is_default: false,
113    },
114    // Alternative: SigLIP-base quantized (higher quality, but may have INT8 issues)
115    ClipModelInfo {
116        name: "siglip-base",
117        vision_url: "https://huggingface.co/Xenova/siglip-base-patch16-224/resolve/main/onnx/vision_model_quantized.onnx",
118        text_url: "https://huggingface.co/Xenova/siglip-base-patch16-224/resolve/main/onnx/text_model_quantized.onnx",
119        tokenizer_url: "https://huggingface.co/Xenova/siglip-base-patch16-224/resolve/main/tokenizer.json",
120        vision_size_mb: 99.5,
121        text_size_mb: 111.0,
122        dims: SIGLIP_DIMS,
123        input_resolution: SIGLIP_INPUT_SIZE,
124        is_default: false,
125    },
126    // Default: MobileCLIP-S2 fp16 (works on all platforms, good balance of size/quality)
127    ClipModelInfo {
128        name: "mobileclip-s2",
129        vision_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/onnx/vision_model_fp16.onnx",
130        text_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/onnx/text_model_fp16.onnx",
131        tokenizer_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/tokenizer.json",
132        vision_size_mb: 71.7,
133        text_size_mb: 127.0,
134        dims: MOBILECLIP_DIMS,
135        input_resolution: MOBILECLIP_INPUT_SIZE,
136        is_default: true,
137    },
138];
139
140/// Get model info by name, defaults to mobileclip-s2
141pub fn get_model_info(name: &str) -> &'static ClipModelInfo {
142    CLIP_MODELS
143        .iter()
144        .find(|m| m.name == name)
145        .unwrap_or_else(|| {
146            CLIP_MODELS
147                .iter()
148                .find(|m| m.is_default)
149                .expect("default model")
150        })
151}
152
153/// Get the default model info
154pub fn default_model_info() -> &'static ClipModelInfo {
155    CLIP_MODELS
156        .iter()
157        .find(|m| m.is_default)
158        .expect("default model exists")
159}
160
161// ============================================================================
162// CLIP Document and Index Types (mirrors vec.rs pattern)
163// ============================================================================
164
165/// A document with CLIP embedding stored in the index
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct ClipDocument {
168    /// Frame ID this embedding belongs to
169    pub frame_id: FrameId,
170    /// CLIP embedding vector (512 or 768 dims depending on model)
171    pub embedding: Vec<f32>,
172    /// Optional page number (for PDFs)
173    #[serde(default)]
174    pub page: Option<u32>,
175}
176
177/// Builder for constructing CLIP index artifacts
178#[derive(Default)]
179pub struct ClipIndexBuilder {
180    documents: Vec<ClipDocument>,
181}
182
183impl ClipIndexBuilder {
184    pub fn new() -> Self {
185        Self::default()
186    }
187
188    /// Add a document with its CLIP embedding
189    pub fn add_document<I>(&mut self, frame_id: FrameId, page: Option<u32>, embedding: I)
190    where
191        I: Into<Vec<f32>>,
192    {
193        self.documents.push(ClipDocument {
194            frame_id,
195            embedding: embedding.into(),
196            page,
197        });
198    }
199
200    /// Finish building and produce the index artifact
201    pub fn finish(self) -> Result<ClipIndexArtifact> {
202        let bytes = bincode::serde::encode_to_vec(&self.documents, clip_config())?;
203
204        let checksum = *hash(&bytes).as_bytes();
205        let dimension = self
206            .documents
207            .first()
208            .map(|doc| doc.embedding.len() as u32)
209            .unwrap_or(0);
210
211        Ok(ClipIndexArtifact {
212            bytes,
213            vector_count: self.documents.len() as u64,
214            dimension,
215            checksum,
216        })
217    }
218}
219
220/// Artifact produced by the CLIP index builder
221#[derive(Debug, Clone)]
222pub struct ClipIndexArtifact {
223    /// Serialized index bytes
224    pub bytes: Vec<u8>,
225    /// Number of vectors in the index
226    pub vector_count: u64,
227    /// Embedding dimension (512 for MobileCLIP, 768 for SigLIP)
228    pub dimension: u32,
229    /// Blake3 checksum of the bytes
230    pub checksum: [u8; 32],
231}
232
233/// In-memory CLIP index for similarity search
234#[derive(Debug, Clone)]
235pub struct ClipIndex {
236    documents: Vec<ClipDocument>,
237}
238
239impl ClipIndex {
240    /// Create a new empty CLIP index
241    pub fn new() -> Self {
242        Self {
243            documents: Vec::new(),
244        }
245    }
246
247    /// Add a document with its CLIP embedding
248    pub fn add_document<I>(&mut self, frame_id: FrameId, page: Option<u32>, embedding: I)
249    where
250        I: Into<Vec<f32>>,
251    {
252        self.documents.push(ClipDocument {
253            frame_id,
254            embedding: embedding.into(),
255            page,
256        });
257    }
258
259    /// Decode CLIP index from bytes
260    pub fn decode(bytes: &[u8]) -> Result<Self> {
261        let (documents, read) = bincode::serde::decode_from_slice::<Vec<ClipDocument>, _>(
262            bytes,
263            bincode::config::standard()
264                .with_fixed_int_encoding()
265                .with_little_endian()
266                .with_limit::<CLIP_DECODE_LIMIT>(),
267        )?;
268
269        if read != bytes.len() {
270            return Err(MemvidError::InvalidToc {
271                reason: Cow::Owned(format!(
272                    "CLIP index decode: expected {} bytes, read {}",
273                    bytes.len(),
274                    read
275                )),
276            });
277        }
278
279        tracing::debug!(
280            bytes_len = bytes.len(),
281            docs_count = documents.len(),
282            "decoded CLIP index"
283        );
284
285        Ok(Self { documents })
286    }
287
288    /// Search for similar embeddings using L2 distance
289    pub fn search(&self, query: &[f32], limit: usize) -> Vec<ClipSearchHit> {
290        if query.is_empty() {
291            return Vec::new();
292        }
293
294        let mut hits: Vec<ClipSearchHit> = self
295            .documents
296            .iter()
297            .map(|doc| {
298                let distance = l2_distance(query, &doc.embedding);
299                ClipSearchHit {
300                    frame_id: doc.frame_id,
301                    page: doc.page,
302                    distance,
303                }
304            })
305            .collect();
306
307        hits.sort_by(|a, b| {
308            a.distance
309                .partial_cmp(&b.distance)
310                .unwrap_or(std::cmp::Ordering::Equal)
311        });
312        hits.truncate(limit);
313        hits
314    }
315
316    /// Get all entries in the index
317    pub fn entries(&self) -> impl Iterator<Item = (FrameId, Option<u32>, &[f32])> + '_ {
318        self.documents
319            .iter()
320            .map(|doc| (doc.frame_id, doc.page, doc.embedding.as_slice()))
321    }
322
323    /// Get embedding for a specific frame
324    pub fn embedding_for(&self, frame_id: FrameId) -> Option<&[f32]> {
325        self.documents
326            .iter()
327            .find(|doc| doc.frame_id == frame_id)
328            .map(|doc| doc.embedding.as_slice())
329    }
330
331    /// Remove a document from the index
332    pub fn remove(&mut self, frame_id: FrameId) {
333        self.documents.retain(|doc| doc.frame_id != frame_id);
334    }
335
336    /// Number of documents in the index
337    pub fn len(&self) -> usize {
338        self.documents.len()
339    }
340
341    /// Check if index is empty
342    pub fn is_empty(&self) -> bool {
343        self.documents.is_empty()
344    }
345
346    /// Encode the CLIP index to bytes and produce an artifact for persistence
347    pub fn encode(&self) -> Result<ClipIndexArtifact> {
348        let bytes = bincode::serde::encode_to_vec(&self.documents, clip_config())?;
349
350        let checksum = *hash(&bytes).as_bytes();
351        let dimension = self
352            .documents
353            .first()
354            .map(|doc| doc.embedding.len() as u32)
355            .unwrap_or(0);
356
357        Ok(ClipIndexArtifact {
358            bytes,
359            vector_count: self.documents.len() as u64,
360            dimension,
361            checksum,
362        })
363    }
364}
365
366/// Search result from CLIP index
367#[derive(Debug, Clone, PartialEq)]
368pub struct ClipSearchHit {
369    /// Frame ID of the matched document
370    pub frame_id: FrameId,
371    /// Optional page number (for PDFs)
372    pub page: Option<u32>,
373    /// L2 distance to query (lower is more similar)
374    pub distance: f32,
375}
376
377/// L2 (Euclidean) distance between two vectors
378fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
379    a.iter()
380        .zip(b.iter())
381        .map(|(x, y)| (x - y).powi(2))
382        .sum::<f32>()
383        .sqrt()
384}
385
386// ============================================================================
387// Image Filtering (Junk Detection)
388// ============================================================================
389
390/// Metadata about an image for filtering
391#[derive(Debug, Clone)]
392pub struct ImageInfo {
393    pub width: u32,
394    pub height: u32,
395    pub color_variance: f32,
396}
397
398impl ImageInfo {
399    /// Check if this image should be processed for CLIP embedding
400    pub fn should_embed(&self) -> bool {
401        // Skip tiny images (icons, bullets)
402        if self.width < MIN_IMAGE_DIM || self.height < MIN_IMAGE_DIM {
403            return false;
404        }
405
406        // Skip extreme aspect ratios (dividers, lines)
407        let aspect = self.width as f32 / self.height as f32;
408        if aspect > MAX_ASPECT_RATIO || aspect < (1.0 / MAX_ASPECT_RATIO) {
409            return false;
410        }
411
412        // Skip near-solid colors (backgrounds, spacers)
413        if self.color_variance < MIN_COLOR_VARIANCE {
414            return false;
415        }
416
417        true
418    }
419}
420
421/// Filter a list of images, keeping only those worth embedding
422pub fn filter_junk_images<T, F>(images: Vec<T>, get_info: F) -> Vec<T>
423where
424    F: Fn(&T) -> ImageInfo,
425{
426    images
427        .into_iter()
428        .filter(|img| get_info(img).should_embed())
429        .collect()
430}
431
432// ============================================================================
433// CLIP Model Configuration
434// ============================================================================
435
436/// Configuration for CLIP model initialization
437#[derive(Debug, Clone)]
438pub struct ClipConfig {
439    /// Model name (e.g., "mobileclip-s2", "siglip-base")
440    pub model_name: String,
441    /// Directory where models are cached
442    pub models_dir: PathBuf,
443    /// Whether to run in offline mode (no downloads)
444    pub offline: bool,
445}
446
447impl Default for ClipConfig {
448    fn default() -> Self {
449        // Use ~/.memvid/models as default, consistent with CLI's model installation
450        let models_dir = std::env::var("MEMVID_MODELS_DIR")
451            .ok()
452            .map(PathBuf::from)
453            .or_else(|| dirs_next::home_dir().map(|d| d.join(".memvid/models")))
454            .unwrap_or_else(|| PathBuf::from(".memvid/models"));
455
456        let model_name =
457            std::env::var("MEMVID_CLIP_MODEL").unwrap_or_else(|_| "mobileclip-s2".to_string());
458
459        let offline = std::env::var("MEMVID_OFFLINE").is_ok();
460
461        Self {
462            model_name,
463            models_dir,
464            offline,
465        }
466    }
467}
468
469// ============================================================================
470// CLIP Error Types
471// ============================================================================
472
473/// CLIP-specific errors
474#[derive(Debug, thiserror::Error)]
475pub enum ClipError {
476    /// Model not found and offline mode enabled
477    #[error("CLIP model '{model}' not found. {hint}")]
478    ModelNotFound { model: String, hint: String },
479
480    /// Image decode failed
481    #[error("Failed to decode image at {path:?}: {cause}")]
482    ImageDecodeError { path: PathBuf, cause: String },
483
484    /// Image bytes decode failed
485    #[error("Failed to decode image bytes: {cause}")]
486    ImageBytesDecodeError { cause: String },
487
488    /// ONNX runtime error
489    #[error("CLIP inference error: {cause}")]
490    InferenceError { cause: String },
491
492    /// Model download failed
493    #[error("Failed to download CLIP model: {cause}")]
494    DownloadError { cause: String },
495
496    /// Model file corrupted or invalid
497    #[error("CLIP model file is corrupted: {cause}")]
498    ModelCorrupted { cause: String },
499}
500
501impl From<ClipError> for MemvidError {
502    fn from(err: ClipError) -> Self {
503        MemvidError::EmbeddingFailed {
504            reason: err.to_string().into_boxed_str(),
505        }
506    }
507}
508
509// ============================================================================
510// CLIP Model (Feature-gated implementation)
511// ============================================================================
512
513#[cfg(feature = "clip")]
514mod model {
515    use super::*;
516    use image::{DynamicImage, GenericImageView, imageops::FilterType};
517    use ndarray::{Array, Array4};
518    use ort::session::{Session, builder::GraphOptimizationLevel};
519    use ort::value::Tensor;
520    use std::sync::Mutex;
521    use std::time::Instant;
522    use tokenizers::{
523        PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationDirection,
524        TruncationParams, TruncationStrategy,
525    };
526
527    /// CLIP model with lazy-loaded vision and text encoders
528    pub struct ClipModel {
529        config: ClipConfig,
530        model_info: &'static ClipModelInfo,
531        /// Lazy-loaded vision encoder session
532        vision_session: Mutex<Option<Session>>,
533        /// Lazy-loaded text encoder session
534        text_session: Mutex<Option<Session>>,
535        /// Lazy-loaded tokenizer matching the text encoder
536        tokenizer: Mutex<Option<Tokenizer>>,
537        /// Last time the model was used (for idle unloading)
538        last_used: Mutex<Instant>,
539    }
540
541    impl ClipModel {
542        /// Create a new CLIP model with the given configuration
543        pub fn new(config: ClipConfig) -> Result<Self> {
544            let model_info = get_model_info(&config.model_name);
545
546            Ok(Self {
547                config,
548                model_info,
549                vision_session: Mutex::new(None),
550                text_session: Mutex::new(None),
551                tokenizer: Mutex::new(None),
552                last_used: Mutex::new(Instant::now()),
553            })
554        }
555
556        /// Create with default configuration
557        pub fn default_model() -> Result<Self> {
558            Self::new(ClipConfig::default())
559        }
560
561        /// Get model info
562        pub fn model_info(&self) -> &'static ClipModelInfo {
563            self.model_info
564        }
565
566        /// Get embedding dimensions
567        pub fn dims(&self) -> u32 {
568            self.model_info.dims
569        }
570
571        /// Ensure model file exists, downloading if necessary
572        fn ensure_model_file(&self, kind: &str) -> Result<PathBuf> {
573            let filename = format!("{}_{}.onnx", self.model_info.name, kind);
574            let path = self.config.models_dir.join(&filename);
575
576            if path.exists() {
577                return Ok(path);
578            }
579
580            if self.config.offline {
581                return Err(ClipError::ModelNotFound {
582                    model: self.model_info.name.to_string(),
583                    hint: format!(
584                        "Run: memvid model download {} (or disable MEMVID_OFFLINE)",
585                        self.model_info.name
586                    ),
587                }
588                .into());
589            }
590
591            // Create models directory if needed
592            std::fs::create_dir_all(&self.config.models_dir).map_err(|e| {
593                ClipError::DownloadError {
594                    cause: format!("Failed to create models directory: {}", e),
595                }
596            })?;
597
598            // Provide manual download instructions
599            Err(ClipError::DownloadError {
600                cause: format!(
601                    "Automatic download not yet implemented. Please download manually:\n\
602                     curl -L '{}' -o '{}'",
603                    if kind == "vision" {
604                        self.model_info.vision_url
605                    } else {
606                        self.model_info.text_url
607                    },
608                    path.display()
609                ),
610            }
611            .into())
612        }
613
614        /// Ensure tokenizer file exists, downloading if necessary
615        fn ensure_tokenizer_file(&self) -> Result<PathBuf> {
616            let filename = format!("{}_tokenizer.json", self.model_info.name);
617            let path = self.config.models_dir.join(&filename);
618
619            if path.exists() {
620                return Ok(path);
621            }
622
623            if self.config.offline {
624                return Err(ClipError::ModelNotFound {
625                    model: self.model_info.name.to_string(),
626                    hint: format!(
627                        "Tokenizer missing at {}. Copy tokenizer.json from {}",
628                        path.display(),
629                        self.model_info.tokenizer_url
630                    ),
631                }
632                .into());
633            }
634
635            std::fs::create_dir_all(&self.config.models_dir).map_err(|e| {
636                ClipError::DownloadError {
637                    cause: format!("Failed to create models directory: {}", e),
638                }
639            })?;
640
641            Err(ClipError::DownloadError {
642                cause: format!(
643                    "Automatic download not yet implemented. Please download manually:\n\
644                     curl -L '{}' -o '{}'",
645                    self.model_info.tokenizer_url,
646                    path.display()
647                ),
648            }
649            .into())
650        }
651
652        /// Load vision session lazily
653        fn load_vision_session(&self) -> Result<()> {
654            let mut session_guard = self
655                .vision_session
656                .lock()
657                .map_err(|_| MemvidError::Lock("Failed to lock vision session".into()))?;
658
659            if session_guard.is_some() {
660                return Ok(());
661            }
662
663            let vision_path = self.ensure_model_file("vision")?;
664
665            tracing::debug!(path = %vision_path.display(), "Loading CLIP vision model");
666
667            let session = Session::builder()
668                .map_err(|e| ClipError::InferenceError {
669                    cause: e.to_string(),
670                })?
671                .with_optimization_level(GraphOptimizationLevel::Level3)
672                .map_err(|e| ClipError::InferenceError {
673                    cause: e.to_string(),
674                })?
675                .with_intra_threads(4)
676                .map_err(|e| ClipError::InferenceError {
677                    cause: e.to_string(),
678                })?
679                .commit_from_file(&vision_path)
680                .map_err(|e| ClipError::InferenceError {
681                    cause: format!("Failed to load vision model: {}", e),
682                })?;
683
684            *session_guard = Some(session);
685            tracing::info!(model = %self.model_info.name, "CLIP vision model loaded");
686
687            Ok(())
688        }
689
690        /// Load text session lazily
691        fn load_text_session(&self) -> Result<()> {
692            let mut session_guard = self
693                .text_session
694                .lock()
695                .map_err(|_| MemvidError::Lock("Failed to lock text session".into()))?;
696
697            if session_guard.is_some() {
698                return Ok(());
699            }
700
701            let text_path = self.ensure_model_file("text")?;
702
703            tracing::debug!(path = %text_path.display(), "Loading CLIP text model");
704
705            let session = Session::builder()
706                .map_err(|e| ClipError::InferenceError {
707                    cause: e.to_string(),
708                })?
709                .with_optimization_level(GraphOptimizationLevel::Level3)
710                .map_err(|e| ClipError::InferenceError {
711                    cause: e.to_string(),
712                })?
713                .with_intra_threads(4)
714                .map_err(|e| ClipError::InferenceError {
715                    cause: e.to_string(),
716                })?
717                .commit_from_file(&text_path)
718                .map_err(|e| ClipError::InferenceError {
719                    cause: format!("Failed to load text model: {}", e),
720                })?;
721
722            *session_guard = Some(session);
723            tracing::info!(model = %self.model_info.name, "CLIP text model loaded");
724
725            Ok(())
726        }
727
728        /// Load tokenizer lazily (matches the text model vocab/BPE)
729        fn load_tokenizer(&self) -> Result<()> {
730            let mut tokenizer_guard = self
731                .tokenizer
732                .lock()
733                .map_err(|_| MemvidError::Lock("Failed to lock CLIP tokenizer".into()))?;
734
735            if tokenizer_guard.is_some() {
736                return Ok(());
737            }
738
739            let tokenizer_path = self.ensure_tokenizer_file()?;
740
741            tracing::debug!(path = %tokenizer_path.display(), "Loading CLIP tokenizer");
742
743            let mut tokenizer =
744                Tokenizer::from_file(&tokenizer_path).map_err(|e| ClipError::InferenceError {
745                    cause: format!("Failed to load tokenizer: {}", e),
746                })?;
747
748            tokenizer.with_padding(Some(PaddingParams {
749                strategy: PaddingStrategy::Fixed(77),
750                direction: PaddingDirection::Right,
751                pad_to_multiple_of: None,
752                pad_id: 0,
753                pad_type_id: 0,
754                pad_token: "[PAD]".to_string(),
755            }));
756
757            tokenizer
758                .with_truncation(Some(TruncationParams {
759                    max_length: 77,
760                    strategy: TruncationStrategy::LongestFirst,
761                    stride: 0,
762                    direction: TruncationDirection::Right,
763                }))
764                .map_err(|e| ClipError::InferenceError {
765                    cause: format!("Failed to apply truncation config: {}", e),
766                })?;
767
768            *tokenizer_guard = Some(tokenizer);
769            tracing::info!(model = %self.model_info.name, "CLIP tokenizer loaded");
770
771            Ok(())
772        }
773
774        /// Preprocess image for CLIP inference
775        ///
776        /// MobileCLIP-S2 uses:
777        /// - Input size: 256x256
778        /// - Resize: shortest edge to 256, preserve aspect, center-crop
779        /// - Normalization: scale to [0, 1] (no mean/std shift per preprocessor_config)
780        /// - Format: NCHW (batch, channels, height, width)
781        fn preprocess_image(&self, image: &DynamicImage) -> Array4<f32> {
782            let size = self.model_info.input_resolution;
783            let rgb_input = image.to_rgb8();
784            let (w, h) = rgb_input.dimensions();
785
786            // Resize shortest edge to target while preserving aspect ratio
787            let scale = size as f32 / w.min(h) as f32;
788            let new_w = ((w as f32) * scale).round().max(1.0) as u32;
789            let new_h = ((h as f32) * scale).round().max(1.0) as u32;
790            let resized = image.resize_exact(new_w, new_h, FilterType::Triangle);
791
792            // Center crop to (size, size)
793            let start_x = (resized.width().saturating_sub(size)) / 2;
794            let start_y = (resized.height().saturating_sub(size)) / 2;
795
796            // Create array in NCHW format: [1, 3, H, W]
797            let mut array = Array4::<f32>::zeros((1, 3, size as usize, size as usize));
798
799            for y in 0..size as usize {
800                for x in 0..size as usize {
801                    let pixel = resized.get_pixel(start_x + x as u32, start_y + y as u32);
802                    array[[0, 0, y, x]] = pixel[0] as f32 / 255.0;
803                    array[[0, 1, y, x]] = pixel[1] as f32 / 255.0;
804                    array[[0, 2, y, x]] = pixel[2] as f32 / 255.0;
805                }
806            }
807
808            array
809        }
810
811        /// Encode an image to CLIP embedding
812        pub fn encode_image(&self, image: &DynamicImage) -> Result<Vec<f32>> {
813            // Ensure vision session is loaded
814            self.load_vision_session()?;
815
816            // Preprocess the image
817            let pixel_values = self.preprocess_image(image);
818
819            // Update last used timestamp
820            if let Ok(mut last) = self.last_used.lock() {
821                *last = Instant::now();
822            }
823
824            // Run inference
825            let mut session_guard = self
826                .vision_session
827                .lock()
828                .map_err(|_| MemvidError::Lock("Failed to lock vision session".into()))?;
829
830            let session = session_guard
831                .as_mut()
832                .ok_or_else(|| ClipError::InferenceError {
833                    cause: "Vision session not loaded".to_string(),
834                })?;
835
836            // Get input and output names from session before running
837            let input_name = session
838                .inputs
839                .first()
840                .map(|i| i.name.clone())
841                .unwrap_or_else(|| "pixel_values".into());
842            let output_name = session
843                .outputs
844                .first()
845                .map(|o| o.name.clone())
846                .unwrap_or_else(|| "image_embeds".into());
847
848            // Create tensor from ndarray
849            let input_tensor =
850                Tensor::from_array(pixel_values).map_err(|e| ClipError::InferenceError {
851                    cause: format!("Failed to create input tensor: {}", e),
852                })?;
853
854            // Run the model
855            let outputs = session
856                .run(ort::inputs![input_name => input_tensor])
857                .map_err(|e| ClipError::InferenceError {
858                    cause: format!("Vision inference failed: {}", e),
859                })?;
860
861            // Extract embeddings from first output
862            let output = outputs
863                .get(&output_name)
864                .ok_or_else(|| ClipError::InferenceError {
865                    cause: format!("No output '{}' from vision model", output_name),
866                })?;
867
868            let (_shape, data) =
869                output
870                    .try_extract_tensor::<f32>()
871                    .map_err(|e| ClipError::InferenceError {
872                        cause: format!("Failed to extract embeddings: {}", e),
873                    })?;
874
875            // Get the embedding from the raw data
876            let embedding: Vec<f32> = data.to_vec();
877            if embedding.iter().any(|v| !v.is_finite()) {
878                return Err(ClipError::InferenceError {
879                    cause: "Vision embedding contains non-finite values".to_string(),
880                }
881                .into());
882            }
883            let normalized = l2_normalize(&embedding);
884
885            tracing::debug!(dims = normalized.len(), "Generated CLIP image embedding");
886
887            Ok(normalized)
888        }
889
890        /// Encode image bytes to CLIP embedding
891        pub fn encode_image_bytes(&self, bytes: &[u8]) -> Result<Vec<f32>> {
892            let image =
893                image::load_from_memory(bytes).map_err(|e| ClipError::ImageBytesDecodeError {
894                    cause: e.to_string(),
895                })?;
896            self.encode_image(&image)
897        }
898
899        /// Encode an image file to CLIP embedding
900        pub fn encode_image_file(&self, path: &Path) -> Result<Vec<f32>> {
901            let image = image::open(path).map_err(|e| ClipError::ImageDecodeError {
902                path: path.to_path_buf(),
903                cause: e.to_string(),
904            })?;
905            self.encode_image(&image)
906        }
907
908        /// Encode text to CLIP embedding (for query)
909        pub fn encode_text(&self, text: &str) -> Result<Vec<f32>> {
910            // Ensure text session is loaded
911            self.load_text_session()?;
912            self.load_tokenizer()?;
913
914            // Tokenize the text using the model's tokenizer
915            let encoding = {
916                let tokenizer_guard = self
917                    .tokenizer
918                    .lock()
919                    .map_err(|_| MemvidError::Lock("Failed to lock CLIP tokenizer".into()))?;
920                let tokenizer =
921                    tokenizer_guard
922                        .as_ref()
923                        .ok_or_else(|| ClipError::InferenceError {
924                            cause: "Tokenizer not loaded".to_string(),
925                        })?;
926
927                tokenizer
928                    .encode(text, true)
929                    .map_err(|e| ClipError::InferenceError {
930                        cause: format!("Text tokenization failed: {}", e),
931                    })?
932            };
933
934            let input_ids: Vec<i64> = encoding.get_ids().iter().map(|id| *id as i64).collect();
935            let attention_mask: Vec<i64> = encoding
936                .get_attention_mask()
937                .iter()
938                .map(|id| *id as i64)
939                .collect();
940            let max_length = input_ids.len();
941
942            // Create input arrays
943            let input_ids_array =
944                Array::from_shape_vec((1, max_length), input_ids).map_err(|e| {
945                    ClipError::InferenceError {
946                        cause: e.to_string(),
947                    }
948                })?;
949            let attention_mask_array = Array::from_shape_vec((1, max_length), attention_mask)
950                .map_err(|e| ClipError::InferenceError {
951                    cause: e.to_string(),
952                })?;
953
954            // Update last used timestamp
955            if let Ok(mut last) = self.last_used.lock() {
956                *last = Instant::now();
957            }
958
959            // Run inference
960            let mut session_guard = self
961                .text_session
962                .lock()
963                .map_err(|_| MemvidError::Lock("Failed to lock text session".into()))?;
964
965            let session = session_guard
966                .as_mut()
967                .ok_or_else(|| ClipError::InferenceError {
968                    cause: "Text session not loaded".to_string(),
969                })?;
970
971            // Get input and output names from session before running
972            let input_names: Vec<String> = session.inputs.iter().map(|i| i.name.clone()).collect();
973            let output_name = session
974                .outputs
975                .first()
976                .map(|o| o.name.clone())
977                .unwrap_or_else(|| "text_embeds".into());
978
979            // Create tensors from ndarray
980            let input_ids_tensor =
981                Tensor::from_array(input_ids_array).map_err(|e| ClipError::InferenceError {
982                    cause: format!("Failed to create input_ids tensor: {}", e),
983                })?;
984            let attention_mask_tensor = Tensor::from_array(attention_mask_array).map_err(|e| {
985                ClipError::InferenceError {
986                    cause: format!("Failed to create attention_mask tensor: {}", e),
987                }
988            })?;
989
990            // Build inputs based on what the model expects
991            let outputs = if input_names.len() >= 2 {
992                session
993                    .run(ort::inputs![
994                        input_names[0].clone() => input_ids_tensor,
995                        input_names[1].clone() => attention_mask_tensor
996                    ])
997                    .map_err(|e| ClipError::InferenceError {
998                        cause: format!("Text inference failed: {}", e),
999                    })?
1000            } else {
1001                // Single input model
1002                let name = input_names
1003                    .first()
1004                    .cloned()
1005                    .unwrap_or_else(|| "input_ids".to_string());
1006                session
1007                    .run(ort::inputs![name => input_ids_tensor])
1008                    .map_err(|e| ClipError::InferenceError {
1009                        cause: format!("Text inference failed: {}", e),
1010                    })?
1011            };
1012
1013            // Extract embeddings from output
1014            let output = outputs
1015                .get(&output_name)
1016                .ok_or_else(|| ClipError::InferenceError {
1017                    cause: format!("No output '{}' from text model", output_name),
1018                })?;
1019
1020            let (_shape, data) =
1021                output
1022                    .try_extract_tensor::<f32>()
1023                    .map_err(|e| ClipError::InferenceError {
1024                        cause: format!("Failed to extract text embeddings: {}", e),
1025                    })?;
1026
1027            // Flatten and normalize the embedding
1028            let embedding: Vec<f32> = data.to_vec();
1029            if embedding.iter().any(|v| !v.is_finite()) {
1030                return Err(ClipError::InferenceError {
1031                    cause: "Text embedding contains non-finite values".to_string(),
1032                }
1033                .into());
1034            }
1035            let normalized = l2_normalize(&embedding);
1036
1037            tracing::debug!(
1038                text_len = text.len(),
1039                dims = normalized.len(),
1040                "Generated CLIP text embedding"
1041            );
1042
1043            Ok(normalized)
1044        }
1045
1046        /// Maybe unload model if unused for too long (memory management)
1047        pub fn maybe_unload(&self) -> Result<()> {
1048            let last_used = self
1049                .last_used
1050                .lock()
1051                .map_err(|_| MemvidError::Lock("Failed to check last_used".into()))?;
1052
1053            if last_used.elapsed() > MODEL_UNLOAD_TIMEOUT {
1054                tracing::debug!(model = %self.model_info.name, "Model idle, unloading sessions");
1055
1056                // Unload vision session
1057                if let Ok(mut guard) = self.vision_session.lock() {
1058                    *guard = None;
1059                }
1060
1061                // Unload text session
1062                if let Ok(mut guard) = self.text_session.lock() {
1063                    *guard = None;
1064                }
1065
1066                // Unload tokenizer
1067                if let Ok(mut guard) = self.tokenizer.lock() {
1068                    *guard = None;
1069                }
1070            }
1071
1072            Ok(())
1073        }
1074
1075        /// Force unload all sessions
1076        pub fn unload(&self) -> Result<()> {
1077            if let Ok(mut guard) = self.vision_session.lock() {
1078                *guard = None;
1079            }
1080            if let Ok(mut guard) = self.text_session.lock() {
1081                *guard = None;
1082            }
1083            if let Ok(mut guard) = self.tokenizer.lock() {
1084                *guard = None;
1085            }
1086            tracing::debug!(model = %self.model_info.name, "CLIP sessions unloaded");
1087            Ok(())
1088        }
1089
1090        /// Check if vision model is loaded
1091        pub fn is_vision_loaded(&self) -> bool {
1092            self.vision_session
1093                .lock()
1094                .map(|g| g.is_some())
1095                .unwrap_or(false)
1096        }
1097
1098        /// Check if text model is loaded
1099        pub fn is_text_loaded(&self) -> bool {
1100            self.text_session
1101                .lock()
1102                .map(|g| g.is_some())
1103                .unwrap_or(false)
1104        }
1105    }
1106
1107    /// L2 normalize a vector (unit length)
1108    fn l2_normalize(v: &[f32]) -> Vec<f32> {
1109        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1110        if norm.is_finite() && norm > 1e-10 {
1111            v.iter().map(|x| x / norm).collect()
1112        } else {
1113            // Fall back to zeros to avoid NaNs propagating through distances
1114            vec![0.0; v.len()]
1115        }
1116    }
1117
1118    /// Calculate color variance of an image
1119    pub fn calculate_color_variance(image: &DynamicImage) -> f32 {
1120        let rgb = image.to_rgb8();
1121        let (width, height) = rgb.dimensions();
1122        let total_pixels = (width * height) as f32;
1123
1124        if total_pixels == 0.0 {
1125            return 0.0;
1126        }
1127
1128        // Calculate mean
1129        let mut sum_r = 0.0f32;
1130        let mut sum_g = 0.0f32;
1131        let mut sum_b = 0.0f32;
1132
1133        for pixel in rgb.pixels() {
1134            sum_r += pixel[0] as f32;
1135            sum_g += pixel[1] as f32;
1136            sum_b += pixel[2] as f32;
1137        }
1138
1139        let mean_r = sum_r / total_pixels;
1140        let mean_g = sum_g / total_pixels;
1141        let mean_b = sum_b / total_pixels;
1142
1143        // Calculate variance
1144        let mut var_r = 0.0f32;
1145        let mut var_g = 0.0f32;
1146        let mut var_b = 0.0f32;
1147
1148        for pixel in rgb.pixels() {
1149            var_r += (pixel[0] as f32 - mean_r).powi(2);
1150            var_g += (pixel[1] as f32 - mean_g).powi(2);
1151            var_b += (pixel[2] as f32 - mean_b).powi(2);
1152        }
1153
1154        // Average variance across channels, normalized to 0-1
1155        ((var_r + var_g + var_b) / (3.0 * total_pixels)) / (255.0 * 255.0)
1156    }
1157
1158    /// Get ImageInfo from a DynamicImage
1159    pub fn get_image_info(image: &DynamicImage) -> ImageInfo {
1160        let (width, height) = image.dimensions();
1161        let color_variance = calculate_color_variance(image);
1162
1163        ImageInfo {
1164            width,
1165            height,
1166            color_variance,
1167        }
1168    }
1169}
1170
1171#[cfg(feature = "clip")]
1172pub use model::*;
1173
1174#[cfg(all(feature = "clip", feature = "pdfium"))]
1175use pdfium_render::prelude::{PdfPageRenderRotation, PdfRenderConfig, Pdfium};
1176
1177/// Render PDF pages to images suitable for CLIP embedding (feature-gated).
1178#[cfg(all(feature = "clip", feature = "pdfium"))]
1179pub fn render_pdf_pages_for_clip(
1180    path: &Path,
1181    max_pages: usize,
1182    target_px: u32,
1183) -> Result<Vec<(u32, DynamicImage)>> {
1184    let bindings = Pdfium::bind_to_system_library().map_err(|e| ClipError::InferenceError {
1185        cause: format!("Failed to bind pdfium: {}", e),
1186    })?;
1187    let pdfium = Pdfium::new(bindings);
1188    let document =
1189        pdfium
1190            .load_pdf_from_file(path, None)
1191            .map_err(|e| ClipError::InferenceError {
1192                cause: format!("Failed to load PDF for CLIP rendering: {}", e),
1193            })?;
1194
1195    let render_config = PdfRenderConfig::new()
1196        .set_target_width(target_px as i32)
1197        .set_maximum_height(target_px as i32)
1198        .set_maximum_width(target_px as i32)
1199        .rotate_if_landscape(PdfPageRenderRotation::None, false);
1200
1201    let mut pages = Vec::new();
1202    for (index, page) in document.pages().iter().enumerate() {
1203        if index >= max_pages {
1204            break;
1205        }
1206        let rendered = page
1207            .render_with_config(&render_config)
1208            .map_err(|e| ClipError::InferenceError {
1209                cause: format!("Failed to render PDF page {}: {}", index + 1, e),
1210            })?
1211            .as_image();
1212        pages.push(((index + 1) as u32, rendered));
1213    }
1214
1215    Ok(pages)
1216}
1217
1218#[cfg(all(feature = "clip", not(feature = "pdfium")))]
1219pub fn render_pdf_pages_for_clip(
1220    path: &Path,
1221    max_pages: usize,
1222    _target_px: u32,
1223) -> Result<Vec<(u32, DynamicImage)>> {
1224    fn extract_images_from_page(
1225        doc: &Document,
1226        page_id: ObjectId,
1227        remaining: &mut usize,
1228        out: &mut Vec<(u32, DynamicImage)>,
1229    ) -> Result<()> {
1230        if *remaining == 0 {
1231            return Ok(());
1232        }
1233
1234        let (resources_opt, resource_ids) = doc.get_page_resources(page_id).map_err(|e| {
1235            ClipError::InferenceError {
1236                cause: format!("Failed to read PDF resources: {}", e),
1237            }
1238        })?;
1239
1240        let mut seen = HashSet::new();
1241        let mut resource_dicts: Vec<Dictionary> = Vec::new();
1242
1243        if let Some(dict) = resources_opt {
1244            resource_dicts.push(dict.clone());
1245        }
1246        for res_id in resource_ids {
1247            if seen.insert(res_id) {
1248                if let Ok(dict) = doc.get_dictionary(res_id) {
1249                    resource_dicts.push(dict.clone());
1250                }
1251            }
1252        }
1253
1254        for dict in resource_dicts {
1255            if let Ok(xobjects) = dict.get(b"XObject") {
1256                let xobj_dict = match xobjects {
1257                    Object::Dictionary(d) => Some(d),
1258                    Object::Reference(id) => doc.get_dictionary(*id).ok(),
1259                    _ => None,
1260                };
1261                if let Some(xobj_dict) = xobj_dict {
1262                    for (_, obj) in xobj_dict.iter() {
1263                        let id = match obj {
1264                            Object::Reference(id) => *id,
1265                            _ => continue,
1266                        };
1267                        let stream = match doc.get_object(id).and_then(Object::as_stream) {
1268                            Ok(s) => s,
1269                            Err(_) => continue,
1270                        };
1271                        let subtype = stream.dict.get(b"Subtype").ok();
1272                        let is_image = matches!(subtype, Some(Object::Name(n)) if n == b"Image");
1273                        if !is_image {
1274                            continue;
1275                        }
1276
1277                        let width = stream
1278                            .dict
1279                            .get(b"Width")
1280                            .ok()
1281                            .and_then(|o| o.as_i64().ok())
1282                            .unwrap_or(0);
1283                        let height = stream
1284                            .dict
1285                            .get(b"Height")
1286                            .ok()
1287                            .and_then(|o| o.as_i64().ok())
1288                            .unwrap_or(0);
1289                        if width <= 0 || height <= 0 {
1290                            continue;
1291                        }
1292
1293                        let filters = stream
1294                            .dict
1295                            .get(b"Filter")
1296                            .ok()
1297                            .and_then(|f| match f {
1298                                Object::Name(n) => Some(vec![n.clone()]),
1299                                Object::Array(arr) => {
1300                                    Some(
1301                                        arr.iter()
1302                                            .filter_map(|o| o.as_name().ok().map(|n| n.to_vec()))
1303                                            .collect(),
1304                                    )
1305                                }
1306                                _ => None,
1307                            })
1308                            .unwrap_or_default();
1309
1310                        let data = stream
1311                            .decompressed_content()
1312                            .unwrap_or_else(|_| stream.content.clone());
1313
1314                        // If DCT/JPX, hand to image crate directly
1315                        if filters
1316                            .iter()
1317                            .any(|f| f == b"DCTDecode" || f == b"JPXDecode")
1318                        {
1319                            if let Ok(img) = image::load_from_memory(&data) {
1320                                out.push((1, img));
1321                                if out.len() >= *remaining {
1322                                    *remaining = 0;
1323                                    return Ok(());
1324                                }
1325                                *remaining -= 1;
1326                                continue;
1327                            }
1328                        }
1329
1330                        let color_space = stream
1331                            .dict
1332                            .get(b"ColorSpace")
1333                            .ok()
1334                            .and_then(|o| o.as_name().ok())
1335                            .unwrap_or(b"DeviceRGB");
1336                        let channels = if color_space == b"DeviceGray" { 1 } else { 3 };
1337
1338                        let expected = width as usize * height as usize * channels;
1339                        if data.len() >= expected && channels == 3 {
1340                            if let Some(buf) = ImageBuffer::<Rgb<u8>, _>::from_raw(
1341                                width as u32,
1342                                height as u32,
1343                                data.clone(),
1344                            ) {
1345                                out.push((1, DynamicImage::ImageRgb8(buf)));
1346                                if out.len() >= *remaining {
1347                                    *remaining = 0;
1348                                    return Ok(());
1349                                }
1350                                *remaining -= 1;
1351                                continue;
1352                            }
1353                        } else if data.len() >= expected && channels == 1 {
1354                            if let Some(buf) = ImageBuffer::<Luma<u8>, _>::from_raw(
1355                                width as u32,
1356                                height as u32,
1357                                data.clone(),
1358                            ) {
1359                                out.push((1, DynamicImage::ImageLuma8(buf)));
1360                                if out.len() >= *remaining {
1361                                    *remaining = 0;
1362                                    return Ok(());
1363                                }
1364                                *remaining -= 1;
1365                                continue;
1366                            }
1367                        }
1368                    }
1369                }
1370            }
1371        }
1372
1373        Ok(())
1374    }
1375
1376    let doc = Document::load(path).map_err(|e| ClipError::InferenceError {
1377        cause: format!("Failed to load PDF for image extraction: {}", e),
1378    })?;
1379
1380    let mut remaining = max_pages;
1381    let mut pages: Vec<(u32, DynamicImage)> = Vec::new();
1382
1383    for (page_num, page_id) in doc.get_pages() {
1384        if remaining == 0 {
1385            break;
1386        }
1387        let start_len = pages.len();
1388        extract_images_from_page(&doc, page_id, &mut remaining, &mut pages)?;
1389        if pages.len() > start_len {
1390            for entry in pages.iter_mut().skip(start_len) {
1391                entry.0 = page_num as u32;
1392            }
1393        }
1394    }
1395
1396    Ok(pages)
1397}
1398
1399// ============================================================================
1400// CLIP Embedding Provider Trait
1401// ============================================================================
1402
1403/// Trait for CLIP visual embedding providers.
1404///
1405/// Unlike text `EmbeddingProvider`, CLIP providers handle both:
1406/// - **Image encoding**: Generate embeddings from images (for indexing)
1407/// - **Text encoding**: Generate embeddings from text (for queries)
1408///
1409/// This allows natural language queries against visual content.
1410///
1411/// # Example
1412///
1413/// ```ignore
1414/// use memvid_core::clip::{ClipEmbeddingProvider, ClipConfig};
1415///
1416/// // Create provider
1417/// let provider = ClipModel::new(ClipConfig::default())?;
1418///
1419/// // Encode image for indexing
1420/// let image_embedding = provider.embed_image_file(&path)?;
1421///
1422/// // Encode query text for search
1423/// let query_embedding = provider.embed_query("a photo of a cat")?;
1424///
1425/// // Search uses cosine similarity between query and image embeddings
1426/// ```
1427pub trait ClipEmbeddingProvider: Send + Sync {
1428    /// Return the provider kind (e.g., "mobileclip", "siglip").
1429    fn kind(&self) -> &str;
1430
1431    /// Return the model identifier.
1432    fn model(&self) -> &str;
1433
1434    /// Return the embedding dimension.
1435    fn dimension(&self) -> usize;
1436
1437    /// Generate an embedding for an image file.
1438    fn embed_image_file(&self, path: &Path) -> Result<Vec<f32>>;
1439
1440    /// Generate an embedding for image bytes.
1441    fn embed_image_bytes(&self, bytes: &[u8]) -> Result<Vec<f32>>;
1442
1443    /// Generate an embedding for a text query (for searching).
1444    fn embed_query(&self, text: &str) -> Result<Vec<f32>>;
1445
1446    /// Generate embeddings for multiple image files.
1447    ///
1448    /// Default implementation calls `embed_image_file` in a loop.
1449    /// Providers should override this for efficient batch processing.
1450    fn embed_image_batch(&self, paths: &[&Path]) -> Result<Vec<Vec<f32>>> {
1451        let mut embeddings = Vec::with_capacity(paths.len());
1452        for path in paths {
1453            embeddings.push(self.embed_image_file(path)?);
1454        }
1455        Ok(embeddings)
1456    }
1457
1458    /// Check if the provider is ready to generate embeddings.
1459    fn is_ready(&self) -> bool {
1460        true
1461    }
1462
1463    /// Initialize the provider (e.g., load models).
1464    fn init(&mut self) -> Result<()> {
1465        Ok(())
1466    }
1467
1468    /// Unload models to free memory.
1469    fn unload(&self) -> Result<()> {
1470        Ok(())
1471    }
1472}
1473
1474/// Result type for CLIP embedding operations
1475pub type ClipEmbeddingResult = Result<Vec<f32>>;
1476pub type ClipBatchEmbeddingResult = Result<Vec<Vec<f32>>>;
1477
1478// ============================================================================
1479// ClipEmbeddingProvider Implementation (Feature-gated)
1480// ============================================================================
1481
1482#[cfg(feature = "clip")]
1483impl ClipEmbeddingProvider for ClipModel {
1484    fn kind(&self) -> &str {
1485        "clip"
1486    }
1487
1488    fn model(&self) -> &str {
1489        self.model_info().name
1490    }
1491
1492    fn dimension(&self) -> usize {
1493        self.model_info().dims as usize
1494    }
1495
1496    fn embed_image_file(&self, path: &Path) -> Result<Vec<f32>> {
1497        self.encode_image_file(path)
1498    }
1499
1500    fn embed_image_bytes(&self, bytes: &[u8]) -> Result<Vec<f32>> {
1501        self.encode_image_bytes(bytes)
1502    }
1503
1504    fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
1505        self.encode_text(text)
1506    }
1507
1508    fn embed_image_batch(&self, paths: &[&Path]) -> Result<Vec<Vec<f32>>> {
1509        // TODO: Use rayon for parallel processing
1510        let mut embeddings = Vec::with_capacity(paths.len());
1511        for path in paths {
1512            embeddings.push(self.encode_image_file(path)?);
1513        }
1514        Ok(embeddings)
1515    }
1516
1517    fn is_ready(&self) -> bool {
1518        // CLIP models are lazy-loaded, so always "ready"
1519        true
1520    }
1521
1522    fn unload(&self) -> Result<()> {
1523        ClipModel::unload(self)
1524    }
1525}
1526
1527// ============================================================================
1528// CLIP Index Manifest (for TOC)
1529// ============================================================================
1530
1531/// Manifest for CLIP index stored in TOC
1532#[derive(Debug, Clone, Serialize, Deserialize)]
1533pub struct ClipIndexManifest {
1534    /// Byte offset in file
1535    pub bytes_offset: u64,
1536    /// Length in bytes
1537    pub bytes_length: u64,
1538    /// Number of vectors
1539    pub vector_count: u64,
1540    /// Embedding dimensions
1541    pub dimension: u32,
1542    /// Blake3 checksum
1543    pub checksum: [u8; 32],
1544    /// Model name used to generate embeddings
1545    pub model_name: String,
1546}
1547
1548// ============================================================================
1549// Tests
1550// ============================================================================
1551
1552#[cfg(test)]
1553mod tests {
1554    use super::*;
1555
1556    #[test]
1557    fn clip_index_builder_roundtrip() {
1558        let mut builder = ClipIndexBuilder::new();
1559        builder.add_document(1, None, vec![0.1, 0.2, 0.3, 0.4]);
1560        builder.add_document(2, None, vec![0.5, 0.6, 0.7, 0.8]);
1561
1562        let artifact = builder.finish().expect("finish");
1563        assert_eq!(artifact.vector_count, 2);
1564        assert_eq!(artifact.dimension, 4);
1565
1566        let index = ClipIndex::decode(&artifact.bytes).expect("decode");
1567        assert_eq!(index.len(), 2);
1568
1569        let hits = index.search(&[0.1, 0.2, 0.3, 0.4], 10);
1570        assert_eq!(hits[0].frame_id, 1);
1571        assert!(hits[0].distance < 0.001); // Should be very close
1572    }
1573
1574    #[test]
1575    fn clip_index_search() {
1576        let mut builder = ClipIndexBuilder::new();
1577        builder.add_document(1, None, vec![1.0, 0.0, 0.0]);
1578        builder.add_document(2, None, vec![0.0, 1.0, 0.0]);
1579        builder.add_document(3, None, vec![0.0, 0.0, 1.0]);
1580
1581        let artifact = builder.finish().expect("finish");
1582        let index = ClipIndex::decode(&artifact.bytes).expect("decode");
1583
1584        // Search for [1, 0, 0] - should find frame 1 first
1585        let hits = index.search(&[1.0, 0.0, 0.0], 3);
1586        assert_eq!(hits[0].frame_id, 1);
1587
1588        // Search for [0, 1, 0] - should find frame 2 first
1589        let hits = index.search(&[0.0, 1.0, 0.0], 3);
1590        assert_eq!(hits[0].frame_id, 2);
1591    }
1592
1593    #[test]
1594    fn l2_distance_calculation() {
1595        let d = l2_distance(&[0.0, 0.0], &[3.0, 4.0]);
1596        assert!((d - 5.0).abs() < 1e-6);
1597
1598        let d = l2_distance(&[1.0, 1.0, 1.0], &[1.0, 1.0, 1.0]);
1599        assert!(d.abs() < 1e-6);
1600    }
1601
1602    #[test]
1603    fn image_info_filtering() {
1604        // Tiny image - should skip
1605        let tiny = ImageInfo {
1606            width: 32,
1607            height: 32,
1608            color_variance: 0.5,
1609        };
1610        assert!(!tiny.should_embed());
1611
1612        // Good image
1613        let good = ImageInfo {
1614            width: 256,
1615            height: 256,
1616            color_variance: 0.5,
1617        };
1618        assert!(good.should_embed());
1619
1620        // Extreme aspect ratio
1621        let wide = ImageInfo {
1622            width: 1000,
1623            height: 10,
1624            color_variance: 0.5,
1625        };
1626        assert!(!wide.should_embed());
1627
1628        // Solid color
1629        let solid = ImageInfo {
1630            width: 256,
1631            height: 256,
1632            color_variance: 0.001,
1633        };
1634        assert!(!solid.should_embed());
1635    }
1636
1637    #[test]
1638    fn model_registry() {
1639        let default = default_model_info();
1640        assert_eq!(default.name, "mobileclip-s2");
1641        assert_eq!(default.dims, 512);
1642        assert!(default.is_default);
1643
1644        let siglip = get_model_info("siglip-base");
1645        assert_eq!(siglip.dims, 768);
1646
1647        // Unknown model returns default
1648        let unknown = get_model_info("nonexistent");
1649        assert_eq!(unknown.name, "mobileclip-s2");
1650    }
1651
1652    #[test]
1653    fn clip_config_defaults() {
1654        // Clear the env vars to test true defaults
1655        // SAFETY: No other threads are modifying these env vars in this test
1656        unsafe {
1657            std::env::remove_var("MEMVID_CLIP_MODEL");
1658            std::env::remove_var("MEMVID_OFFLINE");
1659        }
1660
1661        let config = ClipConfig::default();
1662        assert_eq!(config.model_name, "mobileclip-s2");
1663        assert!(!config.offline);
1664    }
1665
1666    #[test]
1667    fn clip_embedding_provider_trait() {
1668        // Test that the trait is properly defined
1669        fn assert_send_sync<T: Send + Sync>() {}
1670
1671        // The trait should require Send + Sync
1672        assert_send_sync::<Box<dyn super::ClipEmbeddingProvider>>();
1673    }
1674}