1use crate::error::{TokenizerError, TokenizerResult};
35use scirs2_core::ndarray::{Array1, Array2};
36use scirs2_core::random::{rngs::StdRng, Random};
37use serde::{Deserialize, Serialize};
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct MSMConfig {
42 pub mask_ratio: f32,
44 pub mask_length: usize,
46 pub signal_dim: usize,
48 pub embed_dim: usize,
50 pub learning_rate: f32,
52 pub epochs: usize,
54 pub batch_size: usize,
56}
57
58impl Default for MSMConfig {
59 fn default() -> Self {
60 Self {
61 mask_ratio: 0.15,
62 mask_length: 16,
63 signal_dim: 256,
64 embed_dim: 128,
65 learning_rate: 0.001,
66 epochs: 100,
67 batch_size: 32,
68 }
69 }
70}
71
72impl MSMConfig {
73 pub fn validate(&self) -> TokenizerResult<()> {
75 if !(0.0..=1.0).contains(&self.mask_ratio) {
76 return Err(TokenizerError::invalid_input(
77 "mask_ratio must be in [0.0, 1.0]",
78 "MSMConfig::validate",
79 ));
80 }
81 if self.mask_length == 0 {
82 return Err(TokenizerError::invalid_input(
83 "mask_length must be positive",
84 "MSMConfig::validate",
85 ));
86 }
87 if self.signal_dim == 0 || self.embed_dim == 0 {
88 return Err(TokenizerError::invalid_input(
89 "signal_dim and embed_dim must be positive",
90 "MSMConfig::validate",
91 ));
92 }
93 if !(0.0..1.0).contains(&self.learning_rate) {
94 return Err(TokenizerError::invalid_input(
95 "learning_rate must be in (0.0, 1.0)",
96 "MSMConfig::validate",
97 ));
98 }
99 if self.epochs == 0 || self.batch_size == 0 {
100 return Err(TokenizerError::invalid_input(
101 "epochs and batch_size must be positive",
102 "MSMConfig::validate",
103 ));
104 }
105 Ok(())
106 }
107}
108
109#[derive(Debug, Clone)]
113pub struct MaskedSignalModeling {
114 config: MSMConfig,
116 encoder: Array2<f32>,
118 decoder: Array2<f32>,
120 rng: Random<StdRng>,
122}
123
124impl MaskedSignalModeling {
125 pub fn new(config: MSMConfig) -> TokenizerResult<Self> {
127 config.validate()?;
128
129 let mut rng = Random::seed(45);
130
131 let encoder_scale = (2.0 / (config.signal_dim + config.embed_dim) as f32).sqrt();
133 let decoder_scale = (2.0 / (config.embed_dim + config.signal_dim) as f32).sqrt();
134
135 let encoder =
136 Self::init_weights(config.signal_dim, config.embed_dim, encoder_scale, &mut rng);
137 let decoder =
138 Self::init_weights(config.embed_dim, config.signal_dim, decoder_scale, &mut rng);
139
140 Ok(Self {
141 config,
142 encoder,
143 decoder,
144 rng,
145 })
146 }
147
148 fn init_weights(rows: usize, cols: usize, scale: f32, rng: &mut Random<StdRng>) -> Array2<f32> {
150 let mut weights = Array2::zeros((rows, cols));
151 for val in weights.iter_mut() {
152 *val = (rng.gen_range(-1.0..1.0)) * scale;
153 }
154 weights
155 }
156
157 fn create_mask(&mut self, signal_len: usize) -> Array1<bool> {
161 let mut mask = Array1::from_elem(signal_len, false);
162 let num_masks = ((signal_len as f32 * self.config.mask_ratio)
163 / self.config.mask_length as f32) as usize;
164
165 for _ in 0..num_masks {
166 let start = (self.rng.gen_range(0.0..1.0)
167 * (signal_len - self.config.mask_length) as f32) as usize;
168 let end = (start + self.config.mask_length).min(signal_len);
169 for i in start..end {
170 mask[i] = true;
171 }
172 }
173
174 mask
175 }
176
177 fn apply_mask(&self, signal: &Array1<f32>, mask: &Array1<bool>) -> Array1<f32> {
179 signal
180 .iter()
181 .zip(mask.iter())
182 .map(|(&val, &is_masked)| if is_masked { 0.0 } else { val })
183 .collect()
184 }
185
186 fn forward(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
188 let mut embedding = Array1::zeros(self.config.embed_dim);
190 for j in 0..self.config.embed_dim {
191 let mut sum = 0.0;
192 for i in 0..self.config.signal_dim.min(signal.len()) {
193 sum += signal[i] * self.encoder[[i, j]];
194 }
195 embedding[j] = sum;
196 }
197
198 embedding.mapv_inplace(|x| x.max(0.0));
200
201 let mut reconstructed = Array1::zeros(self.config.signal_dim);
203 for i in 0..self.config.signal_dim {
204 let mut sum = 0.0;
205 for j in 0..self.config.embed_dim {
206 sum += embedding[j] * self.decoder[[j, i]];
207 }
208 reconstructed[i] = sum;
209 }
210
211 Ok(reconstructed)
212 }
213
214 fn compute_loss(
216 &self,
217 target: &Array1<f32>,
218 prediction: &Array1<f32>,
219 mask: &Array1<bool>,
220 ) -> f32 {
221 let mut loss = 0.0;
222 let mut count = 0;
223
224 for i in 0..target.len().min(prediction.len()).min(mask.len()) {
225 if mask[i] {
226 let diff = target[i] - prediction[i];
227 loss += diff * diff;
228 count += 1;
229 }
230 }
231
232 if count > 0 {
233 loss / count as f32
234 } else {
235 0.0
236 }
237 }
238
239 pub fn pretrain(
241 &mut self,
242 signals: &[Array1<f32>],
243 num_epochs: usize,
244 ) -> TokenizerResult<Vec<f32>> {
245 let mut losses = Vec::new();
246
247 for epoch in 0..num_epochs {
248 let mut epoch_loss = 0.0;
249 let mut num_batches = 0;
250
251 for signal in signals {
253 if signal.len() != self.config.signal_dim {
254 continue; }
256
257 let mask = self.create_mask(signal.len());
259
260 let masked_signal = self.apply_mask(signal, &mask);
262
263 let reconstructed = self.forward(&masked_signal)?;
265
266 let loss = self.compute_loss(signal, &reconstructed, &mask);
268 epoch_loss += loss;
269 num_batches += 1;
270
271 self.update_weights(signal, &masked_signal, &reconstructed, &mask)?;
273 }
274
275 if num_batches > 0 {
276 epoch_loss /= num_batches as f32;
277 losses.push(epoch_loss);
278
279 if epoch % 10 == 0 {
280 tracing::debug!("Epoch {}: Loss = {:.6}", epoch, epoch_loss);
281 }
282 }
283 }
284
285 Ok(losses)
286 }
287
288 fn update_weights(
290 &mut self,
291 target: &Array1<f32>,
292 input: &Array1<f32>,
293 output: &Array1<f32>,
294 mask: &Array1<bool>,
295 ) -> TokenizerResult<()> {
296 let lr = self.config.learning_rate;
297
298 let mut output_error = Array1::zeros(self.config.signal_dim);
300 for i in 0..self.config.signal_dim.min(output.len()).min(target.len()) {
301 if i < mask.len() && mask[i] {
302 output_error[i] = output[i] - target[i];
303 }
304 }
305
306 let mut embedding = Array1::zeros(self.config.embed_dim);
308 for j in 0..self.config.embed_dim {
309 let mut sum = 0.0;
310 for i in 0..self.config.signal_dim.min(input.len()) {
311 sum += input[i] * self.encoder[[i, j]];
312 }
313 embedding[j] = sum.max(0.0); }
315
316 for j in 0..self.config.embed_dim {
318 for i in 0..self.config.signal_dim {
319 let gradient = output_error[i] * embedding[j];
320 self.decoder[[j, i]] -= lr * gradient;
321 }
322 }
323
324 let mut hidden_error = Array1::zeros(self.config.embed_dim);
326 for j in 0..self.config.embed_dim {
327 let mut sum = 0.0;
328 for i in 0..self.config.signal_dim {
329 sum += output_error[i] * self.decoder[[j, i]];
330 }
331 hidden_error[j] = if embedding[j] > 0.0 { sum } else { 0.0 };
333 }
334
335 for i in 0..self.config.signal_dim.min(input.len()) {
337 for j in 0..self.config.embed_dim {
338 let gradient = hidden_error[j] * input[i];
339 self.encoder[[i, j]] -= lr * gradient;
340 }
341 }
342
343 Ok(())
344 }
345
346 pub fn encoder_weights(&self) -> &Array2<f32> {
348 &self.encoder
349 }
350
351 pub fn decoder_weights(&self) -> &Array2<f32> {
353 &self.decoder
354 }
355}
356
357#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct ContrastiveConfig {
360 pub embed_dim: usize,
362 pub temperature: f32,
364 pub aug_noise_std: f32,
366 pub learning_rate: f32,
368 pub num_negatives: usize,
370}
371
372impl Default for ContrastiveConfig {
373 fn default() -> Self {
374 Self {
375 embed_dim: 128,
376 temperature: 0.07,
377 aug_noise_std: 0.1,
378 learning_rate: 0.001,
379 num_negatives: 16,
380 }
381 }
382}
383
384#[derive(Debug, Clone)]
389pub struct ContrastiveLearning {
390 config: ContrastiveConfig,
392 encoder: Array2<f32>,
394 rng: Random<StdRng>,
396}
397
398impl ContrastiveLearning {
399 pub fn new(signal_dim: usize, config: ContrastiveConfig) -> Self {
401 let mut rng = Random::seed(46);
402 let scale = (2.0 / (signal_dim + config.embed_dim) as f32).sqrt();
403
404 let mut encoder = Array2::zeros((signal_dim, config.embed_dim));
405 for val in encoder.iter_mut() {
406 *val = (rng.gen_range(-1.0..1.0)) * scale;
407 }
408
409 Self {
410 config,
411 encoder,
412 rng,
413 }
414 }
415
416 fn augment(&mut self, signal: &Array1<f32>) -> Array1<f32> {
418 signal.mapv(|x| {
419 let noise = (self.rng.gen_range(-1.0..1.0)) * self.config.aug_noise_std;
420 x + noise
421 })
422 }
423
424 fn encode(&self, signal: &Array1<f32>) -> Array1<f32> {
426 let mut embedding = Array1::zeros(self.config.embed_dim);
427 for j in 0..self.config.embed_dim {
428 let mut sum = 0.0;
429 for i in 0..signal.len().min(self.encoder.nrows()) {
430 sum += signal[i] * self.encoder[[i, j]];
431 }
432 embedding[j] = sum;
433 }
434
435 let norm = embedding.iter().map(|&x| x * x).sum::<f32>().sqrt();
437 if norm > 0.0 {
438 embedding /= norm;
439 }
440 embedding
441 }
442
443 fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
445 a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
446 }
447
448 pub fn contrastive_loss(&mut self, signals: &[Array1<f32>]) -> TokenizerResult<f32> {
450 if signals.len() < 2 {
451 return Ok(0.0);
452 }
453
454 let mut total_loss = 0.0;
455 let mut count = 0;
456
457 for i in 0..signals.len() {
458 let view1 = self.augment(&signals[i]);
460 let view2 = self.augment(&signals[i]);
461
462 let z1 = self.encode(&view1);
463 let z2 = self.encode(&view2);
464
465 let pos_sim = self.cosine_similarity(&z1, &z2) / self.config.temperature;
467
468 let mut neg_sims = Vec::new();
470 for (j, signal) in signals.iter().enumerate() {
471 if i != j {
472 let neg_view = self.augment(signal);
473 let z_neg = self.encode(&neg_view);
474 let neg_sim = self.cosine_similarity(&z1, &z_neg) / self.config.temperature;
475 neg_sims.push(neg_sim);
476
477 if neg_sims.len() >= self.config.num_negatives {
478 break;
479 }
480 }
481 }
482
483 let pos_exp = pos_sim.exp();
485 let neg_sum: f32 = neg_sims.iter().map(|&x| x.exp()).sum();
486 let loss = -(pos_exp / (pos_exp + neg_sum)).ln();
487
488 total_loss += loss;
489 count += 1;
490 }
491
492 Ok(if count > 0 {
493 total_loss / count as f32
494 } else {
495 0.0
496 })
497 }
498
499 pub fn encoder_weights(&self) -> &Array2<f32> {
501 &self.encoder
502 }
503}
504
505#[derive(Debug, Clone, Serialize, Deserialize)]
507pub struct TemporalPredictionConfig {
508 pub context_size: usize,
510 pub prediction_size: usize,
512 pub embed_dim: usize,
514 pub learning_rate: f32,
516}
517
518impl Default for TemporalPredictionConfig {
519 fn default() -> Self {
520 Self {
521 context_size: 64,
522 prediction_size: 16,
523 embed_dim: 128,
524 learning_rate: 0.001,
525 }
526 }
527}
528
529#[derive(Debug, Clone)]
533pub struct TemporalPrediction {
534 config: TemporalPredictionConfig,
536 context_encoder: Array2<f32>,
538 prediction_head: Array2<f32>,
540}
541
542impl TemporalPrediction {
543 pub fn new(config: TemporalPredictionConfig) -> Self {
545 let mut rng = Random::seed(47);
546
547 let encoder_scale = (2.0 / (config.context_size + config.embed_dim) as f32).sqrt();
548 let head_scale = (2.0 / (config.embed_dim + config.prediction_size) as f32).sqrt();
549
550 let mut context_encoder = Array2::zeros((config.context_size, config.embed_dim));
551 let mut prediction_head = Array2::zeros((config.embed_dim, config.prediction_size));
552
553 for val in context_encoder.iter_mut() {
554 *val = (rng.gen_range(-1.0..1.0)) * encoder_scale;
555 }
556 for val in prediction_head.iter_mut() {
557 *val = (rng.gen_range(-1.0..1.0)) * head_scale;
558 }
559
560 Self {
561 config,
562 context_encoder,
563 prediction_head,
564 }
565 }
566
567 pub fn predict(&self, context: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
569 if context.len() != self.config.context_size {
570 return Err(TokenizerError::encoding(
571 format!(
572 "Context size mismatch: expected {}, got {}",
573 self.config.context_size,
574 context.len()
575 ),
576 "TemporalPrediction::predict",
577 ));
578 }
579
580 let mut embedding = Array1::zeros(self.config.embed_dim);
582 for j in 0..self.config.embed_dim {
583 let mut sum = 0.0;
584 for i in 0..self.config.context_size {
585 sum += context[i] * self.context_encoder[[i, j]];
586 }
587 embedding[j] = sum.max(0.0); }
589
590 let mut prediction = Array1::zeros(self.config.prediction_size);
592 for i in 0..self.config.prediction_size {
593 let mut sum = 0.0;
594 for j in 0..self.config.embed_dim {
595 sum += embedding[j] * self.prediction_head[[j, i]];
596 }
597 prediction[i] = sum;
598 }
599
600 Ok(prediction)
601 }
602
603 pub fn context_encoder_weights(&self) -> &Array2<f32> {
605 &self.context_encoder
606 }
607
608 pub fn prediction_head_weights(&self) -> &Array2<f32> {
610 &self.prediction_head
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617
618 #[test]
619 fn test_msm_config_validation() {
620 let config = MSMConfig::default();
621 assert!(config.validate().is_ok());
622
623 let mut bad_config = config.clone();
624 bad_config.mask_ratio = 1.5;
625 assert!(bad_config.validate().is_err());
626
627 let mut bad_config = config.clone();
628 bad_config.learning_rate = 1.5;
629 assert!(bad_config.validate().is_err());
630 }
631
632 #[test]
633 fn test_msm_creation() {
634 let config = MSMConfig::default();
635 let msm = MaskedSignalModeling::new(config);
636 assert!(msm.is_ok());
637 }
638
639 #[test]
640 fn test_msm_create_mask() {
641 let config = MSMConfig {
642 mask_ratio: 0.2,
643 mask_length: 10,
644 ..Default::default()
645 };
646 let mut msm = MaskedSignalModeling::new(config).unwrap();
647
648 let mask = msm.create_mask(100);
649 assert_eq!(mask.len(), 100);
650
651 let num_masked = mask.iter().filter(|&&x| x).count();
653 assert!(num_masked > 0 && num_masked < 100);
654 }
655
656 #[test]
657 fn test_msm_apply_mask() {
658 let config = MSMConfig::default();
659 let msm = MaskedSignalModeling::new(config).unwrap();
660
661 let signal = Array1::linspace(0.0, 1.0, 100);
662 let mask = Array1::from_vec(vec![false; 50].into_iter().chain(vec![true; 50]).collect());
663
664 let masked = msm.apply_mask(&signal, &mask);
665 assert_eq!(masked.len(), 100);
666
667 for i in 0..50 {
669 assert!((masked[i] - signal[i]).abs() < 1e-6);
670 }
671 for i in 50..100 {
672 assert_eq!(masked[i], 0.0);
673 }
674 }
675
676 #[test]
677 fn test_msm_forward() {
678 let config = MSMConfig {
679 signal_dim: 64,
680 embed_dim: 32,
681 ..Default::default()
682 };
683 let msm = MaskedSignalModeling::new(config).unwrap();
684
685 let signal = Array1::linspace(0.0, 1.0, 64);
686 let reconstructed = msm.forward(&signal);
687 assert!(reconstructed.is_ok());
688
689 let reconstructed = reconstructed.unwrap();
690 assert_eq!(reconstructed.len(), 64);
691 }
692
693 #[test]
694 fn test_msm_pretrain() {
695 let config = MSMConfig {
696 signal_dim: 32,
697 embed_dim: 16,
698 epochs: 5,
699 ..Default::default()
700 };
701 let mut msm = MaskedSignalModeling::new(config).unwrap();
702
703 let signals: Vec<Array1<f32>> = (0..10)
704 .map(|i| Array1::linspace(i as f32, (i + 1) as f32, 32))
705 .collect();
706
707 let losses = msm.pretrain(&signals, 5);
708 assert!(losses.is_ok());
709
710 let losses = losses.unwrap();
711 assert_eq!(losses.len(), 5);
712
713 assert!(losses[4] <= losses[0] * 1.5); }
716
717 #[test]
718 fn test_contrastive_learning_creation() {
719 let config = ContrastiveConfig::default();
720 let cl = ContrastiveLearning::new(128, config);
721 assert_eq!(cl.encoder.nrows(), 128);
722 }
723
724 #[test]
725 fn test_contrastive_augment() {
726 let config = ContrastiveConfig {
727 aug_noise_std: 0.1,
728 ..Default::default()
729 };
730 let mut cl = ContrastiveLearning::new(64, config);
731
732 let signal = Array1::zeros(64);
733 let augmented = cl.augment(&signal);
734 assert_eq!(augmented.len(), 64);
735
736 let has_noise = augmented.iter().any(|&x| x != 0.0);
738 assert!(has_noise);
739 }
740
741 #[test]
742 fn test_contrastive_encode() {
743 let config = ContrastiveConfig::default();
744 let cl = ContrastiveLearning::new(64, config);
745
746 let signal = Array1::linspace(0.0, 1.0, 64);
747 let embedding = cl.encode(&signal);
748 assert_eq!(embedding.len(), cl.config.embed_dim);
749
750 let norm: f32 = embedding.iter().map(|&x| x * x).sum::<f32>().sqrt();
752 assert!((norm - 1.0).abs() < 1e-5);
753 }
754
755 #[test]
756 fn test_contrastive_loss() {
757 let config = ContrastiveConfig {
758 num_negatives: 2,
759 ..Default::default()
760 };
761 let mut cl = ContrastiveLearning::new(32, config);
762
763 let signals: Vec<Array1<f32>> = (0..5)
764 .map(|i| Array1::linspace(i as f32, (i + 1) as f32, 32))
765 .collect();
766
767 let loss = cl.contrastive_loss(&signals);
768 assert!(loss.is_ok());
769
770 let loss = loss.unwrap();
771 assert!(loss.is_finite() && loss >= 0.0);
772 }
773
774 #[test]
775 fn test_temporal_prediction_creation() {
776 let config = TemporalPredictionConfig::default();
777 let tp = TemporalPrediction::new(config);
778 assert_eq!(tp.context_encoder.nrows(), tp.config.context_size);
779 }
780
781 #[test]
782 fn test_temporal_prediction_predict() {
783 let config = TemporalPredictionConfig {
784 context_size: 32,
785 prediction_size: 8,
786 embed_dim: 16,
787 ..Default::default()
788 };
789 let tp = TemporalPrediction::new(config);
790
791 let context = Array1::linspace(0.0, 1.0, 32);
792 let prediction = tp.predict(&context);
793 assert!(prediction.is_ok());
794
795 let prediction = prediction.unwrap();
796 assert_eq!(prediction.len(), 8);
797 }
798
799 #[test]
800 fn test_temporal_prediction_wrong_context_size() {
801 let config = TemporalPredictionConfig {
802 context_size: 32,
803 ..Default::default()
804 };
805 let tp = TemporalPrediction::new(config);
806
807 let wrong_context = Array1::linspace(0.0, 1.0, 16); let prediction = tp.predict(&wrong_context);
809 assert!(prediction.is_err());
810 }
811}