mlmf/
multimodal.rs

1//! Multi-Modal Model Support
2//!
3//! This module provides comprehensive support for multi-modal models that process
4//! different types of input data (text, images, audio, video) and coordinate
5//! cross-modal attention and processing.
6
7use crate::error::Result;
8use candle_core::{DType, Device, Tensor};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12
13/// Supported modalities for multi-modal models
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub enum Modality {
16    /// Text input (tokens, embeddings)
17    Text,
18    /// Image input (pixels, features)
19    Image,
20    /// Audio input (waveforms, spectrograms)
21    Audio,
22    /// Video input (frames, temporal features)
23    Video,
24    /// Custom modality with identifier
25    Custom(u32),
26}
27
28impl Modality {
29    /// Get the string representation of the modality
30    pub fn as_str(&self) -> &'static str {
31        match self {
32            Modality::Text => "text",
33            Modality::Image => "image",
34            Modality::Audio => "audio",
35            Modality::Video => "video",
36            Modality::Custom(_) => "custom",
37        }
38    }
39
40    /// Get the default embedding dimension for this modality
41    pub fn default_embedding_dim(&self) -> Option<usize> {
42        match self {
43            Modality::Text => Some(768),   // BERT-base dimension
44            Modality::Image => Some(2048), // ResNet-50 dimension
45            Modality::Audio => Some(512),  // Audio transformer dimension
46            Modality::Video => Some(1024), // Video transformer dimension
47            Modality::Custom(_) => None,
48        }
49    }
50}
51
52/// Configuration for multi-modal processing
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct MultiModalConfig {
55    /// Supported modalities and their configurations
56    pub modalities: HashMap<Modality, ModalityConfig>,
57    /// Cross-modal attention configuration
58    pub cross_modal_attention: CrossModalAttentionConfig,
59    /// Fusion strategy for combining modalities
60    pub fusion_strategy: FusionStrategy,
61    /// Maximum sequence length per modality
62    pub max_sequence_lengths: HashMap<Modality, usize>,
63    /// Whether to use distributed processing
64    pub distributed: bool,
65}
66
67impl Default for MultiModalConfig {
68    fn default() -> Self {
69        let mut modalities = HashMap::new();
70        modalities.insert(Modality::Text, ModalityConfig::default_text());
71        modalities.insert(Modality::Image, ModalityConfig::default_image());
72
73        let mut max_lengths = HashMap::new();
74        max_lengths.insert(Modality::Text, 512);
75        max_lengths.insert(Modality::Image, 196); // 14x14 patches
76
77        Self {
78            modalities,
79            cross_modal_attention: CrossModalAttentionConfig::default(),
80            fusion_strategy: FusionStrategy::EarlyFusion,
81            max_sequence_lengths: max_lengths,
82            distributed: false,
83        }
84    }
85}
86
87/// Configuration for a specific modality
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ModalityConfig {
90    /// Input preprocessing configuration
91    pub preprocessing: PreprocessingConfig,
92    /// Embedding dimension
93    pub embedding_dim: usize,
94    /// Whether this modality requires special attention
95    pub requires_special_attention: bool,
96    /// Device placement for this modality (serialized as string)
97    #[serde(skip)]
98    pub device_placement: Option<Device>,
99}
100
101impl ModalityConfig {
102    /// Default configuration for text modality
103    pub fn default_text() -> Self {
104        Self {
105            preprocessing: PreprocessingConfig::Text {
106                tokenizer_path: None,
107                max_length: 512,
108                padding: true,
109                truncation: true,
110            },
111            embedding_dim: 768,
112            requires_special_attention: false,
113            device_placement: None,
114        }
115    }
116
117    /// Default configuration for image modality
118    pub fn default_image() -> Self {
119        Self {
120            preprocessing: PreprocessingConfig::Image {
121                resize: Some((224, 224)),
122                normalize: true,
123                patch_size: 16,
124            },
125            embedding_dim: 2048,
126            requires_special_attention: true,
127            device_placement: None,
128        }
129    }
130}
131
132/// Preprocessing configuration for different modalities
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub enum PreprocessingConfig {
135    /// Text preprocessing
136    Text {
137        tokenizer_path: Option<String>,
138        max_length: usize,
139        padding: bool,
140        truncation: bool,
141    },
142    /// Image preprocessing
143    Image {
144        resize: Option<(u32, u32)>,
145        normalize: bool,
146        patch_size: u32,
147    },
148    /// Audio preprocessing
149    Audio {
150        sample_rate: u32,
151        frame_length: usize,
152        hop_length: usize,
153        n_mels: usize,
154    },
155    /// Video preprocessing
156    Video {
157        frame_rate: f32,
158        frame_size: (u32, u32),
159        temporal_window: usize,
160    },
161    /// Custom preprocessing
162    Custom(HashMap<String, serde_json::Value>),
163}
164
165/// Cross-modal attention configuration
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct CrossModalAttentionConfig {
168    /// Number of attention heads for cross-modal attention
169    pub num_heads: usize,
170    /// Attention dropout rate
171    pub dropout: f32,
172    /// Whether to use scaled dot-product attention
173    pub scaled_attention: bool,
174    /// Temperature for attention scaling
175    pub temperature: f32,
176}
177
178impl Default for CrossModalAttentionConfig {
179    fn default() -> Self {
180        Self {
181            num_heads: 8,
182            dropout: 0.1,
183            scaled_attention: true,
184            temperature: 1.0,
185        }
186    }
187}
188
189/// Strategy for fusing multiple modalities
190#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
191pub enum FusionStrategy {
192    /// Concatenate modality embeddings early
193    EarlyFusion,
194    /// Fuse modalities at intermediate layers
195    MiddleFusion { fusion_layers: Vec<usize> },
196    /// Fuse modalities at the final layer
197    LateFusion,
198    /// Attention-based fusion
199    AttentionFusion { attention_dim: usize },
200    /// Custom fusion with parameters
201    Custom {
202        strategy_name: String,
203        params: HashMap<String, f32>,
204    },
205}
206
207/// Input data for multi-modal processing
208#[derive(Debug)]
209pub struct MultiModalInput {
210    /// Input data per modality
211    pub modality_inputs: HashMap<Modality, ModalityInput>,
212    /// Optional attention masks per modality
213    pub attention_masks: HashMap<Modality, Tensor>,
214    /// Batch size
215    pub batch_size: usize,
216}
217
218/// Input data for a specific modality
219#[derive(Debug)]
220pub enum ModalityInput {
221    /// Text tokens
222    Text(Tensor), // Shape: [batch_size, sequence_length]
223    /// Image pixels or features
224    Image(Tensor), // Shape: [batch_size, channels, height, width] or [batch_size, patches, features]
225    /// Audio waveform or features
226    Audio(Tensor), // Shape: [batch_size, time_steps, features]
227    /// Video frames
228    Video(Tensor), // Shape: [batch_size, frames, channels, height, width]
229    /// Custom tensor data
230    Custom(Tensor),
231}
232
233impl ModalityInput {
234    /// Get the tensor from the modality input
235    pub fn tensor(&self) -> &Tensor {
236        match self {
237            ModalityInput::Text(t) => t,
238            ModalityInput::Image(t) => t,
239            ModalityInput::Audio(t) => t,
240            ModalityInput::Video(t) => t,
241            ModalityInput::Custom(t) => t,
242        }
243    }
244
245    /// Get the shape of the input tensor
246    pub fn shape(&self) -> &[usize] {
247        self.tensor().shape().dims()
248    }
249
250    /// Get the modality type
251    pub fn modality(&self) -> Modality {
252        match self {
253            ModalityInput::Text(_) => Modality::Text,
254            ModalityInput::Image(_) => Modality::Image,
255            ModalityInput::Audio(_) => Modality::Audio,
256            ModalityInput::Video(_) => Modality::Video,
257            ModalityInput::Custom(_) => Modality::Custom(0),
258        }
259    }
260}
261
262/// Output from multi-modal processing
263#[derive(Debug)]
264pub struct MultiModalOutput {
265    /// Fused representation
266    pub fused_embeddings: Tensor,
267    /// Per-modality embeddings
268    pub modality_embeddings: HashMap<Modality, Tensor>,
269    /// Cross-modal attention weights
270    pub attention_weights: HashMap<(Modality, Modality), Tensor>,
271    /// Additional metadata
272    pub metadata: HashMap<String, serde_json::Value>,
273}
274
275/// Trait for multi-modal processors
276pub trait MultiModalProcessor: Send + Sync {
277    /// Process multi-modal input
278    fn process(&self, input: MultiModalInput) -> Result<MultiModalOutput>;
279
280    /// Get supported modalities
281    fn supported_modalities(&self) -> Vec<Modality>;
282
283    /// Get configuration
284    fn config(&self) -> &MultiModalConfig;
285
286    /// Preprocess input for a specific modality
287    fn preprocess_modality(&self, modality: Modality, input: &Tensor) -> Result<Tensor>;
288
289    /// Apply cross-modal attention
290    fn cross_modal_attention(
291        &self,
292        query_modality: Modality,
293        key_modality: Modality,
294        query: &Tensor,
295        key: &Tensor,
296        value: &Tensor,
297    ) -> Result<(Tensor, Tensor)>; // (attended_output, attention_weights)
298}
299
300/// Statistics for multi-modal processing
301#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct MultiModalStats {
303    /// Processing time per modality (in milliseconds)
304    pub processing_times: HashMap<Modality, f64>,
305    /// Memory usage per modality (in bytes)
306    pub memory_usage: HashMap<Modality, usize>,
307    /// Cross-modal attention statistics
308    pub attention_stats: HashMap<(Modality, Modality), AttentionStats>,
309    /// Total processing time
310    pub total_time_ms: f64,
311    /// Number of processed samples
312    pub samples_processed: usize,
313}
314
315/// Statistics for attention mechanisms
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct AttentionStats {
318    /// Average attention score
319    pub avg_attention: f32,
320    /// Maximum attention score
321    pub max_attention: f32,
322    /// Minimum attention score
323    pub min_attention: f32,
324    /// Attention entropy (measure of attention distribution)
325    pub entropy: f32,
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn test_modality_str_representation() {
334        assert_eq!(Modality::Text.as_str(), "text");
335        assert_eq!(Modality::Image.as_str(), "image");
336        assert_eq!(Modality::Audio.as_str(), "audio");
337        assert_eq!(Modality::Video.as_str(), "video");
338        assert_eq!(Modality::Custom(42).as_str(), "custom");
339    }
340
341    #[test]
342    fn test_modality_embedding_dims() {
343        assert_eq!(Modality::Text.default_embedding_dim(), Some(768));
344        assert_eq!(Modality::Image.default_embedding_dim(), Some(2048));
345        assert_eq!(Modality::Audio.default_embedding_dim(), Some(512));
346        assert_eq!(Modality::Video.default_embedding_dim(), Some(1024));
347        assert_eq!(Modality::Custom(0).default_embedding_dim(), None);
348    }
349
350    #[test]
351    fn test_default_config() {
352        let config = MultiModalConfig::default();
353        assert!(config.modalities.contains_key(&Modality::Text));
354        assert!(config.modalities.contains_key(&Modality::Image));
355        assert_eq!(config.fusion_strategy, FusionStrategy::EarlyFusion);
356        assert!(!config.distributed);
357    }
358}