1use crate::error::{TokenizerError, TokenizerResult};
7use crate::{ContinuousTokenizer, LinearQuantizer, MuLawCodec, SignalTokenizer};
8#[cfg(feature = "vqvae")]
9use crate::{VQConfig, VQVAETokenizer};
10use scirs2_core::ndarray::Array2;
11use serde::{Deserialize, Serialize};
12use std::fs::File;
13use std::io::{BufReader, BufWriter, Read, Write};
14use std::path::Path;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ContinuousTokenizerConfig {
19 pub input_dim: usize,
20 pub embed_dim: usize,
21 #[serde(with = "array2_serde")]
22 pub encoder: Array2<f32>,
23 #[serde(with = "array2_serde")]
24 pub decoder: Array2<f32>,
25}
26
27impl ContinuousTokenizerConfig {
28 pub fn from_tokenizer(tokenizer: &ContinuousTokenizer) -> Self {
30 Self {
31 input_dim: tokenizer.input_dim(),
32 embed_dim: tokenizer.embed_dim(),
33 encoder: tokenizer.encoder().clone(),
34 decoder: tokenizer.decoder().clone(),
35 }
36 }
37
38 pub fn to_tokenizer(&self) -> TokenizerResult<ContinuousTokenizer> {
40 let mut tokenizer = ContinuousTokenizer::new(self.input_dim, self.embed_dim);
41 tokenizer.set_encoder(self.encoder.clone())?;
42 tokenizer.set_decoder(self.decoder.clone())?;
43 Ok(tokenizer)
44 }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct LinearQuantizerConfig {
50 pub min: f32,
51 pub max: f32,
52 pub bits: u8,
53}
54
55impl LinearQuantizerConfig {
56 pub fn from_quantizer(quantizer: &LinearQuantizer) -> Self {
58 let (min, max) = quantizer.range();
59 Self {
60 min,
61 max,
62 bits: quantizer.bits(),
63 }
64 }
65
66 pub fn to_quantizer(&self) -> TokenizerResult<LinearQuantizer> {
68 LinearQuantizer::new(self.min, self.max, self.bits)
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct MuLawCodecConfig {
75 pub mu: f32,
76 pub bits: u8,
77}
78
79impl MuLawCodecConfig {
80 pub fn from_codec(codec: &MuLawCodec) -> Self {
82 Self {
83 mu: codec.mu(),
84 bits: codec.bits(),
85 }
86 }
87
88 pub fn to_codec(&self) -> MuLawCodec {
90 MuLawCodec::with_mu(self.mu, self.bits)
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct MultiScaleTokenizerConfig {
97 pub input_dim: usize,
98 pub embed_dim_per_level: usize,
99 pub downsample_factors: Vec<usize>,
100 #[serde(with = "vec_array2_serde")]
101 pub encoders: Vec<Array2<f32>>,
102 #[serde(with = "vec_array2_serde")]
103 pub decoders: Vec<Array2<f32>>,
104}
105
106#[cfg(feature = "vqvae")]
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct VQVAETokenizerConfig {
110 pub input_dim: usize,
111 pub vq_config: VQConfig,
112 #[serde(with = "array2_serde")]
113 pub encoder: Array2<f32>,
114 #[serde(with = "array2_serde")]
115 pub decoder: Array2<f32>,
116 #[serde(with = "array2_serde")]
117 pub codebook: Array2<f32>,
118}
119
120#[cfg(feature = "vqvae")]
121impl VQVAETokenizerConfig {
122 pub fn from_tokenizer(tokenizer: &VQVAETokenizer) -> Self {
124 Self {
125 input_dim: tokenizer.encoder().shape()[0],
126 vq_config: VQConfig {
127 codebook_size: tokenizer.quantizer().codebook_size(),
128 embed_dim: tokenizer.quantizer().embed_dim(),
129 commitment_beta: 0.25,
130 ema_decay: 0.99,
131 epsilon: 1e-5,
132 use_ema: true,
133 },
134 encoder: tokenizer.encoder().clone(),
135 decoder: tokenizer.decoder().clone(),
136 codebook: tokenizer.quantizer().codebook().clone(),
137 }
138 }
139
140 pub fn to_tokenizer(&self) -> TokenizerResult<VQVAETokenizer> {
142 let mut tokenizer = VQVAETokenizer::new(self.input_dim, self.vq_config.clone());
143 tokenizer.set_encoder(self.encoder.clone())?;
144 tokenizer.set_decoder(self.decoder.clone())?;
145 Ok(tokenizer)
147 }
148}
149
150mod array2_serde {
152 use scirs2_core::ndarray::Array2;
153 use serde::{Deserialize, Deserializer, Serialize, Serializer};
154
155 #[derive(Serialize, Deserialize)]
156 struct Array2Data {
157 shape: Vec<usize>,
158 data: Vec<f32>,
159 }
160
161 pub fn serialize<S>(array: &Array2<f32>, serializer: S) -> Result<S::Ok, S::Error>
162 where
163 S: Serializer,
164 {
165 let shape = array.shape().to_vec();
166 let data = array.iter().cloned().collect();
167 let wrapper = Array2Data { shape, data };
168 wrapper.serialize(serializer)
169 }
170
171 pub fn deserialize<'de, D>(deserializer: D) -> Result<Array2<f32>, D::Error>
172 where
173 D: Deserializer<'de>,
174 {
175 let wrapper = Array2Data::deserialize(deserializer)?;
176 if wrapper.shape.len() != 2 {
177 return Err(serde::de::Error::custom("Expected 2D array"));
178 }
179 Array2::from_shape_vec((wrapper.shape[0], wrapper.shape[1]), wrapper.data)
180 .map_err(serde::de::Error::custom)
181 }
182}
183
184mod vec_array2_serde {
186 use scirs2_core::ndarray::Array2;
187 use serde::{Deserialize, Deserializer, Serialize, Serializer};
188
189 #[derive(Serialize, Deserialize)]
190 struct Array2Data {
191 shape: Vec<usize>,
192 data: Vec<f32>,
193 }
194
195 pub fn serialize<S>(arrays: &[Array2<f32>], serializer: S) -> Result<S::Ok, S::Error>
196 where
197 S: Serializer,
198 {
199 let wrappers: Vec<Array2Data> = arrays
200 .iter()
201 .map(|array| Array2Data {
202 shape: array.shape().to_vec(),
203 data: array.iter().cloned().collect(),
204 })
205 .collect();
206 wrappers.serialize(serializer)
207 }
208
209 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<Array2<f32>>, D::Error>
210 where
211 D: Deserializer<'de>,
212 {
213 let wrappers: Vec<Array2Data> = Vec::deserialize(deserializer)?;
214 wrappers
215 .into_iter()
216 .map(|wrapper| {
217 if wrapper.shape.len() != 2 {
218 return Err(serde::de::Error::custom("Expected 2D array"));
219 }
220 Array2::from_shape_vec((wrapper.shape[0], wrapper.shape[1]), wrapper.data)
221 .map_err(serde::de::Error::custom)
222 })
223 .collect()
224 }
225}
226
227pub trait TokenizerIO: Sized {
229 fn save_json<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()>;
231
232 fn load_json<P: AsRef<Path>>(path: P) -> TokenizerResult<Self>;
234
235 fn save_binary<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()>;
237
238 fn load_binary<P: AsRef<Path>>(path: P) -> TokenizerResult<Self>;
240}
241
242impl TokenizerIO for ContinuousTokenizer {
244 fn save_json<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
245 let config = ContinuousTokenizerConfig::from_tokenizer(self);
246 let file = File::create(path).map_err(|e| {
247 TokenizerError::encoding("serialization", format!("Failed to create file: {}", e))
248 })?;
249 let writer = BufWriter::new(file);
250 serde_json::to_writer_pretty(writer, &config).map_err(|e| {
251 TokenizerError::encoding("serialization", format!("JSON serialization failed: {}", e))
252 })
253 }
254
255 fn load_json<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
256 let file = File::open(path).map_err(|e| {
257 TokenizerError::decoding("deserialization", format!("Failed to open file: {}", e))
258 })?;
259 let reader = BufReader::new(file);
260 let config: ContinuousTokenizerConfig = serde_json::from_reader(reader).map_err(|e| {
261 TokenizerError::decoding(
262 "deserialization",
263 format!("JSON deserialization failed: {}", e),
264 )
265 })?;
266 config.to_tokenizer()
267 }
268
269 fn save_binary<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
270 let config = ContinuousTokenizerConfig::from_tokenizer(self);
271 let file = File::create(path).map_err(|e| {
272 TokenizerError::encoding("serialization", format!("Failed to create file: {}", e))
273 })?;
274 let mut writer = BufWriter::new(file);
275 let encoded = serde_json::to_vec(&config).map_err(|e| {
276 TokenizerError::encoding(
277 "serialization",
278 format!("Binary serialization failed: {}", e),
279 )
280 })?;
281 writer.write_all(&encoded).map_err(|e| {
282 TokenizerError::encoding("serialization", format!("Failed to write file: {}", e))
283 })
284 }
285
286 fn load_binary<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
287 let file = File::open(path).map_err(|e| {
288 TokenizerError::decoding("deserialization", format!("Failed to open file: {}", e))
289 })?;
290 let mut reader = BufReader::new(file);
291 let mut buffer = Vec::new();
292 reader.read_to_end(&mut buffer).map_err(|e| {
293 TokenizerError::decoding("deserialization", format!("Failed to read file: {}", e))
294 })?;
295 let config: ContinuousTokenizerConfig = serde_json::from_slice(&buffer).map_err(|e| {
296 TokenizerError::decoding(
297 "deserialization",
298 format!("Binary deserialization failed: {}", e),
299 )
300 })?;
301 config.to_tokenizer()
302 }
303}
304
305impl TokenizerIO for LinearQuantizer {
307 fn save_json<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
308 let config = LinearQuantizerConfig::from_quantizer(self);
309 let file = File::create(path).map_err(|e| {
310 TokenizerError::encoding("serialization", format!("Failed to create file: {}", e))
311 })?;
312 let writer = BufWriter::new(file);
313 serde_json::to_writer_pretty(writer, &config).map_err(|e| {
314 TokenizerError::encoding("serialization", format!("JSON serialization failed: {}", e))
315 })
316 }
317
318 fn load_json<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
319 let file = File::open(path).map_err(|e| {
320 TokenizerError::decoding("deserialization", format!("Failed to open file: {}", e))
321 })?;
322 let reader = BufReader::new(file);
323 let config: LinearQuantizerConfig = serde_json::from_reader(reader).map_err(|e| {
324 TokenizerError::decoding(
325 "deserialization",
326 format!("JSON deserialization failed: {}", e),
327 )
328 })?;
329 config.to_quantizer()
330 }
331
332 fn save_binary<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
333 let config = LinearQuantizerConfig::from_quantizer(self);
334 let file = File::create(path).map_err(|e| {
335 TokenizerError::encoding("serialization", format!("Failed to create file: {}", e))
336 })?;
337 let mut writer = BufWriter::new(file);
338 let encoded = serde_json::to_vec(&config).map_err(|e| {
339 TokenizerError::encoding(
340 "serialization",
341 format!("Binary serialization failed: {}", e),
342 )
343 })?;
344 writer.write_all(&encoded).map_err(|e| {
345 TokenizerError::encoding("serialization", format!("Failed to write file: {}", e))
346 })
347 }
348
349 fn load_binary<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
350 let file = File::open(path).map_err(|e| {
351 TokenizerError::decoding("deserialization", format!("Failed to open file: {}", e))
352 })?;
353 let mut reader = BufReader::new(file);
354 let mut buffer = Vec::new();
355 reader.read_to_end(&mut buffer).map_err(|e| {
356 TokenizerError::decoding("deserialization", format!("Failed to read file: {}", e))
357 })?;
358 let config: LinearQuantizerConfig = serde_json::from_slice(&buffer).map_err(|e| {
359 TokenizerError::decoding(
360 "deserialization",
361 format!("Binary deserialization failed: {}", e),
362 )
363 })?;
364 config.to_quantizer()
365 }
366}
367
368impl TokenizerIO for MuLawCodec {
370 fn save_json<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
371 let config = MuLawCodecConfig::from_codec(self);
372 let file = File::create(path).map_err(|e| {
373 TokenizerError::encoding("serialization", format!("Failed to create file: {}", e))
374 })?;
375 let writer = BufWriter::new(file);
376 serde_json::to_writer_pretty(writer, &config).map_err(|e| {
377 TokenizerError::encoding("serialization", format!("JSON serialization failed: {}", e))
378 })
379 }
380
381 fn load_json<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
382 let file = File::open(path).map_err(|e| {
383 TokenizerError::decoding("deserialization", format!("Failed to open file: {}", e))
384 })?;
385 let reader = BufReader::new(file);
386 let config: MuLawCodecConfig = serde_json::from_reader(reader).map_err(|e| {
387 TokenizerError::decoding(
388 "deserialization",
389 format!("JSON deserialization failed: {}", e),
390 )
391 })?;
392 Ok(config.to_codec())
393 }
394
395 fn save_binary<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
396 let config = MuLawCodecConfig::from_codec(self);
397 let file = File::create(path).map_err(|e| {
398 TokenizerError::encoding("serialization", format!("Failed to create file: {}", e))
399 })?;
400 let mut writer = BufWriter::new(file);
401 let encoded = serde_json::to_vec(&config).map_err(|e| {
402 TokenizerError::encoding(
403 "serialization",
404 format!("Binary serialization failed: {}", e),
405 )
406 })?;
407 writer.write_all(&encoded).map_err(|e| {
408 TokenizerError::encoding("serialization", format!("Failed to write file: {}", e))
409 })
410 }
411
412 fn load_binary<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
413 let file = File::open(path).map_err(|e| {
414 TokenizerError::decoding("deserialization", format!("Failed to open file: {}", e))
415 })?;
416 let mut reader = BufReader::new(file);
417 let mut buffer = Vec::new();
418 reader.read_to_end(&mut buffer).map_err(|e| {
419 TokenizerError::decoding("deserialization", format!("Failed to read file: {}", e))
420 })?;
421 let config: MuLawCodecConfig = serde_json::from_slice(&buffer).map_err(|e| {
422 TokenizerError::decoding(
423 "deserialization",
424 format!("Binary deserialization failed: {}", e),
425 )
426 })?;
427 Ok(config.to_codec())
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434 use crate::SignalTokenizer;
435 use std::env;
436
437 #[test]
438 fn test_continuous_tokenizer_json_roundtrip() {
439 let tokenizer = ContinuousTokenizer::new(10, 20);
440
441 let temp_dir = env::temp_dir();
442 let path = temp_dir.join("test_continuous.json");
443
444 tokenizer.save_json(&path).unwrap();
445 let loaded = ContinuousTokenizer::load_json(&path).unwrap();
446
447 assert_eq!(tokenizer.input_dim(), loaded.input_dim());
448 assert_eq!(tokenizer.embed_dim(), loaded.embed_dim());
449
450 let _ = std::fs::remove_file(&path);
452 }
453
454 #[test]
455 fn test_continuous_tokenizer_binary_roundtrip() {
456 let tokenizer = ContinuousTokenizer::new(10, 20);
457
458 let temp_dir = env::temp_dir();
459 let path = temp_dir.join("test_continuous.bin");
460
461 tokenizer.save_binary(&path).unwrap();
462 let loaded = ContinuousTokenizer::load_binary(&path).unwrap();
463
464 assert_eq!(tokenizer.input_dim(), loaded.input_dim());
465 assert_eq!(tokenizer.embed_dim(), loaded.embed_dim());
466
467 let _ = std::fs::remove_file(&path);
469 }
470
471 #[test]
472 fn test_linear_quantizer_json_roundtrip() {
473 let quantizer = LinearQuantizer::new(-1.0, 1.0, 8).unwrap();
474
475 let temp_dir = env::temp_dir();
476 let path = temp_dir.join("test_linear.json");
477
478 quantizer.save_json(&path).unwrap();
479 let loaded = LinearQuantizer::load_json(&path).unwrap();
480
481 assert_eq!(quantizer.range(), loaded.range());
482 assert_eq!(quantizer.bits(), loaded.bits());
483
484 let _ = std::fs::remove_file(&path);
486 }
487
488 #[test]
489 fn test_mulaw_codec_json_roundtrip() {
490 let codec = MuLawCodec::new(8);
491
492 let temp_dir = env::temp_dir();
493 let path = temp_dir.join("test_mulaw.json");
494
495 codec.save_json(&path).unwrap();
496 let loaded = MuLawCodec::load_json(&path).unwrap();
497
498 assert_eq!(codec.mu(), loaded.mu());
499 assert_eq!(codec.bits(), loaded.bits());
500
501 let _ = std::fs::remove_file(&path);
503 }
504}