Skip to main content

memvid_core/
clip.rs

1// Safe expect: Static CLIP model lookup with guaranteed default.
2#![allow(clippy::unwrap_used, clippy::expect_used)]
3//! CLIP (Contrastive Language-Image Pre-training) visual embeddings module.
4//!
5//! This module provides visual understanding capabilities using MobileCLIP-S2,
6//! enabling semantic search across images and PDF pages with natural language queries.
7//!
8//! # Design Philosophy
9//!
10//! - **Synchronous with Parallelism**: CLIP runs in parallel with text embedding via rayon.
11//!   Since CLIP (~25ms) is faster than text embedding (~200-500ms), it adds zero latency.
12//! - **Separate Index**: CLIP embeddings (512 dims) are stored in a separate index from
13//!   text embeddings (384/768/1536 dims) because dimensions must match within an index.
14//! - **Auto-detection**: Images and PDFs with images are automatically processed without flags.
15//! - **Graceful Degradation**: Works without CLIP, just loses visual search capability.
16
17use blake3::hash;
18#[cfg(feature = "clip")]
19use image::DynamicImage;
20#[cfg(all(feature = "clip", not(feature = "pdfium")))]
21use image::{ImageBuffer, Luma, Rgb};
22#[cfg(all(feature = "clip", not(feature = "pdfium")))]
23use lopdf::{Dictionary, Document, Object, ObjectId};
24use serde::{Deserialize, Serialize};
25use std::borrow::Cow;
26#[cfg(all(feature = "clip", not(feature = "pdfium")))]
27use std::collections::HashSet;
28use std::path::{Path, PathBuf};
29use std::time::Duration;
30
31use crate::{MemvidError, Result, types::FrameId};
32
33// ============================================================================
34// Configuration Constants
35// ============================================================================
36
37/// CLIP index decode limit (512MB max)
38#[allow(clippy::cast_possible_truncation)]
39const CLIP_DECODE_LIMIT: usize = crate::MAX_INDEX_BYTES as usize;
40
41/// MobileCLIP-S2 embedding dimensions
42pub const MOBILECLIP_DIMS: u32 = 512;
43
44/// SigLIP-base embedding dimensions
45pub const SIGLIP_DIMS: u32 = 768;
46
47/// Default input resolution for MobileCLIP-S2
48pub const MOBILECLIP_INPUT_SIZE: u32 = 256;
49
50/// Default input resolution for `SigLIP`
51pub const SIGLIP_INPUT_SIZE: u32 = 224;
52
53/// Minimum image dimension to process (skip icons, bullets)
54pub const MIN_IMAGE_DIM: u32 = 64;
55
56/// Maximum aspect ratio deviation from 1:1 (skip dividers, lines)
57pub const MAX_ASPECT_RATIO: f32 = 10.0;
58
59/// Minimum color variance threshold (skip solid backgrounds)
60pub const MIN_COLOR_VARIANCE: f32 = 0.01;
61
62/// Model unload timeout (5 minutes idle)
63pub const MODEL_UNLOAD_TIMEOUT: Duration = Duration::from_secs(300);
64
65// ============================================================================
66// Bincode Configuration
67// ============================================================================
68
69fn clip_config() -> impl bincode::config::Config {
70    bincode::config::standard()
71        .with_fixed_int_encoding()
72        .with_little_endian()
73}
74
75// ============================================================================
76// Model Registry
77// ============================================================================
78
79/// Available CLIP models with verified `HuggingFace` URLs
80#[derive(Debug, Clone)]
81pub struct ClipModelInfo {
82    /// Model identifier
83    pub name: &'static str,
84    /// URL for vision encoder ONNX model
85    pub vision_url: &'static str,
86    /// URL for text encoder ONNX model
87    pub text_url: &'static str,
88    /// URL for tokenizer JSON (BPE)
89    pub tokenizer_url: &'static str,
90    /// Vision model size in MB
91    pub vision_size_mb: f32,
92    /// Text model size in MB
93    pub text_size_mb: f32,
94    /// Output embedding dimensions
95    pub dims: u32,
96    /// Input image resolution
97    pub input_resolution: u32,
98    /// Whether this is the default model
99    pub is_default: bool,
100}
101
102/// Available CLIP models registry
103pub static CLIP_MODELS: &[ClipModelInfo] = &[
104    // MobileCLIP-S2 int8 quantized (smallest, but requires INT8 ONNX support)
105    // Note: INT8 quantized models don't work on all platforms (ConvInteger not supported)
106    ClipModelInfo {
107        name: "mobileclip-s2-int8",
108        vision_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/onnx/vision_model_int8.onnx",
109        text_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/onnx/text_model_int8.onnx",
110        tokenizer_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/tokenizer.json",
111        vision_size_mb: 36.7,
112        text_size_mb: 64.1,
113        dims: MOBILECLIP_DIMS,
114        input_resolution: MOBILECLIP_INPUT_SIZE,
115        is_default: false,
116    },
117    // Alternative: SigLIP-base quantized (higher quality, but may have INT8 issues)
118    ClipModelInfo {
119        name: "siglip-base",
120        vision_url: "https://huggingface.co/Xenova/siglip-base-patch16-224/resolve/main/onnx/vision_model_quantized.onnx",
121        text_url: "https://huggingface.co/Xenova/siglip-base-patch16-224/resolve/main/onnx/text_model_quantized.onnx",
122        tokenizer_url: "https://huggingface.co/Xenova/siglip-base-patch16-224/resolve/main/tokenizer.json",
123        vision_size_mb: 99.5,
124        text_size_mb: 111.0,
125        dims: SIGLIP_DIMS,
126        input_resolution: SIGLIP_INPUT_SIZE,
127        is_default: false,
128    },
129    // Default: MobileCLIP-S2 fp16 (works on all platforms, good balance of size/quality)
130    ClipModelInfo {
131        name: "mobileclip-s2",
132        vision_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/onnx/vision_model_fp16.onnx",
133        text_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/onnx/text_model_fp16.onnx",
134        tokenizer_url: "https://huggingface.co/Xenova/mobileclip_s2/resolve/main/tokenizer.json",
135        vision_size_mb: 71.7,
136        text_size_mb: 127.0,
137        dims: MOBILECLIP_DIMS,
138        input_resolution: MOBILECLIP_INPUT_SIZE,
139        is_default: true,
140    },
141];
142
143/// Get model info by name, defaults to mobileclip-s2
144#[must_use]
145pub fn get_model_info(name: &str) -> &'static ClipModelInfo {
146    CLIP_MODELS
147        .iter()
148        .find(|m| m.name == name)
149        .unwrap_or_else(|| {
150            CLIP_MODELS
151                .iter()
152                .find(|m| m.is_default)
153                .expect("default model")
154        })
155}
156
157/// Get the default model info
158#[must_use]
159pub fn default_model_info() -> &'static ClipModelInfo {
160    CLIP_MODELS
161        .iter()
162        .find(|m| m.is_default)
163        .expect("default model exists")
164}
165
166// ============================================================================
167// CLIP Document and Index Types (mirrors vec.rs pattern)
168// ============================================================================
169
170/// A document with CLIP embedding stored in the index
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct ClipDocument {
173    /// Frame ID this embedding belongs to
174    pub frame_id: FrameId,
175    /// CLIP embedding vector (512 or 768 dims depending on model)
176    pub embedding: Vec<f32>,
177    /// Optional page number (for PDFs)
178    #[serde(default)]
179    pub page: Option<u32>,
180}
181
182/// Builder for constructing CLIP index artifacts
183#[derive(Default)]
184pub struct ClipIndexBuilder {
185    documents: Vec<ClipDocument>,
186}
187
188impl ClipIndexBuilder {
189    #[must_use]
190    pub fn new() -> Self {
191        Self::default()
192    }
193
194    /// Add a document with its CLIP embedding
195    pub fn add_document<I>(&mut self, frame_id: FrameId, page: Option<u32>, embedding: I)
196    where
197        I: Into<Vec<f32>>,
198    {
199        self.documents.push(ClipDocument {
200            frame_id,
201            embedding: embedding.into(),
202            page,
203        });
204    }
205
206    /// Finish building and produce the index artifact
207    pub fn finish(self) -> Result<ClipIndexArtifact> {
208        let bytes = bincode::serde::encode_to_vec(&self.documents, clip_config())?;
209
210        let checksum = *hash(&bytes).as_bytes();
211        let dimension = self
212            .documents
213            .first()
214            .map_or(0, |doc| u32::try_from(doc.embedding.len()).unwrap_or(0));
215
216        Ok(ClipIndexArtifact {
217            bytes,
218            vector_count: self.documents.len() as u64,
219            dimension,
220            checksum,
221        })
222    }
223}
224
225/// Artifact produced by the CLIP index builder
226#[derive(Debug, Clone)]
227pub struct ClipIndexArtifact {
228    /// Serialized index bytes
229    pub bytes: Vec<u8>,
230    /// Number of vectors in the index
231    pub vector_count: u64,
232    /// Embedding dimension (512 for `MobileCLIP`, 768 for `SigLIP`)
233    pub dimension: u32,
234    /// Blake3 checksum of the bytes
235    pub checksum: [u8; 32],
236}
237
238/// In-memory CLIP index for similarity search
239#[derive(Debug, Clone)]
240pub struct ClipIndex {
241    documents: Vec<ClipDocument>,
242}
243
244impl Default for ClipIndex {
245    fn default() -> Self {
246        Self::new()
247    }
248}
249
250impl ClipIndex {
251    /// Create a new empty CLIP index
252    #[must_use]
253    pub fn new() -> Self {
254        Self {
255            documents: Vec::new(),
256        }
257    }
258
259    /// Add a document with its CLIP embedding
260    pub fn add_document<I>(&mut self, frame_id: FrameId, page: Option<u32>, embedding: I)
261    where
262        I: Into<Vec<f32>>,
263    {
264        self.documents.push(ClipDocument {
265            frame_id,
266            embedding: embedding.into(),
267            page,
268        });
269    }
270
271    /// Decode CLIP index from bytes
272    pub fn decode(bytes: &[u8]) -> Result<Self> {
273        let (documents, read) = bincode::serde::decode_from_slice::<Vec<ClipDocument>, _>(
274            bytes,
275            bincode::config::standard()
276                .with_fixed_int_encoding()
277                .with_little_endian()
278                .with_limit::<CLIP_DECODE_LIMIT>(),
279        )?;
280
281        if read != bytes.len() {
282            return Err(MemvidError::InvalidToc {
283                reason: Cow::Owned(format!(
284                    "CLIP index decode: expected {} bytes, read {}",
285                    bytes.len(),
286                    read
287                )),
288            });
289        }
290
291        tracing::debug!(
292            bytes_len = bytes.len(),
293            docs_count = documents.len(),
294            "decoded CLIP index"
295        );
296
297        Ok(Self { documents })
298    }
299
300    /// Search for similar embeddings using L2 distance
301    #[must_use]
302    pub fn search(&self, query: &[f32], limit: usize) -> Vec<ClipSearchHit> {
303        if query.is_empty() {
304            return Vec::new();
305        }
306
307        let mut hits: Vec<ClipSearchHit> = self
308            .documents
309            .iter()
310            .map(|doc| {
311                let distance = l2_distance(query, &doc.embedding);
312                ClipSearchHit {
313                    frame_id: doc.frame_id,
314                    page: doc.page,
315                    distance,
316                }
317            })
318            .collect();
319
320        hits.sort_by(|a, b| {
321            a.distance
322                .partial_cmp(&b.distance)
323                .unwrap_or(std::cmp::Ordering::Equal)
324        });
325        hits.truncate(limit);
326        hits
327    }
328
329    /// Get all entries in the index
330    pub fn entries(&self) -> impl Iterator<Item = (FrameId, Option<u32>, &[f32])> + '_ {
331        self.documents
332            .iter()
333            .map(|doc| (doc.frame_id, doc.page, doc.embedding.as_slice()))
334    }
335
336    /// Get embedding for a specific frame
337    #[must_use]
338    pub fn embedding_for(&self, frame_id: FrameId) -> Option<&[f32]> {
339        self.documents
340            .iter()
341            .find(|doc| doc.frame_id == frame_id)
342            .map(|doc| doc.embedding.as_slice())
343    }
344
345    /// Remove a document from the index
346    pub fn remove(&mut self, frame_id: FrameId) {
347        self.documents.retain(|doc| doc.frame_id != frame_id);
348    }
349
350    /// Number of documents in the index
351    #[must_use]
352    pub fn len(&self) -> usize {
353        self.documents.len()
354    }
355
356    /// Check if index is empty
357    #[must_use]
358    pub fn is_empty(&self) -> bool {
359        self.documents.is_empty()
360    }
361
362    /// Encode the CLIP index to bytes and produce an artifact for persistence
363    pub fn encode(&self) -> Result<ClipIndexArtifact> {
364        let bytes = bincode::serde::encode_to_vec(&self.documents, clip_config())?;
365
366        let checksum = *hash(&bytes).as_bytes();
367        let dimension = self
368            .documents
369            .first()
370            .map_or(0, |doc| u32::try_from(doc.embedding.len()).unwrap_or(0));
371
372        Ok(ClipIndexArtifact {
373            bytes,
374            vector_count: self.documents.len() as u64,
375            dimension,
376            checksum,
377        })
378    }
379}
380
381/// Search result from CLIP index
382#[derive(Debug, Clone, PartialEq)]
383pub struct ClipSearchHit {
384    /// Frame ID of the matched document
385    pub frame_id: FrameId,
386    /// Optional page number (for PDFs)
387    pub page: Option<u32>,
388    /// L2 distance to query (lower is more similar)
389    pub distance: f32,
390}
391
392/// L2 (Euclidean) distance between two vectors
393fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
394    a.iter()
395        .zip(b.iter())
396        .map(|(x, y)| (x - y).powi(2))
397        .sum::<f32>()
398        .sqrt()
399}
400
401// ============================================================================
402// Image Filtering (Junk Detection)
403// ============================================================================
404
405/// Metadata about an image for filtering
406#[derive(Debug, Clone)]
407pub struct ImageInfo {
408    pub width: u32,
409    pub height: u32,
410    pub color_variance: f32,
411}
412
413impl ImageInfo {
414    /// Check if this image should be processed for CLIP embedding
415    #[must_use]
416    pub fn should_embed(&self) -> bool {
417        // Skip tiny images (icons, bullets)
418        if self.width < MIN_IMAGE_DIM || self.height < MIN_IMAGE_DIM {
419            return false;
420        }
421
422        // Skip extreme aspect ratios (dividers, lines)
423        let aspect = self.width as f32 / self.height as f32;
424        if !((1.0 / MAX_ASPECT_RATIO)..=MAX_ASPECT_RATIO).contains(&aspect) {
425            return false;
426        }
427
428        // Skip near-solid colors (backgrounds, spacers)
429        if self.color_variance < MIN_COLOR_VARIANCE {
430            return false;
431        }
432
433        true
434    }
435}
436
437/// Filter a list of images, keeping only those worth embedding
438pub fn filter_junk_images<T, F>(images: Vec<T>, get_info: F) -> Vec<T>
439where
440    F: Fn(&T) -> ImageInfo,
441{
442    images
443        .into_iter()
444        .filter(|img| get_info(img).should_embed())
445        .collect()
446}
447
448// ============================================================================
449// CLIP Model Configuration
450// ============================================================================
451
452/// Configuration for CLIP model initialization
453#[derive(Debug, Clone)]
454pub struct ClipConfig {
455    /// Model name (e.g., "mobileclip-s2", "siglip-base")
456    pub model_name: String,
457    /// Directory where models are cached
458    pub models_dir: PathBuf,
459    /// Whether to run in offline mode (no downloads)
460    pub offline: bool,
461}
462
463impl Default for ClipConfig {
464    fn default() -> Self {
465        // Use ~/.memvid/models as default, consistent with CLI's model installation
466        let models_dir = std::env::var("MEMVID_MODELS_DIR")
467            .ok()
468            .map(PathBuf::from)
469            .or_else(|| dirs_next::home_dir().map(|d| d.join(".memvid/models")))
470            .unwrap_or_else(|| PathBuf::from(".memvid/models"));
471
472        let model_name =
473            std::env::var("MEMVID_CLIP_MODEL").unwrap_or_else(|_| "mobileclip-s2".to_string());
474
475        let offline = std::env::var("MEMVID_OFFLINE").is_ok();
476
477        Self {
478            model_name,
479            models_dir,
480            offline,
481        }
482    }
483}
484
485// ============================================================================
486// CLIP Error Types
487// ============================================================================
488
489/// CLIP-specific errors
490#[derive(Debug, thiserror::Error)]
491pub enum ClipError {
492    /// Model not found and offline mode enabled
493    #[error("CLIP model '{model}' not found. {hint}")]
494    ModelNotFound { model: String, hint: String },
495
496    /// Image decode failed
497    #[error("Failed to decode image at {path:?}: {cause}")]
498    ImageDecodeError { path: PathBuf, cause: String },
499
500    /// Image bytes decode failed
501    #[error("Failed to decode image bytes: {cause}")]
502    ImageBytesDecodeError { cause: String },
503
504    /// ONNX runtime error
505    #[error("CLIP inference error: {cause}")]
506    InferenceError { cause: String },
507
508    /// Model download failed
509    #[error("Failed to download CLIP model: {cause}")]
510    DownloadError { cause: String },
511
512    /// Model file corrupted or invalid
513    #[error("CLIP model file is corrupted: {cause}")]
514    ModelCorrupted { cause: String },
515}
516
517impl From<ClipError> for MemvidError {
518    fn from(err: ClipError) -> Self {
519        MemvidError::EmbeddingFailed {
520            reason: err.to_string().into_boxed_str(),
521        }
522    }
523}
524
525// ============================================================================
526// CLIP Model (Feature-gated implementation)
527// ============================================================================
528
529#[cfg(feature = "clip")]
530mod model {
531    use super::*;
532    use image::{DynamicImage, GenericImageView, imageops::FilterType};
533    use ndarray::{Array, Array4};
534    use ort::session::{Session, builder::GraphOptimizationLevel};
535    use ort::value::Tensor;
536    use std::sync::Mutex;
537    use std::time::Instant;
538    use tokenizers::{
539        PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationDirection,
540        TruncationParams, TruncationStrategy,
541    };
542
543    /// CLIP model with lazy-loaded vision and text encoders
544    pub struct ClipModel {
545        config: ClipConfig,
546        model_info: &'static ClipModelInfo,
547        /// Lazy-loaded vision encoder session
548        vision_session: Mutex<Option<Session>>,
549        /// Lazy-loaded text encoder session
550        text_session: Mutex<Option<Session>>,
551        /// Lazy-loaded tokenizer matching the text encoder
552        tokenizer: Mutex<Option<Tokenizer>>,
553        /// Last time the model was used (for idle unloading)
554        last_used: Mutex<Instant>,
555    }
556
557    impl ClipModel {
558        /// Create a new CLIP model with the given configuration
559        pub fn new(config: ClipConfig) -> Result<Self> {
560            let model_info = get_model_info(&config.model_name);
561
562            Ok(Self {
563                config,
564                model_info,
565                vision_session: Mutex::new(None),
566                text_session: Mutex::new(None),
567                tokenizer: Mutex::new(None),
568                last_used: Mutex::new(Instant::now()),
569            })
570        }
571
572        /// Create with default configuration
573        pub fn default_model() -> Result<Self> {
574            Self::new(ClipConfig::default())
575        }
576
577        /// Get model info
578        pub fn model_info(&self) -> &'static ClipModelInfo {
579            self.model_info
580        }
581
582        /// Get embedding dimensions
583        pub fn dims(&self) -> u32 {
584            self.model_info.dims
585        }
586
587        /// Ensure model file exists, downloading if necessary
588        fn ensure_model_file(&self, kind: &str) -> Result<PathBuf> {
589            let filename = format!("{}_{}.onnx", self.model_info.name, kind);
590            let path = self.config.models_dir.join(&filename);
591
592            if path.exists() {
593                return Ok(path);
594            }
595
596            if self.config.offline {
597                return Err(ClipError::ModelNotFound {
598                    model: self.model_info.name.to_string(),
599                    hint: format!(
600                        "Run: memvid model download {} (or disable MEMVID_OFFLINE)",
601                        self.model_info.name
602                    ),
603                }
604                .into());
605            }
606
607            // Create models directory if needed
608            std::fs::create_dir_all(&self.config.models_dir).map_err(|e| {
609                ClipError::DownloadError {
610                    cause: format!("Failed to create models directory: {}", e),
611                }
612            })?;
613
614            // Provide manual download instructions
615            Err(ClipError::DownloadError {
616                cause: format!(
617                    "Automatic download not yet implemented. Please download manually:\n\
618                     curl -L '{}' -o '{}'",
619                    if kind == "vision" {
620                        self.model_info.vision_url
621                    } else {
622                        self.model_info.text_url
623                    },
624                    path.display()
625                ),
626            }
627            .into())
628        }
629
630        /// Ensure tokenizer file exists, downloading if necessary
631        fn ensure_tokenizer_file(&self) -> Result<PathBuf> {
632            let filename = format!("{}_tokenizer.json", self.model_info.name);
633            let path = self.config.models_dir.join(&filename);
634
635            if path.exists() {
636                return Ok(path);
637            }
638
639            if self.config.offline {
640                return Err(ClipError::ModelNotFound {
641                    model: self.model_info.name.to_string(),
642                    hint: format!(
643                        "Tokenizer missing at {}. Copy tokenizer.json from {}",
644                        path.display(),
645                        self.model_info.tokenizer_url
646                    ),
647                }
648                .into());
649            }
650
651            std::fs::create_dir_all(&self.config.models_dir).map_err(|e| {
652                ClipError::DownloadError {
653                    cause: format!("Failed to create models directory: {}", e),
654                }
655            })?;
656
657            Err(ClipError::DownloadError {
658                cause: format!(
659                    "Automatic download not yet implemented. Please download manually:\n\
660                     curl -L '{}' -o '{}'",
661                    self.model_info.tokenizer_url,
662                    path.display()
663                ),
664            }
665            .into())
666        }
667
668        /// Load vision session lazily
669        fn load_vision_session(&self) -> Result<()> {
670            let mut session_guard = self
671                .vision_session
672                .lock()
673                .map_err(|_| MemvidError::Lock("Failed to lock vision session".into()))?;
674
675            if session_guard.is_some() {
676                return Ok(());
677            }
678
679            let vision_path = self.ensure_model_file("vision")?;
680
681            tracing::debug!(path = %vision_path.display(), "Loading CLIP vision model");
682
683            let session = Session::builder()
684                .map_err(|e| ClipError::InferenceError {
685                    cause: e.to_string(),
686                })?
687                .with_optimization_level(GraphOptimizationLevel::Level3)
688                .map_err(|e| ClipError::InferenceError {
689                    cause: e.to_string(),
690                })?
691                .with_intra_threads(4)
692                .map_err(|e| ClipError::InferenceError {
693                    cause: e.to_string(),
694                })?
695                .commit_from_file(&vision_path)
696                .map_err(|e| ClipError::InferenceError {
697                    cause: format!("Failed to load vision model: {}", e),
698                })?;
699
700            *session_guard = Some(session);
701            tracing::info!(model = %self.model_info.name, "CLIP vision model loaded");
702
703            Ok(())
704        }
705
706        /// Load text session lazily
707        fn load_text_session(&self) -> Result<()> {
708            let mut session_guard = self
709                .text_session
710                .lock()
711                .map_err(|_| MemvidError::Lock("Failed to lock text session".into()))?;
712
713            if session_guard.is_some() {
714                return Ok(());
715            }
716
717            let text_path = self.ensure_model_file("text")?;
718
719            tracing::debug!(path = %text_path.display(), "Loading CLIP text model");
720
721            let session = Session::builder()
722                .map_err(|e| ClipError::InferenceError {
723                    cause: e.to_string(),
724                })?
725                .with_optimization_level(GraphOptimizationLevel::Level3)
726                .map_err(|e| ClipError::InferenceError {
727                    cause: e.to_string(),
728                })?
729                .with_intra_threads(4)
730                .map_err(|e| ClipError::InferenceError {
731                    cause: e.to_string(),
732                })?
733                .commit_from_file(&text_path)
734                .map_err(|e| ClipError::InferenceError {
735                    cause: format!("Failed to load text model: {}", e),
736                })?;
737
738            *session_guard = Some(session);
739            tracing::info!(model = %self.model_info.name, "CLIP text model loaded");
740
741            Ok(())
742        }
743
744        /// Load tokenizer lazily (matches the text model vocab/BPE)
745        fn load_tokenizer(&self) -> Result<()> {
746            let mut tokenizer_guard = self
747                .tokenizer
748                .lock()
749                .map_err(|_| MemvidError::Lock("Failed to lock CLIP tokenizer".into()))?;
750
751            if tokenizer_guard.is_some() {
752                return Ok(());
753            }
754
755            let tokenizer_path = self.ensure_tokenizer_file()?;
756
757            tracing::debug!(path = %tokenizer_path.display(), "Loading CLIP tokenizer");
758
759            let mut tokenizer =
760                Tokenizer::from_file(&tokenizer_path).map_err(|e| ClipError::InferenceError {
761                    cause: format!("Failed to load tokenizer: {}", e),
762                })?;
763
764            tokenizer.with_padding(Some(PaddingParams {
765                strategy: PaddingStrategy::Fixed(77),
766                direction: PaddingDirection::Right,
767                pad_to_multiple_of: None,
768                pad_id: 0,
769                pad_type_id: 0,
770                pad_token: "[PAD]".to_string(),
771            }));
772
773            tokenizer
774                .with_truncation(Some(TruncationParams {
775                    max_length: 77,
776                    strategy: TruncationStrategy::LongestFirst,
777                    stride: 0,
778                    direction: TruncationDirection::Right,
779                }))
780                .map_err(|e| ClipError::InferenceError {
781                    cause: format!("Failed to apply truncation config: {}", e),
782                })?;
783
784            *tokenizer_guard = Some(tokenizer);
785            tracing::info!(model = %self.model_info.name, "CLIP tokenizer loaded");
786
787            Ok(())
788        }
789
790        /// Preprocess image for CLIP inference
791        ///
792        /// MobileCLIP-S2 uses:
793        /// - Input size: 256x256
794        /// - Resize: shortest edge to 256, preserve aspect, center-crop
795        /// - Normalization: scale to [0, 1] (no mean/std shift per preprocessor_config)
796        /// - Format: NCHW (batch, channels, height, width)
797        fn preprocess_image(&self, image: &DynamicImage) -> Array4<f32> {
798            let size = self.model_info.input_resolution;
799            let rgb_input = image.to_rgb8();
800            let (w, h) = rgb_input.dimensions();
801
802            // Resize shortest edge to target while preserving aspect ratio
803            let scale = size as f32 / w.min(h) as f32;
804            let new_w = ((w as f32) * scale).round().max(1.0) as u32;
805            let new_h = ((h as f32) * scale).round().max(1.0) as u32;
806            let resized = image.resize_exact(new_w, new_h, FilterType::Triangle);
807
808            // Center crop to (size, size)
809            let start_x = (resized.width().saturating_sub(size)) / 2;
810            let start_y = (resized.height().saturating_sub(size)) / 2;
811
812            // Create array in NCHW format: [1, 3, H, W]
813            let mut array = Array4::<f32>::zeros((1, 3, size as usize, size as usize));
814
815            for y in 0..size as usize {
816                for x in 0..size as usize {
817                    let pixel = resized.get_pixel(start_x + x as u32, start_y + y as u32);
818                    array[[0, 0, y, x]] = pixel[0] as f32 / 255.0;
819                    array[[0, 1, y, x]] = pixel[1] as f32 / 255.0;
820                    array[[0, 2, y, x]] = pixel[2] as f32 / 255.0;
821                }
822            }
823
824            array
825        }
826
827        /// Encode an image to CLIP embedding
828        pub fn encode_image(&self, image: &DynamicImage) -> Result<Vec<f32>> {
829            // Ensure vision session is loaded
830            self.load_vision_session()?;
831
832            // Preprocess the image
833            let pixel_values = self.preprocess_image(image);
834
835            // Update last used timestamp
836            if let Ok(mut last) = self.last_used.lock() {
837                *last = Instant::now();
838            }
839
840            // Run inference
841            let mut session_guard = self
842                .vision_session
843                .lock()
844                .map_err(|_| MemvidError::Lock("Failed to lock vision session".into()))?;
845
846            let session = session_guard
847                .as_mut()
848                .ok_or_else(|| ClipError::InferenceError {
849                    cause: "Vision session not loaded".to_string(),
850                })?;
851
852            // Get input and output names from session before running
853            let input_name = session
854                .inputs
855                .first()
856                .map(|i| i.name.clone())
857                .unwrap_or_else(|| "pixel_values".into());
858            let output_name = session
859                .outputs
860                .first()
861                .map(|o| o.name.clone())
862                .unwrap_or_else(|| "image_embeds".into());
863
864            // Create tensor from ndarray
865            let input_tensor =
866                Tensor::from_array(pixel_values).map_err(|e| ClipError::InferenceError {
867                    cause: format!("Failed to create input tensor: {}", e),
868                })?;
869
870            // Run the model
871            let outputs = session
872                .run(ort::inputs![input_name => input_tensor])
873                .map_err(|e| ClipError::InferenceError {
874                    cause: format!("Vision inference failed: {}", e),
875                })?;
876
877            // Extract embeddings from first output
878            let output = outputs
879                .get(&output_name)
880                .ok_or_else(|| ClipError::InferenceError {
881                    cause: format!("No output '{}' from vision model", output_name),
882                })?;
883
884            let (_shape, data) =
885                output
886                    .try_extract_tensor::<f32>()
887                    .map_err(|e| ClipError::InferenceError {
888                        cause: format!("Failed to extract embeddings: {}", e),
889                    })?;
890
891            // Get the embedding from the raw data
892            let embedding: Vec<f32> = data.to_vec();
893            if embedding.iter().any(|v| !v.is_finite()) {
894                return Err(ClipError::InferenceError {
895                    cause: "Vision embedding contains non-finite values".to_string(),
896                }
897                .into());
898            }
899            let normalized = l2_normalize(&embedding);
900
901            tracing::debug!(dims = normalized.len(), "Generated CLIP image embedding");
902
903            Ok(normalized)
904        }
905
906        /// Encode image bytes to CLIP embedding
907        pub fn encode_image_bytes(&self, bytes: &[u8]) -> Result<Vec<f32>> {
908            let image =
909                image::load_from_memory(bytes).map_err(|e| ClipError::ImageBytesDecodeError {
910                    cause: e.to_string(),
911                })?;
912            self.encode_image(&image)
913        }
914
915        /// Encode an image file to CLIP embedding
916        pub fn encode_image_file(&self, path: &Path) -> Result<Vec<f32>> {
917            let image = image::open(path).map_err(|e| ClipError::ImageDecodeError {
918                path: path.to_path_buf(),
919                cause: e.to_string(),
920            })?;
921            self.encode_image(&image)
922        }
923
924        /// Encode text to CLIP embedding (for query)
925        pub fn encode_text(&self, text: &str) -> Result<Vec<f32>> {
926            // Ensure text session is loaded
927            self.load_text_session()?;
928            self.load_tokenizer()?;
929
930            // Tokenize the text using the model's tokenizer
931            let encoding = {
932                let tokenizer_guard = self
933                    .tokenizer
934                    .lock()
935                    .map_err(|_| MemvidError::Lock("Failed to lock CLIP tokenizer".into()))?;
936                let tokenizer =
937                    tokenizer_guard
938                        .as_ref()
939                        .ok_or_else(|| ClipError::InferenceError {
940                            cause: "Tokenizer not loaded".to_string(),
941                        })?;
942
943                tokenizer
944                    .encode(text, true)
945                    .map_err(|e| ClipError::InferenceError {
946                        cause: format!("Text tokenization failed: {}", e),
947                    })?
948            };
949
950            let input_ids: Vec<i64> = encoding.get_ids().iter().map(|id| *id as i64).collect();
951            let attention_mask: Vec<i64> = encoding
952                .get_attention_mask()
953                .iter()
954                .map(|id| *id as i64)
955                .collect();
956            let max_length = input_ids.len();
957
958            // Create input arrays
959            let input_ids_array =
960                Array::from_shape_vec((1, max_length), input_ids).map_err(|e| {
961                    ClipError::InferenceError {
962                        cause: e.to_string(),
963                    }
964                })?;
965            let attention_mask_array = Array::from_shape_vec((1, max_length), attention_mask)
966                .map_err(|e| ClipError::InferenceError {
967                    cause: e.to_string(),
968                })?;
969
970            // Update last used timestamp
971            if let Ok(mut last) = self.last_used.lock() {
972                *last = Instant::now();
973            }
974
975            // Run inference
976            let mut session_guard = self
977                .text_session
978                .lock()
979                .map_err(|_| MemvidError::Lock("Failed to lock text session".into()))?;
980
981            let session = session_guard
982                .as_mut()
983                .ok_or_else(|| ClipError::InferenceError {
984                    cause: "Text session not loaded".to_string(),
985                })?;
986
987            // Get input and output names from session before running
988            let input_names: Vec<String> = session.inputs.iter().map(|i| i.name.clone()).collect();
989            let output_name = session
990                .outputs
991                .first()
992                .map(|o| o.name.clone())
993                .unwrap_or_else(|| "text_embeds".into());
994
995            // Create tensors from ndarray
996            let input_ids_tensor =
997                Tensor::from_array(input_ids_array).map_err(|e| ClipError::InferenceError {
998                    cause: format!("Failed to create input_ids tensor: {}", e),
999                })?;
1000            let attention_mask_tensor = Tensor::from_array(attention_mask_array).map_err(|e| {
1001                ClipError::InferenceError {
1002                    cause: format!("Failed to create attention_mask tensor: {}", e),
1003                }
1004            })?;
1005
1006            // Build inputs based on what the model expects
1007            let outputs = if input_names.len() >= 2 {
1008                session
1009                    .run(ort::inputs![
1010                        input_names[0].clone() => input_ids_tensor,
1011                        input_names[1].clone() => attention_mask_tensor
1012                    ])
1013                    .map_err(|e| ClipError::InferenceError {
1014                        cause: format!("Text inference failed: {}", e),
1015                    })?
1016            } else {
1017                // Single input model
1018                let name = input_names
1019                    .first()
1020                    .cloned()
1021                    .unwrap_or_else(|| "input_ids".to_string());
1022                session
1023                    .run(ort::inputs![name => input_ids_tensor])
1024                    .map_err(|e| ClipError::InferenceError {
1025                        cause: format!("Text inference failed: {}", e),
1026                    })?
1027            };
1028
1029            // Extract embeddings from output
1030            let output = outputs
1031                .get(&output_name)
1032                .ok_or_else(|| ClipError::InferenceError {
1033                    cause: format!("No output '{}' from text model", output_name),
1034                })?;
1035
1036            let (_shape, data) =
1037                output
1038                    .try_extract_tensor::<f32>()
1039                    .map_err(|e| ClipError::InferenceError {
1040                        cause: format!("Failed to extract text embeddings: {}", e),
1041                    })?;
1042
1043            // Flatten and normalize the embedding
1044            let embedding: Vec<f32> = data.to_vec();
1045            if embedding.iter().any(|v| !v.is_finite()) {
1046                return Err(ClipError::InferenceError {
1047                    cause: "Text embedding contains non-finite values".to_string(),
1048                }
1049                .into());
1050            }
1051            let normalized = l2_normalize(&embedding);
1052
1053            tracing::debug!(
1054                text_len = text.len(),
1055                dims = normalized.len(),
1056                "Generated CLIP text embedding"
1057            );
1058
1059            Ok(normalized)
1060        }
1061
1062        /// Maybe unload model if unused for too long (memory management)
1063        pub fn maybe_unload(&self) -> Result<()> {
1064            let last_used = self
1065                .last_used
1066                .lock()
1067                .map_err(|_| MemvidError::Lock("Failed to check last_used".into()))?;
1068
1069            if last_used.elapsed() > MODEL_UNLOAD_TIMEOUT {
1070                tracing::debug!(model = %self.model_info.name, "Model idle, unloading sessions");
1071
1072                // Unload vision session
1073                if let Ok(mut guard) = self.vision_session.lock() {
1074                    *guard = None;
1075                }
1076
1077                // Unload text session
1078                if let Ok(mut guard) = self.text_session.lock() {
1079                    *guard = None;
1080                }
1081
1082                // Unload tokenizer
1083                if let Ok(mut guard) = self.tokenizer.lock() {
1084                    *guard = None;
1085                }
1086            }
1087
1088            Ok(())
1089        }
1090
1091        /// Force unload all sessions
1092        pub fn unload(&self) -> Result<()> {
1093            if let Ok(mut guard) = self.vision_session.lock() {
1094                *guard = None;
1095            }
1096            if let Ok(mut guard) = self.text_session.lock() {
1097                *guard = None;
1098            }
1099            if let Ok(mut guard) = self.tokenizer.lock() {
1100                *guard = None;
1101            }
1102            tracing::debug!(model = %self.model_info.name, "CLIP sessions unloaded");
1103            Ok(())
1104        }
1105
1106        /// Check if vision model is loaded
1107        pub fn is_vision_loaded(&self) -> bool {
1108            self.vision_session
1109                .lock()
1110                .map(|g| g.is_some())
1111                .unwrap_or(false)
1112        }
1113
1114        /// Check if text model is loaded
1115        pub fn is_text_loaded(&self) -> bool {
1116            self.text_session
1117                .lock()
1118                .map(|g| g.is_some())
1119                .unwrap_or(false)
1120        }
1121    }
1122
1123    /// L2 normalize a vector (unit length)
1124    fn l2_normalize(v: &[f32]) -> Vec<f32> {
1125        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1126        if norm.is_finite() && norm > 1e-10 {
1127            v.iter().map(|x| x / norm).collect()
1128        } else {
1129            // Fall back to zeros to avoid NaNs propagating through distances
1130            vec![0.0; v.len()]
1131        }
1132    }
1133
1134    /// Calculate color variance of an image
1135    pub fn calculate_color_variance(image: &DynamicImage) -> f32 {
1136        let rgb = image.to_rgb8();
1137        let (width, height) = rgb.dimensions();
1138        let total_pixels = (width * height) as f32;
1139
1140        if total_pixels == 0.0 {
1141            return 0.0;
1142        }
1143
1144        // Calculate mean
1145        let mut sum_r = 0.0f32;
1146        let mut sum_g = 0.0f32;
1147        let mut sum_b = 0.0f32;
1148
1149        for pixel in rgb.pixels() {
1150            sum_r += pixel[0] as f32;
1151            sum_g += pixel[1] as f32;
1152            sum_b += pixel[2] as f32;
1153        }
1154
1155        let mean_r = sum_r / total_pixels;
1156        let mean_g = sum_g / total_pixels;
1157        let mean_b = sum_b / total_pixels;
1158
1159        // Calculate variance
1160        let mut var_r = 0.0f32;
1161        let mut var_g = 0.0f32;
1162        let mut var_b = 0.0f32;
1163
1164        for pixel in rgb.pixels() {
1165            var_r += (pixel[0] as f32 - mean_r).powi(2);
1166            var_g += (pixel[1] as f32 - mean_g).powi(2);
1167            var_b += (pixel[2] as f32 - mean_b).powi(2);
1168        }
1169
1170        // Average variance across channels, normalized to 0-1
1171        ((var_r + var_g + var_b) / (3.0 * total_pixels)) / (255.0 * 255.0)
1172    }
1173
1174    /// Get ImageInfo from a DynamicImage
1175    pub fn get_image_info(image: &DynamicImage) -> ImageInfo {
1176        let (width, height) = image.dimensions();
1177        let color_variance = calculate_color_variance(image);
1178
1179        ImageInfo {
1180            width,
1181            height,
1182            color_variance,
1183        }
1184    }
1185}
1186
1187#[cfg(feature = "clip")]
1188pub use model::*;
1189
1190#[cfg(all(feature = "clip", feature = "pdfium"))]
1191use pdfium_render::prelude::{PdfPageRenderRotation, PdfRenderConfig, Pdfium};
1192
1193/// Render PDF pages to images suitable for CLIP embedding (feature-gated).
1194#[cfg(all(feature = "clip", feature = "pdfium"))]
1195pub fn render_pdf_pages_for_clip(
1196    path: &Path,
1197    max_pages: usize,
1198    target_px: u32,
1199) -> Result<Vec<(u32, DynamicImage)>> {
1200    let bindings = Pdfium::bind_to_system_library().map_err(|e| ClipError::InferenceError {
1201        cause: format!("Failed to bind pdfium: {}", e),
1202    })?;
1203    let pdfium = Pdfium::new(bindings);
1204    let document =
1205        pdfium
1206            .load_pdf_from_file(path, None)
1207            .map_err(|e| ClipError::InferenceError {
1208                cause: format!("Failed to load PDF for CLIP rendering: {}", e),
1209            })?;
1210
1211    let render_config = PdfRenderConfig::new()
1212        .set_target_width(target_px as i32)
1213        .set_maximum_height(target_px as i32)
1214        .set_maximum_width(target_px as i32)
1215        .rotate_if_landscape(PdfPageRenderRotation::None, false);
1216
1217    let mut pages = Vec::new();
1218    for (index, page) in document.pages().iter().enumerate() {
1219        if index >= max_pages {
1220            break;
1221        }
1222        let rendered = page
1223            .render_with_config(&render_config)
1224            .map_err(|e| ClipError::InferenceError {
1225                cause: format!("Failed to render PDF page {}: {}", index + 1, e),
1226            })?
1227            .as_image();
1228        pages.push(((index + 1) as u32, rendered));
1229    }
1230
1231    Ok(pages)
1232}
1233
1234#[cfg(all(feature = "clip", not(feature = "pdfium")))]
1235pub fn render_pdf_pages_for_clip(
1236    path: &Path,
1237    max_pages: usize,
1238    _target_px: u32,
1239) -> Result<Vec<(u32, DynamicImage)>> {
1240    fn extract_images_from_page(
1241        doc: &Document,
1242        page_id: ObjectId,
1243        remaining: &mut usize,
1244        out: &mut Vec<(u32, DynamicImage)>,
1245    ) -> Result<()> {
1246        if *remaining == 0 {
1247            return Ok(());
1248        }
1249
1250        let (resources_opt, resource_ids) =
1251            doc.get_page_resources(page_id)
1252                .map_err(|e| ClipError::InferenceError {
1253                    cause: format!("Failed to read PDF resources: {}", e),
1254                })?;
1255
1256        let mut seen = HashSet::new();
1257        let mut resource_dicts: Vec<Dictionary> = Vec::new();
1258
1259        if let Some(dict) = resources_opt {
1260            resource_dicts.push(dict.clone());
1261        }
1262        for res_id in resource_ids {
1263            if seen.insert(res_id) {
1264                if let Ok(dict) = doc.get_dictionary(res_id) {
1265                    resource_dicts.push(dict.clone());
1266                }
1267            }
1268        }
1269
1270        for dict in resource_dicts {
1271            if let Ok(xobjects) = dict.get(b"XObject") {
1272                let xobj_dict = match xobjects {
1273                    Object::Dictionary(d) => Some(d),
1274                    Object::Reference(id) => doc.get_dictionary(*id).ok(),
1275                    _ => None,
1276                };
1277                if let Some(xobj_dict) = xobj_dict {
1278                    for (_, obj) in xobj_dict.iter() {
1279                        let id = match obj {
1280                            Object::Reference(id) => *id,
1281                            _ => continue,
1282                        };
1283                        let stream = match doc.get_object(id).and_then(Object::as_stream) {
1284                            Ok(s) => s,
1285                            Err(_) => continue,
1286                        };
1287                        let subtype = stream.dict.get(b"Subtype").ok();
1288                        let is_image = matches!(subtype, Some(Object::Name(n)) if n == b"Image");
1289                        if !is_image {
1290                            continue;
1291                        }
1292
1293                        let width = stream
1294                            .dict
1295                            .get(b"Width")
1296                            .ok()
1297                            .and_then(|o| o.as_i64().ok())
1298                            .unwrap_or(0);
1299                        let height = stream
1300                            .dict
1301                            .get(b"Height")
1302                            .ok()
1303                            .and_then(|o| o.as_i64().ok())
1304                            .unwrap_or(0);
1305                        if width <= 0 || height <= 0 {
1306                            continue;
1307                        }
1308
1309                        let filters = stream
1310                            .dict
1311                            .get(b"Filter")
1312                            .ok()
1313                            .and_then(|f| match f {
1314                                Object::Name(n) => Some(vec![n.clone()]),
1315                                Object::Array(arr) => Some(
1316                                    arr.iter()
1317                                        .filter_map(|o| o.as_name().ok().map(|n| n.to_vec()))
1318                                        .collect(),
1319                                ),
1320                                _ => None,
1321                            })
1322                            .unwrap_or_default();
1323
1324                        let data = stream
1325                            .decompressed_content()
1326                            .unwrap_or_else(|_| stream.content.clone());
1327
1328                        // If DCT/JPX, hand to image crate directly
1329                        if filters
1330                            .iter()
1331                            .any(|f| f == b"DCTDecode" || f == b"JPXDecode")
1332                        {
1333                            if let Ok(img) = image::load_from_memory(&data) {
1334                                out.push((1, img));
1335                                if out.len() >= *remaining {
1336                                    *remaining = 0;
1337                                    return Ok(());
1338                                }
1339                                *remaining -= 1;
1340                                continue;
1341                            }
1342                        }
1343
1344                        let color_space = stream
1345                            .dict
1346                            .get(b"ColorSpace")
1347                            .ok()
1348                            .and_then(|o| o.as_name().ok())
1349                            .unwrap_or(b"DeviceRGB");
1350                        let channels = if color_space == b"DeviceGray" { 1 } else { 3 };
1351
1352                        let expected = width as usize * height as usize * channels;
1353                        if data.len() >= expected && channels == 3 {
1354                            if let Some(buf) = ImageBuffer::<Rgb<u8>, _>::from_raw(
1355                                width as u32,
1356                                height as u32,
1357                                data.clone(),
1358                            ) {
1359                                out.push((1, DynamicImage::ImageRgb8(buf)));
1360                                if out.len() >= *remaining {
1361                                    *remaining = 0;
1362                                    return Ok(());
1363                                }
1364                                *remaining -= 1;
1365                                continue;
1366                            }
1367                        } else if data.len() >= expected && channels == 1 {
1368                            if let Some(buf) = ImageBuffer::<Luma<u8>, _>::from_raw(
1369                                width as u32,
1370                                height as u32,
1371                                data.clone(),
1372                            ) {
1373                                out.push((1, DynamicImage::ImageLuma8(buf)));
1374                                if out.len() >= *remaining {
1375                                    *remaining = 0;
1376                                    return Ok(());
1377                                }
1378                                *remaining -= 1;
1379                                continue;
1380                            }
1381                        }
1382                    }
1383                }
1384            }
1385        }
1386
1387        Ok(())
1388    }
1389
1390    let doc = Document::load(path).map_err(|e| ClipError::InferenceError {
1391        cause: format!("Failed to load PDF for image extraction: {}", e),
1392    })?;
1393
1394    let mut remaining = max_pages;
1395    let mut pages: Vec<(u32, DynamicImage)> = Vec::new();
1396
1397    for (page_num, page_id) in doc.get_pages() {
1398        if remaining == 0 {
1399            break;
1400        }
1401        let start_len = pages.len();
1402        extract_images_from_page(&doc, page_id, &mut remaining, &mut pages)?;
1403        if pages.len() > start_len {
1404            for entry in pages.iter_mut().skip(start_len) {
1405                entry.0 = page_num as u32;
1406            }
1407        }
1408    }
1409
1410    Ok(pages)
1411}
1412
1413// ============================================================================
1414// CLIP Embedding Provider Trait
1415// ============================================================================
1416
1417/// Trait for CLIP visual embedding providers.
1418///
1419/// Unlike text `EmbeddingProvider`, CLIP providers handle both:
1420/// - **Image encoding**: Generate embeddings from images (for indexing)
1421/// - **Text encoding**: Generate embeddings from text (for queries)
1422///
1423/// This allows natural language queries against visual content.
1424///
1425/// # Example
1426///
1427/// ```ignore
1428/// use memvid_core::clip::{ClipEmbeddingProvider, ClipConfig};
1429///
1430/// // Create provider
1431/// let provider = ClipModel::new(ClipConfig::default())?;
1432///
1433/// // Encode image for indexing
1434/// let image_embedding = provider.embed_image_file(&path)?;
1435///
1436/// // Encode query text for search
1437/// let query_embedding = provider.embed_query("a photo of a cat")?;
1438///
1439/// // Search uses cosine similarity between query and image embeddings
1440/// ```
1441pub trait ClipEmbeddingProvider: Send + Sync {
1442    /// Return the provider kind (e.g., "mobileclip", "siglip").
1443    fn kind(&self) -> &str;
1444
1445    /// Return the model identifier.
1446    fn model(&self) -> &str;
1447
1448    /// Return the embedding dimension.
1449    fn dimension(&self) -> usize;
1450
1451    /// Generate an embedding for an image file.
1452    fn embed_image_file(&self, path: &Path) -> Result<Vec<f32>>;
1453
1454    /// Generate an embedding for image bytes.
1455    fn embed_image_bytes(&self, bytes: &[u8]) -> Result<Vec<f32>>;
1456
1457    /// Generate an embedding for a text query (for searching).
1458    fn embed_query(&self, text: &str) -> Result<Vec<f32>>;
1459
1460    /// Generate embeddings for multiple image files.
1461    ///
1462    /// Default implementation calls `embed_image_file` in a loop.
1463    /// Providers should override this for efficient batch processing.
1464    fn embed_image_batch(&self, paths: &[&Path]) -> Result<Vec<Vec<f32>>> {
1465        let mut embeddings = Vec::with_capacity(paths.len());
1466        for path in paths {
1467            embeddings.push(self.embed_image_file(path)?);
1468        }
1469        Ok(embeddings)
1470    }
1471
1472    /// Check if the provider is ready to generate embeddings.
1473    fn is_ready(&self) -> bool {
1474        true
1475    }
1476
1477    /// Initialize the provider (e.g., load models).
1478    fn init(&mut self) -> Result<()> {
1479        Ok(())
1480    }
1481
1482    /// Unload models to free memory.
1483    fn unload(&self) -> Result<()> {
1484        Ok(())
1485    }
1486}
1487
1488/// Result type for CLIP embedding operations
1489pub type ClipEmbeddingResult = Result<Vec<f32>>;
1490pub type ClipBatchEmbeddingResult = Result<Vec<Vec<f32>>>;
1491
1492// ============================================================================
1493// ClipEmbeddingProvider Implementation (Feature-gated)
1494// ============================================================================
1495
1496#[cfg(feature = "clip")]
1497impl ClipEmbeddingProvider for ClipModel {
1498    fn kind(&self) -> &str {
1499        "clip"
1500    }
1501
1502    fn model(&self) -> &str {
1503        self.model_info().name
1504    }
1505
1506    fn dimension(&self) -> usize {
1507        self.model_info().dims as usize
1508    }
1509
1510    fn embed_image_file(&self, path: &Path) -> Result<Vec<f32>> {
1511        self.encode_image_file(path)
1512    }
1513
1514    fn embed_image_bytes(&self, bytes: &[u8]) -> Result<Vec<f32>> {
1515        self.encode_image_bytes(bytes)
1516    }
1517
1518    fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
1519        self.encode_text(text)
1520    }
1521
1522    fn embed_image_batch(&self, paths: &[&Path]) -> Result<Vec<Vec<f32>>> {
1523        let mut embeddings = Vec::with_capacity(paths.len());
1524        for path in paths {
1525            embeddings.push(self.encode_image_file(path)?);
1526        }
1527        Ok(embeddings)
1528    }
1529
1530    fn is_ready(&self) -> bool {
1531        // CLIP models are lazy-loaded, so always "ready"
1532        true
1533    }
1534
1535    fn unload(&self) -> Result<()> {
1536        ClipModel::unload(self)
1537    }
1538}
1539
1540// ============================================================================
1541// CLIP Index Manifest (for TOC)
1542// ============================================================================
1543
1544/// Manifest for CLIP index stored in TOC
1545#[derive(Debug, Clone, Serialize, Deserialize)]
1546pub struct ClipIndexManifest {
1547    /// Byte offset in file
1548    pub bytes_offset: u64,
1549    /// Length in bytes
1550    pub bytes_length: u64,
1551    /// Number of vectors
1552    pub vector_count: u64,
1553    /// Embedding dimensions
1554    pub dimension: u32,
1555    /// Blake3 checksum
1556    pub checksum: [u8; 32],
1557    /// Model name used to generate embeddings
1558    pub model_name: String,
1559}
1560
1561// ============================================================================
1562// Tests
1563// ============================================================================
1564
1565#[cfg(test)]
1566mod tests {
1567    use super::*;
1568
1569    #[test]
1570    fn clip_index_builder_roundtrip() {
1571        let mut builder = ClipIndexBuilder::new();
1572        builder.add_document(1, None, vec![0.1, 0.2, 0.3, 0.4]);
1573        builder.add_document(2, None, vec![0.5, 0.6, 0.7, 0.8]);
1574
1575        let artifact = builder.finish().expect("finish");
1576        assert_eq!(artifact.vector_count, 2);
1577        assert_eq!(artifact.dimension, 4);
1578
1579        let index = ClipIndex::decode(&artifact.bytes).expect("decode");
1580        assert_eq!(index.len(), 2);
1581
1582        let hits = index.search(&[0.1, 0.2, 0.3, 0.4], 10);
1583        assert_eq!(hits[0].frame_id, 1);
1584        assert!(hits[0].distance < 0.001); // Should be very close
1585    }
1586
1587    #[test]
1588    fn clip_index_search() {
1589        let mut builder = ClipIndexBuilder::new();
1590        builder.add_document(1, None, vec![1.0, 0.0, 0.0]);
1591        builder.add_document(2, None, vec![0.0, 1.0, 0.0]);
1592        builder.add_document(3, None, vec![0.0, 0.0, 1.0]);
1593
1594        let artifact = builder.finish().expect("finish");
1595        let index = ClipIndex::decode(&artifact.bytes).expect("decode");
1596
1597        // Search for [1, 0, 0] - should find frame 1 first
1598        let hits = index.search(&[1.0, 0.0, 0.0], 3);
1599        assert_eq!(hits[0].frame_id, 1);
1600
1601        // Search for [0, 1, 0] - should find frame 2 first
1602        let hits = index.search(&[0.0, 1.0, 0.0], 3);
1603        assert_eq!(hits[0].frame_id, 2);
1604    }
1605
1606    #[test]
1607    fn l2_distance_calculation() {
1608        let d = l2_distance(&[0.0, 0.0], &[3.0, 4.0]);
1609        assert!((d - 5.0).abs() < 1e-6);
1610
1611        let d = l2_distance(&[1.0, 1.0, 1.0], &[1.0, 1.0, 1.0]);
1612        assert!(d.abs() < 1e-6);
1613    }
1614
1615    #[test]
1616    fn image_info_filtering() {
1617        // Tiny image - should skip
1618        let tiny = ImageInfo {
1619            width: 32,
1620            height: 32,
1621            color_variance: 0.5,
1622        };
1623        assert!(!tiny.should_embed());
1624
1625        // Good image
1626        let good = ImageInfo {
1627            width: 256,
1628            height: 256,
1629            color_variance: 0.5,
1630        };
1631        assert!(good.should_embed());
1632
1633        // Extreme aspect ratio
1634        let wide = ImageInfo {
1635            width: 1000,
1636            height: 10,
1637            color_variance: 0.5,
1638        };
1639        assert!(!wide.should_embed());
1640
1641        // Solid color
1642        let solid = ImageInfo {
1643            width: 256,
1644            height: 256,
1645            color_variance: 0.001,
1646        };
1647        assert!(!solid.should_embed());
1648    }
1649
1650    #[test]
1651    fn model_registry() {
1652        let default = default_model_info();
1653        assert_eq!(default.name, "mobileclip-s2");
1654        assert_eq!(default.dims, 512);
1655        assert!(default.is_default);
1656
1657        let siglip = get_model_info("siglip-base");
1658        assert_eq!(siglip.dims, 768);
1659
1660        // Unknown model returns default
1661        let unknown = get_model_info("nonexistent");
1662        assert_eq!(unknown.name, "mobileclip-s2");
1663    }
1664
1665    #[test]
1666    fn clip_config_defaults() {
1667        // Clear the env vars to test true defaults
1668        // SAFETY: No other threads are modifying these env vars in this test
1669        unsafe {
1670            std::env::remove_var("MEMVID_CLIP_MODEL");
1671            std::env::remove_var("MEMVID_OFFLINE");
1672        }
1673
1674        let config = ClipConfig::default();
1675        assert_eq!(config.model_name, "mobileclip-s2");
1676        assert!(!config.offline);
1677    }
1678
1679    #[test]
1680    fn clip_embedding_provider_trait() {
1681        // Test that the trait is properly defined
1682        fn assert_send_sync<T: Send + Sync>() {}
1683
1684        // The trait should require Send + Sync
1685        assert_send_sync::<Box<dyn super::ClipEmbeddingProvider>>();
1686    }
1687}