kizzasi_tokenizer/
compat.rs

1//! Compatibility and interoperability with other frameworks
2//!
3//! This module provides utilities for importing/exporting tokenizer weights
4//! and configurations to/from other ML frameworks and standard formats.
5//!
6//! # Supported Formats
7//!
8//! - **PyTorch**: Import/export weights in PyTorch-compatible format via safetensors
9//! - **ONNX**: Export tokenizer operations for ONNX runtime
10//! - **Audio Metadata**: WAV/FLAC metadata for signal properties
11//!
12//! # Examples
13//!
14//! ```rust,ignore
15//! use kizzasi_tokenizer::compat::{PyTorchCompat, AudioMetadata};
16//!
17//! // Export to PyTorch format
18//! let pytorch_compat = PyTorchCompat::from_tokenizer(&tokenizer)?;
19//! pytorch_compat.save("model.safetensors")?;
20//!
21//! // Add audio metadata
22//! let metadata = AudioMetadata::new(44100, 16, 1);
23//! ```
24
25use crate::error::{TokenizerError, TokenizerResult};
26use scirs2_core::ndarray::{Array1, Array2};
27use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29use std::path::Path;
30
31/// PyTorch-compatible weight export/import
32///
33/// Provides utilities to save and load tokenizer weights in a format
34/// compatible with PyTorch models using safetensors.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct PyTorchCompat {
37    /// Model weights as named tensors
38    pub weights: HashMap<String, TensorInfo>,
39    /// Model configuration
40    pub config: ModelConfig,
41    /// PyTorch version compatibility
42    pub torch_version: String,
43}
44
45/// Tensor information for PyTorch compatibility
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct TensorInfo {
48    /// Tensor shape
49    pub shape: Vec<usize>,
50    /// Tensor data type
51    pub dtype: DType,
52    /// Flattened tensor data
53    pub data: Vec<f32>,
54}
55
56/// Data type enum for cross-framework compatibility
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58pub enum DType {
59    /// 32-bit floating point
60    Float32,
61    /// 16-bit floating point (half precision)
62    Float16,
63    /// 64-bit floating point
64    Float64,
65    /// 32-bit integer
66    Int32,
67    /// 64-bit integer
68    Int64,
69}
70
71impl DType {
72    /// Get the size in bytes of this dtype
73    pub fn size_bytes(&self) -> usize {
74        match self {
75            DType::Float32 => 4,
76            DType::Float16 => 2,
77            DType::Float64 => 8,
78            DType::Int32 => 4,
79            DType::Int64 => 8,
80        }
81    }
82
83    /// Get the PyTorch dtype string
84    pub fn torch_name(&self) -> &'static str {
85        match self {
86            DType::Float32 => "torch.float32",
87            DType::Float16 => "torch.float16",
88            DType::Float64 => "torch.float64",
89            DType::Int32 => "torch.int32",
90            DType::Int64 => "torch.int64",
91        }
92    }
93}
94
95/// Model configuration for framework compatibility
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct ModelConfig {
98    /// Model type identifier
99    pub model_type: String,
100    /// Input dimension
101    pub input_dim: usize,
102    /// Output/embedding dimension
103    pub output_dim: usize,
104    /// Additional hyperparameters
105    pub hyperparameters: HashMap<String, serde_json::Value>,
106}
107
108impl PyTorchCompat {
109    /// Create a new PyTorch compatibility wrapper
110    pub fn new(config: ModelConfig) -> Self {
111        Self {
112            weights: HashMap::new(),
113            config,
114            torch_version: "2.0.0".to_string(),
115        }
116    }
117
118    /// Add a weight tensor
119    pub fn add_weight(&mut self, name: impl Into<String>, array: &Array2<f32>) {
120        let shape = array.shape().to_vec();
121        let data = array.iter().copied().collect();
122
123        self.weights.insert(
124            name.into(),
125            TensorInfo {
126                shape,
127                dtype: DType::Float32,
128                data,
129            },
130        );
131    }
132
133    /// Add a 1D weight tensor (bias, etc.)
134    pub fn add_weight_1d(&mut self, name: impl Into<String>, array: &Array1<f32>) {
135        let shape = vec![array.len()];
136        let data = array.iter().copied().collect();
137
138        self.weights.insert(
139            name.into(),
140            TensorInfo {
141                shape,
142                dtype: DType::Float32,
143                data,
144            },
145        );
146    }
147
148    /// Get a weight tensor as Array2
149    pub fn get_weight(&self, name: &str) -> TokenizerResult<Array2<f32>> {
150        let tensor = self
151            .weights
152            .get(name)
153            .ok_or_else(|| TokenizerError::InvalidConfig(format!("Weight '{}' not found", name)))?;
154
155        if tensor.shape.len() != 2 {
156            return Err(TokenizerError::InvalidConfig(format!(
157                "Expected 2D tensor, got {}D",
158                tensor.shape.len()
159            )));
160        }
161
162        Array2::from_shape_vec((tensor.shape[0], tensor.shape[1]), tensor.data.clone())
163            .map_err(|e| TokenizerError::InvalidConfig(format!("Shape mismatch: {}", e)))
164    }
165
166    /// Get a 1D weight tensor
167    pub fn get_weight_1d(&self, name: &str) -> TokenizerResult<Array1<f32>> {
168        let tensor = self
169            .weights
170            .get(name)
171            .ok_or_else(|| TokenizerError::InvalidConfig(format!("Weight '{}' not found", name)))?;
172
173        if tensor.shape.len() != 1 {
174            return Err(TokenizerError::InvalidConfig(format!(
175                "Expected 1D tensor, got {}D",
176                tensor.shape.len()
177            )));
178        }
179
180        Ok(Array1::from_vec(tensor.data.clone()))
181    }
182
183    /// Save to safetensors format (PyTorch compatible)
184    pub fn save<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
185        let json = serde_json::to_string_pretty(self).map_err(|e| {
186            TokenizerError::SerializationError(format!("JSON serialization failed: {}", e))
187        })?;
188
189        std::fs::write(path, json).map_err(TokenizerError::IoError)?;
190
191        Ok(())
192    }
193
194    /// Load from safetensors format
195    pub fn load<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
196        let json = std::fs::read_to_string(path).map_err(TokenizerError::IoError)?;
197
198        serde_json::from_str(&json).map_err(|e| {
199            TokenizerError::SerializationError(format!("JSON deserialization failed: {}", e))
200        })
201    }
202
203    /// Export weight names for ONNX mapping
204    pub fn weight_names(&self) -> Vec<String> {
205        self.weights.keys().cloned().collect()
206    }
207
208    /// Get total number of parameters
209    pub fn num_parameters(&self) -> usize {
210        self.weights.values().map(|t| t.data.len()).sum()
211    }
212}
213
214/// Audio metadata for signal processing
215///
216/// Stores standard audio properties that can be embedded in WAV/FLAC files
217/// or used for proper signal reconstruction.
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct AudioMetadata {
220    /// Sample rate in Hz
221    pub sample_rate: u32,
222    /// Bit depth (8, 16, 24, 32)
223    pub bit_depth: u8,
224    /// Number of channels (1=mono, 2=stereo)
225    pub num_channels: u8,
226    /// Total number of samples
227    pub num_samples: Option<usize>,
228    /// Duration in seconds
229    pub duration_secs: Option<f64>,
230    /// Additional metadata tags
231    pub tags: HashMap<String, String>,
232}
233
234impl AudioMetadata {
235    /// Create new audio metadata
236    pub fn new(sample_rate: u32, bit_depth: u8, num_channels: u8) -> TokenizerResult<Self> {
237        // Validate parameters
238        if sample_rate == 0 {
239            return Err(TokenizerError::InvalidConfig(
240                "Sample rate must be positive".into(),
241            ));
242        }
243
244        if ![8, 16, 24, 32].contains(&bit_depth) {
245            return Err(TokenizerError::InvalidConfig(format!(
246                "Invalid bit depth: {}. Must be 8, 16, 24, or 32",
247                bit_depth
248            )));
249        }
250
251        if num_channels == 0 || num_channels > 8 {
252            return Err(TokenizerError::InvalidConfig(format!(
253                "Invalid number of channels: {}. Must be 1-8",
254                num_channels
255            )));
256        }
257
258        Ok(Self {
259            sample_rate,
260            bit_depth,
261            num_channels,
262            num_samples: None,
263            duration_secs: None,
264            tags: HashMap::new(),
265        })
266    }
267
268    /// Create metadata from signal length
269    pub fn from_signal(
270        signal: &Array1<f32>,
271        sample_rate: u32,
272        bit_depth: u8,
273        num_channels: u8,
274    ) -> TokenizerResult<Self> {
275        let mut metadata = Self::new(sample_rate, bit_depth, num_channels)?;
276        metadata.num_samples = Some(signal.len());
277        metadata.duration_secs = Some(signal.len() as f64 / sample_rate as f64);
278        Ok(metadata)
279    }
280
281    /// Set a metadata tag
282    pub fn set_tag(&mut self, key: impl Into<String>, value: impl Into<String>) {
283        self.tags.insert(key.into(), value.into());
284    }
285
286    /// Get a metadata tag
287    pub fn get_tag(&self, key: &str) -> Option<&str> {
288        self.tags.get(key).map(|s| s.as_str())
289    }
290
291    /// Compute Nyquist frequency
292    pub fn nyquist_frequency(&self) -> f32 {
293        self.sample_rate as f32 / 2.0
294    }
295
296    /// Get duration in seconds
297    pub fn duration(&self) -> Option<f64> {
298        self.duration_secs
299            .or_else(|| self.num_samples.map(|n| n as f64 / self.sample_rate as f64))
300    }
301
302    /// Export as WAV-compatible metadata JSON
303    pub fn to_wav_metadata(&self) -> String {
304        serde_json::to_string_pretty(self).unwrap_or_default()
305    }
306
307    /// Import from WAV-compatible metadata JSON
308    pub fn from_wav_metadata(json: &str) -> TokenizerResult<Self> {
309        serde_json::from_str(json).map_err(|e| {
310            TokenizerError::SerializationError(format!("Failed to parse metadata: {}", e))
311        })
312    }
313}
314
315/// ONNX export configuration
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct OnnxConfig {
318    /// ONNX opset version
319    pub opset_version: i64,
320    /// Input names
321    pub input_names: Vec<String>,
322    /// Output names
323    pub output_names: Vec<String>,
324    /// Dynamic axes for variable-length inputs
325    pub dynamic_axes: HashMap<String, Vec<i64>>,
326}
327
328impl Default for OnnxConfig {
329    fn default() -> Self {
330        Self {
331            opset_version: 14,
332            input_names: vec!["input".to_string()],
333            output_names: vec!["output".to_string()],
334            dynamic_axes: HashMap::new(),
335        }
336    }
337}
338
339impl OnnxConfig {
340    /// Create ONNX config for a tokenizer
341    pub fn for_tokenizer(_input_dim: usize, _output_dim: usize) -> Self {
342        let mut config = Self::default();
343
344        // Add dynamic batch dimension
345        let mut dynamic_axes = HashMap::new();
346        dynamic_axes.insert("input".to_string(), vec![0]); // Batch dimension
347        dynamic_axes.insert("output".to_string(), vec![0]); // Batch dimension
348        config.dynamic_axes = dynamic_axes;
349
350        config
351    }
352
353    /// Export configuration as JSON
354    pub fn to_json(&self) -> TokenizerResult<String> {
355        serde_json::to_string_pretty(self).map_err(|e| {
356            TokenizerError::SerializationError(format!("ONNX config serialization failed: {}", e))
357        })
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn test_pytorch_compat_basic() {
367        let config = ModelConfig {
368            model_type: "continuous_tokenizer".to_string(),
369            input_dim: 128,
370            output_dim: 256,
371            hyperparameters: HashMap::new(),
372        };
373
374        let mut compat = PyTorchCompat::new(config);
375
376        let encoder = Array2::from_shape_fn((128, 256), |(i, j)| (i + j) as f32 * 0.01);
377        compat.add_weight("encoder", &encoder);
378
379        assert_eq!(compat.weights.len(), 1);
380        assert_eq!(compat.num_parameters(), 128 * 256);
381    }
382
383    #[test]
384    fn test_pytorch_compat_roundtrip() {
385        let config = ModelConfig {
386            model_type: "test".to_string(),
387            input_dim: 10,
388            output_dim: 20,
389            hyperparameters: HashMap::new(),
390        };
391
392        let mut compat = PyTorchCompat::new(config);
393        let weights = Array2::from_shape_fn((10, 20), |(i, j)| (i * 20 + j) as f32);
394        compat.add_weight("test_weight", &weights);
395
396        let retrieved = compat.get_weight("test_weight").unwrap();
397        assert_eq!(retrieved.shape(), &[10, 20]);
398        assert_eq!(retrieved[[0, 0]], 0.0);
399        assert_eq!(retrieved[[9, 19]], 199.0);
400    }
401
402    #[test]
403    fn test_audio_metadata_creation() {
404        let metadata = AudioMetadata::new(44100, 16, 2).unwrap();
405        assert_eq!(metadata.sample_rate, 44100);
406        assert_eq!(metadata.bit_depth, 16);
407        assert_eq!(metadata.num_channels, 2);
408        assert_eq!(metadata.nyquist_frequency(), 22050.0);
409    }
410
411    #[test]
412    fn test_audio_metadata_validation() {
413        // Invalid sample rate
414        assert!(AudioMetadata::new(0, 16, 2).is_err());
415
416        // Invalid bit depth
417        assert!(AudioMetadata::new(44100, 13, 2).is_err());
418
419        // Invalid channels
420        assert!(AudioMetadata::new(44100, 16, 0).is_err());
421        assert!(AudioMetadata::new(44100, 16, 9).is_err());
422    }
423
424    #[test]
425    fn test_audio_metadata_from_signal() {
426        let signal = Array1::from_vec(vec![0.0; 44100]); // 1 second at 44.1kHz
427        let metadata = AudioMetadata::from_signal(&signal, 44100, 16, 1).unwrap();
428
429        assert_eq!(metadata.num_samples, Some(44100));
430        assert!((metadata.duration().unwrap() - 1.0).abs() < 1e-6);
431    }
432
433    #[test]
434    fn test_audio_metadata_tags() {
435        let mut metadata = AudioMetadata::new(44100, 16, 2).unwrap();
436        metadata.set_tag("artist", "Test Artist");
437        metadata.set_tag("title", "Test Title");
438
439        assert_eq!(metadata.get_tag("artist"), Some("Test Artist"));
440        assert_eq!(metadata.get_tag("title"), Some("Test Title"));
441        assert_eq!(metadata.get_tag("nonexistent"), None);
442    }
443
444    #[test]
445    fn test_audio_metadata_serialization() {
446        let metadata = AudioMetadata::new(48000, 24, 2).unwrap();
447        let json = metadata.to_wav_metadata();
448        let deserialized = AudioMetadata::from_wav_metadata(&json).unwrap();
449
450        assert_eq!(deserialized.sample_rate, 48000);
451        assert_eq!(deserialized.bit_depth, 24);
452        assert_eq!(deserialized.num_channels, 2);
453    }
454
455    #[test]
456    fn test_dtype_properties() {
457        assert_eq!(DType::Float32.size_bytes(), 4);
458        assert_eq!(DType::Float16.size_bytes(), 2);
459        assert_eq!(DType::Float64.size_bytes(), 8);
460
461        assert_eq!(DType::Float32.torch_name(), "torch.float32");
462        assert_eq!(DType::Int64.torch_name(), "torch.int64");
463    }
464
465    #[test]
466    fn test_onnx_config_default() {
467        let config = OnnxConfig::default();
468        assert_eq!(config.opset_version, 14);
469        assert_eq!(config.input_names, vec!["input"]);
470        assert_eq!(config.output_names, vec!["output"]);
471    }
472
473    #[test]
474    fn test_onnx_config_for_tokenizer() {
475        let config = OnnxConfig::for_tokenizer(128, 256);
476        assert_eq!(config.opset_version, 14);
477        assert!(config.dynamic_axes.contains_key("input"));
478        assert!(config.dynamic_axes.contains_key("output"));
479    }
480
481    #[test]
482    fn test_onnx_config_serialization() {
483        let config = OnnxConfig::for_tokenizer(100, 200);
484        let json = config.to_json().unwrap();
485        assert!(json.contains("\"opset_version\""));
486        assert!(json.contains("\"input_names\""));
487    }
488}