1use crate::error::{TokenizerError, TokenizerResult};
22use crate::SignalTokenizer;
23use scirs2_core::ndarray::{Array1, Array2};
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26
27struct SeededRng {
33 state: u64,
34}
35
36impl SeededRng {
37 fn new(seed: u64) -> Self {
38 Self { state: seed.max(1) }
39 }
40
41 fn next_f32(&mut self) -> f32 {
43 self.state ^= self.state << 13;
44 self.state ^= self.state >> 7;
45 self.state ^= self.state << 17;
46 (self.state as f64 / u64::MAX as f64 * 2.0 - 1.0) as f32
47 }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
59pub enum ModalityKind {
60 Audio,
62 Control,
64 Sensor,
66 Video,
68 Custom(String),
70}
71
72impl ModalityKind {
73 pub fn key(&self) -> String {
75 match self {
76 ModalityKind::Audio => "audio".to_string(),
77 ModalityKind::Control => "control".to_string(),
78 ModalityKind::Sensor => "sensor".to_string(),
79 ModalityKind::Video => "video".to_string(),
80 ModalityKind::Custom(s) => format!("custom_{s}"),
81 }
82 }
83
84 fn seed(&self) -> u64 {
86 let key = self.key();
88 key.bytes().fold(5381u64, |acc, b| {
89 acc.wrapping_mul(33).wrapping_add(b as u64)
90 })
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ModalityTokenizerConfig {
101 pub modality: ModalityKind,
103 pub input_dim: usize,
105 pub token_dim: usize,
107 pub codebook_size: usize,
109 pub num_stages: usize,
111}
112
113impl ModalityTokenizerConfig {
114 pub fn validate(&self) -> TokenizerResult<()> {
116 if self.input_dim == 0 {
117 return Err(TokenizerError::InvalidConfig(
118 "input_dim must be > 0".into(),
119 ));
120 }
121 if self.token_dim == 0 {
122 return Err(TokenizerError::InvalidConfig(
123 "token_dim must be > 0".into(),
124 ));
125 }
126 if self.codebook_size == 0 {
127 return Err(TokenizerError::InvalidConfig(
128 "codebook_size must be > 0".into(),
129 ));
130 }
131 if self.num_stages == 0 {
132 return Err(TokenizerError::InvalidConfig(
133 "num_stages must be >= 1".into(),
134 ));
135 }
136 Ok(())
137 }
138}
139
140#[inline]
146fn gelu(x: f32) -> f32 {
147 let c = 0.797_884_6_f32; let v = c * (x + 0.044715 * x * x * x);
150 0.5 * x * (1.0 + v.tanh())
151}
152
153pub struct ModalityTokenizer {
164 config: ModalityTokenizerConfig,
165 encoder: Array2<f32>,
167 encoder_bias: Array1<f32>,
169 codebook: Array2<f32>,
171}
172
173impl ModalityTokenizer {
174 pub fn new(config: ModalityTokenizerConfig) -> TokenizerResult<Self> {
176 config.validate()?;
177
178 let seed = config.modality.seed();
179 let mut rng = SeededRng::new(seed);
180
181 let enc_scale = (6.0_f32 / (config.input_dim + config.token_dim) as f32).sqrt();
183 let encoder = Array2::from_shape_fn((config.input_dim, config.token_dim), |_| {
184 rng.next_f32() * enc_scale
185 });
186
187 let encoder_bias = Array1::zeros(config.token_dim);
188
189 let cb_scale = 1.0_f32 / (config.token_dim as f32).sqrt();
191 let codebook = Array2::from_shape_fn((config.codebook_size, config.token_dim), |_| {
192 rng.next_f32() * cb_scale
193 });
194
195 Ok(Self {
196 config,
197 encoder,
198 encoder_bias,
199 codebook,
200 })
201 }
202
203 pub fn encode(&self, input: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
205 if input.len() != self.config.input_dim {
206 return Err(TokenizerError::dim_mismatch(
207 self.config.input_dim,
208 input.len(),
209 "ModalityTokenizer::encode input_dim",
210 ));
211 }
212
213 let pre_act = input.dot(&self.encoder) + &self.encoder_bias;
215
216 let activated = pre_act.mapv(gelu);
218 Ok(activated)
219 }
220
221 pub fn quantize(&self, embedding: &Array1<f32>) -> TokenizerResult<(usize, Array1<f32>)> {
225 if embedding.len() != self.config.token_dim {
226 return Err(TokenizerError::dim_mismatch(
227 self.config.token_dim,
228 embedding.len(),
229 "ModalityTokenizer::quantize embedding dim",
230 ));
231 }
232
233 let mut best_idx = 0usize;
234 let mut best_dist = f32::INFINITY;
235
236 for k in 0..self.config.codebook_size {
237 let code = self.codebook.row(k);
238 let diff = embedding - &code;
239 let dist = diff.dot(&diff); if dist < best_dist {
241 best_dist = dist;
242 best_idx = k;
243 }
244 }
245
246 let quantized = self.codebook.row(best_idx).to_owned();
247 Ok((best_idx, quantized))
248 }
249
250 pub fn decode(&self, token_idx: usize) -> TokenizerResult<Array1<f32>> {
252 if token_idx >= self.config.codebook_size {
253 return Err(TokenizerError::out_of_range(
254 token_idx as f32,
255 0.0,
256 (self.config.codebook_size - 1) as f32,
257 "ModalityTokenizer::decode token_idx",
258 ));
259 }
260 Ok(self.codebook.row(token_idx).to_owned())
261 }
262
263 pub fn decode_embedding(&self, embedding: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
268 if embedding.len() != self.config.token_dim {
269 return Err(TokenizerError::dim_mismatch(
270 self.config.token_dim,
271 embedding.len(),
272 "ModalityTokenizer::decode_embedding embedding dim",
273 ));
274 }
275 let reconstructed = self.encoder.dot(embedding);
279 Ok(reconstructed)
280 }
281
282 pub fn input_dim(&self) -> usize {
284 self.config.input_dim
285 }
286
287 pub fn token_dim(&self) -> usize {
289 self.config.token_dim
290 }
291
292 pub fn codebook_size(&self) -> usize {
294 self.config.codebook_size
295 }
296
297 pub fn codebook(&self) -> &Array2<f32> {
299 &self.codebook
300 }
301
302 pub fn confidence(&self, embedding: &Array1<f32>, quantized: &Array1<f32>) -> f32 {
304 let diff = embedding - quantized;
305 let dist = diff.dot(&diff).sqrt();
306 1.0 / (1.0 + dist)
307 }
308}
309
310#[derive(Debug, Clone, Serialize, Deserialize)]
318pub struct CrossModalToken {
319 pub modality: ModalityKind,
321 pub token_idx: usize,
323 pub embedding: Array1<f32>,
325 pub confidence: f32,
328}
329
330pub struct CrossModalSequence {
336 pub tokens: Vec<CrossModalToken>,
338 pub shared_dim: usize,
340}
341
342impl CrossModalSequence {
343 pub fn new(shared_dim: usize) -> Self {
345 Self {
346 tokens: Vec::new(),
347 shared_dim,
348 }
349 }
350
351 pub fn push(&mut self, token: CrossModalToken) {
353 self.tokens.push(token);
354 }
355
356 pub fn len(&self) -> usize {
358 self.tokens.len()
359 }
360
361 pub fn is_empty(&self) -> bool {
363 self.tokens.is_empty()
364 }
365
366 pub fn to_embedding_matrix(&self) -> Array2<f32> {
370 let n = self.tokens.len();
371 if n == 0 {
372 return Array2::zeros((0, self.shared_dim));
373 }
374 let mut mat = Array2::zeros((n, self.shared_dim));
375 for (i, tok) in self.tokens.iter().enumerate() {
376 let row_len = tok.embedding.len().min(self.shared_dim);
377 for j in 0..row_len {
378 mat[[i, j]] = tok.embedding[j];
379 }
380 }
381 mat
382 }
383
384 pub fn filter_by_modality(&self, modality: &ModalityKind) -> Vec<&CrossModalToken> {
386 self.tokens
387 .iter()
388 .filter(|t| &t.modality == modality)
389 .collect()
390 }
391
392 pub fn modalities_present(&self) -> Vec<&ModalityKind> {
394 let mut seen: Vec<&ModalityKind> = Vec::new();
395 for tok in &self.tokens {
396 if !seen.contains(&&tok.modality) {
397 seen.push(&tok.modality);
398 }
399 }
400 seen
401 }
402}
403
404pub struct CrossModalAligner {
412 shared_dim: usize,
413 modality_counts: HashMap<String, usize>,
414 buffer: Vec<CrossModalToken>,
415}
416
417impl CrossModalAligner {
418 pub fn new(shared_dim: usize) -> Self {
420 Self {
421 shared_dim,
422 modality_counts: HashMap::new(),
423 buffer: Vec::new(),
424 }
425 }
426
427 pub fn push_token(&mut self, token: CrossModalToken) {
429 let key = token.modality.key();
430 *self.modality_counts.entry(key).or_insert(0) += 1;
431 self.buffer.push(token);
432 }
433
434 pub fn flush(&mut self) -> CrossModalSequence {
436 let mut seq = CrossModalSequence::new(self.shared_dim);
437 for tok in self.buffer.drain(..) {
438 seq.push(tok);
439 }
440 self.modality_counts.clear();
441 seq
442 }
443
444 pub fn len(&self) -> usize {
446 self.buffer.len()
447 }
448
449 pub fn is_empty(&self) -> bool {
451 self.buffer.is_empty()
452 }
453
454 pub fn count_for_modality(&self, modality: &ModalityKind) -> usize {
456 self.modality_counts
457 .get(&modality.key())
458 .copied()
459 .unwrap_or(0)
460 }
461}
462
463pub struct CrossModalTokenizer {
475 shared_dim: usize,
476 tokenizers: HashMap<String, ModalityTokenizer>,
478 shared_proj: Array2<f32>,
480 shared_bias: Array1<f32>,
482 modality_embeddings: HashMap<String, Array1<f32>>,
484}
485
486impl CrossModalTokenizer {
487 pub fn new(shared_dim: usize) -> TokenizerResult<Self> {
489 if shared_dim == 0 {
490 return Err(TokenizerError::InvalidConfig(
491 "shared_dim must be > 0".into(),
492 ));
493 }
494
495 let mut rng = SeededRng::new(0xdeadbeef_cafebabe);
497 let scale = 0.01_f32 / (shared_dim as f32).sqrt();
498 let shared_proj = Array2::from_shape_fn((shared_dim, shared_dim), |(i, j)| {
499 let identity = if i == j { 1.0_f32 } else { 0.0_f32 };
500 identity + rng.next_f32() * scale
501 });
502 let shared_bias = Array1::zeros(shared_dim);
503
504 Ok(Self {
505 shared_dim,
506 tokenizers: HashMap::new(),
507 shared_proj,
508 shared_bias,
509 modality_embeddings: HashMap::new(),
510 })
511 }
512
513 pub fn add_modality(&mut self, config: ModalityTokenizerConfig) -> TokenizerResult<()> {
517 if config.token_dim != self.shared_dim {
518 return Err(TokenizerError::InvalidConfig(format!(
519 "ModalityTokenizerConfig.token_dim ({}) must equal shared_dim ({})",
520 config.token_dim, self.shared_dim
521 )));
522 }
523 config.validate()?;
524
525 let key = config.modality.key();
526 let modality_seed = config.modality.seed().wrapping_add(0x1234_5678_9abc_def0);
527 let mut rng = SeededRng::new(modality_seed);
528 let embed_scale = 0.02_f32;
529 let mod_emb = Array1::from_shape_fn(self.shared_dim, |_| rng.next_f32() * embed_scale);
530
531 let tokenizer = ModalityTokenizer::new(config)?;
532 self.tokenizers.insert(key.clone(), tokenizer);
533 self.modality_embeddings.insert(key, mod_emb);
534 Ok(())
535 }
536
537 pub fn tokenize(
545 &self,
546 modality: &ModalityKind,
547 input: &Array1<f32>,
548 ) -> TokenizerResult<CrossModalToken> {
549 let key = modality.key();
550 let tok = self.tokenizers.get(&key).ok_or_else(|| {
551 TokenizerError::InvalidConfig(format!("modality '{key}' not registered"))
552 })?;
553 let mod_emb = self.modality_embeddings.get(&key).ok_or_else(|| {
554 TokenizerError::InternalError(format!("missing modality embedding for '{key}'"))
555 })?;
556
557 let encoded = tok.encode(input)?;
559
560 let with_mod = encoded + mod_emb;
562
563 let aligned = with_mod.dot(&self.shared_proj) + &self.shared_bias;
565
566 let (token_idx, quantized) = tok.quantize(&aligned)?;
568 let confidence = tok.confidence(&aligned, &quantized);
569
570 Ok(CrossModalToken {
571 modality: modality.clone(),
572 token_idx,
573 embedding: aligned,
574 confidence,
575 })
576 }
577
578 pub fn tokenize_batch(
580 &self,
581 inputs: &[(ModalityKind, Array1<f32>)],
582 ) -> TokenizerResult<CrossModalSequence> {
583 let mut seq = CrossModalSequence::new(self.shared_dim);
584 for (modality, signal) in inputs {
585 let token = self.tokenize(modality, signal)?;
586 seq.push(token);
587 }
588 Ok(seq)
589 }
590
591 pub fn decode(&self, token: &CrossModalToken) -> TokenizerResult<Array1<f32>> {
597 let key = token.modality.key();
598 let tok = self.tokenizers.get(&key).ok_or_else(|| {
599 TokenizerError::InvalidConfig(format!("modality '{key}' not registered"))
600 })?;
601 let mod_emb = self.modality_embeddings.get(&key).ok_or_else(|| {
602 TokenizerError::InternalError(format!("missing modality embedding for '{key}'"))
603 })?;
604
605 let quantized = tok.decode(token.token_idx)?;
607
608 let without_mod = quantized - mod_emb;
611
612 tok.decode_embedding(&without_mod)
614 }
615
616 pub fn shared_dim(&self) -> usize {
618 self.shared_dim
619 }
620
621 pub fn num_modalities(&self) -> usize {
623 self.tokenizers.len()
624 }
625
626 pub fn modality_names(&self) -> Vec<String> {
628 let mut names: Vec<String> = self.tokenizers.keys().cloned().collect();
629 names.sort();
630 names
631 }
632
633 pub fn robotics_preset() -> TokenizerResult<Self> {
639 let mut cmt = Self::new(64)?;
640 cmt.add_modality(ModalityTokenizerConfig {
641 modality: ModalityKind::Audio,
642 input_dim: 16,
643 token_dim: 64,
644 codebook_size: 512,
645 num_stages: 1,
646 })?;
647 cmt.add_modality(ModalityTokenizerConfig {
648 modality: ModalityKind::Control,
649 input_dim: 6,
650 token_dim: 64,
651 codebook_size: 256,
652 num_stages: 1,
653 })?;
654 cmt.add_modality(ModalityTokenizerConfig {
655 modality: ModalityKind::Sensor,
656 input_dim: 9,
657 token_dim: 64,
658 codebook_size: 256,
659 num_stages: 1,
660 })?;
661 Ok(cmt)
662 }
663
664 pub fn audio_video_preset() -> TokenizerResult<Self> {
666 let mut cmt = Self::new(256)?;
667 cmt.add_modality(ModalityTokenizerConfig {
668 modality: ModalityKind::Audio,
669 input_dim: 80,
670 token_dim: 256,
671 codebook_size: 1024,
672 num_stages: 2,
673 })?;
674 cmt.add_modality(ModalityTokenizerConfig {
675 modality: ModalityKind::Video,
676 input_dim: 512,
677 token_dim: 256,
678 codebook_size: 2048,
679 num_stages: 2,
680 })?;
681 Ok(cmt)
682 }
683}
684
685impl SignalTokenizer for CrossModalTokenizer {
698 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
705 let mut names = self.modality_names();
706 names.sort();
707
708 let total_input_dim: usize = names.iter().map(|n| self.tokenizers[n].input_dim()).sum();
710 if signal.len() != total_input_dim {
711 return Err(TokenizerError::dim_mismatch(
712 total_input_dim,
713 signal.len(),
714 "CrossModalTokenizer::encode total_input_dim",
715 ));
716 }
717
718 let mut out = Vec::with_capacity(names.len() * self.shared_dim);
719 let mut offset = 0usize;
720
721 for name in &names {
722 let tok = &self.tokenizers[name];
723 let dim = tok.input_dim();
724 let slice = signal.slice(scirs2_core::ndarray::s![offset..offset + dim]);
725 let input_owned = slice.to_owned();
726
727 let modality = Self::key_to_modality_kind(name);
730 let token = self.tokenize(&modality, &input_owned)?;
731 out.extend_from_slice(
732 token.embedding.as_slice().ok_or_else(|| {
733 TokenizerError::InternalError("embedding not contiguous".into())
734 })?,
735 );
736 offset += dim;
737 }
738
739 Ok(Array1::from_vec(out))
740 }
741
742 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
744 let mut names = self.modality_names();
745 names.sort();
746 let n = names.len();
747
748 if n == 0 {
749 return Ok(Array1::zeros(0));
750 }
751
752 let expected = n * self.shared_dim;
753 if tokens.len() != expected {
754 return Err(TokenizerError::dim_mismatch(
755 expected,
756 tokens.len(),
757 "CrossModalTokenizer::decode embedding length",
758 ));
759 }
760
761 let mut out = Vec::new();
762
763 for (i, name) in names.iter().enumerate() {
764 let start = i * self.shared_dim;
765 let end = start + self.shared_dim;
766 let emb_slice = tokens
767 .slice(scirs2_core::ndarray::s![start..end])
768 .to_owned();
769
770 let tok = &self.tokenizers[name];
771 let mod_emb = &self.modality_embeddings[name];
772
773 let without_mod = emb_slice - mod_emb;
775
776 let reconstructed = tok.decode_embedding(&without_mod)?;
778 out.extend_from_slice(reconstructed.as_slice().ok_or_else(|| {
779 TokenizerError::InternalError("reconstructed not contiguous".into())
780 })?);
781 }
782
783 Ok(Array1::from_vec(out))
784 }
785
786 fn embed_dim(&self) -> usize {
788 self.tokenizers.len() * self.shared_dim
789 }
790
791 fn vocab_size(&self) -> usize {
793 0
794 }
795}
796
797impl CrossModalTokenizer {
798 fn key_to_modality_kind(key: &str) -> ModalityKind {
800 match key {
801 "audio" => ModalityKind::Audio,
802 "control" => ModalityKind::Control,
803 "sensor" => ModalityKind::Sensor,
804 "video" => ModalityKind::Video,
805 other => {
806 let custom_name = other.strip_prefix("custom_").unwrap_or(other);
807 ModalityKind::Custom(custom_name.to_string())
808 }
809 }
810 }
811}
812
813#[cfg(test)]
818mod tests {
819 use super::*;
820 use scirs2_core::ndarray::Array1;
821
822 fn zeros(n: usize) -> Array1<f32> {
824 Array1::zeros(n)
825 }
826
827 fn ones(n: usize) -> Array1<f32> {
829 Array1::ones(n)
830 }
831
832 #[test]
836 fn test_modality_tokenizer_creation() {
837 let cfg = ModalityTokenizerConfig {
838 modality: ModalityKind::Audio,
839 input_dim: 16,
840 token_dim: 64,
841 codebook_size: 128,
842 num_stages: 1,
843 };
844 let tok = ModalityTokenizer::new(cfg).expect("should create successfully");
845 assert_eq!(tok.input_dim(), 16);
846 assert_eq!(tok.token_dim(), 64);
847 assert_eq!(tok.codebook_size(), 128);
848 assert_eq!(tok.codebook().shape(), [128, 64]);
849 }
850
851 #[test]
855 fn test_modality_tokenizer_encode() {
856 let cfg = ModalityTokenizerConfig {
857 modality: ModalityKind::Control,
858 input_dim: 6,
859 token_dim: 32,
860 codebook_size: 64,
861 num_stages: 1,
862 };
863 let tok = ModalityTokenizer::new(cfg).expect("create");
864 let input = ones(6);
865 let emb = tok.encode(&input).expect("encode");
866 assert_eq!(emb.len(), 32, "embedding must be token_dim");
867
868 let bad = ones(5);
870 assert!(tok.encode(&bad).is_err());
871 }
872
873 #[test]
877 fn test_modality_tokenizer_quantize() {
878 let cfg = ModalityTokenizerConfig {
879 modality: ModalityKind::Sensor,
880 input_dim: 9,
881 token_dim: 16,
882 codebook_size: 32,
883 num_stages: 1,
884 };
885 let tok = ModalityTokenizer::new(cfg).expect("create");
886 let emb = zeros(16);
887 let (idx, quantized) = tok.quantize(&emb).expect("quantize");
888 assert!(idx < 32, "token index must be within codebook");
889 assert_eq!(quantized.len(), 16, "quantized must be token_dim");
890 }
891
892 #[test]
896 fn test_modality_tokenizer_decode_roundtrip() {
897 let cfg = ModalityTokenizerConfig {
898 modality: ModalityKind::Audio,
899 input_dim: 8,
900 token_dim: 32,
901 codebook_size: 64,
902 num_stages: 1,
903 };
904 let tok = ModalityTokenizer::new(cfg).expect("create");
905 let input = ones(8);
906 let emb = tok.encode(&input).expect("encode");
907 let (idx, _quantized) = tok.quantize(&emb).expect("quantize");
908 let code = tok.decode(idx).expect("decode");
909 assert_eq!(code.len(), 32, "decoded codebook entry must be token_dim");
910
911 let reconstructed = tok.decode_embedding(&emb).expect("decode_embedding");
913 assert_eq!(reconstructed.len(), 8, "reconstructed must be input_dim");
914 }
915
916 #[test]
920 fn test_cross_modal_token_creation() {
921 let token = CrossModalToken {
922 modality: ModalityKind::Video,
923 token_idx: 42,
924 embedding: Array1::from_vec(vec![0.1, 0.2, 0.3]),
925 confidence: 0.95,
926 };
927 assert_eq!(token.token_idx, 42);
928 assert!((token.confidence - 0.95).abs() < 1e-6);
929 assert_eq!(token.modality, ModalityKind::Video);
930 assert_eq!(token.embedding.len(), 3);
931 }
932
933 #[test]
937 fn test_cross_modal_sequence_operations() {
938 let mut seq = CrossModalSequence::new(8);
939 assert!(seq.is_empty());
940
941 seq.push(CrossModalToken {
942 modality: ModalityKind::Audio,
943 token_idx: 0,
944 embedding: Array1::zeros(8),
945 confidence: 0.8,
946 });
947 seq.push(CrossModalToken {
948 modality: ModalityKind::Control,
949 token_idx: 1,
950 embedding: Array1::ones(8),
951 confidence: 0.7,
952 });
953 seq.push(CrossModalToken {
954 modality: ModalityKind::Audio,
955 token_idx: 2,
956 embedding: Array1::zeros(8),
957 confidence: 0.9,
958 });
959
960 assert_eq!(seq.len(), 3);
961 assert!(!seq.is_empty());
962
963 let audio_tokens = seq.filter_by_modality(&ModalityKind::Audio);
964 assert_eq!(audio_tokens.len(), 2);
965
966 let control_tokens = seq.filter_by_modality(&ModalityKind::Control);
967 assert_eq!(control_tokens.len(), 1);
968
969 let video_tokens = seq.filter_by_modality(&ModalityKind::Video);
970 assert_eq!(video_tokens.len(), 0);
971
972 let mods = seq.modalities_present();
973 assert_eq!(mods.len(), 2);
974 }
975
976 #[test]
980 fn test_cross_modal_sequence_embedding_matrix() {
981 let shared_dim = 16;
982 let mut seq = CrossModalSequence::new(shared_dim);
983 for _ in 0..5 {
984 seq.push(CrossModalToken {
985 modality: ModalityKind::Sensor,
986 token_idx: 0,
987 embedding: Array1::zeros(shared_dim),
988 confidence: 1.0,
989 });
990 }
991 let mat = seq.to_embedding_matrix();
992 assert_eq!(mat.shape(), [5, shared_dim]);
993
994 let empty = CrossModalSequence::new(shared_dim);
996 let empty_mat = empty.to_embedding_matrix();
997 assert_eq!(empty_mat.shape(), [0, shared_dim]);
998 }
999
1000 #[test]
1004 fn test_cross_modal_tokenizer_add_modality() {
1005 let mut cmt = CrossModalTokenizer::new(32).expect("new");
1006 cmt.add_modality(ModalityTokenizerConfig {
1007 modality: ModalityKind::Audio,
1008 input_dim: 16,
1009 token_dim: 32,
1010 codebook_size: 64,
1011 num_stages: 1,
1012 })
1013 .expect("add audio");
1014
1015 cmt.add_modality(ModalityTokenizerConfig {
1016 modality: ModalityKind::Control,
1017 input_dim: 6,
1018 token_dim: 32,
1019 codebook_size: 32,
1020 num_stages: 1,
1021 })
1022 .expect("add control");
1023
1024 assert_eq!(cmt.num_modalities(), 2);
1025 let names = cmt.modality_names();
1026 assert!(names.contains(&"audio".to_string()));
1027 assert!(names.contains(&"control".to_string()));
1028
1029 let bad = cmt.add_modality(ModalityTokenizerConfig {
1031 modality: ModalityKind::Sensor,
1032 input_dim: 9,
1033 token_dim: 16, codebook_size: 32,
1035 num_stages: 1,
1036 });
1037 assert!(bad.is_err());
1038 }
1039
1040 #[test]
1044 fn test_cross_modal_tokenizer_tokenize() {
1045 let mut cmt = CrossModalTokenizer::new(64).expect("new");
1046 cmt.add_modality(ModalityTokenizerConfig {
1047 modality: ModalityKind::Audio,
1048 input_dim: 16,
1049 token_dim: 64,
1050 codebook_size: 128,
1051 num_stages: 1,
1052 })
1053 .expect("add audio");
1054
1055 let input = ones(16);
1056 let token = cmt
1057 .tokenize(&ModalityKind::Audio, &input)
1058 .expect("tokenize");
1059 assert_eq!(token.modality, ModalityKind::Audio);
1060 assert!(token.token_idx < 128);
1061 assert_eq!(token.embedding.len(), 64);
1062 assert!(token.confidence > 0.0 && token.confidence <= 1.0);
1063
1064 assert!(cmt.tokenize(&ModalityKind::Video, &ones(512)).is_err());
1066 }
1067
1068 #[test]
1072 fn test_cross_modal_tokenizer_batch() {
1073 let mut cmt = CrossModalTokenizer::new(64).expect("new");
1074 cmt.add_modality(ModalityTokenizerConfig {
1075 modality: ModalityKind::Audio,
1076 input_dim: 16,
1077 token_dim: 64,
1078 codebook_size: 128,
1079 num_stages: 1,
1080 })
1081 .expect("add audio");
1082 cmt.add_modality(ModalityTokenizerConfig {
1083 modality: ModalityKind::Control,
1084 input_dim: 6,
1085 token_dim: 64,
1086 codebook_size: 64,
1087 num_stages: 1,
1088 })
1089 .expect("add control");
1090
1091 let inputs = vec![
1092 (ModalityKind::Audio, ones(16)),
1093 (ModalityKind::Control, zeros(6)),
1094 (ModalityKind::Audio, zeros(16)),
1095 ];
1096 let seq = cmt.tokenize_batch(&inputs).expect("batch");
1097 assert_eq!(seq.len(), 3);
1098 assert_eq!(seq.shared_dim, 64);
1099
1100 let mat = seq.to_embedding_matrix();
1101 assert_eq!(mat.shape(), [3, 64]);
1102
1103 let audio_tokens = seq.filter_by_modality(&ModalityKind::Audio);
1104 assert_eq!(audio_tokens.len(), 2);
1105 }
1106
1107 #[test]
1111 fn test_cross_modal_tokenizer_decode() {
1112 let mut cmt = CrossModalTokenizer::new(32).expect("new");
1113 cmt.add_modality(ModalityTokenizerConfig {
1114 modality: ModalityKind::Sensor,
1115 input_dim: 9,
1116 token_dim: 32,
1117 codebook_size: 64,
1118 num_stages: 1,
1119 })
1120 .expect("add sensor");
1121
1122 let input = ones(9);
1123 let token = cmt
1124 .tokenize(&ModalityKind::Sensor, &input)
1125 .expect("tokenize");
1126
1127 let reconstructed = cmt.decode(&token).expect("decode");
1128 assert_eq!(reconstructed.len(), 9, "decoded must match input_dim");
1129
1130 let bad_token = CrossModalToken {
1132 modality: ModalityKind::Video,
1133 token_idx: 0,
1134 embedding: Array1::zeros(32),
1135 confidence: 1.0,
1136 };
1137 assert!(cmt.decode(&bad_token).is_err());
1138 }
1139
1140 #[test]
1144 fn test_cross_modal_robotics_preset() {
1145 let cmt = CrossModalTokenizer::robotics_preset().expect("robotics preset");
1146 assert_eq!(cmt.shared_dim(), 64);
1147 assert_eq!(cmt.num_modalities(), 3);
1148
1149 let names = cmt.modality_names();
1150 assert!(names.contains(&"audio".to_string()));
1151 assert!(names.contains(&"control".to_string()));
1152 assert!(names.contains(&"sensor".to_string()));
1153
1154 let audio_token = cmt
1156 .tokenize(&ModalityKind::Audio, &ones(16))
1157 .expect("audio tokenize");
1158 assert_eq!(audio_token.embedding.len(), 64);
1159
1160 let control_token = cmt
1161 .tokenize(&ModalityKind::Control, &zeros(6))
1162 .expect("control tokenize");
1163 assert!(control_token.token_idx < 256);
1164
1165 let sensor_token = cmt
1166 .tokenize(&ModalityKind::Sensor, &ones(9))
1167 .expect("sensor tokenize");
1168 assert!(sensor_token.confidence > 0.0);
1169
1170 let inputs = vec![
1172 (ModalityKind::Audio, ones(16)),
1173 (ModalityKind::Control, zeros(6)),
1174 (ModalityKind::Sensor, ones(9)),
1175 ];
1176 let seq = cmt.tokenize_batch(&inputs).expect("batch");
1177 assert_eq!(seq.len(), 3);
1178 }
1179
1180 #[test]
1184 fn test_cross_modal_aligner() {
1185 let mut aligner = CrossModalAligner::new(64);
1186 assert!(aligner.is_empty());
1187
1188 aligner.push_token(CrossModalToken {
1189 modality: ModalityKind::Audio,
1190 token_idx: 0,
1191 embedding: Array1::zeros(64),
1192 confidence: 0.9,
1193 });
1194 aligner.push_token(CrossModalToken {
1195 modality: ModalityKind::Control,
1196 token_idx: 1,
1197 embedding: Array1::ones(64),
1198 confidence: 0.8,
1199 });
1200 aligner.push_token(CrossModalToken {
1201 modality: ModalityKind::Audio,
1202 token_idx: 2,
1203 embedding: Array1::zeros(64),
1204 confidence: 0.7,
1205 });
1206
1207 assert_eq!(aligner.len(), 3);
1208 assert!(!aligner.is_empty());
1209 assert_eq!(aligner.count_for_modality(&ModalityKind::Audio), 2);
1210 assert_eq!(aligner.count_for_modality(&ModalityKind::Control), 1);
1211 assert_eq!(aligner.count_for_modality(&ModalityKind::Sensor), 0);
1212
1213 let seq = aligner.flush();
1214 assert_eq!(seq.len(), 3);
1215 assert!(aligner.is_empty(), "buffer cleared after flush");
1216 assert_eq!(aligner.count_for_modality(&ModalityKind::Audio), 0);
1217
1218 let mat = seq.to_embedding_matrix();
1219 assert_eq!(mat.shape(), [3, 64]);
1220 }
1221
1222 #[test]
1226 fn test_modality_kind_key_and_seed() {
1227 assert_eq!(ModalityKind::Audio.key(), "audio");
1228 assert_eq!(ModalityKind::Control.key(), "control");
1229 assert_eq!(ModalityKind::Sensor.key(), "sensor");
1230 assert_eq!(ModalityKind::Video.key(), "video");
1231 assert_eq!(ModalityKind::Custom("robot".into()).key(), "custom_robot");
1232
1233 assert_eq!(ModalityKind::Audio.seed(), ModalityKind::Audio.seed());
1235 assert_ne!(ModalityKind::Audio.seed(), ModalityKind::Control.seed());
1236 }
1237
1238 #[test]
1242 fn test_audio_video_preset() {
1243 let cmt = CrossModalTokenizer::audio_video_preset().expect("audio_video preset");
1244 assert_eq!(cmt.shared_dim(), 256);
1245 assert_eq!(cmt.num_modalities(), 2);
1246
1247 let audio_tok = cmt
1248 .tokenize(&ModalityKind::Audio, &ones(80))
1249 .expect("audio tokenize");
1250 assert_eq!(audio_tok.embedding.len(), 256);
1251
1252 let video_tok = cmt
1253 .tokenize(&ModalityKind::Video, &ones(512))
1254 .expect("video tokenize");
1255 assert!(video_tok.token_idx < 2048);
1256 }
1257}