kizzasi_tokenizer/
continuous.rs

1//! Continuous (non-discrete) tokenization
2//!
3//! For AGSP, continuous tokenization is often preferred as it preserves
4//! the full precision of the signal. This is essentially a learned
5//! linear projection from signal space to embedding space.
6
7use crate::error::{TokenizerError, TokenizerResult};
8use crate::persistence::{ModelCheckpoint, ModelMetadata, ModelVersion};
9use crate::SignalTokenizer;
10use candle_core::{Device, Result as CandleResult, Tensor, Var};
11use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarMap};
12use scirs2_core::ndarray::{Array1, Array2};
13use scirs2_core::random::thread_rng;
14use serde::{Deserialize, Serialize};
15use std::path::Path;
16
17/// Continuous tokenizer that projects signals to embedding space
18#[derive(Debug, Clone)]
19pub struct ContinuousTokenizer {
20    /// Encoder projection (input_dim -> embed_dim)
21    encoder: Array2<f32>,
22    /// Decoder projection (embed_dim -> input_dim)
23    decoder: Array2<f32>,
24    /// Input dimension
25    input_dim: usize,
26    /// Embedding dimension
27    embed_dim: usize,
28}
29
30impl ContinuousTokenizer {
31    /// Create a new continuous tokenizer
32    pub fn new(input_dim: usize, embed_dim: usize) -> Self {
33        let mut rng = thread_rng();
34
35        // Xavier initialization
36        let enc_scale = (2.0 / (input_dim + embed_dim) as f32).sqrt();
37        let encoder = Array2::from_shape_fn((input_dim, embed_dim), |_| {
38            (rng.random::<f32>() - 0.5) * 2.0 * enc_scale
39        });
40
41        let dec_scale = (2.0 / (embed_dim + input_dim) as f32).sqrt();
42        let decoder = Array2::from_shape_fn((embed_dim, input_dim), |_| {
43            (rng.random::<f32>() - 0.5) * 2.0 * dec_scale
44        });
45
46        Self {
47            encoder,
48            decoder,
49            input_dim,
50            embed_dim,
51        }
52    }
53
54    /// Get input dimension
55    pub fn input_dim(&self) -> usize {
56        self.input_dim
57    }
58
59    /// Set encoder weights
60    pub fn set_encoder(&mut self, weights: Array2<f32>) -> TokenizerResult<()> {
61        if weights.shape() != [self.input_dim, self.embed_dim] {
62            return Err(TokenizerError::dim_mismatch(
63                self.input_dim * self.embed_dim,
64                weights.len(),
65                "dimension validation",
66            ));
67        }
68        self.encoder = weights;
69        Ok(())
70    }
71
72    /// Set decoder weights
73    pub fn set_decoder(&mut self, weights: Array2<f32>) -> TokenizerResult<()> {
74        if weights.shape() != [self.embed_dim, self.input_dim] {
75            return Err(TokenizerError::dim_mismatch(
76                self.embed_dim * self.input_dim,
77                weights.len(),
78                "dimension validation",
79            ));
80        }
81        self.decoder = weights;
82        Ok(())
83    }
84
85    /// Get encoder weights
86    pub fn encoder(&self) -> &Array2<f32> {
87        &self.encoder
88    }
89
90    /// Get decoder weights
91    pub fn decoder(&self) -> &Array2<f32> {
92        &self.decoder
93    }
94}
95
96impl SignalTokenizer for ContinuousTokenizer {
97    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
98        if signal.len() != self.input_dim {
99            return Err(TokenizerError::dim_mismatch(
100                self.input_dim,
101                signal.len(),
102                "dimension validation",
103            ));
104        }
105        Ok(signal.dot(&self.encoder))
106    }
107
108    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
109        if tokens.len() != self.embed_dim {
110            return Err(TokenizerError::dim_mismatch(
111                self.embed_dim,
112                tokens.len(),
113                "dimension validation",
114            ));
115        }
116        Ok(tokens.dot(&self.decoder))
117    }
118
119    fn embed_dim(&self) -> usize {
120        self.embed_dim
121    }
122
123    fn vocab_size(&self) -> usize {
124        0 // Continuous = no discrete vocabulary
125    }
126}
127
128/// Training configuration for trainable tokenizer
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct TrainingConfig {
131    /// Learning rate
132    pub learning_rate: f64,
133    /// Weight decay (L2 regularization)
134    pub weight_decay: f64,
135    /// Beta1 for AdamW
136    pub beta1: f64,
137    /// Beta2 for AdamW
138    pub beta2: f64,
139    /// Epsilon for AdamW
140    pub eps: f64,
141    /// Number of training epochs
142    pub num_epochs: usize,
143    /// Batch size
144    pub batch_size: usize,
145}
146
147impl Default for TrainingConfig {
148    fn default() -> Self {
149        Self {
150            learning_rate: 1e-3,
151            weight_decay: 1e-4,
152            beta1: 0.9,
153            beta2: 0.999,
154            eps: 1e-8,
155            num_epochs: 100,
156            batch_size: 32,
157        }
158    }
159}
160
161/// Reconstruction loss metrics
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct ReconstructionMetrics {
164    /// Mean Squared Error
165    pub mse: f32,
166    /// Mean Absolute Error
167    pub mae: f32,
168    /// Signal-to-Noise Ratio (dB)
169    pub snr_db: f32,
170    /// Root Mean Squared Error
171    pub rmse: f32,
172}
173
174impl ReconstructionMetrics {
175    /// Compute metrics from original and reconstructed signals
176    pub fn compute(original: &Array1<f32>, reconstructed: &Array1<f32>) -> Self {
177        assert_eq!(
178            original.len(),
179            reconstructed.len(),
180            "Signal lengths must match"
181        );
182
183        let n = original.len() as f32;
184
185        // MSE
186        let mse: f32 = original
187            .iter()
188            .zip(reconstructed.iter())
189            .map(|(o, r)| (o - r).powi(2))
190            .sum::<f32>()
191            / n;
192
193        // MAE
194        let mae: f32 = original
195            .iter()
196            .zip(reconstructed.iter())
197            .map(|(o, r)| (o - r).abs())
198            .sum::<f32>()
199            / n;
200
201        // RMSE
202        let rmse = mse.sqrt();
203
204        // SNR (Signal-to-Noise Ratio)
205        let signal_power: f32 = original.iter().map(|x| x.powi(2)).sum::<f32>() / n;
206        let noise_power = mse;
207        let snr_db = if noise_power > 0.0 {
208            10.0 * (signal_power / noise_power).log10()
209        } else {
210            f32::INFINITY
211        };
212
213        Self {
214            mse,
215            mae,
216            snr_db,
217            rmse,
218        }
219    }
220
221    /// Check if reconstruction quality is acceptable
222    pub fn is_acceptable(&self, mse_threshold: f32, snr_threshold_db: f32) -> bool {
223        self.mse < mse_threshold && self.snr_db > snr_threshold_db
224    }
225}
226
227/// Trainable continuous tokenizer with gradient descent
228pub struct TrainableContinuousTokenizer {
229    /// Variable map for parameters
230    varmap: VarMap,
231    /// Encoder weights variable
232    encoder_var: Var,
233    /// Decoder weights variable
234    decoder_var: Var,
235    /// Input dimension
236    input_dim: usize,
237    /// Embedding dimension
238    embed_dim: usize,
239    /// Device (CPU or CUDA)
240    device: Device,
241}
242
243impl TrainableContinuousTokenizer {
244    /// Create a new trainable continuous tokenizer
245    pub fn new(input_dim: usize, embed_dim: usize) -> CandleResult<Self> {
246        let device = Device::Cpu;
247        let varmap = VarMap::new();
248
249        // Xavier initialization for encoder
250        let enc_scale = (2.0 / (input_dim + embed_dim) as f32).sqrt();
251        let encoder_init = Tensor::randn(0f32, 1.0, (input_dim, embed_dim), &device)?
252            .affine(0.0, enc_scale as f64)?;
253        let encoder_var = Var::from_tensor(&encoder_init)?;
254
255        // Xavier initialization for decoder
256        let dec_scale = (2.0 / (embed_dim + input_dim) as f32).sqrt();
257        let decoder_init = Tensor::randn(0f32, 1.0, (embed_dim, input_dim), &device)?
258            .affine(0.0, dec_scale as f64)?;
259        let decoder_var = Var::from_tensor(&decoder_init)?;
260
261        // Add variables to varmap
262        varmap
263            .data()
264            .lock()
265            .expect("VarMap lock should not be poisoned")
266            .insert("encoder".to_string(), encoder_var.clone());
267        varmap
268            .data()
269            .lock()
270            .expect("VarMap lock should not be poisoned")
271            .insert("decoder".to_string(), decoder_var.clone());
272
273        Ok(Self {
274            varmap,
275            encoder_var,
276            decoder_var,
277            input_dim,
278            embed_dim,
279            device,
280        })
281    }
282
283    /// Encode a signal to embeddings
284    fn forward_encode(&self, signal: &Tensor) -> CandleResult<Tensor> {
285        signal.matmul(self.encoder_var.as_tensor())
286    }
287
288    /// Decode embeddings back to signal
289    fn forward_decode(&self, embeddings: &Tensor) -> CandleResult<Tensor> {
290        embeddings.matmul(self.decoder_var.as_tensor())
291    }
292
293    /// Full forward pass (encode then decode)
294    fn forward(&self, signal: &Tensor) -> CandleResult<Tensor> {
295        let embeddings = self.forward_encode(signal)?;
296        self.forward_decode(&embeddings)
297    }
298
299    /// Compute reconstruction loss (MSE)
300    fn compute_loss(&self, original: &Tensor, reconstructed: &Tensor) -> CandleResult<Tensor> {
301        let diff = (original - reconstructed)?;
302        let squared = diff.sqr()?;
303        squared.mean_all()
304    }
305
306    /// Train on a batch of signals
307    pub fn train_batch(
308        &self,
309        signals: &[Array1<f32>],
310        optimizer: &mut AdamW,
311    ) -> TokenizerResult<f32> {
312        // Convert signals to tensor
313        let batch_data: Vec<f32> = signals.iter().flat_map(|s| s.iter().copied()).collect();
314        let batch_tensor =
315            Tensor::from_slice(&batch_data, (signals.len(), self.input_dim), &self.device)
316                .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
317
318        // Forward pass
319        let reconstructed = self
320            .forward(&batch_tensor)
321            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
322
323        // Compute loss
324        let loss = self
325            .compute_loss(&batch_tensor, &reconstructed)
326            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
327
328        // Backward pass
329        optimizer
330            .backward_step(&loss)
331            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
332
333        // Return loss value
334        let loss_val = loss
335            .to_vec0::<f32>()
336            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
337
338        Ok(loss_val)
339    }
340
341    /// Train on a dataset
342    pub fn train(
343        &self,
344        training_data: &[Array1<f32>],
345        config: &TrainingConfig,
346    ) -> TokenizerResult<Vec<f32>> {
347        // Create optimizer
348        let params = ParamsAdamW {
349            lr: config.learning_rate,
350            weight_decay: config.weight_decay,
351            beta1: config.beta1,
352            beta2: config.beta2,
353            eps: config.eps,
354        };
355        let mut optimizer = AdamW::new(self.varmap.all_vars(), params).map_err(|e| {
356            TokenizerError::InternalError(format!("Failed to create optimizer: {}", e))
357        })?;
358
359        let mut loss_history = Vec::with_capacity(config.num_epochs);
360
361        // Training loop
362        for epoch in 0..config.num_epochs {
363            let mut epoch_loss = 0.0;
364            let mut num_batches = 0;
365
366            // Process in batches
367            for batch_start in (0..training_data.len()).step_by(config.batch_size) {
368                let batch_end = (batch_start + config.batch_size).min(training_data.len());
369                let batch = &training_data[batch_start..batch_end];
370
371                let loss = self.train_batch(batch, &mut optimizer)?;
372                epoch_loss += loss;
373                num_batches += 1;
374            }
375
376            let avg_loss = epoch_loss / num_batches as f32;
377            loss_history.push(avg_loss);
378
379            // Log progress every 10 epochs
380            if (epoch + 1) % 10 == 0 {
381                tracing::debug!(
382                    "Epoch {}/{}: Loss = {:.6}",
383                    epoch + 1,
384                    config.num_epochs,
385                    avg_loss
386                );
387            }
388        }
389
390        Ok(loss_history)
391    }
392
393    /// Encode a signal (inference)
394    pub fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
395        if signal.len() != self.input_dim {
396            return Err(TokenizerError::dim_mismatch(
397                self.input_dim,
398                signal.len(),
399                "dimension validation",
400            ));
401        }
402
403        let signal_data: Vec<f32> = signal.iter().copied().collect();
404        let signal_tensor = Tensor::from_slice(&signal_data, (1, self.input_dim), &self.device)
405            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
406
407        let embeddings = self
408            .forward_encode(&signal_tensor)
409            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
410
411        let result_vec = embeddings
412            .to_vec2::<f32>()
413            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
414
415        Ok(Array1::from_vec(result_vec[0].clone()))
416    }
417
418    /// Decode embeddings (inference)
419    pub fn decode(&self, embeddings: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
420        if embeddings.len() != self.embed_dim {
421            return Err(TokenizerError::dim_mismatch(
422                self.embed_dim,
423                embeddings.len(),
424                "dimension validation",
425            ));
426        }
427
428        let emb_data: Vec<f32> = embeddings.iter().copied().collect();
429        let emb_tensor = Tensor::from_slice(&emb_data, (1, self.embed_dim), &self.device)
430            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
431
432        let reconstructed = self
433            .forward_decode(&emb_tensor)
434            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
435
436        let result_vec = reconstructed
437            .to_vec2::<f32>()
438            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
439
440        Ok(Array1::from_vec(result_vec[0].clone()))
441    }
442
443    /// Get encoder weights as Array2
444    pub fn get_encoder_weights(&self) -> TokenizerResult<Array2<f32>> {
445        let tensor = self.encoder_var.as_tensor();
446        let data = tensor
447            .to_vec2::<f32>()
448            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
449
450        let mut result = Array2::zeros((self.input_dim, self.embed_dim));
451        for (i, row) in data.iter().enumerate() {
452            for (j, &val) in row.iter().enumerate() {
453                result[[i, j]] = val;
454            }
455        }
456
457        Ok(result)
458    }
459
460    /// Get decoder weights as Array2
461    pub fn get_decoder_weights(&self) -> TokenizerResult<Array2<f32>> {
462        let tensor = self.decoder_var.as_tensor();
463        let data = tensor
464            .to_vec2::<f32>()
465            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
466
467        let mut result = Array2::zeros((self.embed_dim, self.input_dim));
468        for (i, row) in data.iter().enumerate() {
469            for (j, &val) in row.iter().enumerate() {
470                result[[i, j]] = val;
471            }
472        }
473
474        Ok(result)
475    }
476
477    /// Evaluate reconstruction quality on test data
478    pub fn evaluate(&self, test_data: &[Array1<f32>]) -> TokenizerResult<ReconstructionMetrics> {
479        let mut total_mse = 0.0;
480        let mut total_mae = 0.0;
481        let mut total_signal_power = 0.0;
482        let mut total_noise_power = 0.0;
483        let mut total_samples = 0;
484
485        for signal in test_data {
486            let embeddings = self.encode(signal)?;
487            let reconstructed = self.decode(&embeddings)?;
488
489            let metrics = ReconstructionMetrics::compute(signal, &reconstructed);
490            total_mse += metrics.mse;
491            total_mae += metrics.mae;
492
493            let signal_power: f32 =
494                signal.iter().map(|x| x.powi(2)).sum::<f32>() / signal.len() as f32;
495            total_signal_power += signal_power;
496            total_noise_power += metrics.mse;
497            total_samples += 1;
498        }
499
500        let avg_mse = total_mse / total_samples as f32;
501        let avg_mae = total_mae / total_samples as f32;
502        let avg_rmse = avg_mse.sqrt();
503        let avg_snr_db = if total_noise_power > 0.0 {
504            10.0 * (total_signal_power / total_noise_power).log10()
505        } else {
506            f32::INFINITY
507        };
508
509        Ok(ReconstructionMetrics {
510            mse: avg_mse,
511            mae: avg_mae,
512            snr_db: avg_snr_db,
513            rmse: avg_rmse,
514        })
515    }
516
517    /// Get embedding dimension
518    pub fn embed_dim(&self) -> usize {
519        self.embed_dim
520    }
521
522    /// Get input dimension
523    pub fn input_dim(&self) -> usize {
524        self.input_dim
525    }
526
527    /// Save model to checkpoint
528    pub fn save_checkpoint<P: AsRef<Path>>(
529        &self,
530        path: P,
531        version: &str,
532        training_config: Option<TrainingConfig>,
533        metrics: Option<ReconstructionMetrics>,
534    ) -> TokenizerResult<()> {
535        let version = ModelVersion::parse(version)?;
536
537        let mut metadata = ModelMetadata::new(
538            version,
539            "TrainableContinuousTokenizer".to_string(),
540            self.input_dim,
541            self.embed_dim,
542        );
543
544        metadata.training_config = training_config;
545        metadata.metrics = metrics;
546
547        let mut checkpoint = ModelCheckpoint::new(metadata);
548
549        // Add encoder and decoder weights
550        let encoder_weights = self.get_encoder_weights()?;
551        let decoder_weights = self.get_decoder_weights()?;
552
553        checkpoint.add_array2("encoder".to_string(), &encoder_weights);
554        checkpoint.add_array2("decoder".to_string(), &decoder_weights);
555
556        checkpoint.save(path)
557    }
558
559    /// Load model from checkpoint
560    pub fn load_checkpoint<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
561        let checkpoint = ModelCheckpoint::load(path)?;
562
563        // Verify model type
564        if checkpoint.metadata.model_type != "TrainableContinuousTokenizer" {
565            return Err(TokenizerError::InvalidConfig(format!(
566                "Expected TrainableContinuousTokenizer, got {}",
567                checkpoint.metadata.model_type
568            )));
569        }
570
571        // Create new tokenizer with the right dimensions
572        let input_dim = checkpoint.metadata.input_dim;
573        let embed_dim = checkpoint.metadata.embed_dim;
574
575        let mut tokenizer = Self::new(input_dim, embed_dim)
576            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
577
578        // Load weights
579        let encoder_weights = checkpoint.get_array2("encoder")?;
580        let decoder_weights = checkpoint.get_array2("decoder")?;
581
582        // Set the weights by creating new tensors
583        let encoder_tensor = Tensor::from_slice(
584            encoder_weights
585                .as_slice()
586                .expect("Encoder weights must have contiguous layout"),
587            (input_dim, embed_dim),
588            &tokenizer.device,
589        )
590        .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
591
592        let decoder_tensor = Tensor::from_slice(
593            decoder_weights
594                .as_slice()
595                .expect("Decoder weights must have contiguous layout"),
596            (embed_dim, input_dim),
597            &tokenizer.device,
598        )
599        .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
600
601        // Update the variables
602        tokenizer.encoder_var = Var::from_tensor(&encoder_tensor)
603            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
604        tokenizer.decoder_var = Var::from_tensor(&decoder_tensor)
605            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
606
607        // Update varmap
608        tokenizer
609            .varmap
610            .data()
611            .lock()
612            .expect("VarMap lock should not be poisoned")
613            .insert("encoder".to_string(), tokenizer.encoder_var.clone());
614        tokenizer
615            .varmap
616            .data()
617            .lock()
618            .expect("VarMap lock should not be poisoned")
619            .insert("decoder".to_string(), tokenizer.decoder_var.clone());
620
621        Ok(tokenizer)
622    }
623
624    /// Get metadata from a checkpoint file without loading the full model
625    pub fn peek_checkpoint<P: AsRef<Path>>(path: P) -> TokenizerResult<ModelMetadata> {
626        let checkpoint = ModelCheckpoint::load(path)?;
627        Ok(checkpoint.metadata)
628    }
629}
630
631#[cfg(test)]
632mod tests {
633    use super::*;
634
635    #[test]
636    fn test_continuous_tokenizer() {
637        let tokenizer = ContinuousTokenizer::new(3, 64);
638
639        let signal = Array1::from_vec(vec![0.1, 0.2, 0.3]);
640        let encoded = tokenizer.encode(&signal).unwrap();
641        assert_eq!(encoded.len(), 64);
642
643        let decoded = tokenizer.decode(&encoded).unwrap();
644        assert_eq!(decoded.len(), 3);
645    }
646
647    #[test]
648    fn test_dimension_mismatch() {
649        let tokenizer = ContinuousTokenizer::new(3, 64);
650        let signal = Array1::from_vec(vec![0.1, 0.2]); // Wrong size
651        assert!(tokenizer.encode(&signal).is_err());
652    }
653
654    #[test]
655    fn test_reconstruction_metrics() {
656        let original = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
657        let reconstructed = Array1::from_vec(vec![1.1, 1.9, 3.1, 3.9]);
658
659        let metrics = ReconstructionMetrics::compute(&original, &reconstructed);
660
661        assert!(metrics.mse > 0.0);
662        assert!(metrics.mae > 0.0);
663        assert!(metrics.rmse > 0.0);
664        assert!(metrics.snr_db.is_finite());
665        assert!(metrics.snr_db > 0.0); // Should be positive for small errors
666    }
667
668    #[test]
669    fn test_reconstruction_metrics_perfect() {
670        let original = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
671        let reconstructed = original.clone();
672
673        let metrics = ReconstructionMetrics::compute(&original, &reconstructed);
674
675        assert_eq!(metrics.mse, 0.0);
676        assert_eq!(metrics.mae, 0.0);
677        assert_eq!(metrics.rmse, 0.0);
678        assert!(metrics.snr_db.is_infinite());
679    }
680
681    #[test]
682    fn test_metrics_is_acceptable() {
683        let original = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
684        let reconstructed = Array1::from_vec(vec![1.01, 2.01, 3.01, 4.01]);
685
686        let metrics = ReconstructionMetrics::compute(&original, &reconstructed);
687
688        assert!(metrics.is_acceptable(0.01, 10.0)); // Low MSE threshold, low SNR threshold
689        assert!(!metrics.is_acceptable(0.0001, 100.0)); // Very strict thresholds
690    }
691
692    #[test]
693    fn test_trainable_tokenizer_creation() {
694        let tokenizer = TrainableContinuousTokenizer::new(8, 16).unwrap();
695
696        assert_eq!(tokenizer.input_dim(), 8);
697        assert_eq!(tokenizer.embed_dim(), 16);
698    }
699
700    #[test]
701    fn test_trainable_encode_decode() {
702        let tokenizer = TrainableContinuousTokenizer::new(8, 16).unwrap();
703
704        let signal = Array1::from_vec((0..8).map(|i| i as f32 * 0.1).collect());
705        let embeddings = tokenizer.encode(&signal).unwrap();
706        let reconstructed = tokenizer.decode(&embeddings).unwrap();
707
708        assert_eq!(embeddings.len(), 16);
709        assert_eq!(reconstructed.len(), 8);
710    }
711
712    #[test]
713    fn test_trainable_tokenizer_training() {
714        let tokenizer = TrainableContinuousTokenizer::new(4, 8).unwrap();
715
716        // Generate synthetic training data
717        let training_data: Vec<Array1<f32>> = (0..50)
718            .map(|i| Array1::from_vec((0..4).map(|j| ((i + j) as f32 * 0.1).sin()).collect()))
719            .collect();
720
721        let config = TrainingConfig {
722            num_epochs: 10,
723            batch_size: 8,
724            learning_rate: 1e-3,
725            ..Default::default()
726        };
727
728        let loss_history = tokenizer.train(&training_data, &config).unwrap();
729
730        assert_eq!(loss_history.len(), 10);
731        // Loss should generally decrease
732        assert!(loss_history[loss_history.len() - 1] < loss_history[0] * 2.0);
733    }
734
735    #[test]
736    fn test_trainable_tokenizer_evaluation() {
737        let tokenizer = TrainableContinuousTokenizer::new(4, 8).unwrap();
738
739        // Generate test data
740        let test_data: Vec<Array1<f32>> = (0..10)
741            .map(|i| Array1::from_vec((0..4).map(|j| ((i + j) as f32 * 0.1).cos()).collect()))
742            .collect();
743
744        let metrics = tokenizer.evaluate(&test_data).unwrap();
745
746        assert!(metrics.mse >= 0.0);
747        assert!(metrics.mae >= 0.0);
748        assert!(metrics.rmse >= 0.0);
749        assert!(metrics.snr_db.is_finite() || metrics.snr_db.is_infinite());
750    }
751
752    #[test]
753    fn test_trainable_get_weights() {
754        let tokenizer = TrainableContinuousTokenizer::new(4, 8).unwrap();
755
756        let encoder_weights = tokenizer.get_encoder_weights().unwrap();
757        let decoder_weights = tokenizer.get_decoder_weights().unwrap();
758
759        assert_eq!(encoder_weights.shape(), &[4, 8]);
760        assert_eq!(decoder_weights.shape(), &[8, 4]);
761    }
762
763    #[test]
764    fn test_training_config_default() {
765        let config = TrainingConfig::default();
766
767        assert_eq!(config.learning_rate, 1e-3);
768        assert_eq!(config.num_epochs, 100);
769        assert_eq!(config.batch_size, 32);
770    }
771
772    #[test]
773    fn test_trainable_convergence() {
774        // Test that training actually improves reconstruction
775        let tokenizer = TrainableContinuousTokenizer::new(8, 16).unwrap();
776
777        // Create a simple dataset (sine waves with different frequencies)
778        let training_data: Vec<Array1<f32>> = (0..100)
779            .map(|i| {
780                let freq = (i % 5 + 1) as f32 * 0.1;
781                Array1::from_vec((0..8).map(|j| (j as f32 * freq).sin()).collect())
782            })
783            .collect();
784
785        // Evaluate before training
786        let metrics_before = tokenizer.evaluate(&training_data[..10]).unwrap();
787
788        // Train
789        let config = TrainingConfig {
790            num_epochs: 20,
791            batch_size: 16,
792            learning_rate: 1e-2,
793            ..Default::default()
794        };
795        tokenizer.train(&training_data, &config).unwrap();
796
797        // Evaluate after training
798        let metrics_after = tokenizer.evaluate(&training_data[..10]).unwrap();
799
800        // MSE should decrease after training
801        assert!(
802            metrics_after.mse < metrics_before.mse,
803            "MSE should decrease: before={}, after={}",
804            metrics_before.mse,
805            metrics_after.mse
806        );
807
808        // SNR should improve (increase) after training
809        if metrics_before.snr_db.is_finite() {
810            assert!(metrics_after.snr_db > metrics_before.snr_db);
811        }
812    }
813
814    #[test]
815    fn test_save_load_checkpoint() {
816        use std::env;
817
818        let temp_dir = env::temp_dir();
819        let checkpoint_path = temp_dir.join("test_trainable_checkpoint.safetensors");
820
821        // Create and train a tokenizer
822        let tokenizer = TrainableContinuousTokenizer::new(4, 8).unwrap();
823
824        let training_data: Vec<Array1<f32>> = (0..20)
825            .map(|i| Array1::from_vec((0..4).map(|j| ((i + j) as f32 * 0.1).sin()).collect()))
826            .collect();
827
828        let config = TrainingConfig {
829            num_epochs: 5,
830            batch_size: 4,
831            learning_rate: 1e-3,
832            ..Default::default()
833        };
834
835        tokenizer.train(&training_data, &config).unwrap();
836
837        // Evaluate before saving
838        let metrics_before = tokenizer.evaluate(&training_data[..5]).unwrap();
839
840        // Save checkpoint
841        tokenizer
842            .save_checkpoint(
843                &checkpoint_path,
844                "1.0.0",
845                Some(config.clone()),
846                Some(metrics_before.clone()),
847            )
848            .unwrap();
849
850        // Load checkpoint
851        let loaded_tokenizer =
852            TrainableContinuousTokenizer::load_checkpoint(&checkpoint_path).unwrap();
853
854        // Verify dimensions
855        assert_eq!(loaded_tokenizer.input_dim(), 4);
856        assert_eq!(loaded_tokenizer.embed_dim(), 8);
857
858        // Evaluate loaded model - should have similar performance
859        let metrics_loaded = loaded_tokenizer.evaluate(&training_data[..5]).unwrap();
860
861        // Metrics should be very close
862        assert!(
863            (metrics_loaded.mse - metrics_before.mse).abs() < 1e-4,
864            "Loaded model MSE should match: before={}, loaded={}",
865            metrics_before.mse,
866            metrics_loaded.mse
867        );
868
869        // Test encoding/decoding with loaded model
870        let test_signal = Array1::from_vec((0..4).map(|i| (i as f32) * 0.1).collect());
871        let encoded_original = tokenizer.encode(&test_signal).unwrap();
872        let encoded_loaded = loaded_tokenizer.encode(&test_signal).unwrap();
873
874        // Encodings should be identical
875        for (o, l) in encoded_original.iter().zip(encoded_loaded.iter()) {
876            assert!(
877                (o - l).abs() < 1e-4,
878                "Encoded values should match: original={}, loaded={}",
879                o,
880                l
881            );
882        }
883
884        // Cleanup
885        std::fs::remove_file(&checkpoint_path).ok();
886    }
887
888    #[test]
889    fn test_peek_checkpoint() {
890        use std::env;
891
892        let temp_dir = env::temp_dir();
893        let checkpoint_path = temp_dir.join("test_peek_checkpoint.safetensors");
894
895        let tokenizer = TrainableContinuousTokenizer::new(6, 12).unwrap();
896
897        let config = TrainingConfig {
898            num_epochs: 1,
899            batch_size: 4,
900            ..Default::default()
901        };
902
903        tokenizer
904            .save_checkpoint(&checkpoint_path, "2.1.3", Some(config.clone()), None)
905            .unwrap();
906
907        // Peek at metadata without loading full model
908        let metadata = TrainableContinuousTokenizer::peek_checkpoint(&checkpoint_path).unwrap();
909
910        assert_eq!(metadata.model_type, "TrainableContinuousTokenizer");
911        assert_eq!(metadata.input_dim, 6);
912        assert_eq!(metadata.embed_dim, 12);
913        assert_eq!(metadata.version.to_string(), "2.1.3");
914        assert!(metadata.training_config.is_some());
915
916        // Cleanup
917        std::fs::remove_file(&checkpoint_path).ok();
918    }
919
920    #[test]
921    fn test_checkpoint_version_compatibility() {
922        use std::env;
923
924        let temp_dir = env::temp_dir();
925        let checkpoint_path = temp_dir.join("test_version_checkpoint.safetensors");
926
927        let tokenizer = TrainableContinuousTokenizer::new(4, 8).unwrap();
928
929        tokenizer
930            .save_checkpoint(&checkpoint_path, "1.0.0", None, None)
931            .unwrap();
932
933        // Load and check version
934        let metadata = TrainableContinuousTokenizer::peek_checkpoint(&checkpoint_path).unwrap();
935
936        let current_version = ModelVersion::new(1, 0, 0);
937        assert!(metadata.version.is_compatible_with(&current_version));
938
939        let incompatible_version = ModelVersion::new(2, 0, 0);
940        assert!(!metadata.version.is_compatible_with(&incompatible_version));
941
942        // Cleanup
943        std::fs::remove_file(&checkpoint_path).ok();
944    }
945
946    #[test]
947    fn test_save_checkpoint_with_metrics() {
948        use std::env;
949
950        let temp_dir = env::temp_dir();
951        let checkpoint_path = temp_dir.join("test_metrics_checkpoint.safetensors");
952
953        let tokenizer = TrainableContinuousTokenizer::new(4, 8).unwrap();
954
955        let test_data: Vec<Array1<f32>> = (0..10)
956            .map(|i| Array1::from_vec((0..4).map(|j| ((i + j) as f32 * 0.1).cos()).collect()))
957            .collect();
958
959        let metrics = tokenizer.evaluate(&test_data).unwrap();
960
961        tokenizer
962            .save_checkpoint(&checkpoint_path, "1.0.0", None, Some(metrics.clone()))
963            .unwrap();
964
965        // Load and verify metrics
966        let checkpoint = crate::persistence::ModelCheckpoint::load(&checkpoint_path).unwrap();
967        assert!(checkpoint.metadata.metrics.is_some());
968
969        let loaded_metrics = checkpoint.metadata.metrics.unwrap();
970        assert_eq!(loaded_metrics.mse, metrics.mse);
971        assert_eq!(loaded_metrics.mae, metrics.mae);
972        assert_eq!(loaded_metrics.rmse, metrics.rmse);
973
974        // Cleanup
975        std::fs::remove_file(&checkpoint_path).ok();
976    }
977}