1use crate::{SignalTokenizer, TokenizerError, TokenizerResult};
10use scirs2_core::ndarray::{s, Array1, Array2};
11use serde::{Deserialize, Serialize};
12use std::f32::consts::PI;
13
14#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
16pub enum WaveletFamily {
17 Haar,
19 Daubechies4,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct WaveletConfig {
26 pub levels: usize,
28 pub family: WaveletFamily,
30 pub bits: u8,
32}
33
34impl Default for WaveletConfig {
35 fn default() -> Self {
36 Self {
37 levels: 3,
38 family: WaveletFamily::Haar,
39 bits: 8,
40 }
41 }
42}
43
44pub struct WaveletTokenizer {
46 config: WaveletConfig,
47 lowpass: Vec<f32>,
48 highpass: Vec<f32>,
49}
50
51impl WaveletTokenizer {
52 pub fn new(config: WaveletConfig) -> TokenizerResult<Self> {
54 if config.levels == 0 {
55 return Err(TokenizerError::InvalidConfig(
56 "Wavelet levels must be > 0".to_string(),
57 ));
58 }
59 if config.bits == 0 || config.bits > 16 {
60 return Err(TokenizerError::InvalidConfig(
61 "Bits must be in range [1, 16]".to_string(),
62 ));
63 }
64
65 let (lowpass, highpass) = match config.family {
66 WaveletFamily::Haar => {
67 let sqrt2_inv = 1.0 / 2.0_f32.sqrt();
69 (vec![sqrt2_inv, sqrt2_inv], vec![sqrt2_inv, -sqrt2_inv])
70 }
71 WaveletFamily::Daubechies4 => {
72 let sqrt2 = 2.0_f32.sqrt();
74 let sqrt3 = 3.0_f32.sqrt();
75 let h0 = (1.0 + sqrt3) / (4.0 * sqrt2);
76 let h1 = (3.0 + sqrt3) / (4.0 * sqrt2);
77 let h2 = (3.0 - sqrt3) / (4.0 * sqrt2);
78 let h3 = (1.0 - sqrt3) / (4.0 * sqrt2);
79 (
80 vec![h0, h1, h2, h3],
81 vec![h3, -h2, h1, -h0], )
83 }
84 };
85
86 Ok(Self {
87 config,
88 lowpass,
89 highpass,
90 })
91 }
92
93 fn dwt_step(&self, signal: &[f32]) -> (Vec<f32>, Vec<f32>) {
95 let n = signal.len();
96 let mut approx = Vec::with_capacity(n / 2);
97 let mut detail = Vec::with_capacity(n / 2);
98
99 for i in (0..n).step_by(2) {
100 let mut low_sum = 0.0;
101 let mut high_sum = 0.0;
102
103 for (j, (&l, &h)) in self.lowpass.iter().zip(self.highpass.iter()).enumerate() {
104 let idx = (i + j) % n; low_sum += signal[idx] * l;
106 high_sum += signal[idx] * h;
107 }
108
109 approx.push(low_sum);
110 detail.push(high_sum);
111 }
112
113 (approx, detail)
114 }
115
116 fn idwt_step(&self, approx: &[f32], detail: &[f32]) -> Vec<f32> {
118 let n = approx.len() * 2;
119 let mut signal = vec![0.0; n];
120
121 for i in 0..approx.len() {
122 for (j, (&l, &h)) in self.lowpass.iter().zip(self.highpass.iter()).enumerate() {
123 let idx = (2 * i + j) % n;
124 signal[idx] += approx[i] * l + detail[i] * h;
125 }
126 }
127
128 signal
129 }
130
131 fn decompose(&self, signal: &Array1<f32>) -> Vec<Vec<f32>> {
133 let mut coeffs = Vec::new();
134 let mut current = signal.to_vec();
135
136 for _ in 0..self.config.levels {
137 let (approx, detail) = self.dwt_step(¤t);
138 coeffs.push(detail);
139 current = approx;
140 }
141
142 coeffs.push(current);
144 coeffs.reverse(); coeffs
146 }
147
148 fn reconstruct(&self, coeffs: &[Vec<f32>]) -> Vec<f32> {
150 let mut current = coeffs[0].clone();
151
152 for detail in coeffs.iter().skip(1) {
153 current = self.idwt_step(¤t, detail);
154 }
155
156 current
157 }
158
159 fn quantize_coeffs(&self, coeffs: &[Vec<f32>]) -> Vec<Vec<i32>> {
161 let levels = (1 << self.config.bits) as f32;
162 let max_val = coeffs
163 .iter()
164 .flat_map(|c| c.iter())
165 .map(|&x| x.abs())
166 .fold(0.0_f32, f32::max);
167
168 if max_val == 0.0 {
169 return coeffs.iter().map(|c| vec![0; c.len()]).collect();
170 }
171
172 coeffs
173 .iter()
174 .map(|band| {
175 band.iter()
176 .map(|&x| {
177 let normalized = x / max_val; let quantized = (normalized * (levels / 2.0)).round();
179 quantized.clamp(-(levels / 2.0), levels / 2.0 - 1.0) as i32
180 })
181 .collect()
182 })
183 .collect()
184 }
185
186 fn dequantize_coeffs(&self, quantized: &[Vec<i32>], max_val: f32) -> Vec<Vec<f32>> {
188 let levels = (1 << self.config.bits) as f32;
189
190 quantized
191 .iter()
192 .map(|band| {
193 band.iter()
194 .map(|&q| (q as f32 / (levels / 2.0)) * max_val)
195 .collect()
196 })
197 .collect()
198 }
199}
200
201impl SignalTokenizer for WaveletTokenizer {
202 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
203 let coeffs = self.decompose(signal);
204 let quantized = self.quantize_coeffs(&coeffs);
205
206 let tokens: Vec<f32> = quantized
208 .iter()
209 .flat_map(|band| band.iter().map(|&q| q as f32))
210 .collect();
211
212 Ok(Array1::from_vec(tokens))
213 }
214
215 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
216 let max_val = 1.0; let quantized: Vec<i32> = tokens.iter().map(|&t| t as i32).collect();
220
221 let mut band_sizes = Vec::new();
223 let total_len = quantized.len();
224 let mut remaining = total_len;
225 for _ in 0..self.config.levels {
226 let size = remaining / 2;
227 band_sizes.push(size);
228 remaining -= size;
229 }
230 band_sizes.push(remaining);
231 band_sizes.reverse();
232
233 let mut offset = 0;
234 let mut bands = Vec::new();
235 for &size in &band_sizes {
236 bands.push(quantized[offset..offset + size].to_vec());
237 offset += size;
238 }
239
240 let dequantized = self.dequantize_coeffs(&bands, max_val);
241 let reconstructed = self.reconstruct(&dequantized);
242
243 Ok(Array1::from_vec(reconstructed))
244 }
245
246 fn embed_dim(&self) -> usize {
247 0 }
250
251 fn vocab_size(&self) -> usize {
252 1 << self.config.bits
253 }
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct FourierConfig {
259 pub num_bins: usize,
261 pub magnitude_only: bool,
263 pub bits: u8,
265}
266
267impl Default for FourierConfig {
268 fn default() -> Self {
269 Self {
270 num_bins: 256,
271 magnitude_only: false,
272 bits: 8,
273 }
274 }
275}
276
277pub struct FourierTokenizer {
279 config: FourierConfig,
280}
281
282impl FourierTokenizer {
283 pub fn new(config: FourierConfig) -> TokenizerResult<Self> {
285 if config.num_bins == 0 {
286 return Err(TokenizerError::InvalidConfig(
287 "Number of bins must be > 0".to_string(),
288 ));
289 }
290 Ok(Self { config })
291 }
292
293 fn fft(&self, signal: &[f32]) -> Vec<(f32, f32)> {
295 let n = signal.len();
296 let mut spectrum = Vec::with_capacity(n);
297
298 for k in 0..n {
299 let mut real_sum = 0.0;
300 let mut imag_sum = 0.0;
301
302 for (i, &x) in signal.iter().enumerate() {
303 let angle = -2.0 * PI * (k as f32) * (i as f32) / (n as f32);
304 real_sum += x * angle.cos();
305 imag_sum += x * angle.sin();
306 }
307
308 spectrum.push((real_sum, imag_sum));
309 }
310
311 spectrum
312 }
313
314 fn ifft(&self, spectrum: &[(f32, f32)]) -> Vec<f32> {
316 let n = spectrum.len();
317 let mut signal = Vec::with_capacity(n);
318
319 for i in 0..n {
320 let mut sum = 0.0;
321
322 for (k, &(real, imag)) in spectrum.iter().enumerate() {
323 let angle = 2.0 * PI * (k as f32) * (i as f32) / (n as f32);
324 sum += real * angle.cos() - imag * angle.sin();
325 }
326
327 signal.push(sum / (n as f32));
328 }
329
330 signal
331 }
332}
333
334impl SignalTokenizer for FourierTokenizer {
335 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
336 let spectrum = self.fft(
337 signal
338 .as_slice()
339 .expect("Signal must have contiguous layout"),
340 );
341
342 let tokens: Vec<f32> = spectrum
343 .iter()
344 .take(self.config.num_bins)
345 .flat_map(|&(real, imag)| {
346 if self.config.magnitude_only {
347 vec![(real * real + imag * imag).sqrt()]
348 } else {
349 vec![real, imag]
350 }
351 })
352 .collect();
353
354 Ok(Array1::from_vec(tokens))
355 }
356
357 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
358 let spectrum: Vec<(f32, f32)> = if self.config.magnitude_only {
359 tokens
360 .iter()
361 .map(|&mag| (mag, 0.0)) .collect()
363 } else {
364 let mut result = Vec::new();
366 let tokens_slice = tokens
367 .as_slice()
368 .expect("Tokens must have contiguous layout");
369 for i in (0..tokens_slice.len()).step_by(2) {
370 let real = tokens_slice[i];
371 let imag = tokens_slice.get(i + 1).copied().unwrap_or(0.0);
372 result.push((real, imag));
373 }
374 result
375 };
376
377 let reconstructed = self.ifft(&spectrum);
378 Ok(Array1::from_vec(reconstructed))
379 }
380
381 fn embed_dim(&self) -> usize {
382 if self.config.magnitude_only {
383 self.config.num_bins
384 } else {
385 self.config.num_bins * 2
386 }
387 }
388
389 fn vocab_size(&self) -> usize {
390 0 }
392}
393
394#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct DCTConfig {
397 pub num_coeffs: usize,
399 pub bits: u8,
401}
402
403impl Default for DCTConfig {
404 fn default() -> Self {
405 Self {
406 num_coeffs: 64,
407 bits: 8,
408 }
409 }
410}
411
412pub struct DCTTokenizer {
414 config: DCTConfig,
415}
416
417impl DCTTokenizer {
418 pub fn new(config: DCTConfig) -> TokenizerResult<Self> {
420 if config.num_coeffs == 0 {
421 return Err(TokenizerError::InvalidConfig(
422 "Number of coefficients must be > 0".to_string(),
423 ));
424 }
425 Ok(Self { config })
426 }
427
428 fn dct(&self, signal: &[f32]) -> Vec<f32> {
430 let n = signal.len();
431 let mut coeffs = Vec::with_capacity(n);
432
433 for k in 0..n {
434 let mut sum = 0.0;
435 for (i, &x) in signal.iter().enumerate() {
436 sum += x * ((PI * k as f32 * (2 * i + 1) as f32) / (2.0 * n as f32)).cos();
437 }
438
439 let scale = if k == 0 {
440 (1.0 / n as f32).sqrt()
441 } else {
442 (2.0 / n as f32).sqrt()
443 };
444
445 coeffs.push(sum * scale);
446 }
447
448 coeffs
449 }
450
451 fn idct(&self, coeffs: &[f32]) -> Vec<f32> {
453 let n = coeffs.len();
454 let mut signal = Vec::with_capacity(n);
455
456 for i in 0..n {
457 let mut sum = 0.0;
458
459 for (k, &c) in coeffs.iter().enumerate() {
460 let scale = if k == 0 {
461 (1.0 / n as f32).sqrt()
462 } else {
463 (2.0 / n as f32).sqrt()
464 };
465
466 sum += c * scale * ((PI * k as f32 * (2 * i + 1) as f32) / (2.0 * n as f32)).cos();
467 }
468
469 signal.push(sum);
470 }
471
472 signal
473 }
474
475 fn quantize(&self, coeffs: &[f32]) -> Vec<i32> {
477 let levels = (1 << self.config.bits) as f32;
478 let max_val = coeffs
479 .iter()
480 .take(self.config.num_coeffs)
481 .map(|&x| x.abs())
482 .fold(0.0_f32, f32::max);
483
484 if max_val == 0.0 {
485 return vec![0; self.config.num_coeffs];
486 }
487
488 coeffs
489 .iter()
490 .take(self.config.num_coeffs)
491 .map(|&x| {
492 let normalized = x / max_val;
493 let quantized = (normalized * (levels / 2.0)).round();
494 quantized.clamp(-(levels / 2.0), levels / 2.0 - 1.0) as i32
495 })
496 .collect()
497 }
498
499 fn dequantize(&self, quantized: &[i32], max_val: f32) -> Vec<f32> {
501 let levels = (1 << self.config.bits) as f32;
502
503 quantized
504 .iter()
505 .map(|&q| (q as f32 / (levels / 2.0)) * max_val)
506 .collect()
507 }
508}
509
510impl SignalTokenizer for DCTTokenizer {
511 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
512 let coeffs = self.dct(
513 signal
514 .as_slice()
515 .expect("Signal must have contiguous layout"),
516 );
517 let quantized = self.quantize(&coeffs);
518
519 let tokens: Vec<f32> = quantized.iter().map(|&q| q as f32).collect();
520 Ok(Array1::from_vec(tokens))
521 }
522
523 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
524 let max_val = 1.0; let quantized: Vec<i32> = tokens.iter().map(|&t| t as i32).collect();
526 let coeffs = self.dequantize(&quantized, max_val);
527
528 let mut full_coeffs = coeffs;
530 while full_coeffs.len() < tokens.len() {
531 full_coeffs.push(0.0);
532 }
533
534 let reconstructed = self.idct(&full_coeffs);
535 Ok(Array1::from_vec(reconstructed))
536 }
537
538 fn embed_dim(&self) -> usize {
539 self.config.num_coeffs
540 }
541
542 fn vocab_size(&self) -> usize {
543 1 << self.config.bits
544 }
545}
546
547#[derive(Debug, Clone, Serialize, Deserialize)]
549pub struct KMeansConfig {
550 pub num_clusters: usize,
552 pub embed_dim: usize,
554 pub max_iterations: usize,
556 pub tolerance: f32,
558}
559
560impl Default for KMeansConfig {
561 fn default() -> Self {
562 Self {
563 num_clusters: 256,
564 embed_dim: 16,
565 max_iterations: 100,
566 tolerance: 1e-4,
567 }
568 }
569}
570
571pub struct KMeansTokenizer {
573 config: KMeansConfig,
574 centroids: Array2<f32>,
575 trained: bool,
576}
577
578impl KMeansTokenizer {
579 pub fn new(config: KMeansConfig) -> TokenizerResult<Self> {
581 if config.num_clusters == 0 {
582 return Err(TokenizerError::InvalidConfig(
583 "Number of clusters must be > 0".to_string(),
584 ));
585 }
586 if config.embed_dim == 0 {
587 return Err(TokenizerError::InvalidConfig(
588 "Embedding dimension must be > 0".to_string(),
589 ));
590 }
591
592 let centroids = Array2::zeros((config.num_clusters, config.embed_dim));
593
594 Ok(Self {
595 config,
596 centroids,
597 trained: false,
598 })
599 }
600
601 pub fn train(&mut self, data: &[Array1<f32>]) -> TokenizerResult<()> {
603 if data.is_empty() {
604 return Err(TokenizerError::InvalidConfig(
605 "No training data".to_string(),
606 ));
607 }
608
609 let mut windows = Vec::new();
611 for signal in data {
612 for i in 0..=signal.len().saturating_sub(self.config.embed_dim) {
613 let window = signal.slice(s![i..i + self.config.embed_dim]).to_owned();
614 windows.push(window);
615 }
616 }
617
618 if windows.len() < self.config.num_clusters {
619 return Err(TokenizerError::InvalidConfig(
620 "Not enough data for clustering".to_string(),
621 ));
622 }
623
624 self.kmeans_plus_plus_init(&windows)?;
626
627 for iteration in 0..self.config.max_iterations {
629 let assignments = self.assign_clusters(&windows);
631
632 let old_centroids = self.centroids.clone();
634 self.update_centroids(&windows, &assignments)?;
635
636 let change = self.compute_centroid_change(&old_centroids);
638 if change < self.config.tolerance {
639 tracing::debug!("K-means converged at iteration {}", iteration);
640 break;
641 }
642 }
643
644 self.trained = true;
645 Ok(())
646 }
647
648 fn kmeans_plus_plus_init(&mut self, windows: &[Array1<f32>]) -> TokenizerResult<()> {
650 use scirs2_core::random::quick::{random_f32, random_usize};
651
652 let first_idx = random_usize(0, windows.len() - 1);
654 self.centroids.row_mut(0).assign(&windows[first_idx].view());
655
656 for k in 1..self.config.num_clusters {
658 let mut distances = vec![f32::MAX; windows.len()];
659
660 for (i, window) in windows.iter().enumerate() {
662 for j in 0..k {
663 let centroid = self.centroids.row(j);
664 let dist = Self::euclidean_distance(window, ¢roid.to_owned());
665 distances[i] = distances[i].min(dist);
666 }
667 }
668
669 let total: f32 = distances.iter().map(|&d| d * d).sum();
671 let mut threshold = random_f32() * total;
672 let mut chosen_idx = 0;
673
674 for (i, &dist) in distances.iter().enumerate() {
675 threshold -= dist * dist;
676 if threshold <= 0.0 {
677 chosen_idx = i;
678 break;
679 }
680 }
681
682 self.centroids
683 .row_mut(k)
684 .assign(&windows[chosen_idx].view());
685 }
686
687 Ok(())
688 }
689
690 fn assign_clusters(&self, windows: &[Array1<f32>]) -> Vec<usize> {
692 windows
693 .iter()
694 .map(|window| self.find_nearest_centroid(window))
695 .collect()
696 }
697
698 fn update_centroids(
700 &mut self,
701 windows: &[Array1<f32>],
702 assignments: &[usize],
703 ) -> TokenizerResult<()> {
704 let mut counts = vec![0usize; self.config.num_clusters];
705 self.centroids.fill(0.0);
706
707 for (window, &cluster) in windows.iter().zip(assignments.iter()) {
709 for (i, &val) in window.iter().enumerate() {
710 self.centroids[[cluster, i]] += val;
711 }
712 counts[cluster] += 1;
713 }
714
715 for (k, &count) in counts.iter().enumerate().take(self.config.num_clusters) {
717 if count > 0 {
718 for i in 0..self.config.embed_dim {
719 self.centroids[[k, i]] /= count as f32;
720 }
721 }
722 }
723
724 Ok(())
725 }
726
727 fn find_nearest_centroid(&self, window: &Array1<f32>) -> usize {
729 (0..self.config.num_clusters)
730 .min_by(|&a, &b| {
731 let dist_a = Self::euclidean_distance(window, &self.centroids.row(a).to_owned());
732 let dist_b = Self::euclidean_distance(window, &self.centroids.row(b).to_owned());
733 dist_a
734 .partial_cmp(&dist_b)
735 .unwrap_or(std::cmp::Ordering::Equal)
736 })
737 .expect("Range must be non-empty")
738 }
739
740 fn euclidean_distance(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
742 a.iter()
743 .zip(b.iter())
744 .map(|(x, y)| (x - y).powi(2))
745 .sum::<f32>()
746 .sqrt()
747 }
748
749 fn compute_centroid_change(&self, old_centroids: &Array2<f32>) -> f32 {
751 self.centroids
752 .iter()
753 .zip(old_centroids.iter())
754 .map(|(a, b)| (a - b).powi(2))
755 .sum::<f32>()
756 .sqrt()
757 }
758
759 pub fn is_trained(&self) -> bool {
761 self.trained
762 }
763
764 pub fn centroids(&self) -> &Array2<f32> {
766 &self.centroids
767 }
768}
769
770impl SignalTokenizer for KMeansTokenizer {
771 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
772 if !self.trained {
773 return Err(TokenizerError::InvalidConfig(
774 "K-means model not trained".to_string(),
775 ));
776 }
777
778 let mut tokens = Vec::new();
779
780 for i in 0..=signal.len().saturating_sub(self.config.embed_dim) {
782 let window = signal.slice(s![i..i + self.config.embed_dim]).to_owned();
783 let cluster = self.find_nearest_centroid(&window);
784 tokens.push(cluster as f32);
785 }
786
787 Ok(Array1::from_vec(tokens))
788 }
789
790 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
791 if !self.trained {
792 return Err(TokenizerError::InvalidConfig(
793 "K-means model not trained".to_string(),
794 ));
795 }
796
797 let output_len = tokens.len() + self.config.embed_dim - 1;
799 let mut signal = vec![0.0; output_len];
800 let mut counts = vec![0.0; output_len];
801
802 for (i, &token) in tokens.iter().enumerate() {
803 let cluster = token as usize;
804 if cluster >= self.config.num_clusters {
805 return Err(TokenizerError::invalid_input(
806 "decoding",
807 "Invalid cluster index",
808 ));
809 }
810
811 let centroid = self.centroids.row(cluster);
812 for (j, &val) in centroid.iter().enumerate() {
813 signal[i + j] += val;
814 counts[i + j] += 1.0;
815 }
816 }
817
818 for (s, c) in signal.iter_mut().zip(counts.iter()) {
820 if *c > 0.0 {
821 *s /= c;
822 }
823 }
824
825 Ok(Array1::from_vec(signal))
826 }
827
828 fn embed_dim(&self) -> usize {
829 self.config.embed_dim
830 }
831
832 fn vocab_size(&self) -> usize {
833 self.config.num_clusters
834 }
835}
836
837#[cfg(test)]
838mod tests {
839 use super::*;
840
841 #[test]
842 fn test_wavelet_haar_basic() {
843 let config = WaveletConfig {
844 levels: 2,
845 family: WaveletFamily::Haar,
846 bits: 8,
847 };
848 let tokenizer = WaveletTokenizer::new(config).unwrap();
849
850 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
851 let tokens = tokenizer.encode(&signal).unwrap();
852 assert!(!tokens.is_empty());
853
854 let reconstructed = tokenizer.decode(&tokens).unwrap();
855 assert_eq!(reconstructed.len(), signal.len());
856 }
857
858 #[test]
859 fn test_wavelet_daubechies4() {
860 let config = WaveletConfig {
861 levels: 1,
862 family: WaveletFamily::Daubechies4,
863 bits: 8,
864 };
865 let tokenizer = WaveletTokenizer::new(config).unwrap();
866
867 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
868 let tokens = tokenizer.encode(&signal).unwrap();
869 assert!(!tokens.is_empty());
870 }
871
872 #[test]
873 fn test_wavelet_invalid_config() {
874 let config = WaveletConfig {
875 levels: 0,
876 family: WaveletFamily::Haar,
877 bits: 8,
878 };
879 assert!(WaveletTokenizer::new(config).is_err());
880
881 let config = WaveletConfig {
882 levels: 1,
883 family: WaveletFamily::Haar,
884 bits: 0,
885 };
886 assert!(WaveletTokenizer::new(config).is_err());
887 }
888
889 #[test]
890 fn test_fourier_magnitude_only() {
891 let config = FourierConfig {
892 num_bins: 8,
893 magnitude_only: true,
894 bits: 8,
895 };
896 let tokenizer = FourierTokenizer::new(config).unwrap();
897
898 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 1.0, 0.0]);
899 let tokens = tokenizer.encode(&signal).unwrap();
900 assert_eq!(tokens.len(), 8); let reconstructed = tokenizer.decode(&tokens).unwrap();
903 assert_eq!(reconstructed.len(), 8);
904 }
905
906 #[test]
907 fn test_fourier_complex() {
908 let config = FourierConfig {
909 num_bins: 4,
910 magnitude_only: false,
911 bits: 8,
912 };
913 let tokenizer = FourierTokenizer::new(config).unwrap();
914
915 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
916 let tokens = tokenizer.encode(&signal).unwrap();
917 assert_eq!(tokens.len(), 8); }
919
920 #[test]
921 fn test_dct_basic() {
922 let config = DCTConfig {
923 num_coeffs: 8,
924 bits: 8,
925 };
926 let tokenizer = DCTTokenizer::new(config).unwrap();
927
928 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
929 let tokens = tokenizer.encode(&signal).unwrap();
930 assert_eq!(tokens.len(), 8);
931
932 let reconstructed = tokenizer.decode(&tokens).unwrap();
933 assert_eq!(reconstructed.len(), 8);
934 }
935
936 #[test]
937 fn test_dct_compression() {
938 let config = DCTConfig {
939 num_coeffs: 4,
940 bits: 8,
941 };
942 let tokenizer = DCTTokenizer::new(config).unwrap();
943
944 let signal = Array1::from_vec(vec![1.0, 1.1, 1.2, 1.1, 1.0, 0.9, 0.8, 0.9]);
946 let tokens = tokenizer.encode(&signal).unwrap();
947 assert_eq!(tokens.len(), 4); }
949
950 #[test]
951 fn test_kmeans_training() {
952 let config = KMeansConfig {
953 num_clusters: 4,
954 embed_dim: 4,
955 max_iterations: 50,
956 tolerance: 1e-3,
957 };
958 let mut tokenizer = KMeansTokenizer::new(config).unwrap();
959
960 let data = vec![
962 Array1::from_vec(vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0]),
963 Array1::from_vec(vec![3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0]),
964 ];
965
966 assert!(!tokenizer.is_trained());
967 tokenizer.train(&data).unwrap();
968 assert!(tokenizer.is_trained());
969
970 let centroids = tokenizer.centroids();
971 assert_eq!(centroids.shape(), &[4, 4]);
972 }
973
974 #[test]
975 fn test_kmeans_encode_decode() {
976 let config = KMeansConfig {
977 num_clusters: 8,
978 embed_dim: 4,
979 max_iterations: 100,
980 tolerance: 1e-4,
981 };
982 let mut tokenizer = KMeansTokenizer::new(config).unwrap();
983
984 let data = vec![
986 Array1::from_vec((0..32).map(|x| x as f32).collect::<Vec<_>>()),
987 Array1::from_vec((0..32).map(|x| (x as f32).sin()).collect::<Vec<_>>()),
988 ];
989
990 tokenizer.train(&data).unwrap();
991
992 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
993 let tokens = tokenizer.encode(&signal).unwrap();
994 assert!(!tokens.is_empty());
995
996 let reconstructed = tokenizer.decode(&tokens).unwrap();
997 assert!(!reconstructed.is_empty());
998 }
999
1000 #[test]
1001 fn test_kmeans_untrained_error() {
1002 let config = KMeansConfig::default();
1003 let tokenizer = KMeansTokenizer::new(config).unwrap();
1004
1005 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1006 assert!(tokenizer.encode(&signal).is_err());
1007 }
1008
1009 #[test]
1010 fn test_kmeans_invalid_config() {
1011 let config = KMeansConfig {
1012 num_clusters: 0,
1013 embed_dim: 4,
1014 max_iterations: 10,
1015 tolerance: 1e-3,
1016 };
1017 assert!(KMeansTokenizer::new(config).is_err());
1018 }
1019
1020 #[test]
1021 fn test_signal_tokenizer_trait() {
1022 let tokenizers: Vec<Box<dyn SignalTokenizer>> = vec![
1023 Box::new(
1024 WaveletTokenizer::new(WaveletConfig {
1025 levels: 1,
1026 family: WaveletFamily::Haar,
1027 bits: 8,
1028 })
1029 .unwrap(),
1030 ),
1031 Box::new(
1032 FourierTokenizer::new(FourierConfig {
1033 num_bins: 8,
1034 magnitude_only: true,
1035 bits: 8,
1036 })
1037 .unwrap(),
1038 ),
1039 Box::new(
1040 DCTTokenizer::new(DCTConfig {
1041 num_coeffs: 8,
1042 bits: 8,
1043 })
1044 .unwrap(),
1045 ),
1046 ];
1047
1048 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1049
1050 for tokenizer in tokenizers {
1051 let tokens = tokenizer.encode(&signal).unwrap();
1052 assert!(!tokens.is_empty());
1053 assert!(tokenizer.vocab_size() > 0 || tokenizer.embed_dim() > 0);
1054 }
1055 }
1056}