kizzasi_inference/
multimodal.rs

1//! Multi-modal inference pipeline support
2//!
3//! Provides infrastructure for handling multiple input modalities and fusing them
4//! for unified inference. Supports various fusion strategies and modality-specific
5//! preprocessing.
6//!
7//! # Supported Modalities
8//!
9//! - Audio: Time-series audio signals
10//! - Video: Frame-based visual data
11//! - Sensor: Generic sensor readings (IMU, temperature, etc.)
12//! - Text: Embedded text representations
13//!
14//! # Fusion Strategies
15//!
16//! - Early fusion: Concatenate modalities before model
17//! - Late fusion: Process separately, combine outputs
18//! - Cross-attention: Attend across modalities
19//! - Hierarchical: Multi-level fusion
20//!
21//! # Example
22//!
23//! ```rust,ignore
24//! use kizzasi_inference::multimodal::{MultiModalPipeline, ModalityType, FusionStrategy};
25//!
26//! let mut pipeline = MultiModalPipeline::builder()
27//!     .add_modality(ModalityType::Audio, audio_processor)
28//!     .add_modality(ModalityType::Video, video_processor)
29//!     .fusion_strategy(FusionStrategy::EarlyFusion)
30//!     .build()?;
31//!
32//! let audio_input = Array1::from_vec(vec![0.1, 0.2, 0.3]);
33//! let video_input = Array1::from_vec(vec![0.4, 0.5, 0.6]);
34//!
35//! let output = pipeline.forward(&[
36//!     (ModalityType::Audio, audio_input),
37//!     (ModalityType::Video, video_input),
38//! ])?;
39//! ```
40
41use crate::engine::{EngineConfig, InferenceEngine};
42use crate::error::{InferenceError, InferenceResult};
43use scirs2_core::ndarray::{Array1, Array2, Axis};
44use serde::{Deserialize, Serialize};
45use std::collections::HashMap;
46use std::sync::Arc;
47
48/// Supported modality types
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
50pub enum ModalityType {
51    /// Audio signals (time-series)
52    Audio,
53    /// Video frames (image sequences)
54    Video,
55    /// Sensor data (IMU, temperature, pressure, etc.)
56    Sensor,
57    /// Text embeddings
58    Text,
59    /// Custom modality
60    Custom(&'static str),
61}
62
63impl ModalityType {
64    /// Get the modality name as a string
65    pub fn name(&self) -> &str {
66        match self {
67            ModalityType::Audio => "audio",
68            ModalityType::Video => "video",
69            ModalityType::Sensor => "sensor",
70            ModalityType::Text => "text",
71            ModalityType::Custom(name) => name,
72        }
73    }
74}
75
76/// Fusion strategy for combining multiple modalities
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
78pub enum FusionStrategy {
79    /// Concatenate all modalities before processing
80    #[default]
81    EarlyFusion,
82    /// Process each modality separately, then combine outputs
83    LateFusion,
84    /// Weighted average of modality outputs
85    WeightedFusion,
86    /// Maximum pooling across modalities
87    MaxPooling,
88    /// Cross-attention between modalities
89    CrossAttention,
90    /// Hierarchical multi-level fusion
91    Hierarchical,
92}
93
94/// Preprocessor for a specific modality
95pub type ModalityPreprocessor =
96    Arc<dyn Fn(&Array1<f32>) -> InferenceResult<Array1<f32>> + Send + Sync>;
97
98/// Configuration for a single modality
99#[derive(Clone)]
100pub struct ModalityConfig {
101    /// Type of modality
102    pub modality_type: ModalityType,
103    /// Expected input dimension
104    pub input_dim: usize,
105    /// Optional preprocessor
106    pub preprocessor: Option<ModalityPreprocessor>,
107    /// Weight for fusion (used in weighted fusion)
108    pub fusion_weight: f32,
109}
110
111impl ModalityConfig {
112    /// Create a new modality configuration
113    pub fn new(modality_type: ModalityType, input_dim: usize) -> Self {
114        Self {
115            modality_type,
116            input_dim,
117            preprocessor: None,
118            fusion_weight: 1.0,
119        }
120    }
121
122    /// Set the preprocessor
123    pub fn preprocessor(mut self, preprocessor: ModalityPreprocessor) -> Self {
124        self.preprocessor = Some(preprocessor);
125        self
126    }
127
128    /// Set the fusion weight
129    pub fn fusion_weight(mut self, weight: f32) -> Self {
130        self.fusion_weight = weight;
131        self
132    }
133}
134
135/// Multi-modal inference pipeline
136pub struct MultiModalPipeline {
137    /// Base inference engine
138    engine: InferenceEngine,
139    /// Modality configurations
140    modalities: HashMap<ModalityType, ModalityConfig>,
141    /// Fusion strategy
142    fusion_strategy: FusionStrategy,
143    /// Total expected input dimension (sum of all modality dims)
144    #[allow(dead_code)]
145    total_input_dim: usize,
146}
147
148impl MultiModalPipeline {
149    /// Create a new builder
150    pub fn builder() -> MultiModalPipelineBuilder {
151        MultiModalPipelineBuilder::new()
152    }
153
154    /// Forward pass with multi-modal inputs
155    ///
156    /// # Arguments
157    ///
158    /// * `inputs` - Slice of (modality_type, input_array) pairs
159    ///
160    /// # Returns
161    ///
162    /// Fused output array
163    pub fn forward(
164        &mut self,
165        inputs: &[(ModalityType, Array1<f32>)],
166    ) -> InferenceResult<Array1<f32>> {
167        // Validate inputs
168        for (modality, input) in inputs {
169            let config = self.modalities.get(modality).ok_or_else(|| {
170                InferenceError::PipelineConfig(format!("Unknown modality: {:?}", modality))
171            })?;
172
173            if input.len() != config.input_dim {
174                return Err(InferenceError::DimensionMismatch {
175                    expected: config.input_dim,
176                    got: input.len(),
177                });
178            }
179        }
180
181        // Preprocess each modality
182        let mut preprocessed: HashMap<ModalityType, Array1<f32>> = HashMap::new();
183        for (modality, input) in inputs {
184            let config = &self.modalities[modality];
185            let processed = if let Some(ref preprocessor) = config.preprocessor {
186                preprocessor(input)?
187            } else {
188                input.clone()
189            };
190            preprocessed.insert(*modality, processed);
191        }
192
193        // Apply fusion strategy
194        let fused = self.fuse(&preprocessed)?;
195
196        // Run inference
197        self.engine.step(&fused)
198    }
199
200    /// Apply fusion strategy to combine modalities
201    fn fuse(
202        &mut self,
203        inputs: &HashMap<ModalityType, Array1<f32>>,
204    ) -> InferenceResult<Array1<f32>> {
205        match self.fusion_strategy {
206            FusionStrategy::EarlyFusion => self.early_fusion(inputs),
207            FusionStrategy::LateFusion => self.late_fusion(inputs),
208            FusionStrategy::WeightedFusion => self.weighted_fusion(inputs),
209            FusionStrategy::MaxPooling => self.max_pooling_fusion(inputs),
210            FusionStrategy::CrossAttention => self.cross_attention_fusion(inputs),
211            FusionStrategy::Hierarchical => self.hierarchical_fusion(inputs),
212        }
213    }
214
215    /// Early fusion: concatenate all modalities
216    fn early_fusion(
217        &self,
218        inputs: &HashMap<ModalityType, Array1<f32>>,
219    ) -> InferenceResult<Array1<f32>> {
220        let mut result = Vec::new();
221
222        // Concatenate in a deterministic order (by modality name)
223        let mut sorted_modalities: Vec<_> = inputs.keys().collect();
224        sorted_modalities.sort_by_key(|m| m.name());
225
226        for modality in sorted_modalities {
227            let input = &inputs[modality];
228            let slice = input.as_slice().ok_or_else(|| {
229                InferenceError::ForwardError(
230                    "Array data not contiguous in early fusion".to_string(),
231                )
232            })?;
233            result.extend_from_slice(slice);
234        }
235
236        Ok(Array1::from_vec(result))
237    }
238
239    /// Late fusion: average modalities element-wise, then concatenate
240    ///
241    /// Note: With a single engine, true late fusion (separate processing per modality)
242    /// isn't possible. This implements a simplified version that averages modalities
243    /// of the same dimension before concatenation.
244    fn late_fusion(
245        &mut self,
246        inputs: &HashMap<ModalityType, Array1<f32>>,
247    ) -> InferenceResult<Array1<f32>> {
248        if inputs.is_empty() {
249            return Err(InferenceError::PipelineConfig(
250                "No modalities to fuse".into(),
251            ));
252        }
253
254        // Group modalities by dimension
255        let mut by_dim: std::collections::HashMap<usize, Vec<Array1<f32>>> =
256            std::collections::HashMap::new();
257        for input in inputs.values() {
258            by_dim.entry(input.len()).or_default().push(input.clone());
259        }
260
261        // Average within each dimension group, then concatenate
262        let mut result = Vec::new();
263        let mut dims: Vec<_> = by_dim.keys().cloned().collect();
264        dims.sort();
265
266        for dim in dims {
267            let arrays = &by_dim[&dim];
268            let mut averaged = Array1::zeros(dim);
269            for arr in arrays {
270                averaged += arr;
271            }
272            averaged /= arrays.len() as f32;
273            let slice = averaged.as_slice().ok_or_else(|| {
274                InferenceError::ForwardError("Array data not contiguous in late fusion".to_string())
275            })?;
276            result.extend_from_slice(slice);
277        }
278
279        Ok(Array1::from_vec(result))
280    }
281
282    /// Weighted fusion: weighted average based on modality weights
283    fn weighted_fusion(
284        &self,
285        inputs: &HashMap<ModalityType, Array1<f32>>,
286    ) -> InferenceResult<Array1<f32>> {
287        let mut result = Vec::new();
288        let mut total_weight = 0.0;
289
290        // Weighted concatenation
291        for (modality, input) in inputs {
292            let config = &self.modalities[modality];
293            let weight = config.fusion_weight;
294            total_weight += weight;
295
296            let weighted = input.mapv(|x| x * weight);
297            let slice = weighted.as_slice().ok_or_else(|| {
298                InferenceError::ForwardError(
299                    "Array data not contiguous in weighted fusion".to_string(),
300                )
301            })?;
302            result.extend_from_slice(slice);
303        }
304
305        // Normalize by total weight
306        let normalized: Vec<f32> = result.iter().map(|x| x / total_weight).collect();
307        Ok(Array1::from_vec(normalized))
308    }
309
310    /// Max pooling fusion: element-wise max across modalities of same dimension, then concatenate
311    fn max_pooling_fusion(
312        &self,
313        inputs: &HashMap<ModalityType, Array1<f32>>,
314    ) -> InferenceResult<Array1<f32>> {
315        if inputs.is_empty() {
316            return Err(InferenceError::PipelineConfig(
317                "No modalities to fuse".into(),
318            ));
319        }
320
321        // Group modalities by dimension
322        let mut by_dim: std::collections::HashMap<usize, Vec<Array1<f32>>> =
323            std::collections::HashMap::new();
324        for input in inputs.values() {
325            by_dim.entry(input.len()).or_default().push(input.clone());
326        }
327
328        // Max pool within each dimension group, then concatenate
329        let mut result = Vec::new();
330        let mut dims: Vec<_> = by_dim.keys().cloned().collect();
331        dims.sort();
332
333        for dim in dims {
334            let arrays = &by_dim[&dim];
335            if arrays.len() == 1 {
336                let slice = arrays[0].as_slice().ok_or_else(|| {
337                    InferenceError::ForwardError(
338                        "Array data not contiguous in max pooling".to_string(),
339                    )
340                })?;
341                result.extend_from_slice(slice);
342            } else {
343                // Stack and max pool
344                let nrows = arrays.len();
345                let ncols = dim;
346                let mut stacked = Array2::zeros((nrows, ncols));
347                for (i, arr) in arrays.iter().enumerate() {
348                    for (j, &val) in arr.iter().enumerate() {
349                        stacked[[i, j]] = val;
350                    }
351                }
352                let pooled = stacked.map_axis(Axis(0), |col| {
353                    col.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
354                });
355                let slice = pooled.as_slice().ok_or_else(|| {
356                    InferenceError::ForwardError(
357                        "Array data not contiguous in max pooling result".to_string(),
358                    )
359                })?;
360                result.extend_from_slice(slice);
361            }
362        }
363
364        Ok(Array1::from_vec(result))
365    }
366
367    /// Cross-attention fusion: attention-weighted combination then concatenation
368    fn cross_attention_fusion(
369        &self,
370        inputs: &HashMap<ModalityType, Array1<f32>>,
371    ) -> InferenceResult<Array1<f32>> {
372        if inputs.is_empty() {
373            return Err(InferenceError::PipelineConfig(
374                "No modalities to fuse".into(),
375            ));
376        }
377
378        if inputs.len() == 1 {
379            let single_input = inputs.values().next().ok_or_else(|| {
380                InferenceError::ForwardError("No input found in hierarchical fusion".to_string())
381            })?;
382            return Ok(single_input.clone());
383        }
384
385        // Group modalities by dimension
386        let mut by_dim: std::collections::HashMap<usize, Vec<Array1<f32>>> =
387            std::collections::HashMap::new();
388        for input in inputs.values() {
389            by_dim.entry(input.len()).or_default().push(input.clone());
390        }
391
392        // Apply attention within each dimension group, then concatenate
393        let mut result = Vec::new();
394        let mut dims: Vec<_> = by_dim.keys().cloned().collect();
395        dims.sort();
396
397        for dim in dims {
398            let modalities = &by_dim[&dim];
399            let n = modalities.len();
400
401            if n == 1 {
402                let slice = modalities[0].as_slice().ok_or_else(|| {
403                    InferenceError::ForwardError(
404                        "Array data not contiguous in hierarchical fusion".to_string(),
405                    )
406                })?;
407                result.extend_from_slice(slice);
408            } else {
409                // Compute pairwise attention scores (dot products)
410                let mut attention_weights = vec![0.0; n];
411                for i in 0..n {
412                    for j in 0..n {
413                        if i != j {
414                            let dot_product: f32 = modalities[i]
415                                .iter()
416                                .zip(modalities[j].iter())
417                                .map(|(a, b)| a * b)
418                                .sum();
419                            attention_weights[i] += dot_product.abs();
420                        }
421                    }
422                }
423
424                // Normalize attention weights
425                let total: f32 = attention_weights.iter().sum();
426                if total > 0.0 {
427                    for weight in &mut attention_weights {
428                        *weight /= total;
429                    }
430                } else {
431                    // Uniform weights if no attention signal
432                    let uniform = 1.0 / n as f32;
433                    attention_weights.fill(uniform);
434                }
435
436                // Weighted sum
437                let mut weighted_result = Array1::zeros(dim);
438                for (i, modality) in modalities.iter().enumerate() {
439                    weighted_result += &(modality * attention_weights[i]);
440                }
441                let slice = weighted_result.as_slice().ok_or_else(|| {
442                    InferenceError::ForwardError(
443                        "Array data not contiguous in cross-attention result".to_string(),
444                    )
445                })?;
446                result.extend_from_slice(slice);
447            }
448        }
449
450        Ok(Array1::from_vec(result))
451    }
452
453    /// Hierarchical fusion: multi-level combination
454    ///
455    /// Combines early fusion (direct concatenation) and weighted fusion
456    /// by blending their results element-wise.
457    fn hierarchical_fusion(
458        &mut self,
459        inputs: &HashMap<ModalityType, Array1<f32>>,
460    ) -> InferenceResult<Array1<f32>> {
461        // First level: early fusion (direct concatenation)
462        let early = self.early_fusion(inputs)?;
463
464        // Second level: weighted fusion (using modality weights)
465        let weighted = self.weighted_fusion(inputs)?;
466
467        // Both should have the same dimension
468        if early.len() != weighted.len() {
469            return Err(InferenceError::PipelineConfig(format!(
470                "Fusion dimension mismatch: early={}, weighted={}",
471                early.len(),
472                weighted.len()
473            )));
474        }
475
476        // Blend both strategies (element-wise average)
477        let result = (early + weighted) / 2.0;
478        Ok(result)
479    }
480
481    /// Reset the pipeline state
482    pub fn reset(&mut self) {
483        self.engine.reset();
484    }
485
486    /// Get the fusion strategy
487    pub fn fusion_strategy(&self) -> FusionStrategy {
488        self.fusion_strategy
489    }
490
491    /// Get modality configurations
492    pub fn modalities(&self) -> &HashMap<ModalityType, ModalityConfig> {
493        &self.modalities
494    }
495
496    /// Get the underlying engine
497    pub fn engine(&self) -> &InferenceEngine {
498        &self.engine
499    }
500
501    /// Get mutable access to the engine
502    pub fn engine_mut(&mut self) -> &mut InferenceEngine {
503        &mut self.engine
504    }
505}
506
507/// Builder for multi-modal pipelines
508pub struct MultiModalPipelineBuilder {
509    engine_config: Option<EngineConfig>,
510    modalities: HashMap<ModalityType, ModalityConfig>,
511    fusion_strategy: FusionStrategy,
512}
513
514impl MultiModalPipelineBuilder {
515    /// Create a new builder
516    pub fn new() -> Self {
517        Self {
518            engine_config: None,
519            modalities: HashMap::new(),
520            fusion_strategy: FusionStrategy::default(),
521        }
522    }
523
524    /// Set the engine configuration
525    pub fn engine_config(mut self, config: EngineConfig) -> Self {
526        self.engine_config = Some(config);
527        self
528    }
529
530    /// Add a modality configuration
531    pub fn add_modality(mut self, config: ModalityConfig) -> Self {
532        self.modalities.insert(config.modality_type, config);
533        self
534    }
535
536    /// Add a modality with type and dimension
537    pub fn modality(mut self, modality_type: ModalityType, input_dim: usize) -> Self {
538        let config = ModalityConfig::new(modality_type, input_dim);
539        self.modalities.insert(modality_type, config);
540        self
541    }
542
543    /// Set the fusion strategy
544    pub fn fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
545        self.fusion_strategy = strategy;
546        self
547    }
548
549    /// Build the pipeline
550    pub fn build(self) -> InferenceResult<MultiModalPipeline> {
551        if self.modalities.is_empty() {
552            return Err(InferenceError::PipelineConfig(
553                "At least one modality must be configured".into(),
554            ));
555        }
556
557        // Calculate total input dimension
558        let total_input_dim: usize = self.modalities.values().map(|c| c.input_dim).sum();
559
560        let engine_config = self
561            .engine_config
562            .ok_or_else(|| InferenceError::PipelineConfig("engine_config not set".into()))?;
563
564        let engine = InferenceEngine::new(engine_config);
565
566        Ok(MultiModalPipeline {
567            engine,
568            modalities: self.modalities,
569            fusion_strategy: self.fusion_strategy,
570            total_input_dim,
571        })
572    }
573}
574
575impl Default for MultiModalPipelineBuilder {
576    fn default() -> Self {
577        Self::new()
578    }
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584    use kizzasi_model::s4::{S4Config, S4D};
585
586    fn create_test_model(input_dim: usize, _output_dim: usize) -> Box<S4D> {
587        let config = S4Config::new()
588            .input_dim(input_dim)
589            .hidden_dim(32)
590            .state_dim(16)
591            .num_layers(1)
592            .diagonal(true);
593        Box::new(S4D::new(config).unwrap())
594    }
595
596    #[test]
597    fn test_modality_type_name() {
598        assert_eq!(ModalityType::Audio.name(), "audio");
599        assert_eq!(ModalityType::Video.name(), "video");
600        assert_eq!(ModalityType::Sensor.name(), "sensor");
601        assert_eq!(ModalityType::Text.name(), "text");
602        assert_eq!(ModalityType::Custom("xyz").name(), "xyz");
603    }
604
605    #[test]
606    fn test_multimodal_builder() {
607        let engine_config = EngineConfig::new(6, 10);
608        let pipeline = MultiModalPipeline::builder()
609            .engine_config(engine_config)
610            .modality(ModalityType::Audio, 3)
611            .modality(ModalityType::Video, 3)
612            .fusion_strategy(FusionStrategy::EarlyFusion)
613            .build();
614
615        assert!(pipeline.is_ok());
616        let p = pipeline.unwrap();
617        assert_eq!(p.modalities().len(), 2);
618        assert_eq!(p.total_input_dim, 6);
619    }
620
621    #[test]
622    fn test_multimodal_no_modalities() {
623        let engine_config = EngineConfig::new(3, 10);
624        let result = MultiModalPipeline::builder()
625            .engine_config(engine_config)
626            .build();
627
628        assert!(result.is_err());
629    }
630
631    #[test]
632    fn test_early_fusion() {
633        let engine_config = EngineConfig::new(6, 6);
634        let mut pipeline = MultiModalPipeline::builder()
635            .engine_config(engine_config)
636            .modality(ModalityType::Audio, 3)
637            .modality(ModalityType::Video, 3)
638            .fusion_strategy(FusionStrategy::EarlyFusion)
639            .build()
640            .unwrap();
641
642        pipeline.engine_mut().set_model(create_test_model(6, 6));
643
644        let audio = Array1::from_vec(vec![0.1, 0.2, 0.3]);
645        let video = Array1::from_vec(vec![0.4, 0.5, 0.6]);
646
647        let result =
648            pipeline.forward(&[(ModalityType::Audio, audio), (ModalityType::Video, video)]);
649
650        result.unwrap(); // This will show the error
651    }
652
653    #[test]
654    fn test_weighted_fusion() {
655        let engine_config = EngineConfig::new(4, 4);
656
657        let audio_config = ModalityConfig::new(ModalityType::Audio, 2).fusion_weight(2.0);
658        let video_config = ModalityConfig::new(ModalityType::Video, 2).fusion_weight(1.0);
659
660        let mut pipeline = MultiModalPipeline::builder()
661            .engine_config(engine_config)
662            .add_modality(audio_config)
663            .add_modality(video_config)
664            .fusion_strategy(FusionStrategy::WeightedFusion)
665            .build()
666            .unwrap();
667
668        pipeline.engine_mut().set_model(create_test_model(4, 4));
669
670        let audio = Array1::from_vec(vec![0.3, 0.6]);
671        let video = Array1::from_vec(vec![0.1, 0.2]);
672
673        let result =
674            pipeline.forward(&[(ModalityType::Audio, audio), (ModalityType::Video, video)]);
675
676        assert!(result.is_ok());
677    }
678
679    #[test]
680    fn test_dimension_mismatch() {
681        let engine_config = EngineConfig::new(6, 10);
682        let mut pipeline = MultiModalPipeline::builder()
683            .engine_config(engine_config)
684            .modality(ModalityType::Audio, 3)
685            .modality(ModalityType::Video, 3)
686            .build()
687            .unwrap();
688
689        let audio = Array1::from_vec(vec![0.1, 0.2]); // Wrong dimension!
690        let video = Array1::from_vec(vec![0.4, 0.5, 0.6]);
691
692        let result =
693            pipeline.forward(&[(ModalityType::Audio, audio), (ModalityType::Video, video)]);
694
695        assert!(result.is_err());
696    }
697
698    #[test]
699    fn test_unknown_modality() {
700        let engine_config = EngineConfig::new(3, 10);
701        let mut pipeline = MultiModalPipeline::builder()
702            .engine_config(engine_config)
703            .modality(ModalityType::Audio, 3)
704            .build()
705            .unwrap();
706
707        let video = Array1::from_vec(vec![0.4, 0.5, 0.6]);
708
709        let result = pipeline.forward(&[(ModalityType::Video, video)]);
710        assert!(result.is_err());
711    }
712
713    #[test]
714    fn test_max_pooling_fusion() {
715        let engine_config = EngineConfig::new(3, 3);
716        let mut pipeline = MultiModalPipeline::builder()
717            .engine_config(engine_config)
718            .modality(ModalityType::Audio, 3)
719            .modality(ModalityType::Video, 3)
720            .fusion_strategy(FusionStrategy::MaxPooling)
721            .build()
722            .unwrap();
723
724        pipeline.engine_mut().set_model(create_test_model(3, 3));
725
726        let audio = Array1::from_vec(vec![0.1, 0.9, 0.3]);
727        let video = Array1::from_vec(vec![0.8, 0.2, 0.6]);
728
729        let result =
730            pipeline.forward(&[(ModalityType::Audio, audio), (ModalityType::Video, video)]);
731
732        assert!(result.is_ok());
733    }
734
735    #[test]
736    fn test_cross_attention_fusion() {
737        let engine_config = EngineConfig::new(3, 3);
738        let mut pipeline = MultiModalPipeline::builder()
739            .engine_config(engine_config)
740            .modality(ModalityType::Audio, 3)
741            .modality(ModalityType::Sensor, 3)
742            .fusion_strategy(FusionStrategy::CrossAttention)
743            .build()
744            .unwrap();
745
746        pipeline.engine_mut().set_model(create_test_model(3, 3));
747
748        let audio = Array1::from_vec(vec![0.1, 0.2, 0.3]);
749        let sensor = Array1::from_vec(vec![0.4, 0.5, 0.6]);
750
751        let result =
752            pipeline.forward(&[(ModalityType::Audio, audio), (ModalityType::Sensor, sensor)]);
753
754        assert!(result.is_ok());
755    }
756
757    #[test]
758    fn test_hierarchical_fusion() {
759        let engine_config = EngineConfig::new(4, 4);
760        let mut pipeline = MultiModalPipeline::builder()
761            .engine_config(engine_config)
762            .modality(ModalityType::Audio, 2)
763            .modality(ModalityType::Text, 2)
764            .fusion_strategy(FusionStrategy::Hierarchical)
765            .build()
766            .unwrap();
767
768        pipeline.engine_mut().set_model(create_test_model(4, 4));
769
770        let audio = Array1::from_vec(vec![0.1, 0.2]);
771        let text = Array1::from_vec(vec![0.3, 0.4]);
772
773        let result = pipeline.forward(&[(ModalityType::Audio, audio), (ModalityType::Text, text)]);
774
775        result.unwrap(); // Show error
776    }
777
778    #[test]
779    fn test_modality_preprocessor() {
780        let engine_config = EngineConfig::new(3, 3);
781
782        let preprocessor: ModalityPreprocessor = Arc::new(|input| Ok(input.mapv(|x| x * 2.0)));
783
784        let config = ModalityConfig::new(ModalityType::Audio, 3).preprocessor(preprocessor);
785
786        let mut pipeline = MultiModalPipeline::builder()
787            .engine_config(engine_config)
788            .add_modality(config)
789            .build()
790            .unwrap();
791
792        pipeline.engine_mut().set_model(create_test_model(3, 3));
793
794        let audio = Array1::from_vec(vec![0.1, 0.2, 0.3]);
795        let result = pipeline.forward(&[(ModalityType::Audio, audio)]);
796
797        assert!(result.is_ok());
798    }
799}