ipfrs_semantic/
multimodal.rs

1//! Multi-modal embedding support for unified semantic search across text, image, and audio.
2//!
3//! This module provides infrastructure for:
4//! - Unified embedding space across modalities
5//! - Cross-modal similarity search
6//! - Modality-specific distance metrics
7//! - Embedding projection and alignment
8
9use crate::{DistanceMetric, VectorIndex};
10use ipfrs_core::{Cid, Error, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14/// Supported modalities for embeddings
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
16pub enum Modality {
17    /// Text embeddings (e.g., from BERT, GPT)
18    Text,
19    /// Image embeddings (e.g., from ResNet, CLIP)
20    Image,
21    /// Audio embeddings (e.g., from Wav2Vec, CLAP)
22    Audio,
23    /// Video embeddings (e.g., from VideoMAE)
24    Video,
25    /// Code embeddings (e.g., from CodeBERT)
26    Code,
27}
28
29impl Modality {
30    /// Get the default embedding dimension for this modality
31    pub fn default_dim(&self) -> usize {
32        match self {
33            Modality::Text => 768,  // BERT-base
34            Modality::Image => 512, // ResNet-50
35            Modality::Audio => 768, // Wav2Vec 2.0
36            Modality::Video => 768, // VideoMAE
37            Modality::Code => 768,  // CodeBERT
38        }
39    }
40
41    /// Get the recommended distance metric for this modality
42    pub fn default_metric(&self) -> DistanceMetric {
43        match self {
44            Modality::Text => DistanceMetric::Cosine,
45            Modality::Image => DistanceMetric::L2,
46            Modality::Audio => DistanceMetric::Cosine,
47            Modality::Video => DistanceMetric::L2,
48            Modality::Code => DistanceMetric::Cosine,
49        }
50    }
51}
52
53/// Multi-modal embedding with modality information
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct MultiModalEmbedding {
56    /// The embedding vector
57    pub vector: Vec<f32>,
58    /// The modality this embedding belongs to
59    pub modality: Modality,
60    /// Optional metadata about the embedding source
61    pub metadata: HashMap<String, String>,
62}
63
64impl MultiModalEmbedding {
65    /// Create a new multi-modal embedding
66    pub fn new(vector: Vec<f32>, modality: Modality) -> Self {
67        Self {
68            vector,
69            modality,
70            metadata: HashMap::new(),
71        }
72    }
73
74    /// Add metadata to the embedding
75    pub fn with_metadata(mut self, key: String, value: String) -> Self {
76        self.metadata.insert(key, value);
77        self
78    }
79
80    /// Get the dimension of the embedding
81    pub fn dim(&self) -> usize {
82        self.vector.len()
83    }
84}
85
86/// Configuration for multi-modal index
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct MultiModalConfig {
89    /// Target dimension for unified embedding space
90    pub unified_dim: usize,
91    /// Whether to project embeddings to unified dimension
92    pub project_to_unified: bool,
93    /// Modality-specific weights for cross-modal search
94    pub modality_weights: HashMap<Modality, f32>,
95}
96
97impl Default for MultiModalConfig {
98    fn default() -> Self {
99        let mut weights = HashMap::new();
100        weights.insert(Modality::Text, 1.0);
101        weights.insert(Modality::Image, 1.0);
102        weights.insert(Modality::Audio, 1.0);
103        weights.insert(Modality::Video, 1.0);
104        weights.insert(Modality::Code, 1.0);
105
106        Self {
107            unified_dim: 768,
108            project_to_unified: false,
109            modality_weights: weights,
110        }
111    }
112}
113
114/// Multi-modal index for unified semantic search
115pub struct MultiModalIndex {
116    /// Separate indices for each modality
117    indices: HashMap<Modality, VectorIndex>,
118    /// Configuration
119    config: MultiModalConfig,
120    /// Projection matrices for each modality (if using unified embedding space)
121    projections: HashMap<Modality, Vec<Vec<f32>>>,
122}
123
124impl MultiModalIndex {
125    /// Create a new multi-modal index
126    pub fn new(config: MultiModalConfig) -> Self {
127        Self {
128            indices: HashMap::new(),
129            config,
130            projections: HashMap::new(),
131        }
132    }
133
134    /// Register a modality with the index
135    pub fn register_modality(&mut self, modality: Modality, dim: usize) -> Result<()> {
136        let metric = modality.default_metric();
137
138        // If projection is enabled, create index with unified dimension
139        // Otherwise, use the modality's dimension
140        let index_dim = if self.config.project_to_unified {
141            self.config.unified_dim
142        } else {
143            dim
144        };
145
146        let index = VectorIndex::new(index_dim, metric, 16, 200)?;
147        self.indices.insert(modality, index);
148
149        // Initialize projection matrix if needed
150        if self.config.project_to_unified && dim != self.config.unified_dim {
151            self.init_projection(modality, dim)?;
152        }
153
154        Ok(())
155    }
156
157    /// Initialize random projection matrix for dimensionality reduction/expansion
158    fn init_projection(&mut self, modality: Modality, from_dim: usize) -> Result<()> {
159        let to_dim = self.config.unified_dim;
160
161        // Use random projection (Johnson-Lindenstrauss lemma)
162        // Each element ~ N(0, 1/to_dim)
163        let mut projection = Vec::with_capacity(from_dim);
164
165        use rand::Rng;
166        let mut rng = rand::rng();
167        let scale = (1.0 / to_dim as f32).sqrt();
168
169        for _ in 0..from_dim {
170            let mut row = Vec::with_capacity(to_dim);
171            for _ in 0..to_dim {
172                // Sample from standard normal, then scale
173                let val: f32 = rng.random_range(-1.0..1.0);
174                row.push(val * scale);
175            }
176            projection.push(row);
177        }
178
179        self.projections.insert(modality, projection);
180        Ok(())
181    }
182
183    /// Project embedding to unified dimension
184    fn project_embedding(&self, embedding: &[f32], modality: Modality) -> Vec<f32> {
185        if !self.config.project_to_unified {
186            return embedding.to_vec();
187        }
188
189        if let Some(projection) = self.projections.get(&modality) {
190            let mut result = vec![0.0; self.config.unified_dim];
191
192            for (i, row) in projection.iter().enumerate() {
193                if i >= embedding.len() {
194                    break;
195                }
196                for (j, &proj_val) in row.iter().enumerate() {
197                    result[j] += embedding[i] * proj_val;
198                }
199            }
200
201            // Normalize
202            let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
203            if norm > 0.0 {
204                for val in &mut result {
205                    *val /= norm;
206                }
207            }
208
209            result
210        } else {
211            embedding.to_vec()
212        }
213    }
214
215    /// Add an embedding to the index
216    pub fn add(&mut self, cid: Cid, embedding: MultiModalEmbedding) -> Result<()> {
217        // Project embedding first to avoid borrowing issues
218        let projected = self.project_embedding(&embedding.vector, embedding.modality);
219
220        let index = self.indices.get_mut(&embedding.modality).ok_or_else(|| {
221            Error::InvalidInput(format!("Modality {:?} not registered", embedding.modality))
222        })?;
223
224        index.insert(&cid, &projected)?;
225
226        Ok(())
227    }
228
229    /// Search within a specific modality
230    pub fn search_modality(
231        &self,
232        query: &MultiModalEmbedding,
233        k: usize,
234        ef_search: Option<usize>,
235    ) -> Result<Vec<(Cid, f32)>> {
236        let index = self.indices.get(&query.modality).ok_or_else(|| {
237            Error::InvalidInput(format!("Modality {:?} not registered", query.modality))
238        })?;
239
240        let projected = self.project_embedding(&query.vector, query.modality);
241        let ef_search = ef_search.unwrap_or(50);
242
243        let results = index.search(&projected, k, ef_search)?;
244        Ok(results.into_iter().map(|r| (r.cid, r.score)).collect())
245    }
246
247    /// Cross-modal search: search across all modalities
248    pub fn search_cross_modal(
249        &self,
250        query: &MultiModalEmbedding,
251        k: usize,
252        ef_search: Option<usize>,
253    ) -> Result<Vec<(Cid, f32, Modality)>> {
254        let mut all_results = Vec::new();
255        let projected_query = self.project_embedding(&query.vector, query.modality);
256        let ef_search = ef_search.unwrap_or(50);
257
258        // Search each modality
259        for (modality, index) in &self.indices {
260            let weight = self
261                .config
262                .modality_weights
263                .get(modality)
264                .copied()
265                .unwrap_or(1.0);
266
267            match index.search(&projected_query, k * 2, ef_search) {
268                Ok(results) => {
269                    for result in results {
270                        // Apply modality weight to score
271                        let weighted_score = result.score * weight;
272                        all_results.push((result.cid, weighted_score, *modality));
273                    }
274                }
275                Err(_) => continue,
276            }
277        }
278
279        // Sort by score and take top k
280        all_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
281        all_results.truncate(k);
282
283        Ok(all_results)
284    }
285
286    /// Get statistics for each modality
287    pub fn stats(&self) -> HashMap<Modality, ModalityStats> {
288        let mut stats = HashMap::new();
289
290        for (modality, index) in &self.indices {
291            stats.insert(
292                *modality,
293                ModalityStats {
294                    num_embeddings: index.len(),
295                    dimension: index.dimension(),
296                    metric: modality.default_metric(),
297                },
298            );
299        }
300
301        stats
302    }
303
304    /// Get the number of embeddings in a specific modality
305    pub fn len_for_modality(&self, modality: Modality) -> usize {
306        self.indices
307            .get(&modality)
308            .map(|idx| idx.len())
309            .unwrap_or(0)
310    }
311
312    /// Check if the index is empty
313    pub fn is_empty(&self) -> bool {
314        self.indices.values().all(|idx| idx.is_empty())
315    }
316
317    /// Total number of embeddings across all modalities
318    pub fn total_len(&self) -> usize {
319        self.indices.values().map(|idx| idx.len()).sum()
320    }
321}
322
323/// Statistics for a specific modality
324#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct ModalityStats {
326    /// Number of embeddings in this modality
327    pub num_embeddings: usize,
328    /// Dimension of embeddings
329    pub dimension: usize,
330    /// Distance metric used
331    pub metric: DistanceMetric,
332}
333
334/// Alignment between two modalities for cross-modal search
335pub struct ModalityAlignment {
336    /// Source modality
337    #[allow(dead_code)]
338    source: Modality,
339    /// Target modality
340    #[allow(dead_code)]
341    target: Modality,
342    /// Learned transformation matrix (source_dim × target_dim)
343    transform: Vec<Vec<f32>>,
344}
345
346impl ModalityAlignment {
347    /// Create a new modality alignment
348    pub fn new(source: Modality, target: Modality, source_dim: usize, target_dim: usize) -> Self {
349        // Initialize with identity-like transformation
350        let mut transform = vec![vec![0.0; target_dim]; source_dim];
351        let min_dim = source_dim.min(target_dim);
352
353        for (i, row) in transform.iter_mut().enumerate().take(min_dim) {
354            row[i] = 1.0;
355        }
356
357        Self {
358            source,
359            target,
360            transform,
361        }
362    }
363
364    /// Learn alignment from paired examples
365    ///
366    /// This is a simplified version - in practice, you'd use CCA, GCCA, or neural networks
367    pub fn learn_from_pairs(&mut self, pairs: &[(Vec<f32>, Vec<f32>)]) -> Result<()> {
368        if pairs.is_empty() {
369            return Err(Error::InvalidInput("No pairs provided".into()));
370        }
371
372        // Simplified learning: use average mapping
373        // In practice, use Canonical Correlation Analysis (CCA) or neural networks
374        let source_dim = pairs[0].0.len();
375        let target_dim = pairs[0].1.len();
376
377        let mut transform = vec![vec![0.0; target_dim]; source_dim];
378
379        for (source_vec, target_vec) in pairs {
380            for (i, &source_val) in source_vec.iter().enumerate().take(source_dim) {
381                for (j, &target_val) in target_vec.iter().enumerate().take(target_dim) {
382                    transform[i][j] += source_val * target_val;
383                }
384            }
385        }
386
387        // Normalize by number of pairs
388        let n = pairs.len() as f32;
389        for row in &mut transform {
390            for val in row {
391                *val /= n;
392            }
393        }
394
395        self.transform = transform;
396        Ok(())
397    }
398
399    /// Transform a source embedding to target modality space
400    pub fn transform_embedding(&self, source: &[f32]) -> Vec<f32> {
401        let target_dim = self.transform[0].len();
402        let mut result = vec![0.0; target_dim];
403
404        for (i, row) in self.transform.iter().enumerate() {
405            if i >= source.len() {
406                break;
407            }
408            for (j, &val) in row.iter().enumerate() {
409                result[j] += source[i] * val;
410            }
411        }
412
413        // Normalize
414        let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
415        if norm > 0.0 {
416            for val in &mut result {
417                *val /= norm;
418            }
419        }
420
421        result
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    fn generate_test_cid(index: usize) -> Cid {
430        use multihash_codetable::{Code, MultihashDigest};
431        let data = format!("multimodal_test_{}", index);
432        let hash = Code::Sha2_256.digest(data.as_bytes());
433        Cid::new_v1(0x55, hash)
434    }
435
436    #[test]
437    fn test_modality_defaults() {
438        assert_eq!(Modality::Text.default_dim(), 768);
439        assert_eq!(Modality::Image.default_dim(), 512);
440        assert_eq!(Modality::Text.default_metric(), DistanceMetric::Cosine);
441    }
442
443    #[test]
444    fn test_multimodal_embedding_creation() {
445        let vec = vec![0.1, 0.2, 0.3];
446        let emb = MultiModalEmbedding::new(vec.clone(), Modality::Text);
447
448        assert_eq!(emb.vector, vec);
449        assert_eq!(emb.modality, Modality::Text);
450        assert_eq!(emb.dim(), 3);
451    }
452
453    #[test]
454    fn test_multimodal_index_creation() {
455        let config = MultiModalConfig::default();
456        let mut index = MultiModalIndex::new(config);
457
458        assert!(index.is_empty());
459        assert_eq!(index.total_len(), 0);
460
461        // Register modalities
462        index.register_modality(Modality::Text, 768).unwrap();
463        index.register_modality(Modality::Image, 512).unwrap();
464
465        assert_eq!(index.len_for_modality(Modality::Text), 0);
466        assert_eq!(index.len_for_modality(Modality::Image), 0);
467    }
468
469    #[test]
470    fn test_add_and_search_single_modality() {
471        let config = MultiModalConfig::default();
472        let mut index = MultiModalIndex::new(config);
473        index.register_modality(Modality::Text, 3).unwrap();
474
475        // Add embeddings
476        let cid1 = generate_test_cid(1);
477        let emb1 = MultiModalEmbedding::new(vec![1.0, 0.0, 0.0], Modality::Text);
478        index.add(cid1, emb1).unwrap();
479
480        let cid2 = generate_test_cid(2);
481        let emb2 = MultiModalEmbedding::new(vec![0.0, 1.0, 0.0], Modality::Text);
482        index.add(cid2, emb2).unwrap();
483
484        assert_eq!(index.len_for_modality(Modality::Text), 2);
485
486        // Search
487        let query = MultiModalEmbedding::new(vec![0.9, 0.1, 0.0], Modality::Text);
488        let results = index.search_modality(&query, 1, None).unwrap();
489
490        assert_eq!(results.len(), 1);
491        assert_eq!(results[0].0, cid1);
492    }
493
494    #[test]
495    fn test_cross_modal_search() {
496        let config = MultiModalConfig::default();
497        let mut index = MultiModalIndex::new(config);
498
499        index.register_modality(Modality::Text, 3).unwrap();
500        index.register_modality(Modality::Image, 3).unwrap();
501
502        // Add text embedding
503        let cid1 = generate_test_cid(3);
504        let emb1 = MultiModalEmbedding::new(vec![1.0, 0.0, 0.0], Modality::Text);
505        index.add(cid1, emb1).unwrap();
506
507        // Add image embedding
508        let cid2 = generate_test_cid(4);
509        let emb2 = MultiModalEmbedding::new(vec![0.0, 1.0, 0.0], Modality::Image);
510        index.add(cid2, emb2).unwrap();
511
512        // Cross-modal search from text
513        let query = MultiModalEmbedding::new(vec![0.9, 0.1, 0.0], Modality::Text);
514        let results = index.search_cross_modal(&query, 2, None).unwrap();
515
516        assert!(!results.is_empty());
517    }
518
519    #[test]
520    fn test_modality_alignment() {
521        let mut alignment = ModalityAlignment::new(Modality::Text, Modality::Image, 3, 3);
522
523        // Create some paired examples
524        let pairs = vec![
525            (vec![1.0, 0.0, 0.0], vec![0.9, 0.1, 0.0]),
526            (vec![0.0, 1.0, 0.0], vec![0.1, 0.9, 0.0]),
527        ];
528
529        alignment.learn_from_pairs(&pairs).unwrap();
530
531        // Transform a source embedding
532        let source = vec![1.0, 0.0, 0.0];
533        let transformed = alignment.transform_embedding(&source);
534
535        assert_eq!(transformed.len(), 3);
536        assert!(transformed[0] > 0.5); // Should be close to target space
537    }
538
539    #[test]
540    fn test_modality_stats() {
541        let config = MultiModalConfig::default();
542        let mut index = MultiModalIndex::new(config);
543
544        index.register_modality(Modality::Text, 768).unwrap();
545        index.register_modality(Modality::Image, 512).unwrap();
546
547        let stats = index.stats();
548
549        assert_eq!(stats.len(), 2);
550        assert_eq!(stats.get(&Modality::Text).unwrap().dimension, 768);
551        assert_eq!(stats.get(&Modality::Image).unwrap().dimension, 512);
552    }
553
554    #[test]
555    fn test_projection() {
556        let config = MultiModalConfig {
557            project_to_unified: true,
558            unified_dim: 512,
559            ..Default::default()
560        };
561
562        let mut index = MultiModalIndex::new(config);
563        index.register_modality(Modality::Text, 768).unwrap();
564
565        // Add an embedding (should be projected from 768 to 512)
566        let cid = generate_test_cid(5);
567        let emb = MultiModalEmbedding::new(vec![0.5; 768], Modality::Text);
568        index.add(cid, emb).unwrap();
569
570        assert_eq!(index.len_for_modality(Modality::Text), 1);
571    }
572}