Skip to main content

imgfprint/embed/
mod.rs

1//! Semantic embedding support for CLIP-style image representations.
2//!
3//! Provides types and traits for working with high-dimensional semantic embeddings
4//! obtained from external providers (OpenAI CLIP, HuggingFace, local models).
5//!
6//! ## Example
7//!
8//! ```rust,ignore
9//! use imgfprint::{ImageFingerprinter, Embedding, EmbeddingProvider};
10//!
11//! // Define your provider implementation
12//! struct MyClipProvider;
13//!
14//! impl EmbeddingProvider for MyClipProvider {
15//!     fn embed(&self, image: &[u8]) -> Result<Embedding, ImgFprintError> {
16//!         // Call external API or local model
17//!         // Return the embedding vector
18//!     }
19//! }
20//!
21//! let provider = MyClipProvider;
22//! let embedding = ImageFingerprinter::semantic_embedding(&provider, &image_bytes)?;
23//! ```
24
25use crate::error::ImgFprintError;
26
27#[cfg(feature = "local-embedding")]
28pub mod local;
29
30#[cfg(feature = "local-embedding")]
31pub use local::{LocalProvider, LocalProviderConfig};
32
33/// A semantic embedding vector representing image content.
34///
35/// Embeddings are high-dimensional vectors (typically 512-1024 dimensions)
36/// that capture semantic meaning of images in a way that similar images
37/// have similar vector representations.
38///
39/// # Immutability
40///
41/// Once created, embeddings are immutable. The vector can be accessed
42/// via [`as_slice()`](Embedding::as_slice) or [`vector()`](Embedding::vector).
43///
44/// # Validation
45///
46/// Embeddings are validated on creation to ensure:
47/// - Vector is non-empty
48/// - All elements are finite (not NaN or infinity)
49///
50/// # Examples
51///
52/// ```rust
53/// use imgfprint::Embedding;
54///
55/// # fn example() -> Result<(), imgfprint::ImgFprintError> {
56/// // Create from a Vec<f32>
57/// let embedding = Embedding::new(vec![0.1, 0.2, 0.3, 0.4])?;
58///
59/// // Access the vector
60/// assert_eq!(embedding.len(), 4);
61/// assert_eq!(embedding.as_slice(), &[0.1, 0.2, 0.3, 0.4]);
62/// # Ok(())
63/// # }
64/// ```
65#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
66#[cfg_attr(feature = "serde", serde(deny_unknown_fields))]
67#[derive(Debug, Clone, PartialEq)]
68#[allow(clippy::derive_partial_eq_without_eq)] // Vec<f32> field prevents Eq (NaN != NaN)
69pub struct Embedding {
70    vector: Vec<f32>,
71    /// Optional model identifier to prevent comparing embeddings from different models.
72    ///
73    /// Comparing embeddings from different models (e.g., 512-dim CLIP vs 768-dim CLIP)
74    /// is semantically meaningless even if dimensions match. This field enables
75    /// detection of such mismatches.
76    #[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))]
77    model_id: Option<String>,
78}
79
80impl Embedding {
81    /// Creates a new embedding from a vector of f32 values.
82    ///
83    /// # Arguments
84    ///
85    /// * `vector` - The embedding vector. Must be non-empty and contain only finite values.
86    ///
87    /// # Errors
88    ///
89    /// Returns [`ImgFprintError::InvalidEmbedding`] if:
90    /// - The vector is empty
91    /// - Any element is NaN or infinity
92    ///
93    /// # Examples
94    ///
95    /// ```rust
96    /// use imgfprint::Embedding;
97    ///
98    /// # fn example() -> Result<(), imgfprint::ImgFprintError> {
99    /// let embedding = Embedding::new(vec![0.1, 0.2, 0.3])?;
100    /// assert_eq!(embedding.len(), 3);
101    /// # Ok(())
102    /// # }
103    /// ```
104    pub fn new(vector: Vec<f32>) -> Result<Self, ImgFprintError> {
105        Self::new_with_model(vector, None)
106    }
107
108    /// Creates a new embedding from a vector with an optional model identifier.
109    ///
110    /// # Arguments
111    ///
112    /// * `vector` - The embedding vector. Must be non-empty and contain only finite values.
113    /// * `model_id` - Optional model identifier (e.g., "clip-vit-base-patch32")
114    ///
115    /// # Errors
116    ///
117    /// Returns [`ImgFprintError::InvalidEmbedding`] if:
118    /// - The vector is empty
119    /// - Any element is NaN or infinity
120    pub fn new_with_model(
121        vector: Vec<f32>,
122        model_id: Option<String>,
123    ) -> Result<Self, ImgFprintError> {
124        if vector.is_empty() {
125            return Err(ImgFprintError::InvalidEmbedding(
126                "embedding vector cannot be empty".to_string(),
127            ));
128        }
129
130        if vector.iter().any(|&v| !v.is_finite()) {
131            return Err(ImgFprintError::InvalidEmbedding(
132                "embedding contains non-finite values (NaN or infinity)".to_string(),
133            ));
134        }
135
136        Ok(Self { vector, model_id })
137    }
138
139    /// Returns the model identifier if set.
140    #[inline]
141    pub fn model_id(&self) -> Option<&str> {
142        self.model_id.as_deref()
143    }
144
145    /// Returns the embedding vector as a slice.
146    ///
147    /// # Examples
148    ///
149    /// ```rust
150    /// use imgfprint::Embedding;
151    ///
152    /// # fn example() -> Result<(), imgfprint::ImgFprintError> {
153    /// let embedding = Embedding::new(vec![0.1, 0.2, 0.3])?;
154    /// let slice = embedding.as_slice();
155    /// assert_eq!(slice, &[0.1, 0.2, 0.3]);
156    /// # Ok(())
157    /// # }
158    /// ```
159    #[inline]
160    pub fn as_slice(&self) -> &[f32] {
161        &self.vector
162    }
163
164    /// Returns a clone of the embedding vector.
165    ///
166    /// # Performance
167    /// This method allocates and copies the entire vector (O(n)).
168    /// For zero-copy access, prefer [`as_slice()`](Embedding::as_slice).
169    ///
170    /// # Examples
171    ///
172    /// ```rust
173    /// use imgfprint::Embedding;
174    ///
175    /// # fn example() -> Result<(), imgfprint::ImgFprintError> {
176    /// let embedding = Embedding::new(vec![0.1, 0.2, 0.3])?;
177    /// let vec = embedding.vector();
178    /// assert_eq!(vec, vec![0.1, 0.2, 0.3]);
179    /// # Ok(())
180    /// # }
181    /// ```
182    pub fn vector(&self) -> Vec<f32> {
183        self.vector.clone()
184    }
185
186    /// Returns the dimensionality (length) of the embedding.
187    ///
188    /// # Examples
189    ///
190    /// ```rust
191    /// use imgfprint::Embedding;
192    ///
193    /// # fn example() -> Result<(), imgfprint::ImgFprintError> {
194    /// let embedding = Embedding::new(vec![0.1, 0.2, 0.3, 0.4])?;
195    /// assert_eq!(embedding.len(), 4);
196    /// # Ok(())
197    /// # }
198    /// ```
199    #[inline]
200    #[allow(clippy::len_without_is_empty)] // Valid embeddings are never empty
201    pub fn len(&self) -> usize {
202        self.vector.len()
203    }
204
205    /// Returns the dimensionality of the embedding.
206    ///
207    /// Alias for [`len()`](Embedding::len) for semantic clarity.
208    ///
209    /// # Examples
210    ///
211    /// ```rust
212    /// use imgfprint::Embedding;
213    ///
214    /// # fn example() -> Result<(), imgfprint::ImgFprintError> {
215    /// let embedding = Embedding::new(vec![0.1; 512])?;
216    /// assert_eq!(embedding.dimension(), 512);
217    /// # Ok(())
218    /// # }
219    /// ```
220    #[inline]
221    pub fn dimension(&self) -> usize {
222        self.len()
223    }
224}
225
226/// Trait for embedding providers.
227///
228/// Implement this trait to integrate external embedding services such as:
229/// - OpenAI CLIP API
230/// - HuggingFace Inference API
231/// - Local CLIP models (via ONNX, PyTorch, etc.)
232///
233/// The SDK itself does not perform HTTP requests or model inference.
234/// It only defines this abstraction, allowing users to bring their own
235/// provider implementation.
236///
237/// # Example Implementation
238///
239/// ```rust
240/// use imgfprint::{EmbeddingProvider, Embedding, ImgFprintError};
241///
242/// struct DummyProvider;
243///
244/// impl EmbeddingProvider for DummyProvider {
245///     fn embed(&self, _image: &[u8]) -> Result<Embedding, ImgFprintError> {
246///         // In a real implementation, this would call an external API
247///         // or run a local model to generate embeddings
248///         Embedding::new(vec![0.1, 0.2, 0.3, 0.4])
249///     }
250/// }
251/// ```
252pub trait EmbeddingProvider {
253    /// Generates a semantic embedding for the given image bytes.
254    ///
255    /// # Arguments
256    ///
257    /// * `image` - Raw image bytes in any supported format (PNG, JPEG, etc.)
258    ///
259    /// # Errors
260    ///
261    /// Implementations should return:
262    /// - [`ImgFprintError::ProviderError`] for provider-specific failures (network, auth, etc.)
263    /// - [`ImgFprintError::InvalidImage`] if the image format is not supported by the provider
264    /// - [`ImgFprintError::InvalidEmbedding`] if the returned embedding is invalid
265    ///
266    /// # Examples
267    ///
268    /// ```rust
269    /// use imgfprint::{EmbeddingProvider, ImgFprintError};
270    ///
271    /// # fn example<P: EmbeddingProvider>(provider: &P) -> Result<(), ImgFprintError> {
272    /// let image_bytes = vec![0u8; 1000]; // Your image data
273    /// let embedding = provider.embed(&image_bytes)?;
274    /// # Ok(())
275    /// # }
276    /// ```
277    fn embed(&self, image: &[u8]) -> Result<Embedding, ImgFprintError>;
278}
279
280/// Computes cosine similarity between two embeddings.
281///
282/// Cosine similarity measures the cosine of the angle between two vectors,
283/// providing a value in the range [-1.0, 1.0]:
284/// - 1.0: vectors point in the same direction (identical orientation)
285/// - 0.0: vectors are orthogonal (unrelated)
286/// - -1.0: vectors point in opposite directions
287///
288/// For normalized embeddings (common in CLIP models), the range is [0.0, 1.0].
289///
290/// # Arguments
291///
292/// * `a` - First embedding
293/// * `b` - Second embedding
294///
295/// # Returns
296///
297/// Cosine similarity as an `f32` in the range [-1.0, 1.0].
298///
299/// # Errors
300///
301/// Returns [`ImgFprintError::EmbeddingDimensionMismatch`] if the embeddings
302/// have different dimensions.
303///
304/// # Performance
305///
306/// This function:
307/// - Runs in O(n) time where n is the embedding dimension
308/// - Performs no heap allocations
309/// - Operates directly on slices for cache efficiency
310/// - Is SIMD-friendly (the compiler can auto-vectorize the loop)
311///
312/// # Examples
313///
314/// ```rust
315/// use imgfprint::{Embedding, semantic_similarity};
316///
317/// # fn example() -> Result<(), imgfprint::ImgFprintError> {
318/// let a = Embedding::new(vec![1.0, 0.0, 0.0])?;
319/// let b = Embedding::new(vec![1.0, 0.0, 0.0])?;
320///
321/// let sim = semantic_similarity(&a, &b)?;
322/// assert!((sim - 1.0).abs() < 1e-6);
323/// # Ok(())
324/// # }
325/// ```
326pub fn semantic_similarity(a: &Embedding, b: &Embedding) -> Result<f32, ImgFprintError> {
327    // Check for model ID mismatch
328    if let (Some(a_model), Some(b_model)) = (a.model_id(), b.model_id()) {
329        if a_model != b_model {
330            return Err(ImgFprintError::InvalidEmbedding(format!(
331                "model ID mismatch: '{}' vs '{}'",
332                a_model, b_model
333            )));
334        }
335    }
336
337    let a_vec = a.as_slice();
338    let b_vec = b.as_slice();
339
340    if a_vec.len() != b_vec.len() {
341        return Err(ImgFprintError::EmbeddingDimensionMismatch {
342            expected: a_vec.len(),
343            actual: b_vec.len(),
344        });
345    }
346
347    // Note: Embedding::new() already validates finiteness, so no need to re-check here.
348    // Compute dot product and norms in a single pass for better cache locality
349    let mut dot_product: f32 = 0.0;
350    let mut norm_first_sq: f32 = 0.0;
351    let mut norm_second_sq: f32 = 0.0;
352
353    for i in 0..a_vec.len() {
354        let a_i = a_vec[i];
355        let b_i = b_vec[i];
356
357        dot_product += a_i * b_i;
358        norm_first_sq += a_i * a_i;
359        norm_second_sq += b_i * b_i;
360    }
361
362    let norm_a = norm_first_sq.sqrt();
363    let norm_b = norm_second_sq.sqrt();
364
365    // Handle zero vectors (shouldn't happen with valid embeddings, but be safe)
366    if norm_a == 0.0 || norm_b == 0.0 {
367        return Err(ImgFprintError::InvalidEmbedding(
368            "cannot compute similarity for zero vector".to_string(),
369        ));
370    }
371
372    Ok(dot_product / (norm_a * norm_b))
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn test_embedding_new_valid() {
381        let vector = vec![0.1, 0.2, 0.3, 0.4];
382        let embedding = Embedding::new(vector.clone()).unwrap();
383
384        assert_eq!(embedding.len(), 4);
385        assert_eq!(embedding.dimension(), 4);
386        assert_eq!(embedding.as_slice(), &vector);
387        assert_eq!(embedding.vector(), vector);
388    }
389
390    #[test]
391    fn test_embedding_empty_vector() {
392        let result = Embedding::new(vec![]);
393        assert!(matches!(
394            result,
395            Err(ImgFprintError::InvalidEmbedding(msg)) if msg.contains("empty")
396        ));
397    }
398
399    #[test]
400    fn test_embedding_nan_values() {
401        let result = Embedding::new(vec![0.1, f32::NAN, 0.3]);
402        assert!(matches!(
403            result,
404            Err(ImgFprintError::InvalidEmbedding(msg)) if msg.contains("non-finite")
405        ));
406    }
407
408    #[test]
409    fn test_embedding_infinity_values() {
410        let result = Embedding::new(vec![0.1, f32::INFINITY, 0.3]);
411        assert!(matches!(
412            result,
413            Err(ImgFprintError::InvalidEmbedding(msg)) if msg.contains("non-finite")
414        ));
415    }
416
417    #[test]
418    fn test_embedding_negative_infinity() {
419        let result = Embedding::new(vec![0.1, f32::NEG_INFINITY, 0.3]);
420        assert!(matches!(
421            result,
422            Err(ImgFprintError::InvalidEmbedding(msg)) if msg.contains("non-finite")
423        ));
424    }
425
426    // Helper to create embeddings for tests
427    fn emb(vector: Vec<f32>) -> Embedding {
428        Embedding::new(vector).unwrap()
429    }
430
431    #[test]
432    fn test_cosine_similarity_identical() {
433        let a = emb(vec![1.0, 0.0, 0.0, 0.0]);
434        let b = emb(vec![1.0, 0.0, 0.0, 0.0]);
435
436        let sim = semantic_similarity(&a, &b).unwrap();
437        assert!(
438            (sim - 1.0).abs() < 1e-6,
439            "Identical vectors should have similarity 1.0, got {}",
440            sim
441        );
442    }
443
444    #[test]
445    fn test_cosine_similarity_orthogonal() {
446        let a = emb(vec![1.0, 0.0, 0.0]);
447        let b = emb(vec![0.0, 1.0, 0.0]);
448
449        let sim = semantic_similarity(&a, &b).unwrap();
450        assert!(
451            sim.abs() < 1e-6,
452            "Orthogonal vectors should have similarity ~0.0, got {}",
453            sim
454        );
455    }
456
457    #[test]
458    fn test_cosine_similarity_opposite() {
459        let a = emb(vec![1.0, 0.0, 0.0]);
460        let b = emb(vec![-1.0, 0.0, 0.0]);
461
462        let sim = semantic_similarity(&a, &b).unwrap();
463        assert!(
464            (sim - (-1.0)).abs() < 1e-6,
465            "Opposite vectors should have similarity -1.0, got {}",
466            sim
467        );
468    }
469
470    #[test]
471    fn test_cosine_similarity_45_degrees() {
472        let a = emb(vec![1.0, 0.0]);
473        let b = emb(vec![1.0, 1.0]);
474
475        let sim = semantic_similarity(&a, &b).unwrap();
476        let expected = 1.0 / f32::sqrt(2.0);
477        assert!(
478            (sim - expected).abs() < 1e-5,
479            "45-degree angle similarity should be ~0.707, got {}",
480            sim
481        );
482    }
483
484    #[test]
485    fn test_cosine_similarity_dimension_mismatch() {
486        let a = emb(vec![1.0, 0.0, 0.0]);
487        let b = emb(vec![1.0, 0.0]);
488
489        let result = semantic_similarity(&a, &b);
490        assert!(matches!(
491            result,
492            Err(ImgFprintError::EmbeddingDimensionMismatch {
493                expected: 3,
494                actual: 2
495            })
496        ));
497    }
498
499    #[test]
500    fn test_cosine_similarity_with_normalization() {
501        // CLIP embeddings are typically L2-normalized, so similarity is in [0, 1]
502        let a = emb(vec![0.5, 0.5, 0.5, 0.5]);
503        let b = emb(vec![0.5, 0.5, 0.5, 0.5]);
504
505        let sim = semantic_similarity(&a, &b).unwrap();
506        assert!((sim - 1.0).abs() < 1e-6);
507    }
508
509    #[test]
510    fn test_cosine_similarity_various_dimensions() {
511        // Test with various typical CLIP embedding dimensions
512        for dim in [512, 768, 1024] {
513            let a = emb(vec![1.0; dim]);
514            let b = emb(vec![1.0; dim]);
515
516            let sim = semantic_similarity(&a, &b).unwrap();
517            assert!((sim - 1.0).abs() < 1e-6, "Failed for dimension {}", dim);
518        }
519    }
520
521    #[test]
522    fn test_cosine_similarity_negative_values() {
523        let a = emb(vec![1.0, -1.0, 1.0, -1.0]);
524        let b = emb(vec![-1.0, 1.0, -1.0, 1.0]);
525
526        let sim = semantic_similarity(&a, &b).unwrap();
527        assert!(
528            (sim - (-1.0)).abs() < 1e-6,
529            "Opposite signs should give -1.0, got {}",
530            sim
531        );
532    }
533
534    #[test]
535    fn test_cosine_similarity_zero_vector() {
536        let a = emb(vec![1.0, 0.0, 0.0]);
537        let _b = emb(vec![0.0, 0.0, 0.0]);
538
539        // We need to bypass the normal constructor to create a zero vector
540        // Since Embedding::new validates against this, we manually construct
541        let zero_embedding = Embedding {
542            vector: vec![0.0; 3],
543            model_id: None,
544        };
545
546        let result = semantic_similarity(&a, &zero_embedding);
547        assert!(matches!(
548            result,
549            Err(ImgFprintError::InvalidEmbedding(msg)) if msg.contains("zero vector")
550        ));
551    }
552
553    #[test]
554    fn test_embedding_clone() {
555        let a = emb(vec![0.1, 0.2, 0.3]);
556        let b = a.clone();
557
558        assert_eq!(a.as_slice(), b.as_slice());
559        assert_eq!(a.len(), b.len());
560    }
561
562    #[test]
563    fn test_embedding_partial_eq() {
564        let a = emb(vec![0.1, 0.2, 0.3]);
565        let b = emb(vec![0.1, 0.2, 0.3]);
566        let c = emb(vec![0.3, 0.2, 0.1]);
567
568        assert_eq!(a, b);
569        assert_ne!(a, c);
570    }
571
572    // Test provider trait with a mock implementation
573    struct MockProvider {
574        return_value: Vec<f32>,
575    }
576
577    impl EmbeddingProvider for MockProvider {
578        fn embed(&self, _image: &[u8]) -> Result<Embedding, ImgFprintError> {
579            Embedding::new(self.return_value.clone())
580        }
581    }
582
583    #[test]
584    fn test_embedding_provider_mock() {
585        let provider = MockProvider {
586            return_value: vec![0.1, 0.2, 0.3],
587        };
588
589        let image_bytes = vec![0u8; 100];
590        let embedding = provider.embed(&image_bytes).unwrap();
591
592        assert_eq!(embedding.as_slice(), &[0.1, 0.2, 0.3]);
593    }
594
595    #[test]
596    fn test_embedding_provider_error_propagation() {
597        struct FailingProvider;
598
599        impl EmbeddingProvider for FailingProvider {
600            fn embed(&self, _image: &[u8]) -> Result<Embedding, ImgFprintError> {
601                Err(ImgFprintError::ProviderError("network timeout".to_string()))
602            }
603        }
604
605        let provider = FailingProvider;
606        let image_bytes = vec![0u8; 100];
607        let result = provider.embed(&image_bytes);
608
609        assert!(matches!(
610            result,
611            Err(ImgFprintError::ProviderError(msg)) if msg == "network timeout"
612        ));
613    }
614
615    #[test]
616    fn test_embedding_new_with_model() {
617        let vector = vec![0.1, 0.2, 0.3, 0.4];
618        let embedding =
619            Embedding::new_with_model(vector.clone(), Some("clip-vit-base-patch32".to_string()))
620                .unwrap();
621
622        assert_eq!(embedding.len(), 4);
623        assert_eq!(embedding.model_id(), Some("clip-vit-base-patch32"));
624        assert_eq!(embedding.as_slice(), &vector);
625    }
626
627    #[test]
628    fn test_embedding_model_id_mismatch() {
629        let a = Embedding::new_with_model(vec![0.1; 512], Some("model-a".to_string())).unwrap();
630        let b = Embedding::new_with_model(vec![0.1; 512], Some("model-b".to_string())).unwrap();
631
632        let result = semantic_similarity(&a, &b);
633        assert!(matches!(
634            result,
635            Err(ImgFprintError::InvalidEmbedding(msg)) if msg.contains("model ID mismatch")
636        ));
637    }
638
639    #[test]
640    fn test_embedding_same_model_id_ok() {
641        let a =
642            Embedding::new_with_model(vec![1.0; 512], Some("clip-vit-base-patch32".to_string()))
643                .unwrap();
644        let b =
645            Embedding::new_with_model(vec![1.0; 512], Some("clip-vit-base-patch32".to_string()))
646                .unwrap();
647
648        let sim = semantic_similarity(&a, &b).unwrap();
649        assert!((sim - 1.0).abs() < 1e-6);
650    }
651
652    #[test]
653    fn test_embedding_missing_model_id_ok() {
654        // When one or both lack model_id, comparison should work
655        let a = Embedding::new(vec![1.0; 512]).unwrap();
656        let b = Embedding::new_with_model(vec![1.0; 512], None).unwrap();
657        let c = Embedding::new_with_model(vec![1.0; 512], Some("model-a".to_string())).unwrap();
658
659        // Both without model_id
660        let sim1 = semantic_similarity(&a, &b).unwrap();
661        assert!((sim1 - 1.0).abs() < 1e-6);
662
663        // One without, one with - should still work
664        let sim2 = semantic_similarity(&a, &c).unwrap();
665        assert!((sim2 - 1.0).abs() < 1e-6);
666    }
667}