use crate::error::{TokenizerError, TokenizerResult};
use crate::SignalTokenizer;
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::thread_rng;
#[derive(Debug, Clone)]
pub struct ScaleLevel {
downsample_factor: usize,
embed_dim: usize,
input_dim: usize,
}
impl ScaleLevel {
pub fn new(downsample_factor: usize, embed_dim: usize, input_dim: usize) -> Self {
Self {
downsample_factor,
embed_dim,
input_dim,
}
}
}
#[derive(Debug, Clone)]
pub struct MultiScaleTokenizer {
encoders: Vec<Array2<f32>>,
decoders: Vec<Array2<f32>>,
levels: Vec<ScaleLevel>,
input_dim: usize,
pool_method: PoolMethod,
upsample_method: UpsampleMethod,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PoolMethod {
Stride,
#[default]
Average,
Max,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum UpsampleMethod {
Repeat,
#[default]
Linear,
}
impl MultiScaleTokenizer {
pub fn new(input_dim: usize, embed_dim_per_level: usize) -> Self {
Self::with_factors(input_dim, embed_dim_per_level, &[1, 2, 4])
}
pub fn with_factors(input_dim: usize, embed_dim_per_level: usize, factors: &[usize]) -> Self {
let mut rng = thread_rng();
let mut encoders = Vec::with_capacity(factors.len());
let mut decoders = Vec::with_capacity(factors.len());
let mut levels = Vec::with_capacity(factors.len());
for &factor in factors {
let level_input_dim = input_dim / factor;
if level_input_dim == 0 {
continue;
}
let enc_scale = (2.0 / (level_input_dim + embed_dim_per_level) as f32).sqrt();
let encoder = Array2::from_shape_fn((level_input_dim, embed_dim_per_level), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * enc_scale
});
let dec_scale = (2.0 / (embed_dim_per_level + level_input_dim) as f32).sqrt();
let decoder = Array2::from_shape_fn((embed_dim_per_level, level_input_dim), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * dec_scale
});
encoders.push(encoder);
decoders.push(decoder);
levels.push(ScaleLevel::new(
factor,
embed_dim_per_level,
level_input_dim,
));
}
Self {
encoders,
decoders,
levels,
input_dim,
pool_method: PoolMethod::default(),
upsample_method: UpsampleMethod::default(),
}
}
pub fn with_pool_method(mut self, method: PoolMethod) -> Self {
self.pool_method = method;
self
}
pub fn with_upsample_method(mut self, method: UpsampleMethod) -> Self {
self.upsample_method = method;
self
}
pub fn num_levels(&self) -> usize {
self.levels.len()
}
pub fn total_embed_dim(&self) -> usize {
self.levels.iter().map(|l| l.embed_dim).sum()
}
fn downsample(&self, signal: &Array1<f32>, factor: usize) -> Array1<f32> {
if factor <= 1 {
return signal.clone();
}
let new_len = signal.len() / factor;
if new_len == 0 {
return Array1::zeros(1);
}
match self.pool_method {
PoolMethod::Stride => {
Array1::from_vec((0..new_len).map(|i| signal[i * factor]).collect())
}
PoolMethod::Average => Array1::from_vec(
(0..new_len)
.map(|i| {
let start = i * factor;
let end = (start + factor).min(signal.len());
signal.iter().skip(start).take(end - start).sum::<f32>()
/ (end - start) as f32
})
.collect(),
),
PoolMethod::Max => Array1::from_vec(
(0..new_len)
.map(|i| {
let start = i * factor;
let end = (start + factor).min(signal.len());
signal
.iter()
.skip(start)
.take(end - start)
.cloned()
.fold(f32::NEG_INFINITY, f32::max)
})
.collect(),
),
}
}
fn upsample(&self, signal: &Array1<f32>, factor: usize, target_len: usize) -> Array1<f32> {
if factor <= 1 {
return signal.clone();
}
match self.upsample_method {
UpsampleMethod::Repeat => {
let mut result = Vec::with_capacity(target_len);
for &val in signal.iter() {
for _ in 0..factor {
if result.len() < target_len {
result.push(val);
}
}
}
while result.len() < target_len {
result.push(*signal.last().unwrap_or(&0.0));
}
Array1::from_vec(result)
}
UpsampleMethod::Linear => {
if signal.len() < 2 {
return Array1::from_elem(target_len, signal.get(0).copied().unwrap_or(0.0));
}
let mut result = Vec::with_capacity(target_len);
for i in 0..target_len {
let src_pos = i as f32 / factor as f32;
let src_idx = src_pos.floor() as usize;
let t = src_pos - src_idx as f32;
let val = if src_idx + 1 < signal.len() {
signal[src_idx] * (1.0 - t) + signal[src_idx + 1] * t
} else {
signal[signal.len() - 1]
};
result.push(val);
}
Array1::from_vec(result)
}
}
}
pub fn encode_level(&self, signal: &Array1<f32>, level: usize) -> TokenizerResult<Array1<f32>> {
if level >= self.levels.len() {
return Err(TokenizerError::InvalidConfig(format!(
"Level {} out of range (0..{})",
level,
self.levels.len()
)));
}
let factor = self.levels[level].downsample_factor;
let downsampled = self.downsample(signal, factor);
if downsampled.len() != self.levels[level].input_dim {
let mut resized = Array1::zeros(self.levels[level].input_dim);
for i in 0..resized.len().min(downsampled.len()) {
resized[i] = downsampled[i];
}
return Ok(resized.dot(&self.encoders[level]));
}
Ok(downsampled.dot(&self.encoders[level]))
}
pub fn decode_level(
&self,
embedding: &Array1<f32>,
level: usize,
) -> TokenizerResult<Array1<f32>> {
if level >= self.levels.len() {
return Err(TokenizerError::InvalidConfig(format!(
"Level {} out of range (0..{})",
level,
self.levels.len()
)));
}
if embedding.len() != self.levels[level].embed_dim {
return Err(TokenizerError::dim_mismatch(
self.levels[level].embed_dim,
embedding.len(),
"dimension validation",
));
}
let decoded = embedding.dot(&self.decoders[level]);
let factor = self.levels[level].downsample_factor;
Ok(self.upsample(&decoded, factor, self.input_dim))
}
pub fn encode_all(&self, signal: &Array1<f32>) -> TokenizerResult<Vec<Array1<f32>>> {
let mut embeddings = Vec::with_capacity(self.levels.len());
for level in 0..self.levels.len() {
embeddings.push(self.encode_level(signal, level)?);
}
Ok(embeddings)
}
pub fn decode_all(&self, embeddings: &[Array1<f32>]) -> TokenizerResult<Array1<f32>> {
if embeddings.len() != self.levels.len() {
return Err(TokenizerError::InvalidConfig(format!(
"Expected {} embeddings, got {}",
self.levels.len(),
embeddings.len()
)));
}
let mut result = Array1::zeros(self.input_dim);
let weight = 1.0 / self.levels.len() as f32;
for (level, embedding) in embeddings.iter().enumerate() {
let decoded = self.decode_level(embedding, level)?;
result = &result + &(&decoded * weight);
}
Ok(result)
}
pub fn encode_concat(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
let embeddings = self.encode_all(signal)?;
let total_len: usize = embeddings.iter().map(|e| e.len()).sum();
let mut result = Vec::with_capacity(total_len);
for emb in embeddings {
result.extend(emb.iter());
}
Ok(Array1::from_vec(result))
}
}
impl SignalTokenizer for MultiScaleTokenizer {
fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
if signal.len() != self.input_dim {
return Err(TokenizerError::dim_mismatch(
self.input_dim,
signal.len(),
"dimension validation",
));
}
self.encode_concat(signal)
}
fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
if tokens.len() != self.total_embed_dim() {
return Err(TokenizerError::dim_mismatch(
self.total_embed_dim(),
tokens.len(),
"dimension validation",
));
}
let mut embeddings = Vec::with_capacity(self.levels.len());
let mut offset = 0;
for level in &self.levels {
let end = offset + level.embed_dim;
let embedding: Array1<f32> = Array1::from_vec(
tokens
.iter()
.skip(offset)
.take(level.embed_dim)
.cloned()
.collect(),
);
embeddings.push(embedding);
offset = end;
}
self.decode_all(&embeddings)
}
fn embed_dim(&self) -> usize {
self.total_embed_dim()
}
fn vocab_size(&self) -> usize {
0 }
}
#[derive(Debug, Clone)]
pub struct PyramidTokenizer {
inner: MultiScaleTokenizer,
use_residual: bool,
}
impl PyramidTokenizer {
pub fn new(input_dim: usize, embed_dim_per_level: usize, num_levels: usize) -> Self {
let factors: Vec<usize> = (0..num_levels).map(|i| 1 << i).collect();
let inner = MultiScaleTokenizer::with_factors(input_dim, embed_dim_per_level, &factors);
Self {
inner,
use_residual: true,
}
}
pub fn without_residual(mut self) -> Self {
self.use_residual = false;
self
}
pub fn encode_pyramid(&self, signal: &Array1<f32>) -> TokenizerResult<Vec<Array1<f32>>> {
if !self.use_residual {
return self.inner.encode_all(signal);
}
let mut embeddings = Vec::with_capacity(self.inner.num_levels());
let mut residual = signal.clone();
for level in 0..self.inner.num_levels() {
let embedding = self.inner.encode_level(&residual, level)?;
embeddings.push(embedding.clone());
let reconstruction = self.inner.decode_level(&embedding, level)?;
residual = &residual - &reconstruction;
}
Ok(embeddings)
}
pub fn decode_pyramid(&self, embeddings: &[Array1<f32>]) -> TokenizerResult<Array1<f32>> {
if !self.use_residual {
return self.inner.decode_all(embeddings);
}
let mut result = Array1::zeros(self.inner.input_dim);
for (level, embedding) in embeddings.iter().enumerate() {
let decoded = self.inner.decode_level(embedding, level)?;
result = &result + &decoded;
}
Ok(result)
}
pub fn num_levels(&self) -> usize {
self.inner.num_levels()
}
}
impl SignalTokenizer for PyramidTokenizer {
fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
let embeddings = self.encode_pyramid(signal)?;
let total_len: usize = embeddings.iter().map(|e| e.len()).sum();
let mut result = Vec::with_capacity(total_len);
for emb in embeddings {
result.extend(emb.iter());
}
Ok(Array1::from_vec(result))
}
fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
let total_dim = self.inner.total_embed_dim();
if tokens.len() != total_dim {
return Err(TokenizerError::dim_mismatch(
total_dim,
tokens.len(),
"dimension validation",
));
}
let mut embeddings = Vec::new();
let mut offset = 0;
for level in &self.inner.levels {
let end = offset + level.embed_dim;
let embedding = Array1::from_vec(
tokens
.iter()
.skip(offset)
.take(level.embed_dim)
.cloned()
.collect(),
);
embeddings.push(embedding);
offset = end;
}
self.decode_pyramid(&embeddings)
}
fn embed_dim(&self) -> usize {
self.inner.total_embed_dim()
}
fn vocab_size(&self) -> usize {
0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multiscale_basic() {
let tokenizer = MultiScaleTokenizer::new(64, 16);
assert_eq!(tokenizer.num_levels(), 3);
assert_eq!(tokenizer.total_embed_dim(), 48);
let signal = Array1::from_vec((0..64).map(|i| (i as f32 * 0.1).sin()).collect());
let encoded = tokenizer.encode(&signal).unwrap();
assert_eq!(encoded.len(), 48);
let decoded = tokenizer.decode(&encoded).unwrap();
assert_eq!(decoded.len(), 64);
}
#[test]
fn test_downsample_average() {
let tokenizer = MultiScaleTokenizer::new(8, 4);
let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
let down = tokenizer.downsample(&signal, 2);
assert_eq!(down.len(), 4);
assert!((down[0] - 1.5).abs() < 0.01);
assert!((down[1] - 3.5).abs() < 0.01);
}
#[test]
fn test_downsample_stride() {
let tokenizer = MultiScaleTokenizer::new(8, 4).with_pool_method(PoolMethod::Stride);
let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
let down = tokenizer.downsample(&signal, 2);
assert_eq!(down.len(), 4);
assert_eq!(down[0], 1.0);
assert_eq!(down[1], 3.0);
}
#[test]
fn test_upsample_repeat() {
let tokenizer = MultiScaleTokenizer::new(8, 4).with_upsample_method(UpsampleMethod::Repeat);
let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let up = tokenizer.upsample(&signal, 2, 8);
assert_eq!(up.len(), 8);
assert_eq!(up[0], 1.0);
assert_eq!(up[1], 1.0);
assert_eq!(up[2], 2.0);
assert_eq!(up[3], 2.0);
}
#[test]
fn test_upsample_linear() {
let tokenizer = MultiScaleTokenizer::new(8, 4).with_upsample_method(UpsampleMethod::Linear);
let signal = Array1::from_vec(vec![0.0, 2.0]);
let up = tokenizer.upsample(&signal, 4, 8);
assert_eq!(up.len(), 8);
assert!(up[0].abs() < 0.01);
assert!((up[2] - 1.0).abs() < 0.01);
}
#[test]
fn test_encode_level() {
let tokenizer = MultiScaleTokenizer::new(64, 16);
let signal = Array1::from_vec((0..64).map(|i| i as f32).collect());
let enc0 = tokenizer.encode_level(&signal, 0).unwrap();
assert_eq!(enc0.len(), 16);
let enc1 = tokenizer.encode_level(&signal, 1).unwrap();
assert_eq!(enc1.len(), 16);
let enc2 = tokenizer.encode_level(&signal, 2).unwrap();
assert_eq!(enc2.len(), 16);
}
#[test]
fn test_pyramid_tokenizer() {
let tokenizer = PyramidTokenizer::new(64, 16, 3);
assert_eq!(tokenizer.num_levels(), 3);
let signal = Array1::from_vec((0..64).map(|i| (i as f32 * 0.1).sin()).collect());
let embeddings = tokenizer.encode_pyramid(&signal).unwrap();
assert_eq!(embeddings.len(), 3);
let decoded = tokenizer.decode_pyramid(&embeddings).unwrap();
assert_eq!(decoded.len(), 64);
}
#[test]
fn test_pyramid_residual() {
let tokenizer = PyramidTokenizer::new(32, 8, 3);
let signal = Array1::from_vec((0..32).map(|i| (i as f32 * 0.2).sin()).collect());
let embeddings = tokenizer.encode_pyramid(&signal).unwrap();
let variances: Vec<f32> = embeddings
.iter()
.map(|e| {
let mean = e.sum() / e.len() as f32;
e.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / e.len() as f32
})
.collect();
assert!(variances[0] > 0.0);
}
#[test]
fn test_custom_factors() {
let tokenizer = MultiScaleTokenizer::with_factors(100, 10, &[1, 5, 10, 20]);
assert_eq!(tokenizer.num_levels(), 4);
let signal = Array1::from_vec((0..100).map(|i| i as f32).collect());
let encoded = tokenizer.encode(&signal).unwrap();
assert_eq!(encoded.len(), 40);
let decoded = tokenizer.decode(&encoded).unwrap();
assert_eq!(decoded.len(), 100);
}
}