Skip to main content

kizzasi_tokenizer/
cross_modal.rs

1//! Cross-Modal Tokenization
2//!
3//! Provides unified tokenization of multiple signal modalities (audio, control signals,
4//! sensor data, video features) into a shared embedding space. This enables cross-modal
5//! alignment and joint reasoning over heterogeneous sensory streams.
6//!
7//! ## Architecture
8//!
9//! Each modality has its own linear encoder projecting into a shared embedding space.
10//! A discrete codebook (per modality) provides token indices for autoregressive use.
11//! A shared alignment projection aligns all modality embeddings into a common manifold.
12//! Modality-type embeddings (learned offsets) allow the model to distinguish modalities.
13//!
14//! ## Design Principles
15//!
16//! - **Unified token space**: All modalities produce tokens of `shared_dim` size.
17//! - **Residual codebooks**: Optional multi-stage residual VQ for fine-grained encoding.
18//! - **Confidence**: `1 / (1 + distance_to_nearest)` — high confidence = close match.
19//! - **Pure Rust**: No C/Fortran dependencies; deterministic xorshift64 initialization.
20
21use crate::error::{TokenizerError, TokenizerResult};
22use crate::SignalTokenizer;
23use scirs2_core::ndarray::{Array1, Array2};
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26
27// ---------------------------------------------------------------------------
28// Deterministic PRNG (xorshift64)
29// ---------------------------------------------------------------------------
30
31/// Simple xorshift64 PRNG for deterministic weight initialization.
32struct SeededRng {
33    state: u64,
34}
35
36impl SeededRng {
37    fn new(seed: u64) -> Self {
38        Self { state: seed.max(1) }
39    }
40
41    /// Returns a float in [-1, 1)
42    fn next_f32(&mut self) -> f32 {
43        self.state ^= self.state << 13;
44        self.state ^= self.state >> 7;
45        self.state ^= self.state << 17;
46        (self.state as f64 / u64::MAX as f64 * 2.0 - 1.0) as f32
47    }
48}
49
50// ---------------------------------------------------------------------------
51// Modality descriptor
52// ---------------------------------------------------------------------------
53
54/// Identifies the type of a signal modality.
55///
56/// Used to route signals to the correct per-modality encoder and to
57/// add the learned modality-type embedding.
58#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
59pub enum ModalityKind {
60    /// Continuous amplitude waveform (e.g. raw audio samples or mel features).
61    Audio,
62    /// Robot joint angles, velocities, or action commands.
63    Control,
64    /// IMU, pressure, temperature, or other physical sensor readings.
65    Sensor,
66    /// Pixel features or CNN/ViT embeddings from video frames.
67    Video,
68    /// User-defined modality with a descriptive name.
69    Custom(String),
70}
71
72impl ModalityKind {
73    /// Canonical string key used in hash-maps.
74    pub fn key(&self) -> String {
75        match self {
76            ModalityKind::Audio => "audio".to_string(),
77            ModalityKind::Control => "control".to_string(),
78            ModalityKind::Sensor => "sensor".to_string(),
79            ModalityKind::Video => "video".to_string(),
80            ModalityKind::Custom(s) => format!("custom_{s}"),
81        }
82    }
83
84    /// Deterministic seed derived from the modality name (for xorshift64 init).
85    fn seed(&self) -> u64 {
86        // Simple djb2-style hash of the key bytes
87        let key = self.key();
88        key.bytes().fold(5381u64, |acc, b| {
89            acc.wrapping_mul(33).wrapping_add(b as u64)
90        })
91    }
92}
93
94// ---------------------------------------------------------------------------
95// Per-modality configuration
96// ---------------------------------------------------------------------------
97
98/// Configuration for a single modality's tokenizer.
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ModalityTokenizerConfig {
101    /// Which modality this tokenizer handles.
102    pub modality: ModalityKind,
103    /// Dimensionality of the raw input signal for this modality.
104    pub input_dim: usize,
105    /// Shared embedding dimension (must equal `CrossModalTokenizer::shared_dim`).
106    pub token_dim: usize,
107    /// Number of VQ codebook entries.
108    pub codebook_size: usize,
109    /// Number of residual VQ stages (1 = standard VQ, >1 = RVQ).
110    pub num_stages: usize,
111}
112
113impl ModalityTokenizerConfig {
114    /// Validate the configuration fields.
115    pub fn validate(&self) -> TokenizerResult<()> {
116        if self.input_dim == 0 {
117            return Err(TokenizerError::InvalidConfig(
118                "input_dim must be > 0".into(),
119            ));
120        }
121        if self.token_dim == 0 {
122            return Err(TokenizerError::InvalidConfig(
123                "token_dim must be > 0".into(),
124            ));
125        }
126        if self.codebook_size == 0 {
127            return Err(TokenizerError::InvalidConfig(
128                "codebook_size must be > 0".into(),
129            ));
130        }
131        if self.num_stages == 0 {
132            return Err(TokenizerError::InvalidConfig(
133                "num_stages must be >= 1".into(),
134            ));
135        }
136        Ok(())
137    }
138}
139
140// ---------------------------------------------------------------------------
141// GELU activation
142// ---------------------------------------------------------------------------
143
144/// GELU activation: x * Φ(x), approximated via tanh.
145#[inline]
146fn gelu(x: f32) -> f32 {
147    // tanh approximation: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
148    let c = 0.797_884_6_f32; // sqrt(2/π)
149    let v = c * (x + 0.044715 * x * x * x);
150    0.5 * x * (1.0 + v.tanh())
151}
152
153// ---------------------------------------------------------------------------
154// Per-modality tokenizer
155// ---------------------------------------------------------------------------
156
157/// Encodes raw signals from a single modality into a shared token embedding space.
158///
159/// Internally uses:
160/// - A linear encoder with GELU activation: `input_dim → token_dim`
161/// - A nearest-neighbour codebook for discrete token assignment
162/// - A linear decoder (transpose of encoder weights) for approximate reconstruction
163pub struct ModalityTokenizer {
164    config: ModalityTokenizerConfig,
165    /// Encoder weight matrix: shape `(input_dim, token_dim)`.
166    encoder: Array2<f32>,
167    /// Encoder bias: shape `(token_dim,)`.
168    encoder_bias: Array1<f32>,
169    /// Codebook: shape `(codebook_size, token_dim)`.
170    codebook: Array2<f32>,
171}
172
173impl ModalityTokenizer {
174    /// Create a new modality tokenizer with deterministic weight initialization.
175    pub fn new(config: ModalityTokenizerConfig) -> TokenizerResult<Self> {
176        config.validate()?;
177
178        let seed = config.modality.seed();
179        let mut rng = SeededRng::new(seed);
180
181        // Xavier / Glorot uniform initialization scale
182        let enc_scale = (6.0_f32 / (config.input_dim + config.token_dim) as f32).sqrt();
183        let encoder = Array2::from_shape_fn((config.input_dim, config.token_dim), |_| {
184            rng.next_f32() * enc_scale
185        });
186
187        let encoder_bias = Array1::zeros(config.token_dim);
188
189        // Codebook: small random init, scaled by 1/sqrt(token_dim)
190        let cb_scale = 1.0_f32 / (config.token_dim as f32).sqrt();
191        let codebook = Array2::from_shape_fn((config.codebook_size, config.token_dim), |_| {
192            rng.next_f32() * cb_scale
193        });
194
195        Ok(Self {
196            config,
197            encoder,
198            encoder_bias,
199            codebook,
200        })
201    }
202
203    /// Project raw input through the encoder (linear + GELU), producing a `token_dim` embedding.
204    pub fn encode(&self, input: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
205        if input.len() != self.config.input_dim {
206            return Err(TokenizerError::dim_mismatch(
207                self.config.input_dim,
208                input.len(),
209                "ModalityTokenizer::encode input_dim",
210            ));
211        }
212
213        // Linear: out = input @ encoder + bias
214        let pre_act = input.dot(&self.encoder) + &self.encoder_bias;
215
216        // GELU element-wise
217        let activated = pre_act.mapv(gelu);
218        Ok(activated)
219    }
220
221    /// Find the nearest codebook entry (L2) and return `(token_idx, quantized_embedding)`.
222    ///
223    /// Confidence is defined as `1 / (1 + min_distance)`.
224    pub fn quantize(&self, embedding: &Array1<f32>) -> TokenizerResult<(usize, Array1<f32>)> {
225        if embedding.len() != self.config.token_dim {
226            return Err(TokenizerError::dim_mismatch(
227                self.config.token_dim,
228                embedding.len(),
229                "ModalityTokenizer::quantize embedding dim",
230            ));
231        }
232
233        let mut best_idx = 0usize;
234        let mut best_dist = f32::INFINITY;
235
236        for k in 0..self.config.codebook_size {
237            let code = self.codebook.row(k);
238            let diff = embedding - &code;
239            let dist = diff.dot(&diff); // squared L2
240            if dist < best_dist {
241                best_dist = dist;
242                best_idx = k;
243            }
244        }
245
246        let quantized = self.codebook.row(best_idx).to_owned();
247        Ok((best_idx, quantized))
248    }
249
250    /// Decode a discrete token index back to the embedding space (codebook lookup).
251    pub fn decode(&self, token_idx: usize) -> TokenizerResult<Array1<f32>> {
252        if token_idx >= self.config.codebook_size {
253            return Err(TokenizerError::out_of_range(
254                token_idx as f32,
255                0.0,
256                (self.config.codebook_size - 1) as f32,
257                "ModalityTokenizer::decode token_idx",
258            ));
259        }
260        Ok(self.codebook.row(token_idx).to_owned())
261    }
262
263    /// Decode an embedding back to the raw input space via the encoder weights.
264    ///
265    /// This is a pseudo-inverse: `out = encoder @ embedding` where
266    /// `encoder: (input_dim, token_dim)` and `embedding: (token_dim,)` → `(input_dim,)`.
267    pub fn decode_embedding(&self, embedding: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
268        if embedding.len() != self.config.token_dim {
269            return Err(TokenizerError::dim_mismatch(
270                self.config.token_dim,
271                embedding.len(),
272                "ModalityTokenizer::decode_embedding embedding dim",
273            ));
274        }
275        // encoder: (input_dim, token_dim)
276        // We want W @ e  where W: (input_dim, token_dim) and e: (token_dim,) → (input_dim,)
277        // ndarray Array2::dot(&Array1): (m, n) @ (n,) = (m,)
278        let reconstructed = self.encoder.dot(embedding);
279        Ok(reconstructed)
280    }
281
282    /// Raw input dimension.
283    pub fn input_dim(&self) -> usize {
284        self.config.input_dim
285    }
286
287    /// Shared token / embedding dimension.
288    pub fn token_dim(&self) -> usize {
289        self.config.token_dim
290    }
291
292    /// Number of discrete codebook entries.
293    pub fn codebook_size(&self) -> usize {
294        self.config.codebook_size
295    }
296
297    /// Read-only reference to the codebook.
298    pub fn codebook(&self) -> &Array2<f32> {
299        &self.codebook
300    }
301
302    /// Compute confidence for an embedding/quantized pair: `1 / (1 + dist)`.
303    pub fn confidence(&self, embedding: &Array1<f32>, quantized: &Array1<f32>) -> f32 {
304        let diff = embedding - quantized;
305        let dist = diff.dot(&diff).sqrt();
306        1.0 / (1.0 + dist)
307    }
308}
309
310// ---------------------------------------------------------------------------
311// Cross-modal token
312// ---------------------------------------------------------------------------
313
314/// A single token produced by cross-modal tokenization.
315///
316/// Carries both the modality identity and the token value in the shared space.
317#[derive(Debug, Clone, Serialize, Deserialize)]
318pub struct CrossModalToken {
319    /// Which modality produced this token.
320    pub modality: ModalityKind,
321    /// Discrete codebook index within that modality's codebook.
322    pub token_idx: usize,
323    /// Continuous embedding in the shared `token_dim`-dimensional space.
324    pub embedding: Array1<f32>,
325    /// Quantization confidence: `1 / (1 + ||embedding - nearest_code||)`.
326    /// High value ≈ the signal matched a codebook entry closely.
327    pub confidence: f32,
328}
329
330// ---------------------------------------------------------------------------
331// Cross-modal sequence
332// ---------------------------------------------------------------------------
333
334/// An ordered sequence of cross-modal tokens drawn from one or more modalities.
335pub struct CrossModalSequence {
336    /// Ordered tokens.
337    pub tokens: Vec<CrossModalToken>,
338    /// Embedding dimensionality shared by all tokens.
339    pub shared_dim: usize,
340}
341
342impl CrossModalSequence {
343    /// Create an empty sequence.
344    pub fn new(shared_dim: usize) -> Self {
345        Self {
346            tokens: Vec::new(),
347            shared_dim,
348        }
349    }
350
351    /// Append a token to the sequence.
352    pub fn push(&mut self, token: CrossModalToken) {
353        self.tokens.push(token);
354    }
355
356    /// Number of tokens in the sequence.
357    pub fn len(&self) -> usize {
358        self.tokens.len()
359    }
360
361    /// Returns `true` if the sequence contains no tokens.
362    pub fn is_empty(&self) -> bool {
363        self.tokens.is_empty()
364    }
365
366    /// Build a `(num_tokens, shared_dim)` embedding matrix from the sequence.
367    ///
368    /// Each row corresponds to one token's continuous embedding.
369    pub fn to_embedding_matrix(&self) -> Array2<f32> {
370        let n = self.tokens.len();
371        if n == 0 {
372            return Array2::zeros((0, self.shared_dim));
373        }
374        let mut mat = Array2::zeros((n, self.shared_dim));
375        for (i, tok) in self.tokens.iter().enumerate() {
376            let row_len = tok.embedding.len().min(self.shared_dim);
377            for j in 0..row_len {
378                mat[[i, j]] = tok.embedding[j];
379            }
380        }
381        mat
382    }
383
384    /// Return all tokens belonging to the specified modality.
385    pub fn filter_by_modality(&self, modality: &ModalityKind) -> Vec<&CrossModalToken> {
386        self.tokens
387            .iter()
388            .filter(|t| &t.modality == modality)
389            .collect()
390    }
391
392    /// Return the distinct modalities present in this sequence (in order of first appearance).
393    pub fn modalities_present(&self) -> Vec<&ModalityKind> {
394        let mut seen: Vec<&ModalityKind> = Vec::new();
395        for tok in &self.tokens {
396            if !seen.contains(&&tok.modality) {
397                seen.push(&tok.modality);
398            }
399        }
400        seen
401    }
402}
403
404// ---------------------------------------------------------------------------
405// Cross-modal aligner
406// ---------------------------------------------------------------------------
407
408/// Buffers tokens from different modalities and flushes them as an aligned sequence.
409///
410/// Useful for synchronising multi-modal streams at a common time step boundary.
411pub struct CrossModalAligner {
412    shared_dim: usize,
413    modality_counts: HashMap<String, usize>,
414    buffer: Vec<CrossModalToken>,
415}
416
417impl CrossModalAligner {
418    /// Create a new aligner for the given shared embedding dimension.
419    pub fn new(shared_dim: usize) -> Self {
420        Self {
421            shared_dim,
422            modality_counts: HashMap::new(),
423            buffer: Vec::new(),
424        }
425    }
426
427    /// Add a token to the alignment buffer.
428    pub fn push_token(&mut self, token: CrossModalToken) {
429        let key = token.modality.key();
430        *self.modality_counts.entry(key).or_insert(0) += 1;
431        self.buffer.push(token);
432    }
433
434    /// Consume the buffer and return it as a `CrossModalSequence`.
435    pub fn flush(&mut self) -> CrossModalSequence {
436        let mut seq = CrossModalSequence::new(self.shared_dim);
437        for tok in self.buffer.drain(..) {
438            seq.push(tok);
439        }
440        self.modality_counts.clear();
441        seq
442    }
443
444    /// Number of tokens currently buffered.
445    pub fn len(&self) -> usize {
446        self.buffer.len()
447    }
448
449    /// Returns `true` if the buffer is empty.
450    pub fn is_empty(&self) -> bool {
451        self.buffer.is_empty()
452    }
453
454    /// How many tokens from a given modality are currently buffered.
455    pub fn count_for_modality(&self, modality: &ModalityKind) -> usize {
456        self.modality_counts
457            .get(&modality.key())
458            .copied()
459            .unwrap_or(0)
460    }
461}
462
463// ---------------------------------------------------------------------------
464// Cross-modal tokenizer
465// ---------------------------------------------------------------------------
466
467/// Unified cross-modal tokenizer.
468///
469/// Manages one `ModalityTokenizer` per registered modality and applies:
470/// 1. Per-modality linear projection (input → shared_dim)
471/// 2. Modality-type embedding offset (for identity disambiguation)
472/// 3. Shared alignment projection (shared_dim → shared_dim)
473/// 4. Nearest-neighbour codebook quantization
474pub struct CrossModalTokenizer {
475    shared_dim: usize,
476    /// Per-modality tokenizers keyed by `ModalityKind::key()`.
477    tokenizers: HashMap<String, ModalityTokenizer>,
478    /// Shared alignment weight: `(shared_dim, shared_dim)`.
479    shared_proj: Array2<f32>,
480    /// Shared alignment bias: `(shared_dim,)`.
481    shared_bias: Array1<f32>,
482    /// Per-modality learned offset vectors: `(shared_dim,)`.
483    modality_embeddings: HashMap<String, Array1<f32>>,
484}
485
486impl CrossModalTokenizer {
487    /// Create a new cross-modal tokenizer with the given shared embedding dimension.
488    pub fn new(shared_dim: usize) -> TokenizerResult<Self> {
489        if shared_dim == 0 {
490            return Err(TokenizerError::InvalidConfig(
491                "shared_dim must be > 0".into(),
492            ));
493        }
494
495        // Initialize shared alignment projection as identity + small noise (xorshift64)
496        let mut rng = SeededRng::new(0xdeadbeef_cafebabe);
497        let scale = 0.01_f32 / (shared_dim as f32).sqrt();
498        let shared_proj = Array2::from_shape_fn((shared_dim, shared_dim), |(i, j)| {
499            let identity = if i == j { 1.0_f32 } else { 0.0_f32 };
500            identity + rng.next_f32() * scale
501        });
502        let shared_bias = Array1::zeros(shared_dim);
503
504        Ok(Self {
505            shared_dim,
506            tokenizers: HashMap::new(),
507            shared_proj,
508            shared_bias,
509            modality_embeddings: HashMap::new(),
510        })
511    }
512
513    /// Register a new modality.
514    ///
515    /// The `config.token_dim` must equal `self.shared_dim`.
516    pub fn add_modality(&mut self, config: ModalityTokenizerConfig) -> TokenizerResult<()> {
517        if config.token_dim != self.shared_dim {
518            return Err(TokenizerError::InvalidConfig(format!(
519                "ModalityTokenizerConfig.token_dim ({}) must equal shared_dim ({})",
520                config.token_dim, self.shared_dim
521            )));
522        }
523        config.validate()?;
524
525        let key = config.modality.key();
526        let modality_seed = config.modality.seed().wrapping_add(0x1234_5678_9abc_def0);
527        let mut rng = SeededRng::new(modality_seed);
528        let embed_scale = 0.02_f32;
529        let mod_emb = Array1::from_shape_fn(self.shared_dim, |_| rng.next_f32() * embed_scale);
530
531        let tokenizer = ModalityTokenizer::new(config)?;
532        self.tokenizers.insert(key.clone(), tokenizer);
533        self.modality_embeddings.insert(key, mod_emb);
534        Ok(())
535    }
536
537    /// Tokenize a single-modality input.
538    ///
539    /// Steps:
540    /// 1. Encode raw input → `shared_dim` embedding (per-modality encoder + GELU)
541    /// 2. Add modality-type embedding offset
542    /// 3. Apply shared alignment projection
543    /// 4. Quantize against per-modality codebook
544    pub fn tokenize(
545        &self,
546        modality: &ModalityKind,
547        input: &Array1<f32>,
548    ) -> TokenizerResult<CrossModalToken> {
549        let key = modality.key();
550        let tok = self.tokenizers.get(&key).ok_or_else(|| {
551            TokenizerError::InvalidConfig(format!("modality '{key}' not registered"))
552        })?;
553        let mod_emb = self.modality_embeddings.get(&key).ok_or_else(|| {
554            TokenizerError::InternalError(format!("missing modality embedding for '{key}'"))
555        })?;
556
557        // 1. Per-modality encode
558        let encoded = tok.encode(input)?;
559
560        // 2. Add modality-type offset
561        let with_mod = encoded + mod_emb;
562
563        // 3. Shared alignment: aligned = with_mod @ shared_proj + shared_bias
564        let aligned = with_mod.dot(&self.shared_proj) + &self.shared_bias;
565
566        // 4. Quantize
567        let (token_idx, quantized) = tok.quantize(&aligned)?;
568        let confidence = tok.confidence(&aligned, &quantized);
569
570        Ok(CrossModalToken {
571            modality: modality.clone(),
572            token_idx,
573            embedding: aligned,
574            confidence,
575        })
576    }
577
578    /// Tokenize a batch of (modality, signal) pairs and return them as a `CrossModalSequence`.
579    pub fn tokenize_batch(
580        &self,
581        inputs: &[(ModalityKind, Array1<f32>)],
582    ) -> TokenizerResult<CrossModalSequence> {
583        let mut seq = CrossModalSequence::new(self.shared_dim);
584        for (modality, signal) in inputs {
585            let token = self.tokenize(modality, signal)?;
586            seq.push(token);
587        }
588        Ok(seq)
589    }
590
591    /// Decode a `CrossModalToken` back to the raw input space.
592    ///
593    /// Uses the per-modality codebook entry as the quantized embedding,
594    /// inverts the shared projection, removes the modality offset,
595    /// and applies the pseudo-inverse decoder.
596    pub fn decode(&self, token: &CrossModalToken) -> TokenizerResult<Array1<f32>> {
597        let key = token.modality.key();
598        let tok = self.tokenizers.get(&key).ok_or_else(|| {
599            TokenizerError::InvalidConfig(format!("modality '{key}' not registered"))
600        })?;
601        let mod_emb = self.modality_embeddings.get(&key).ok_or_else(|| {
602            TokenizerError::InternalError(format!("missing modality embedding for '{key}'"))
603        })?;
604
605        // Codebook lookup gives quantized embedding in shared space
606        let quantized = tok.decode(token.token_idx)?;
607
608        // Invert shared projection (approximate: use transpose)
609        // aligned ≈ quantized  (we skip full inverse for efficiency)
610        let without_mod = quantized - mod_emb;
611
612        // Pseudo-inverse decode through encoder transpose
613        tok.decode_embedding(&without_mod)
614    }
615
616    /// The shared embedding dimension.
617    pub fn shared_dim(&self) -> usize {
618        self.shared_dim
619    }
620
621    /// Number of registered modalities.
622    pub fn num_modalities(&self) -> usize {
623        self.tokenizers.len()
624    }
625
626    /// Sorted list of registered modality keys.
627    pub fn modality_names(&self) -> Vec<String> {
628        let mut names: Vec<String> = self.tokenizers.keys().cloned().collect();
629        names.sort();
630        names
631    }
632
633    // -----------------------------------------------------------------------
634    // Presets
635    // -----------------------------------------------------------------------
636
637    /// Robotics preset: audio (16-dim), control (6-dim), sensor (9-dim) → shared_dim 64.
638    pub fn robotics_preset() -> TokenizerResult<Self> {
639        let mut cmt = Self::new(64)?;
640        cmt.add_modality(ModalityTokenizerConfig {
641            modality: ModalityKind::Audio,
642            input_dim: 16,
643            token_dim: 64,
644            codebook_size: 512,
645            num_stages: 1,
646        })?;
647        cmt.add_modality(ModalityTokenizerConfig {
648            modality: ModalityKind::Control,
649            input_dim: 6,
650            token_dim: 64,
651            codebook_size: 256,
652            num_stages: 1,
653        })?;
654        cmt.add_modality(ModalityTokenizerConfig {
655            modality: ModalityKind::Sensor,
656            input_dim: 9,
657            token_dim: 64,
658            codebook_size: 256,
659            num_stages: 1,
660        })?;
661        Ok(cmt)
662    }
663
664    /// Audio-video preset: audio (80-dim), video (512-dim) → shared_dim 256.
665    pub fn audio_video_preset() -> TokenizerResult<Self> {
666        let mut cmt = Self::new(256)?;
667        cmt.add_modality(ModalityTokenizerConfig {
668            modality: ModalityKind::Audio,
669            input_dim: 80,
670            token_dim: 256,
671            codebook_size: 1024,
672            num_stages: 2,
673        })?;
674        cmt.add_modality(ModalityTokenizerConfig {
675            modality: ModalityKind::Video,
676            input_dim: 512,
677            token_dim: 256,
678            codebook_size: 2048,
679            num_stages: 2,
680        })?;
681        Ok(cmt)
682    }
683}
684
685// ---------------------------------------------------------------------------
686// SignalTokenizer implementation
687// ---------------------------------------------------------------------------
688
689/// `SignalTokenizer` implementation for `CrossModalTokenizer`.
690///
691/// Treats the input as a concatenation of registered modality signals (in
692/// registration order). Encodes each slice, concatenates the resulting
693/// embeddings, and returns the full multi-modal embedding vector.
694///
695/// For `decode`, the embedding is split back into per-modality chunks,
696/// decoded, and concatenated.
697impl SignalTokenizer for CrossModalTokenizer {
698    /// Encode a concatenated multi-modal signal.
699    ///
700    /// The input must be the concatenation of all registered modalities'
701    /// raw signals (in sorted key order). Each modality's token embedding
702    /// is concatenated into a single output vector of length
703    /// `num_modalities * shared_dim`.
704    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
705        let mut names = self.modality_names();
706        names.sort();
707
708        // Verify total input length
709        let total_input_dim: usize = names.iter().map(|n| self.tokenizers[n].input_dim()).sum();
710        if signal.len() != total_input_dim {
711            return Err(TokenizerError::dim_mismatch(
712                total_input_dim,
713                signal.len(),
714                "CrossModalTokenizer::encode total_input_dim",
715            ));
716        }
717
718        let mut out = Vec::with_capacity(names.len() * self.shared_dim);
719        let mut offset = 0usize;
720
721        for name in &names {
722            let tok = &self.tokenizers[name];
723            let dim = tok.input_dim();
724            let slice = signal.slice(scirs2_core::ndarray::s![offset..offset + dim]);
725            let input_owned = slice.to_owned();
726
727            // Find the ModalityKind from the stored tokenizer config
728            // (we use the key to re-derive the modality kind by checking all known kinds)
729            let modality = Self::key_to_modality_kind(name);
730            let token = self.tokenize(&modality, &input_owned)?;
731            out.extend_from_slice(
732                token.embedding.as_slice().ok_or_else(|| {
733                    TokenizerError::InternalError("embedding not contiguous".into())
734                })?,
735            );
736            offset += dim;
737        }
738
739        Ok(Array1::from_vec(out))
740    }
741
742    /// Decode a concatenated embedding vector back to the raw input space.
743    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
744        let mut names = self.modality_names();
745        names.sort();
746        let n = names.len();
747
748        if n == 0 {
749            return Ok(Array1::zeros(0));
750        }
751
752        let expected = n * self.shared_dim;
753        if tokens.len() != expected {
754            return Err(TokenizerError::dim_mismatch(
755                expected,
756                tokens.len(),
757                "CrossModalTokenizer::decode embedding length",
758            ));
759        }
760
761        let mut out = Vec::new();
762
763        for (i, name) in names.iter().enumerate() {
764            let start = i * self.shared_dim;
765            let end = start + self.shared_dim;
766            let emb_slice = tokens
767                .slice(scirs2_core::ndarray::s![start..end])
768                .to_owned();
769
770            let tok = &self.tokenizers[name];
771            let mod_emb = &self.modality_embeddings[name];
772
773            // Remove modality offset
774            let without_mod = emb_slice - mod_emb;
775
776            // Pseudo-inverse decode
777            let reconstructed = tok.decode_embedding(&without_mod)?;
778            out.extend_from_slice(reconstructed.as_slice().ok_or_else(|| {
779                TokenizerError::InternalError("reconstructed not contiguous".into())
780            })?);
781        }
782
783        Ok(Array1::from_vec(out))
784    }
785
786    /// Total output embedding dimension: `num_modalities * shared_dim`.
787    fn embed_dim(&self) -> usize {
788        self.tokenizers.len() * self.shared_dim
789    }
790
791    /// Returns 0 (continuous-style tokenizer; each modality has its own discrete codebook).
792    fn vocab_size(&self) -> usize {
793        0
794    }
795}
796
797impl CrossModalTokenizer {
798    /// Reconstruct a `ModalityKind` from a string key.
799    fn key_to_modality_kind(key: &str) -> ModalityKind {
800        match key {
801            "audio" => ModalityKind::Audio,
802            "control" => ModalityKind::Control,
803            "sensor" => ModalityKind::Sensor,
804            "video" => ModalityKind::Video,
805            other => {
806                let custom_name = other.strip_prefix("custom_").unwrap_or(other);
807                ModalityKind::Custom(custom_name.to_string())
808            }
809        }
810    }
811}
812
813// ---------------------------------------------------------------------------
814// Tests
815// ---------------------------------------------------------------------------
816
817#[cfg(test)]
818mod tests {
819    use super::*;
820    use scirs2_core::ndarray::Array1;
821
822    // Helper: create a zero-filled input of a given length.
823    fn zeros(n: usize) -> Array1<f32> {
824        Array1::zeros(n)
825    }
826
827    // Helper: small non-zero input.
828    fn ones(n: usize) -> Array1<f32> {
829        Array1::ones(n)
830    }
831
832    // -----------------------------------------------------------------------
833    // 1. ModalityTokenizer creation
834    // -----------------------------------------------------------------------
835    #[test]
836    fn test_modality_tokenizer_creation() {
837        let cfg = ModalityTokenizerConfig {
838            modality: ModalityKind::Audio,
839            input_dim: 16,
840            token_dim: 64,
841            codebook_size: 128,
842            num_stages: 1,
843        };
844        let tok = ModalityTokenizer::new(cfg).expect("should create successfully");
845        assert_eq!(tok.input_dim(), 16);
846        assert_eq!(tok.token_dim(), 64);
847        assert_eq!(tok.codebook_size(), 128);
848        assert_eq!(tok.codebook().shape(), [128, 64]);
849    }
850
851    // -----------------------------------------------------------------------
852    // 2. ModalityTokenizer encode produces correct shape
853    // -----------------------------------------------------------------------
854    #[test]
855    fn test_modality_tokenizer_encode() {
856        let cfg = ModalityTokenizerConfig {
857            modality: ModalityKind::Control,
858            input_dim: 6,
859            token_dim: 32,
860            codebook_size: 64,
861            num_stages: 1,
862        };
863        let tok = ModalityTokenizer::new(cfg).expect("create");
864        let input = ones(6);
865        let emb = tok.encode(&input).expect("encode");
866        assert_eq!(emb.len(), 32, "embedding must be token_dim");
867
868        // Wrong dimension should error
869        let bad = ones(5);
870        assert!(tok.encode(&bad).is_err());
871    }
872
873    // -----------------------------------------------------------------------
874    // 3. ModalityTokenizer quantize returns valid token index
875    // -----------------------------------------------------------------------
876    #[test]
877    fn test_modality_tokenizer_quantize() {
878        let cfg = ModalityTokenizerConfig {
879            modality: ModalityKind::Sensor,
880            input_dim: 9,
881            token_dim: 16,
882            codebook_size: 32,
883            num_stages: 1,
884        };
885        let tok = ModalityTokenizer::new(cfg).expect("create");
886        let emb = zeros(16);
887        let (idx, quantized) = tok.quantize(&emb).expect("quantize");
888        assert!(idx < 32, "token index must be within codebook");
889        assert_eq!(quantized.len(), 16, "quantized must be token_dim");
890    }
891
892    // -----------------------------------------------------------------------
893    // 4. decode(quantize(encode(x))) roundtrip
894    // -----------------------------------------------------------------------
895    #[test]
896    fn test_modality_tokenizer_decode_roundtrip() {
897        let cfg = ModalityTokenizerConfig {
898            modality: ModalityKind::Audio,
899            input_dim: 8,
900            token_dim: 32,
901            codebook_size: 64,
902            num_stages: 1,
903        };
904        let tok = ModalityTokenizer::new(cfg).expect("create");
905        let input = ones(8);
906        let emb = tok.encode(&input).expect("encode");
907        let (idx, _quantized) = tok.quantize(&emb).expect("quantize");
908        let code = tok.decode(idx).expect("decode");
909        assert_eq!(code.len(), 32, "decoded codebook entry must be token_dim");
910
911        // Soft decode: pseudo-inverse should return input_dim vector
912        let reconstructed = tok.decode_embedding(&emb).expect("decode_embedding");
913        assert_eq!(reconstructed.len(), 8, "reconstructed must be input_dim");
914    }
915
916    // -----------------------------------------------------------------------
917    // 5. CrossModalToken creation
918    // -----------------------------------------------------------------------
919    #[test]
920    fn test_cross_modal_token_creation() {
921        let token = CrossModalToken {
922            modality: ModalityKind::Video,
923            token_idx: 42,
924            embedding: Array1::from_vec(vec![0.1, 0.2, 0.3]),
925            confidence: 0.95,
926        };
927        assert_eq!(token.token_idx, 42);
928        assert!((token.confidence - 0.95).abs() < 1e-6);
929        assert_eq!(token.modality, ModalityKind::Video);
930        assert_eq!(token.embedding.len(), 3);
931    }
932
933    // -----------------------------------------------------------------------
934    // 6. CrossModalSequence push, len, filter_by_modality
935    // -----------------------------------------------------------------------
936    #[test]
937    fn test_cross_modal_sequence_operations() {
938        let mut seq = CrossModalSequence::new(8);
939        assert!(seq.is_empty());
940
941        seq.push(CrossModalToken {
942            modality: ModalityKind::Audio,
943            token_idx: 0,
944            embedding: Array1::zeros(8),
945            confidence: 0.8,
946        });
947        seq.push(CrossModalToken {
948            modality: ModalityKind::Control,
949            token_idx: 1,
950            embedding: Array1::ones(8),
951            confidence: 0.7,
952        });
953        seq.push(CrossModalToken {
954            modality: ModalityKind::Audio,
955            token_idx: 2,
956            embedding: Array1::zeros(8),
957            confidence: 0.9,
958        });
959
960        assert_eq!(seq.len(), 3);
961        assert!(!seq.is_empty());
962
963        let audio_tokens = seq.filter_by_modality(&ModalityKind::Audio);
964        assert_eq!(audio_tokens.len(), 2);
965
966        let control_tokens = seq.filter_by_modality(&ModalityKind::Control);
967        assert_eq!(control_tokens.len(), 1);
968
969        let video_tokens = seq.filter_by_modality(&ModalityKind::Video);
970        assert_eq!(video_tokens.len(), 0);
971
972        let mods = seq.modalities_present();
973        assert_eq!(mods.len(), 2);
974    }
975
976    // -----------------------------------------------------------------------
977    // 7. CrossModalSequence embedding matrix shape
978    // -----------------------------------------------------------------------
979    #[test]
980    fn test_cross_modal_sequence_embedding_matrix() {
981        let shared_dim = 16;
982        let mut seq = CrossModalSequence::new(shared_dim);
983        for _ in 0..5 {
984            seq.push(CrossModalToken {
985                modality: ModalityKind::Sensor,
986                token_idx: 0,
987                embedding: Array1::zeros(shared_dim),
988                confidence: 1.0,
989            });
990        }
991        let mat = seq.to_embedding_matrix();
992        assert_eq!(mat.shape(), [5, shared_dim]);
993
994        // Empty sequence
995        let empty = CrossModalSequence::new(shared_dim);
996        let empty_mat = empty.to_embedding_matrix();
997        assert_eq!(empty_mat.shape(), [0, shared_dim]);
998    }
999
1000    // -----------------------------------------------------------------------
1001    // 8. CrossModalTokenizer add_modality
1002    // -----------------------------------------------------------------------
1003    #[test]
1004    fn test_cross_modal_tokenizer_add_modality() {
1005        let mut cmt = CrossModalTokenizer::new(32).expect("new");
1006        cmt.add_modality(ModalityTokenizerConfig {
1007            modality: ModalityKind::Audio,
1008            input_dim: 16,
1009            token_dim: 32,
1010            codebook_size: 64,
1011            num_stages: 1,
1012        })
1013        .expect("add audio");
1014
1015        cmt.add_modality(ModalityTokenizerConfig {
1016            modality: ModalityKind::Control,
1017            input_dim: 6,
1018            token_dim: 32,
1019            codebook_size: 32,
1020            num_stages: 1,
1021        })
1022        .expect("add control");
1023
1024        assert_eq!(cmt.num_modalities(), 2);
1025        let names = cmt.modality_names();
1026        assert!(names.contains(&"audio".to_string()));
1027        assert!(names.contains(&"control".to_string()));
1028
1029        // Wrong token_dim should fail
1030        let bad = cmt.add_modality(ModalityTokenizerConfig {
1031            modality: ModalityKind::Sensor,
1032            input_dim: 9,
1033            token_dim: 16, // mismatch
1034            codebook_size: 32,
1035            num_stages: 1,
1036        });
1037        assert!(bad.is_err());
1038    }
1039
1040    // -----------------------------------------------------------------------
1041    // 9. CrossModalTokenizer tokenize single modality
1042    // -----------------------------------------------------------------------
1043    #[test]
1044    fn test_cross_modal_tokenizer_tokenize() {
1045        let mut cmt = CrossModalTokenizer::new(64).expect("new");
1046        cmt.add_modality(ModalityTokenizerConfig {
1047            modality: ModalityKind::Audio,
1048            input_dim: 16,
1049            token_dim: 64,
1050            codebook_size: 128,
1051            num_stages: 1,
1052        })
1053        .expect("add audio");
1054
1055        let input = ones(16);
1056        let token = cmt
1057            .tokenize(&ModalityKind::Audio, &input)
1058            .expect("tokenize");
1059        assert_eq!(token.modality, ModalityKind::Audio);
1060        assert!(token.token_idx < 128);
1061        assert_eq!(token.embedding.len(), 64);
1062        assert!(token.confidence > 0.0 && token.confidence <= 1.0);
1063
1064        // Unregistered modality should error
1065        assert!(cmt.tokenize(&ModalityKind::Video, &ones(512)).is_err());
1066    }
1067
1068    // -----------------------------------------------------------------------
1069    // 10. CrossModalTokenizer tokenize_batch
1070    // -----------------------------------------------------------------------
1071    #[test]
1072    fn test_cross_modal_tokenizer_batch() {
1073        let mut cmt = CrossModalTokenizer::new(64).expect("new");
1074        cmt.add_modality(ModalityTokenizerConfig {
1075            modality: ModalityKind::Audio,
1076            input_dim: 16,
1077            token_dim: 64,
1078            codebook_size: 128,
1079            num_stages: 1,
1080        })
1081        .expect("add audio");
1082        cmt.add_modality(ModalityTokenizerConfig {
1083            modality: ModalityKind::Control,
1084            input_dim: 6,
1085            token_dim: 64,
1086            codebook_size: 64,
1087            num_stages: 1,
1088        })
1089        .expect("add control");
1090
1091        let inputs = vec![
1092            (ModalityKind::Audio, ones(16)),
1093            (ModalityKind::Control, zeros(6)),
1094            (ModalityKind::Audio, zeros(16)),
1095        ];
1096        let seq = cmt.tokenize_batch(&inputs).expect("batch");
1097        assert_eq!(seq.len(), 3);
1098        assert_eq!(seq.shared_dim, 64);
1099
1100        let mat = seq.to_embedding_matrix();
1101        assert_eq!(mat.shape(), [3, 64]);
1102
1103        let audio_tokens = seq.filter_by_modality(&ModalityKind::Audio);
1104        assert_eq!(audio_tokens.len(), 2);
1105    }
1106
1107    // -----------------------------------------------------------------------
1108    // 11. CrossModalTokenizer decode
1109    // -----------------------------------------------------------------------
1110    #[test]
1111    fn test_cross_modal_tokenizer_decode() {
1112        let mut cmt = CrossModalTokenizer::new(32).expect("new");
1113        cmt.add_modality(ModalityTokenizerConfig {
1114            modality: ModalityKind::Sensor,
1115            input_dim: 9,
1116            token_dim: 32,
1117            codebook_size: 64,
1118            num_stages: 1,
1119        })
1120        .expect("add sensor");
1121
1122        let input = ones(9);
1123        let token = cmt
1124            .tokenize(&ModalityKind::Sensor, &input)
1125            .expect("tokenize");
1126
1127        let reconstructed = cmt.decode(&token).expect("decode");
1128        assert_eq!(reconstructed.len(), 9, "decoded must match input_dim");
1129
1130        // Decoding token for an unregistered modality should error
1131        let bad_token = CrossModalToken {
1132            modality: ModalityKind::Video,
1133            token_idx: 0,
1134            embedding: Array1::zeros(32),
1135            confidence: 1.0,
1136        };
1137        assert!(cmt.decode(&bad_token).is_err());
1138    }
1139
1140    // -----------------------------------------------------------------------
1141    // 12. Robotics preset
1142    // -----------------------------------------------------------------------
1143    #[test]
1144    fn test_cross_modal_robotics_preset() {
1145        let cmt = CrossModalTokenizer::robotics_preset().expect("robotics preset");
1146        assert_eq!(cmt.shared_dim(), 64);
1147        assert_eq!(cmt.num_modalities(), 3);
1148
1149        let names = cmt.modality_names();
1150        assert!(names.contains(&"audio".to_string()));
1151        assert!(names.contains(&"control".to_string()));
1152        assert!(names.contains(&"sensor".to_string()));
1153
1154        // Tokenize all three modalities
1155        let audio_token = cmt
1156            .tokenize(&ModalityKind::Audio, &ones(16))
1157            .expect("audio tokenize");
1158        assert_eq!(audio_token.embedding.len(), 64);
1159
1160        let control_token = cmt
1161            .tokenize(&ModalityKind::Control, &zeros(6))
1162            .expect("control tokenize");
1163        assert!(control_token.token_idx < 256);
1164
1165        let sensor_token = cmt
1166            .tokenize(&ModalityKind::Sensor, &ones(9))
1167            .expect("sensor tokenize");
1168        assert!(sensor_token.confidence > 0.0);
1169
1170        // tokenize_batch
1171        let inputs = vec![
1172            (ModalityKind::Audio, ones(16)),
1173            (ModalityKind::Control, zeros(6)),
1174            (ModalityKind::Sensor, ones(9)),
1175        ];
1176        let seq = cmt.tokenize_batch(&inputs).expect("batch");
1177        assert_eq!(seq.len(), 3);
1178    }
1179
1180    // -----------------------------------------------------------------------
1181    // 13. CrossModalAligner push and flush
1182    // -----------------------------------------------------------------------
1183    #[test]
1184    fn test_cross_modal_aligner() {
1185        let mut aligner = CrossModalAligner::new(64);
1186        assert!(aligner.is_empty());
1187
1188        aligner.push_token(CrossModalToken {
1189            modality: ModalityKind::Audio,
1190            token_idx: 0,
1191            embedding: Array1::zeros(64),
1192            confidence: 0.9,
1193        });
1194        aligner.push_token(CrossModalToken {
1195            modality: ModalityKind::Control,
1196            token_idx: 1,
1197            embedding: Array1::ones(64),
1198            confidence: 0.8,
1199        });
1200        aligner.push_token(CrossModalToken {
1201            modality: ModalityKind::Audio,
1202            token_idx: 2,
1203            embedding: Array1::zeros(64),
1204            confidence: 0.7,
1205        });
1206
1207        assert_eq!(aligner.len(), 3);
1208        assert!(!aligner.is_empty());
1209        assert_eq!(aligner.count_for_modality(&ModalityKind::Audio), 2);
1210        assert_eq!(aligner.count_for_modality(&ModalityKind::Control), 1);
1211        assert_eq!(aligner.count_for_modality(&ModalityKind::Sensor), 0);
1212
1213        let seq = aligner.flush();
1214        assert_eq!(seq.len(), 3);
1215        assert!(aligner.is_empty(), "buffer cleared after flush");
1216        assert_eq!(aligner.count_for_modality(&ModalityKind::Audio), 0);
1217
1218        let mat = seq.to_embedding_matrix();
1219        assert_eq!(mat.shape(), [3, 64]);
1220    }
1221
1222    // -----------------------------------------------------------------------
1223    // 14. ModalityKind key and seed determinism
1224    // -----------------------------------------------------------------------
1225    #[test]
1226    fn test_modality_kind_key_and_seed() {
1227        assert_eq!(ModalityKind::Audio.key(), "audio");
1228        assert_eq!(ModalityKind::Control.key(), "control");
1229        assert_eq!(ModalityKind::Sensor.key(), "sensor");
1230        assert_eq!(ModalityKind::Video.key(), "video");
1231        assert_eq!(ModalityKind::Custom("robot".into()).key(), "custom_robot");
1232
1233        // Seeds must be deterministic
1234        assert_eq!(ModalityKind::Audio.seed(), ModalityKind::Audio.seed());
1235        assert_ne!(ModalityKind::Audio.seed(), ModalityKind::Control.seed());
1236    }
1237
1238    // -----------------------------------------------------------------------
1239    // 15. Audio-video preset
1240    // -----------------------------------------------------------------------
1241    #[test]
1242    fn test_audio_video_preset() {
1243        let cmt = CrossModalTokenizer::audio_video_preset().expect("audio_video preset");
1244        assert_eq!(cmt.shared_dim(), 256);
1245        assert_eq!(cmt.num_modalities(), 2);
1246
1247        let audio_tok = cmt
1248            .tokenize(&ModalityKind::Audio, &ones(80))
1249            .expect("audio tokenize");
1250        assert_eq!(audio_tok.embedding.len(), 256);
1251
1252        let video_tok = cmt
1253            .tokenize(&ModalityKind::Video, &ones(512))
1254            .expect("video tokenize");
1255        assert!(video_tok.token_idx < 2048);
1256    }
1257}