Skip to main content

kizzasi_model/
multimodal.rs

1//! Multi-Modal Input Fusion
2//!
3//! This module enables cross-modal signal processing by combining multiple
4//! signal streams (audio, vision, sensors, control, text) into a unified
5//! representation for prediction.
6//!
7//! # Architecture
8//!
9//! ```text
10//! Audio  ──→ [Encoder_A] ──┐
11//! Vision ──→ [Encoder_V] ──┼──→ [FusionLayer] ──→ [OutputProj] ──→ Prediction
12//! Sensor ──→ [Encoder_S] ──┘
13//! ```
14//!
15//! # Fusion Strategies
16//!
17//! - **Concatenation**: Concat all encoded modalities, project to fusion_dim
18//! - **Addition**: Element-wise sum (all modalities share same projection_dim)
19//! - **Gated**: Learned sigmoid gates control each modality's contribution
20//! - **CrossAttention**: Each modality attends to all others via scaled dot-product
21//! - **Bottleneck**: Concat → compress → expand through bottleneck layer
22//!
23//! # Missing Modalities
24//!
25//! The model gracefully handles missing modalities by substituting zero vectors,
26//! enabling robust inference when some sensor streams are unavailable.
27//!
28//! # Stream Alignment
29//!
30//! [`ModalityAligner`] synchronizes streams with different sample rates by
31//! buffering and releasing aligned frames based on sample rate ratios.
32
33use crate::error::{ModelError, ModelResult};
34use crate::{AutoregressiveModel, ModelType};
35use kizzasi_core::{sigmoid, CoreResult, HiddenState, SignalPredictor};
36use scirs2_core::ndarray::{Array1, Array2};
37
38#[allow(unused_imports)]
39use tracing::{debug, instrument, trace};
40
41// ---------------------------------------------------------------------------
42// Seeded deterministic RNG (same pattern as rwkv7.rs)
43// ---------------------------------------------------------------------------
44
45/// Simple xorshift64 PRNG for deterministic weight initialization.
46struct SeededRng {
47    state: u64,
48}
49
50impl SeededRng {
51    fn new(seed: u64) -> Self {
52        Self { state: seed.max(1) }
53    }
54
55    /// Returns a float in [-1, 1)
56    fn next_f32(&mut self) -> f32 {
57        self.state ^= self.state << 13;
58        self.state ^= self.state >> 7;
59        self.state ^= self.state << 17;
60        (self.state as f64 / u64::MAX as f64 * 2.0 - 1.0) as f32
61    }
62}
63
64// ---------------------------------------------------------------------------
65// Modality type
66// ---------------------------------------------------------------------------
67
68/// Modality type identifier
69#[derive(Debug, Clone, PartialEq, Eq, Hash)]
70pub enum Modality {
71    /// Audio signal stream
72    Audio,
73    /// Vision / image stream
74    Vision,
75    /// Generic sensor data
76    Sensor,
77    /// Control / action signals
78    Control,
79    /// Text / token embeddings
80    Text,
81    /// User-defined modality
82    Custom(String),
83}
84
85impl std::fmt::Display for Modality {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        match self {
88            Modality::Audio => write!(f, "Audio"),
89            Modality::Vision => write!(f, "Vision"),
90            Modality::Sensor => write!(f, "Sensor"),
91            Modality::Control => write!(f, "Control"),
92            Modality::Text => write!(f, "Text"),
93            Modality::Custom(name) => write!(f, "Custom({name})"),
94        }
95    }
96}
97
98// ---------------------------------------------------------------------------
99// Modality encoder
100// ---------------------------------------------------------------------------
101
102/// Configuration for a single modality encoder
103#[derive(Debug, Clone)]
104pub struct ModalityEncoderConfig {
105    /// Which modality this encoder handles
106    pub modality: Modality,
107    /// Raw input dimension for this modality
108    pub input_dim: usize,
109    /// Dimension after projection (should match fusion_dim)
110    pub projection_dim: usize,
111    /// Depth of the encoder MLP
112    pub num_layers: usize,
113}
114
115/// Encodes a single modality's input into a shared representation space
116pub struct ModalityEncoder {
117    config: ModalityEncoderConfig,
118    /// (weight, bias) for each MLP layer
119    layers: Vec<(Array2<f32>, Array1<f32>)>,
120    /// Optional layer normalization (gamma, beta)
121    norm: Option<(Array1<f32>, Array1<f32>)>,
122}
123
124impl ModalityEncoder {
125    /// Create a new modality encoder with deterministic weight initialization.
126    pub fn new(config: ModalityEncoderConfig) -> ModelResult<Self> {
127        if config.input_dim == 0 {
128            return Err(ModelError::invalid_config("input_dim must be > 0"));
129        }
130        if config.projection_dim == 0 {
131            return Err(ModelError::invalid_config("projection_dim must be > 0"));
132        }
133        if config.num_layers == 0 {
134            return Err(ModelError::invalid_config("num_layers must be > 0"));
135        }
136
137        let mut rng =
138            SeededRng::new(42 + config.input_dim as u64 * 7 + config.projection_dim as u64 * 13);
139
140        let mut layers = Vec::with_capacity(config.num_layers);
141
142        for i in 0..config.num_layers {
143            let (in_dim, out_dim) = if i == 0 {
144                (config.input_dim, config.projection_dim)
145            } else {
146                (config.projection_dim, config.projection_dim)
147            };
148
149            let scale = (2.0 / in_dim as f32).sqrt();
150            let weight = Array2::from_shape_fn((in_dim, out_dim), |_| rng.next_f32() * scale);
151            let bias = Array1::zeros(out_dim);
152            layers.push((weight, bias));
153        }
154
155        // Layer normalization on the final output
156        let gamma = Array1::ones(config.projection_dim);
157        let beta = Array1::zeros(config.projection_dim);
158        let norm = Some((gamma, beta));
159
160        Ok(Self {
161            config,
162            layers,
163            norm,
164        })
165    }
166
167    /// Encode an input vector into the shared representation space.
168    pub fn encode(&self, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
169        if input.len() != self.config.input_dim {
170            return Err(ModelError::dimension_mismatch(
171                format!("ModalityEncoder({}) input", self.config.modality),
172                self.config.input_dim,
173                input.len(),
174            ));
175        }
176
177        let mut x = input.clone();
178        for (i, (weight, bias)) in self.layers.iter().enumerate() {
179            x = x.dot(weight) + bias;
180            // Apply ReLU activation on all but the last layer
181            if i + 1 < self.layers.len() {
182                x.mapv_inplace(|v| v.max(0.0));
183            }
184        }
185
186        // Apply layer normalization
187        if let Some((gamma, beta)) = &self.norm {
188            x = layer_norm_1d(&x, gamma, beta);
189        }
190
191        Ok(x)
192    }
193
194    /// Get the input dimension.
195    pub fn input_dim(&self) -> usize {
196        self.config.input_dim
197    }
198
199    /// Get the output (projection) dimension.
200    pub fn output_dim(&self) -> usize {
201        self.config.projection_dim
202    }
203}
204
205/// Simple layer normalization for a 1-D vector.
206fn layer_norm_1d(x: &Array1<f32>, gamma: &Array1<f32>, beta: &Array1<f32>) -> Array1<f32> {
207    let n = x.len() as f32;
208    let mean = x.sum() / n;
209    let var = x.mapv(|v| (v - mean).powi(2)).sum() / n;
210    let std_inv = 1.0 / (var + 1e-5_f32).sqrt();
211    let normalized = x.mapv(|v| (v - mean) * std_inv);
212    &normalized * gamma + beta
213}
214
215// ---------------------------------------------------------------------------
216// Fusion strategy
217// ---------------------------------------------------------------------------
218
219/// Fusion strategy for combining multiple modalities
220#[derive(Debug, Clone)]
221pub enum FusionStrategy {
222    /// Simple concatenation followed by linear projection
223    Concatenation,
224    /// Element-wise addition (all modalities must share same projection_dim)
225    Addition,
226    /// Gated fusion: learned sigmoid gates control each modality's contribution
227    Gated,
228    /// Cross-attention: each modality attends to all others
229    CrossAttention {
230        /// Number of attention heads
231        num_heads: usize,
232    },
233    /// Bottleneck: concat → compress → expand
234    Bottleneck {
235        /// Inner bottleneck dimension
236        bottleneck_dim: usize,
237    },
238}
239
240// ---------------------------------------------------------------------------
241// Fusion layer
242// ---------------------------------------------------------------------------
243
244/// Multi-modal fusion layer implementing various fusion strategies
245pub struct FusionLayer {
246    strategy: FusionStrategy,
247    fusion_dim: usize,
248    num_modalities: usize,
249    // Concatenation parameters
250    concat_proj: Option<Array2<f32>>,
251    // Gated fusion parameters
252    gate_weights: Option<Vec<Array2<f32>>>,
253    // Cross-attention parameters
254    attention_q: Option<Vec<Array2<f32>>>,
255    attention_k: Option<Vec<Array2<f32>>>,
256    attention_v: Option<Vec<Array2<f32>>>,
257    // Bottleneck parameters
258    bottleneck_down: Option<Array2<f32>>,
259    bottleneck_up: Option<Array2<f32>>,
260}
261
262impl FusionLayer {
263    /// Create a new fusion layer for the given strategy.
264    pub fn new(
265        strategy: FusionStrategy,
266        num_modalities: usize,
267        fusion_dim: usize,
268    ) -> ModelResult<Self> {
269        if num_modalities == 0 {
270            return Err(ModelError::invalid_config("num_modalities must be > 0"));
271        }
272        if fusion_dim == 0 {
273            return Err(ModelError::invalid_config("fusion_dim must be > 0"));
274        }
275
276        let mut rng = SeededRng::new(1337 + num_modalities as u64 * 11 + fusion_dim as u64 * 3);
277
278        let mut layer = Self {
279            strategy: strategy.clone(),
280            fusion_dim,
281            num_modalities,
282            concat_proj: None,
283            gate_weights: None,
284            attention_q: None,
285            attention_k: None,
286            attention_v: None,
287            bottleneck_down: None,
288            bottleneck_up: None,
289        };
290
291        match &strategy {
292            FusionStrategy::Concatenation => {
293                let concat_dim = fusion_dim * num_modalities;
294                let scale = (2.0 / concat_dim as f32).sqrt();
295                let proj =
296                    Array2::from_shape_fn((concat_dim, fusion_dim), |_| rng.next_f32() * scale);
297                layer.concat_proj = Some(proj);
298            }
299            FusionStrategy::Addition => {
300                // No extra parameters needed
301            }
302            FusionStrategy::Gated => {
303                let scale = (2.0 / fusion_dim as f32).sqrt();
304                let gates: Vec<Array2<f32>> = (0..num_modalities)
305                    .map(|_| {
306                        Array2::from_shape_fn((fusion_dim, fusion_dim), |_| rng.next_f32() * scale)
307                    })
308                    .collect();
309                layer.gate_weights = Some(gates);
310            }
311            FusionStrategy::CrossAttention { num_heads } => {
312                if !fusion_dim.is_multiple_of(*num_heads) {
313                    return Err(ModelError::invalid_config(format!(
314                        "fusion_dim ({fusion_dim}) must be divisible by num_heads ({num_heads})"
315                    )));
316                }
317                let scale = (2.0 / fusion_dim as f32).sqrt();
318                let make_projs = |rng: &mut SeededRng| -> Vec<Array2<f32>> {
319                    (0..num_modalities)
320                        .map(|_| {
321                            Array2::from_shape_fn((fusion_dim, fusion_dim), |_| {
322                                rng.next_f32() * scale
323                            })
324                        })
325                        .collect()
326                };
327                layer.attention_q = Some(make_projs(&mut rng));
328                layer.attention_k = Some(make_projs(&mut rng));
329                layer.attention_v = Some(make_projs(&mut rng));
330            }
331            FusionStrategy::Bottleneck { bottleneck_dim } => {
332                if *bottleneck_dim == 0 {
333                    return Err(ModelError::invalid_config("bottleneck_dim must be > 0"));
334                }
335                let concat_dim = fusion_dim * num_modalities;
336                let scale_down = (2.0 / concat_dim as f32).sqrt();
337                let scale_up = (2.0 / *bottleneck_dim as f32).sqrt();
338                layer.bottleneck_down =
339                    Some(Array2::from_shape_fn((concat_dim, *bottleneck_dim), |_| {
340                        rng.next_f32() * scale_down
341                    }));
342                layer.bottleneck_up =
343                    Some(Array2::from_shape_fn((*bottleneck_dim, fusion_dim), |_| {
344                        rng.next_f32() * scale_up
345                    }));
346            }
347        }
348
349        Ok(layer)
350    }
351
352    /// Fuse encoded modality representations into a single vector.
353    pub fn fuse(&self, encoded_modalities: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
354        if encoded_modalities.len() != self.num_modalities {
355            return Err(ModelError::dimension_mismatch(
356                "FusionLayer modality count",
357                self.num_modalities,
358                encoded_modalities.len(),
359            ));
360        }
361
362        // Validate dimensions
363        for (i, enc) in encoded_modalities.iter().enumerate() {
364            if enc.len() != self.fusion_dim {
365                return Err(ModelError::dimension_mismatch(
366                    format!("FusionLayer modality {i} dim"),
367                    self.fusion_dim,
368                    enc.len(),
369                ));
370            }
371        }
372
373        match &self.strategy {
374            FusionStrategy::Concatenation => self.fuse_concatenation(encoded_modalities),
375            FusionStrategy::Addition => self.fuse_addition(encoded_modalities),
376            FusionStrategy::Gated => self.fuse_gated(encoded_modalities),
377            FusionStrategy::CrossAttention { num_heads } => {
378                self.fuse_cross_attention(encoded_modalities, *num_heads)
379            }
380            FusionStrategy::Bottleneck { .. } => self.fuse_bottleneck(encoded_modalities),
381        }
382    }
383
384    fn fuse_concatenation(&self, encoded_modalities: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
385        let concat_dim = self.fusion_dim * self.num_modalities;
386        let mut concat = Array1::zeros(concat_dim);
387        for (i, enc) in encoded_modalities.iter().enumerate() {
388            let start = i * self.fusion_dim;
389            for (j, &val) in enc.iter().enumerate() {
390                concat[start + j] = val;
391            }
392        }
393        let proj = self
394            .concat_proj
395            .as_ref()
396            .ok_or_else(|| ModelError::not_initialized("concat_proj"))?;
397        Ok(concat.dot(proj))
398    }
399
400    fn fuse_addition(&self, encoded_modalities: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
401        let mut result = Array1::zeros(self.fusion_dim);
402        for enc in encoded_modalities {
403            result += enc;
404        }
405        Ok(result)
406    }
407
408    fn fuse_gated(&self, encoded_modalities: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
409        let gate_weights = self
410            .gate_weights
411            .as_ref()
412            .ok_or_else(|| ModelError::not_initialized("gate_weights"))?;
413
414        let mut result = Array1::zeros(self.fusion_dim);
415        for (i, enc) in encoded_modalities.iter().enumerate() {
416            let pre_gate = enc.dot(&gate_weights[i]);
417            let gate = sigmoid(&pre_gate);
418            result += &(enc * &gate);
419        }
420        Ok(result)
421    }
422
423    fn fuse_cross_attention(
424        &self,
425        encoded_modalities: &[Array1<f32>],
426        num_heads: usize,
427    ) -> ModelResult<Array1<f32>> {
428        let q_projs = self
429            .attention_q
430            .as_ref()
431            .ok_or_else(|| ModelError::not_initialized("attention_q"))?;
432        let k_projs = self
433            .attention_k
434            .as_ref()
435            .ok_or_else(|| ModelError::not_initialized("attention_k"))?;
436        let v_projs = self
437            .attention_v
438            .as_ref()
439            .ok_or_else(|| ModelError::not_initialized("attention_v"))?;
440
441        let head_dim = self.fusion_dim / num_heads;
442        let scale = 1.0 / (head_dim as f32).sqrt();
443        let n = self.num_modalities;
444
445        // For each modality, compute attention over all other modalities and sum
446        let mut fused = Array1::zeros(self.fusion_dim);
447
448        for i in 0..n {
449            let q = encoded_modalities[i].dot(&q_projs[i]);
450
451            // Compute attention scores against all modalities (including self)
452            let mut attn_output: Array1<f32> = Array1::zeros(self.fusion_dim);
453
454            // Per-head attention
455            for h in 0..num_heads {
456                let h_start = h * head_dim;
457                let h_end = h_start + head_dim;
458
459                let q_h = q.slice(scirs2_core::ndarray::s![h_start..h_end]);
460
461                // Compute scores for all modalities
462                let mut scores = Vec::with_capacity(n);
463                let mut values = Vec::with_capacity(n);
464                for j in 0..n {
465                    let k = encoded_modalities[j].dot(&k_projs[j]);
466                    let v = encoded_modalities[j].dot(&v_projs[j]);
467                    let k_h = k.slice(scirs2_core::ndarray::s![h_start..h_end]);
468                    let score = q_h.dot(&k_h) * scale;
469                    scores.push(score);
470                    values.push(v.slice(scirs2_core::ndarray::s![h_start..h_end]).to_owned());
471                }
472
473                // Softmax over scores
474                let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
475                let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
476                let sum_exp: f32 = exp_scores.iter().sum();
477                let sum_exp_safe = if sum_exp.abs() < 1e-10 {
478                    1e-10
479                } else {
480                    sum_exp
481                };
482
483                // Weighted sum of values
484                for (j, v_h) in values.iter().enumerate() {
485                    let weight = exp_scores[j] / sum_exp_safe;
486                    for (k, &val) in v_h.iter().enumerate() {
487                        attn_output[h_start + k] += weight * val;
488                    }
489                }
490            }
491
492            fused = fused + attn_output;
493        }
494
495        // Average over modalities
496        let divisor = n as f32;
497        fused.mapv_inplace(|v| v / divisor);
498
499        Ok(fused)
500    }
501
502    fn fuse_bottleneck(&self, encoded_modalities: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
503        let down = self
504            .bottleneck_down
505            .as_ref()
506            .ok_or_else(|| ModelError::not_initialized("bottleneck_down"))?;
507        let up = self
508            .bottleneck_up
509            .as_ref()
510            .ok_or_else(|| ModelError::not_initialized("bottleneck_up"))?;
511
512        // Concatenate
513        let concat_dim = self.fusion_dim * self.num_modalities;
514        let mut concat = Array1::zeros(concat_dim);
515        for (i, enc) in encoded_modalities.iter().enumerate() {
516            let start = i * self.fusion_dim;
517            for (j, &val) in enc.iter().enumerate() {
518                concat[start + j] = val;
519            }
520        }
521
522        // Down-project → ReLU → Up-project
523        let bottleneck = concat.dot(down);
524        let activated = bottleneck.mapv(|v| v.max(0.0));
525        Ok(activated.dot(up))
526    }
527}
528
529// ---------------------------------------------------------------------------
530// Multi-modal configuration and model
531// ---------------------------------------------------------------------------
532
533/// Configuration for the multi-modal fusion model
534#[derive(Debug, Clone)]
535pub struct MultiModalConfig {
536    /// Shared representation dimension (all encoders project to this)
537    pub fusion_dim: usize,
538    /// Fusion strategy
539    pub fusion_strategy: FusionStrategy,
540    /// Final output dimension
541    pub output_dim: usize,
542    /// Per-modality encoder configurations
543    pub modalities: Vec<ModalityEncoderConfig>,
544    /// Context window length
545    pub context_length: usize,
546}
547
548/// Complete multi-modal model combining encoders, fusion, and output projection
549pub struct MultiModalModel {
550    /// Model configuration
551    pub config: MultiModalConfig,
552    encoders: Vec<ModalityEncoder>,
553    fusion: FusionLayer,
554    output_proj: Array2<f32>,
555    output_bias: Array1<f32>,
556    /// Fused hidden state
557    state: Array1<f32>,
558}
559
560impl MultiModalModel {
561    /// Create a new multi-modal model from configuration.
562    pub fn new(config: MultiModalConfig) -> ModelResult<Self> {
563        if config.modalities.is_empty() {
564            return Err(ModelError::invalid_config(
565                "at least one modality is required",
566            ));
567        }
568        if config.fusion_dim == 0 {
569            return Err(ModelError::invalid_config("fusion_dim must be > 0"));
570        }
571        if config.output_dim == 0 {
572            return Err(ModelError::invalid_config("output_dim must be > 0"));
573        }
574        if config.context_length == 0 {
575            return Err(ModelError::invalid_config("context_length must be > 0"));
576        }
577
578        // Validate that all modalities project to fusion_dim
579        for mc in &config.modalities {
580            if mc.projection_dim != config.fusion_dim {
581                return Err(ModelError::invalid_config(format!(
582                    "modality {} projection_dim ({}) must match fusion_dim ({})",
583                    mc.modality, mc.projection_dim, config.fusion_dim
584                )));
585            }
586        }
587
588        // Build encoders
589        let encoders: Vec<ModalityEncoder> = config
590            .modalities
591            .iter()
592            .map(|mc| ModalityEncoder::new(mc.clone()))
593            .collect::<ModelResult<Vec<_>>>()?;
594
595        // Build fusion layer
596        let fusion = FusionLayer::new(
597            config.fusion_strategy.clone(),
598            config.modalities.len(),
599            config.fusion_dim,
600        )?;
601
602        // Output projection
603        let mut rng = SeededRng::new(99 + config.fusion_dim as u64 * 5);
604        let scale = (2.0 / config.fusion_dim as f32).sqrt();
605        let output_proj = Array2::from_shape_fn((config.fusion_dim, config.output_dim), |_| {
606            rng.next_f32() * scale
607        });
608        let output_bias = Array1::zeros(config.output_dim);
609
610        let state = Array1::zeros(config.fusion_dim);
611
612        Ok(Self {
613            config,
614            encoders,
615            fusion,
616            output_proj,
617            output_bias,
618            state,
619        })
620    }
621
622    /// Forward pass with all modality inputs present.
623    ///
624    /// `inputs` must contain one `Array1<f32>` per modality, in the same order
625    /// as `config.modalities`.
626    pub fn forward_multimodal(&mut self, inputs: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
627        if inputs.len() != self.encoders.len() {
628            return Err(ModelError::dimension_mismatch(
629                "MultiModalModel input count",
630                self.encoders.len(),
631                inputs.len(),
632            ));
633        }
634
635        // Encode each modality
636        let encoded: Vec<Array1<f32>> = self
637            .encoders
638            .iter()
639            .zip(inputs.iter())
640            .map(|(enc, inp)| enc.encode(inp))
641            .collect::<ModelResult<Vec<_>>>()?;
642
643        // Fuse
644        let fused = self.fusion.fuse(&encoded)?;
645
646        // Check for NaN / Inf
647        if fused.iter().any(|v| v.is_nan() || v.is_infinite()) {
648            return Err(ModelError::numerical_instability(
649                "forward_multimodal",
650                "NaN or Inf detected after fusion",
651            ));
652        }
653
654        // Update state
655        self.state = fused.clone();
656
657        // Output projection
658        let output = fused.dot(&self.output_proj) + &self.output_bias;
659        Ok(output)
660    }
661
662    /// Forward pass with optional modalities.
663    ///
664    /// Missing modalities (None) are replaced with zero vectors.
665    pub fn forward_with_missing(
666        &mut self,
667        inputs: &[Option<Array1<f32>>],
668    ) -> ModelResult<Array1<f32>> {
669        if inputs.len() != self.encoders.len() {
670            return Err(ModelError::dimension_mismatch(
671                "MultiModalModel input count",
672                self.encoders.len(),
673                inputs.len(),
674            ));
675        }
676
677        // Encode each present modality; use zeros for missing ones
678        let encoded: Vec<Array1<f32>> = self
679            .encoders
680            .iter()
681            .zip(inputs.iter())
682            .map(|(enc, maybe_inp)| match maybe_inp {
683                Some(inp) => enc.encode(inp),
684                None => Ok(Array1::zeros(enc.output_dim())),
685            })
686            .collect::<ModelResult<Vec<_>>>()?;
687
688        // Fuse
689        let fused = self.fusion.fuse(&encoded)?;
690
691        if fused.iter().any(|v| v.is_nan() || v.is_infinite()) {
692            return Err(ModelError::numerical_instability(
693                "forward_with_missing",
694                "NaN or Inf detected after fusion",
695            ));
696        }
697
698        self.state = fused.clone();
699
700        let output = fused.dot(&self.output_proj) + &self.output_bias;
701        Ok(output)
702    }
703
704    /// Number of modalities.
705    pub fn num_modalities(&self) -> usize {
706        self.encoders.len()
707    }
708
709    /// References to each modality identifier.
710    pub fn modality_names(&self) -> Vec<&Modality> {
711        self.config
712            .modalities
713            .iter()
714            .map(|mc| &mc.modality)
715            .collect()
716    }
717
718    /// Total trainable parameter count.
719    pub fn total_params(&self) -> usize {
720        let mut count = 0usize;
721
722        // Encoder parameters
723        for enc in &self.encoders {
724            for (w, b) in &enc.layers {
725                count += w.len() + b.len();
726            }
727            if let Some((g, b)) = &enc.norm {
728                count += g.len() + b.len();
729            }
730        }
731
732        // Fusion parameters
733        if let Some(p) = &self.fusion.concat_proj {
734            count += p.len();
735        }
736        if let Some(gates) = &self.fusion.gate_weights {
737            for g in gates {
738                count += g.len();
739            }
740        }
741        if let Some(qs) = &self.fusion.attention_q {
742            for q in qs {
743                count += q.len();
744            }
745        }
746        if let Some(ks) = &self.fusion.attention_k {
747            for k in ks {
748                count += k.len();
749            }
750        }
751        if let Some(vs) = &self.fusion.attention_v {
752            for v in vs {
753                count += v.len();
754            }
755        }
756        if let Some(d) = &self.fusion.bottleneck_down {
757            count += d.len();
758        }
759        if let Some(u) = &self.fusion.bottleneck_up {
760            count += u.len();
761        }
762
763        // Output projection
764        count += self.output_proj.len() + self.output_bias.len();
765
766        count
767    }
768}
769
770impl SignalPredictor for MultiModalModel {
771    /// Step with a single concatenated input.
772    ///
773    /// The input is split into per-modality chunks based on each encoder's
774    /// `input_dim`, concatenated in the order of `config.modalities`.
775    #[instrument(skip(self, input))]
776    fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
777        // Split input into per-modality slices
778        let total_input_dim: usize = self.encoders.iter().map(|e| e.input_dim()).sum();
779
780        if input.len() != total_input_dim {
781            return Err(kizzasi_core::CoreError::DimensionMismatch {
782                expected: total_input_dim,
783                got: input.len(),
784            });
785        }
786
787        let mut offset = 0;
788        let mut per_modality = Vec::with_capacity(self.encoders.len());
789        for enc in &self.encoders {
790            let dim = enc.input_dim();
791            let slice = input
792                .slice(scirs2_core::ndarray::s![offset..offset + dim])
793                .to_owned();
794            per_modality.push(slice);
795            offset += dim;
796        }
797
798        self.forward_multimodal(&per_modality)
799            .map_err(|e| kizzasi_core::CoreError::Generic(e.to_string()))
800    }
801
802    #[instrument(skip(self))]
803    fn reset(&mut self) {
804        debug!("Resetting MultiModalModel state");
805        self.state = Array1::zeros(self.config.fusion_dim);
806    }
807
808    fn context_window(&self) -> usize {
809        self.config.context_length
810    }
811}
812
813impl AutoregressiveModel for MultiModalModel {
814    fn hidden_dim(&self) -> usize {
815        self.config.fusion_dim
816    }
817
818    fn state_dim(&self) -> usize {
819        self.config.fusion_dim
820    }
821
822    fn num_layers(&self) -> usize {
823        // 1 fusion layer
824        1
825    }
826
827    fn model_type(&self) -> ModelType {
828        ModelType::MultiModal
829    }
830
831    fn get_states(&self) -> Vec<HiddenState> {
832        vec![HiddenState::new(self.config.fusion_dim, 1)]
833    }
834
835    fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
836        if states.len() != 1 {
837            return Err(ModelError::state_count_mismatch(
838                "MultiModal",
839                1,
840                states.len(),
841            ));
842        }
843        Ok(())
844    }
845}
846
847// ---------------------------------------------------------------------------
848// Modality aligner
849// ---------------------------------------------------------------------------
850
851/// Synchronizes multiple modality streams with different sample rates.
852///
853/// The aligner buffers incoming samples and produces aligned frames only when
854/// all modalities have accumulated enough data relative to the reference rate.
855pub struct ModalityAligner {
856    reference_rate: f32,
857    modality_rates: Vec<f32>,
858    buffers: Vec<Vec<Array1<f32>>>,
859}
860
861impl ModalityAligner {
862    /// Create a new aligner with the given reference and per-modality sample rates.
863    pub fn new(reference_rate: f32, modality_rates: Vec<f32>) -> Self {
864        let buffers = modality_rates.iter().map(|_| Vec::new()).collect();
865        Self {
866            reference_rate,
867            modality_rates,
868            buffers,
869        }
870    }
871
872    /// Push a sample for the given modality.
873    pub fn push(&mut self, modality_idx: usize, sample: Array1<f32>) {
874        if modality_idx < self.buffers.len() {
875            self.buffers[modality_idx].push(sample);
876        }
877    }
878
879    /// Try to produce an aligned frame.
880    ///
881    /// Returns `Some(vec)` containing one sample per modality when all
882    /// modalities have buffered enough data. Otherwise returns `None`.
883    pub fn try_align(&mut self) -> Option<Vec<Array1<f32>>> {
884        // For each modality, determine how many samples are needed per
885        // reference frame: ceil(modality_rate / reference_rate).
886        // When all modalities have at least that many samples, we take the
887        // most recent sample from each and drain the consumed portion.
888
889        let mut required: Vec<usize> = Vec::with_capacity(self.modality_rates.len());
890        for rate in &self.modality_rates {
891            let ratio = rate / self.reference_rate;
892            let need = ratio.ceil().max(1.0) as usize;
893            required.push(need);
894        }
895
896        // Check if all modalities have enough buffered data
897        for (i, &need) in required.iter().enumerate() {
898            if self.buffers[i].len() < need {
899                return None;
900            }
901        }
902
903        // Consume and return the last sample of each consumed chunk
904        let mut aligned = Vec::with_capacity(self.buffers.len());
905        for (i, &need) in required.iter().enumerate() {
906            // Take the last sample from the consumed window
907            let sample = self.buffers[i][need - 1].clone();
908            // Drain consumed samples
909            self.buffers[i].drain(..need);
910            aligned.push(sample);
911        }
912
913        Some(aligned)
914    }
915
916    /// Clear all internal buffers.
917    pub fn clear(&mut self) {
918        for buf in &mut self.buffers {
919            buf.clear();
920        }
921    }
922}
923
924// ---------------------------------------------------------------------------
925// Tests
926// ---------------------------------------------------------------------------
927
928#[cfg(test)]
929mod tests {
930    use super::*;
931
932    fn make_encoder_config(
933        modality: Modality,
934        input_dim: usize,
935        proj_dim: usize,
936    ) -> ModalityEncoderConfig {
937        ModalityEncoderConfig {
938            modality,
939            input_dim,
940            projection_dim: proj_dim,
941            num_layers: 2,
942        }
943    }
944
945    fn make_default_config() -> MultiModalConfig {
946        MultiModalConfig {
947            fusion_dim: 16,
948            fusion_strategy: FusionStrategy::Addition,
949            output_dim: 4,
950            modalities: vec![
951                make_encoder_config(Modality::Audio, 8, 16),
952                make_encoder_config(Modality::Vision, 12, 16),
953                make_encoder_config(Modality::Sensor, 6, 16),
954            ],
955            context_length: 512,
956        }
957    }
958
959    // 1. test_modality_encoder_creation
960    #[test]
961    fn test_modality_encoder_creation() {
962        let cfg = make_encoder_config(Modality::Audio, 8, 16);
963        let enc = ModalityEncoder::new(cfg).expect("failed to create encoder");
964        assert_eq!(enc.input_dim(), 8);
965        assert_eq!(enc.output_dim(), 16);
966    }
967
968    // 2. test_modality_encoder_forward
969    #[test]
970    fn test_modality_encoder_forward() {
971        let cfg = make_encoder_config(Modality::Vision, 12, 16);
972        let enc = ModalityEncoder::new(cfg).expect("failed to create encoder");
973        let input = Array1::from_vec(vec![0.1; 12]);
974        let output = enc.encode(&input).expect("encode failed");
975        assert_eq!(output.len(), 16);
976        // Output should be finite
977        assert!(output.iter().all(|v| v.is_finite()));
978    }
979
980    // 3. test_fusion_concatenation
981    #[test]
982    fn test_fusion_concatenation() {
983        let fusion_dim = 8;
984        let n = 3;
985        let layer = FusionLayer::new(FusionStrategy::Concatenation, n, fusion_dim)
986            .expect("failed to create fusion layer");
987        let inputs: Vec<Array1<f32>> = (0..n)
988            .map(|_| Array1::from_vec(vec![0.5; fusion_dim]))
989            .collect();
990        let out = layer.fuse(&inputs).expect("fuse failed");
991        assert_eq!(out.len(), fusion_dim);
992        assert!(out.iter().all(|v| v.is_finite()));
993    }
994
995    // 4. test_fusion_addition
996    #[test]
997    fn test_fusion_addition() {
998        let fusion_dim = 8;
999        let n = 3;
1000        let layer = FusionLayer::new(FusionStrategy::Addition, n, fusion_dim)
1001            .expect("failed to create fusion layer");
1002        let inputs: Vec<Array1<f32>> = (0..n).map(|_| Array1::ones(fusion_dim)).collect();
1003        let out = layer.fuse(&inputs).expect("fuse failed");
1004        assert_eq!(out.len(), fusion_dim);
1005        // Each element should be 3.0 (1+1+1)
1006        for &v in out.iter() {
1007            assert!((v - 3.0).abs() < 1e-6);
1008        }
1009    }
1010
1011    // 5. test_fusion_gated
1012    #[test]
1013    fn test_fusion_gated() {
1014        let fusion_dim = 8;
1015        let n = 2;
1016        let layer = FusionLayer::new(FusionStrategy::Gated, n, fusion_dim)
1017            .expect("failed to create fusion layer");
1018        let inputs: Vec<Array1<f32>> = (0..n)
1019            .map(|_| Array1::from_vec(vec![0.3; fusion_dim]))
1020            .collect();
1021        let out = layer.fuse(&inputs).expect("fuse failed");
1022        assert_eq!(out.len(), fusion_dim);
1023        assert!(out.iter().all(|v| v.is_finite()));
1024    }
1025
1026    // 6. test_fusion_cross_attention
1027    #[test]
1028    fn test_fusion_cross_attention() {
1029        let fusion_dim = 8;
1030        let n = 2;
1031        let layer = FusionLayer::new(
1032            FusionStrategy::CrossAttention { num_heads: 2 },
1033            n,
1034            fusion_dim,
1035        )
1036        .expect("failed to create fusion layer");
1037        let inputs: Vec<Array1<f32>> = (0..n)
1038            .map(|_| Array1::from_vec(vec![0.2; fusion_dim]))
1039            .collect();
1040        let out = layer.fuse(&inputs).expect("fuse failed");
1041        assert_eq!(out.len(), fusion_dim);
1042        assert!(out.iter().all(|v| v.is_finite()));
1043    }
1044
1045    // 7. test_fusion_bottleneck
1046    #[test]
1047    fn test_fusion_bottleneck() {
1048        let fusion_dim = 8;
1049        let n = 3;
1050        let layer = FusionLayer::new(
1051            FusionStrategy::Bottleneck { bottleneck_dim: 4 },
1052            n,
1053            fusion_dim,
1054        )
1055        .expect("failed to create fusion layer");
1056        let inputs: Vec<Array1<f32>> = (0..n)
1057            .map(|_| Array1::from_vec(vec![0.4; fusion_dim]))
1058            .collect();
1059        let out = layer.fuse(&inputs).expect("fuse failed");
1060        assert_eq!(out.len(), fusion_dim);
1061        assert!(out.iter().all(|v| v.is_finite()));
1062    }
1063
1064    // 8. test_multimodal_model_creation
1065    #[test]
1066    fn test_multimodal_model_creation() {
1067        let config = make_default_config();
1068        let model = MultiModalModel::new(config).expect("failed to create model");
1069        assert_eq!(model.num_modalities(), 3);
1070        assert_eq!(model.modality_names().len(), 3);
1071        assert!(model.total_params() > 0);
1072    }
1073
1074    // 9. test_multimodal_forward
1075    #[test]
1076    fn test_multimodal_forward() {
1077        let config = make_default_config();
1078        let mut model = MultiModalModel::new(config).expect("failed to create model");
1079
1080        let audio = Array1::from_vec(vec![0.1; 8]);
1081        let vision = Array1::from_vec(vec![0.2; 12]);
1082        let sensor = Array1::from_vec(vec![0.3; 6]);
1083
1084        let out = model
1085            .forward_multimodal(&[audio, vision, sensor])
1086            .expect("forward failed");
1087        assert_eq!(out.len(), 4);
1088        assert!(out.iter().all(|v| v.is_finite()));
1089    }
1090
1091    // 10. test_multimodal_missing_modalities
1092    #[test]
1093    fn test_multimodal_missing_modalities() {
1094        let config = make_default_config();
1095        let mut model = MultiModalModel::new(config).expect("failed to create model");
1096
1097        let audio = Some(Array1::from_vec(vec![0.1; 8]));
1098        let vision = None; // missing
1099        let sensor = Some(Array1::from_vec(vec![0.3; 6]));
1100
1101        let out = model
1102            .forward_with_missing(&[audio, vision, sensor])
1103            .expect("forward_with_missing failed");
1104        assert_eq!(out.len(), 4);
1105        assert!(out.iter().all(|v| v.is_finite()));
1106    }
1107
1108    // 11. test_multimodal_signal_predictor
1109    #[test]
1110    fn test_multimodal_signal_predictor() {
1111        let config = make_default_config();
1112        let mut model = MultiModalModel::new(config).expect("failed to create model");
1113
1114        // Total input dim = 8 + 12 + 6 = 26
1115        let input = Array1::from_vec(vec![0.1; 26]);
1116        let out = model.step(&input).expect("step failed");
1117        assert_eq!(out.len(), 4);
1118        assert!(out.iter().all(|v| v.is_finite()));
1119
1120        // Test reset
1121        model.reset();
1122        assert_eq!(model.context_window(), 512);
1123    }
1124
1125    // 12. test_modality_aligner
1126    #[test]
1127    fn test_modality_aligner() {
1128        // Reference rate 10 Hz, modality A at 10 Hz, modality B at 20 Hz
1129        let mut aligner = ModalityAligner::new(10.0, vec![10.0, 20.0]);
1130
1131        // Push one sample for modality A
1132        aligner.push(0, Array1::from_vec(vec![1.0, 2.0]));
1133        // Not enough for modality B yet (needs ceil(20/10) = 2 samples)
1134        aligner.push(1, Array1::from_vec(vec![3.0, 4.0]));
1135        assert!(aligner.try_align().is_none());
1136
1137        // Push second sample for modality B
1138        aligner.push(1, Array1::from_vec(vec![5.0, 6.0]));
1139        let aligned = aligner.try_align().expect("should have aligned frame");
1140        assert_eq!(aligned.len(), 2);
1141        // Modality A: the first sample [1, 2]
1142        assert!((aligned[0][0] - 1.0).abs() < 1e-6);
1143        // Modality B: the last consumed sample [5, 6]
1144        assert!((aligned[1][0] - 5.0).abs() < 1e-6);
1145
1146        // After consuming, buffers should be empty
1147        assert!(aligner.try_align().is_none());
1148    }
1149
1150    // 13. test_multimodal_numerical_stability
1151    #[test]
1152    fn test_multimodal_numerical_stability() {
1153        let config = make_default_config();
1154        let mut model = MultiModalModel::new(config).expect("failed to create model");
1155
1156        // Test with very large inputs
1157        let audio_large = Array1::from_vec(vec![1e6; 8]);
1158        let vision_large = Array1::from_vec(vec![1e6; 12]);
1159        let sensor_large = Array1::from_vec(vec![1e6; 6]);
1160        let out = model.forward_multimodal(&[audio_large, vision_large, sensor_large]);
1161        // Should either succeed with finite values or return a numerical error
1162        match out {
1163            Ok(o) => assert!(o.iter().all(|v| v.is_finite()), "output should be finite"),
1164            Err(ModelError::NumericalInstability { .. }) => {
1165                // acceptable
1166            }
1167            Err(e) => panic!("unexpected error: {e}"),
1168        }
1169
1170        // Test with very small inputs
1171        let audio_small = Array1::from_vec(vec![1e-30; 8]);
1172        let vision_small = Array1::from_vec(vec![1e-30; 12]);
1173        let sensor_small = Array1::from_vec(vec![1e-30; 6]);
1174        let out = model
1175            .forward_multimodal(&[audio_small, vision_small, sensor_small])
1176            .expect("small inputs should not cause errors");
1177        assert!(
1178            out.iter().all(|v| v.is_finite()),
1179            "output should be finite for small inputs"
1180        );
1181    }
1182
1183    // 14. test_autoregressive_model_trait
1184    #[test]
1185    fn test_autoregressive_model_trait() {
1186        let config = make_default_config();
1187        let model = MultiModalModel::new(config).expect("failed to create model");
1188        assert_eq!(model.hidden_dim(), 16);
1189        assert_eq!(model.state_dim(), 16);
1190        assert_eq!(model.num_layers(), 1);
1191        assert_eq!(model.model_type(), ModelType::MultiModal);
1192        let states = model.get_states();
1193        assert_eq!(states.len(), 1);
1194    }
1195
1196    // 15. test_modality_display
1197    #[test]
1198    fn test_modality_display() {
1199        assert_eq!(format!("{}", Modality::Audio), "Audio");
1200        assert_eq!(format!("{}", Modality::Vision), "Vision");
1201        assert_eq!(
1202            format!("{}", Modality::Custom("Lidar".to_string())),
1203            "Custom(Lidar)"
1204        );
1205    }
1206
1207    // 16. test_encoder_dimension_mismatch
1208    #[test]
1209    fn test_encoder_dimension_mismatch() {
1210        let cfg = make_encoder_config(Modality::Audio, 8, 16);
1211        let enc = ModalityEncoder::new(cfg).expect("failed to create encoder");
1212        let bad_input = Array1::from_vec(vec![0.1; 5]); // wrong dim
1213        assert!(enc.encode(&bad_input).is_err());
1214    }
1215
1216    // 17. test_aligner_clear
1217    #[test]
1218    fn test_aligner_clear() {
1219        let mut aligner = ModalityAligner::new(10.0, vec![10.0]);
1220        aligner.push(0, Array1::from_vec(vec![1.0]));
1221        aligner.clear();
1222        assert!(aligner.try_align().is_none());
1223    }
1224}