1use crate::error::{TokenizerError, TokenizerResult};
26use scirs2_core::ndarray::{Array1, Array2};
27use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29use std::path::Path;
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct PyTorchCompat {
37 pub weights: HashMap<String, TensorInfo>,
39 pub config: ModelConfig,
41 pub torch_version: String,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct TensorInfo {
48 pub shape: Vec<usize>,
50 pub dtype: DType,
52 pub data: Vec<f32>,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58pub enum DType {
59 Float32,
61 Float16,
63 Float64,
65 Int32,
67 Int64,
69}
70
71impl DType {
72 pub fn size_bytes(&self) -> usize {
74 match self {
75 DType::Float32 => 4,
76 DType::Float16 => 2,
77 DType::Float64 => 8,
78 DType::Int32 => 4,
79 DType::Int64 => 8,
80 }
81 }
82
83 pub fn torch_name(&self) -> &'static str {
85 match self {
86 DType::Float32 => "torch.float32",
87 DType::Float16 => "torch.float16",
88 DType::Float64 => "torch.float64",
89 DType::Int32 => "torch.int32",
90 DType::Int64 => "torch.int64",
91 }
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct ModelConfig {
98 pub model_type: String,
100 pub input_dim: usize,
102 pub output_dim: usize,
104 pub hyperparameters: HashMap<String, serde_json::Value>,
106}
107
108impl PyTorchCompat {
109 pub fn new(config: ModelConfig) -> Self {
111 Self {
112 weights: HashMap::new(),
113 config,
114 torch_version: "2.0.0".to_string(),
115 }
116 }
117
118 pub fn add_weight(&mut self, name: impl Into<String>, array: &Array2<f32>) {
120 let shape = array.shape().to_vec();
121 let data = array.iter().copied().collect();
122
123 self.weights.insert(
124 name.into(),
125 TensorInfo {
126 shape,
127 dtype: DType::Float32,
128 data,
129 },
130 );
131 }
132
133 pub fn add_weight_1d(&mut self, name: impl Into<String>, array: &Array1<f32>) {
135 let shape = vec![array.len()];
136 let data = array.iter().copied().collect();
137
138 self.weights.insert(
139 name.into(),
140 TensorInfo {
141 shape,
142 dtype: DType::Float32,
143 data,
144 },
145 );
146 }
147
148 pub fn get_weight(&self, name: &str) -> TokenizerResult<Array2<f32>> {
150 let tensor = self
151 .weights
152 .get(name)
153 .ok_or_else(|| TokenizerError::InvalidConfig(format!("Weight '{}' not found", name)))?;
154
155 if tensor.shape.len() != 2 {
156 return Err(TokenizerError::InvalidConfig(format!(
157 "Expected 2D tensor, got {}D",
158 tensor.shape.len()
159 )));
160 }
161
162 Array2::from_shape_vec((tensor.shape[0], tensor.shape[1]), tensor.data.clone())
163 .map_err(|e| TokenizerError::InvalidConfig(format!("Shape mismatch: {}", e)))
164 }
165
166 pub fn get_weight_1d(&self, name: &str) -> TokenizerResult<Array1<f32>> {
168 let tensor = self
169 .weights
170 .get(name)
171 .ok_or_else(|| TokenizerError::InvalidConfig(format!("Weight '{}' not found", name)))?;
172
173 if tensor.shape.len() != 1 {
174 return Err(TokenizerError::InvalidConfig(format!(
175 "Expected 1D tensor, got {}D",
176 tensor.shape.len()
177 )));
178 }
179
180 Ok(Array1::from_vec(tensor.data.clone()))
181 }
182
183 pub fn save<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
185 let json = serde_json::to_string_pretty(self).map_err(|e| {
186 TokenizerError::SerializationError(format!("JSON serialization failed: {}", e))
187 })?;
188
189 std::fs::write(path, json).map_err(TokenizerError::IoError)?;
190
191 Ok(())
192 }
193
194 pub fn load<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
196 let json = std::fs::read_to_string(path).map_err(TokenizerError::IoError)?;
197
198 serde_json::from_str(&json).map_err(|e| {
199 TokenizerError::SerializationError(format!("JSON deserialization failed: {}", e))
200 })
201 }
202
203 pub fn weight_names(&self) -> Vec<String> {
205 self.weights.keys().cloned().collect()
206 }
207
208 pub fn num_parameters(&self) -> usize {
210 self.weights.values().map(|t| t.data.len()).sum()
211 }
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct AudioMetadata {
220 pub sample_rate: u32,
222 pub bit_depth: u8,
224 pub num_channels: u8,
226 pub num_samples: Option<usize>,
228 pub duration_secs: Option<f64>,
230 pub tags: HashMap<String, String>,
232}
233
234impl AudioMetadata {
235 pub fn new(sample_rate: u32, bit_depth: u8, num_channels: u8) -> TokenizerResult<Self> {
237 if sample_rate == 0 {
239 return Err(TokenizerError::InvalidConfig(
240 "Sample rate must be positive".into(),
241 ));
242 }
243
244 if ![8, 16, 24, 32].contains(&bit_depth) {
245 return Err(TokenizerError::InvalidConfig(format!(
246 "Invalid bit depth: {}. Must be 8, 16, 24, or 32",
247 bit_depth
248 )));
249 }
250
251 if num_channels == 0 || num_channels > 8 {
252 return Err(TokenizerError::InvalidConfig(format!(
253 "Invalid number of channels: {}. Must be 1-8",
254 num_channels
255 )));
256 }
257
258 Ok(Self {
259 sample_rate,
260 bit_depth,
261 num_channels,
262 num_samples: None,
263 duration_secs: None,
264 tags: HashMap::new(),
265 })
266 }
267
268 pub fn from_signal(
270 signal: &Array1<f32>,
271 sample_rate: u32,
272 bit_depth: u8,
273 num_channels: u8,
274 ) -> TokenizerResult<Self> {
275 let mut metadata = Self::new(sample_rate, bit_depth, num_channels)?;
276 metadata.num_samples = Some(signal.len());
277 metadata.duration_secs = Some(signal.len() as f64 / sample_rate as f64);
278 Ok(metadata)
279 }
280
281 pub fn set_tag(&mut self, key: impl Into<String>, value: impl Into<String>) {
283 self.tags.insert(key.into(), value.into());
284 }
285
286 pub fn get_tag(&self, key: &str) -> Option<&str> {
288 self.tags.get(key).map(|s| s.as_str())
289 }
290
291 pub fn nyquist_frequency(&self) -> f32 {
293 self.sample_rate as f32 / 2.0
294 }
295
296 pub fn duration(&self) -> Option<f64> {
298 self.duration_secs
299 .or_else(|| self.num_samples.map(|n| n as f64 / self.sample_rate as f64))
300 }
301
302 pub fn to_wav_metadata(&self) -> String {
304 serde_json::to_string_pretty(self).unwrap_or_default()
305 }
306
307 pub fn from_wav_metadata(json: &str) -> TokenizerResult<Self> {
309 serde_json::from_str(json).map_err(|e| {
310 TokenizerError::SerializationError(format!("Failed to parse metadata: {}", e))
311 })
312 }
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct OnnxConfig {
318 pub opset_version: i64,
320 pub input_names: Vec<String>,
322 pub output_names: Vec<String>,
324 pub dynamic_axes: HashMap<String, Vec<i64>>,
326}
327
328impl Default for OnnxConfig {
329 fn default() -> Self {
330 Self {
331 opset_version: 14,
332 input_names: vec!["input".to_string()],
333 output_names: vec!["output".to_string()],
334 dynamic_axes: HashMap::new(),
335 }
336 }
337}
338
339impl OnnxConfig {
340 pub fn for_tokenizer(_input_dim: usize, _output_dim: usize) -> Self {
342 let mut config = Self::default();
343
344 let mut dynamic_axes = HashMap::new();
346 dynamic_axes.insert("input".to_string(), vec![0]); dynamic_axes.insert("output".to_string(), vec![0]); config.dynamic_axes = dynamic_axes;
349
350 config
351 }
352
353 pub fn to_json(&self) -> TokenizerResult<String> {
355 serde_json::to_string_pretty(self).map_err(|e| {
356 TokenizerError::SerializationError(format!("ONNX config serialization failed: {}", e))
357 })
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn test_pytorch_compat_basic() {
367 let config = ModelConfig {
368 model_type: "continuous_tokenizer".to_string(),
369 input_dim: 128,
370 output_dim: 256,
371 hyperparameters: HashMap::new(),
372 };
373
374 let mut compat = PyTorchCompat::new(config);
375
376 let encoder = Array2::from_shape_fn((128, 256), |(i, j)| (i + j) as f32 * 0.01);
377 compat.add_weight("encoder", &encoder);
378
379 assert_eq!(compat.weights.len(), 1);
380 assert_eq!(compat.num_parameters(), 128 * 256);
381 }
382
383 #[test]
384 fn test_pytorch_compat_roundtrip() {
385 let config = ModelConfig {
386 model_type: "test".to_string(),
387 input_dim: 10,
388 output_dim: 20,
389 hyperparameters: HashMap::new(),
390 };
391
392 let mut compat = PyTorchCompat::new(config);
393 let weights = Array2::from_shape_fn((10, 20), |(i, j)| (i * 20 + j) as f32);
394 compat.add_weight("test_weight", &weights);
395
396 let retrieved = compat.get_weight("test_weight").unwrap();
397 assert_eq!(retrieved.shape(), &[10, 20]);
398 assert_eq!(retrieved[[0, 0]], 0.0);
399 assert_eq!(retrieved[[9, 19]], 199.0);
400 }
401
402 #[test]
403 fn test_audio_metadata_creation() {
404 let metadata = AudioMetadata::new(44100, 16, 2).unwrap();
405 assert_eq!(metadata.sample_rate, 44100);
406 assert_eq!(metadata.bit_depth, 16);
407 assert_eq!(metadata.num_channels, 2);
408 assert_eq!(metadata.nyquist_frequency(), 22050.0);
409 }
410
411 #[test]
412 fn test_audio_metadata_validation() {
413 assert!(AudioMetadata::new(0, 16, 2).is_err());
415
416 assert!(AudioMetadata::new(44100, 13, 2).is_err());
418
419 assert!(AudioMetadata::new(44100, 16, 0).is_err());
421 assert!(AudioMetadata::new(44100, 16, 9).is_err());
422 }
423
424 #[test]
425 fn test_audio_metadata_from_signal() {
426 let signal = Array1::from_vec(vec![0.0; 44100]); let metadata = AudioMetadata::from_signal(&signal, 44100, 16, 1).unwrap();
428
429 assert_eq!(metadata.num_samples, Some(44100));
430 assert!((metadata.duration().unwrap() - 1.0).abs() < 1e-6);
431 }
432
433 #[test]
434 fn test_audio_metadata_tags() {
435 let mut metadata = AudioMetadata::new(44100, 16, 2).unwrap();
436 metadata.set_tag("artist", "Test Artist");
437 metadata.set_tag("title", "Test Title");
438
439 assert_eq!(metadata.get_tag("artist"), Some("Test Artist"));
440 assert_eq!(metadata.get_tag("title"), Some("Test Title"));
441 assert_eq!(metadata.get_tag("nonexistent"), None);
442 }
443
444 #[test]
445 fn test_audio_metadata_serialization() {
446 let metadata = AudioMetadata::new(48000, 24, 2).unwrap();
447 let json = metadata.to_wav_metadata();
448 let deserialized = AudioMetadata::from_wav_metadata(&json).unwrap();
449
450 assert_eq!(deserialized.sample_rate, 48000);
451 assert_eq!(deserialized.bit_depth, 24);
452 assert_eq!(deserialized.num_channels, 2);
453 }
454
455 #[test]
456 fn test_dtype_properties() {
457 assert_eq!(DType::Float32.size_bytes(), 4);
458 assert_eq!(DType::Float16.size_bytes(), 2);
459 assert_eq!(DType::Float64.size_bytes(), 8);
460
461 assert_eq!(DType::Float32.torch_name(), "torch.float32");
462 assert_eq!(DType::Int64.torch_name(), "torch.int64");
463 }
464
465 #[test]
466 fn test_onnx_config_default() {
467 let config = OnnxConfig::default();
468 assert_eq!(config.opset_version, 14);
469 assert_eq!(config.input_names, vec!["input"]);
470 assert_eq!(config.output_names, vec!["output"]);
471 }
472
473 #[test]
474 fn test_onnx_config_for_tokenizer() {
475 let config = OnnxConfig::for_tokenizer(128, 256);
476 assert_eq!(config.opset_version, 14);
477 assert!(config.dynamic_axes.contains_key("input"));
478 assert!(config.dynamic_axes.contains_key("output"));
479 }
480
481 #[test]
482 fn test_onnx_config_serialization() {
483 let config = OnnxConfig::for_tokenizer(100, 200);
484 let json = config.to_json().unwrap();
485 assert!(json.contains("\"opset_version\""));
486 assert!(json.contains("\"input_names\""));
487 }
488}