kizzasi_tokenizer/
serde_utils.rs

1//! Serialization and deserialization utilities for tokenizers
2//!
3//! Provides methods to save and load tokenizer weights and configurations
4//! in various formats (JSON, binary, safetensors).
5
6use crate::error::{TokenizerError, TokenizerResult};
7use crate::{ContinuousTokenizer, LinearQuantizer, MuLawCodec, SignalTokenizer};
8#[cfg(feature = "vqvae")]
9use crate::{VQConfig, VQVAETokenizer};
10use scirs2_core::ndarray::Array2;
11use serde::{Deserialize, Serialize};
12use std::fs::File;
13use std::io::{BufReader, BufWriter, Read, Write};
14use std::path::Path;
15
16/// Serializable configuration for ContinuousTokenizer
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ContinuousTokenizerConfig {
19    pub input_dim: usize,
20    pub embed_dim: usize,
21    #[serde(with = "array2_serde")]
22    pub encoder: Array2<f32>,
23    #[serde(with = "array2_serde")]
24    pub decoder: Array2<f32>,
25}
26
27impl ContinuousTokenizerConfig {
28    /// Convert from ContinuousTokenizer
29    pub fn from_tokenizer(tokenizer: &ContinuousTokenizer) -> Self {
30        Self {
31            input_dim: tokenizer.input_dim(),
32            embed_dim: tokenizer.embed_dim(),
33            encoder: tokenizer.encoder().clone(),
34            decoder: tokenizer.decoder().clone(),
35        }
36    }
37
38    /// Convert to ContinuousTokenizer
39    pub fn to_tokenizer(&self) -> TokenizerResult<ContinuousTokenizer> {
40        let mut tokenizer = ContinuousTokenizer::new(self.input_dim, self.embed_dim);
41        tokenizer.set_encoder(self.encoder.clone())?;
42        tokenizer.set_decoder(self.decoder.clone())?;
43        Ok(tokenizer)
44    }
45}
46
47/// Serializable configuration for LinearQuantizer
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct LinearQuantizerConfig {
50    pub min: f32,
51    pub max: f32,
52    pub bits: u8,
53}
54
55impl LinearQuantizerConfig {
56    /// Convert from LinearQuantizer
57    pub fn from_quantizer(quantizer: &LinearQuantizer) -> Self {
58        let (min, max) = quantizer.range();
59        Self {
60            min,
61            max,
62            bits: quantizer.bits(),
63        }
64    }
65
66    /// Convert to LinearQuantizer
67    pub fn to_quantizer(&self) -> TokenizerResult<LinearQuantizer> {
68        LinearQuantizer::new(self.min, self.max, self.bits)
69    }
70}
71
72/// Serializable configuration for MuLawCodec
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct MuLawCodecConfig {
75    pub mu: f32,
76    pub bits: u8,
77}
78
79impl MuLawCodecConfig {
80    /// Convert from MuLawCodec
81    pub fn from_codec(codec: &MuLawCodec) -> Self {
82        Self {
83            mu: codec.mu(),
84            bits: codec.bits(),
85        }
86    }
87
88    /// Convert to MuLawCodec
89    pub fn to_codec(&self) -> MuLawCodec {
90        MuLawCodec::with_mu(self.mu, self.bits)
91    }
92}
93
94/// Serializable configuration for MultiScaleTokenizer
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct MultiScaleTokenizerConfig {
97    pub input_dim: usize,
98    pub embed_dim_per_level: usize,
99    pub downsample_factors: Vec<usize>,
100    #[serde(with = "vec_array2_serde")]
101    pub encoders: Vec<Array2<f32>>,
102    #[serde(with = "vec_array2_serde")]
103    pub decoders: Vec<Array2<f32>>,
104}
105
106/// Serializable configuration for VQVAETokenizer
107#[cfg(feature = "vqvae")]
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct VQVAETokenizerConfig {
110    pub input_dim: usize,
111    pub vq_config: VQConfig,
112    #[serde(with = "array2_serde")]
113    pub encoder: Array2<f32>,
114    #[serde(with = "array2_serde")]
115    pub decoder: Array2<f32>,
116    #[serde(with = "array2_serde")]
117    pub codebook: Array2<f32>,
118}
119
120#[cfg(feature = "vqvae")]
121impl VQVAETokenizerConfig {
122    /// Convert from VQVAETokenizer
123    pub fn from_tokenizer(tokenizer: &VQVAETokenizer) -> Self {
124        Self {
125            input_dim: tokenizer.encoder().shape()[0],
126            vq_config: VQConfig {
127                codebook_size: tokenizer.quantizer().codebook_size(),
128                embed_dim: tokenizer.quantizer().embed_dim(),
129                commitment_beta: 0.25,
130                ema_decay: 0.99,
131                epsilon: 1e-5,
132                use_ema: true,
133            },
134            encoder: tokenizer.encoder().clone(),
135            decoder: tokenizer.decoder().clone(),
136            codebook: tokenizer.quantizer().codebook().clone(),
137        }
138    }
139
140    /// Convert to VQVAETokenizer
141    pub fn to_tokenizer(&self) -> TokenizerResult<VQVAETokenizer> {
142        let mut tokenizer = VQVAETokenizer::new(self.input_dim, self.vq_config.clone());
143        tokenizer.set_encoder(self.encoder.clone())?;
144        tokenizer.set_decoder(self.decoder.clone())?;
145        // Note: We would need to add a method to set codebook in VectorQuantizer
146        Ok(tokenizer)
147    }
148}
149
150/// Custom serde module for Array2<f32>
151mod array2_serde {
152    use scirs2_core::ndarray::Array2;
153    use serde::{Deserialize, Deserializer, Serialize, Serializer};
154
155    #[derive(Serialize, Deserialize)]
156    struct Array2Data {
157        shape: Vec<usize>,
158        data: Vec<f32>,
159    }
160
161    pub fn serialize<S>(array: &Array2<f32>, serializer: S) -> Result<S::Ok, S::Error>
162    where
163        S: Serializer,
164    {
165        let shape = array.shape().to_vec();
166        let data = array.iter().cloned().collect();
167        let wrapper = Array2Data { shape, data };
168        wrapper.serialize(serializer)
169    }
170
171    pub fn deserialize<'de, D>(deserializer: D) -> Result<Array2<f32>, D::Error>
172    where
173        D: Deserializer<'de>,
174    {
175        let wrapper = Array2Data::deserialize(deserializer)?;
176        if wrapper.shape.len() != 2 {
177            return Err(serde::de::Error::custom("Expected 2D array"));
178        }
179        Array2::from_shape_vec((wrapper.shape[0], wrapper.shape[1]), wrapper.data)
180            .map_err(serde::de::Error::custom)
181    }
182}
183
184/// Custom serde module for Vec<Array2<f32>>
185mod vec_array2_serde {
186    use scirs2_core::ndarray::Array2;
187    use serde::{Deserialize, Deserializer, Serialize, Serializer};
188
189    #[derive(Serialize, Deserialize)]
190    struct Array2Data {
191        shape: Vec<usize>,
192        data: Vec<f32>,
193    }
194
195    pub fn serialize<S>(arrays: &[Array2<f32>], serializer: S) -> Result<S::Ok, S::Error>
196    where
197        S: Serializer,
198    {
199        let wrappers: Vec<Array2Data> = arrays
200            .iter()
201            .map(|array| Array2Data {
202                shape: array.shape().to_vec(),
203                data: array.iter().cloned().collect(),
204            })
205            .collect();
206        wrappers.serialize(serializer)
207    }
208
209    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<Array2<f32>>, D::Error>
210    where
211        D: Deserializer<'de>,
212    {
213        let wrappers: Vec<Array2Data> = Vec::deserialize(deserializer)?;
214        wrappers
215            .into_iter()
216            .map(|wrapper| {
217                if wrapper.shape.len() != 2 {
218                    return Err(serde::de::Error::custom("Expected 2D array"));
219                }
220                Array2::from_shape_vec((wrapper.shape[0], wrapper.shape[1]), wrapper.data)
221                    .map_err(serde::de::Error::custom)
222            })
223            .collect()
224    }
225}
226
227/// Trait for saving and loading tokenizers
228pub trait TokenizerIO: Sized {
229    /// Save to JSON file
230    fn save_json<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()>;
231
232    /// Load from JSON file
233    fn load_json<P: AsRef<Path>>(path: P) -> TokenizerResult<Self>;
234
235    /// Save to binary file (using bincode)
236    fn save_binary<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()>;
237
238    /// Load from binary file (using bincode)
239    fn load_binary<P: AsRef<Path>>(path: P) -> TokenizerResult<Self>;
240}
241
242// Implement TokenizerIO for ContinuousTokenizer
243impl TokenizerIO for ContinuousTokenizer {
244    fn save_json<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
245        let config = ContinuousTokenizerConfig::from_tokenizer(self);
246        let file = File::create(path).map_err(|e| {
247            TokenizerError::encoding("serialization", format!("Failed to create file: {}", e))
248        })?;
249        let writer = BufWriter::new(file);
250        serde_json::to_writer_pretty(writer, &config).map_err(|e| {
251            TokenizerError::encoding("serialization", format!("JSON serialization failed: {}", e))
252        })
253    }
254
255    fn load_json<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
256        let file = File::open(path).map_err(|e| {
257            TokenizerError::decoding("deserialization", format!("Failed to open file: {}", e))
258        })?;
259        let reader = BufReader::new(file);
260        let config: ContinuousTokenizerConfig = serde_json::from_reader(reader).map_err(|e| {
261            TokenizerError::decoding(
262                "deserialization",
263                format!("JSON deserialization failed: {}", e),
264            )
265        })?;
266        config.to_tokenizer()
267    }
268
269    fn save_binary<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
270        let config = ContinuousTokenizerConfig::from_tokenizer(self);
271        let file = File::create(path).map_err(|e| {
272            TokenizerError::encoding("serialization", format!("Failed to create file: {}", e))
273        })?;
274        let mut writer = BufWriter::new(file);
275        let encoded = serde_json::to_vec(&config).map_err(|e| {
276            TokenizerError::encoding(
277                "serialization",
278                format!("Binary serialization failed: {}", e),
279            )
280        })?;
281        writer.write_all(&encoded).map_err(|e| {
282            TokenizerError::encoding("serialization", format!("Failed to write file: {}", e))
283        })
284    }
285
286    fn load_binary<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
287        let file = File::open(path).map_err(|e| {
288            TokenizerError::decoding("deserialization", format!("Failed to open file: {}", e))
289        })?;
290        let mut reader = BufReader::new(file);
291        let mut buffer = Vec::new();
292        reader.read_to_end(&mut buffer).map_err(|e| {
293            TokenizerError::decoding("deserialization", format!("Failed to read file: {}", e))
294        })?;
295        let config: ContinuousTokenizerConfig = serde_json::from_slice(&buffer).map_err(|e| {
296            TokenizerError::decoding(
297                "deserialization",
298                format!("Binary deserialization failed: {}", e),
299            )
300        })?;
301        config.to_tokenizer()
302    }
303}
304
305// Implement TokenizerIO for LinearQuantizer
306impl TokenizerIO for LinearQuantizer {
307    fn save_json<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
308        let config = LinearQuantizerConfig::from_quantizer(self);
309        let file = File::create(path).map_err(|e| {
310            TokenizerError::encoding("serialization", format!("Failed to create file: {}", e))
311        })?;
312        let writer = BufWriter::new(file);
313        serde_json::to_writer_pretty(writer, &config).map_err(|e| {
314            TokenizerError::encoding("serialization", format!("JSON serialization failed: {}", e))
315        })
316    }
317
318    fn load_json<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
319        let file = File::open(path).map_err(|e| {
320            TokenizerError::decoding("deserialization", format!("Failed to open file: {}", e))
321        })?;
322        let reader = BufReader::new(file);
323        let config: LinearQuantizerConfig = serde_json::from_reader(reader).map_err(|e| {
324            TokenizerError::decoding(
325                "deserialization",
326                format!("JSON deserialization failed: {}", e),
327            )
328        })?;
329        config.to_quantizer()
330    }
331
332    fn save_binary<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
333        let config = LinearQuantizerConfig::from_quantizer(self);
334        let file = File::create(path).map_err(|e| {
335            TokenizerError::encoding("serialization", format!("Failed to create file: {}", e))
336        })?;
337        let mut writer = BufWriter::new(file);
338        let encoded = serde_json::to_vec(&config).map_err(|e| {
339            TokenizerError::encoding(
340                "serialization",
341                format!("Binary serialization failed: {}", e),
342            )
343        })?;
344        writer.write_all(&encoded).map_err(|e| {
345            TokenizerError::encoding("serialization", format!("Failed to write file: {}", e))
346        })
347    }
348
349    fn load_binary<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
350        let file = File::open(path).map_err(|e| {
351            TokenizerError::decoding("deserialization", format!("Failed to open file: {}", e))
352        })?;
353        let mut reader = BufReader::new(file);
354        let mut buffer = Vec::new();
355        reader.read_to_end(&mut buffer).map_err(|e| {
356            TokenizerError::decoding("deserialization", format!("Failed to read file: {}", e))
357        })?;
358        let config: LinearQuantizerConfig = serde_json::from_slice(&buffer).map_err(|e| {
359            TokenizerError::decoding(
360                "deserialization",
361                format!("Binary deserialization failed: {}", e),
362            )
363        })?;
364        config.to_quantizer()
365    }
366}
367
368// Implement TokenizerIO for MuLawCodec
369impl TokenizerIO for MuLawCodec {
370    fn save_json<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
371        let config = MuLawCodecConfig::from_codec(self);
372        let file = File::create(path).map_err(|e| {
373            TokenizerError::encoding("serialization", format!("Failed to create file: {}", e))
374        })?;
375        let writer = BufWriter::new(file);
376        serde_json::to_writer_pretty(writer, &config).map_err(|e| {
377            TokenizerError::encoding("serialization", format!("JSON serialization failed: {}", e))
378        })
379    }
380
381    fn load_json<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
382        let file = File::open(path).map_err(|e| {
383            TokenizerError::decoding("deserialization", format!("Failed to open file: {}", e))
384        })?;
385        let reader = BufReader::new(file);
386        let config: MuLawCodecConfig = serde_json::from_reader(reader).map_err(|e| {
387            TokenizerError::decoding(
388                "deserialization",
389                format!("JSON deserialization failed: {}", e),
390            )
391        })?;
392        Ok(config.to_codec())
393    }
394
395    fn save_binary<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
396        let config = MuLawCodecConfig::from_codec(self);
397        let file = File::create(path).map_err(|e| {
398            TokenizerError::encoding("serialization", format!("Failed to create file: {}", e))
399        })?;
400        let mut writer = BufWriter::new(file);
401        let encoded = serde_json::to_vec(&config).map_err(|e| {
402            TokenizerError::encoding(
403                "serialization",
404                format!("Binary serialization failed: {}", e),
405            )
406        })?;
407        writer.write_all(&encoded).map_err(|e| {
408            TokenizerError::encoding("serialization", format!("Failed to write file: {}", e))
409        })
410    }
411
412    fn load_binary<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
413        let file = File::open(path).map_err(|e| {
414            TokenizerError::decoding("deserialization", format!("Failed to open file: {}", e))
415        })?;
416        let mut reader = BufReader::new(file);
417        let mut buffer = Vec::new();
418        reader.read_to_end(&mut buffer).map_err(|e| {
419            TokenizerError::decoding("deserialization", format!("Failed to read file: {}", e))
420        })?;
421        let config: MuLawCodecConfig = serde_json::from_slice(&buffer).map_err(|e| {
422            TokenizerError::decoding(
423                "deserialization",
424                format!("Binary deserialization failed: {}", e),
425            )
426        })?;
427        Ok(config.to_codec())
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    use crate::SignalTokenizer;
435    use std::env;
436
437    #[test]
438    fn test_continuous_tokenizer_json_roundtrip() {
439        let tokenizer = ContinuousTokenizer::new(10, 20);
440
441        let temp_dir = env::temp_dir();
442        let path = temp_dir.join("test_continuous.json");
443
444        tokenizer.save_json(&path).unwrap();
445        let loaded = ContinuousTokenizer::load_json(&path).unwrap();
446
447        assert_eq!(tokenizer.input_dim(), loaded.input_dim());
448        assert_eq!(tokenizer.embed_dim(), loaded.embed_dim());
449
450        // Cleanup
451        let _ = std::fs::remove_file(&path);
452    }
453
454    #[test]
455    fn test_continuous_tokenizer_binary_roundtrip() {
456        let tokenizer = ContinuousTokenizer::new(10, 20);
457
458        let temp_dir = env::temp_dir();
459        let path = temp_dir.join("test_continuous.bin");
460
461        tokenizer.save_binary(&path).unwrap();
462        let loaded = ContinuousTokenizer::load_binary(&path).unwrap();
463
464        assert_eq!(tokenizer.input_dim(), loaded.input_dim());
465        assert_eq!(tokenizer.embed_dim(), loaded.embed_dim());
466
467        // Cleanup
468        let _ = std::fs::remove_file(&path);
469    }
470
471    #[test]
472    fn test_linear_quantizer_json_roundtrip() {
473        let quantizer = LinearQuantizer::new(-1.0, 1.0, 8).unwrap();
474
475        let temp_dir = env::temp_dir();
476        let path = temp_dir.join("test_linear.json");
477
478        quantizer.save_json(&path).unwrap();
479        let loaded = LinearQuantizer::load_json(&path).unwrap();
480
481        assert_eq!(quantizer.range(), loaded.range());
482        assert_eq!(quantizer.bits(), loaded.bits());
483
484        // Cleanup
485        let _ = std::fs::remove_file(&path);
486    }
487
488    #[test]
489    fn test_mulaw_codec_json_roundtrip() {
490        let codec = MuLawCodec::new(8);
491
492        let temp_dir = env::temp_dir();
493        let path = temp_dir.join("test_mulaw.json");
494
495        codec.save_json(&path).unwrap();
496        let loaded = MuLawCodec::load_json(&path).unwrap();
497
498        assert_eq!(codec.mu(), loaded.mu());
499        assert_eq!(codec.bits(), loaded.bits());
500
501        // Cleanup
502        let _ = std::fs::remove_file(&path);
503    }
504}