1use crate::error::Result;
8use candle_core::{DType, Device, Tensor};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub enum Modality {
16 Text,
18 Image,
20 Audio,
22 Video,
24 Custom(u32),
26}
27
28impl Modality {
29 pub fn as_str(&self) -> &'static str {
31 match self {
32 Modality::Text => "text",
33 Modality::Image => "image",
34 Modality::Audio => "audio",
35 Modality::Video => "video",
36 Modality::Custom(_) => "custom",
37 }
38 }
39
40 pub fn default_embedding_dim(&self) -> Option<usize> {
42 match self {
43 Modality::Text => Some(768), Modality::Image => Some(2048), Modality::Audio => Some(512), Modality::Video => Some(1024), Modality::Custom(_) => None,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct MultiModalConfig {
55 pub modalities: HashMap<Modality, ModalityConfig>,
57 pub cross_modal_attention: CrossModalAttentionConfig,
59 pub fusion_strategy: FusionStrategy,
61 pub max_sequence_lengths: HashMap<Modality, usize>,
63 pub distributed: bool,
65}
66
67impl Default for MultiModalConfig {
68 fn default() -> Self {
69 let mut modalities = HashMap::new();
70 modalities.insert(Modality::Text, ModalityConfig::default_text());
71 modalities.insert(Modality::Image, ModalityConfig::default_image());
72
73 let mut max_lengths = HashMap::new();
74 max_lengths.insert(Modality::Text, 512);
75 max_lengths.insert(Modality::Image, 196); Self {
78 modalities,
79 cross_modal_attention: CrossModalAttentionConfig::default(),
80 fusion_strategy: FusionStrategy::EarlyFusion,
81 max_sequence_lengths: max_lengths,
82 distributed: false,
83 }
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ModalityConfig {
90 pub preprocessing: PreprocessingConfig,
92 pub embedding_dim: usize,
94 pub requires_special_attention: bool,
96 #[serde(skip)]
98 pub device_placement: Option<Device>,
99}
100
101impl ModalityConfig {
102 pub fn default_text() -> Self {
104 Self {
105 preprocessing: PreprocessingConfig::Text {
106 tokenizer_path: None,
107 max_length: 512,
108 padding: true,
109 truncation: true,
110 },
111 embedding_dim: 768,
112 requires_special_attention: false,
113 device_placement: None,
114 }
115 }
116
117 pub fn default_image() -> Self {
119 Self {
120 preprocessing: PreprocessingConfig::Image {
121 resize: Some((224, 224)),
122 normalize: true,
123 patch_size: 16,
124 },
125 embedding_dim: 2048,
126 requires_special_attention: true,
127 device_placement: None,
128 }
129 }
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub enum PreprocessingConfig {
135 Text {
137 tokenizer_path: Option<String>,
138 max_length: usize,
139 padding: bool,
140 truncation: bool,
141 },
142 Image {
144 resize: Option<(u32, u32)>,
145 normalize: bool,
146 patch_size: u32,
147 },
148 Audio {
150 sample_rate: u32,
151 frame_length: usize,
152 hop_length: usize,
153 n_mels: usize,
154 },
155 Video {
157 frame_rate: f32,
158 frame_size: (u32, u32),
159 temporal_window: usize,
160 },
161 Custom(HashMap<String, serde_json::Value>),
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct CrossModalAttentionConfig {
168 pub num_heads: usize,
170 pub dropout: f32,
172 pub scaled_attention: bool,
174 pub temperature: f32,
176}
177
178impl Default for CrossModalAttentionConfig {
179 fn default() -> Self {
180 Self {
181 num_heads: 8,
182 dropout: 0.1,
183 scaled_attention: true,
184 temperature: 1.0,
185 }
186 }
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
191pub enum FusionStrategy {
192 EarlyFusion,
194 MiddleFusion { fusion_layers: Vec<usize> },
196 LateFusion,
198 AttentionFusion { attention_dim: usize },
200 Custom {
202 strategy_name: String,
203 params: HashMap<String, f32>,
204 },
205}
206
207#[derive(Debug)]
209pub struct MultiModalInput {
210 pub modality_inputs: HashMap<Modality, ModalityInput>,
212 pub attention_masks: HashMap<Modality, Tensor>,
214 pub batch_size: usize,
216}
217
218#[derive(Debug)]
220pub enum ModalityInput {
221 Text(Tensor), Image(Tensor), Audio(Tensor), Video(Tensor), Custom(Tensor),
231}
232
233impl ModalityInput {
234 pub fn tensor(&self) -> &Tensor {
236 match self {
237 ModalityInput::Text(t) => t,
238 ModalityInput::Image(t) => t,
239 ModalityInput::Audio(t) => t,
240 ModalityInput::Video(t) => t,
241 ModalityInput::Custom(t) => t,
242 }
243 }
244
245 pub fn shape(&self) -> &[usize] {
247 self.tensor().shape().dims()
248 }
249
250 pub fn modality(&self) -> Modality {
252 match self {
253 ModalityInput::Text(_) => Modality::Text,
254 ModalityInput::Image(_) => Modality::Image,
255 ModalityInput::Audio(_) => Modality::Audio,
256 ModalityInput::Video(_) => Modality::Video,
257 ModalityInput::Custom(_) => Modality::Custom(0),
258 }
259 }
260}
261
262#[derive(Debug)]
264pub struct MultiModalOutput {
265 pub fused_embeddings: Tensor,
267 pub modality_embeddings: HashMap<Modality, Tensor>,
269 pub attention_weights: HashMap<(Modality, Modality), Tensor>,
271 pub metadata: HashMap<String, serde_json::Value>,
273}
274
275pub trait MultiModalProcessor: Send + Sync {
277 fn process(&self, input: MultiModalInput) -> Result<MultiModalOutput>;
279
280 fn supported_modalities(&self) -> Vec<Modality>;
282
283 fn config(&self) -> &MultiModalConfig;
285
286 fn preprocess_modality(&self, modality: Modality, input: &Tensor) -> Result<Tensor>;
288
289 fn cross_modal_attention(
291 &self,
292 query_modality: Modality,
293 key_modality: Modality,
294 query: &Tensor,
295 key: &Tensor,
296 value: &Tensor,
297 ) -> Result<(Tensor, Tensor)>; }
299
300#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct MultiModalStats {
303 pub processing_times: HashMap<Modality, f64>,
305 pub memory_usage: HashMap<Modality, usize>,
307 pub attention_stats: HashMap<(Modality, Modality), AttentionStats>,
309 pub total_time_ms: f64,
311 pub samples_processed: usize,
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct AttentionStats {
318 pub avg_attention: f32,
320 pub max_attention: f32,
322 pub min_attention: f32,
324 pub entropy: f32,
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[test]
333 fn test_modality_str_representation() {
334 assert_eq!(Modality::Text.as_str(), "text");
335 assert_eq!(Modality::Image.as_str(), "image");
336 assert_eq!(Modality::Audio.as_str(), "audio");
337 assert_eq!(Modality::Video.as_str(), "video");
338 assert_eq!(Modality::Custom(42).as_str(), "custom");
339 }
340
341 #[test]
342 fn test_modality_embedding_dims() {
343 assert_eq!(Modality::Text.default_embedding_dim(), Some(768));
344 assert_eq!(Modality::Image.default_embedding_dim(), Some(2048));
345 assert_eq!(Modality::Audio.default_embedding_dim(), Some(512));
346 assert_eq!(Modality::Video.default_embedding_dim(), Some(1024));
347 assert_eq!(Modality::Custom(0).default_embedding_dim(), None);
348 }
349
350 #[test]
351 fn test_default_config() {
352 let config = MultiModalConfig::default();
353 assert!(config.modalities.contains_key(&Modality::Text));
354 assert!(config.modalities.contains_key(&Modality::Image));
355 assert_eq!(config.fusion_strategy, FusionStrategy::EarlyFusion);
356 assert!(!config.distributed);
357 }
358}