1use crate::error::{ModelError, ModelResult};
34use crate::{AutoregressiveModel, ModelType};
35use kizzasi_core::{sigmoid, CoreResult, HiddenState, SignalPredictor};
36use scirs2_core::ndarray::{Array1, Array2};
37
38#[allow(unused_imports)]
39use tracing::{debug, instrument, trace};
40
41struct SeededRng {
47 state: u64,
48}
49
50impl SeededRng {
51 fn new(seed: u64) -> Self {
52 Self { state: seed.max(1) }
53 }
54
55 fn next_f32(&mut self) -> f32 {
57 self.state ^= self.state << 13;
58 self.state ^= self.state >> 7;
59 self.state ^= self.state << 17;
60 (self.state as f64 / u64::MAX as f64 * 2.0 - 1.0) as f32
61 }
62}
63
64#[derive(Debug, Clone, PartialEq, Eq, Hash)]
70pub enum Modality {
71 Audio,
73 Vision,
75 Sensor,
77 Control,
79 Text,
81 Custom(String),
83}
84
85impl std::fmt::Display for Modality {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 match self {
88 Modality::Audio => write!(f, "Audio"),
89 Modality::Vision => write!(f, "Vision"),
90 Modality::Sensor => write!(f, "Sensor"),
91 Modality::Control => write!(f, "Control"),
92 Modality::Text => write!(f, "Text"),
93 Modality::Custom(name) => write!(f, "Custom({name})"),
94 }
95 }
96}
97
98#[derive(Debug, Clone)]
104pub struct ModalityEncoderConfig {
105 pub modality: Modality,
107 pub input_dim: usize,
109 pub projection_dim: usize,
111 pub num_layers: usize,
113}
114
115pub struct ModalityEncoder {
117 config: ModalityEncoderConfig,
118 layers: Vec<(Array2<f32>, Array1<f32>)>,
120 norm: Option<(Array1<f32>, Array1<f32>)>,
122}
123
124impl ModalityEncoder {
125 pub fn new(config: ModalityEncoderConfig) -> ModelResult<Self> {
127 if config.input_dim == 0 {
128 return Err(ModelError::invalid_config("input_dim must be > 0"));
129 }
130 if config.projection_dim == 0 {
131 return Err(ModelError::invalid_config("projection_dim must be > 0"));
132 }
133 if config.num_layers == 0 {
134 return Err(ModelError::invalid_config("num_layers must be > 0"));
135 }
136
137 let mut rng =
138 SeededRng::new(42 + config.input_dim as u64 * 7 + config.projection_dim as u64 * 13);
139
140 let mut layers = Vec::with_capacity(config.num_layers);
141
142 for i in 0..config.num_layers {
143 let (in_dim, out_dim) = if i == 0 {
144 (config.input_dim, config.projection_dim)
145 } else {
146 (config.projection_dim, config.projection_dim)
147 };
148
149 let scale = (2.0 / in_dim as f32).sqrt();
150 let weight = Array2::from_shape_fn((in_dim, out_dim), |_| rng.next_f32() * scale);
151 let bias = Array1::zeros(out_dim);
152 layers.push((weight, bias));
153 }
154
155 let gamma = Array1::ones(config.projection_dim);
157 let beta = Array1::zeros(config.projection_dim);
158 let norm = Some((gamma, beta));
159
160 Ok(Self {
161 config,
162 layers,
163 norm,
164 })
165 }
166
167 pub fn encode(&self, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
169 if input.len() != self.config.input_dim {
170 return Err(ModelError::dimension_mismatch(
171 format!("ModalityEncoder({}) input", self.config.modality),
172 self.config.input_dim,
173 input.len(),
174 ));
175 }
176
177 let mut x = input.clone();
178 for (i, (weight, bias)) in self.layers.iter().enumerate() {
179 x = x.dot(weight) + bias;
180 if i + 1 < self.layers.len() {
182 x.mapv_inplace(|v| v.max(0.0));
183 }
184 }
185
186 if let Some((gamma, beta)) = &self.norm {
188 x = layer_norm_1d(&x, gamma, beta);
189 }
190
191 Ok(x)
192 }
193
194 pub fn input_dim(&self) -> usize {
196 self.config.input_dim
197 }
198
199 pub fn output_dim(&self) -> usize {
201 self.config.projection_dim
202 }
203}
204
205fn layer_norm_1d(x: &Array1<f32>, gamma: &Array1<f32>, beta: &Array1<f32>) -> Array1<f32> {
207 let n = x.len() as f32;
208 let mean = x.sum() / n;
209 let var = x.mapv(|v| (v - mean).powi(2)).sum() / n;
210 let std_inv = 1.0 / (var + 1e-5_f32).sqrt();
211 let normalized = x.mapv(|v| (v - mean) * std_inv);
212 &normalized * gamma + beta
213}
214
215#[derive(Debug, Clone)]
221pub enum FusionStrategy {
222 Concatenation,
224 Addition,
226 Gated,
228 CrossAttention {
230 num_heads: usize,
232 },
233 Bottleneck {
235 bottleneck_dim: usize,
237 },
238}
239
240pub struct FusionLayer {
246 strategy: FusionStrategy,
247 fusion_dim: usize,
248 num_modalities: usize,
249 concat_proj: Option<Array2<f32>>,
251 gate_weights: Option<Vec<Array2<f32>>>,
253 attention_q: Option<Vec<Array2<f32>>>,
255 attention_k: Option<Vec<Array2<f32>>>,
256 attention_v: Option<Vec<Array2<f32>>>,
257 bottleneck_down: Option<Array2<f32>>,
259 bottleneck_up: Option<Array2<f32>>,
260}
261
262impl FusionLayer {
263 pub fn new(
265 strategy: FusionStrategy,
266 num_modalities: usize,
267 fusion_dim: usize,
268 ) -> ModelResult<Self> {
269 if num_modalities == 0 {
270 return Err(ModelError::invalid_config("num_modalities must be > 0"));
271 }
272 if fusion_dim == 0 {
273 return Err(ModelError::invalid_config("fusion_dim must be > 0"));
274 }
275
276 let mut rng = SeededRng::new(1337 + num_modalities as u64 * 11 + fusion_dim as u64 * 3);
277
278 let mut layer = Self {
279 strategy: strategy.clone(),
280 fusion_dim,
281 num_modalities,
282 concat_proj: None,
283 gate_weights: None,
284 attention_q: None,
285 attention_k: None,
286 attention_v: None,
287 bottleneck_down: None,
288 bottleneck_up: None,
289 };
290
291 match &strategy {
292 FusionStrategy::Concatenation => {
293 let concat_dim = fusion_dim * num_modalities;
294 let scale = (2.0 / concat_dim as f32).sqrt();
295 let proj =
296 Array2::from_shape_fn((concat_dim, fusion_dim), |_| rng.next_f32() * scale);
297 layer.concat_proj = Some(proj);
298 }
299 FusionStrategy::Addition => {
300 }
302 FusionStrategy::Gated => {
303 let scale = (2.0 / fusion_dim as f32).sqrt();
304 let gates: Vec<Array2<f32>> = (0..num_modalities)
305 .map(|_| {
306 Array2::from_shape_fn((fusion_dim, fusion_dim), |_| rng.next_f32() * scale)
307 })
308 .collect();
309 layer.gate_weights = Some(gates);
310 }
311 FusionStrategy::CrossAttention { num_heads } => {
312 if !fusion_dim.is_multiple_of(*num_heads) {
313 return Err(ModelError::invalid_config(format!(
314 "fusion_dim ({fusion_dim}) must be divisible by num_heads ({num_heads})"
315 )));
316 }
317 let scale = (2.0 / fusion_dim as f32).sqrt();
318 let make_projs = |rng: &mut SeededRng| -> Vec<Array2<f32>> {
319 (0..num_modalities)
320 .map(|_| {
321 Array2::from_shape_fn((fusion_dim, fusion_dim), |_| {
322 rng.next_f32() * scale
323 })
324 })
325 .collect()
326 };
327 layer.attention_q = Some(make_projs(&mut rng));
328 layer.attention_k = Some(make_projs(&mut rng));
329 layer.attention_v = Some(make_projs(&mut rng));
330 }
331 FusionStrategy::Bottleneck { bottleneck_dim } => {
332 if *bottleneck_dim == 0 {
333 return Err(ModelError::invalid_config("bottleneck_dim must be > 0"));
334 }
335 let concat_dim = fusion_dim * num_modalities;
336 let scale_down = (2.0 / concat_dim as f32).sqrt();
337 let scale_up = (2.0 / *bottleneck_dim as f32).sqrt();
338 layer.bottleneck_down =
339 Some(Array2::from_shape_fn((concat_dim, *bottleneck_dim), |_| {
340 rng.next_f32() * scale_down
341 }));
342 layer.bottleneck_up =
343 Some(Array2::from_shape_fn((*bottleneck_dim, fusion_dim), |_| {
344 rng.next_f32() * scale_up
345 }));
346 }
347 }
348
349 Ok(layer)
350 }
351
352 pub fn fuse(&self, encoded_modalities: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
354 if encoded_modalities.len() != self.num_modalities {
355 return Err(ModelError::dimension_mismatch(
356 "FusionLayer modality count",
357 self.num_modalities,
358 encoded_modalities.len(),
359 ));
360 }
361
362 for (i, enc) in encoded_modalities.iter().enumerate() {
364 if enc.len() != self.fusion_dim {
365 return Err(ModelError::dimension_mismatch(
366 format!("FusionLayer modality {i} dim"),
367 self.fusion_dim,
368 enc.len(),
369 ));
370 }
371 }
372
373 match &self.strategy {
374 FusionStrategy::Concatenation => self.fuse_concatenation(encoded_modalities),
375 FusionStrategy::Addition => self.fuse_addition(encoded_modalities),
376 FusionStrategy::Gated => self.fuse_gated(encoded_modalities),
377 FusionStrategy::CrossAttention { num_heads } => {
378 self.fuse_cross_attention(encoded_modalities, *num_heads)
379 }
380 FusionStrategy::Bottleneck { .. } => self.fuse_bottleneck(encoded_modalities),
381 }
382 }
383
384 fn fuse_concatenation(&self, encoded_modalities: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
385 let concat_dim = self.fusion_dim * self.num_modalities;
386 let mut concat = Array1::zeros(concat_dim);
387 for (i, enc) in encoded_modalities.iter().enumerate() {
388 let start = i * self.fusion_dim;
389 for (j, &val) in enc.iter().enumerate() {
390 concat[start + j] = val;
391 }
392 }
393 let proj = self
394 .concat_proj
395 .as_ref()
396 .ok_or_else(|| ModelError::not_initialized("concat_proj"))?;
397 Ok(concat.dot(proj))
398 }
399
400 fn fuse_addition(&self, encoded_modalities: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
401 let mut result = Array1::zeros(self.fusion_dim);
402 for enc in encoded_modalities {
403 result += enc;
404 }
405 Ok(result)
406 }
407
408 fn fuse_gated(&self, encoded_modalities: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
409 let gate_weights = self
410 .gate_weights
411 .as_ref()
412 .ok_or_else(|| ModelError::not_initialized("gate_weights"))?;
413
414 let mut result = Array1::zeros(self.fusion_dim);
415 for (i, enc) in encoded_modalities.iter().enumerate() {
416 let pre_gate = enc.dot(&gate_weights[i]);
417 let gate = sigmoid(&pre_gate);
418 result += &(enc * &gate);
419 }
420 Ok(result)
421 }
422
423 fn fuse_cross_attention(
424 &self,
425 encoded_modalities: &[Array1<f32>],
426 num_heads: usize,
427 ) -> ModelResult<Array1<f32>> {
428 let q_projs = self
429 .attention_q
430 .as_ref()
431 .ok_or_else(|| ModelError::not_initialized("attention_q"))?;
432 let k_projs = self
433 .attention_k
434 .as_ref()
435 .ok_or_else(|| ModelError::not_initialized("attention_k"))?;
436 let v_projs = self
437 .attention_v
438 .as_ref()
439 .ok_or_else(|| ModelError::not_initialized("attention_v"))?;
440
441 let head_dim = self.fusion_dim / num_heads;
442 let scale = 1.0 / (head_dim as f32).sqrt();
443 let n = self.num_modalities;
444
445 let mut fused = Array1::zeros(self.fusion_dim);
447
448 for i in 0..n {
449 let q = encoded_modalities[i].dot(&q_projs[i]);
450
451 let mut attn_output: Array1<f32> = Array1::zeros(self.fusion_dim);
453
454 for h in 0..num_heads {
456 let h_start = h * head_dim;
457 let h_end = h_start + head_dim;
458
459 let q_h = q.slice(scirs2_core::ndarray::s![h_start..h_end]);
460
461 let mut scores = Vec::with_capacity(n);
463 let mut values = Vec::with_capacity(n);
464 for j in 0..n {
465 let k = encoded_modalities[j].dot(&k_projs[j]);
466 let v = encoded_modalities[j].dot(&v_projs[j]);
467 let k_h = k.slice(scirs2_core::ndarray::s![h_start..h_end]);
468 let score = q_h.dot(&k_h) * scale;
469 scores.push(score);
470 values.push(v.slice(scirs2_core::ndarray::s![h_start..h_end]).to_owned());
471 }
472
473 let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
475 let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
476 let sum_exp: f32 = exp_scores.iter().sum();
477 let sum_exp_safe = if sum_exp.abs() < 1e-10 {
478 1e-10
479 } else {
480 sum_exp
481 };
482
483 for (j, v_h) in values.iter().enumerate() {
485 let weight = exp_scores[j] / sum_exp_safe;
486 for (k, &val) in v_h.iter().enumerate() {
487 attn_output[h_start + k] += weight * val;
488 }
489 }
490 }
491
492 fused = fused + attn_output;
493 }
494
495 let divisor = n as f32;
497 fused.mapv_inplace(|v| v / divisor);
498
499 Ok(fused)
500 }
501
502 fn fuse_bottleneck(&self, encoded_modalities: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
503 let down = self
504 .bottleneck_down
505 .as_ref()
506 .ok_or_else(|| ModelError::not_initialized("bottleneck_down"))?;
507 let up = self
508 .bottleneck_up
509 .as_ref()
510 .ok_or_else(|| ModelError::not_initialized("bottleneck_up"))?;
511
512 let concat_dim = self.fusion_dim * self.num_modalities;
514 let mut concat = Array1::zeros(concat_dim);
515 for (i, enc) in encoded_modalities.iter().enumerate() {
516 let start = i * self.fusion_dim;
517 for (j, &val) in enc.iter().enumerate() {
518 concat[start + j] = val;
519 }
520 }
521
522 let bottleneck = concat.dot(down);
524 let activated = bottleneck.mapv(|v| v.max(0.0));
525 Ok(activated.dot(up))
526 }
527}
528
529#[derive(Debug, Clone)]
535pub struct MultiModalConfig {
536 pub fusion_dim: usize,
538 pub fusion_strategy: FusionStrategy,
540 pub output_dim: usize,
542 pub modalities: Vec<ModalityEncoderConfig>,
544 pub context_length: usize,
546}
547
548pub struct MultiModalModel {
550 pub config: MultiModalConfig,
552 encoders: Vec<ModalityEncoder>,
553 fusion: FusionLayer,
554 output_proj: Array2<f32>,
555 output_bias: Array1<f32>,
556 state: Array1<f32>,
558}
559
560impl MultiModalModel {
561 pub fn new(config: MultiModalConfig) -> ModelResult<Self> {
563 if config.modalities.is_empty() {
564 return Err(ModelError::invalid_config(
565 "at least one modality is required",
566 ));
567 }
568 if config.fusion_dim == 0 {
569 return Err(ModelError::invalid_config("fusion_dim must be > 0"));
570 }
571 if config.output_dim == 0 {
572 return Err(ModelError::invalid_config("output_dim must be > 0"));
573 }
574 if config.context_length == 0 {
575 return Err(ModelError::invalid_config("context_length must be > 0"));
576 }
577
578 for mc in &config.modalities {
580 if mc.projection_dim != config.fusion_dim {
581 return Err(ModelError::invalid_config(format!(
582 "modality {} projection_dim ({}) must match fusion_dim ({})",
583 mc.modality, mc.projection_dim, config.fusion_dim
584 )));
585 }
586 }
587
588 let encoders: Vec<ModalityEncoder> = config
590 .modalities
591 .iter()
592 .map(|mc| ModalityEncoder::new(mc.clone()))
593 .collect::<ModelResult<Vec<_>>>()?;
594
595 let fusion = FusionLayer::new(
597 config.fusion_strategy.clone(),
598 config.modalities.len(),
599 config.fusion_dim,
600 )?;
601
602 let mut rng = SeededRng::new(99 + config.fusion_dim as u64 * 5);
604 let scale = (2.0 / config.fusion_dim as f32).sqrt();
605 let output_proj = Array2::from_shape_fn((config.fusion_dim, config.output_dim), |_| {
606 rng.next_f32() * scale
607 });
608 let output_bias = Array1::zeros(config.output_dim);
609
610 let state = Array1::zeros(config.fusion_dim);
611
612 Ok(Self {
613 config,
614 encoders,
615 fusion,
616 output_proj,
617 output_bias,
618 state,
619 })
620 }
621
622 pub fn forward_multimodal(&mut self, inputs: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
627 if inputs.len() != self.encoders.len() {
628 return Err(ModelError::dimension_mismatch(
629 "MultiModalModel input count",
630 self.encoders.len(),
631 inputs.len(),
632 ));
633 }
634
635 let encoded: Vec<Array1<f32>> = self
637 .encoders
638 .iter()
639 .zip(inputs.iter())
640 .map(|(enc, inp)| enc.encode(inp))
641 .collect::<ModelResult<Vec<_>>>()?;
642
643 let fused = self.fusion.fuse(&encoded)?;
645
646 if fused.iter().any(|v| v.is_nan() || v.is_infinite()) {
648 return Err(ModelError::numerical_instability(
649 "forward_multimodal",
650 "NaN or Inf detected after fusion",
651 ));
652 }
653
654 self.state = fused.clone();
656
657 let output = fused.dot(&self.output_proj) + &self.output_bias;
659 Ok(output)
660 }
661
662 pub fn forward_with_missing(
666 &mut self,
667 inputs: &[Option<Array1<f32>>],
668 ) -> ModelResult<Array1<f32>> {
669 if inputs.len() != self.encoders.len() {
670 return Err(ModelError::dimension_mismatch(
671 "MultiModalModel input count",
672 self.encoders.len(),
673 inputs.len(),
674 ));
675 }
676
677 let encoded: Vec<Array1<f32>> = self
679 .encoders
680 .iter()
681 .zip(inputs.iter())
682 .map(|(enc, maybe_inp)| match maybe_inp {
683 Some(inp) => enc.encode(inp),
684 None => Ok(Array1::zeros(enc.output_dim())),
685 })
686 .collect::<ModelResult<Vec<_>>>()?;
687
688 let fused = self.fusion.fuse(&encoded)?;
690
691 if fused.iter().any(|v| v.is_nan() || v.is_infinite()) {
692 return Err(ModelError::numerical_instability(
693 "forward_with_missing",
694 "NaN or Inf detected after fusion",
695 ));
696 }
697
698 self.state = fused.clone();
699
700 let output = fused.dot(&self.output_proj) + &self.output_bias;
701 Ok(output)
702 }
703
704 pub fn num_modalities(&self) -> usize {
706 self.encoders.len()
707 }
708
709 pub fn modality_names(&self) -> Vec<&Modality> {
711 self.config
712 .modalities
713 .iter()
714 .map(|mc| &mc.modality)
715 .collect()
716 }
717
718 pub fn total_params(&self) -> usize {
720 let mut count = 0usize;
721
722 for enc in &self.encoders {
724 for (w, b) in &enc.layers {
725 count += w.len() + b.len();
726 }
727 if let Some((g, b)) = &enc.norm {
728 count += g.len() + b.len();
729 }
730 }
731
732 if let Some(p) = &self.fusion.concat_proj {
734 count += p.len();
735 }
736 if let Some(gates) = &self.fusion.gate_weights {
737 for g in gates {
738 count += g.len();
739 }
740 }
741 if let Some(qs) = &self.fusion.attention_q {
742 for q in qs {
743 count += q.len();
744 }
745 }
746 if let Some(ks) = &self.fusion.attention_k {
747 for k in ks {
748 count += k.len();
749 }
750 }
751 if let Some(vs) = &self.fusion.attention_v {
752 for v in vs {
753 count += v.len();
754 }
755 }
756 if let Some(d) = &self.fusion.bottleneck_down {
757 count += d.len();
758 }
759 if let Some(u) = &self.fusion.bottleneck_up {
760 count += u.len();
761 }
762
763 count += self.output_proj.len() + self.output_bias.len();
765
766 count
767 }
768}
769
770impl SignalPredictor for MultiModalModel {
771 #[instrument(skip(self, input))]
776 fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
777 let total_input_dim: usize = self.encoders.iter().map(|e| e.input_dim()).sum();
779
780 if input.len() != total_input_dim {
781 return Err(kizzasi_core::CoreError::DimensionMismatch {
782 expected: total_input_dim,
783 got: input.len(),
784 });
785 }
786
787 let mut offset = 0;
788 let mut per_modality = Vec::with_capacity(self.encoders.len());
789 for enc in &self.encoders {
790 let dim = enc.input_dim();
791 let slice = input
792 .slice(scirs2_core::ndarray::s![offset..offset + dim])
793 .to_owned();
794 per_modality.push(slice);
795 offset += dim;
796 }
797
798 self.forward_multimodal(&per_modality)
799 .map_err(|e| kizzasi_core::CoreError::Generic(e.to_string()))
800 }
801
802 #[instrument(skip(self))]
803 fn reset(&mut self) {
804 debug!("Resetting MultiModalModel state");
805 self.state = Array1::zeros(self.config.fusion_dim);
806 }
807
808 fn context_window(&self) -> usize {
809 self.config.context_length
810 }
811}
812
813impl AutoregressiveModel for MultiModalModel {
814 fn hidden_dim(&self) -> usize {
815 self.config.fusion_dim
816 }
817
818 fn state_dim(&self) -> usize {
819 self.config.fusion_dim
820 }
821
822 fn num_layers(&self) -> usize {
823 1
825 }
826
827 fn model_type(&self) -> ModelType {
828 ModelType::MultiModal
829 }
830
831 fn get_states(&self) -> Vec<HiddenState> {
832 vec![HiddenState::new(self.config.fusion_dim, 1)]
833 }
834
835 fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
836 if states.len() != 1 {
837 return Err(ModelError::state_count_mismatch(
838 "MultiModal",
839 1,
840 states.len(),
841 ));
842 }
843 Ok(())
844 }
845}
846
847pub struct ModalityAligner {
856 reference_rate: f32,
857 modality_rates: Vec<f32>,
858 buffers: Vec<Vec<Array1<f32>>>,
859}
860
861impl ModalityAligner {
862 pub fn new(reference_rate: f32, modality_rates: Vec<f32>) -> Self {
864 let buffers = modality_rates.iter().map(|_| Vec::new()).collect();
865 Self {
866 reference_rate,
867 modality_rates,
868 buffers,
869 }
870 }
871
872 pub fn push(&mut self, modality_idx: usize, sample: Array1<f32>) {
874 if modality_idx < self.buffers.len() {
875 self.buffers[modality_idx].push(sample);
876 }
877 }
878
879 pub fn try_align(&mut self) -> Option<Vec<Array1<f32>>> {
884 let mut required: Vec<usize> = Vec::with_capacity(self.modality_rates.len());
890 for rate in &self.modality_rates {
891 let ratio = rate / self.reference_rate;
892 let need = ratio.ceil().max(1.0) as usize;
893 required.push(need);
894 }
895
896 for (i, &need) in required.iter().enumerate() {
898 if self.buffers[i].len() < need {
899 return None;
900 }
901 }
902
903 let mut aligned = Vec::with_capacity(self.buffers.len());
905 for (i, &need) in required.iter().enumerate() {
906 let sample = self.buffers[i][need - 1].clone();
908 self.buffers[i].drain(..need);
910 aligned.push(sample);
911 }
912
913 Some(aligned)
914 }
915
916 pub fn clear(&mut self) {
918 for buf in &mut self.buffers {
919 buf.clear();
920 }
921 }
922}
923
924#[cfg(test)]
929mod tests {
930 use super::*;
931
932 fn make_encoder_config(
933 modality: Modality,
934 input_dim: usize,
935 proj_dim: usize,
936 ) -> ModalityEncoderConfig {
937 ModalityEncoderConfig {
938 modality,
939 input_dim,
940 projection_dim: proj_dim,
941 num_layers: 2,
942 }
943 }
944
945 fn make_default_config() -> MultiModalConfig {
946 MultiModalConfig {
947 fusion_dim: 16,
948 fusion_strategy: FusionStrategy::Addition,
949 output_dim: 4,
950 modalities: vec![
951 make_encoder_config(Modality::Audio, 8, 16),
952 make_encoder_config(Modality::Vision, 12, 16),
953 make_encoder_config(Modality::Sensor, 6, 16),
954 ],
955 context_length: 512,
956 }
957 }
958
959 #[test]
961 fn test_modality_encoder_creation() {
962 let cfg = make_encoder_config(Modality::Audio, 8, 16);
963 let enc = ModalityEncoder::new(cfg).expect("failed to create encoder");
964 assert_eq!(enc.input_dim(), 8);
965 assert_eq!(enc.output_dim(), 16);
966 }
967
968 #[test]
970 fn test_modality_encoder_forward() {
971 let cfg = make_encoder_config(Modality::Vision, 12, 16);
972 let enc = ModalityEncoder::new(cfg).expect("failed to create encoder");
973 let input = Array1::from_vec(vec![0.1; 12]);
974 let output = enc.encode(&input).expect("encode failed");
975 assert_eq!(output.len(), 16);
976 assert!(output.iter().all(|v| v.is_finite()));
978 }
979
980 #[test]
982 fn test_fusion_concatenation() {
983 let fusion_dim = 8;
984 let n = 3;
985 let layer = FusionLayer::new(FusionStrategy::Concatenation, n, fusion_dim)
986 .expect("failed to create fusion layer");
987 let inputs: Vec<Array1<f32>> = (0..n)
988 .map(|_| Array1::from_vec(vec![0.5; fusion_dim]))
989 .collect();
990 let out = layer.fuse(&inputs).expect("fuse failed");
991 assert_eq!(out.len(), fusion_dim);
992 assert!(out.iter().all(|v| v.is_finite()));
993 }
994
995 #[test]
997 fn test_fusion_addition() {
998 let fusion_dim = 8;
999 let n = 3;
1000 let layer = FusionLayer::new(FusionStrategy::Addition, n, fusion_dim)
1001 .expect("failed to create fusion layer");
1002 let inputs: Vec<Array1<f32>> = (0..n).map(|_| Array1::ones(fusion_dim)).collect();
1003 let out = layer.fuse(&inputs).expect("fuse failed");
1004 assert_eq!(out.len(), fusion_dim);
1005 for &v in out.iter() {
1007 assert!((v - 3.0).abs() < 1e-6);
1008 }
1009 }
1010
1011 #[test]
1013 fn test_fusion_gated() {
1014 let fusion_dim = 8;
1015 let n = 2;
1016 let layer = FusionLayer::new(FusionStrategy::Gated, n, fusion_dim)
1017 .expect("failed to create fusion layer");
1018 let inputs: Vec<Array1<f32>> = (0..n)
1019 .map(|_| Array1::from_vec(vec![0.3; fusion_dim]))
1020 .collect();
1021 let out = layer.fuse(&inputs).expect("fuse failed");
1022 assert_eq!(out.len(), fusion_dim);
1023 assert!(out.iter().all(|v| v.is_finite()));
1024 }
1025
1026 #[test]
1028 fn test_fusion_cross_attention() {
1029 let fusion_dim = 8;
1030 let n = 2;
1031 let layer = FusionLayer::new(
1032 FusionStrategy::CrossAttention { num_heads: 2 },
1033 n,
1034 fusion_dim,
1035 )
1036 .expect("failed to create fusion layer");
1037 let inputs: Vec<Array1<f32>> = (0..n)
1038 .map(|_| Array1::from_vec(vec![0.2; fusion_dim]))
1039 .collect();
1040 let out = layer.fuse(&inputs).expect("fuse failed");
1041 assert_eq!(out.len(), fusion_dim);
1042 assert!(out.iter().all(|v| v.is_finite()));
1043 }
1044
1045 #[test]
1047 fn test_fusion_bottleneck() {
1048 let fusion_dim = 8;
1049 let n = 3;
1050 let layer = FusionLayer::new(
1051 FusionStrategy::Bottleneck { bottleneck_dim: 4 },
1052 n,
1053 fusion_dim,
1054 )
1055 .expect("failed to create fusion layer");
1056 let inputs: Vec<Array1<f32>> = (0..n)
1057 .map(|_| Array1::from_vec(vec![0.4; fusion_dim]))
1058 .collect();
1059 let out = layer.fuse(&inputs).expect("fuse failed");
1060 assert_eq!(out.len(), fusion_dim);
1061 assert!(out.iter().all(|v| v.is_finite()));
1062 }
1063
1064 #[test]
1066 fn test_multimodal_model_creation() {
1067 let config = make_default_config();
1068 let model = MultiModalModel::new(config).expect("failed to create model");
1069 assert_eq!(model.num_modalities(), 3);
1070 assert_eq!(model.modality_names().len(), 3);
1071 assert!(model.total_params() > 0);
1072 }
1073
1074 #[test]
1076 fn test_multimodal_forward() {
1077 let config = make_default_config();
1078 let mut model = MultiModalModel::new(config).expect("failed to create model");
1079
1080 let audio = Array1::from_vec(vec![0.1; 8]);
1081 let vision = Array1::from_vec(vec![0.2; 12]);
1082 let sensor = Array1::from_vec(vec![0.3; 6]);
1083
1084 let out = model
1085 .forward_multimodal(&[audio, vision, sensor])
1086 .expect("forward failed");
1087 assert_eq!(out.len(), 4);
1088 assert!(out.iter().all(|v| v.is_finite()));
1089 }
1090
1091 #[test]
1093 fn test_multimodal_missing_modalities() {
1094 let config = make_default_config();
1095 let mut model = MultiModalModel::new(config).expect("failed to create model");
1096
1097 let audio = Some(Array1::from_vec(vec![0.1; 8]));
1098 let vision = None; let sensor = Some(Array1::from_vec(vec![0.3; 6]));
1100
1101 let out = model
1102 .forward_with_missing(&[audio, vision, sensor])
1103 .expect("forward_with_missing failed");
1104 assert_eq!(out.len(), 4);
1105 assert!(out.iter().all(|v| v.is_finite()));
1106 }
1107
1108 #[test]
1110 fn test_multimodal_signal_predictor() {
1111 let config = make_default_config();
1112 let mut model = MultiModalModel::new(config).expect("failed to create model");
1113
1114 let input = Array1::from_vec(vec![0.1; 26]);
1116 let out = model.step(&input).expect("step failed");
1117 assert_eq!(out.len(), 4);
1118 assert!(out.iter().all(|v| v.is_finite()));
1119
1120 model.reset();
1122 assert_eq!(model.context_window(), 512);
1123 }
1124
1125 #[test]
1127 fn test_modality_aligner() {
1128 let mut aligner = ModalityAligner::new(10.0, vec![10.0, 20.0]);
1130
1131 aligner.push(0, Array1::from_vec(vec![1.0, 2.0]));
1133 aligner.push(1, Array1::from_vec(vec![3.0, 4.0]));
1135 assert!(aligner.try_align().is_none());
1136
1137 aligner.push(1, Array1::from_vec(vec![5.0, 6.0]));
1139 let aligned = aligner.try_align().expect("should have aligned frame");
1140 assert_eq!(aligned.len(), 2);
1141 assert!((aligned[0][0] - 1.0).abs() < 1e-6);
1143 assert!((aligned[1][0] - 5.0).abs() < 1e-6);
1145
1146 assert!(aligner.try_align().is_none());
1148 }
1149
1150 #[test]
1152 fn test_multimodal_numerical_stability() {
1153 let config = make_default_config();
1154 let mut model = MultiModalModel::new(config).expect("failed to create model");
1155
1156 let audio_large = Array1::from_vec(vec![1e6; 8]);
1158 let vision_large = Array1::from_vec(vec![1e6; 12]);
1159 let sensor_large = Array1::from_vec(vec![1e6; 6]);
1160 let out = model.forward_multimodal(&[audio_large, vision_large, sensor_large]);
1161 match out {
1163 Ok(o) => assert!(o.iter().all(|v| v.is_finite()), "output should be finite"),
1164 Err(ModelError::NumericalInstability { .. }) => {
1165 }
1167 Err(e) => panic!("unexpected error: {e}"),
1168 }
1169
1170 let audio_small = Array1::from_vec(vec![1e-30; 8]);
1172 let vision_small = Array1::from_vec(vec![1e-30; 12]);
1173 let sensor_small = Array1::from_vec(vec![1e-30; 6]);
1174 let out = model
1175 .forward_multimodal(&[audio_small, vision_small, sensor_small])
1176 .expect("small inputs should not cause errors");
1177 assert!(
1178 out.iter().all(|v| v.is_finite()),
1179 "output should be finite for small inputs"
1180 );
1181 }
1182
1183 #[test]
1185 fn test_autoregressive_model_trait() {
1186 let config = make_default_config();
1187 let model = MultiModalModel::new(config).expect("failed to create model");
1188 assert_eq!(model.hidden_dim(), 16);
1189 assert_eq!(model.state_dim(), 16);
1190 assert_eq!(model.num_layers(), 1);
1191 assert_eq!(model.model_type(), ModelType::MultiModal);
1192 let states = model.get_states();
1193 assert_eq!(states.len(), 1);
1194 }
1195
1196 #[test]
1198 fn test_modality_display() {
1199 assert_eq!(format!("{}", Modality::Audio), "Audio");
1200 assert_eq!(format!("{}", Modality::Vision), "Vision");
1201 assert_eq!(
1202 format!("{}", Modality::Custom("Lidar".to_string())),
1203 "Custom(Lidar)"
1204 );
1205 }
1206
1207 #[test]
1209 fn test_encoder_dimension_mismatch() {
1210 let cfg = make_encoder_config(Modality::Audio, 8, 16);
1211 let enc = ModalityEncoder::new(cfg).expect("failed to create encoder");
1212 let bad_input = Array1::from_vec(vec![0.1; 5]); assert!(enc.encode(&bad_input).is_err());
1214 }
1215
1216 #[test]
1218 fn test_aligner_clear() {
1219 let mut aligner = ModalityAligner::new(10.0, vec![10.0]);
1220 aligner.push(0, Array1::from_vec(vec![1.0]));
1221 aligner.clear();
1222 assert!(aligner.try_align().is_none());
1223 }
1224}