1use crate::error::{TokenizerError, TokenizerResult};
8use crate::persistence::{ModelCheckpoint, ModelMetadata, ModelVersion};
9use crate::SignalTokenizer;
10use candle_core::{Device, Result as CandleResult, Tensor, Var};
11use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarMap};
12use scirs2_core::ndarray::{Array1, Array2};
13use scirs2_core::random::thread_rng;
14use serde::{Deserialize, Serialize};
15use std::path::Path;
16
17#[derive(Debug, Clone)]
19pub struct ContinuousTokenizer {
20 encoder: Array2<f32>,
22 decoder: Array2<f32>,
24 input_dim: usize,
26 embed_dim: usize,
28}
29
30impl ContinuousTokenizer {
31 pub fn new(input_dim: usize, embed_dim: usize) -> Self {
33 let mut rng = thread_rng();
34
35 let enc_scale = (2.0 / (input_dim + embed_dim) as f32).sqrt();
37 let encoder = Array2::from_shape_fn((input_dim, embed_dim), |_| {
38 (rng.random::<f32>() - 0.5) * 2.0 * enc_scale
39 });
40
41 let dec_scale = (2.0 / (embed_dim + input_dim) as f32).sqrt();
42 let decoder = Array2::from_shape_fn((embed_dim, input_dim), |_| {
43 (rng.random::<f32>() - 0.5) * 2.0 * dec_scale
44 });
45
46 Self {
47 encoder,
48 decoder,
49 input_dim,
50 embed_dim,
51 }
52 }
53
54 pub fn input_dim(&self) -> usize {
56 self.input_dim
57 }
58
59 pub fn set_encoder(&mut self, weights: Array2<f32>) -> TokenizerResult<()> {
61 if weights.shape() != [self.input_dim, self.embed_dim] {
62 return Err(TokenizerError::dim_mismatch(
63 self.input_dim * self.embed_dim,
64 weights.len(),
65 "dimension validation",
66 ));
67 }
68 self.encoder = weights;
69 Ok(())
70 }
71
72 pub fn set_decoder(&mut self, weights: Array2<f32>) -> TokenizerResult<()> {
74 if weights.shape() != [self.embed_dim, self.input_dim] {
75 return Err(TokenizerError::dim_mismatch(
76 self.embed_dim * self.input_dim,
77 weights.len(),
78 "dimension validation",
79 ));
80 }
81 self.decoder = weights;
82 Ok(())
83 }
84
85 pub fn encoder(&self) -> &Array2<f32> {
87 &self.encoder
88 }
89
90 pub fn decoder(&self) -> &Array2<f32> {
92 &self.decoder
93 }
94}
95
96impl SignalTokenizer for ContinuousTokenizer {
97 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
98 if signal.len() != self.input_dim {
99 return Err(TokenizerError::dim_mismatch(
100 self.input_dim,
101 signal.len(),
102 "dimension validation",
103 ));
104 }
105 Ok(signal.dot(&self.encoder))
106 }
107
108 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
109 if tokens.len() != self.embed_dim {
110 return Err(TokenizerError::dim_mismatch(
111 self.embed_dim,
112 tokens.len(),
113 "dimension validation",
114 ));
115 }
116 Ok(tokens.dot(&self.decoder))
117 }
118
119 fn embed_dim(&self) -> usize {
120 self.embed_dim
121 }
122
123 fn vocab_size(&self) -> usize {
124 0 }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct TrainingConfig {
131 pub learning_rate: f64,
133 pub weight_decay: f64,
135 pub beta1: f64,
137 pub beta2: f64,
139 pub eps: f64,
141 pub num_epochs: usize,
143 pub batch_size: usize,
145}
146
147impl Default for TrainingConfig {
148 fn default() -> Self {
149 Self {
150 learning_rate: 1e-3,
151 weight_decay: 1e-4,
152 beta1: 0.9,
153 beta2: 0.999,
154 eps: 1e-8,
155 num_epochs: 100,
156 batch_size: 32,
157 }
158 }
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct ReconstructionMetrics {
164 pub mse: f32,
166 pub mae: f32,
168 pub snr_db: f32,
170 pub rmse: f32,
172}
173
174impl ReconstructionMetrics {
175 pub fn compute(original: &Array1<f32>, reconstructed: &Array1<f32>) -> Self {
177 assert_eq!(
178 original.len(),
179 reconstructed.len(),
180 "Signal lengths must match"
181 );
182
183 let n = original.len() as f32;
184
185 let mse: f32 = original
187 .iter()
188 .zip(reconstructed.iter())
189 .map(|(o, r)| (o - r).powi(2))
190 .sum::<f32>()
191 / n;
192
193 let mae: f32 = original
195 .iter()
196 .zip(reconstructed.iter())
197 .map(|(o, r)| (o - r).abs())
198 .sum::<f32>()
199 / n;
200
201 let rmse = mse.sqrt();
203
204 let signal_power: f32 = original.iter().map(|x| x.powi(2)).sum::<f32>() / n;
206 let noise_power = mse;
207 let snr_db = if noise_power > 0.0 {
208 10.0 * (signal_power / noise_power).log10()
209 } else {
210 f32::INFINITY
211 };
212
213 Self {
214 mse,
215 mae,
216 snr_db,
217 rmse,
218 }
219 }
220
221 pub fn is_acceptable(&self, mse_threshold: f32, snr_threshold_db: f32) -> bool {
223 self.mse < mse_threshold && self.snr_db > snr_threshold_db
224 }
225}
226
227pub struct TrainableContinuousTokenizer {
229 varmap: VarMap,
231 encoder_var: Var,
233 decoder_var: Var,
235 input_dim: usize,
237 embed_dim: usize,
239 device: Device,
241}
242
243impl TrainableContinuousTokenizer {
244 pub fn new(input_dim: usize, embed_dim: usize) -> CandleResult<Self> {
246 let device = Device::Cpu;
247 let varmap = VarMap::new();
248
249 let enc_scale = (2.0 / (input_dim + embed_dim) as f32).sqrt();
251 let encoder_init = Tensor::randn(0f32, 1.0, (input_dim, embed_dim), &device)?
252 .affine(0.0, enc_scale as f64)?;
253 let encoder_var = Var::from_tensor(&encoder_init)?;
254
255 let dec_scale = (2.0 / (embed_dim + input_dim) as f32).sqrt();
257 let decoder_init = Tensor::randn(0f32, 1.0, (embed_dim, input_dim), &device)?
258 .affine(0.0, dec_scale as f64)?;
259 let decoder_var = Var::from_tensor(&decoder_init)?;
260
261 varmap
263 .data()
264 .lock()
265 .expect("VarMap lock should not be poisoned")
266 .insert("encoder".to_string(), encoder_var.clone());
267 varmap
268 .data()
269 .lock()
270 .expect("VarMap lock should not be poisoned")
271 .insert("decoder".to_string(), decoder_var.clone());
272
273 Ok(Self {
274 varmap,
275 encoder_var,
276 decoder_var,
277 input_dim,
278 embed_dim,
279 device,
280 })
281 }
282
283 fn forward_encode(&self, signal: &Tensor) -> CandleResult<Tensor> {
285 signal.matmul(self.encoder_var.as_tensor())
286 }
287
288 fn forward_decode(&self, embeddings: &Tensor) -> CandleResult<Tensor> {
290 embeddings.matmul(self.decoder_var.as_tensor())
291 }
292
293 fn forward(&self, signal: &Tensor) -> CandleResult<Tensor> {
295 let embeddings = self.forward_encode(signal)?;
296 self.forward_decode(&embeddings)
297 }
298
299 fn compute_loss(&self, original: &Tensor, reconstructed: &Tensor) -> CandleResult<Tensor> {
301 let diff = (original - reconstructed)?;
302 let squared = diff.sqr()?;
303 squared.mean_all()
304 }
305
306 pub fn train_batch(
308 &self,
309 signals: &[Array1<f32>],
310 optimizer: &mut AdamW,
311 ) -> TokenizerResult<f32> {
312 let batch_data: Vec<f32> = signals.iter().flat_map(|s| s.iter().copied()).collect();
314 let batch_tensor =
315 Tensor::from_slice(&batch_data, (signals.len(), self.input_dim), &self.device)
316 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
317
318 let reconstructed = self
320 .forward(&batch_tensor)
321 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
322
323 let loss = self
325 .compute_loss(&batch_tensor, &reconstructed)
326 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
327
328 optimizer
330 .backward_step(&loss)
331 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
332
333 let loss_val = loss
335 .to_vec0::<f32>()
336 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
337
338 Ok(loss_val)
339 }
340
341 pub fn train(
343 &self,
344 training_data: &[Array1<f32>],
345 config: &TrainingConfig,
346 ) -> TokenizerResult<Vec<f32>> {
347 let params = ParamsAdamW {
349 lr: config.learning_rate,
350 weight_decay: config.weight_decay,
351 beta1: config.beta1,
352 beta2: config.beta2,
353 eps: config.eps,
354 };
355 let mut optimizer = AdamW::new(self.varmap.all_vars(), params).map_err(|e| {
356 TokenizerError::InternalError(format!("Failed to create optimizer: {}", e))
357 })?;
358
359 let mut loss_history = Vec::with_capacity(config.num_epochs);
360
361 for epoch in 0..config.num_epochs {
363 let mut epoch_loss = 0.0;
364 let mut num_batches = 0;
365
366 for batch_start in (0..training_data.len()).step_by(config.batch_size) {
368 let batch_end = (batch_start + config.batch_size).min(training_data.len());
369 let batch = &training_data[batch_start..batch_end];
370
371 let loss = self.train_batch(batch, &mut optimizer)?;
372 epoch_loss += loss;
373 num_batches += 1;
374 }
375
376 let avg_loss = epoch_loss / num_batches as f32;
377 loss_history.push(avg_loss);
378
379 if (epoch + 1) % 10 == 0 {
381 tracing::debug!(
382 "Epoch {}/{}: Loss = {:.6}",
383 epoch + 1,
384 config.num_epochs,
385 avg_loss
386 );
387 }
388 }
389
390 Ok(loss_history)
391 }
392
393 pub fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
395 if signal.len() != self.input_dim {
396 return Err(TokenizerError::dim_mismatch(
397 self.input_dim,
398 signal.len(),
399 "dimension validation",
400 ));
401 }
402
403 let signal_data: Vec<f32> = signal.iter().copied().collect();
404 let signal_tensor = Tensor::from_slice(&signal_data, (1, self.input_dim), &self.device)
405 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
406
407 let embeddings = self
408 .forward_encode(&signal_tensor)
409 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
410
411 let result_vec = embeddings
412 .to_vec2::<f32>()
413 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
414
415 Ok(Array1::from_vec(result_vec[0].clone()))
416 }
417
418 pub fn decode(&self, embeddings: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
420 if embeddings.len() != self.embed_dim {
421 return Err(TokenizerError::dim_mismatch(
422 self.embed_dim,
423 embeddings.len(),
424 "dimension validation",
425 ));
426 }
427
428 let emb_data: Vec<f32> = embeddings.iter().copied().collect();
429 let emb_tensor = Tensor::from_slice(&emb_data, (1, self.embed_dim), &self.device)
430 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
431
432 let reconstructed = self
433 .forward_decode(&emb_tensor)
434 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
435
436 let result_vec = reconstructed
437 .to_vec2::<f32>()
438 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
439
440 Ok(Array1::from_vec(result_vec[0].clone()))
441 }
442
443 pub fn get_encoder_weights(&self) -> TokenizerResult<Array2<f32>> {
445 let tensor = self.encoder_var.as_tensor();
446 let data = tensor
447 .to_vec2::<f32>()
448 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
449
450 let mut result = Array2::zeros((self.input_dim, self.embed_dim));
451 for (i, row) in data.iter().enumerate() {
452 for (j, &val) in row.iter().enumerate() {
453 result[[i, j]] = val;
454 }
455 }
456
457 Ok(result)
458 }
459
460 pub fn get_decoder_weights(&self) -> TokenizerResult<Array2<f32>> {
462 let tensor = self.decoder_var.as_tensor();
463 let data = tensor
464 .to_vec2::<f32>()
465 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
466
467 let mut result = Array2::zeros((self.embed_dim, self.input_dim));
468 for (i, row) in data.iter().enumerate() {
469 for (j, &val) in row.iter().enumerate() {
470 result[[i, j]] = val;
471 }
472 }
473
474 Ok(result)
475 }
476
477 pub fn evaluate(&self, test_data: &[Array1<f32>]) -> TokenizerResult<ReconstructionMetrics> {
479 let mut total_mse = 0.0;
480 let mut total_mae = 0.0;
481 let mut total_signal_power = 0.0;
482 let mut total_noise_power = 0.0;
483 let mut total_samples = 0;
484
485 for signal in test_data {
486 let embeddings = self.encode(signal)?;
487 let reconstructed = self.decode(&embeddings)?;
488
489 let metrics = ReconstructionMetrics::compute(signal, &reconstructed);
490 total_mse += metrics.mse;
491 total_mae += metrics.mae;
492
493 let signal_power: f32 =
494 signal.iter().map(|x| x.powi(2)).sum::<f32>() / signal.len() as f32;
495 total_signal_power += signal_power;
496 total_noise_power += metrics.mse;
497 total_samples += 1;
498 }
499
500 let avg_mse = total_mse / total_samples as f32;
501 let avg_mae = total_mae / total_samples as f32;
502 let avg_rmse = avg_mse.sqrt();
503 let avg_snr_db = if total_noise_power > 0.0 {
504 10.0 * (total_signal_power / total_noise_power).log10()
505 } else {
506 f32::INFINITY
507 };
508
509 Ok(ReconstructionMetrics {
510 mse: avg_mse,
511 mae: avg_mae,
512 snr_db: avg_snr_db,
513 rmse: avg_rmse,
514 })
515 }
516
517 pub fn embed_dim(&self) -> usize {
519 self.embed_dim
520 }
521
522 pub fn input_dim(&self) -> usize {
524 self.input_dim
525 }
526
527 pub fn save_checkpoint<P: AsRef<Path>>(
529 &self,
530 path: P,
531 version: &str,
532 training_config: Option<TrainingConfig>,
533 metrics: Option<ReconstructionMetrics>,
534 ) -> TokenizerResult<()> {
535 let version = ModelVersion::parse(version)?;
536
537 let mut metadata = ModelMetadata::new(
538 version,
539 "TrainableContinuousTokenizer".to_string(),
540 self.input_dim,
541 self.embed_dim,
542 );
543
544 metadata.training_config = training_config;
545 metadata.metrics = metrics;
546
547 let mut checkpoint = ModelCheckpoint::new(metadata);
548
549 let encoder_weights = self.get_encoder_weights()?;
551 let decoder_weights = self.get_decoder_weights()?;
552
553 checkpoint.add_array2("encoder".to_string(), &encoder_weights);
554 checkpoint.add_array2("decoder".to_string(), &decoder_weights);
555
556 checkpoint.save(path)
557 }
558
559 pub fn load_checkpoint<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
561 let checkpoint = ModelCheckpoint::load(path)?;
562
563 if checkpoint.metadata.model_type != "TrainableContinuousTokenizer" {
565 return Err(TokenizerError::InvalidConfig(format!(
566 "Expected TrainableContinuousTokenizer, got {}",
567 checkpoint.metadata.model_type
568 )));
569 }
570
571 let input_dim = checkpoint.metadata.input_dim;
573 let embed_dim = checkpoint.metadata.embed_dim;
574
575 let mut tokenizer = Self::new(input_dim, embed_dim)
576 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
577
578 let encoder_weights = checkpoint.get_array2("encoder")?;
580 let decoder_weights = checkpoint.get_array2("decoder")?;
581
582 let encoder_tensor = Tensor::from_slice(
584 encoder_weights
585 .as_slice()
586 .expect("Encoder weights must have contiguous layout"),
587 (input_dim, embed_dim),
588 &tokenizer.device,
589 )
590 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
591
592 let decoder_tensor = Tensor::from_slice(
593 decoder_weights
594 .as_slice()
595 .expect("Decoder weights must have contiguous layout"),
596 (embed_dim, input_dim),
597 &tokenizer.device,
598 )
599 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
600
601 tokenizer.encoder_var = Var::from_tensor(&encoder_tensor)
603 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
604 tokenizer.decoder_var = Var::from_tensor(&decoder_tensor)
605 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
606
607 tokenizer
609 .varmap
610 .data()
611 .lock()
612 .expect("VarMap lock should not be poisoned")
613 .insert("encoder".to_string(), tokenizer.encoder_var.clone());
614 tokenizer
615 .varmap
616 .data()
617 .lock()
618 .expect("VarMap lock should not be poisoned")
619 .insert("decoder".to_string(), tokenizer.decoder_var.clone());
620
621 Ok(tokenizer)
622 }
623
624 pub fn peek_checkpoint<P: AsRef<Path>>(path: P) -> TokenizerResult<ModelMetadata> {
626 let checkpoint = ModelCheckpoint::load(path)?;
627 Ok(checkpoint.metadata)
628 }
629}
630
631#[cfg(test)]
632mod tests {
633 use super::*;
634
635 #[test]
636 fn test_continuous_tokenizer() {
637 let tokenizer = ContinuousTokenizer::new(3, 64);
638
639 let signal = Array1::from_vec(vec![0.1, 0.2, 0.3]);
640 let encoded = tokenizer.encode(&signal).unwrap();
641 assert_eq!(encoded.len(), 64);
642
643 let decoded = tokenizer.decode(&encoded).unwrap();
644 assert_eq!(decoded.len(), 3);
645 }
646
647 #[test]
648 fn test_dimension_mismatch() {
649 let tokenizer = ContinuousTokenizer::new(3, 64);
650 let signal = Array1::from_vec(vec![0.1, 0.2]); assert!(tokenizer.encode(&signal).is_err());
652 }
653
654 #[test]
655 fn test_reconstruction_metrics() {
656 let original = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
657 let reconstructed = Array1::from_vec(vec![1.1, 1.9, 3.1, 3.9]);
658
659 let metrics = ReconstructionMetrics::compute(&original, &reconstructed);
660
661 assert!(metrics.mse > 0.0);
662 assert!(metrics.mae > 0.0);
663 assert!(metrics.rmse > 0.0);
664 assert!(metrics.snr_db.is_finite());
665 assert!(metrics.snr_db > 0.0); }
667
668 #[test]
669 fn test_reconstruction_metrics_perfect() {
670 let original = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
671 let reconstructed = original.clone();
672
673 let metrics = ReconstructionMetrics::compute(&original, &reconstructed);
674
675 assert_eq!(metrics.mse, 0.0);
676 assert_eq!(metrics.mae, 0.0);
677 assert_eq!(metrics.rmse, 0.0);
678 assert!(metrics.snr_db.is_infinite());
679 }
680
681 #[test]
682 fn test_metrics_is_acceptable() {
683 let original = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
684 let reconstructed = Array1::from_vec(vec![1.01, 2.01, 3.01, 4.01]);
685
686 let metrics = ReconstructionMetrics::compute(&original, &reconstructed);
687
688 assert!(metrics.is_acceptable(0.01, 10.0)); assert!(!metrics.is_acceptable(0.0001, 100.0)); }
691
692 #[test]
693 fn test_trainable_tokenizer_creation() {
694 let tokenizer = TrainableContinuousTokenizer::new(8, 16).unwrap();
695
696 assert_eq!(tokenizer.input_dim(), 8);
697 assert_eq!(tokenizer.embed_dim(), 16);
698 }
699
700 #[test]
701 fn test_trainable_encode_decode() {
702 let tokenizer = TrainableContinuousTokenizer::new(8, 16).unwrap();
703
704 let signal = Array1::from_vec((0..8).map(|i| i as f32 * 0.1).collect());
705 let embeddings = tokenizer.encode(&signal).unwrap();
706 let reconstructed = tokenizer.decode(&embeddings).unwrap();
707
708 assert_eq!(embeddings.len(), 16);
709 assert_eq!(reconstructed.len(), 8);
710 }
711
712 #[test]
713 fn test_trainable_tokenizer_training() {
714 let tokenizer = TrainableContinuousTokenizer::new(4, 8).unwrap();
715
716 let training_data: Vec<Array1<f32>> = (0..50)
718 .map(|i| Array1::from_vec((0..4).map(|j| ((i + j) as f32 * 0.1).sin()).collect()))
719 .collect();
720
721 let config = TrainingConfig {
722 num_epochs: 10,
723 batch_size: 8,
724 learning_rate: 1e-3,
725 ..Default::default()
726 };
727
728 let loss_history = tokenizer.train(&training_data, &config).unwrap();
729
730 assert_eq!(loss_history.len(), 10);
731 assert!(loss_history[loss_history.len() - 1] < loss_history[0] * 2.0);
733 }
734
735 #[test]
736 fn test_trainable_tokenizer_evaluation() {
737 let tokenizer = TrainableContinuousTokenizer::new(4, 8).unwrap();
738
739 let test_data: Vec<Array1<f32>> = (0..10)
741 .map(|i| Array1::from_vec((0..4).map(|j| ((i + j) as f32 * 0.1).cos()).collect()))
742 .collect();
743
744 let metrics = tokenizer.evaluate(&test_data).unwrap();
745
746 assert!(metrics.mse >= 0.0);
747 assert!(metrics.mae >= 0.0);
748 assert!(metrics.rmse >= 0.0);
749 assert!(metrics.snr_db.is_finite() || metrics.snr_db.is_infinite());
750 }
751
752 #[test]
753 fn test_trainable_get_weights() {
754 let tokenizer = TrainableContinuousTokenizer::new(4, 8).unwrap();
755
756 let encoder_weights = tokenizer.get_encoder_weights().unwrap();
757 let decoder_weights = tokenizer.get_decoder_weights().unwrap();
758
759 assert_eq!(encoder_weights.shape(), &[4, 8]);
760 assert_eq!(decoder_weights.shape(), &[8, 4]);
761 }
762
763 #[test]
764 fn test_training_config_default() {
765 let config = TrainingConfig::default();
766
767 assert_eq!(config.learning_rate, 1e-3);
768 assert_eq!(config.num_epochs, 100);
769 assert_eq!(config.batch_size, 32);
770 }
771
772 #[test]
773 fn test_trainable_convergence() {
774 let tokenizer = TrainableContinuousTokenizer::new(8, 16).unwrap();
776
777 let training_data: Vec<Array1<f32>> = (0..100)
779 .map(|i| {
780 let freq = (i % 5 + 1) as f32 * 0.1;
781 Array1::from_vec((0..8).map(|j| (j as f32 * freq).sin()).collect())
782 })
783 .collect();
784
785 let metrics_before = tokenizer.evaluate(&training_data[..10]).unwrap();
787
788 let config = TrainingConfig {
790 num_epochs: 20,
791 batch_size: 16,
792 learning_rate: 1e-2,
793 ..Default::default()
794 };
795 tokenizer.train(&training_data, &config).unwrap();
796
797 let metrics_after = tokenizer.evaluate(&training_data[..10]).unwrap();
799
800 assert!(
802 metrics_after.mse < metrics_before.mse,
803 "MSE should decrease: before={}, after={}",
804 metrics_before.mse,
805 metrics_after.mse
806 );
807
808 if metrics_before.snr_db.is_finite() {
810 assert!(metrics_after.snr_db > metrics_before.snr_db);
811 }
812 }
813
814 #[test]
815 fn test_save_load_checkpoint() {
816 use std::env;
817
818 let temp_dir = env::temp_dir();
819 let checkpoint_path = temp_dir.join("test_trainable_checkpoint.safetensors");
820
821 let tokenizer = TrainableContinuousTokenizer::new(4, 8).unwrap();
823
824 let training_data: Vec<Array1<f32>> = (0..20)
825 .map(|i| Array1::from_vec((0..4).map(|j| ((i + j) as f32 * 0.1).sin()).collect()))
826 .collect();
827
828 let config = TrainingConfig {
829 num_epochs: 5,
830 batch_size: 4,
831 learning_rate: 1e-3,
832 ..Default::default()
833 };
834
835 tokenizer.train(&training_data, &config).unwrap();
836
837 let metrics_before = tokenizer.evaluate(&training_data[..5]).unwrap();
839
840 tokenizer
842 .save_checkpoint(
843 &checkpoint_path,
844 "1.0.0",
845 Some(config.clone()),
846 Some(metrics_before.clone()),
847 )
848 .unwrap();
849
850 let loaded_tokenizer =
852 TrainableContinuousTokenizer::load_checkpoint(&checkpoint_path).unwrap();
853
854 assert_eq!(loaded_tokenizer.input_dim(), 4);
856 assert_eq!(loaded_tokenizer.embed_dim(), 8);
857
858 let metrics_loaded = loaded_tokenizer.evaluate(&training_data[..5]).unwrap();
860
861 assert!(
863 (metrics_loaded.mse - metrics_before.mse).abs() < 1e-4,
864 "Loaded model MSE should match: before={}, loaded={}",
865 metrics_before.mse,
866 metrics_loaded.mse
867 );
868
869 let test_signal = Array1::from_vec((0..4).map(|i| (i as f32) * 0.1).collect());
871 let encoded_original = tokenizer.encode(&test_signal).unwrap();
872 let encoded_loaded = loaded_tokenizer.encode(&test_signal).unwrap();
873
874 for (o, l) in encoded_original.iter().zip(encoded_loaded.iter()) {
876 assert!(
877 (o - l).abs() < 1e-4,
878 "Encoded values should match: original={}, loaded={}",
879 o,
880 l
881 );
882 }
883
884 std::fs::remove_file(&checkpoint_path).ok();
886 }
887
888 #[test]
889 fn test_peek_checkpoint() {
890 use std::env;
891
892 let temp_dir = env::temp_dir();
893 let checkpoint_path = temp_dir.join("test_peek_checkpoint.safetensors");
894
895 let tokenizer = TrainableContinuousTokenizer::new(6, 12).unwrap();
896
897 let config = TrainingConfig {
898 num_epochs: 1,
899 batch_size: 4,
900 ..Default::default()
901 };
902
903 tokenizer
904 .save_checkpoint(&checkpoint_path, "2.1.3", Some(config.clone()), None)
905 .unwrap();
906
907 let metadata = TrainableContinuousTokenizer::peek_checkpoint(&checkpoint_path).unwrap();
909
910 assert_eq!(metadata.model_type, "TrainableContinuousTokenizer");
911 assert_eq!(metadata.input_dim, 6);
912 assert_eq!(metadata.embed_dim, 12);
913 assert_eq!(metadata.version.to_string(), "2.1.3");
914 assert!(metadata.training_config.is_some());
915
916 std::fs::remove_file(&checkpoint_path).ok();
918 }
919
920 #[test]
921 fn test_checkpoint_version_compatibility() {
922 use std::env;
923
924 let temp_dir = env::temp_dir();
925 let checkpoint_path = temp_dir.join("test_version_checkpoint.safetensors");
926
927 let tokenizer = TrainableContinuousTokenizer::new(4, 8).unwrap();
928
929 tokenizer
930 .save_checkpoint(&checkpoint_path, "1.0.0", None, None)
931 .unwrap();
932
933 let metadata = TrainableContinuousTokenizer::peek_checkpoint(&checkpoint_path).unwrap();
935
936 let current_version = ModelVersion::new(1, 0, 0);
937 assert!(metadata.version.is_compatible_with(¤t_version));
938
939 let incompatible_version = ModelVersion::new(2, 0, 0);
940 assert!(!metadata.version.is_compatible_with(&incompatible_version));
941
942 std::fs::remove_file(&checkpoint_path).ok();
944 }
945
946 #[test]
947 fn test_save_checkpoint_with_metrics() {
948 use std::env;
949
950 let temp_dir = env::temp_dir();
951 let checkpoint_path = temp_dir.join("test_metrics_checkpoint.safetensors");
952
953 let tokenizer = TrainableContinuousTokenizer::new(4, 8).unwrap();
954
955 let test_data: Vec<Array1<f32>> = (0..10)
956 .map(|i| Array1::from_vec((0..4).map(|j| ((i + j) as f32 * 0.1).cos()).collect()))
957 .collect();
958
959 let metrics = tokenizer.evaluate(&test_data).unwrap();
960
961 tokenizer
962 .save_checkpoint(&checkpoint_path, "1.0.0", None, Some(metrics.clone()))
963 .unwrap();
964
965 let checkpoint = crate::persistence::ModelCheckpoint::load(&checkpoint_path).unwrap();
967 assert!(checkpoint.metadata.metrics.is_some());
968
969 let loaded_metrics = checkpoint.metadata.metrics.unwrap();
970 assert_eq!(loaded_metrics.mse, metrics.mse);
971 assert_eq!(loaded_metrics.mae, metrics.mae);
972 assert_eq!(loaded_metrics.rmse, metrics.rmse);
973
974 std::fs::remove_file(&checkpoint_path).ok();
976 }
977}