kizzasi_tokenizer/
specialized.rs

1//! Specialized tokenizers for signal processing
2//!
3//! This module provides domain-specific tokenization strategies:
4//! - Wavelet-based: Multi-resolution time-frequency analysis
5//! - Fourier-based: Frequency domain representation via FFT
6//! - DCT-based: Discrete Cosine Transform (JPEG-style compression)
7//! - K-means: Clustering-based vector quantization
8
9use crate::{SignalTokenizer, TokenizerError, TokenizerResult};
10use scirs2_core::ndarray::{s, Array1, Array2};
11use serde::{Deserialize, Serialize};
12use std::f32::consts::PI;
13
14/// Wavelet family for decomposition
15#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
16pub enum WaveletFamily {
17    /// Haar wavelet (simplest, discontinuous)
18    Haar,
19    /// Daubechies 4-tap wavelet (smooth, compact support)
20    Daubechies4,
21}
22
23/// Configuration for wavelet-based tokenization
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct WaveletConfig {
26    /// Number of decomposition levels
27    pub levels: usize,
28    /// Wavelet family to use
29    pub family: WaveletFamily,
30    /// Quantization bits per coefficient
31    pub bits: u8,
32}
33
34impl Default for WaveletConfig {
35    fn default() -> Self {
36        Self {
37            levels: 3,
38            family: WaveletFamily::Haar,
39            bits: 8,
40        }
41    }
42}
43
44/// Wavelet-based tokenizer using multi-resolution decomposition
45pub struct WaveletTokenizer {
46    config: WaveletConfig,
47    lowpass: Vec<f32>,
48    highpass: Vec<f32>,
49}
50
51impl WaveletTokenizer {
52    /// Create a new wavelet tokenizer
53    pub fn new(config: WaveletConfig) -> TokenizerResult<Self> {
54        if config.levels == 0 {
55            return Err(TokenizerError::InvalidConfig(
56                "Wavelet levels must be > 0".to_string(),
57            ));
58        }
59        if config.bits == 0 || config.bits > 16 {
60            return Err(TokenizerError::InvalidConfig(
61                "Bits must be in range [1, 16]".to_string(),
62            ));
63        }
64
65        let (lowpass, highpass) = match config.family {
66            WaveletFamily::Haar => {
67                // Haar wavelet filters (normalized)
68                let sqrt2_inv = 1.0 / 2.0_f32.sqrt();
69                (vec![sqrt2_inv, sqrt2_inv], vec![sqrt2_inv, -sqrt2_inv])
70            }
71            WaveletFamily::Daubechies4 => {
72                // Daubechies-4 wavelet coefficients
73                let sqrt2 = 2.0_f32.sqrt();
74                let sqrt3 = 3.0_f32.sqrt();
75                let h0 = (1.0 + sqrt3) / (4.0 * sqrt2);
76                let h1 = (3.0 + sqrt3) / (4.0 * sqrt2);
77                let h2 = (3.0 - sqrt3) / (4.0 * sqrt2);
78                let h3 = (1.0 - sqrt3) / (4.0 * sqrt2);
79                (
80                    vec![h0, h1, h2, h3],
81                    vec![h3, -h2, h1, -h0], // QMF relationship
82                )
83            }
84        };
85
86        Ok(Self {
87            config,
88            lowpass,
89            highpass,
90        })
91    }
92
93    /// Forward wavelet transform (one level)
94    fn dwt_step(&self, signal: &[f32]) -> (Vec<f32>, Vec<f32>) {
95        let n = signal.len();
96        let mut approx = Vec::with_capacity(n / 2);
97        let mut detail = Vec::with_capacity(n / 2);
98
99        for i in (0..n).step_by(2) {
100            let mut low_sum = 0.0;
101            let mut high_sum = 0.0;
102
103            for (j, (&l, &h)) in self.lowpass.iter().zip(self.highpass.iter()).enumerate() {
104                let idx = (i + j) % n; // Circular boundary
105                low_sum += signal[idx] * l;
106                high_sum += signal[idx] * h;
107            }
108
109            approx.push(low_sum);
110            detail.push(high_sum);
111        }
112
113        (approx, detail)
114    }
115
116    /// Inverse wavelet transform (one level)
117    fn idwt_step(&self, approx: &[f32], detail: &[f32]) -> Vec<f32> {
118        let n = approx.len() * 2;
119        let mut signal = vec![0.0; n];
120
121        for i in 0..approx.len() {
122            for (j, (&l, &h)) in self.lowpass.iter().zip(self.highpass.iter()).enumerate() {
123                let idx = (2 * i + j) % n;
124                signal[idx] += approx[i] * l + detail[i] * h;
125            }
126        }
127
128        signal
129    }
130
131    /// Multi-level decomposition
132    fn decompose(&self, signal: &Array1<f32>) -> Vec<Vec<f32>> {
133        let mut coeffs = Vec::new();
134        let mut current = signal.to_vec();
135
136        for _ in 0..self.config.levels {
137            let (approx, detail) = self.dwt_step(&current);
138            coeffs.push(detail);
139            current = approx;
140        }
141
142        // Add final approximation
143        coeffs.push(current);
144        coeffs.reverse(); // [approx, detail_N, ..., detail_1]
145        coeffs
146    }
147
148    /// Multi-level reconstruction
149    fn reconstruct(&self, coeffs: &[Vec<f32>]) -> Vec<f32> {
150        let mut current = coeffs[0].clone();
151
152        for detail in coeffs.iter().skip(1) {
153            current = self.idwt_step(&current, detail);
154        }
155
156        current
157    }
158
159    /// Quantize coefficients
160    fn quantize_coeffs(&self, coeffs: &[Vec<f32>]) -> Vec<Vec<i32>> {
161        let levels = (1 << self.config.bits) as f32;
162        let max_val = coeffs
163            .iter()
164            .flat_map(|c| c.iter())
165            .map(|&x| x.abs())
166            .fold(0.0_f32, f32::max);
167
168        if max_val == 0.0 {
169            return coeffs.iter().map(|c| vec![0; c.len()]).collect();
170        }
171
172        coeffs
173            .iter()
174            .map(|band| {
175                band.iter()
176                    .map(|&x| {
177                        let normalized = x / max_val; // [-1, 1]
178                        let quantized = (normalized * (levels / 2.0)).round();
179                        quantized.clamp(-(levels / 2.0), levels / 2.0 - 1.0) as i32
180                    })
181                    .collect()
182            })
183            .collect()
184    }
185
186    /// Dequantize coefficients
187    fn dequantize_coeffs(&self, quantized: &[Vec<i32>], max_val: f32) -> Vec<Vec<f32>> {
188        let levels = (1 << self.config.bits) as f32;
189
190        quantized
191            .iter()
192            .map(|band| {
193                band.iter()
194                    .map(|&q| (q as f32 / (levels / 2.0)) * max_val)
195                    .collect()
196            })
197            .collect()
198    }
199}
200
201impl SignalTokenizer for WaveletTokenizer {
202    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
203        let coeffs = self.decompose(signal);
204        let quantized = self.quantize_coeffs(&coeffs);
205
206        // Flatten into 1D array
207        let tokens: Vec<f32> = quantized
208            .iter()
209            .flat_map(|band| band.iter().map(|&q| q as f32))
210            .collect();
211
212        Ok(Array1::from_vec(tokens))
213    }
214
215    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
216        // Reconstruct coefficient structure
217        // This is a simplified version - in practice, we'd need to store band sizes
218        let max_val = 1.0; // Simplified - should be stored with coefficients
219        let quantized: Vec<i32> = tokens.iter().map(|&t| t as i32).collect();
220
221        // Estimate band sizes (simplified - assumes power-of-2 signal length)
222        let mut band_sizes = Vec::new();
223        let total_len = quantized.len();
224        let mut remaining = total_len;
225        for _ in 0..self.config.levels {
226            let size = remaining / 2;
227            band_sizes.push(size);
228            remaining -= size;
229        }
230        band_sizes.push(remaining);
231        band_sizes.reverse();
232
233        let mut offset = 0;
234        let mut bands = Vec::new();
235        for &size in &band_sizes {
236            bands.push(quantized[offset..offset + size].to_vec());
237            offset += size;
238        }
239
240        let dequantized = self.dequantize_coeffs(&bands, max_val);
241        let reconstructed = self.reconstruct(&dequantized);
242
243        Ok(Array1::from_vec(reconstructed))
244    }
245
246    fn embed_dim(&self) -> usize {
247        // Variable based on signal length and decomposition
248        0 // Indicates variable length
249    }
250
251    fn vocab_size(&self) -> usize {
252        1 << self.config.bits
253    }
254}
255
256/// Configuration for Fourier-based tokenization
257#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct FourierConfig {
259    /// Number of frequency bins to keep
260    pub num_bins: usize,
261    /// Whether to use magnitude only (discard phase)
262    pub magnitude_only: bool,
263    /// Quantization bits
264    pub bits: u8,
265}
266
267impl Default for FourierConfig {
268    fn default() -> Self {
269        Self {
270            num_bins: 256,
271            magnitude_only: false,
272            bits: 8,
273        }
274    }
275}
276
277/// Fourier-based tokenizer using FFT
278pub struct FourierTokenizer {
279    config: FourierConfig,
280}
281
282impl FourierTokenizer {
283    /// Create a new Fourier tokenizer
284    pub fn new(config: FourierConfig) -> TokenizerResult<Self> {
285        if config.num_bins == 0 {
286            return Err(TokenizerError::InvalidConfig(
287                "Number of bins must be > 0".to_string(),
288            ));
289        }
290        Ok(Self { config })
291    }
292
293    /// Compute FFT (simplified DFT for now)
294    fn fft(&self, signal: &[f32]) -> Vec<(f32, f32)> {
295        let n = signal.len();
296        let mut spectrum = Vec::with_capacity(n);
297
298        for k in 0..n {
299            let mut real_sum = 0.0;
300            let mut imag_sum = 0.0;
301
302            for (i, &x) in signal.iter().enumerate() {
303                let angle = -2.0 * PI * (k as f32) * (i as f32) / (n as f32);
304                real_sum += x * angle.cos();
305                imag_sum += x * angle.sin();
306            }
307
308            spectrum.push((real_sum, imag_sum));
309        }
310
311        spectrum
312    }
313
314    /// Compute inverse FFT
315    fn ifft(&self, spectrum: &[(f32, f32)]) -> Vec<f32> {
316        let n = spectrum.len();
317        let mut signal = Vec::with_capacity(n);
318
319        for i in 0..n {
320            let mut sum = 0.0;
321
322            for (k, &(real, imag)) in spectrum.iter().enumerate() {
323                let angle = 2.0 * PI * (k as f32) * (i as f32) / (n as f32);
324                sum += real * angle.cos() - imag * angle.sin();
325            }
326
327            signal.push(sum / (n as f32));
328        }
329
330        signal
331    }
332}
333
334impl SignalTokenizer for FourierTokenizer {
335    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
336        let spectrum = self.fft(
337            signal
338                .as_slice()
339                .expect("Signal must have contiguous layout"),
340        );
341
342        let tokens: Vec<f32> = spectrum
343            .iter()
344            .take(self.config.num_bins)
345            .flat_map(|&(real, imag)| {
346                if self.config.magnitude_only {
347                    vec![(real * real + imag * imag).sqrt()]
348                } else {
349                    vec![real, imag]
350                }
351            })
352            .collect();
353
354        Ok(Array1::from_vec(tokens))
355    }
356
357    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
358        let spectrum: Vec<(f32, f32)> = if self.config.magnitude_only {
359            tokens
360                .iter()
361                .map(|&mag| (mag, 0.0)) // Zero phase
362                .collect()
363        } else {
364            // Manually iterate in chunks of 2
365            let mut result = Vec::new();
366            let tokens_slice = tokens
367                .as_slice()
368                .expect("Tokens must have contiguous layout");
369            for i in (0..tokens_slice.len()).step_by(2) {
370                let real = tokens_slice[i];
371                let imag = tokens_slice.get(i + 1).copied().unwrap_or(0.0);
372                result.push((real, imag));
373            }
374            result
375        };
376
377        let reconstructed = self.ifft(&spectrum);
378        Ok(Array1::from_vec(reconstructed))
379    }
380
381    fn embed_dim(&self) -> usize {
382        if self.config.magnitude_only {
383            self.config.num_bins
384        } else {
385            self.config.num_bins * 2
386        }
387    }
388
389    fn vocab_size(&self) -> usize {
390        0 // Continuous
391    }
392}
393
394/// Configuration for DCT-based tokenization
395#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct DCTConfig {
397    /// Number of DCT coefficients to keep
398    pub num_coeffs: usize,
399    /// Quantization bits
400    pub bits: u8,
401}
402
403impl Default for DCTConfig {
404    fn default() -> Self {
405        Self {
406            num_coeffs: 64,
407            bits: 8,
408        }
409    }
410}
411
412/// DCT-based tokenizer (Type-II DCT, like JPEG)
413pub struct DCTTokenizer {
414    config: DCTConfig,
415}
416
417impl DCTTokenizer {
418    /// Create a new DCT tokenizer
419    pub fn new(config: DCTConfig) -> TokenizerResult<Self> {
420        if config.num_coeffs == 0 {
421            return Err(TokenizerError::InvalidConfig(
422                "Number of coefficients must be > 0".to_string(),
423            ));
424        }
425        Ok(Self { config })
426    }
427
428    /// Compute DCT-II
429    fn dct(&self, signal: &[f32]) -> Vec<f32> {
430        let n = signal.len();
431        let mut coeffs = Vec::with_capacity(n);
432
433        for k in 0..n {
434            let mut sum = 0.0;
435            for (i, &x) in signal.iter().enumerate() {
436                sum += x * ((PI * k as f32 * (2 * i + 1) as f32) / (2.0 * n as f32)).cos();
437            }
438
439            let scale = if k == 0 {
440                (1.0 / n as f32).sqrt()
441            } else {
442                (2.0 / n as f32).sqrt()
443            };
444
445            coeffs.push(sum * scale);
446        }
447
448        coeffs
449    }
450
451    /// Compute inverse DCT-II (DCT-III)
452    fn idct(&self, coeffs: &[f32]) -> Vec<f32> {
453        let n = coeffs.len();
454        let mut signal = Vec::with_capacity(n);
455
456        for i in 0..n {
457            let mut sum = 0.0;
458
459            for (k, &c) in coeffs.iter().enumerate() {
460                let scale = if k == 0 {
461                    (1.0 / n as f32).sqrt()
462                } else {
463                    (2.0 / n as f32).sqrt()
464                };
465
466                sum += c * scale * ((PI * k as f32 * (2 * i + 1) as f32) / (2.0 * n as f32)).cos();
467            }
468
469            signal.push(sum);
470        }
471
472        signal
473    }
474
475    /// Quantize DCT coefficients (zig-zag scan and quantization)
476    fn quantize(&self, coeffs: &[f32]) -> Vec<i32> {
477        let levels = (1 << self.config.bits) as f32;
478        let max_val = coeffs
479            .iter()
480            .take(self.config.num_coeffs)
481            .map(|&x| x.abs())
482            .fold(0.0_f32, f32::max);
483
484        if max_val == 0.0 {
485            return vec![0; self.config.num_coeffs];
486        }
487
488        coeffs
489            .iter()
490            .take(self.config.num_coeffs)
491            .map(|&x| {
492                let normalized = x / max_val;
493                let quantized = (normalized * (levels / 2.0)).round();
494                quantized.clamp(-(levels / 2.0), levels / 2.0 - 1.0) as i32
495            })
496            .collect()
497    }
498
499    /// Dequantize coefficients
500    fn dequantize(&self, quantized: &[i32], max_val: f32) -> Vec<f32> {
501        let levels = (1 << self.config.bits) as f32;
502
503        quantized
504            .iter()
505            .map(|&q| (q as f32 / (levels / 2.0)) * max_val)
506            .collect()
507    }
508}
509
510impl SignalTokenizer for DCTTokenizer {
511    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
512        let coeffs = self.dct(
513            signal
514                .as_slice()
515                .expect("Signal must have contiguous layout"),
516        );
517        let quantized = self.quantize(&coeffs);
518
519        let tokens: Vec<f32> = quantized.iter().map(|&q| q as f32).collect();
520        Ok(Array1::from_vec(tokens))
521    }
522
523    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
524        let max_val = 1.0; // Simplified - should be stored
525        let quantized: Vec<i32> = tokens.iter().map(|&t| t as i32).collect();
526        let coeffs = self.dequantize(&quantized, max_val);
527
528        // Pad with zeros if needed
529        let mut full_coeffs = coeffs;
530        while full_coeffs.len() < tokens.len() {
531            full_coeffs.push(0.0);
532        }
533
534        let reconstructed = self.idct(&full_coeffs);
535        Ok(Array1::from_vec(reconstructed))
536    }
537
538    fn embed_dim(&self) -> usize {
539        self.config.num_coeffs
540    }
541
542    fn vocab_size(&self) -> usize {
543        1 << self.config.bits
544    }
545}
546
547/// Configuration for K-means clustering tokenizer
548#[derive(Debug, Clone, Serialize, Deserialize)]
549pub struct KMeansConfig {
550    /// Number of clusters (codebook size)
551    pub num_clusters: usize,
552    /// Embedding dimension (window size)
553    pub embed_dim: usize,
554    /// Maximum iterations for k-means
555    pub max_iterations: usize,
556    /// Convergence tolerance
557    pub tolerance: f32,
558}
559
560impl Default for KMeansConfig {
561    fn default() -> Self {
562        Self {
563            num_clusters: 256,
564            embed_dim: 16,
565            max_iterations: 100,
566            tolerance: 1e-4,
567        }
568    }
569}
570
571/// K-means clustering tokenizer for vector quantization
572pub struct KMeansTokenizer {
573    config: KMeansConfig,
574    centroids: Array2<f32>,
575    trained: bool,
576}
577
578impl KMeansTokenizer {
579    /// Create a new K-means tokenizer (untrained)
580    pub fn new(config: KMeansConfig) -> TokenizerResult<Self> {
581        if config.num_clusters == 0 {
582            return Err(TokenizerError::InvalidConfig(
583                "Number of clusters must be > 0".to_string(),
584            ));
585        }
586        if config.embed_dim == 0 {
587            return Err(TokenizerError::InvalidConfig(
588                "Embedding dimension must be > 0".to_string(),
589            ));
590        }
591
592        let centroids = Array2::zeros((config.num_clusters, config.embed_dim));
593
594        Ok(Self {
595            config,
596            centroids,
597            trained: false,
598        })
599    }
600
601    /// Train the k-means model on data
602    pub fn train(&mut self, data: &[Array1<f32>]) -> TokenizerResult<()> {
603        if data.is_empty() {
604            return Err(TokenizerError::InvalidConfig(
605                "No training data".to_string(),
606            ));
607        }
608
609        // Extract windows from signals
610        let mut windows = Vec::new();
611        for signal in data {
612            for i in 0..=signal.len().saturating_sub(self.config.embed_dim) {
613                let window = signal.slice(s![i..i + self.config.embed_dim]).to_owned();
614                windows.push(window);
615            }
616        }
617
618        if windows.len() < self.config.num_clusters {
619            return Err(TokenizerError::InvalidConfig(
620                "Not enough data for clustering".to_string(),
621            ));
622        }
623
624        // Initialize centroids with k-means++
625        self.kmeans_plus_plus_init(&windows)?;
626
627        // Run k-means iterations
628        for iteration in 0..self.config.max_iterations {
629            // Assignment step
630            let assignments = self.assign_clusters(&windows);
631
632            // Update step
633            let old_centroids = self.centroids.clone();
634            self.update_centroids(&windows, &assignments)?;
635
636            // Check convergence
637            let change = self.compute_centroid_change(&old_centroids);
638            if change < self.config.tolerance {
639                tracing::debug!("K-means converged at iteration {}", iteration);
640                break;
641            }
642        }
643
644        self.trained = true;
645        Ok(())
646    }
647
648    /// Initialize centroids using k-means++
649    fn kmeans_plus_plus_init(&mut self, windows: &[Array1<f32>]) -> TokenizerResult<()> {
650        use scirs2_core::random::quick::{random_f32, random_usize};
651
652        // Choose first centroid randomly
653        let first_idx = random_usize(0, windows.len() - 1);
654        self.centroids.row_mut(0).assign(&windows[first_idx].view());
655
656        // Choose remaining centroids
657        for k in 1..self.config.num_clusters {
658            let mut distances = vec![f32::MAX; windows.len()];
659
660            // Compute distances to nearest centroid
661            for (i, window) in windows.iter().enumerate() {
662                for j in 0..k {
663                    let centroid = self.centroids.row(j);
664                    let dist = Self::euclidean_distance(window, &centroid.to_owned());
665                    distances[i] = distances[i].min(dist);
666                }
667            }
668
669            // Choose next centroid with probability proportional to distance squared
670            let total: f32 = distances.iter().map(|&d| d * d).sum();
671            let mut threshold = random_f32() * total;
672            let mut chosen_idx = 0;
673
674            for (i, &dist) in distances.iter().enumerate() {
675                threshold -= dist * dist;
676                if threshold <= 0.0 {
677                    chosen_idx = i;
678                    break;
679                }
680            }
681
682            self.centroids
683                .row_mut(k)
684                .assign(&windows[chosen_idx].view());
685        }
686
687        Ok(())
688    }
689
690    /// Assign windows to nearest clusters
691    fn assign_clusters(&self, windows: &[Array1<f32>]) -> Vec<usize> {
692        windows
693            .iter()
694            .map(|window| self.find_nearest_centroid(window))
695            .collect()
696    }
697
698    /// Update centroids based on assignments
699    fn update_centroids(
700        &mut self,
701        windows: &[Array1<f32>],
702        assignments: &[usize],
703    ) -> TokenizerResult<()> {
704        let mut counts = vec![0usize; self.config.num_clusters];
705        self.centroids.fill(0.0);
706
707        // Accumulate
708        for (window, &cluster) in windows.iter().zip(assignments.iter()) {
709            for (i, &val) in window.iter().enumerate() {
710                self.centroids[[cluster, i]] += val;
711            }
712            counts[cluster] += 1;
713        }
714
715        // Average (handle empty clusters by keeping old centroid)
716        for (k, &count) in counts.iter().enumerate().take(self.config.num_clusters) {
717            if count > 0 {
718                for i in 0..self.config.embed_dim {
719                    self.centroids[[k, i]] /= count as f32;
720                }
721            }
722        }
723
724        Ok(())
725    }
726
727    /// Find nearest centroid for a window
728    fn find_nearest_centroid(&self, window: &Array1<f32>) -> usize {
729        (0..self.config.num_clusters)
730            .min_by(|&a, &b| {
731                let dist_a = Self::euclidean_distance(window, &self.centroids.row(a).to_owned());
732                let dist_b = Self::euclidean_distance(window, &self.centroids.row(b).to_owned());
733                dist_a
734                    .partial_cmp(&dist_b)
735                    .unwrap_or(std::cmp::Ordering::Equal)
736            })
737            .expect("Range must be non-empty")
738    }
739
740    /// Compute Euclidean distance
741    fn euclidean_distance(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
742        a.iter()
743            .zip(b.iter())
744            .map(|(x, y)| (x - y).powi(2))
745            .sum::<f32>()
746            .sqrt()
747    }
748
749    /// Compute change in centroids
750    fn compute_centroid_change(&self, old_centroids: &Array2<f32>) -> f32 {
751        self.centroids
752            .iter()
753            .zip(old_centroids.iter())
754            .map(|(a, b)| (a - b).powi(2))
755            .sum::<f32>()
756            .sqrt()
757    }
758
759    /// Check if model is trained
760    pub fn is_trained(&self) -> bool {
761        self.trained
762    }
763
764    /// Get centroids
765    pub fn centroids(&self) -> &Array2<f32> {
766        &self.centroids
767    }
768}
769
770impl SignalTokenizer for KMeansTokenizer {
771    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
772        if !self.trained {
773            return Err(TokenizerError::InvalidConfig(
774                "K-means model not trained".to_string(),
775            ));
776        }
777
778        let mut tokens = Vec::new();
779
780        // Extract windows and assign to clusters
781        for i in 0..=signal.len().saturating_sub(self.config.embed_dim) {
782            let window = signal.slice(s![i..i + self.config.embed_dim]).to_owned();
783            let cluster = self.find_nearest_centroid(&window);
784            tokens.push(cluster as f32);
785        }
786
787        Ok(Array1::from_vec(tokens))
788    }
789
790    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
791        if !self.trained {
792            return Err(TokenizerError::InvalidConfig(
793                "K-means model not trained".to_string(),
794            ));
795        }
796
797        // Simple overlap-add reconstruction
798        let output_len = tokens.len() + self.config.embed_dim - 1;
799        let mut signal = vec![0.0; output_len];
800        let mut counts = vec![0.0; output_len];
801
802        for (i, &token) in tokens.iter().enumerate() {
803            let cluster = token as usize;
804            if cluster >= self.config.num_clusters {
805                return Err(TokenizerError::invalid_input(
806                    "decoding",
807                    "Invalid cluster index",
808                ));
809            }
810
811            let centroid = self.centroids.row(cluster);
812            for (j, &val) in centroid.iter().enumerate() {
813                signal[i + j] += val;
814                counts[i + j] += 1.0;
815            }
816        }
817
818        // Average overlapping regions
819        for (s, c) in signal.iter_mut().zip(counts.iter()) {
820            if *c > 0.0 {
821                *s /= c;
822            }
823        }
824
825        Ok(Array1::from_vec(signal))
826    }
827
828    fn embed_dim(&self) -> usize {
829        self.config.embed_dim
830    }
831
832    fn vocab_size(&self) -> usize {
833        self.config.num_clusters
834    }
835}
836
837#[cfg(test)]
838mod tests {
839    use super::*;
840
841    #[test]
842    fn test_wavelet_haar_basic() {
843        let config = WaveletConfig {
844            levels: 2,
845            family: WaveletFamily::Haar,
846            bits: 8,
847        };
848        let tokenizer = WaveletTokenizer::new(config).unwrap();
849
850        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
851        let tokens = tokenizer.encode(&signal).unwrap();
852        assert!(!tokens.is_empty());
853
854        let reconstructed = tokenizer.decode(&tokens).unwrap();
855        assert_eq!(reconstructed.len(), signal.len());
856    }
857
858    #[test]
859    fn test_wavelet_daubechies4() {
860        let config = WaveletConfig {
861            levels: 1,
862            family: WaveletFamily::Daubechies4,
863            bits: 8,
864        };
865        let tokenizer = WaveletTokenizer::new(config).unwrap();
866
867        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
868        let tokens = tokenizer.encode(&signal).unwrap();
869        assert!(!tokens.is_empty());
870    }
871
872    #[test]
873    fn test_wavelet_invalid_config() {
874        let config = WaveletConfig {
875            levels: 0,
876            family: WaveletFamily::Haar,
877            bits: 8,
878        };
879        assert!(WaveletTokenizer::new(config).is_err());
880
881        let config = WaveletConfig {
882            levels: 1,
883            family: WaveletFamily::Haar,
884            bits: 0,
885        };
886        assert!(WaveletTokenizer::new(config).is_err());
887    }
888
889    #[test]
890    fn test_fourier_magnitude_only() {
891        let config = FourierConfig {
892            num_bins: 8,
893            magnitude_only: true,
894            bits: 8,
895        };
896        let tokenizer = FourierTokenizer::new(config).unwrap();
897
898        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 1.0, 0.0]);
899        let tokens = tokenizer.encode(&signal).unwrap();
900        assert_eq!(tokens.len(), 8); // 8 magnitude values
901
902        let reconstructed = tokenizer.decode(&tokens).unwrap();
903        assert_eq!(reconstructed.len(), 8);
904    }
905
906    #[test]
907    fn test_fourier_complex() {
908        let config = FourierConfig {
909            num_bins: 4,
910            magnitude_only: false,
911            bits: 8,
912        };
913        let tokenizer = FourierTokenizer::new(config).unwrap();
914
915        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
916        let tokens = tokenizer.encode(&signal).unwrap();
917        assert_eq!(tokens.len(), 8); // 4 bins * 2 (real + imag)
918    }
919
920    #[test]
921    fn test_dct_basic() {
922        let config = DCTConfig {
923            num_coeffs: 8,
924            bits: 8,
925        };
926        let tokenizer = DCTTokenizer::new(config).unwrap();
927
928        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
929        let tokens = tokenizer.encode(&signal).unwrap();
930        assert_eq!(tokens.len(), 8);
931
932        let reconstructed = tokenizer.decode(&tokens).unwrap();
933        assert_eq!(reconstructed.len(), 8);
934    }
935
936    #[test]
937    fn test_dct_compression() {
938        let config = DCTConfig {
939            num_coeffs: 4,
940            bits: 8,
941        };
942        let tokenizer = DCTTokenizer::new(config).unwrap();
943
944        // Smooth signal should compress well
945        let signal = Array1::from_vec(vec![1.0, 1.1, 1.2, 1.1, 1.0, 0.9, 0.8, 0.9]);
946        let tokens = tokenizer.encode(&signal).unwrap();
947        assert_eq!(tokens.len(), 4); // Compressed to 4 coeffs
948    }
949
950    #[test]
951    fn test_kmeans_training() {
952        let config = KMeansConfig {
953            num_clusters: 4,
954            embed_dim: 4,
955            max_iterations: 50,
956            tolerance: 1e-3,
957        };
958        let mut tokenizer = KMeansTokenizer::new(config).unwrap();
959
960        // Generate training data
961        let data = vec![
962            Array1::from_vec(vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0]),
963            Array1::from_vec(vec![3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0]),
964        ];
965
966        assert!(!tokenizer.is_trained());
967        tokenizer.train(&data).unwrap();
968        assert!(tokenizer.is_trained());
969
970        let centroids = tokenizer.centroids();
971        assert_eq!(centroids.shape(), &[4, 4]);
972    }
973
974    #[test]
975    fn test_kmeans_encode_decode() {
976        let config = KMeansConfig {
977            num_clusters: 8,
978            embed_dim: 4,
979            max_iterations: 100,
980            tolerance: 1e-4,
981        };
982        let mut tokenizer = KMeansTokenizer::new(config).unwrap();
983
984        // Training data
985        let data = vec![
986            Array1::from_vec((0..32).map(|x| x as f32).collect::<Vec<_>>()),
987            Array1::from_vec((0..32).map(|x| (x as f32).sin()).collect::<Vec<_>>()),
988        ];
989
990        tokenizer.train(&data).unwrap();
991
992        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
993        let tokens = tokenizer.encode(&signal).unwrap();
994        assert!(!tokens.is_empty());
995
996        let reconstructed = tokenizer.decode(&tokens).unwrap();
997        assert!(!reconstructed.is_empty());
998    }
999
1000    #[test]
1001    fn test_kmeans_untrained_error() {
1002        let config = KMeansConfig::default();
1003        let tokenizer = KMeansTokenizer::new(config).unwrap();
1004
1005        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1006        assert!(tokenizer.encode(&signal).is_err());
1007    }
1008
1009    #[test]
1010    fn test_kmeans_invalid_config() {
1011        let config = KMeansConfig {
1012            num_clusters: 0,
1013            embed_dim: 4,
1014            max_iterations: 10,
1015            tolerance: 1e-3,
1016        };
1017        assert!(KMeansTokenizer::new(config).is_err());
1018    }
1019
1020    #[test]
1021    fn test_signal_tokenizer_trait() {
1022        let tokenizers: Vec<Box<dyn SignalTokenizer>> = vec![
1023            Box::new(
1024                WaveletTokenizer::new(WaveletConfig {
1025                    levels: 1,
1026                    family: WaveletFamily::Haar,
1027                    bits: 8,
1028                })
1029                .unwrap(),
1030            ),
1031            Box::new(
1032                FourierTokenizer::new(FourierConfig {
1033                    num_bins: 8,
1034                    magnitude_only: true,
1035                    bits: 8,
1036                })
1037                .unwrap(),
1038            ),
1039            Box::new(
1040                DCTTokenizer::new(DCTConfig {
1041                    num_coeffs: 8,
1042                    bits: 8,
1043                })
1044                .unwrap(),
1045            ),
1046        ];
1047
1048        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1049
1050        for tokenizer in tokenizers {
1051            let tokens = tokenizer.encode(&signal).unwrap();
1052            assert!(!tokens.is_empty());
1053            assert!(tokenizer.vocab_size() > 0 || tokenizer.embed_dim() > 0);
1054        }
1055    }
1056}