1use crate::engine::{EngineConfig, InferenceEngine};
42use crate::error::{InferenceError, InferenceResult};
43use scirs2_core::ndarray::{Array1, Array2, Axis};
44use serde::{Deserialize, Serialize};
45use std::collections::HashMap;
46use std::sync::Arc;
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
50pub enum ModalityType {
51 Audio,
53 Video,
55 Sensor,
57 Text,
59 Custom(&'static str),
61}
62
63impl ModalityType {
64 pub fn name(&self) -> &str {
66 match self {
67 ModalityType::Audio => "audio",
68 ModalityType::Video => "video",
69 ModalityType::Sensor => "sensor",
70 ModalityType::Text => "text",
71 ModalityType::Custom(name) => name,
72 }
73 }
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
78pub enum FusionStrategy {
79 #[default]
81 EarlyFusion,
82 LateFusion,
84 WeightedFusion,
86 MaxPooling,
88 CrossAttention,
90 Hierarchical,
92}
93
94pub type ModalityPreprocessor =
96 Arc<dyn Fn(&Array1<f32>) -> InferenceResult<Array1<f32>> + Send + Sync>;
97
98#[derive(Clone)]
100pub struct ModalityConfig {
101 pub modality_type: ModalityType,
103 pub input_dim: usize,
105 pub preprocessor: Option<ModalityPreprocessor>,
107 pub fusion_weight: f32,
109}
110
111impl ModalityConfig {
112 pub fn new(modality_type: ModalityType, input_dim: usize) -> Self {
114 Self {
115 modality_type,
116 input_dim,
117 preprocessor: None,
118 fusion_weight: 1.0,
119 }
120 }
121
122 pub fn preprocessor(mut self, preprocessor: ModalityPreprocessor) -> Self {
124 self.preprocessor = Some(preprocessor);
125 self
126 }
127
128 pub fn fusion_weight(mut self, weight: f32) -> Self {
130 self.fusion_weight = weight;
131 self
132 }
133}
134
135pub struct MultiModalPipeline {
137 engine: InferenceEngine,
139 modalities: HashMap<ModalityType, ModalityConfig>,
141 fusion_strategy: FusionStrategy,
143 #[allow(dead_code)]
145 total_input_dim: usize,
146}
147
148impl MultiModalPipeline {
149 pub fn builder() -> MultiModalPipelineBuilder {
151 MultiModalPipelineBuilder::new()
152 }
153
154 pub fn forward(
164 &mut self,
165 inputs: &[(ModalityType, Array1<f32>)],
166 ) -> InferenceResult<Array1<f32>> {
167 for (modality, input) in inputs {
169 let config = self.modalities.get(modality).ok_or_else(|| {
170 InferenceError::PipelineConfig(format!("Unknown modality: {:?}", modality))
171 })?;
172
173 if input.len() != config.input_dim {
174 return Err(InferenceError::DimensionMismatch {
175 expected: config.input_dim,
176 got: input.len(),
177 });
178 }
179 }
180
181 let mut preprocessed: HashMap<ModalityType, Array1<f32>> = HashMap::new();
183 for (modality, input) in inputs {
184 let config = &self.modalities[modality];
185 let processed = if let Some(ref preprocessor) = config.preprocessor {
186 preprocessor(input)?
187 } else {
188 input.clone()
189 };
190 preprocessed.insert(*modality, processed);
191 }
192
193 let fused = self.fuse(&preprocessed)?;
195
196 self.engine.step(&fused)
198 }
199
200 fn fuse(
202 &mut self,
203 inputs: &HashMap<ModalityType, Array1<f32>>,
204 ) -> InferenceResult<Array1<f32>> {
205 match self.fusion_strategy {
206 FusionStrategy::EarlyFusion => self.early_fusion(inputs),
207 FusionStrategy::LateFusion => self.late_fusion(inputs),
208 FusionStrategy::WeightedFusion => self.weighted_fusion(inputs),
209 FusionStrategy::MaxPooling => self.max_pooling_fusion(inputs),
210 FusionStrategy::CrossAttention => self.cross_attention_fusion(inputs),
211 FusionStrategy::Hierarchical => self.hierarchical_fusion(inputs),
212 }
213 }
214
215 fn early_fusion(
217 &self,
218 inputs: &HashMap<ModalityType, Array1<f32>>,
219 ) -> InferenceResult<Array1<f32>> {
220 let mut result = Vec::new();
221
222 let mut sorted_modalities: Vec<_> = inputs.keys().collect();
224 sorted_modalities.sort_by_key(|m| m.name());
225
226 for modality in sorted_modalities {
227 let input = &inputs[modality];
228 let slice = input.as_slice().ok_or_else(|| {
229 InferenceError::ForwardError(
230 "Array data not contiguous in early fusion".to_string(),
231 )
232 })?;
233 result.extend_from_slice(slice);
234 }
235
236 Ok(Array1::from_vec(result))
237 }
238
239 fn late_fusion(
245 &mut self,
246 inputs: &HashMap<ModalityType, Array1<f32>>,
247 ) -> InferenceResult<Array1<f32>> {
248 if inputs.is_empty() {
249 return Err(InferenceError::PipelineConfig(
250 "No modalities to fuse".into(),
251 ));
252 }
253
254 let mut by_dim: std::collections::HashMap<usize, Vec<Array1<f32>>> =
256 std::collections::HashMap::new();
257 for input in inputs.values() {
258 by_dim.entry(input.len()).or_default().push(input.clone());
259 }
260
261 let mut result = Vec::new();
263 let mut dims: Vec<_> = by_dim.keys().cloned().collect();
264 dims.sort();
265
266 for dim in dims {
267 let arrays = &by_dim[&dim];
268 let mut averaged = Array1::zeros(dim);
269 for arr in arrays {
270 averaged += arr;
271 }
272 averaged /= arrays.len() as f32;
273 let slice = averaged.as_slice().ok_or_else(|| {
274 InferenceError::ForwardError("Array data not contiguous in late fusion".to_string())
275 })?;
276 result.extend_from_slice(slice);
277 }
278
279 Ok(Array1::from_vec(result))
280 }
281
282 fn weighted_fusion(
284 &self,
285 inputs: &HashMap<ModalityType, Array1<f32>>,
286 ) -> InferenceResult<Array1<f32>> {
287 let mut result = Vec::new();
288 let mut total_weight = 0.0;
289
290 for (modality, input) in inputs {
292 let config = &self.modalities[modality];
293 let weight = config.fusion_weight;
294 total_weight += weight;
295
296 let weighted = input.mapv(|x| x * weight);
297 let slice = weighted.as_slice().ok_or_else(|| {
298 InferenceError::ForwardError(
299 "Array data not contiguous in weighted fusion".to_string(),
300 )
301 })?;
302 result.extend_from_slice(slice);
303 }
304
305 let normalized: Vec<f32> = result.iter().map(|x| x / total_weight).collect();
307 Ok(Array1::from_vec(normalized))
308 }
309
310 fn max_pooling_fusion(
312 &self,
313 inputs: &HashMap<ModalityType, Array1<f32>>,
314 ) -> InferenceResult<Array1<f32>> {
315 if inputs.is_empty() {
316 return Err(InferenceError::PipelineConfig(
317 "No modalities to fuse".into(),
318 ));
319 }
320
321 let mut by_dim: std::collections::HashMap<usize, Vec<Array1<f32>>> =
323 std::collections::HashMap::new();
324 for input in inputs.values() {
325 by_dim.entry(input.len()).or_default().push(input.clone());
326 }
327
328 let mut result = Vec::new();
330 let mut dims: Vec<_> = by_dim.keys().cloned().collect();
331 dims.sort();
332
333 for dim in dims {
334 let arrays = &by_dim[&dim];
335 if arrays.len() == 1 {
336 let slice = arrays[0].as_slice().ok_or_else(|| {
337 InferenceError::ForwardError(
338 "Array data not contiguous in max pooling".to_string(),
339 )
340 })?;
341 result.extend_from_slice(slice);
342 } else {
343 let nrows = arrays.len();
345 let ncols = dim;
346 let mut stacked = Array2::zeros((nrows, ncols));
347 for (i, arr) in arrays.iter().enumerate() {
348 for (j, &val) in arr.iter().enumerate() {
349 stacked[[i, j]] = val;
350 }
351 }
352 let pooled = stacked.map_axis(Axis(0), |col| {
353 col.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
354 });
355 let slice = pooled.as_slice().ok_or_else(|| {
356 InferenceError::ForwardError(
357 "Array data not contiguous in max pooling result".to_string(),
358 )
359 })?;
360 result.extend_from_slice(slice);
361 }
362 }
363
364 Ok(Array1::from_vec(result))
365 }
366
367 fn cross_attention_fusion(
369 &self,
370 inputs: &HashMap<ModalityType, Array1<f32>>,
371 ) -> InferenceResult<Array1<f32>> {
372 if inputs.is_empty() {
373 return Err(InferenceError::PipelineConfig(
374 "No modalities to fuse".into(),
375 ));
376 }
377
378 if inputs.len() == 1 {
379 let single_input = inputs.values().next().ok_or_else(|| {
380 InferenceError::ForwardError("No input found in hierarchical fusion".to_string())
381 })?;
382 return Ok(single_input.clone());
383 }
384
385 let mut by_dim: std::collections::HashMap<usize, Vec<Array1<f32>>> =
387 std::collections::HashMap::new();
388 for input in inputs.values() {
389 by_dim.entry(input.len()).or_default().push(input.clone());
390 }
391
392 let mut result = Vec::new();
394 let mut dims: Vec<_> = by_dim.keys().cloned().collect();
395 dims.sort();
396
397 for dim in dims {
398 let modalities = &by_dim[&dim];
399 let n = modalities.len();
400
401 if n == 1 {
402 let slice = modalities[0].as_slice().ok_or_else(|| {
403 InferenceError::ForwardError(
404 "Array data not contiguous in hierarchical fusion".to_string(),
405 )
406 })?;
407 result.extend_from_slice(slice);
408 } else {
409 let mut attention_weights = vec![0.0; n];
411 for i in 0..n {
412 for j in 0..n {
413 if i != j {
414 let dot_product: f32 = modalities[i]
415 .iter()
416 .zip(modalities[j].iter())
417 .map(|(a, b)| a * b)
418 .sum();
419 attention_weights[i] += dot_product.abs();
420 }
421 }
422 }
423
424 let total: f32 = attention_weights.iter().sum();
426 if total > 0.0 {
427 for weight in &mut attention_weights {
428 *weight /= total;
429 }
430 } else {
431 let uniform = 1.0 / n as f32;
433 attention_weights.fill(uniform);
434 }
435
436 let mut weighted_result = Array1::zeros(dim);
438 for (i, modality) in modalities.iter().enumerate() {
439 weighted_result += &(modality * attention_weights[i]);
440 }
441 let slice = weighted_result.as_slice().ok_or_else(|| {
442 InferenceError::ForwardError(
443 "Array data not contiguous in cross-attention result".to_string(),
444 )
445 })?;
446 result.extend_from_slice(slice);
447 }
448 }
449
450 Ok(Array1::from_vec(result))
451 }
452
453 fn hierarchical_fusion(
458 &mut self,
459 inputs: &HashMap<ModalityType, Array1<f32>>,
460 ) -> InferenceResult<Array1<f32>> {
461 let early = self.early_fusion(inputs)?;
463
464 let weighted = self.weighted_fusion(inputs)?;
466
467 if early.len() != weighted.len() {
469 return Err(InferenceError::PipelineConfig(format!(
470 "Fusion dimension mismatch: early={}, weighted={}",
471 early.len(),
472 weighted.len()
473 )));
474 }
475
476 let result = (early + weighted) / 2.0;
478 Ok(result)
479 }
480
481 pub fn reset(&mut self) {
483 self.engine.reset();
484 }
485
486 pub fn fusion_strategy(&self) -> FusionStrategy {
488 self.fusion_strategy
489 }
490
491 pub fn modalities(&self) -> &HashMap<ModalityType, ModalityConfig> {
493 &self.modalities
494 }
495
496 pub fn engine(&self) -> &InferenceEngine {
498 &self.engine
499 }
500
501 pub fn engine_mut(&mut self) -> &mut InferenceEngine {
503 &mut self.engine
504 }
505}
506
507pub struct MultiModalPipelineBuilder {
509 engine_config: Option<EngineConfig>,
510 modalities: HashMap<ModalityType, ModalityConfig>,
511 fusion_strategy: FusionStrategy,
512}
513
514impl MultiModalPipelineBuilder {
515 pub fn new() -> Self {
517 Self {
518 engine_config: None,
519 modalities: HashMap::new(),
520 fusion_strategy: FusionStrategy::default(),
521 }
522 }
523
524 pub fn engine_config(mut self, config: EngineConfig) -> Self {
526 self.engine_config = Some(config);
527 self
528 }
529
530 pub fn add_modality(mut self, config: ModalityConfig) -> Self {
532 self.modalities.insert(config.modality_type, config);
533 self
534 }
535
536 pub fn modality(mut self, modality_type: ModalityType, input_dim: usize) -> Self {
538 let config = ModalityConfig::new(modality_type, input_dim);
539 self.modalities.insert(modality_type, config);
540 self
541 }
542
543 pub fn fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
545 self.fusion_strategy = strategy;
546 self
547 }
548
549 pub fn build(self) -> InferenceResult<MultiModalPipeline> {
551 if self.modalities.is_empty() {
552 return Err(InferenceError::PipelineConfig(
553 "At least one modality must be configured".into(),
554 ));
555 }
556
557 let total_input_dim: usize = self.modalities.values().map(|c| c.input_dim).sum();
559
560 let engine_config = self
561 .engine_config
562 .ok_or_else(|| InferenceError::PipelineConfig("engine_config not set".into()))?;
563
564 let engine = InferenceEngine::new(engine_config);
565
566 Ok(MultiModalPipeline {
567 engine,
568 modalities: self.modalities,
569 fusion_strategy: self.fusion_strategy,
570 total_input_dim,
571 })
572 }
573}
574
575impl Default for MultiModalPipelineBuilder {
576 fn default() -> Self {
577 Self::new()
578 }
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584 use kizzasi_model::s4::{S4Config, S4D};
585
586 fn create_test_model(input_dim: usize, _output_dim: usize) -> Box<S4D> {
587 let config = S4Config::new()
588 .input_dim(input_dim)
589 .hidden_dim(32)
590 .state_dim(16)
591 .num_layers(1)
592 .diagonal(true);
593 Box::new(S4D::new(config).unwrap())
594 }
595
596 #[test]
597 fn test_modality_type_name() {
598 assert_eq!(ModalityType::Audio.name(), "audio");
599 assert_eq!(ModalityType::Video.name(), "video");
600 assert_eq!(ModalityType::Sensor.name(), "sensor");
601 assert_eq!(ModalityType::Text.name(), "text");
602 assert_eq!(ModalityType::Custom("xyz").name(), "xyz");
603 }
604
605 #[test]
606 fn test_multimodal_builder() {
607 let engine_config = EngineConfig::new(6, 10);
608 let pipeline = MultiModalPipeline::builder()
609 .engine_config(engine_config)
610 .modality(ModalityType::Audio, 3)
611 .modality(ModalityType::Video, 3)
612 .fusion_strategy(FusionStrategy::EarlyFusion)
613 .build();
614
615 assert!(pipeline.is_ok());
616 let p = pipeline.unwrap();
617 assert_eq!(p.modalities().len(), 2);
618 assert_eq!(p.total_input_dim, 6);
619 }
620
621 #[test]
622 fn test_multimodal_no_modalities() {
623 let engine_config = EngineConfig::new(3, 10);
624 let result = MultiModalPipeline::builder()
625 .engine_config(engine_config)
626 .build();
627
628 assert!(result.is_err());
629 }
630
631 #[test]
632 fn test_early_fusion() {
633 let engine_config = EngineConfig::new(6, 6);
634 let mut pipeline = MultiModalPipeline::builder()
635 .engine_config(engine_config)
636 .modality(ModalityType::Audio, 3)
637 .modality(ModalityType::Video, 3)
638 .fusion_strategy(FusionStrategy::EarlyFusion)
639 .build()
640 .unwrap();
641
642 pipeline.engine_mut().set_model(create_test_model(6, 6));
643
644 let audio = Array1::from_vec(vec![0.1, 0.2, 0.3]);
645 let video = Array1::from_vec(vec![0.4, 0.5, 0.6]);
646
647 let result =
648 pipeline.forward(&[(ModalityType::Audio, audio), (ModalityType::Video, video)]);
649
650 result.unwrap(); }
652
653 #[test]
654 fn test_weighted_fusion() {
655 let engine_config = EngineConfig::new(4, 4);
656
657 let audio_config = ModalityConfig::new(ModalityType::Audio, 2).fusion_weight(2.0);
658 let video_config = ModalityConfig::new(ModalityType::Video, 2).fusion_weight(1.0);
659
660 let mut pipeline = MultiModalPipeline::builder()
661 .engine_config(engine_config)
662 .add_modality(audio_config)
663 .add_modality(video_config)
664 .fusion_strategy(FusionStrategy::WeightedFusion)
665 .build()
666 .unwrap();
667
668 pipeline.engine_mut().set_model(create_test_model(4, 4));
669
670 let audio = Array1::from_vec(vec![0.3, 0.6]);
671 let video = Array1::from_vec(vec![0.1, 0.2]);
672
673 let result =
674 pipeline.forward(&[(ModalityType::Audio, audio), (ModalityType::Video, video)]);
675
676 assert!(result.is_ok());
677 }
678
679 #[test]
680 fn test_dimension_mismatch() {
681 let engine_config = EngineConfig::new(6, 10);
682 let mut pipeline = MultiModalPipeline::builder()
683 .engine_config(engine_config)
684 .modality(ModalityType::Audio, 3)
685 .modality(ModalityType::Video, 3)
686 .build()
687 .unwrap();
688
689 let audio = Array1::from_vec(vec![0.1, 0.2]); let video = Array1::from_vec(vec![0.4, 0.5, 0.6]);
691
692 let result =
693 pipeline.forward(&[(ModalityType::Audio, audio), (ModalityType::Video, video)]);
694
695 assert!(result.is_err());
696 }
697
698 #[test]
699 fn test_unknown_modality() {
700 let engine_config = EngineConfig::new(3, 10);
701 let mut pipeline = MultiModalPipeline::builder()
702 .engine_config(engine_config)
703 .modality(ModalityType::Audio, 3)
704 .build()
705 .unwrap();
706
707 let video = Array1::from_vec(vec![0.4, 0.5, 0.6]);
708
709 let result = pipeline.forward(&[(ModalityType::Video, video)]);
710 assert!(result.is_err());
711 }
712
713 #[test]
714 fn test_max_pooling_fusion() {
715 let engine_config = EngineConfig::new(3, 3);
716 let mut pipeline = MultiModalPipeline::builder()
717 .engine_config(engine_config)
718 .modality(ModalityType::Audio, 3)
719 .modality(ModalityType::Video, 3)
720 .fusion_strategy(FusionStrategy::MaxPooling)
721 .build()
722 .unwrap();
723
724 pipeline.engine_mut().set_model(create_test_model(3, 3));
725
726 let audio = Array1::from_vec(vec![0.1, 0.9, 0.3]);
727 let video = Array1::from_vec(vec![0.8, 0.2, 0.6]);
728
729 let result =
730 pipeline.forward(&[(ModalityType::Audio, audio), (ModalityType::Video, video)]);
731
732 assert!(result.is_ok());
733 }
734
735 #[test]
736 fn test_cross_attention_fusion() {
737 let engine_config = EngineConfig::new(3, 3);
738 let mut pipeline = MultiModalPipeline::builder()
739 .engine_config(engine_config)
740 .modality(ModalityType::Audio, 3)
741 .modality(ModalityType::Sensor, 3)
742 .fusion_strategy(FusionStrategy::CrossAttention)
743 .build()
744 .unwrap();
745
746 pipeline.engine_mut().set_model(create_test_model(3, 3));
747
748 let audio = Array1::from_vec(vec![0.1, 0.2, 0.3]);
749 let sensor = Array1::from_vec(vec![0.4, 0.5, 0.6]);
750
751 let result =
752 pipeline.forward(&[(ModalityType::Audio, audio), (ModalityType::Sensor, sensor)]);
753
754 assert!(result.is_ok());
755 }
756
757 #[test]
758 fn test_hierarchical_fusion() {
759 let engine_config = EngineConfig::new(4, 4);
760 let mut pipeline = MultiModalPipeline::builder()
761 .engine_config(engine_config)
762 .modality(ModalityType::Audio, 2)
763 .modality(ModalityType::Text, 2)
764 .fusion_strategy(FusionStrategy::Hierarchical)
765 .build()
766 .unwrap();
767
768 pipeline.engine_mut().set_model(create_test_model(4, 4));
769
770 let audio = Array1::from_vec(vec![0.1, 0.2]);
771 let text = Array1::from_vec(vec![0.3, 0.4]);
772
773 let result = pipeline.forward(&[(ModalityType::Audio, audio), (ModalityType::Text, text)]);
774
775 result.unwrap(); }
777
778 #[test]
779 fn test_modality_preprocessor() {
780 let engine_config = EngineConfig::new(3, 3);
781
782 let preprocessor: ModalityPreprocessor = Arc::new(|input| Ok(input.mapv(|x| x * 2.0)));
783
784 let config = ModalityConfig::new(ModalityType::Audio, 3).preprocessor(preprocessor);
785
786 let mut pipeline = MultiModalPipeline::builder()
787 .engine_config(engine_config)
788 .add_modality(config)
789 .build()
790 .unwrap();
791
792 pipeline.engine_mut().set_model(create_test_model(3, 3));
793
794 let audio = Array1::from_vec(vec![0.1, 0.2, 0.3]);
795 let result = pipeline.forward(&[(ModalityType::Audio, audio)]);
796
797 assert!(result.is_ok());
798 }
799}