trustformers 0.1.1

TrustformeRS - Rust port of Hugging Face Transformers
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
use crate::error::{Result, TrustformersError};
use crate::pipeline::{BasePipeline, Pipeline};
use crate::{AutoModel, AutoTokenizer};
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::sync::Arc;

/// Audio input for speech-to-text pipeline
#[derive(Debug, Clone)]
pub enum AudioInput {
    /// File path to audio file
    FilePath(String),
    /// Raw audio samples (f32) with sample rate
    RawAudio { samples: Vec<f32>, sample_rate: u32 },
    /// Base64 encoded audio data
    Base64(String),
    /// Audio bytes with format info
    Bytes {
        data: Vec<u8>,
        format: AudioFormat,
        sample_rate: u32,
    },
}

/// Supported audio formats
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum AudioFormat {
    Wav,
    Flac,
    Mp3,
    M4a,
    Ogg,
    WebM,
}

impl AudioFormat {
    pub fn from_extension(ext: &str) -> Option<Self> {
        match ext.to_lowercase().as_str() {
            "wav" => Some(Self::Wav),
            "flac" => Some(Self::Flac),
            "mp3" => Some(Self::Mp3),
            "m4a" => Some(Self::M4a),
            "ogg" => Some(Self::Ogg),
            "webm" => Some(Self::WebM),
            _ => None,
        }
    }
}

/// Speech-to-text output
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpeechToTextOutput {
    /// Transcribed text
    pub text: String,
    /// Confidence score (0.0 to 1.0)
    pub confidence: Option<f32>,
    /// Word-level timestamps (if supported)
    pub word_timestamps: Option<Vec<WordTimestamp>>,
    /// Language detected (if multi-language model)
    pub language: Option<String>,
    /// Processing time in milliseconds
    pub processing_time_ms: Option<u64>,
}

/// Word-level timestamp information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WordTimestamp {
    pub word: String,
    pub start_time: f64, // seconds
    pub end_time: f64,   // seconds
    pub confidence: f32, // 0.0 to 1.0
}

/// Configuration for speech-to-text processing
#[derive(Clone, Debug)]
pub struct SpeechToTextConfig {
    /// Target sample rate for audio preprocessing
    pub sample_rate: u32,
    /// Maximum audio duration in seconds
    pub max_duration: Option<f64>,
    /// Return word-level timestamps
    pub return_timestamps: bool,
    /// Target language (for multilingual models)
    pub language: Option<String>,
    /// Task type (transcribe or translate)
    pub task: SpeechTask,
    /// Beam search configuration
    pub num_beams: usize,
    /// Use temperature sampling
    pub temperature: f32,
    /// Length penalty for beam search
    pub length_penalty: f32,
    /// Repetition penalty
    pub repetition_penalty: f32,
    /// No repeat n-gram size
    pub no_repeat_ngram_size: usize,
    /// Chunk length for long audio (in seconds)
    pub chunk_length_s: Option<f64>,
    /// Stride length for overlapping chunks
    pub stride_length_s: Option<f64>,
}

impl Default for SpeechToTextConfig {
    fn default() -> Self {
        Self {
            sample_rate: 16000,       // Whisper default
            max_duration: Some(30.0), // 30 seconds max
            return_timestamps: false,
            language: None, // Auto-detect
            task: SpeechTask::Transcribe,
            num_beams: 1,     // Greedy decoding by default
            temperature: 0.0, // Deterministic
            length_penalty: 1.0,
            repetition_penalty: 1.0,
            no_repeat_ngram_size: 0,
            chunk_length_s: Some(30.0), // 30-second chunks
            stride_length_s: Some(5.0), // 5-second stride
        }
    }
}

/// Speech recognition task type
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum SpeechTask {
    /// Transcribe in the same language
    Transcribe,
    /// Translate to English
    Translate,
}

/// Pipeline for speech-to-text tasks (ASR)
#[derive(Clone)]
pub struct SpeechToTextPipeline {
    base: BasePipeline<AutoModel, AutoTokenizer>,
    config: SpeechToTextConfig,
    feature_extractor: Arc<AudioFeatureExtractor>,
}

impl SpeechToTextPipeline {
    /// Create a new speech-to-text pipeline
    pub fn new(model: AutoModel, tokenizer: AutoTokenizer) -> Result<Self> {
        let base = BasePipeline::new(model, tokenizer);
        let config = SpeechToTextConfig::default();
        let feature_extractor = Arc::new(AudioFeatureExtractor::new(config.sample_rate)?);

        Ok(Self {
            base,
            config,
            feature_extractor,
        })
    }

    /// Create pipeline with custom configuration
    pub fn with_config(mut self, config: SpeechToTextConfig) -> Self {
        self.config = config;
        self
    }

    /// Set target language for transcription
    pub fn with_language(mut self, language: String) -> Self {
        self.config.language = Some(language);
        self
    }

    /// Enable word-level timestamps
    pub fn with_timestamps(mut self, enable: bool) -> Self {
        self.config.return_timestamps = enable;
        self
    }

    /// Set task type (transcribe or translate)
    pub fn with_task(mut self, task: SpeechTask) -> Self {
        self.config.task = task;
        self
    }

    /// Set audio chunk length for processing long audio
    pub fn with_chunk_length(mut self, chunk_length_s: f64) -> Self {
        self.config.chunk_length_s = Some(chunk_length_s);
        self
    }

    /// Process audio file from path
    pub fn transcribe_file<P: AsRef<Path>>(&self, audio_path: P) -> Result<SpeechToTextOutput> {
        let input = AudioInput::FilePath(audio_path.as_ref().to_string_lossy().to_string());
        self.__call__(input)
    }

    /// Process raw audio samples
    pub fn transcribe_samples(
        &self,
        samples: Vec<f32>,
        sample_rate: u32,
    ) -> Result<SpeechToTextOutput> {
        let input = AudioInput::RawAudio {
            samples,
            sample_rate,
        };
        self.__call__(input)
    }

    /// Process audio in streaming fashion (for real-time)
    pub fn transcribe_streaming(&self, audio_chunk: &[f32]) -> Result<SpeechToTextOutput> {
        // For streaming, we process shorter chunks
        let input = AudioInput::RawAudio {
            samples: audio_chunk.to_vec(),
            sample_rate: self.config.sample_rate,
        };
        self.__call__(input)
    }

    /// Pre-process audio input to features
    fn preprocess_audio(&self, input: &AudioInput) -> Result<AudioFeatures> {
        match input {
            AudioInput::FilePath(path) => {
                // Load audio file and extract features
                self.feature_extractor.load_and_extract(path)
            },
            AudioInput::RawAudio {
                samples,
                sample_rate,
            } => {
                // Resample if necessary
                let resampled = if *sample_rate != self.config.sample_rate {
                    self.feature_extractor.resample(
                        samples,
                        *sample_rate,
                        self.config.sample_rate,
                    )?
                } else {
                    samples.clone()
                };

                // Extract features
                self.feature_extractor.extract_features(&resampled)
            },
            AudioInput::Base64(encoded) => {
                // Decode base64 and process
                let decoded = base64::decode(encoded).map_err(|e| {
                    TrustformersError::invalid_input_simple(format!(
                        "Failed to decode base64 audio: {}",
                        e
                    ))
                })?;

                // Assume WAV format for base64 input
                self.feature_extractor.decode_and_extract(&decoded, AudioFormat::Wav)
            },
            AudioInput::Bytes {
                data,
                format,
                sample_rate,
            } => {
                // Decode bytes and extract features
                self.feature_extractor
                    .decode_and_extract(data, *format)?
                    .resample_to(self.config.sample_rate)
            },
        }
    }

    /// Post-process model output to speech-to-text result
    fn postprocess_output(
        &self,
        model_output: &crate::core::tensor::Tensor,
        audio_duration: f64,
    ) -> Result<SpeechToTextOutput> {
        // This is a simplified implementation
        // In a real implementation, this would:
        // 1. Decode token IDs to text using the tokenizer
        // 2. Extract timestamps if requested
        // 3. Calculate confidence scores
        // 4. Handle language detection

        let text = "Transcribed text placeholder".to_string(); // Simplified
        let confidence = Some(0.95); // Placeholder confidence

        let word_timestamps = if self.config.return_timestamps {
            Some(vec![
                WordTimestamp {
                    word: "Transcribed".to_string(),
                    start_time: 0.0,
                    end_time: 0.5,
                    confidence: 0.95,
                },
                WordTimestamp {
                    word: "text".to_string(),
                    start_time: 0.5,
                    end_time: 1.0,
                    confidence: 0.90,
                },
            ])
        } else {
            None
        };

        Ok(SpeechToTextOutput {
            text,
            confidence,
            word_timestamps,
            language: self.config.language.clone(),
            processing_time_ms: Some(100), // Placeholder
        })
    }
}

impl Pipeline for SpeechToTextPipeline {
    type Input = AudioInput;
    type Output = SpeechToTextOutput;

    fn __call__(&self, input: Self::Input) -> Result<Self::Output> {
        let start_time = std::time::Instant::now();

        // 1. Preprocess audio to features
        let audio_features = self.preprocess_audio(&input)?;
        let audio_duration = audio_features.duration();

        // 2. Check audio duration limits
        if let Some(max_duration) = self.config.max_duration {
            if audio_duration > max_duration {
                return Err(TrustformersError::invalid_input_simple(format!(
                    "Audio duration ({:.2}s) exceeds maximum allowed ({:.2}s)",
                    audio_duration, max_duration
                )));
            }
        }

        // 3. Convert features to tensor for model input
        let input_tensor = audio_features.to_tensor()?;

        // 4. Run model inference (simplified for demonstration)
        // let model_output = self.base.model.forward(input_tensor)?;
        let model_output = input_tensor; // Placeholder

        // 5. Post-process output to final result
        let mut result = self.postprocess_output(&model_output, audio_duration)?;

        // 6. Add processing time
        result.processing_time_ms = Some(start_time.elapsed().as_millis() as u64);

        Ok(result)
    }
}

/// Audio feature extractor for speech models
pub struct AudioFeatureExtractor {
    sample_rate: u32,
    n_fft: usize,
    hop_length: usize,
    n_mels: usize,
}

impl AudioFeatureExtractor {
    pub fn new(sample_rate: u32) -> Result<Self> {
        Ok(Self {
            sample_rate,
            n_fft: 400,      // Whisper default
            hop_length: 160, // Whisper default
            n_mels: 80,      // Whisper default
        })
    }

    pub fn load_and_extract(&self, path: &str) -> Result<AudioFeatures> {
        // Placeholder implementation
        // In a real implementation, this would use a library like `symphonia` or `rodio`
        // to load audio files and extract features

        Ok(AudioFeatures {
            features: vec![vec![0.0; self.n_mels]; 100], // Placeholder mel spectrogram
            sample_rate: self.sample_rate,
            duration_s: 5.0, // Placeholder duration
        })
    }

    pub fn extract_features(&self, samples: &[f32]) -> Result<AudioFeatures> {
        // Placeholder for mel spectrogram extraction
        // In a real implementation, this would:
        // 1. Apply pre-emphasis filter
        // 2. Compute STFT
        // 3. Convert to mel scale
        // 4. Apply log compression

        let duration_s = samples.len() as f64 / self.sample_rate as f64;
        let n_frames = (samples.len() / self.hop_length) + 1;

        Ok(AudioFeatures {
            features: vec![vec![0.0; self.n_mels]; n_frames], // Placeholder
            sample_rate: self.sample_rate,
            duration_s,
        })
    }

    pub fn resample(&self, samples: &[f32], from_rate: u32, to_rate: u32) -> Result<Vec<f32>> {
        if from_rate == to_rate {
            return Ok(samples.to_vec());
        }

        // Simplified resampling (in reality would use proper resampling algorithm)
        let ratio = to_rate as f64 / from_rate as f64;
        let new_len = (samples.len() as f64 * ratio) as usize;

        let mut resampled = Vec::with_capacity(new_len);
        for i in 0..new_len {
            let original_idx = (i as f64 / ratio) as usize;
            if original_idx < samples.len() {
                resampled.push(samples[original_idx]);
            } else {
                resampled.push(0.0);
            }
        }

        Ok(resampled)
    }

    pub fn decode_and_extract(&self, data: &[u8], format: AudioFormat) -> Result<AudioFeatures> {
        // Placeholder for audio decoding
        // In a real implementation, this would decode various audio formats

        match format {
            AudioFormat::Wav => {
                // Decode WAV format
                self.extract_features(&[0.0; 16000]) // Placeholder
            },
            _ => {
                // For other formats, would use appropriate decoder
                self.extract_features(&[0.0; 16000]) // Placeholder
            },
        }
    }
}

/// Audio features (typically mel spectrogram)
#[derive(Debug)]
pub struct AudioFeatures {
    pub features: Vec<Vec<f32>>, // [time_frames, n_mels]
    pub sample_rate: u32,
    pub duration_s: f64,
}

impl AudioFeatures {
    pub fn duration(&self) -> f64 {
        self.duration_s
    }

    pub fn to_tensor(&self) -> Result<crate::core::tensor::Tensor> {
        // Convert features to tensor format expected by model
        // This is a placeholder implementation
        use crate::core::tensor::Tensor;

        // Flatten features for tensor creation
        let flat_features: Vec<f32> = self.features.iter().flatten().cloned().collect();
        let shape = vec![1, self.features.len(), self.features[0].len()]; // [batch, time, features]

        Tensor::from_vec(flat_features, &shape).map_err(Into::into)
    }

    pub fn resample_to(self, target_rate: u32) -> Result<Self> {
        if self.sample_rate == target_rate {
            return Ok(self);
        }

        // Placeholder resampling
        Ok(self)
    }
}

// Import base64 crate (would be added to dependencies)
mod base64 {
    pub fn decode(_input: &str) -> Result<Vec<u8>, String> {
        // Placeholder implementation
        Ok(vec![])
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    // ---- AudioFormat tests ----

    #[test]
    fn test_audio_format_from_extension_wav() {
        let fmt = AudioFormat::from_extension("wav");
        assert!(matches!(fmt, Some(AudioFormat::Wav)));
    }

    #[test]
    fn test_audio_format_from_extension_flac() {
        let fmt = AudioFormat::from_extension("flac");
        assert!(matches!(fmt, Some(AudioFormat::Flac)));
    }

    #[test]
    fn test_audio_format_from_extension_mp3() {
        let fmt = AudioFormat::from_extension("mp3");
        assert!(matches!(fmt, Some(AudioFormat::Mp3)));
    }

    #[test]
    fn test_audio_format_from_extension_case_insensitive() {
        let fmt = AudioFormat::from_extension("WAV");
        assert!(matches!(fmt, Some(AudioFormat::Wav)));
    }

    #[test]
    fn test_audio_format_from_extension_unknown() {
        let fmt = AudioFormat::from_extension("xyz");
        assert!(fmt.is_none());
    }

    #[test]
    fn test_audio_format_all_variants() {
        let exts = ["wav", "flac", "mp3", "m4a", "ogg", "webm"];
        for ext in &exts {
            assert!(
                AudioFormat::from_extension(ext).is_some(),
                "missing: {}",
                ext
            );
        }
    }

    // ---- SpeechToTextConfig tests ----

    #[test]
    fn test_config_default_values() {
        let cfg = SpeechToTextConfig::default();
        assert_eq!(cfg.sample_rate, 16000);
        assert_eq!(cfg.max_duration, Some(30.0));
        assert!(!cfg.return_timestamps);
        assert!(cfg.language.is_none());
        assert!(matches!(cfg.task, SpeechTask::Transcribe));
        assert_eq!(cfg.num_beams, 1);
        assert!((cfg.temperature - 0.0).abs() < 1e-6);
    }

    #[test]
    fn test_config_chunk_length() {
        let cfg = SpeechToTextConfig::default();
        assert_eq!(cfg.chunk_length_s, Some(30.0));
        assert_eq!(cfg.stride_length_s, Some(5.0));
    }

    // ---- AudioFeatureExtractor tests ----

    #[test]
    fn test_extractor_creates_successfully() {
        let extractor = AudioFeatureExtractor::new(16000).expect("extractor creation succeeded");
        assert_eq!(extractor.sample_rate, 16000);
    }

    #[test]
    fn test_extract_features_duration_calculation() {
        let extractor = AudioFeatureExtractor::new(16000).expect("ok");
        let samples = vec![0.0f32; 16000]; // 1 second of silence
        let features = extractor.extract_features(&samples).expect("ok");
        assert!((features.duration_s - 1.0).abs() < 0.01);
    }

    #[test]
    fn test_extract_features_frame_count() {
        let extractor = AudioFeatureExtractor::new(16000).expect("ok");
        let samples = vec![0.0f32; 1600]; // 0.1 seconds
        let features = extractor.extract_features(&samples).expect("ok");
        // n_frames = (1600 / 160) + 1 = 11
        assert_eq!(features.features.len(), 11);
    }

    #[test]
    fn test_extract_features_mel_dims() {
        let extractor = AudioFeatureExtractor::new(16000).expect("ok");
        let samples = vec![0.0f32; 3200];
        let features = extractor.extract_features(&samples).expect("ok");
        // Each frame should have n_mels = 80 dimensions
        for frame in &features.features {
            assert_eq!(frame.len(), 80);
        }
    }

    // ---- Resampling tests ----

    #[test]
    fn test_resample_same_rate_no_op() {
        let extractor = AudioFeatureExtractor::new(16000).expect("ok");
        let samples = vec![0.1_f32, 0.2, 0.3];
        let resampled = extractor.resample(&samples, 16000, 16000).expect("ok");
        assert_eq!(resampled.len(), samples.len());
        for (a, b) in resampled.iter().zip(samples.iter()) {
            assert!((a - b).abs() < 1e-6);
        }
    }

    #[test]
    fn test_resample_upsample_increases_length() {
        let extractor = AudioFeatureExtractor::new(16000).expect("ok");
        let samples = vec![0.0f32; 100];
        let resampled = extractor.resample(&samples, 8000, 16000).expect("ok");
        assert!(
            resampled.len() > samples.len(),
            "upsampled should be longer: {}",
            resampled.len()
        );
    }

    #[test]
    fn test_resample_downsample_decreases_length() {
        let extractor = AudioFeatureExtractor::new(16000).expect("ok");
        let samples = vec![0.0f32; 200];
        let resampled = extractor.resample(&samples, 16000, 8000).expect("ok");
        assert!(
            resampled.len() < samples.len(),
            "downsampled should be shorter: {}",
            resampled.len()
        );
    }

    // ---- AudioFeatures tests ----

    #[test]
    fn test_audio_features_duration() {
        let af = AudioFeatures {
            features: vec![vec![0.0; 80]; 50],
            sample_rate: 16000,
            duration_s: 3.5,
        };
        assert!((af.duration() - 3.5).abs() < 1e-6);
    }

    #[test]
    fn test_audio_features_to_tensor_shape() {
        let n_frames = 10;
        let n_mels = 80;
        let af = AudioFeatures {
            features: vec![vec![0.1; n_mels]; n_frames],
            sample_rate: 16000,
            duration_s: 1.0,
        };
        let tensor = af.to_tensor().expect("tensor creation succeeded");
        let shape = tensor.shape();
        // Expected shape: [1, n_frames, n_mels]
        assert_eq!(shape[0], 1);
        assert_eq!(shape[1], n_frames);
        assert_eq!(shape[2], n_mels);
    }

    #[test]
    fn test_audio_features_resample_to_same_rate() {
        let af = AudioFeatures {
            features: vec![vec![0.0; 80]; 5],
            sample_rate: 16000,
            duration_s: 1.0,
        };
        let result = af.resample_to(16000).expect("ok");
        assert_eq!(result.sample_rate, 16000);
    }

    // ---- WordTimestamp tests ----

    #[test]
    fn test_word_timestamp_time_ordering() {
        let ts = WordTimestamp {
            word: "hello".to_string(),
            start_time: 0.0,
            end_time: 0.5,
            confidence: 0.95,
        };
        assert!(ts.start_time < ts.end_time);
        assert!(ts.confidence >= 0.0 && ts.confidence <= 1.0);
    }

    #[test]
    fn test_word_timestamp_confidence_range() {
        let ts = WordTimestamp {
            word: "world".to_string(),
            start_time: 0.5,
            end_time: 1.0,
            confidence: 0.87,
        };
        assert!(ts.confidence >= 0.0 && ts.confidence <= 1.0);
    }

    // ---- SpeechTask enum ----

    #[test]
    fn test_speech_task_transcribe_variant() {
        let task = SpeechTask::Transcribe;
        assert!(matches!(task, SpeechTask::Transcribe));
    }

    #[test]
    fn test_speech_task_translate_variant() {
        let task = SpeechTask::Translate;
        assert!(matches!(task, SpeechTask::Translate));
    }

    // ---- Feature extraction simulation (MFCC-like) ----

    #[test]
    fn test_frame_level_processing_non_empty() {
        let extractor = AudioFeatureExtractor::new(16000).expect("ok");
        // Use LCG to generate deterministic pseudo-random audio samples
        let mut seed = 12345u64;
        let samples: Vec<f32> = (0..4800)
            .map(|_| {
                seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
                ((seed >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0
            })
            .collect();
        let features = extractor.extract_features(&samples).expect("ok");
        assert!(!features.features.is_empty());
        assert_eq!(features.features[0].len(), 80);
    }

    #[test]
    fn test_load_and_extract_returns_features() {
        let extractor = AudioFeatureExtractor::new(16000).expect("ok");
        let af = extractor.load_and_extract("dummy_path.wav").expect("ok");
        assert!(!af.features.is_empty());
        assert!(af.duration_s > 0.0);
    }

    #[test]
    fn test_decode_and_extract_wav() {
        let extractor = AudioFeatureExtractor::new(16000).expect("ok");
        let dummy_data = vec![0u8; 512];
        let af = extractor.decode_and_extract(&dummy_data, AudioFormat::Wav).expect("ok");
        assert!(!af.features.is_empty());
    }
}