use crate::core::traits::{Model, Tokenizer};
use crate::error::Result;
use crate::pipeline::{BasePipeline, Device, Pipeline};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use trustformers_core::cache::CacheKeyBuilder;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiModalConfig {
pub max_text_length: usize,
pub max_image_size: (usize, usize),
pub max_audio_duration: f64,
pub fusion_strategy: FusionStrategy,
pub normalize_inputs: bool,
pub attention_config: AttentionConfig,
pub cross_modal_attention: bool,
pub temperature: f32,
pub top_k: Option<usize>,
pub top_p: Option<f32>,
}
impl Default for MultiModalConfig {
fn default() -> Self {
Self {
max_text_length: 512,
max_image_size: (224, 224),
max_audio_duration: 30.0,
fusion_strategy: FusionStrategy::Concatenation,
normalize_inputs: true,
attention_config: AttentionConfig::default(),
cross_modal_attention: true,
temperature: 1.0,
top_k: None,
top_p: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FusionStrategy {
Concatenation,
Addition,
WeightedAverage,
CrossAttention,
GatedFusion,
TransformerFusion,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionConfig {
pub num_heads: usize,
pub head_dim: usize,
pub dropout: f32,
pub use_relative_position: bool,
pub max_relative_position: i32,
}
impl Default for AttentionConfig {
fn default() -> Self {
Self {
num_heads: 8,
head_dim: 64,
dropout: 0.1,
use_relative_position: true,
max_relative_position: 128,
}
}
}
#[derive(Debug, Clone)]
pub struct MultiModalInput {
pub text: Option<String>,
pub image: Option<Vec<u8>>,
pub audio: Option<Vec<u8>>,
pub video: Option<Vec<u8>>,
pub metadata: HashMap<String, String>,
pub modality_weights: Option<HashMap<String, f32>>,
}
#[derive(Debug, Clone)]
pub struct ModalityFeatures {
pub text_features: Option<Vec<Vec<f32>>>,
pub image_features: Option<Vec<Vec<f32>>>,
pub audio_features: Option<Vec<Vec<f32>>>,
pub video_features: Option<Vec<Vec<f32>>>,
pub feature_dims: HashMap<String, usize>,
pub attention_masks: HashMap<String, Vec<bool>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiModalOutput {
pub text: Option<String>,
pub image: Option<Vec<u8>>,
pub audio: Option<Vec<u8>>,
pub classifications: Option<Vec<ClassificationResult>>,
pub attention_weights: Option<AttentionWeights>,
pub cross_modal_similarities: Option<HashMap<String, f32>>,
pub metadata: ProcessingMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassificationResult {
pub label: String,
pub score: f32,
pub modality_contributions: HashMap<String, f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionWeights {
pub text_to_image: Option<Vec<Vec<f32>>>,
pub image_to_text: Option<Vec<Vec<f32>>>,
pub audio_to_text: Option<Vec<Vec<f32>>>,
pub cross_modal_attention: Option<Vec<Vec<f32>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProcessingMetadata {
pub processing_time_ms: u64,
pub modalities_used: Vec<String>,
pub fusion_strategy_used: String,
pub model_confidence: f32,
pub feature_extraction_time_ms: HashMap<String, u64>,
}
pub struct MultiModalPipeline<M, T> {
base: BasePipeline<M, T>,
config: MultiModalConfig,
text_processor: Arc<TextProcessor>,
image_processor: Arc<ImageProcessor>,
audio_processor: Arc<AudioProcessor>,
video_processor: Arc<VideoProcessor>,
fusion_layer: Arc<FusionLayer>,
}
impl<M, T> MultiModalPipeline<M, T>
where
M: Model + Send + Sync + 'static,
T: Tokenizer + Send + Sync + 'static,
{
pub fn new(model: M, tokenizer: T) -> Result<Self> {
Ok(Self {
base: BasePipeline::new(model, tokenizer),
config: MultiModalConfig::default(),
text_processor: Arc::new(TextProcessor::new()),
image_processor: Arc::new(ImageProcessor::new()),
audio_processor: Arc::new(AudioProcessor::new()),
video_processor: Arc::new(VideoProcessor::new()),
fusion_layer: Arc::new(FusionLayer::new()),
})
}
pub fn with_config(mut self, config: MultiModalConfig) -> Self {
self.config = config;
self
}
pub fn with_fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
self.config.fusion_strategy = strategy;
self
}
pub fn with_cross_modal_attention(mut self, enabled: bool) -> Self {
self.config.cross_modal_attention = enabled;
self
}
pub fn to_device(mut self, device: Device) -> Self {
self.base = self.base.to_device(device);
self
}
pub fn process_multimodal(&self, input: &MultiModalInput) -> Result<ModalityFeatures> {
let mut features = ModalityFeatures {
text_features: None,
image_features: None,
audio_features: None,
video_features: None,
feature_dims: HashMap::new(),
attention_masks: HashMap::new(),
};
if let Some(text) = &input.text {
let text_features = self.text_processor.process(text, &self.config)?;
features.feature_dims.insert("text".to_string(), text_features[0].len());
features
.attention_masks
.insert("text".to_string(), vec![true; text_features.len()]);
features.text_features = Some(text_features);
}
if let Some(image) = &input.image {
let image_features = self.image_processor.process(image, &self.config)?;
features.feature_dims.insert("image".to_string(), image_features[0].len());
features
.attention_masks
.insert("image".to_string(), vec![true; image_features.len()]);
features.image_features = Some(image_features);
}
if let Some(audio) = &input.audio {
let audio_features = self.audio_processor.process(audio, &self.config)?;
features.feature_dims.insert("audio".to_string(), audio_features[0].len());
features
.attention_masks
.insert("audio".to_string(), vec![true; audio_features.len()]);
features.audio_features = Some(audio_features);
}
if let Some(video) = &input.video {
let video_features = self.video_processor.process(video, &self.config)?;
features.feature_dims.insert("video".to_string(), video_features[0].len());
features
.attention_masks
.insert("video".to_string(), vec![true; video_features.len()]);
features.video_features = Some(video_features);
}
Ok(features)
}
pub fn fuse_features(&self, features: &ModalityFeatures) -> Result<Vec<Vec<f32>>> {
self.fusion_layer.fuse(features, &self.config)
}
pub fn compute_cross_modal_attention(
&self,
features: &ModalityFeatures,
) -> Result<AttentionWeights> {
let mut attention_weights = AttentionWeights {
text_to_image: None,
image_to_text: None,
audio_to_text: None,
cross_modal_attention: None,
};
if let (Some(text_features), Some(image_features)) =
(&features.text_features, &features.image_features)
{
attention_weights.text_to_image =
Some(self.compute_attention_weights(text_features, image_features)?);
attention_weights.image_to_text =
Some(self.compute_attention_weights(image_features, text_features)?);
}
if let (Some(audio_features), Some(text_features)) =
(&features.audio_features, &features.text_features)
{
attention_weights.audio_to_text =
Some(self.compute_attention_weights(audio_features, text_features)?);
}
Ok(attention_weights)
}
fn compute_attention_weights(
&self,
query_features: &[Vec<f32>],
key_features: &[Vec<f32>],
) -> Result<Vec<Vec<f32>>> {
let mut attention_weights = Vec::new();
for query in query_features {
let mut query_weights = Vec::new();
for key in key_features {
let dot_product: f32 = query.iter().zip(key.iter()).map(|(q, k)| q * k).sum();
let attention_score = (dot_product / (query.len() as f32).sqrt()).exp();
query_weights.push(attention_score);
}
let sum: f32 = query_weights.iter().sum();
if sum > 0.0 {
query_weights.iter_mut().for_each(|w| *w /= sum);
}
attention_weights.push(query_weights);
}
Ok(attention_weights)
}
fn compute_cross_modal_similarities(
&self,
features: &ModalityFeatures,
) -> HashMap<String, f32> {
let mut similarities = HashMap::new();
if let (Some(text_features), Some(image_features)) =
(&features.text_features, &features.image_features)
{
let similarity = self.compute_feature_similarity(&text_features[0], &image_features[0]);
similarities.insert("text_image".to_string(), similarity);
}
if let (Some(text_features), Some(audio_features)) =
(&features.text_features, &features.audio_features)
{
let similarity = self.compute_feature_similarity(&text_features[0], &audio_features[0]);
similarities.insert("text_audio".to_string(), similarity);
}
if let (Some(image_features), Some(audio_features)) =
(&features.image_features, &features.audio_features)
{
let similarity =
self.compute_feature_similarity(&image_features[0], &audio_features[0]);
similarities.insert("image_audio".to_string(), similarity);
}
similarities
}
fn compute_feature_similarity(&self, features1: &[f32], features2: &[f32]) -> f32 {
let min_len = features1.len().min(features2.len());
let dot_product: f32 = features1[..min_len]
.iter()
.zip(features2[..min_len].iter())
.map(|(a, b)| a * b)
.sum();
let norm1: f32 = features1[..min_len].iter().map(|x| x * x).sum::<f32>().sqrt();
let norm2: f32 = features2[..min_len].iter().map(|x| x * x).sum::<f32>().sqrt();
if norm1 > 0.0 && norm2 > 0.0 {
dot_product / (norm1 * norm2)
} else {
0.0
}
}
}
impl<M, T> Pipeline for MultiModalPipeline<M, T>
where
M: Model + Send + Sync + 'static,
T: Tokenizer + Send + Sync + 'static,
{
type Input = MultiModalInput;
type Output = MultiModalOutput;
fn __call__(&self, input: Self::Input) -> Result<Self::Output> {
let start_time = std::time::Instant::now();
let mut feature_extraction_times = HashMap::new();
let cache_key = if let Some(cache) = &self.base.cache {
let mut builder = CacheKeyBuilder::new("multimodal", "inference");
if let Some(text) = &input.text {
builder = builder.with_text(text);
}
if let Some(image) = &input.image {
builder = builder.with_param("image", image);
}
if let Some(audio) = &input.audio {
builder = builder.with_param("audio", audio);
}
builder = builder.with_param(
"config",
&serde_json::to_string(&self.config).unwrap_or_default(),
);
let key = builder.build();
if let Some(cached) = cache.get(&key) {
if let Ok(output) = serde_json::from_slice::<MultiModalOutput>(&cached) {
return Ok(output);
}
}
Some(key)
} else {
None
};
let feature_start = std::time::Instant::now();
let features = self.process_multimodal(&input)?;
let feature_time = feature_start.elapsed().as_millis() as u64;
if input.text.is_some() {
feature_extraction_times.insert("text".to_string(), feature_time / 4);
}
if input.image.is_some() {
feature_extraction_times.insert("image".to_string(), feature_time / 4);
}
if input.audio.is_some() {
feature_extraction_times.insert("audio".to_string(), feature_time / 4);
}
if input.video.is_some() {
feature_extraction_times.insert("video".to_string(), feature_time / 4);
}
let _fused_features = self.fuse_features(&features)?;
let attention_weights = if self.config.cross_modal_attention {
Some(self.compute_cross_modal_attention(&features)?)
} else {
None
};
let cross_modal_similarities = Some(self.compute_cross_modal_similarities(&features));
let mut modalities_used = Vec::new();
if input.text.is_some() {
modalities_used.push("text".to_string());
}
if input.image.is_some() {
modalities_used.push("image".to_string());
}
if input.audio.is_some() {
modalities_used.push("audio".to_string());
}
if input.video.is_some() {
modalities_used.push("video".to_string());
}
let output = MultiModalOutput {
text: input.text.clone().map(|t| format!("Processed: {}", t)),
image: None, audio: None, classifications: Some(vec![ClassificationResult {
label: "positive".to_string(),
score: 0.85,
modality_contributions: [("text".to_string(), 0.4), ("image".to_string(), 0.6)]
.into_iter()
.collect(),
}]),
attention_weights,
cross_modal_similarities,
metadata: ProcessingMetadata {
processing_time_ms: start_time.elapsed().as_millis() as u64,
modalities_used,
fusion_strategy_used: format!("{:?}", self.config.fusion_strategy),
model_confidence: 0.85,
feature_extraction_time_ms: feature_extraction_times,
},
};
if let (Some(cache), Some(key)) = (&self.base.cache, cache_key) {
if let Ok(serialized) = serde_json::to_vec(&output) {
cache.insert(key, serialized);
}
}
Ok(output)
}
}
pub struct TextProcessor;
impl Default for TextProcessor {
fn default() -> Self {
Self::new()
}
}
impl TextProcessor {
pub fn new() -> Self {
Self
}
pub fn process(&self, text: &str, config: &MultiModalConfig) -> Result<Vec<Vec<f32>>> {
let tokens: Vec<&str> = text.split_whitespace().collect();
let max_tokens = config.max_text_length.min(tokens.len());
let mut features = Vec::new();
for i in 0..max_tokens {
let embedding: Vec<f32> =
(0..768).map(|j| ((i * 768 + j) as f32).sin() * 0.1).collect();
features.push(embedding);
}
Ok(features)
}
}
pub struct ImageProcessor;
impl Default for ImageProcessor {
fn default() -> Self {
Self::new()
}
}
impl ImageProcessor {
pub fn new() -> Self {
Self
}
pub fn process(&self, _image: &[u8], config: &MultiModalConfig) -> Result<Vec<Vec<f32>>> {
let patch_size = 16;
let (width, height) = config.max_image_size;
let num_patches = (width / patch_size) * (height / patch_size);
let mut features = Vec::new();
for i in 0..num_patches {
let embedding: Vec<f32> =
(0..768).map(|j| ((i * 768 + j) as f32).cos() * 0.1).collect();
features.push(embedding);
}
Ok(features)
}
}
pub struct AudioProcessor;
impl Default for AudioProcessor {
fn default() -> Self {
Self::new()
}
}
impl AudioProcessor {
pub fn new() -> Self {
Self
}
pub fn process(&self, _audio: &[u8], config: &MultiModalConfig) -> Result<Vec<Vec<f32>>> {
let sample_rate = 16000;
let frame_length = 1024;
let hop_length = 512;
let num_frames =
((config.max_audio_duration * sample_rate as f64) / hop_length as f64) as usize;
let mut features = Vec::new();
for i in 0..num_frames {
let embedding: Vec<f32> =
(0..128).map(|j| ((i * 128 + j) as f32).sin() * 0.2).collect();
features.push(embedding);
}
Ok(features)
}
}
pub struct VideoProcessor;
impl Default for VideoProcessor {
fn default() -> Self {
Self::new()
}
}
impl VideoProcessor {
pub fn new() -> Self {
Self
}
pub fn process(&self, _video: &[u8], config: &MultiModalConfig) -> Result<Vec<Vec<f32>>> {
let frames_per_second = 30;
let max_frames = (config.max_audio_duration * frames_per_second as f64) as usize;
let mut features = Vec::new();
for i in 0..max_frames {
let embedding: Vec<f32> =
(0..512).map(|j| ((i * 512 + j) as f32).cos() * 0.15).collect();
features.push(embedding);
}
Ok(features)
}
}
pub struct FusionLayer;
impl Default for FusionLayer {
fn default() -> Self {
Self::new()
}
}
impl FusionLayer {
pub fn new() -> Self {
Self
}
pub fn fuse(
&self,
features: &ModalityFeatures,
config: &MultiModalConfig,
) -> Result<Vec<Vec<f32>>> {
match config.fusion_strategy {
FusionStrategy::Concatenation => self.concatenate_features(features),
FusionStrategy::Addition => self.add_features(features),
FusionStrategy::WeightedAverage => self.weighted_average_features(features),
FusionStrategy::CrossAttention => self.cross_attention_fusion(features),
FusionStrategy::GatedFusion => self.gated_fusion(features),
FusionStrategy::TransformerFusion => self.transformer_fusion(features),
}
}
fn concatenate_features(&self, features: &ModalityFeatures) -> Result<Vec<Vec<f32>>> {
let mut fused_features = Vec::new();
let max_len = [
features.text_features.as_ref().map(|f| f.len()).unwrap_or(0),
features.image_features.as_ref().map(|f| f.len()).unwrap_or(0),
features.audio_features.as_ref().map(|f| f.len()).unwrap_or(0),
features.video_features.as_ref().map(|f| f.len()).unwrap_or(0),
]
.into_iter()
.max()
.unwrap_or(0);
for i in 0..max_len {
let mut combined_feature = Vec::new();
if let Some(text_features) = &features.text_features {
if i < text_features.len() {
combined_feature.extend_from_slice(&text_features[i]);
}
}
if let Some(image_features) = &features.image_features {
if i < image_features.len() {
combined_feature.extend_from_slice(&image_features[i]);
}
}
if let Some(audio_features) = &features.audio_features {
if i < audio_features.len() {
combined_feature.extend_from_slice(&audio_features[i]);
}
}
if let Some(video_features) = &features.video_features {
if i < video_features.len() {
combined_feature.extend_from_slice(&video_features[i]);
}
}
if !combined_feature.is_empty() {
fused_features.push(combined_feature);
}
}
Ok(fused_features)
}
fn add_features(&self, features: &ModalityFeatures) -> Result<Vec<Vec<f32>>> {
let mut fused_features = Vec::new();
let common_dim = 768;
let max_len = [
features.text_features.as_ref().map(|f| f.len()).unwrap_or(0),
features.image_features.as_ref().map(|f| f.len()).unwrap_or(0),
features.audio_features.as_ref().map(|f| f.len()).unwrap_or(0),
features.video_features.as_ref().map(|f| f.len()).unwrap_or(0),
]
.into_iter()
.max()
.unwrap_or(0);
for i in 0..max_len {
let mut combined_feature = vec![0.0; common_dim];
let mut count = 0;
if let Some(text_features) = &features.text_features {
if i < text_features.len() && text_features[i].len() >= common_dim {
for j in 0..common_dim {
combined_feature[j] += text_features[i][j];
}
count += 1;
}
}
if let Some(image_features) = &features.image_features {
if i < image_features.len() && image_features[i].len() >= common_dim {
for j in 0..common_dim {
combined_feature[j] += image_features[i][j];
}
count += 1;
}
}
if count > 0 {
combined_feature.iter_mut().for_each(|x| *x /= count as f32);
fused_features.push(combined_feature);
}
}
Ok(fused_features)
}
fn weighted_average_features(&self, features: &ModalityFeatures) -> Result<Vec<Vec<f32>>> {
let text_weight = 0.4;
let image_weight = 0.6;
let audio_weight = 0.3;
let video_weight = 0.2;
let mut fused_features = Vec::new();
let common_dim = 768;
let max_len = [
features.text_features.as_ref().map(|f| f.len()).unwrap_or(0),
features.image_features.as_ref().map(|f| f.len()).unwrap_or(0),
features.audio_features.as_ref().map(|f| f.len()).unwrap_or(0),
features.video_features.as_ref().map(|f| f.len()).unwrap_or(0),
]
.into_iter()
.max()
.unwrap_or(0);
for i in 0..max_len {
let mut combined_feature = vec![0.0; common_dim];
let mut total_weight = 0.0;
if let Some(text_features) = &features.text_features {
if i < text_features.len() && text_features[i].len() >= common_dim {
for j in 0..common_dim {
combined_feature[j] += text_features[i][j] * text_weight;
}
total_weight += text_weight;
}
}
if let Some(image_features) = &features.image_features {
if i < image_features.len() && image_features[i].len() >= common_dim {
for j in 0..common_dim {
combined_feature[j] += image_features[i][j] * image_weight;
}
total_weight += image_weight;
}
}
if total_weight > 0.0 {
combined_feature.iter_mut().for_each(|x| *x /= total_weight);
fused_features.push(combined_feature);
}
}
Ok(fused_features)
}
fn cross_attention_fusion(&self, features: &ModalityFeatures) -> Result<Vec<Vec<f32>>> {
self.concatenate_features(features)
}
fn gated_fusion(&self, features: &ModalityFeatures) -> Result<Vec<Vec<f32>>> {
self.weighted_average_features(features)
}
fn transformer_fusion(&self, features: &ModalityFeatures) -> Result<Vec<Vec<f32>>> {
self.concatenate_features(features)
}
}
pub fn multimodal_pipeline<M, T>(model: M, tokenizer: T) -> Result<MultiModalPipeline<M, T>>
where
M: Model + Send + Sync + 'static,
T: Tokenizer + Send + Sync + 'static,
{
MultiModalPipeline::new(model, tokenizer)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default_values() {
let cfg = MultiModalConfig::default();
assert_eq!(cfg.max_text_length, 512);
assert_eq!(cfg.max_image_size, (224, 224));
assert!((cfg.max_audio_duration - 30.0).abs() < 1e-6);
assert!(cfg.normalize_inputs);
assert!(cfg.cross_modal_attention);
assert!((cfg.temperature - 1.0).abs() < 1e-6);
}
#[test]
fn test_config_clone() {
let cfg = MultiModalConfig {
max_text_length: 256,
..MultiModalConfig::default()
};
assert_eq!(cfg.clone().max_text_length, 256);
}
#[test]
fn test_attention_config_default() {
let acfg = AttentionConfig::default();
assert_eq!(acfg.num_heads, 8);
assert_eq!(acfg.head_dim, 64);
assert!((acfg.dropout - 0.1).abs() < 1e-6);
assert!(acfg.use_relative_position);
assert_eq!(acfg.max_relative_position, 128);
}
#[test]
fn test_input_text_only() {
let input = MultiModalInput {
text: Some("Hello world".to_string()),
image: None,
audio: None,
video: None,
metadata: HashMap::new(),
modality_weights: None,
};
assert!(input.text.is_some());
assert!(input.image.is_none());
}
#[test]
fn test_input_image_plus_text() {
let input = MultiModalInput {
text: Some("Describe this image".to_string()),
image: Some(vec![0u8; 100]),
audio: None,
video: None,
metadata: HashMap::new(),
modality_weights: None,
};
assert!(input.text.is_some());
assert!(input.image.is_some());
}
#[test]
fn test_input_multimodality_flags() {
let input = MultiModalInput {
text: Some("text".to_string()),
image: Some(vec![1, 2, 3]),
audio: Some(vec![4, 5, 6]),
video: None,
metadata: HashMap::new(),
modality_weights: None,
};
let mut modalities = Vec::new();
if input.text.is_some() {
modalities.push("text");
}
if input.image.is_some() {
modalities.push("image");
}
if input.audio.is_some() {
modalities.push("audio");
}
if input.video.is_some() {
modalities.push("video");
}
assert_eq!(modalities.len(), 3);
}
#[test]
fn test_text_processor_produces_features() {
let processor = TextProcessor::new();
let cfg = MultiModalConfig::default();
let features =
processor.process("Hello world test", &cfg).expect("text processing succeeded");
assert_eq!(features.len(), 3);
assert_eq!(features[0].len(), 768); }
#[test]
fn test_text_processor_respects_max_length() {
let processor = TextProcessor::new();
let cfg = MultiModalConfig {
max_text_length: 2,
..MultiModalConfig::default()
};
let text = "one two three four five";
let features = processor.process(text, &cfg).expect("text processing succeeded");
assert!(features.len() <= 2);
}
#[test]
fn test_text_processor_empty_text() {
let processor = TextProcessor::new();
let cfg = MultiModalConfig::default();
let features = processor.process("", &cfg).expect("empty text processing succeeded");
assert!(features.is_empty());
}
#[test]
fn test_image_processor_produces_patch_features() {
let processor = ImageProcessor::new();
let cfg = MultiModalConfig::default();
let dummy_image = vec![0u8; 224 * 224 * 3];
let features = processor.process(&dummy_image, &cfg).expect("image processing succeeded");
assert_eq!(features.len(), 196);
assert_eq!(features[0].len(), 768);
}
#[test]
fn test_image_processor_feature_dimensionality() {
let processor = ImageProcessor::new();
let cfg = MultiModalConfig {
max_image_size: (32, 32),
..MultiModalConfig::default()
};
let dummy = vec![0u8; 32 * 32 * 3];
let features = processor.process(&dummy, &cfg).expect("ok");
assert_eq!(features.len(), 4);
}
#[test]
fn test_audio_processor_produces_frames() {
let processor = AudioProcessor::new();
let cfg = MultiModalConfig {
max_audio_duration: 1.0,
..MultiModalConfig::default()
};
let dummy_audio = vec![0u8; 16000];
let features = processor.process(&dummy_audio, &cfg).expect("audio processing succeeded");
assert!(!features.is_empty());
assert_eq!(features[0].len(), 128); }
#[test]
fn test_fusion_concatenation_non_empty() {
let fusion = FusionLayer::new();
let cfg = MultiModalConfig {
fusion_strategy: FusionStrategy::Concatenation,
..MultiModalConfig::default()
};
let features = ModalityFeatures {
text_features: Some(vec![vec![0.1; 768]; 3]),
image_features: None,
audio_features: None,
video_features: None,
feature_dims: HashMap::new(),
attention_masks: HashMap::new(),
};
let fused = fusion.fuse(&features, &cfg).expect("fusion succeeded");
assert!(!fused.is_empty());
}
#[test]
fn test_fusion_addition_with_two_modalities() {
let fusion = FusionLayer::new();
let cfg = MultiModalConfig {
fusion_strategy: FusionStrategy::Addition,
..MultiModalConfig::default()
};
let features = ModalityFeatures {
text_features: Some(vec![vec![1.0; 768]]),
image_features: Some(vec![vec![2.0; 768]]),
audio_features: None,
video_features: None,
feature_dims: HashMap::new(),
attention_masks: HashMap::new(),
};
let fused = fusion.fuse(&features, &cfg).expect("fusion succeeded");
assert_eq!(fused.len(), 1);
assert!(
(fused[0][0] - 1.5).abs() < 1e-4,
"expected 1.5, got {}",
fused[0][0]
);
}
#[test]
fn test_fusion_weighted_average() {
let fusion = FusionLayer::new();
let cfg = MultiModalConfig {
fusion_strategy: FusionStrategy::WeightedAverage,
..MultiModalConfig::default()
};
let features = ModalityFeatures {
text_features: Some(vec![vec![1.0; 768]]),
image_features: Some(vec![vec![1.0; 768]]),
audio_features: None,
video_features: None,
feature_dims: HashMap::new(),
attention_masks: HashMap::new(),
};
let fused = fusion.fuse(&features, &cfg).expect("fusion succeeded");
assert!(!fused.is_empty());
}
#[test]
fn test_attention_weights_normalised() {
let query = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
let key = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let mut attention_weights = Vec::new();
for q in &query {
let mut q_weights = Vec::new();
for k in &key {
let dot: f32 = q.iter().zip(k.iter()).map(|(a, b)| a * b).sum();
let score = (dot / (q.len() as f32).sqrt()).exp();
q_weights.push(score);
}
let sum: f32 = q_weights.iter().sum();
if sum > 0.0 {
q_weights.iter_mut().for_each(|w| *w /= sum);
}
attention_weights.push(q_weights);
}
for row in &attention_weights {
let sum: f32 = row.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "row sum = {}", sum);
}
}
#[test]
fn test_cross_modal_similarity_range() {
let a = [1.0_f32, 0.0, 0.0];
let b = [0.0_f32, 1.0, 0.0];
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
let sim = if na > 0.0 && nb > 0.0 { dot / (na * nb) } else { 0.0 };
assert!((-1.0..=1.0).contains(&sim), "sim = {}", sim);
}
#[test]
fn test_classification_result_score_in_range() {
let result = ClassificationResult {
label: "positive".to_string(),
score: 0.85,
modality_contributions: HashMap::new(),
};
assert!(result.score >= 0.0 && result.score <= 1.0);
}
#[test]
fn test_processing_metadata_modalities_list() {
let meta = ProcessingMetadata {
processing_time_ms: 42,
modalities_used: vec!["text".to_string(), "image".to_string()],
fusion_strategy_used: "Concatenation".to_string(),
model_confidence: 0.85,
feature_extraction_time_ms: HashMap::new(),
};
assert_eq!(meta.modalities_used.len(), 2);
assert!(meta.modalities_used.contains(&"text".to_string()));
}
}