1use crate::{DistanceMetric, VectorIndex};
10use ipfrs_core::{Cid, Error, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
16pub enum Modality {
17 Text,
19 Image,
21 Audio,
23 Video,
25 Code,
27}
28
29impl Modality {
30 pub fn default_dim(&self) -> usize {
32 match self {
33 Modality::Text => 768, Modality::Image => 512, Modality::Audio => 768, Modality::Video => 768, Modality::Code => 768, }
39 }
40
41 pub fn default_metric(&self) -> DistanceMetric {
43 match self {
44 Modality::Text => DistanceMetric::Cosine,
45 Modality::Image => DistanceMetric::L2,
46 Modality::Audio => DistanceMetric::Cosine,
47 Modality::Video => DistanceMetric::L2,
48 Modality::Code => DistanceMetric::Cosine,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct MultiModalEmbedding {
56 pub vector: Vec<f32>,
58 pub modality: Modality,
60 pub metadata: HashMap<String, String>,
62}
63
64impl MultiModalEmbedding {
65 pub fn new(vector: Vec<f32>, modality: Modality) -> Self {
67 Self {
68 vector,
69 modality,
70 metadata: HashMap::new(),
71 }
72 }
73
74 pub fn with_metadata(mut self, key: String, value: String) -> Self {
76 self.metadata.insert(key, value);
77 self
78 }
79
80 pub fn dim(&self) -> usize {
82 self.vector.len()
83 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct MultiModalConfig {
89 pub unified_dim: usize,
91 pub project_to_unified: bool,
93 pub modality_weights: HashMap<Modality, f32>,
95}
96
97impl Default for MultiModalConfig {
98 fn default() -> Self {
99 let mut weights = HashMap::new();
100 weights.insert(Modality::Text, 1.0);
101 weights.insert(Modality::Image, 1.0);
102 weights.insert(Modality::Audio, 1.0);
103 weights.insert(Modality::Video, 1.0);
104 weights.insert(Modality::Code, 1.0);
105
106 Self {
107 unified_dim: 768,
108 project_to_unified: false,
109 modality_weights: weights,
110 }
111 }
112}
113
114pub struct MultiModalIndex {
116 indices: HashMap<Modality, VectorIndex>,
118 config: MultiModalConfig,
120 projections: HashMap<Modality, Vec<Vec<f32>>>,
122}
123
124impl MultiModalIndex {
125 pub fn new(config: MultiModalConfig) -> Self {
127 Self {
128 indices: HashMap::new(),
129 config,
130 projections: HashMap::new(),
131 }
132 }
133
134 pub fn register_modality(&mut self, modality: Modality, dim: usize) -> Result<()> {
136 let metric = modality.default_metric();
137
138 let index_dim = if self.config.project_to_unified {
141 self.config.unified_dim
142 } else {
143 dim
144 };
145
146 let index = VectorIndex::new(index_dim, metric, 16, 200)?;
147 self.indices.insert(modality, index);
148
149 if self.config.project_to_unified && dim != self.config.unified_dim {
151 self.init_projection(modality, dim)?;
152 }
153
154 Ok(())
155 }
156
157 fn init_projection(&mut self, modality: Modality, from_dim: usize) -> Result<()> {
159 let to_dim = self.config.unified_dim;
160
161 let mut projection = Vec::with_capacity(from_dim);
164
165 use rand::Rng;
166 let mut rng = rand::rng();
167 let scale = (1.0 / to_dim as f32).sqrt();
168
169 for _ in 0..from_dim {
170 let mut row = Vec::with_capacity(to_dim);
171 for _ in 0..to_dim {
172 let val: f32 = rng.random_range(-1.0..1.0);
174 row.push(val * scale);
175 }
176 projection.push(row);
177 }
178
179 self.projections.insert(modality, projection);
180 Ok(())
181 }
182
183 fn project_embedding(&self, embedding: &[f32], modality: Modality) -> Vec<f32> {
185 if !self.config.project_to_unified {
186 return embedding.to_vec();
187 }
188
189 if let Some(projection) = self.projections.get(&modality) {
190 let mut result = vec![0.0; self.config.unified_dim];
191
192 for (i, row) in projection.iter().enumerate() {
193 if i >= embedding.len() {
194 break;
195 }
196 for (j, &proj_val) in row.iter().enumerate() {
197 result[j] += embedding[i] * proj_val;
198 }
199 }
200
201 let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
203 if norm > 0.0 {
204 for val in &mut result {
205 *val /= norm;
206 }
207 }
208
209 result
210 } else {
211 embedding.to_vec()
212 }
213 }
214
215 pub fn add(&mut self, cid: Cid, embedding: MultiModalEmbedding) -> Result<()> {
217 let projected = self.project_embedding(&embedding.vector, embedding.modality);
219
220 let index = self.indices.get_mut(&embedding.modality).ok_or_else(|| {
221 Error::InvalidInput(format!("Modality {:?} not registered", embedding.modality))
222 })?;
223
224 index.insert(&cid, &projected)?;
225
226 Ok(())
227 }
228
229 pub fn search_modality(
231 &self,
232 query: &MultiModalEmbedding,
233 k: usize,
234 ef_search: Option<usize>,
235 ) -> Result<Vec<(Cid, f32)>> {
236 let index = self.indices.get(&query.modality).ok_or_else(|| {
237 Error::InvalidInput(format!("Modality {:?} not registered", query.modality))
238 })?;
239
240 let projected = self.project_embedding(&query.vector, query.modality);
241 let ef_search = ef_search.unwrap_or(50);
242
243 let results = index.search(&projected, k, ef_search)?;
244 Ok(results.into_iter().map(|r| (r.cid, r.score)).collect())
245 }
246
247 pub fn search_cross_modal(
249 &self,
250 query: &MultiModalEmbedding,
251 k: usize,
252 ef_search: Option<usize>,
253 ) -> Result<Vec<(Cid, f32, Modality)>> {
254 let mut all_results = Vec::new();
255 let projected_query = self.project_embedding(&query.vector, query.modality);
256 let ef_search = ef_search.unwrap_or(50);
257
258 for (modality, index) in &self.indices {
260 let weight = self
261 .config
262 .modality_weights
263 .get(modality)
264 .copied()
265 .unwrap_or(1.0);
266
267 match index.search(&projected_query, k * 2, ef_search) {
268 Ok(results) => {
269 for result in results {
270 let weighted_score = result.score * weight;
272 all_results.push((result.cid, weighted_score, *modality));
273 }
274 }
275 Err(_) => continue,
276 }
277 }
278
279 all_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
281 all_results.truncate(k);
282
283 Ok(all_results)
284 }
285
286 pub fn stats(&self) -> HashMap<Modality, ModalityStats> {
288 let mut stats = HashMap::new();
289
290 for (modality, index) in &self.indices {
291 stats.insert(
292 *modality,
293 ModalityStats {
294 num_embeddings: index.len(),
295 dimension: index.dimension(),
296 metric: modality.default_metric(),
297 },
298 );
299 }
300
301 stats
302 }
303
304 pub fn len_for_modality(&self, modality: Modality) -> usize {
306 self.indices
307 .get(&modality)
308 .map(|idx| idx.len())
309 .unwrap_or(0)
310 }
311
312 pub fn is_empty(&self) -> bool {
314 self.indices.values().all(|idx| idx.is_empty())
315 }
316
317 pub fn total_len(&self) -> usize {
319 self.indices.values().map(|idx| idx.len()).sum()
320 }
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct ModalityStats {
326 pub num_embeddings: usize,
328 pub dimension: usize,
330 pub metric: DistanceMetric,
332}
333
334pub struct ModalityAlignment {
336 #[allow(dead_code)]
338 source: Modality,
339 #[allow(dead_code)]
341 target: Modality,
342 transform: Vec<Vec<f32>>,
344}
345
346impl ModalityAlignment {
347 pub fn new(source: Modality, target: Modality, source_dim: usize, target_dim: usize) -> Self {
349 let mut transform = vec![vec![0.0; target_dim]; source_dim];
351 let min_dim = source_dim.min(target_dim);
352
353 for (i, row) in transform.iter_mut().enumerate().take(min_dim) {
354 row[i] = 1.0;
355 }
356
357 Self {
358 source,
359 target,
360 transform,
361 }
362 }
363
364 pub fn learn_from_pairs(&mut self, pairs: &[(Vec<f32>, Vec<f32>)]) -> Result<()> {
368 if pairs.is_empty() {
369 return Err(Error::InvalidInput("No pairs provided".into()));
370 }
371
372 let source_dim = pairs[0].0.len();
375 let target_dim = pairs[0].1.len();
376
377 let mut transform = vec![vec![0.0; target_dim]; source_dim];
378
379 for (source_vec, target_vec) in pairs {
380 for (i, &source_val) in source_vec.iter().enumerate().take(source_dim) {
381 for (j, &target_val) in target_vec.iter().enumerate().take(target_dim) {
382 transform[i][j] += source_val * target_val;
383 }
384 }
385 }
386
387 let n = pairs.len() as f32;
389 for row in &mut transform {
390 for val in row {
391 *val /= n;
392 }
393 }
394
395 self.transform = transform;
396 Ok(())
397 }
398
399 pub fn transform_embedding(&self, source: &[f32]) -> Vec<f32> {
401 let target_dim = self.transform[0].len();
402 let mut result = vec![0.0; target_dim];
403
404 for (i, row) in self.transform.iter().enumerate() {
405 if i >= source.len() {
406 break;
407 }
408 for (j, &val) in row.iter().enumerate() {
409 result[j] += source[i] * val;
410 }
411 }
412
413 let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
415 if norm > 0.0 {
416 for val in &mut result {
417 *val /= norm;
418 }
419 }
420
421 result
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 fn generate_test_cid(index: usize) -> Cid {
430 use multihash_codetable::{Code, MultihashDigest};
431 let data = format!("multimodal_test_{}", index);
432 let hash = Code::Sha2_256.digest(data.as_bytes());
433 Cid::new_v1(0x55, hash)
434 }
435
436 #[test]
437 fn test_modality_defaults() {
438 assert_eq!(Modality::Text.default_dim(), 768);
439 assert_eq!(Modality::Image.default_dim(), 512);
440 assert_eq!(Modality::Text.default_metric(), DistanceMetric::Cosine);
441 }
442
443 #[test]
444 fn test_multimodal_embedding_creation() {
445 let vec = vec![0.1, 0.2, 0.3];
446 let emb = MultiModalEmbedding::new(vec.clone(), Modality::Text);
447
448 assert_eq!(emb.vector, vec);
449 assert_eq!(emb.modality, Modality::Text);
450 assert_eq!(emb.dim(), 3);
451 }
452
453 #[test]
454 fn test_multimodal_index_creation() {
455 let config = MultiModalConfig::default();
456 let mut index = MultiModalIndex::new(config);
457
458 assert!(index.is_empty());
459 assert_eq!(index.total_len(), 0);
460
461 index.register_modality(Modality::Text, 768).unwrap();
463 index.register_modality(Modality::Image, 512).unwrap();
464
465 assert_eq!(index.len_for_modality(Modality::Text), 0);
466 assert_eq!(index.len_for_modality(Modality::Image), 0);
467 }
468
469 #[test]
470 fn test_add_and_search_single_modality() {
471 let config = MultiModalConfig::default();
472 let mut index = MultiModalIndex::new(config);
473 index.register_modality(Modality::Text, 3).unwrap();
474
475 let cid1 = generate_test_cid(1);
477 let emb1 = MultiModalEmbedding::new(vec![1.0, 0.0, 0.0], Modality::Text);
478 index.add(cid1, emb1).unwrap();
479
480 let cid2 = generate_test_cid(2);
481 let emb2 = MultiModalEmbedding::new(vec![0.0, 1.0, 0.0], Modality::Text);
482 index.add(cid2, emb2).unwrap();
483
484 assert_eq!(index.len_for_modality(Modality::Text), 2);
485
486 let query = MultiModalEmbedding::new(vec![0.9, 0.1, 0.0], Modality::Text);
488 let results = index.search_modality(&query, 1, None).unwrap();
489
490 assert_eq!(results.len(), 1);
491 assert_eq!(results[0].0, cid1);
492 }
493
494 #[test]
495 fn test_cross_modal_search() {
496 let config = MultiModalConfig::default();
497 let mut index = MultiModalIndex::new(config);
498
499 index.register_modality(Modality::Text, 3).unwrap();
500 index.register_modality(Modality::Image, 3).unwrap();
501
502 let cid1 = generate_test_cid(3);
504 let emb1 = MultiModalEmbedding::new(vec![1.0, 0.0, 0.0], Modality::Text);
505 index.add(cid1, emb1).unwrap();
506
507 let cid2 = generate_test_cid(4);
509 let emb2 = MultiModalEmbedding::new(vec![0.0, 1.0, 0.0], Modality::Image);
510 index.add(cid2, emb2).unwrap();
511
512 let query = MultiModalEmbedding::new(vec![0.9, 0.1, 0.0], Modality::Text);
514 let results = index.search_cross_modal(&query, 2, None).unwrap();
515
516 assert!(!results.is_empty());
517 }
518
519 #[test]
520 fn test_modality_alignment() {
521 let mut alignment = ModalityAlignment::new(Modality::Text, Modality::Image, 3, 3);
522
523 let pairs = vec![
525 (vec![1.0, 0.0, 0.0], vec![0.9, 0.1, 0.0]),
526 (vec![0.0, 1.0, 0.0], vec![0.1, 0.9, 0.0]),
527 ];
528
529 alignment.learn_from_pairs(&pairs).unwrap();
530
531 let source = vec![1.0, 0.0, 0.0];
533 let transformed = alignment.transform_embedding(&source);
534
535 assert_eq!(transformed.len(), 3);
536 assert!(transformed[0] > 0.5); }
538
539 #[test]
540 fn test_modality_stats() {
541 let config = MultiModalConfig::default();
542 let mut index = MultiModalIndex::new(config);
543
544 index.register_modality(Modality::Text, 768).unwrap();
545 index.register_modality(Modality::Image, 512).unwrap();
546
547 let stats = index.stats();
548
549 assert_eq!(stats.len(), 2);
550 assert_eq!(stats.get(&Modality::Text).unwrap().dimension, 768);
551 assert_eq!(stats.get(&Modality::Image).unwrap().dimension, 512);
552 }
553
554 #[test]
555 fn test_projection() {
556 let config = MultiModalConfig {
557 project_to_unified: true,
558 unified_dim: 512,
559 ..Default::default()
560 };
561
562 let mut index = MultiModalIndex::new(config);
563 index.register_modality(Modality::Text, 768).unwrap();
564
565 let cid = generate_test_cid(5);
567 let emb = MultiModalEmbedding::new(vec![0.5; 768], Modality::Text);
568 index.add(cid, emb).unwrap();
569
570 assert_eq!(index.len_for_modality(Modality::Text), 1);
571 }
572}