use anyhow::{anyhow, Result};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Modality {
Text,
Image,
Audio,
Video,
Code,
Custom(u8),
}
impl Modality {
pub fn name(&self) -> &str {
match self {
Modality::Text => "text",
Modality::Image => "image",
Modality::Audio => "audio",
Modality::Video => "video",
Modality::Code => "code",
Modality::Custom(_) => "custom",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MultiModalFusion {
Average,
WeightedAverage(Vec<f32>),
Concatenate,
Max,
LateFusion {
weights: HashMap<Modality, f32>,
},
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MultiModalEntry {
pub id: String,
pub embeddings: HashMap<Modality, Vec<f32>>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
impl MultiModalEntry {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
embeddings: HashMap::new(),
metadata: None,
}
}
pub fn add_embedding(mut self, modality: Modality, embedding: Vec<f32>) -> Self {
self.embeddings.insert(modality, embedding);
self
}
pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
self.metadata = Some(metadata);
self
}
pub fn get_embedding(&self, modality: Modality) -> Option<&Vec<f32>> {
self.embeddings.get(&modality)
}
pub fn has_modality(&self, modality: Modality) -> bool {
self.embeddings.contains_key(&modality)
}
}
#[derive(Clone)]
pub struct MultiModalQuery {
pub embeddings: HashMap<Modality, Vec<f32>>,
pub target_modalities: Option<Vec<Modality>>,
pub limit: usize,
pub fusion: MultiModalFusion,
pub filter: Option<HashMap<String, serde_json::Value>>,
}
impl MultiModalQuery {
pub fn new(embedding: Vec<f32>, modality: Modality) -> Self {
let mut embeddings = HashMap::new();
embeddings.insert(modality, embedding);
Self {
embeddings,
target_modalities: None,
limit: 10,
fusion: MultiModalFusion::Average,
filter: None,
}
}
pub fn multi(embeddings: HashMap<Modality, Vec<f32>>) -> Self {
Self {
embeddings,
target_modalities: None,
limit: 10,
fusion: MultiModalFusion::Average,
filter: None,
}
}
pub fn with_target_modality(mut self, modality: Modality) -> Self {
self.target_modalities = Some(vec![modality]);
self
}
pub fn with_target_modalities(mut self, modalities: Vec<Modality>) -> Self {
self.target_modalities = Some(modalities);
self
}
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
pub fn with_fusion(mut self, fusion: MultiModalFusion) -> Self {
self.fusion = fusion;
self
}
pub fn with_filter(mut self, filter: HashMap<String, serde_json::Value>) -> Self {
self.filter = Some(filter);
self
}
}
#[derive(Debug, Clone)]
pub struct MultiModalResult {
pub id: String,
pub score: f32,
pub distance: f32,
pub modality: Modality,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
pub struct MultiModalIndex {
entries: Vec<MultiModalEntry>,
dimensions: HashMap<Modality, usize>,
}
impl MultiModalIndex {
pub fn new() -> Result<Self> {
Ok(Self {
entries: Vec::new(),
dimensions: HashMap::new(),
})
}
pub fn add_with_modality(
&mut self,
id: impl Into<String>,
embedding: Vec<f32>,
modality: Modality,
metadata: Option<HashMap<String, serde_json::Value>>,
) -> Result<()> {
let dim = embedding.len();
if let Some(&existing_dim) = self.dimensions.get(&modality) {
if existing_dim != dim {
return Err(anyhow!(
"Dimension mismatch for modality {:?}: expected {}, got {}",
modality,
existing_dim,
dim
));
}
} else {
self.dimensions.insert(modality, dim);
}
let mut entry = MultiModalEntry::new(id);
entry.embeddings.insert(modality, embedding);
if let Some(meta) = metadata {
entry.metadata = Some(meta);
}
self.entries.push(entry);
Ok(())
}
pub fn add_entry(&mut self, entry: MultiModalEntry) -> Result<()> {
for (modality, embedding) in &entry.embeddings {
let dim = embedding.len();
if let Some(&existing_dim) = self.dimensions.get(modality) {
if existing_dim != dim {
return Err(anyhow!(
"Dimension mismatch for modality {:?}: expected {}, got {}",
modality,
existing_dim,
dim
));
}
} else {
self.dimensions.insert(*modality, dim);
}
}
self.entries.push(entry);
Ok(())
}
pub fn search(&self, query: &MultiModalQuery) -> Result<Vec<MultiModalResult>> {
if self.entries.is_empty() {
return Ok(Vec::new());
}
let results = match &query.fusion {
MultiModalFusion::LateFusion { weights } => self.late_fusion_search(query, weights)?,
_ => self.early_fusion_search(query)?,
};
Ok(results)
}
fn early_fusion_search(&self, query: &MultiModalQuery) -> Result<Vec<MultiModalResult>> {
let target_modalities = if let Some(ref targets) = query.target_modalities {
targets.clone()
} else {
self.dimensions.keys().copied().collect()
};
let mut results: Vec<MultiModalResult> = self
.entries
.par_iter()
.filter(|entry| {
target_modalities.iter().any(|m| entry.has_modality(*m))
})
.filter(|entry| {
if let Some(ref filter) = query.filter {
if let Some(ref entry_meta) = entry.metadata {
filter.iter().all(|(k, v)| entry_meta.get(k) == Some(v))
} else {
false
}
} else {
true
}
})
.flat_map(|entry| {
target_modalities
.iter()
.filter_map(|&target_modality| {
let entry_embedding = entry.get_embedding(target_modality)?;
let query_embedding = self
.fuse_query_embeddings(
&query.embeddings,
&query.fusion,
target_modality,
)
.ok()?;
let distance = euclidean_distance(&query_embedding, entry_embedding);
let score = 1.0 / (1.0 + distance);
Some(MultiModalResult {
id: entry.id.clone(),
score,
distance,
modality: target_modality,
metadata: entry.metadata.clone(),
})
})
.collect::<Vec<_>>()
})
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
results.truncate(query.limit);
Ok(results)
}
fn late_fusion_search(
&self,
query: &MultiModalQuery,
weights: &HashMap<Modality, f32>,
) -> Result<Vec<MultiModalResult>> {
let mut all_results: HashMap<String, (f32, MultiModalResult)> = HashMap::new();
for (query_modality, query_embedding) in &query.embeddings {
let weight = weights.get(query_modality).copied().unwrap_or(1.0);
for entry in &self.entries {
if let Some(entry_embedding) = entry.get_embedding(*query_modality) {
if let Some(ref filter) = query.filter {
if let Some(ref entry_meta) = entry.metadata {
if !filter.iter().all(|(k, v)| entry_meta.get(k) == Some(v)) {
continue;
}
} else {
continue;
}
}
let distance = euclidean_distance(query_embedding, entry_embedding);
let score = (1.0 / (1.0 + distance)) * weight;
all_results
.entry(entry.id.clone())
.and_modify(|(accumulated_score, _)| *accumulated_score += score)
.or_insert((
score,
MultiModalResult {
id: entry.id.clone(),
score,
distance,
modality: *query_modality,
metadata: entry.metadata.clone(),
},
));
}
}
}
let mut results: Vec<MultiModalResult> = all_results
.into_iter()
.map(|(_, (accumulated_score, mut result))| {
result.score = accumulated_score;
result
})
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
results.truncate(query.limit);
Ok(results)
}
fn fuse_query_embeddings(
&self,
query_embeddings: &HashMap<Modality, Vec<f32>>,
strategy: &MultiModalFusion,
target_modality: Modality,
) -> Result<Vec<f32>> {
if query_embeddings.len() == 1 {
let (_, embedding) = query_embeddings.iter().next().unwrap();
return Ok(embedding.clone());
}
match strategy {
MultiModalFusion::Average => {
let embeddings: Vec<&Vec<f32>> = query_embeddings.values().collect();
let dim = embeddings[0].len();
let mut fused = vec![0.0; dim];
for embedding in &embeddings {
for (i, &val) in embedding.iter().enumerate() {
fused[i] += val;
}
}
let n = embeddings.len() as f32;
for val in &mut fused {
*val /= n;
}
Ok(fused)
}
MultiModalFusion::WeightedAverage(weights) => {
let embeddings: Vec<&Vec<f32>> = query_embeddings.values().collect();
let dim = embeddings[0].len();
let mut fused = vec![0.0; dim];
for (embedding, &weight) in embeddings.iter().zip(weights.iter()) {
for (i, &val) in embedding.iter().enumerate() {
fused[i] += val * weight;
}
}
Ok(fused)
}
MultiModalFusion::Concatenate => {
let mut fused = Vec::new();
for embedding in query_embeddings.values() {
fused.extend_from_slice(embedding);
}
Ok(fused)
}
MultiModalFusion::Max => {
let embeddings: Vec<&Vec<f32>> = query_embeddings.values().collect();
let dim = embeddings[0].len();
let mut fused = vec![f32::NEG_INFINITY; dim];
for embedding in &embeddings {
for (i, &val) in embedding.iter().enumerate() {
fused[i] = fused[i].max(val);
}
}
Ok(fused)
}
MultiModalFusion::LateFusion { .. } => {
Err(anyhow!("Late fusion not applicable in early fusion search"))
}
}
}
pub fn remove(&mut self, id: &str) -> Result<bool> {
if let Some(pos) = self.entries.iter().position(|e| e.id == id) {
self.entries.remove(pos);
Ok(true)
} else {
Ok(false)
}
}
pub fn stats(&self) -> MultiModalStats {
let mut modality_counts: HashMap<Modality, usize> = HashMap::new();
for entry in &self.entries {
for modality in entry.embeddings.keys() {
*modality_counts.entry(*modality).or_insert(0) += 1;
}
}
MultiModalStats {
total_entries: self.entries.len(),
modality_counts,
dimensions: self.dimensions.clone(),
}
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
impl Default for MultiModalIndex {
fn default() -> Self {
Self::new().unwrap()
}
}
#[derive(Debug, Clone)]
pub struct MultiModalStats {
pub total_entries: usize,
pub modality_counts: HashMap<Modality, usize>,
pub dimensions: HashMap<Modality, usize>,
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multimodal_basic() {
let mut index = MultiModalIndex::new().unwrap();
index
.add_with_modality("doc1", vec![0.1, 0.2, 0.3], Modality::Text, None)
.unwrap();
index
.add_with_modality("img1", vec![0.4, 0.5, 0.6], Modality::Image, None)
.unwrap();
assert_eq!(index.len(), 2);
let query = MultiModalQuery::new(vec![0.1, 0.2, 0.3], Modality::Text);
let results = index.search(&query).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].id, "doc1");
}
#[test]
fn test_cross_modal_search() {
let mut index = MultiModalIndex::new().unwrap();
index
.add_with_modality("doc1", vec![0.1; 64], Modality::Text, None)
.unwrap();
index
.add_with_modality("img1", vec![0.2; 64], Modality::Image, None)
.unwrap();
index
.add_with_modality("img2", vec![0.3; 64], Modality::Image, None)
.unwrap();
let query = MultiModalQuery::new(vec![0.25; 64], Modality::Text)
.with_target_modality(Modality::Image)
.with_limit(10);
let results = index.search(&query).unwrap();
assert!(!results.is_empty());
for result in &results {
assert_eq!(result.modality, Modality::Image);
}
}
#[test]
fn test_multimodal_entry() {
let entry = MultiModalEntry::new("item1")
.add_embedding(Modality::Text, vec![0.1, 0.2])
.add_embedding(Modality::Image, vec![0.3, 0.4]);
assert!(entry.has_modality(Modality::Text));
assert!(entry.has_modality(Modality::Image));
assert!(!entry.has_modality(Modality::Audio));
assert_eq!(entry.get_embedding(Modality::Text), Some(&vec![0.1, 0.2]));
}
#[test]
fn test_late_fusion() {
let mut index = MultiModalIndex::new().unwrap();
let entry1 = MultiModalEntry::new("item1")
.add_embedding(Modality::Text, vec![0.1; 32])
.add_embedding(Modality::Image, vec![0.2; 32]);
let entry2 = MultiModalEntry::new("item2")
.add_embedding(Modality::Text, vec![0.5; 32])
.add_embedding(Modality::Image, vec![0.6; 32]);
index.add_entry(entry1).unwrap();
index.add_entry(entry2).unwrap();
let mut query_embeddings = HashMap::new();
query_embeddings.insert(Modality::Text, vec![0.1; 32]);
query_embeddings.insert(Modality::Image, vec![0.2; 32]);
let mut weights = HashMap::new();
weights.insert(Modality::Text, 0.7);
weights.insert(Modality::Image, 0.3);
let query = MultiModalQuery::multi(query_embeddings)
.with_fusion(MultiModalFusion::LateFusion { weights })
.with_limit(10);
let results = index.search(&query).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].id, "item1");
}
#[test]
fn test_metadata_filter() {
let mut index = MultiModalIndex::new().unwrap();
let mut meta1 = HashMap::new();
meta1.insert("category".to_string(), serde_json::json!("tech"));
let mut meta2 = HashMap::new();
meta2.insert("category".to_string(), serde_json::json!("sports"));
index
.add_with_modality("doc1", vec![0.1; 32], Modality::Text, Some(meta1))
.unwrap();
index
.add_with_modality("doc2", vec![0.2; 32], Modality::Text, Some(meta2))
.unwrap();
let mut filter = HashMap::new();
filter.insert("category".to_string(), serde_json::json!("tech"));
let query = MultiModalQuery::new(vec![0.15; 32], Modality::Text).with_filter(filter);
let results = index.search(&query).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "doc1");
}
#[test]
fn test_remove() {
let mut index = MultiModalIndex::new().unwrap();
index
.add_with_modality("doc1", vec![0.1; 32], Modality::Text, None)
.unwrap();
index
.add_with_modality("doc2", vec![0.2; 32], Modality::Text, None)
.unwrap();
assert_eq!(index.len(), 2);
let removed = index.remove("doc1").unwrap();
assert!(removed);
assert_eq!(index.len(), 1);
let removed = index.remove("doc1").unwrap();
assert!(!removed);
}
#[test]
fn test_stats() {
let mut index = MultiModalIndex::new().unwrap();
index
.add_with_modality("doc1", vec![0.1; 64], Modality::Text, None)
.unwrap();
index
.add_with_modality("img1", vec![0.2; 128], Modality::Image, None)
.unwrap();
index
.add_with_modality("img2", vec![0.3; 128], Modality::Image, None)
.unwrap();
let stats = index.stats();
assert_eq!(stats.total_entries, 3);
assert_eq!(stats.modality_counts.get(&Modality::Text), Some(&1));
assert_eq!(stats.modality_counts.get(&Modality::Image), Some(&2));
assert_eq!(stats.dimensions.get(&Modality::Text), Some(&64));
assert_eq!(stats.dimensions.get(&Modality::Image), Some(&128));
}
#[test]
fn test_fusion_strategies() {
let mut index = MultiModalIndex::new().unwrap();
let entry = MultiModalEntry::new("item1")
.add_embedding(Modality::Text, vec![0.5; 64])
.add_embedding(Modality::Image, vec![0.5; 64]);
index.add_entry(entry).unwrap();
let mut query_embeddings = HashMap::new();
query_embeddings.insert(Modality::Text, vec![0.6; 64]);
query_embeddings.insert(Modality::Image, vec![0.4; 64]);
let query =
MultiModalQuery::multi(query_embeddings.clone()).with_fusion(MultiModalFusion::Average);
let results = index.search(&query).unwrap();
assert!(!results.is_empty());
let weights = vec![0.7, 0.3];
let query = MultiModalQuery::multi(query_embeddings)
.with_fusion(MultiModalFusion::WeightedAverage(weights));
let results = index.search(&query).unwrap();
assert!(!results.is_empty());
}
}