kizzasi_tokenizer/
pretraining.rs

1//! Self-supervised pre-training for tokenizers.
2//!
3//! This module implements various self-supervised learning objectives for
4//! pre-training signal tokenizers without requiring labeled data.
5//!
6//! # Supported Methods
7//!
8//! - **Masked Signal Modeling (MSM)**: Predict masked portions of the signal
9//! - **Contrastive Learning**: Learn representations by contrasting similar/dissimilar segments
10//! - **Temporal Prediction**: Predict future signal segments from past context
11//! - **Denoising**: Reconstruct clean signals from noisy inputs
12//!
13//! # Example
14//!
15//! ```
16//! use kizzasi_tokenizer::{MaskedSignalModeling, MSMConfig};
17//! use scirs2_core::ndarray::Array1;
18//!
19//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! let config = MSMConfig {
21//!     mask_ratio: 0.15,
22//!     mask_length: 16,
23//!     learning_rate: 0.001,
24//!     ..Default::default()
25//! };
26//!
27//! let mut msm = MaskedSignalModeling::new(config)?;
28//! let signals = vec![Array1::linspace(0.0, 1.0, 256); 32];
29//! msm.pretrain(&signals, 10)?;
30//! # Ok(())
31//! # }
32//! ```
33
34use crate::error::{TokenizerError, TokenizerResult};
35use scirs2_core::ndarray::{Array1, Array2};
36use scirs2_core::random::{rngs::StdRng, Random};
37use serde::{Deserialize, Serialize};
38
39/// Configuration for Masked Signal Modeling
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct MSMConfig {
42    /// Ratio of signal to mask (0.0 to 1.0)
43    pub mask_ratio: f32,
44    /// Length of each masked segment
45    pub mask_length: usize,
46    /// Signal dimension
47    pub signal_dim: usize,
48    /// Embedding dimension
49    pub embed_dim: usize,
50    /// Learning rate
51    pub learning_rate: f32,
52    /// Number of training epochs
53    pub epochs: usize,
54    /// Batch size
55    pub batch_size: usize,
56}
57
58impl Default for MSMConfig {
59    fn default() -> Self {
60        Self {
61            mask_ratio: 0.15,
62            mask_length: 16,
63            signal_dim: 256,
64            embed_dim: 128,
65            learning_rate: 0.001,
66            epochs: 100,
67            batch_size: 32,
68        }
69    }
70}
71
72impl MSMConfig {
73    /// Validate the configuration
74    pub fn validate(&self) -> TokenizerResult<()> {
75        if !(0.0..=1.0).contains(&self.mask_ratio) {
76            return Err(TokenizerError::invalid_input(
77                "mask_ratio must be in [0.0, 1.0]",
78                "MSMConfig::validate",
79            ));
80        }
81        if self.mask_length == 0 {
82            return Err(TokenizerError::invalid_input(
83                "mask_length must be positive",
84                "MSMConfig::validate",
85            ));
86        }
87        if self.signal_dim == 0 || self.embed_dim == 0 {
88            return Err(TokenizerError::invalid_input(
89                "signal_dim and embed_dim must be positive",
90                "MSMConfig::validate",
91            ));
92        }
93        if !(0.0..1.0).contains(&self.learning_rate) {
94            return Err(TokenizerError::invalid_input(
95                "learning_rate must be in (0.0, 1.0)",
96                "MSMConfig::validate",
97            ));
98        }
99        if self.epochs == 0 || self.batch_size == 0 {
100            return Err(TokenizerError::invalid_input(
101                "epochs and batch_size must be positive",
102                "MSMConfig::validate",
103            ));
104        }
105        Ok(())
106    }
107}
108
109/// Masked Signal Modeling pre-trainer
110///
111/// Learns signal representations by predicting masked portions of the input.
112#[derive(Debug, Clone)]
113pub struct MaskedSignalModeling {
114    /// Configuration
115    config: MSMConfig,
116    /// Encoder weights [signal_dim, embed_dim]
117    encoder: Array2<f32>,
118    /// Decoder weights [embed_dim, signal_dim]
119    decoder: Array2<f32>,
120    /// Random number generator
121    rng: Random<StdRng>,
122}
123
124impl MaskedSignalModeling {
125    /// Create a new MSM pre-trainer
126    pub fn new(config: MSMConfig) -> TokenizerResult<Self> {
127        config.validate()?;
128
129        let mut rng = Random::seed(45);
130
131        // Xavier initialization
132        let encoder_scale = (2.0 / (config.signal_dim + config.embed_dim) as f32).sqrt();
133        let decoder_scale = (2.0 / (config.embed_dim + config.signal_dim) as f32).sqrt();
134
135        let encoder =
136            Self::init_weights(config.signal_dim, config.embed_dim, encoder_scale, &mut rng);
137        let decoder =
138            Self::init_weights(config.embed_dim, config.signal_dim, decoder_scale, &mut rng);
139
140        Ok(Self {
141            config,
142            encoder,
143            decoder,
144            rng,
145        })
146    }
147
148    /// Initialize weights with Xavier uniform distribution
149    fn init_weights(rows: usize, cols: usize, scale: f32, rng: &mut Random<StdRng>) -> Array2<f32> {
150        let mut weights = Array2::zeros((rows, cols));
151        for val in weights.iter_mut() {
152            *val = (rng.gen_range(-1.0..1.0)) * scale;
153        }
154        weights
155    }
156
157    /// Create a mask for the signal
158    ///
159    /// Returns a boolean array where true indicates masked positions
160    fn create_mask(&mut self, signal_len: usize) -> Array1<bool> {
161        let mut mask = Array1::from_elem(signal_len, false);
162        let num_masks = ((signal_len as f32 * self.config.mask_ratio)
163            / self.config.mask_length as f32) as usize;
164
165        for _ in 0..num_masks {
166            let start = (self.rng.gen_range(0.0..1.0)
167                * (signal_len - self.config.mask_length) as f32) as usize;
168            let end = (start + self.config.mask_length).min(signal_len);
169            for i in start..end {
170                mask[i] = true;
171            }
172        }
173
174        mask
175    }
176
177    /// Apply mask to signal (replace masked values with zeros)
178    fn apply_mask(&self, signal: &Array1<f32>, mask: &Array1<bool>) -> Array1<f32> {
179        signal
180            .iter()
181            .zip(mask.iter())
182            .map(|(&val, &is_masked)| if is_masked { 0.0 } else { val })
183            .collect()
184    }
185
186    /// Forward pass: encode and decode
187    fn forward(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
188        // Encode: signal -> embedding
189        let mut embedding = Array1::zeros(self.config.embed_dim);
190        for j in 0..self.config.embed_dim {
191            let mut sum = 0.0;
192            for i in 0..self.config.signal_dim.min(signal.len()) {
193                sum += signal[i] * self.encoder[[i, j]];
194            }
195            embedding[j] = sum;
196        }
197
198        // Apply ReLU activation
199        embedding.mapv_inplace(|x| x.max(0.0));
200
201        // Decode: embedding -> reconstructed signal
202        let mut reconstructed = Array1::zeros(self.config.signal_dim);
203        for i in 0..self.config.signal_dim {
204            let mut sum = 0.0;
205            for j in 0..self.config.embed_dim {
206                sum += embedding[j] * self.decoder[[j, i]];
207            }
208            reconstructed[i] = sum;
209        }
210
211        Ok(reconstructed)
212    }
213
214    /// Compute MSE loss between target and prediction
215    fn compute_loss(
216        &self,
217        target: &Array1<f32>,
218        prediction: &Array1<f32>,
219        mask: &Array1<bool>,
220    ) -> f32 {
221        let mut loss = 0.0;
222        let mut count = 0;
223
224        for i in 0..target.len().min(prediction.len()).min(mask.len()) {
225            if mask[i] {
226                let diff = target[i] - prediction[i];
227                loss += diff * diff;
228                count += 1;
229            }
230        }
231
232        if count > 0 {
233            loss / count as f32
234        } else {
235            0.0
236        }
237    }
238
239    /// Pre-train on a dataset of signals
240    pub fn pretrain(
241        &mut self,
242        signals: &[Array1<f32>],
243        num_epochs: usize,
244    ) -> TokenizerResult<Vec<f32>> {
245        let mut losses = Vec::new();
246
247        for epoch in 0..num_epochs {
248            let mut epoch_loss = 0.0;
249            let mut num_batches = 0;
250
251            // Process each signal
252            for signal in signals {
253                if signal.len() != self.config.signal_dim {
254                    continue; // Skip signals with wrong dimension
255                }
256
257                // Create mask
258                let mask = self.create_mask(signal.len());
259
260                // Apply mask
261                let masked_signal = self.apply_mask(signal, &mask);
262
263                // Forward pass
264                let reconstructed = self.forward(&masked_signal)?;
265
266                // Compute loss (only on masked positions)
267                let loss = self.compute_loss(signal, &reconstructed, &mask);
268                epoch_loss += loss;
269                num_batches += 1;
270
271                // Backward pass (simplified gradient descent)
272                self.update_weights(signal, &masked_signal, &reconstructed, &mask)?;
273            }
274
275            if num_batches > 0 {
276                epoch_loss /= num_batches as f32;
277                losses.push(epoch_loss);
278
279                if epoch % 10 == 0 {
280                    tracing::debug!("Epoch {}: Loss = {:.6}", epoch, epoch_loss);
281                }
282            }
283        }
284
285        Ok(losses)
286    }
287
288    /// Update weights using gradient descent (simplified)
289    fn update_weights(
290        &mut self,
291        target: &Array1<f32>,
292        input: &Array1<f32>,
293        output: &Array1<f32>,
294        mask: &Array1<bool>,
295    ) -> TokenizerResult<()> {
296        let lr = self.config.learning_rate;
297
298        // Compute output error (only for masked positions)
299        let mut output_error = Array1::zeros(self.config.signal_dim);
300        for i in 0..self.config.signal_dim.min(output.len()).min(target.len()) {
301            if i < mask.len() && mask[i] {
302                output_error[i] = output[i] - target[i];
303            }
304        }
305
306        // Compute embedding
307        let mut embedding = Array1::zeros(self.config.embed_dim);
308        for j in 0..self.config.embed_dim {
309            let mut sum = 0.0;
310            for i in 0..self.config.signal_dim.min(input.len()) {
311                sum += input[i] * self.encoder[[i, j]];
312            }
313            embedding[j] = sum.max(0.0); // ReLU
314        }
315
316        // Update decoder weights
317        for j in 0..self.config.embed_dim {
318            for i in 0..self.config.signal_dim {
319                let gradient = output_error[i] * embedding[j];
320                self.decoder[[j, i]] -= lr * gradient;
321            }
322        }
323
324        // Compute hidden error
325        let mut hidden_error = Array1::zeros(self.config.embed_dim);
326        for j in 0..self.config.embed_dim {
327            let mut sum = 0.0;
328            for i in 0..self.config.signal_dim {
329                sum += output_error[i] * self.decoder[[j, i]];
330            }
331            // ReLU derivative (0 if embedding[j] <= 0, else 1)
332            hidden_error[j] = if embedding[j] > 0.0 { sum } else { 0.0 };
333        }
334
335        // Update encoder weights
336        for i in 0..self.config.signal_dim.min(input.len()) {
337            for j in 0..self.config.embed_dim {
338                let gradient = hidden_error[j] * input[i];
339                self.encoder[[i, j]] -= lr * gradient;
340            }
341        }
342
343        Ok(())
344    }
345
346    /// Get the learned encoder weights
347    pub fn encoder_weights(&self) -> &Array2<f32> {
348        &self.encoder
349    }
350
351    /// Get the learned decoder weights
352    pub fn decoder_weights(&self) -> &Array2<f32> {
353        &self.decoder
354    }
355}
356
357/// Configuration for Contrastive Learning
358#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct ContrastiveConfig {
360    /// Embedding dimension
361    pub embed_dim: usize,
362    /// Temperature for contrastive loss
363    pub temperature: f32,
364    /// Augmentation noise std
365    pub aug_noise_std: f32,
366    /// Learning rate
367    pub learning_rate: f32,
368    /// Number of negative samples
369    pub num_negatives: usize,
370}
371
372impl Default for ContrastiveConfig {
373    fn default() -> Self {
374        Self {
375            embed_dim: 128,
376            temperature: 0.07,
377            aug_noise_std: 0.1,
378            learning_rate: 0.001,
379            num_negatives: 16,
380        }
381    }
382}
383
384/// Contrastive learning pre-trainer
385///
386/// Learns representations by maximizing agreement between augmented views
387/// of the same signal while minimizing agreement with different signals.
388#[derive(Debug, Clone)]
389pub struct ContrastiveLearning {
390    /// Configuration
391    config: ContrastiveConfig,
392    /// Encoder weights
393    encoder: Array2<f32>,
394    /// Random number generator
395    rng: Random<StdRng>,
396}
397
398impl ContrastiveLearning {
399    /// Create a new contrastive learning pre-trainer
400    pub fn new(signal_dim: usize, config: ContrastiveConfig) -> Self {
401        let mut rng = Random::seed(46);
402        let scale = (2.0 / (signal_dim + config.embed_dim) as f32).sqrt();
403
404        let mut encoder = Array2::zeros((signal_dim, config.embed_dim));
405        for val in encoder.iter_mut() {
406            *val = (rng.gen_range(-1.0..1.0)) * scale;
407        }
408
409        Self {
410            config,
411            encoder,
412            rng,
413        }
414    }
415
416    /// Apply data augmentation (add noise)
417    fn augment(&mut self, signal: &Array1<f32>) -> Array1<f32> {
418        signal.mapv(|x| {
419            let noise = (self.rng.gen_range(-1.0..1.0)) * self.config.aug_noise_std;
420            x + noise
421        })
422    }
423
424    /// Encode signal to embedding
425    fn encode(&self, signal: &Array1<f32>) -> Array1<f32> {
426        let mut embedding = Array1::zeros(self.config.embed_dim);
427        for j in 0..self.config.embed_dim {
428            let mut sum = 0.0;
429            for i in 0..signal.len().min(self.encoder.nrows()) {
430                sum += signal[i] * self.encoder[[i, j]];
431            }
432            embedding[j] = sum;
433        }
434
435        // L2 normalization
436        let norm = embedding.iter().map(|&x| x * x).sum::<f32>().sqrt();
437        if norm > 0.0 {
438            embedding /= norm;
439        }
440        embedding
441    }
442
443    /// Compute cosine similarity
444    fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
445        a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
446    }
447
448    /// Compute contrastive loss (NT-Xent)
449    pub fn contrastive_loss(&mut self, signals: &[Array1<f32>]) -> TokenizerResult<f32> {
450        if signals.len() < 2 {
451            return Ok(0.0);
452        }
453
454        let mut total_loss = 0.0;
455        let mut count = 0;
456
457        for i in 0..signals.len() {
458            // Create two augmented views of the same signal
459            let view1 = self.augment(&signals[i]);
460            let view2 = self.augment(&signals[i]);
461
462            let z1 = self.encode(&view1);
463            let z2 = self.encode(&view2);
464
465            // Positive pair similarity
466            let pos_sim = self.cosine_similarity(&z1, &z2) / self.config.temperature;
467
468            // Negative pairs (from other signals)
469            let mut neg_sims = Vec::new();
470            for (j, signal) in signals.iter().enumerate() {
471                if i != j {
472                    let neg_view = self.augment(signal);
473                    let z_neg = self.encode(&neg_view);
474                    let neg_sim = self.cosine_similarity(&z1, &z_neg) / self.config.temperature;
475                    neg_sims.push(neg_sim);
476
477                    if neg_sims.len() >= self.config.num_negatives {
478                        break;
479                    }
480                }
481            }
482
483            // NT-Xent loss: -log(exp(pos) / (exp(pos) + sum(exp(neg))))
484            let pos_exp = pos_sim.exp();
485            let neg_sum: f32 = neg_sims.iter().map(|&x| x.exp()).sum();
486            let loss = -(pos_exp / (pos_exp + neg_sum)).ln();
487
488            total_loss += loss;
489            count += 1;
490        }
491
492        Ok(if count > 0 {
493            total_loss / count as f32
494        } else {
495            0.0
496        })
497    }
498
499    /// Get encoder weights
500    pub fn encoder_weights(&self) -> &Array2<f32> {
501        &self.encoder
502    }
503}
504
505/// Configuration for Temporal Prediction
506#[derive(Debug, Clone, Serialize, Deserialize)]
507pub struct TemporalPredictionConfig {
508    /// Context window size
509    pub context_size: usize,
510    /// Prediction horizon
511    pub prediction_size: usize,
512    /// Embedding dimension
513    pub embed_dim: usize,
514    /// Learning rate
515    pub learning_rate: f32,
516}
517
518impl Default for TemporalPredictionConfig {
519    fn default() -> Self {
520        Self {
521            context_size: 64,
522            prediction_size: 16,
523            embed_dim: 128,
524            learning_rate: 0.001,
525        }
526    }
527}
528
529/// Temporal prediction pre-trainer
530///
531/// Learns representations by predicting future signal segments from past context.
532#[derive(Debug, Clone)]
533pub struct TemporalPrediction {
534    /// Configuration
535    config: TemporalPredictionConfig,
536    /// Context encoder weights
537    context_encoder: Array2<f32>,
538    /// Prediction head weights
539    prediction_head: Array2<f32>,
540}
541
542impl TemporalPrediction {
543    /// Create a new temporal prediction pre-trainer
544    pub fn new(config: TemporalPredictionConfig) -> Self {
545        let mut rng = Random::seed(47);
546
547        let encoder_scale = (2.0 / (config.context_size + config.embed_dim) as f32).sqrt();
548        let head_scale = (2.0 / (config.embed_dim + config.prediction_size) as f32).sqrt();
549
550        let mut context_encoder = Array2::zeros((config.context_size, config.embed_dim));
551        let mut prediction_head = Array2::zeros((config.embed_dim, config.prediction_size));
552
553        for val in context_encoder.iter_mut() {
554            *val = (rng.gen_range(-1.0..1.0)) * encoder_scale;
555        }
556        for val in prediction_head.iter_mut() {
557            *val = (rng.gen_range(-1.0..1.0)) * head_scale;
558        }
559
560        Self {
561            config,
562            context_encoder,
563            prediction_head,
564        }
565    }
566
567    /// Predict future segment from context
568    pub fn predict(&self, context: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
569        if context.len() != self.config.context_size {
570            return Err(TokenizerError::encoding(
571                format!(
572                    "Context size mismatch: expected {}, got {}",
573                    self.config.context_size,
574                    context.len()
575                ),
576                "TemporalPrediction::predict",
577            ));
578        }
579
580        // Encode context
581        let mut embedding = Array1::zeros(self.config.embed_dim);
582        for j in 0..self.config.embed_dim {
583            let mut sum = 0.0;
584            for i in 0..self.config.context_size {
585                sum += context[i] * self.context_encoder[[i, j]];
586            }
587            embedding[j] = sum.max(0.0); // ReLU
588        }
589
590        // Predict future
591        let mut prediction = Array1::zeros(self.config.prediction_size);
592        for i in 0..self.config.prediction_size {
593            let mut sum = 0.0;
594            for j in 0..self.config.embed_dim {
595                sum += embedding[j] * self.prediction_head[[j, i]];
596            }
597            prediction[i] = sum;
598        }
599
600        Ok(prediction)
601    }
602
603    /// Get context encoder weights
604    pub fn context_encoder_weights(&self) -> &Array2<f32> {
605        &self.context_encoder
606    }
607
608    /// Get prediction head weights
609    pub fn prediction_head_weights(&self) -> &Array2<f32> {
610        &self.prediction_head
611    }
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617
618    #[test]
619    fn test_msm_config_validation() {
620        let config = MSMConfig::default();
621        assert!(config.validate().is_ok());
622
623        let mut bad_config = config.clone();
624        bad_config.mask_ratio = 1.5;
625        assert!(bad_config.validate().is_err());
626
627        let mut bad_config = config.clone();
628        bad_config.learning_rate = 1.5;
629        assert!(bad_config.validate().is_err());
630    }
631
632    #[test]
633    fn test_msm_creation() {
634        let config = MSMConfig::default();
635        let msm = MaskedSignalModeling::new(config);
636        assert!(msm.is_ok());
637    }
638
639    #[test]
640    fn test_msm_create_mask() {
641        let config = MSMConfig {
642            mask_ratio: 0.2,
643            mask_length: 10,
644            ..Default::default()
645        };
646        let mut msm = MaskedSignalModeling::new(config).unwrap();
647
648        let mask = msm.create_mask(100);
649        assert_eq!(mask.len(), 100);
650
651        // Count masked positions
652        let num_masked = mask.iter().filter(|&&x| x).count();
653        assert!(num_masked > 0 && num_masked < 100);
654    }
655
656    #[test]
657    fn test_msm_apply_mask() {
658        let config = MSMConfig::default();
659        let msm = MaskedSignalModeling::new(config).unwrap();
660
661        let signal = Array1::linspace(0.0, 1.0, 100);
662        let mask = Array1::from_vec(vec![false; 50].into_iter().chain(vec![true; 50]).collect());
663
664        let masked = msm.apply_mask(&signal, &mask);
665        assert_eq!(masked.len(), 100);
666
667        // First half should be unchanged, second half should be zero
668        for i in 0..50 {
669            assert!((masked[i] - signal[i]).abs() < 1e-6);
670        }
671        for i in 50..100 {
672            assert_eq!(masked[i], 0.0);
673        }
674    }
675
676    #[test]
677    fn test_msm_forward() {
678        let config = MSMConfig {
679            signal_dim: 64,
680            embed_dim: 32,
681            ..Default::default()
682        };
683        let msm = MaskedSignalModeling::new(config).unwrap();
684
685        let signal = Array1::linspace(0.0, 1.0, 64);
686        let reconstructed = msm.forward(&signal);
687        assert!(reconstructed.is_ok());
688
689        let reconstructed = reconstructed.unwrap();
690        assert_eq!(reconstructed.len(), 64);
691    }
692
693    #[test]
694    fn test_msm_pretrain() {
695        let config = MSMConfig {
696            signal_dim: 32,
697            embed_dim: 16,
698            epochs: 5,
699            ..Default::default()
700        };
701        let mut msm = MaskedSignalModeling::new(config).unwrap();
702
703        let signals: Vec<Array1<f32>> = (0..10)
704            .map(|i| Array1::linspace(i as f32, (i + 1) as f32, 32))
705            .collect();
706
707        let losses = msm.pretrain(&signals, 5);
708        assert!(losses.is_ok());
709
710        let losses = losses.unwrap();
711        assert_eq!(losses.len(), 5);
712
713        // Loss should generally decrease
714        assert!(losses[4] <= losses[0] * 1.5); // Allow some variance
715    }
716
717    #[test]
718    fn test_contrastive_learning_creation() {
719        let config = ContrastiveConfig::default();
720        let cl = ContrastiveLearning::new(128, config);
721        assert_eq!(cl.encoder.nrows(), 128);
722    }
723
724    #[test]
725    fn test_contrastive_augment() {
726        let config = ContrastiveConfig {
727            aug_noise_std: 0.1,
728            ..Default::default()
729        };
730        let mut cl = ContrastiveLearning::new(64, config);
731
732        let signal = Array1::zeros(64);
733        let augmented = cl.augment(&signal);
734        assert_eq!(augmented.len(), 64);
735
736        // Should have some noise
737        let has_noise = augmented.iter().any(|&x| x != 0.0);
738        assert!(has_noise);
739    }
740
741    #[test]
742    fn test_contrastive_encode() {
743        let config = ContrastiveConfig::default();
744        let cl = ContrastiveLearning::new(64, config);
745
746        let signal = Array1::linspace(0.0, 1.0, 64);
747        let embedding = cl.encode(&signal);
748        assert_eq!(embedding.len(), cl.config.embed_dim);
749
750        // Check L2 normalization
751        let norm: f32 = embedding.iter().map(|&x| x * x).sum::<f32>().sqrt();
752        assert!((norm - 1.0).abs() < 1e-5);
753    }
754
755    #[test]
756    fn test_contrastive_loss() {
757        let config = ContrastiveConfig {
758            num_negatives: 2,
759            ..Default::default()
760        };
761        let mut cl = ContrastiveLearning::new(32, config);
762
763        let signals: Vec<Array1<f32>> = (0..5)
764            .map(|i| Array1::linspace(i as f32, (i + 1) as f32, 32))
765            .collect();
766
767        let loss = cl.contrastive_loss(&signals);
768        assert!(loss.is_ok());
769
770        let loss = loss.unwrap();
771        assert!(loss.is_finite() && loss >= 0.0);
772    }
773
774    #[test]
775    fn test_temporal_prediction_creation() {
776        let config = TemporalPredictionConfig::default();
777        let tp = TemporalPrediction::new(config);
778        assert_eq!(tp.context_encoder.nrows(), tp.config.context_size);
779    }
780
781    #[test]
782    fn test_temporal_prediction_predict() {
783        let config = TemporalPredictionConfig {
784            context_size: 32,
785            prediction_size: 8,
786            embed_dim: 16,
787            ..Default::default()
788        };
789        let tp = TemporalPrediction::new(config);
790
791        let context = Array1::linspace(0.0, 1.0, 32);
792        let prediction = tp.predict(&context);
793        assert!(prediction.is_ok());
794
795        let prediction = prediction.unwrap();
796        assert_eq!(prediction.len(), 8);
797    }
798
799    #[test]
800    fn test_temporal_prediction_wrong_context_size() {
801        let config = TemporalPredictionConfig {
802            context_size: 32,
803            ..Default::default()
804        };
805        let tp = TemporalPrediction::new(config);
806
807        let wrong_context = Array1::linspace(0.0, 1.0, 16); // Wrong size
808        let prediction = tp.predict(&wrong_context);
809        assert!(prediction.is_err());
810    }
811}