use crate::error::{TokenizerError, TokenizerResult};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::{rngs::StdRng, Random};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MSMConfig {
pub mask_ratio: f32,
pub mask_length: usize,
pub signal_dim: usize,
pub embed_dim: usize,
pub learning_rate: f32,
pub epochs: usize,
pub batch_size: usize,
}
impl Default for MSMConfig {
fn default() -> Self {
Self {
mask_ratio: 0.15,
mask_length: 16,
signal_dim: 256,
embed_dim: 128,
learning_rate: 0.001,
epochs: 100,
batch_size: 32,
}
}
}
impl MSMConfig {
pub fn validate(&self) -> TokenizerResult<()> {
if !(0.0..=1.0).contains(&self.mask_ratio) {
return Err(TokenizerError::invalid_input(
"mask_ratio must be in [0.0, 1.0]",
"MSMConfig::validate",
));
}
if self.mask_length == 0 {
return Err(TokenizerError::invalid_input(
"mask_length must be positive",
"MSMConfig::validate",
));
}
if self.signal_dim == 0 || self.embed_dim == 0 {
return Err(TokenizerError::invalid_input(
"signal_dim and embed_dim must be positive",
"MSMConfig::validate",
));
}
if !(0.0..1.0).contains(&self.learning_rate) {
return Err(TokenizerError::invalid_input(
"learning_rate must be in (0.0, 1.0)",
"MSMConfig::validate",
));
}
if self.epochs == 0 || self.batch_size == 0 {
return Err(TokenizerError::invalid_input(
"epochs and batch_size must be positive",
"MSMConfig::validate",
));
}
Ok(())
}
}
#[derive(Debug)]
pub struct MaskedSignalModeling {
config: MSMConfig,
encoder: Array2<f32>,
decoder: Array2<f32>,
rng: Random<StdRng>,
}
impl MaskedSignalModeling {
pub fn new(config: MSMConfig) -> TokenizerResult<Self> {
config.validate()?;
let mut rng = Random::seed(45);
let encoder_scale = (2.0 / (config.signal_dim + config.embed_dim) as f32).sqrt();
let decoder_scale = (2.0 / (config.embed_dim + config.signal_dim) as f32).sqrt();
let encoder =
Self::init_weights(config.signal_dim, config.embed_dim, encoder_scale, &mut rng);
let decoder =
Self::init_weights(config.embed_dim, config.signal_dim, decoder_scale, &mut rng);
Ok(Self {
config,
encoder,
decoder,
rng,
})
}
fn init_weights(rows: usize, cols: usize, scale: f32, rng: &mut Random<StdRng>) -> Array2<f32> {
let mut weights = Array2::zeros((rows, cols));
for val in weights.iter_mut() {
*val = (rng.gen_range(-1.0..1.0)) * scale;
}
weights
}
fn create_mask(&mut self, signal_len: usize) -> Array1<bool> {
let mut mask = Array1::from_elem(signal_len, false);
let num_masks = ((signal_len as f32 * self.config.mask_ratio)
/ self.config.mask_length as f32) as usize;
for _ in 0..num_masks {
let start = (self.rng.gen_range(0.0..1.0)
* (signal_len - self.config.mask_length) as f32) as usize;
let end = (start + self.config.mask_length).min(signal_len);
for i in start..end {
mask[i] = true;
}
}
mask
}
fn apply_mask(&self, signal: &Array1<f32>, mask: &Array1<bool>) -> Array1<f32> {
signal
.iter()
.zip(mask.iter())
.map(|(&val, &is_masked)| if is_masked { 0.0 } else { val })
.collect()
}
fn forward(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
let mut embedding = Array1::zeros(self.config.embed_dim);
for j in 0..self.config.embed_dim {
let mut sum = 0.0;
for i in 0..self.config.signal_dim.min(signal.len()) {
sum += signal[i] * self.encoder[[i, j]];
}
embedding[j] = sum;
}
embedding.mapv_inplace(|x| x.max(0.0));
let mut reconstructed = Array1::zeros(self.config.signal_dim);
for i in 0..self.config.signal_dim {
let mut sum = 0.0;
for j in 0..self.config.embed_dim {
sum += embedding[j] * self.decoder[[j, i]];
}
reconstructed[i] = sum;
}
Ok(reconstructed)
}
fn compute_loss(
&self,
target: &Array1<f32>,
prediction: &Array1<f32>,
mask: &Array1<bool>,
) -> f32 {
let mut loss = 0.0;
let mut count = 0;
for i in 0..target.len().min(prediction.len()).min(mask.len()) {
if mask[i] {
let diff = target[i] - prediction[i];
loss += diff * diff;
count += 1;
}
}
if count > 0 {
loss / count as f32
} else {
0.0
}
}
pub fn pretrain(
&mut self,
signals: &[Array1<f32>],
num_epochs: usize,
) -> TokenizerResult<Vec<f32>> {
let mut losses = Vec::new();
for epoch in 0..num_epochs {
let mut epoch_loss = 0.0;
let mut num_batches = 0;
for signal in signals {
if signal.len() != self.config.signal_dim {
continue; }
let mask = self.create_mask(signal.len());
let masked_signal = self.apply_mask(signal, &mask);
let reconstructed = self.forward(&masked_signal)?;
let loss = self.compute_loss(signal, &reconstructed, &mask);
epoch_loss += loss;
num_batches += 1;
self.update_weights(signal, &masked_signal, &reconstructed, &mask)?;
}
if num_batches > 0 {
epoch_loss /= num_batches as f32;
losses.push(epoch_loss);
if epoch % 10 == 0 {
tracing::debug!("Epoch {}: Loss = {:.6}", epoch, epoch_loss);
}
}
}
Ok(losses)
}
fn update_weights(
&mut self,
target: &Array1<f32>,
input: &Array1<f32>,
output: &Array1<f32>,
mask: &Array1<bool>,
) -> TokenizerResult<()> {
let lr = self.config.learning_rate;
let mut output_error = Array1::zeros(self.config.signal_dim);
for i in 0..self.config.signal_dim.min(output.len()).min(target.len()) {
if i < mask.len() && mask[i] {
output_error[i] = output[i] - target[i];
}
}
let mut embedding = Array1::zeros(self.config.embed_dim);
for j in 0..self.config.embed_dim {
let mut sum = 0.0;
for i in 0..self.config.signal_dim.min(input.len()) {
sum += input[i] * self.encoder[[i, j]];
}
embedding[j] = sum.max(0.0); }
for j in 0..self.config.embed_dim {
for i in 0..self.config.signal_dim {
let gradient = output_error[i] * embedding[j];
self.decoder[[j, i]] -= lr * gradient;
}
}
let mut hidden_error = Array1::zeros(self.config.embed_dim);
for j in 0..self.config.embed_dim {
let mut sum = 0.0;
for i in 0..self.config.signal_dim {
sum += output_error[i] * self.decoder[[j, i]];
}
hidden_error[j] = if embedding[j] > 0.0 { sum } else { 0.0 };
}
for i in 0..self.config.signal_dim.min(input.len()) {
for j in 0..self.config.embed_dim {
let gradient = hidden_error[j] * input[i];
self.encoder[[i, j]] -= lr * gradient;
}
}
Ok(())
}
pub fn encoder_weights(&self) -> &Array2<f32> {
&self.encoder
}
pub fn decoder_weights(&self) -> &Array2<f32> {
&self.decoder
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContrastiveConfig {
pub embed_dim: usize,
pub temperature: f32,
pub aug_noise_std: f32,
pub learning_rate: f32,
pub num_negatives: usize,
}
impl Default for ContrastiveConfig {
fn default() -> Self {
Self {
embed_dim: 128,
temperature: 0.07,
aug_noise_std: 0.1,
learning_rate: 0.001,
num_negatives: 16,
}
}
}
#[derive(Debug)]
pub struct ContrastiveLearning {
config: ContrastiveConfig,
encoder: Array2<f32>,
rng: Random<StdRng>,
}
impl ContrastiveLearning {
pub fn new(signal_dim: usize, config: ContrastiveConfig) -> Self {
let mut rng = Random::seed(46);
let scale = (2.0 / (signal_dim + config.embed_dim) as f32).sqrt();
let mut encoder = Array2::zeros((signal_dim, config.embed_dim));
for val in encoder.iter_mut() {
*val = (rng.gen_range(-1.0..1.0)) * scale;
}
Self {
config,
encoder,
rng,
}
}
fn augment(&mut self, signal: &Array1<f32>) -> Array1<f32> {
signal.mapv(|x| {
let noise = (self.rng.gen_range(-1.0..1.0)) * self.config.aug_noise_std;
x + noise
})
}
fn encode(&self, signal: &Array1<f32>) -> Array1<f32> {
let mut embedding = Array1::zeros(self.config.embed_dim);
for j in 0..self.config.embed_dim {
let mut sum = 0.0;
for i in 0..signal.len().min(self.encoder.nrows()) {
sum += signal[i] * self.encoder[[i, j]];
}
embedding[j] = sum;
}
let norm = embedding.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
embedding /= norm;
}
embedding
}
fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
pub fn contrastive_loss(&mut self, signals: &[Array1<f32>]) -> TokenizerResult<f32> {
if signals.len() < 2 {
return Ok(0.0);
}
let mut total_loss = 0.0;
let mut count = 0;
for i in 0..signals.len() {
let view1 = self.augment(&signals[i]);
let view2 = self.augment(&signals[i]);
let z1 = self.encode(&view1);
let z2 = self.encode(&view2);
let pos_sim = self.cosine_similarity(&z1, &z2) / self.config.temperature;
let mut neg_sims = Vec::new();
for (j, signal) in signals.iter().enumerate() {
if i != j {
let neg_view = self.augment(signal);
let z_neg = self.encode(&neg_view);
let neg_sim = self.cosine_similarity(&z1, &z_neg) / self.config.temperature;
neg_sims.push(neg_sim);
if neg_sims.len() >= self.config.num_negatives {
break;
}
}
}
let pos_exp = pos_sim.exp();
let neg_sum: f32 = neg_sims.iter().map(|&x| x.exp()).sum();
let loss = -(pos_exp / (pos_exp + neg_sum)).ln();
total_loss += loss;
count += 1;
}
Ok(if count > 0 {
total_loss / count as f32
} else {
0.0
})
}
pub fn encoder_weights(&self) -> &Array2<f32> {
&self.encoder
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemporalPredictionConfig {
pub context_size: usize,
pub prediction_size: usize,
pub embed_dim: usize,
pub learning_rate: f32,
}
impl Default for TemporalPredictionConfig {
fn default() -> Self {
Self {
context_size: 64,
prediction_size: 16,
embed_dim: 128,
learning_rate: 0.001,
}
}
}
#[derive(Debug, Clone)]
pub struct TemporalPrediction {
config: TemporalPredictionConfig,
context_encoder: Array2<f32>,
prediction_head: Array2<f32>,
}
impl TemporalPrediction {
pub fn new(config: TemporalPredictionConfig) -> Self {
let mut rng = Random::seed(47);
let encoder_scale = (2.0 / (config.context_size + config.embed_dim) as f32).sqrt();
let head_scale = (2.0 / (config.embed_dim + config.prediction_size) as f32).sqrt();
let mut context_encoder = Array2::zeros((config.context_size, config.embed_dim));
let mut prediction_head = Array2::zeros((config.embed_dim, config.prediction_size));
for val in context_encoder.iter_mut() {
*val = (rng.gen_range(-1.0..1.0)) * encoder_scale;
}
for val in prediction_head.iter_mut() {
*val = (rng.gen_range(-1.0..1.0)) * head_scale;
}
Self {
config,
context_encoder,
prediction_head,
}
}
pub fn predict(&self, context: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
if context.len() != self.config.context_size {
return Err(TokenizerError::encoding(
format!(
"Context size mismatch: expected {}, got {}",
self.config.context_size,
context.len()
),
"TemporalPrediction::predict",
));
}
let mut embedding = Array1::zeros(self.config.embed_dim);
for j in 0..self.config.embed_dim {
let mut sum = 0.0;
for i in 0..self.config.context_size {
sum += context[i] * self.context_encoder[[i, j]];
}
embedding[j] = sum.max(0.0); }
let mut prediction = Array1::zeros(self.config.prediction_size);
for i in 0..self.config.prediction_size {
let mut sum = 0.0;
for j in 0..self.config.embed_dim {
sum += embedding[j] * self.prediction_head[[j, i]];
}
prediction[i] = sum;
}
Ok(prediction)
}
pub fn context_encoder_weights(&self) -> &Array2<f32> {
&self.context_encoder
}
pub fn prediction_head_weights(&self) -> &Array2<f32> {
&self.prediction_head
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_msm_config_validation() {
let config = MSMConfig::default();
assert!(config.validate().is_ok());
let mut bad_config = config.clone();
bad_config.mask_ratio = 1.5;
assert!(bad_config.validate().is_err());
let mut bad_config = config.clone();
bad_config.learning_rate = 1.5;
assert!(bad_config.validate().is_err());
}
#[test]
fn test_msm_creation() {
let config = MSMConfig::default();
let msm = MaskedSignalModeling::new(config);
assert!(msm.is_ok());
}
#[test]
fn test_msm_create_mask() {
let config = MSMConfig {
mask_ratio: 0.2,
mask_length: 10,
..Default::default()
};
let mut msm = MaskedSignalModeling::new(config).unwrap();
let mask = msm.create_mask(100);
assert_eq!(mask.len(), 100);
let num_masked = mask.iter().filter(|&&x| x).count();
assert!(num_masked > 0 && num_masked < 100);
}
#[test]
fn test_msm_apply_mask() {
let config = MSMConfig::default();
let msm = MaskedSignalModeling::new(config).unwrap();
let signal = Array1::linspace(0.0, 1.0, 100);
let mask = Array1::from_vec(vec![false; 50].into_iter().chain(vec![true; 50]).collect());
let masked = msm.apply_mask(&signal, &mask);
assert_eq!(masked.len(), 100);
for i in 0..50 {
assert!((masked[i] - signal[i]).abs() < 1e-6);
}
for i in 50..100 {
assert_eq!(masked[i], 0.0);
}
}
#[test]
fn test_msm_forward() {
let config = MSMConfig {
signal_dim: 64,
embed_dim: 32,
..Default::default()
};
let msm = MaskedSignalModeling::new(config).unwrap();
let signal = Array1::linspace(0.0, 1.0, 64);
let reconstructed = msm.forward(&signal);
assert!(reconstructed.is_ok());
let reconstructed = reconstructed.unwrap();
assert_eq!(reconstructed.len(), 64);
}
#[test]
fn test_msm_pretrain() {
let config = MSMConfig {
signal_dim: 32,
embed_dim: 16,
epochs: 5,
..Default::default()
};
let mut msm = MaskedSignalModeling::new(config).unwrap();
let signals: Vec<Array1<f32>> = (0..10)
.map(|i| Array1::linspace(i as f32, (i + 1) as f32, 32))
.collect();
let losses = msm.pretrain(&signals, 5);
assert!(losses.is_ok());
let losses = losses.unwrap();
assert_eq!(losses.len(), 5);
assert!(losses[4] <= losses[0] * 1.5); }
#[test]
fn test_contrastive_learning_creation() {
let config = ContrastiveConfig::default();
let cl = ContrastiveLearning::new(128, config);
assert_eq!(cl.encoder.nrows(), 128);
}
#[test]
fn test_contrastive_augment() {
let config = ContrastiveConfig {
aug_noise_std: 0.1,
..Default::default()
};
let mut cl = ContrastiveLearning::new(64, config);
let signal = Array1::zeros(64);
let augmented = cl.augment(&signal);
assert_eq!(augmented.len(), 64);
let has_noise = augmented.iter().any(|&x| x != 0.0);
assert!(has_noise);
}
#[test]
fn test_contrastive_encode() {
let config = ContrastiveConfig::default();
let cl = ContrastiveLearning::new(64, config);
let signal = Array1::linspace(0.0, 1.0, 64);
let embedding = cl.encode(&signal);
assert_eq!(embedding.len(), cl.config.embed_dim);
let norm: f32 = embedding.iter().map(|&x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
}
#[test]
fn test_contrastive_loss() {
let config = ContrastiveConfig {
num_negatives: 2,
..Default::default()
};
let mut cl = ContrastiveLearning::new(32, config);
let signals: Vec<Array1<f32>> = (0..5)
.map(|i| Array1::linspace(i as f32, (i + 1) as f32, 32))
.collect();
let loss = cl.contrastive_loss(&signals);
assert!(loss.is_ok());
let loss = loss.unwrap();
assert!(loss.is_finite() && loss >= 0.0);
}
#[test]
fn test_temporal_prediction_creation() {
let config = TemporalPredictionConfig::default();
let tp = TemporalPrediction::new(config);
assert_eq!(tp.context_encoder.nrows(), tp.config.context_size);
}
#[test]
fn test_temporal_prediction_predict() {
let config = TemporalPredictionConfig {
context_size: 32,
prediction_size: 8,
embed_dim: 16,
..Default::default()
};
let tp = TemporalPrediction::new(config);
let context = Array1::linspace(0.0, 1.0, 32);
let prediction = tp.predict(&context);
assert!(prediction.is_ok());
let prediction = prediction.unwrap();
assert_eq!(prediction.len(), 8);
}
#[test]
fn test_temporal_prediction_wrong_context_size() {
let config = TemporalPredictionConfig {
context_size: 32,
..Default::default()
};
let tp = TemporalPrediction::new(config);
let wrong_context = Array1::linspace(0.0, 1.0, 16); let prediction = tp.predict(&wrong_context);
assert!(prediction.is_err());
}
}