use crate::error::{TokenizerError, TokenizerResult};
use crate::SignalTokenizer;
use scirs2_core::ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
struct SeededRng {
state: u64,
}
impl SeededRng {
fn new(seed: u64) -> Self {
Self { state: seed.max(1) }
}
fn next_f32(&mut self) -> f32 {
self.state ^= self.state << 13;
self.state ^= self.state >> 7;
self.state ^= self.state << 17;
(self.state as f64 / u64::MAX as f64 * 2.0 - 1.0) as f32
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModalityKind {
Audio,
Control,
Sensor,
Video,
Custom(String),
}
impl ModalityKind {
pub fn key(&self) -> String {
match self {
ModalityKind::Audio => "audio".to_string(),
ModalityKind::Control => "control".to_string(),
ModalityKind::Sensor => "sensor".to_string(),
ModalityKind::Video => "video".to_string(),
ModalityKind::Custom(s) => format!("custom_{s}"),
}
}
fn seed(&self) -> u64 {
let key = self.key();
key.bytes().fold(5381u64, |acc, b| {
acc.wrapping_mul(33).wrapping_add(b as u64)
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModalityTokenizerConfig {
pub modality: ModalityKind,
pub input_dim: usize,
pub token_dim: usize,
pub codebook_size: usize,
pub num_stages: usize,
}
impl ModalityTokenizerConfig {
pub fn validate(&self) -> TokenizerResult<()> {
if self.input_dim == 0 {
return Err(TokenizerError::InvalidConfig(
"input_dim must be > 0".into(),
));
}
if self.token_dim == 0 {
return Err(TokenizerError::InvalidConfig(
"token_dim must be > 0".into(),
));
}
if self.codebook_size == 0 {
return Err(TokenizerError::InvalidConfig(
"codebook_size must be > 0".into(),
));
}
if self.num_stages == 0 {
return Err(TokenizerError::InvalidConfig(
"num_stages must be >= 1".into(),
));
}
Ok(())
}
}
#[inline]
fn gelu(x: f32) -> f32 {
let c = 0.797_884_6_f32; let v = c * (x + 0.044715 * x * x * x);
0.5 * x * (1.0 + v.tanh())
}
pub struct ModalityTokenizer {
config: ModalityTokenizerConfig,
encoder: Array2<f32>,
encoder_bias: Array1<f32>,
codebook: Array2<f32>,
}
impl ModalityTokenizer {
pub fn new(config: ModalityTokenizerConfig) -> TokenizerResult<Self> {
config.validate()?;
let seed = config.modality.seed();
let mut rng = SeededRng::new(seed);
let enc_scale = (6.0_f32 / (config.input_dim + config.token_dim) as f32).sqrt();
let encoder = Array2::from_shape_fn((config.input_dim, config.token_dim), |_| {
rng.next_f32() * enc_scale
});
let encoder_bias = Array1::zeros(config.token_dim);
let cb_scale = 1.0_f32 / (config.token_dim as f32).sqrt();
let codebook = Array2::from_shape_fn((config.codebook_size, config.token_dim), |_| {
rng.next_f32() * cb_scale
});
Ok(Self {
config,
encoder,
encoder_bias,
codebook,
})
}
pub fn encode(&self, input: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
if input.len() != self.config.input_dim {
return Err(TokenizerError::dim_mismatch(
self.config.input_dim,
input.len(),
"ModalityTokenizer::encode input_dim",
));
}
let pre_act = input.dot(&self.encoder) + &self.encoder_bias;
let activated = pre_act.mapv(gelu);
Ok(activated)
}
pub fn quantize(&self, embedding: &Array1<f32>) -> TokenizerResult<(usize, Array1<f32>)> {
if embedding.len() != self.config.token_dim {
return Err(TokenizerError::dim_mismatch(
self.config.token_dim,
embedding.len(),
"ModalityTokenizer::quantize embedding dim",
));
}
let mut best_idx = 0usize;
let mut best_dist = f32::INFINITY;
for k in 0..self.config.codebook_size {
let code = self.codebook.row(k);
let diff = embedding - &code;
let dist = diff.dot(&diff); if dist < best_dist {
best_dist = dist;
best_idx = k;
}
}
let quantized = self.codebook.row(best_idx).to_owned();
Ok((best_idx, quantized))
}
pub fn decode(&self, token_idx: usize) -> TokenizerResult<Array1<f32>> {
if token_idx >= self.config.codebook_size {
return Err(TokenizerError::out_of_range(
token_idx as f32,
0.0,
(self.config.codebook_size - 1) as f32,
"ModalityTokenizer::decode token_idx",
));
}
Ok(self.codebook.row(token_idx).to_owned())
}
pub fn decode_embedding(&self, embedding: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
if embedding.len() != self.config.token_dim {
return Err(TokenizerError::dim_mismatch(
self.config.token_dim,
embedding.len(),
"ModalityTokenizer::decode_embedding embedding dim",
));
}
let reconstructed = self.encoder.dot(embedding);
Ok(reconstructed)
}
pub fn input_dim(&self) -> usize {
self.config.input_dim
}
pub fn token_dim(&self) -> usize {
self.config.token_dim
}
pub fn codebook_size(&self) -> usize {
self.config.codebook_size
}
pub fn codebook(&self) -> &Array2<f32> {
&self.codebook
}
pub fn confidence(&self, embedding: &Array1<f32>, quantized: &Array1<f32>) -> f32 {
let diff = embedding - quantized;
let dist = diff.dot(&diff).sqrt();
1.0 / (1.0 + dist)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossModalToken {
pub modality: ModalityKind,
pub token_idx: usize,
pub embedding: Array1<f32>,
pub confidence: f32,
}
pub struct CrossModalSequence {
pub tokens: Vec<CrossModalToken>,
pub shared_dim: usize,
}
impl CrossModalSequence {
pub fn new(shared_dim: usize) -> Self {
Self {
tokens: Vec::new(),
shared_dim,
}
}
pub fn push(&mut self, token: CrossModalToken) {
self.tokens.push(token);
}
pub fn len(&self) -> usize {
self.tokens.len()
}
pub fn is_empty(&self) -> bool {
self.tokens.is_empty()
}
pub fn to_embedding_matrix(&self) -> Array2<f32> {
let n = self.tokens.len();
if n == 0 {
return Array2::zeros((0, self.shared_dim));
}
let mut mat = Array2::zeros((n, self.shared_dim));
for (i, tok) in self.tokens.iter().enumerate() {
let row_len = tok.embedding.len().min(self.shared_dim);
for j in 0..row_len {
mat[[i, j]] = tok.embedding[j];
}
}
mat
}
pub fn filter_by_modality(&self, modality: &ModalityKind) -> Vec<&CrossModalToken> {
self.tokens
.iter()
.filter(|t| &t.modality == modality)
.collect()
}
pub fn modalities_present(&self) -> Vec<&ModalityKind> {
let mut seen: Vec<&ModalityKind> = Vec::new();
for tok in &self.tokens {
if !seen.contains(&&tok.modality) {
seen.push(&tok.modality);
}
}
seen
}
}
pub struct CrossModalAligner {
shared_dim: usize,
modality_counts: HashMap<String, usize>,
buffer: Vec<CrossModalToken>,
}
impl CrossModalAligner {
pub fn new(shared_dim: usize) -> Self {
Self {
shared_dim,
modality_counts: HashMap::new(),
buffer: Vec::new(),
}
}
pub fn push_token(&mut self, token: CrossModalToken) {
let key = token.modality.key();
*self.modality_counts.entry(key).or_insert(0) += 1;
self.buffer.push(token);
}
pub fn flush(&mut self) -> CrossModalSequence {
let mut seq = CrossModalSequence::new(self.shared_dim);
for tok in self.buffer.drain(..) {
seq.push(tok);
}
self.modality_counts.clear();
seq
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
pub fn count_for_modality(&self, modality: &ModalityKind) -> usize {
self.modality_counts
.get(&modality.key())
.copied()
.unwrap_or(0)
}
}
pub struct CrossModalTokenizer {
shared_dim: usize,
tokenizers: HashMap<String, ModalityTokenizer>,
shared_proj: Array2<f32>,
shared_bias: Array1<f32>,
modality_embeddings: HashMap<String, Array1<f32>>,
}
impl CrossModalTokenizer {
pub fn new(shared_dim: usize) -> TokenizerResult<Self> {
if shared_dim == 0 {
return Err(TokenizerError::InvalidConfig(
"shared_dim must be > 0".into(),
));
}
let mut rng = SeededRng::new(0xdeadbeef_cafebabe);
let scale = 0.01_f32 / (shared_dim as f32).sqrt();
let shared_proj = Array2::from_shape_fn((shared_dim, shared_dim), |(i, j)| {
let identity = if i == j { 1.0_f32 } else { 0.0_f32 };
identity + rng.next_f32() * scale
});
let shared_bias = Array1::zeros(shared_dim);
Ok(Self {
shared_dim,
tokenizers: HashMap::new(),
shared_proj,
shared_bias,
modality_embeddings: HashMap::new(),
})
}
pub fn add_modality(&mut self, config: ModalityTokenizerConfig) -> TokenizerResult<()> {
if config.token_dim != self.shared_dim {
return Err(TokenizerError::InvalidConfig(format!(
"ModalityTokenizerConfig.token_dim ({}) must equal shared_dim ({})",
config.token_dim, self.shared_dim
)));
}
config.validate()?;
let key = config.modality.key();
let modality_seed = config.modality.seed().wrapping_add(0x1234_5678_9abc_def0);
let mut rng = SeededRng::new(modality_seed);
let embed_scale = 0.02_f32;
let mod_emb = Array1::from_shape_fn(self.shared_dim, |_| rng.next_f32() * embed_scale);
let tokenizer = ModalityTokenizer::new(config)?;
self.tokenizers.insert(key.clone(), tokenizer);
self.modality_embeddings.insert(key, mod_emb);
Ok(())
}
pub fn tokenize(
&self,
modality: &ModalityKind,
input: &Array1<f32>,
) -> TokenizerResult<CrossModalToken> {
let key = modality.key();
let tok = self.tokenizers.get(&key).ok_or_else(|| {
TokenizerError::InvalidConfig(format!("modality '{key}' not registered"))
})?;
let mod_emb = self.modality_embeddings.get(&key).ok_or_else(|| {
TokenizerError::InternalError(format!("missing modality embedding for '{key}'"))
})?;
let encoded = tok.encode(input)?;
let with_mod = encoded + mod_emb;
let aligned = with_mod.dot(&self.shared_proj) + &self.shared_bias;
let (token_idx, quantized) = tok.quantize(&aligned)?;
let confidence = tok.confidence(&aligned, &quantized);
Ok(CrossModalToken {
modality: modality.clone(),
token_idx,
embedding: aligned,
confidence,
})
}
pub fn tokenize_batch(
&self,
inputs: &[(ModalityKind, Array1<f32>)],
) -> TokenizerResult<CrossModalSequence> {
let mut seq = CrossModalSequence::new(self.shared_dim);
for (modality, signal) in inputs {
let token = self.tokenize(modality, signal)?;
seq.push(token);
}
Ok(seq)
}
pub fn decode(&self, token: &CrossModalToken) -> TokenizerResult<Array1<f32>> {
let key = token.modality.key();
let tok = self.tokenizers.get(&key).ok_or_else(|| {
TokenizerError::InvalidConfig(format!("modality '{key}' not registered"))
})?;
let mod_emb = self.modality_embeddings.get(&key).ok_or_else(|| {
TokenizerError::InternalError(format!("missing modality embedding for '{key}'"))
})?;
let quantized = tok.decode(token.token_idx)?;
let without_mod = quantized - mod_emb;
tok.decode_embedding(&without_mod)
}
pub fn shared_dim(&self) -> usize {
self.shared_dim
}
pub fn num_modalities(&self) -> usize {
self.tokenizers.len()
}
pub fn modality_names(&self) -> Vec<String> {
let mut names: Vec<String> = self.tokenizers.keys().cloned().collect();
names.sort();
names
}
pub fn robotics_preset() -> TokenizerResult<Self> {
let mut cmt = Self::new(64)?;
cmt.add_modality(ModalityTokenizerConfig {
modality: ModalityKind::Audio,
input_dim: 16,
token_dim: 64,
codebook_size: 512,
num_stages: 1,
})?;
cmt.add_modality(ModalityTokenizerConfig {
modality: ModalityKind::Control,
input_dim: 6,
token_dim: 64,
codebook_size: 256,
num_stages: 1,
})?;
cmt.add_modality(ModalityTokenizerConfig {
modality: ModalityKind::Sensor,
input_dim: 9,
token_dim: 64,
codebook_size: 256,
num_stages: 1,
})?;
Ok(cmt)
}
pub fn audio_video_preset() -> TokenizerResult<Self> {
let mut cmt = Self::new(256)?;
cmt.add_modality(ModalityTokenizerConfig {
modality: ModalityKind::Audio,
input_dim: 80,
token_dim: 256,
codebook_size: 1024,
num_stages: 2,
})?;
cmt.add_modality(ModalityTokenizerConfig {
modality: ModalityKind::Video,
input_dim: 512,
token_dim: 256,
codebook_size: 2048,
num_stages: 2,
})?;
Ok(cmt)
}
}
impl SignalTokenizer for CrossModalTokenizer {
fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
let mut names = self.modality_names();
names.sort();
let total_input_dim: usize = names.iter().map(|n| self.tokenizers[n].input_dim()).sum();
if signal.len() != total_input_dim {
return Err(TokenizerError::dim_mismatch(
total_input_dim,
signal.len(),
"CrossModalTokenizer::encode total_input_dim",
));
}
let mut out = Vec::with_capacity(names.len() * self.shared_dim);
let mut offset = 0usize;
for name in &names {
let tok = &self.tokenizers[name];
let dim = tok.input_dim();
let slice = signal.slice(scirs2_core::ndarray::s![offset..offset + dim]);
let input_owned = slice.to_owned();
let modality = Self::key_to_modality_kind(name);
let token = self.tokenize(&modality, &input_owned)?;
out.extend_from_slice(
token.embedding.as_slice().ok_or_else(|| {
TokenizerError::InternalError("embedding not contiguous".into())
})?,
);
offset += dim;
}
Ok(Array1::from_vec(out))
}
fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
let mut names = self.modality_names();
names.sort();
let n = names.len();
if n == 0 {
return Ok(Array1::zeros(0));
}
let expected = n * self.shared_dim;
if tokens.len() != expected {
return Err(TokenizerError::dim_mismatch(
expected,
tokens.len(),
"CrossModalTokenizer::decode embedding length",
));
}
let mut out = Vec::new();
for (i, name) in names.iter().enumerate() {
let start = i * self.shared_dim;
let end = start + self.shared_dim;
let emb_slice = tokens
.slice(scirs2_core::ndarray::s![start..end])
.to_owned();
let tok = &self.tokenizers[name];
let mod_emb = &self.modality_embeddings[name];
let without_mod = emb_slice - mod_emb;
let reconstructed = tok.decode_embedding(&without_mod)?;
out.extend_from_slice(reconstructed.as_slice().ok_or_else(|| {
TokenizerError::InternalError("reconstructed not contiguous".into())
})?);
}
Ok(Array1::from_vec(out))
}
fn embed_dim(&self) -> usize {
self.tokenizers.len() * self.shared_dim
}
fn vocab_size(&self) -> usize {
0
}
}
impl CrossModalTokenizer {
fn key_to_modality_kind(key: &str) -> ModalityKind {
match key {
"audio" => ModalityKind::Audio,
"control" => ModalityKind::Control,
"sensor" => ModalityKind::Sensor,
"video" => ModalityKind::Video,
other => {
let custom_name = other.strip_prefix("custom_").unwrap_or(other);
ModalityKind::Custom(custom_name.to_string())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
fn zeros(n: usize) -> Array1<f32> {
Array1::zeros(n)
}
fn ones(n: usize) -> Array1<f32> {
Array1::ones(n)
}
#[test]
fn test_modality_tokenizer_creation() {
let cfg = ModalityTokenizerConfig {
modality: ModalityKind::Audio,
input_dim: 16,
token_dim: 64,
codebook_size: 128,
num_stages: 1,
};
let tok = ModalityTokenizer::new(cfg).expect("should create successfully");
assert_eq!(tok.input_dim(), 16);
assert_eq!(tok.token_dim(), 64);
assert_eq!(tok.codebook_size(), 128);
assert_eq!(tok.codebook().shape(), [128, 64]);
}
#[test]
fn test_modality_tokenizer_encode() {
let cfg = ModalityTokenizerConfig {
modality: ModalityKind::Control,
input_dim: 6,
token_dim: 32,
codebook_size: 64,
num_stages: 1,
};
let tok = ModalityTokenizer::new(cfg).expect("create");
let input = ones(6);
let emb = tok.encode(&input).expect("encode");
assert_eq!(emb.len(), 32, "embedding must be token_dim");
let bad = ones(5);
assert!(tok.encode(&bad).is_err());
}
#[test]
fn test_modality_tokenizer_quantize() {
let cfg = ModalityTokenizerConfig {
modality: ModalityKind::Sensor,
input_dim: 9,
token_dim: 16,
codebook_size: 32,
num_stages: 1,
};
let tok = ModalityTokenizer::new(cfg).expect("create");
let emb = zeros(16);
let (idx, quantized) = tok.quantize(&emb).expect("quantize");
assert!(idx < 32, "token index must be within codebook");
assert_eq!(quantized.len(), 16, "quantized must be token_dim");
}
#[test]
fn test_modality_tokenizer_decode_roundtrip() {
let cfg = ModalityTokenizerConfig {
modality: ModalityKind::Audio,
input_dim: 8,
token_dim: 32,
codebook_size: 64,
num_stages: 1,
};
let tok = ModalityTokenizer::new(cfg).expect("create");
let input = ones(8);
let emb = tok.encode(&input).expect("encode");
let (idx, _quantized) = tok.quantize(&emb).expect("quantize");
let code = tok.decode(idx).expect("decode");
assert_eq!(code.len(), 32, "decoded codebook entry must be token_dim");
let reconstructed = tok.decode_embedding(&emb).expect("decode_embedding");
assert_eq!(reconstructed.len(), 8, "reconstructed must be input_dim");
}
#[test]
fn test_cross_modal_token_creation() {
let token = CrossModalToken {
modality: ModalityKind::Video,
token_idx: 42,
embedding: Array1::from_vec(vec![0.1, 0.2, 0.3]),
confidence: 0.95,
};
assert_eq!(token.token_idx, 42);
assert!((token.confidence - 0.95).abs() < 1e-6);
assert_eq!(token.modality, ModalityKind::Video);
assert_eq!(token.embedding.len(), 3);
}
#[test]
fn test_cross_modal_sequence_operations() {
let mut seq = CrossModalSequence::new(8);
assert!(seq.is_empty());
seq.push(CrossModalToken {
modality: ModalityKind::Audio,
token_idx: 0,
embedding: Array1::zeros(8),
confidence: 0.8,
});
seq.push(CrossModalToken {
modality: ModalityKind::Control,
token_idx: 1,
embedding: Array1::ones(8),
confidence: 0.7,
});
seq.push(CrossModalToken {
modality: ModalityKind::Audio,
token_idx: 2,
embedding: Array1::zeros(8),
confidence: 0.9,
});
assert_eq!(seq.len(), 3);
assert!(!seq.is_empty());
let audio_tokens = seq.filter_by_modality(&ModalityKind::Audio);
assert_eq!(audio_tokens.len(), 2);
let control_tokens = seq.filter_by_modality(&ModalityKind::Control);
assert_eq!(control_tokens.len(), 1);
let video_tokens = seq.filter_by_modality(&ModalityKind::Video);
assert_eq!(video_tokens.len(), 0);
let mods = seq.modalities_present();
assert_eq!(mods.len(), 2);
}
#[test]
fn test_cross_modal_sequence_embedding_matrix() {
let shared_dim = 16;
let mut seq = CrossModalSequence::new(shared_dim);
for _ in 0..5 {
seq.push(CrossModalToken {
modality: ModalityKind::Sensor,
token_idx: 0,
embedding: Array1::zeros(shared_dim),
confidence: 1.0,
});
}
let mat = seq.to_embedding_matrix();
assert_eq!(mat.shape(), [5, shared_dim]);
let empty = CrossModalSequence::new(shared_dim);
let empty_mat = empty.to_embedding_matrix();
assert_eq!(empty_mat.shape(), [0, shared_dim]);
}
#[test]
fn test_cross_modal_tokenizer_add_modality() {
let mut cmt = CrossModalTokenizer::new(32).expect("new");
cmt.add_modality(ModalityTokenizerConfig {
modality: ModalityKind::Audio,
input_dim: 16,
token_dim: 32,
codebook_size: 64,
num_stages: 1,
})
.expect("add audio");
cmt.add_modality(ModalityTokenizerConfig {
modality: ModalityKind::Control,
input_dim: 6,
token_dim: 32,
codebook_size: 32,
num_stages: 1,
})
.expect("add control");
assert_eq!(cmt.num_modalities(), 2);
let names = cmt.modality_names();
assert!(names.contains(&"audio".to_string()));
assert!(names.contains(&"control".to_string()));
let bad = cmt.add_modality(ModalityTokenizerConfig {
modality: ModalityKind::Sensor,
input_dim: 9,
token_dim: 16, codebook_size: 32,
num_stages: 1,
});
assert!(bad.is_err());
}
#[test]
fn test_cross_modal_tokenizer_tokenize() {
let mut cmt = CrossModalTokenizer::new(64).expect("new");
cmt.add_modality(ModalityTokenizerConfig {
modality: ModalityKind::Audio,
input_dim: 16,
token_dim: 64,
codebook_size: 128,
num_stages: 1,
})
.expect("add audio");
let input = ones(16);
let token = cmt
.tokenize(&ModalityKind::Audio, &input)
.expect("tokenize");
assert_eq!(token.modality, ModalityKind::Audio);
assert!(token.token_idx < 128);
assert_eq!(token.embedding.len(), 64);
assert!(token.confidence > 0.0 && token.confidence <= 1.0);
assert!(cmt.tokenize(&ModalityKind::Video, &ones(512)).is_err());
}
#[test]
fn test_cross_modal_tokenizer_batch() {
let mut cmt = CrossModalTokenizer::new(64).expect("new");
cmt.add_modality(ModalityTokenizerConfig {
modality: ModalityKind::Audio,
input_dim: 16,
token_dim: 64,
codebook_size: 128,
num_stages: 1,
})
.expect("add audio");
cmt.add_modality(ModalityTokenizerConfig {
modality: ModalityKind::Control,
input_dim: 6,
token_dim: 64,
codebook_size: 64,
num_stages: 1,
})
.expect("add control");
let inputs = vec![
(ModalityKind::Audio, ones(16)),
(ModalityKind::Control, zeros(6)),
(ModalityKind::Audio, zeros(16)),
];
let seq = cmt.tokenize_batch(&inputs).expect("batch");
assert_eq!(seq.len(), 3);
assert_eq!(seq.shared_dim, 64);
let mat = seq.to_embedding_matrix();
assert_eq!(mat.shape(), [3, 64]);
let audio_tokens = seq.filter_by_modality(&ModalityKind::Audio);
assert_eq!(audio_tokens.len(), 2);
}
#[test]
fn test_cross_modal_tokenizer_decode() {
let mut cmt = CrossModalTokenizer::new(32).expect("new");
cmt.add_modality(ModalityTokenizerConfig {
modality: ModalityKind::Sensor,
input_dim: 9,
token_dim: 32,
codebook_size: 64,
num_stages: 1,
})
.expect("add sensor");
let input = ones(9);
let token = cmt
.tokenize(&ModalityKind::Sensor, &input)
.expect("tokenize");
let reconstructed = cmt.decode(&token).expect("decode");
assert_eq!(reconstructed.len(), 9, "decoded must match input_dim");
let bad_token = CrossModalToken {
modality: ModalityKind::Video,
token_idx: 0,
embedding: Array1::zeros(32),
confidence: 1.0,
};
assert!(cmt.decode(&bad_token).is_err());
}
#[test]
fn test_cross_modal_robotics_preset() {
let cmt = CrossModalTokenizer::robotics_preset().expect("robotics preset");
assert_eq!(cmt.shared_dim(), 64);
assert_eq!(cmt.num_modalities(), 3);
let names = cmt.modality_names();
assert!(names.contains(&"audio".to_string()));
assert!(names.contains(&"control".to_string()));
assert!(names.contains(&"sensor".to_string()));
let audio_token = cmt
.tokenize(&ModalityKind::Audio, &ones(16))
.expect("audio tokenize");
assert_eq!(audio_token.embedding.len(), 64);
let control_token = cmt
.tokenize(&ModalityKind::Control, &zeros(6))
.expect("control tokenize");
assert!(control_token.token_idx < 256);
let sensor_token = cmt
.tokenize(&ModalityKind::Sensor, &ones(9))
.expect("sensor tokenize");
assert!(sensor_token.confidence > 0.0);
let inputs = vec![
(ModalityKind::Audio, ones(16)),
(ModalityKind::Control, zeros(6)),
(ModalityKind::Sensor, ones(9)),
];
let seq = cmt.tokenize_batch(&inputs).expect("batch");
assert_eq!(seq.len(), 3);
}
#[test]
fn test_cross_modal_aligner() {
let mut aligner = CrossModalAligner::new(64);
assert!(aligner.is_empty());
aligner.push_token(CrossModalToken {
modality: ModalityKind::Audio,
token_idx: 0,
embedding: Array1::zeros(64),
confidence: 0.9,
});
aligner.push_token(CrossModalToken {
modality: ModalityKind::Control,
token_idx: 1,
embedding: Array1::ones(64),
confidence: 0.8,
});
aligner.push_token(CrossModalToken {
modality: ModalityKind::Audio,
token_idx: 2,
embedding: Array1::zeros(64),
confidence: 0.7,
});
assert_eq!(aligner.len(), 3);
assert!(!aligner.is_empty());
assert_eq!(aligner.count_for_modality(&ModalityKind::Audio), 2);
assert_eq!(aligner.count_for_modality(&ModalityKind::Control), 1);
assert_eq!(aligner.count_for_modality(&ModalityKind::Sensor), 0);
let seq = aligner.flush();
assert_eq!(seq.len(), 3);
assert!(aligner.is_empty(), "buffer cleared after flush");
assert_eq!(aligner.count_for_modality(&ModalityKind::Audio), 0);
let mat = seq.to_embedding_matrix();
assert_eq!(mat.shape(), [3, 64]);
}
#[test]
fn test_modality_kind_key_and_seed() {
assert_eq!(ModalityKind::Audio.key(), "audio");
assert_eq!(ModalityKind::Control.key(), "control");
assert_eq!(ModalityKind::Sensor.key(), "sensor");
assert_eq!(ModalityKind::Video.key(), "video");
assert_eq!(ModalityKind::Custom("robot".into()).key(), "custom_robot");
assert_eq!(ModalityKind::Audio.seed(), ModalityKind::Audio.seed());
assert_ne!(ModalityKind::Audio.seed(), ModalityKind::Control.seed());
}
#[test]
fn test_audio_video_preset() {
let cmt = CrossModalTokenizer::audio_video_preset().expect("audio_video preset");
assert_eq!(cmt.shared_dim(), 256);
assert_eq!(cmt.num_modalities(), 2);
let audio_tok = cmt
.tokenize(&ModalityKind::Audio, &ones(80))
.expect("audio tokenize");
assert_eq!(audio_tok.embedding.len(), 256);
let video_tok = cmt
.tokenize(&ModalityKind::Video, &ones(512))
.expect("video tokenize");
assert!(video_tok.token_idx < 2048);
}
}