1#[cfg(feature = "native")]
4#[path = "ffi.rs"]
5mod ffi;
6
7#[cfg(feature = "native")]
8use std::ffi::CStr;
9#[cfg(any(feature = "native", test))]
10use std::ffi::CString;
11use std::fmt::{Display, Formatter};
12#[cfg(feature = "native")]
13use std::fs::File;
14#[cfg(any(feature = "native", test))]
15use std::fs::{self, OpenOptions};
16#[cfg(any(feature = "native", test))]
17use std::io::Write;
18#[cfg(feature = "native")]
19use std::io::{BufWriter, Read};
20use std::path::{Path, PathBuf};
21#[cfg(any(feature = "native", test))]
22use std::thread;
23#[cfg(any(feature = "native", test))]
24use std::time::{Duration, Instant};
25
26use serde::{Deserialize, Serialize};
27#[cfg(feature = "native")]
28use sha2::{Digest, Sha256};
29
30#[derive(
31 Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash, Default,
32)]
33pub enum WhisperCppModel {
35 #[serde(rename = "tiny.en")]
36 TinyEn,
38 #[serde(rename = "tiny")]
39 Tiny,
41 #[serde(rename = "base.en")]
42 #[default]
43 BaseEn,
45 #[serde(rename = "base")]
46 Base,
48 #[serde(rename = "small.en")]
49 SmallEn,
51 #[serde(rename = "small")]
52 Small,
54 #[serde(rename = "medium.en")]
55 MediumEn,
57 #[serde(rename = "medium")]
58 Medium,
60 #[serde(rename = "large-v1")]
61 LargeV1,
63 #[serde(rename = "large-v2")]
64 LargeV2,
66 #[serde(rename = "large-v3")]
67 LargeV3,
69 #[serde(rename = "large-v3-turbo")]
70 LargeV3Turbo,
72}
73
74impl WhisperCppModel {
75 pub const ALL: [Self; 12] = [
77 Self::TinyEn,
78 Self::Tiny,
79 Self::BaseEn,
80 Self::Base,
81 Self::SmallEn,
82 Self::Small,
83 Self::MediumEn,
84 Self::Medium,
85 Self::LargeV1,
86 Self::LargeV2,
87 Self::LargeV3,
88 Self::LargeV3Turbo,
89 ];
90
91 pub fn id(self) -> &'static str {
93 match self {
94 Self::TinyEn => "tiny.en",
95 Self::Tiny => "tiny",
96 Self::BaseEn => "base.en",
97 Self::Base => "base",
98 Self::SmallEn => "small.en",
99 Self::Small => "small",
100 Self::MediumEn => "medium.en",
101 Self::Medium => "medium",
102 Self::LargeV1 => "large-v1",
103 Self::LargeV2 => "large-v2",
104 Self::LargeV3 => "large-v3",
105 Self::LargeV3Turbo => "large-v3-turbo",
106 }
107 }
108
109 pub fn file_name(self) -> String {
111 format!("ggml-{}.bin", self.id())
112 }
113
114 pub fn download_url(self) -> String {
116 format!(
117 "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/{}",
118 self.file_name()
119 )
120 }
121
122 pub fn checksum_sha256(self) -> &'static str {
124 match self {
125 Self::TinyEn => "0d686a2a6a22b02da2ef3101d4c86e68461363a623c58f27f81b1b2d36b42317",
126 Self::Tiny => "518970a29bedb265f23ac48d486ddbc63bedffd90967b10140ae5ac61243acf3",
127 Self::BaseEn => "a03779c86df3323075f5e796cb2ce5029f00ec8869eee3fdfb897afe36c6d002",
128 Self::Base => "2f62d18b50c3f3feafbf990eec23a93d319660b1efbdd3fff55e52b7cde2e374",
129 Self::SmallEn => "0d57184d34ae7d736e5bb2db5bf83debe730bd53dcefa235a0979b9dcfd33fb3",
130 Self::Small => "edd29d67e70b000132af65205b99bb774b77abc13d10103e14f80ce2242913e1",
131 Self::MediumEn => "a163589aa264d5188df3b05ed4eac56bfd97e26910f207809d869f7e99886fd2",
132 Self::Medium => "d3d5696e6a3e0ca2aa08eb31cad208ffa1e87b3cc341f59e628fbdcf8122de9b",
133 Self::LargeV1 => "cbcb187d1e1abe979d33636cdc63381de20738eeda0885c39440b086e184248a",
134 Self::LargeV2 => "c6d6d3dcebc5e0074175386e17eba305fc5cc7d3d5dff3ecfd11e8f2bd4222d7",
135 Self::LargeV3 => "766d11cebbdf5a67c179c5774e2642b609e35e1a30240e7b559d5647c655b0a4",
136 Self::LargeV3Turbo => {
137 "5a4b65b05933d70ce9d5aa6265eb128fa5eba38f6fee40836fdedc4d2fde42ad"
138 }
139 }
140 }
141
142 pub fn multilingual(self) -> bool {
144 !matches!(
145 self,
146 Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn
147 )
148 }
149}
150
151impl Display for WhisperCppModel {
152 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
153 f.write_str(self.id())
154 }
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
158pub struct WhisperCppConfig {
160 #[serde(default)]
161 pub model: WhisperCppModel,
163 pub language: Option<String>,
165 #[serde(default)]
166 pub translate: bool,
168 pub threads: Option<usize>,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
173pub struct WhisperCppSegment {
175 pub index: u64,
177 pub start_seconds: Option<f64>,
179 pub end_seconds: Option<f64>,
181 pub text: String,
183 pub confidence: Option<f32>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
188pub struct WhisperCppTranscription {
190 pub text: Option<String>,
192 pub language: Option<String>,
194 pub segments: Vec<WhisperCppSegment>,
196 pub source: Option<String>,
198}
199
200#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
201#[serde(rename_all = "snake_case")]
202pub enum WhisperCppPhase {
204 Preparing,
206 DownloadingModel,
208 LoadingModel,
210 Transcribing,
212}
213
214impl WhisperCppPhase {
215 pub fn as_str(self) -> &'static str {
217 match self {
218 Self::Preparing => "preparing",
219 Self::DownloadingModel => "downloading_model",
220 Self::LoadingModel => "loading_model",
221 Self::Transcribing => "transcribing",
222 }
223 }
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
227pub struct WhisperCppProgressEvent {
229 pub phase: WhisperCppPhase,
231 pub message: String,
233 pub progress: Option<f32>,
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
238pub struct WhisperCppModelStatus {
240 pub model: WhisperCppModel,
242 pub cached: bool,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
247pub struct WhisperCppCatalog {
249 pub default_model: WhisperCppModel,
251 pub models: Vec<WhisperCppModelStatus>,
253}
254
255#[derive(Debug, thiserror::Error)]
256pub enum WhisperCppError {
258 #[error("I/O error: {0}")]
259 Io(#[from] std::io::Error),
261 #[error("wave input error: {0}")]
262 Wav(#[from] hound::Error),
264 #[error("network error: {0}")]
265 Http(String),
267 #[error("invalid input: {0}")]
268 InvalidInput(String),
270 #[error("unsupported language `{0}`")]
271 UnsupportedLanguage(String),
273 #[error("downloaded model `{model}` failed checksum verification")]
274 InvalidChecksum {
276 model: WhisperCppModel,
278 },
279 #[error("failed to initialize whisper.cpp from `{0}`")]
280 Initialization(String),
282 #[error("whisper.cpp inference failed for `{0}`")]
283 Inference(String),
285 #[error("invalid utf-8 returned by whisper.cpp")]
286 InvalidUtf8,
288}
289
290pub type Result<T> = std::result::Result<T, WhisperCppError>;
292
293type OwnedProgressCallback = dyn FnMut(WhisperCppProgressEvent) + 'static;
294
295#[derive(Clone)]
296pub struct ModelStore {
298 root: PathBuf,
299}
300
301impl Default for ModelStore {
302 fn default() -> Self {
303 Self {
304 root: cache_root().join("whisper-cpp"),
305 }
306 }
307}
308
309impl ModelStore {
310 pub fn new(root: PathBuf) -> Self {
312 Self { root }
313 }
314
315 pub fn models_dir(&self) -> PathBuf {
317 self.root.join("models")
318 }
319
320 pub fn model_path(&self, model: WhisperCppModel) -> PathBuf {
322 self.models_dir().join(model.file_name())
323 }
324
325 pub fn lock_path(&self, model: WhisperCppModel) -> PathBuf {
327 self.models_dir()
328 .join(format!("{}.lock", model.file_name()))
329 }
330
331 pub fn catalog(&self) -> WhisperCppCatalog {
333 WhisperCppCatalog {
334 default_model: WhisperCppModel::default(),
335 models: WhisperCppModel::ALL
336 .into_iter()
337 .map(|model| WhisperCppModelStatus {
338 model,
339 cached: self.model_path(model).is_file(),
340 })
341 .collect(),
342 }
343 }
344
345 #[cfg(feature = "native")]
346 fn ensure_model(
347 &self,
348 model: WhisperCppModel,
349 progress: &mut ProgressSink<'_>,
350 ) -> Result<PathBuf> {
351 fs::create_dir_all(self.models_dir())?;
352 let model_path = self.model_path(model);
353 if model_path.is_file() {
354 return Ok(model_path);
355 }
356
357 let _lock = FileLock::acquire(self.lock_path(model))?;
358 if model_path.is_file() {
359 return Ok(model_path);
360 }
361
362 progress.emit(
363 WhisperCppPhase::DownloadingModel,
364 format!("downloading whisper.cpp model `{model}`"),
365 Some(0.0),
366 );
367
368 let temp_path = model_path.with_extension("bin.part");
369 if temp_path.exists() {
370 let _ = fs::remove_file(&temp_path);
371 }
372
373 let response = ureq::get(&model.download_url())
374 .call()
375 .map_err(|error| WhisperCppError::Http(error.to_string()))?;
376 let total_bytes = response
377 .header("Content-Length")
378 .and_then(|value| value.parse::<u64>().ok());
379 let mut reader = response.into_reader();
380 let mut file = BufWriter::new(File::create(&temp_path)?);
381 let mut hasher = Sha256::new();
382 let mut downloaded = 0_u64;
383 let mut buffer = [0_u8; 64 * 1024];
384
385 loop {
386 let read = reader
387 .read(&mut buffer)
388 .map_err(|error| WhisperCppError::Http(error.to_string()))?;
389 if read == 0 {
390 break;
391 }
392 file.write_all(&buffer[..read])?;
393 hasher.update(&buffer[..read]);
394 downloaded += read as u64;
395 let fraction =
396 total_bytes.map(|total| (downloaded as f32 / total as f32).clamp(0.0, 1.0));
397 progress.emit(
398 WhisperCppPhase::DownloadingModel,
399 format!("downloading whisper.cpp model `{model}`"),
400 fraction,
401 );
402 }
403 file.flush()?;
404
405 let checksum = format!("{:x}", hasher.finalize());
406 if checksum != model.checksum_sha256() {
407 let _ = fs::remove_file(&temp_path);
408 return Err(WhisperCppError::InvalidChecksum { model });
409 }
410
411 fs::rename(temp_path, &model_path)?;
412 Ok(model_path)
413 }
414}
415
416pub struct WhisperCppTranscriber {
418 config: WhisperCppConfig,
419 store: ModelStore,
420 progress: Option<Box<OwnedProgressCallback>>,
421}
422
423impl WhisperCppTranscriber {
424 pub fn new(config: WhisperCppConfig) -> Self {
426 Self {
427 config,
428 store: ModelStore::default(),
429 progress: None,
430 }
431 }
432
433 pub fn with_model_store(mut self, store: ModelStore) -> Self {
435 self.store = store;
436 self
437 }
438
439 pub fn on_progress<F>(mut self, callback: F) -> Self
441 where
442 F: FnMut(WhisperCppProgressEvent) + 'static,
443 {
444 self.progress = Some(Box::new(callback));
445 self
446 }
447
448 pub fn transcribe_file(&mut self, input: &Path) -> Result<WhisperCppTranscription> {
450 let store = self.store.clone();
451 let config = self.config.clone();
452 let mut progress = ProgressSink::new(self.progress_deref_mut());
453 transcribe_impl(&store, &config, input, &mut progress)
454 }
455
456 pub fn transcribe_file_with_progress(
458 &mut self,
459 input: &Path,
460 progress: &mut dyn FnMut(WhisperCppProgressEvent),
461 ) -> Result<WhisperCppTranscription> {
462 let mut progress = ProgressSink::new(Some(progress));
463 transcribe_impl(&self.store, &self.config, input, &mut progress)
464 }
465
466 fn progress_deref_mut(&mut self) -> Option<&mut dyn FnMut(WhisperCppProgressEvent)> {
467 self.progress
468 .as_mut()
469 .map(|callback| callback.as_mut() as &mut dyn FnMut(WhisperCppProgressEvent))
470 }
471}
472
473pub fn transcription_catalog() -> WhisperCppCatalog {
475 ModelStore::default().catalog()
476}
477
478pub fn whisper_cpp_system_info() -> Option<String> {
480 #[cfg(not(feature = "native"))]
481 {
482 None
483 }
484
485 #[cfg(feature = "native")]
486 {
487 let value = unsafe { ffi::whisper_print_system_info() };
488 if value.is_null() {
489 return None;
490 }
491 unsafe { CStr::from_ptr(value) }
492 .to_str()
493 .ok()
494 .map(|value| value.to_string())
495 }
496}
497
498#[cfg(feature = "native")]
499fn transcribe_impl(
500 store: &ModelStore,
501 config: &WhisperCppConfig,
502 input: &Path,
503 progress: &mut ProgressSink<'_>,
504) -> Result<WhisperCppTranscription> {
505 let model = config.model;
506 progress.emit(
507 WhisperCppPhase::Preparing,
508 format!(
509 "preparing native whisper.cpp transcription for {}",
510 input.display()
511 ),
512 None,
513 );
514
515 let model_path = store.ensure_model(model, progress)?;
516 progress.emit(
517 WhisperCppPhase::LoadingModel,
518 format!("loading whisper.cpp model `{model}`"),
519 None,
520 );
521
522 let audio = read_wav_mono_f32(input)?;
523 progress.emit(
524 WhisperCppPhase::Transcribing,
525 format!("transcribing audio with whisper.cpp model `{model}`"),
526 None,
527 );
528
529 let context = WhisperContext::from_model(&model_path)?;
530 let mut params = unsafe {
531 ffi::whisper_full_default_params(ffi::whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY)
532 };
533 params.n_threads = resolve_threads(config.threads);
534 params.translate = config.translate;
535 params.print_progress = false;
536 params.print_realtime = false;
537 params.print_special = false;
538 params.print_timestamps = false;
539 params.no_timestamps = false;
540
541 let language = resolve_language(config)?;
542 if let Some(language) = language.as_ref() {
543 let lang_id = unsafe { ffi::whisper_lang_id(language.as_ptr()) };
544 if lang_id < 0 {
545 return Err(WhisperCppError::UnsupportedLanguage(
546 language.to_string_lossy().into_owned(),
547 ));
548 }
549 params.language = language.as_ptr();
550 } else {
551 params.language = std::ptr::null();
552 }
553 params.detect_language = false;
554
555 let status = unsafe {
556 ffi::whisper_full(
557 context.raw,
558 params,
559 audio.samples.as_ptr(),
560 audio.samples.len() as i32,
561 )
562 };
563 if status != 0 {
564 return Err(WhisperCppError::Inference(model_path.display().to_string()));
565 }
566
567 let segment_count = unsafe { ffi::whisper_full_n_segments(context.raw) };
568 let mut segments = Vec::with_capacity(segment_count.max(0) as usize);
569 for index in 0..segment_count {
570 let text_ptr = unsafe { ffi::whisper_full_get_segment_text(context.raw, index) };
571 let text = c_string(text_ptr)?.trim().to_string();
572 let start = unsafe { ffi::whisper_full_get_segment_t0(context.raw, index) };
573 let end = unsafe { ffi::whisper_full_get_segment_t1(context.raw, index) };
574 let token_count = unsafe { ffi::whisper_full_n_tokens(context.raw, index) };
575 let confidence = if token_count > 0 {
576 let mut total = 0.0_f32;
577 for token_index in 0..token_count {
578 total += unsafe { ffi::whisper_full_get_token_p(context.raw, index, token_index) };
579 }
580 Some(total / token_count as f32)
581 } else {
582 None
583 };
584 segments.push(WhisperCppSegment {
585 index: index as u64,
586 start_seconds: Some(timestamp_to_seconds(start)),
587 end_seconds: Some(timestamp_to_seconds(end)),
588 text,
589 confidence,
590 });
591 }
592
593 let language = unsafe { ffi::whisper_full_lang_id(context.raw) };
594 let language = if language >= 0 {
595 Some(c_string(unsafe { ffi::whisper_lang_str(language) })?)
596 } else {
597 None
598 };
599 let text = join_segments(&segments);
600
601 Ok(WhisperCppTranscription {
602 text,
603 language,
604 segments,
605 source: Some(model_path.to_string_lossy().into_owned()),
606 })
607}
608
609#[cfg(not(feature = "native"))]
610fn transcribe_impl(
611 _store: &ModelStore,
612 _config: &WhisperCppConfig,
613 _input: &Path,
614 _progress: &mut ProgressSink<'_>,
615) -> Result<WhisperCppTranscription> {
616 Err(WhisperCppError::Initialization(
617 "text-transcripts was built without the `native` feature".to_string(),
618 ))
619}
620
621#[cfg(any(feature = "native", test))]
622fn resolve_language(config: &WhisperCppConfig) -> Result<Option<CString>> {
623 match config.language.as_deref().map(str::trim) {
624 Some("") => resolve_default_language(config.model),
625 Some(value) if value.eq_ignore_ascii_case("auto") => resolve_default_language(config.model),
626 Some(value) => CString::new(value)
627 .map(Some)
628 .map_err(|_| WhisperCppError::UnsupportedLanguage(value.to_string())),
629 None => resolve_default_language(config.model),
630 }
631}
632
633#[cfg(any(feature = "native", test))]
634fn resolve_default_language(model: WhisperCppModel) -> Result<Option<CString>> {
635 if model.multilingual() {
636 Ok(None)
637 } else {
638 CString::new("en")
639 .map(Some)
640 .map_err(|_| WhisperCppError::UnsupportedLanguage("en".to_string()))
641 }
642}
643
644#[cfg_attr(not(feature = "native"), allow(dead_code))]
645struct ProgressSink<'a> {
646 callback: Option<&'a mut dyn FnMut(WhisperCppProgressEvent)>,
647}
648
649impl<'a> ProgressSink<'a> {
650 fn new(callback: Option<&'a mut dyn FnMut(WhisperCppProgressEvent)>) -> Self {
651 Self { callback }
652 }
653
654 #[cfg(feature = "native")]
655 fn emit(&mut self, phase: WhisperCppPhase, message: String, progress: Option<f32>) {
656 if let Some(callback) = self.callback.as_mut() {
657 callback(WhisperCppProgressEvent {
658 phase,
659 message,
660 progress,
661 });
662 }
663 }
664}
665
666#[cfg(feature = "native")]
667fn read_wav_mono_f32(path: &Path) -> Result<AudioSamples> {
668 let mut reader = hound::WavReader::open(path)?;
669 let spec = reader.spec();
670 if spec.channels == 0 {
671 return Err(WhisperCppError::InvalidInput(
672 "wav file has no channels".to_string(),
673 ));
674 }
675 if spec.sample_rate != 16_000 {
676 return Err(WhisperCppError::InvalidInput(format!(
677 "expected 16 kHz wav input, got {} Hz",
678 spec.sample_rate
679 )));
680 }
681
682 let interleaved = match spec.sample_format {
683 hound::SampleFormat::Int => read_int_samples(&mut reader, spec.bits_per_sample)?,
684 hound::SampleFormat::Float => reader
685 .samples::<f32>()
686 .collect::<std::result::Result<Vec<_>, _>>()?,
687 };
688
689 let channels = spec.channels as usize;
690 let samples = if channels == 1 {
691 interleaved
692 } else {
693 interleaved
694 .chunks(channels)
695 .map(|frame| frame.iter().copied().sum::<f32>() / frame.len() as f32)
696 .collect()
697 };
698
699 Ok(AudioSamples { samples })
700}
701
702#[cfg(feature = "native")]
703fn read_int_samples(
704 reader: &mut hound::WavReader<std::io::BufReader<File>>,
705 bits_per_sample: u16,
706) -> Result<Vec<f32>> {
707 let scale = ((1_i64 << (bits_per_sample.saturating_sub(1) as u32)) - 1) as f32;
708 if bits_per_sample <= 16 {
709 Ok(reader
710 .samples::<i16>()
711 .map(|sample| sample.map(|sample| sample as f32 / scale))
712 .collect::<std::result::Result<Vec<_>, _>>()?)
713 } else {
714 Ok(reader
715 .samples::<i32>()
716 .map(|sample| sample.map(|sample| sample as f32 / scale))
717 .collect::<std::result::Result<Vec<_>, _>>()?)
718 }
719}
720
721#[cfg(feature = "native")]
722fn resolve_threads(value: Option<usize>) -> i32 {
723 value
724 .or_else(|| thread::available_parallelism().ok().map(usize::from))
725 .unwrap_or(4)
726 .min(i32::MAX as usize) as i32
727}
728
729#[cfg(feature = "native")]
730fn timestamp_to_seconds(value: i64) -> f64 {
731 value as f64 / 100.0
732}
733
734#[cfg(feature = "native")]
735fn join_segments(segments: &[WhisperCppSegment]) -> Option<String> {
736 let text = segments
737 .iter()
738 .map(|segment| segment.text.trim())
739 .filter(|text| !text.is_empty())
740 .collect::<Vec<_>>()
741 .join(" ");
742 (!text.is_empty()).then_some(text)
743}
744
745#[cfg(feature = "native")]
746fn c_string(value: *const std::ffi::c_char) -> Result<String> {
747 if value.is_null() {
748 return Ok(String::new());
749 }
750 unsafe { CStr::from_ptr(value) }
751 .to_str()
752 .map(|value| value.to_string())
753 .map_err(|_| WhisperCppError::InvalidUtf8)
754}
755
756fn cache_root() -> PathBuf {
757 if let Some(dir) = std::env::var_os("VIDEO_ANALYSIS_STUDIO_CACHE_DIR") {
758 return PathBuf::from(dir);
759 }
760 if let Some(dir) = std::env::var_os("XDG_CACHE_HOME") {
761 return PathBuf::from(dir).join("video-analysis-studio");
762 }
763 if cfg!(target_os = "windows") {
764 if let Some(dir) = std::env::var_os("LOCALAPPDATA") {
765 return PathBuf::from(dir).join("video-analysis-studio");
766 }
767 }
768 if let Some(home) = std::env::var_os("HOME") {
769 return PathBuf::from(home)
770 .join(".cache")
771 .join("video-analysis-studio");
772 }
773 PathBuf::from(".cache/video-analysis-studio")
774}
775
776#[cfg(feature = "native")]
777struct AudioSamples {
778 samples: Vec<f32>,
779}
780
781#[cfg(feature = "native")]
782struct WhisperContext {
783 raw: *mut ffi::whisper_context,
784}
785
786#[cfg(feature = "native")]
787impl WhisperContext {
788 fn from_model(path: &Path) -> Result<Self> {
789 let model_path = CString::new(path.to_string_lossy().into_owned())
790 .map_err(|_| WhisperCppError::Initialization(path.display().to_string()))?;
791 let mut params = unsafe { ffi::whisper_context_default_params() };
792 params.use_gpu = cfg!(target_os = "macos");
793 params.flash_attn = false;
794 let raw = unsafe { ffi::whisper_init_from_file_with_params(model_path.as_ptr(), params) };
795 if raw.is_null() {
796 return Err(WhisperCppError::Initialization(path.display().to_string()));
797 }
798 Ok(Self { raw })
799 }
800}
801
802#[cfg(feature = "native")]
803impl Drop for WhisperContext {
804 fn drop(&mut self) {
805 if !self.raw.is_null() {
806 unsafe { ffi::whisper_free(self.raw) };
807 }
808 }
809}
810
811#[cfg(any(feature = "native", test))]
812struct FileLock {
813 path: PathBuf,
814}
815
816#[cfg(any(feature = "native", test))]
817impl FileLock {
818 fn acquire(path: PathBuf) -> Result<Self> {
819 let deadline = Instant::now() + Duration::from_secs(120);
820 loop {
821 match OpenOptions::new().create_new(true).write(true).open(&path) {
822 Ok(mut file) => {
823 let _ = writeln!(file, "{}", std::process::id());
824 return Ok(Self { path });
825 }
826 Err(error) if error.kind() == std::io::ErrorKind::AlreadyExists => {
827 if Instant::now() >= deadline {
828 return Err(WhisperCppError::Io(error));
829 }
830 thread::sleep(Duration::from_millis(250));
831 }
832 Err(error) => return Err(WhisperCppError::Io(error)),
833 }
834 }
835 }
836}
837
838#[cfg(any(feature = "native", test))]
839impl Drop for FileLock {
840 fn drop(&mut self) {
841 let _ = fs::remove_file(&self.path);
842 }
843}
844
845#[cfg(test)]
846mod tests {
847 use super::*;
848 use tempfile::tempdir;
849
850 #[test]
851 fn model_metadata_matches_expected_file_names() {
852 assert_eq!(WhisperCppModel::BaseEn.file_name(), "ggml-base.en.bin");
853 assert_eq!(
854 WhisperCppModel::LargeV3Turbo.download_url(),
855 "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-turbo.bin"
856 );
857 }
858
859 #[test]
860 fn catalog_uses_base_en_by_default() {
861 let catalog = ModelStore::new(PathBuf::from("/tmp/video-analysis-studio-test")).catalog();
862 assert_eq!(catalog.default_model, WhisperCppModel::BaseEn);
863 assert_eq!(catalog.models.len(), WhisperCppModel::ALL.len());
864 }
865
866 #[test]
867 fn cache_paths_are_stable() {
868 let store = ModelStore::new(PathBuf::from("/tmp/video-analysis-studio-test"));
869 assert_eq!(
870 store.model_path(WhisperCppModel::SmallEn),
871 PathBuf::from("/tmp/video-analysis-studio-test/models/ggml-small.en.bin")
872 );
873 assert_eq!(
874 store.lock_path(WhisperCppModel::SmallEn),
875 PathBuf::from("/tmp/video-analysis-studio-test/models/ggml-small.en.bin.lock")
876 );
877 }
878
879 #[test]
880 fn file_lock_creates_and_releases_lock_path() {
881 let dir = tempdir().unwrap();
882 let path = dir.path().join("model.lock");
883 {
884 let _lock = FileLock::acquire(path.clone()).unwrap();
885 assert!(path.is_file());
886 }
887 assert!(!path.exists());
888 }
889
890 #[test]
891 fn english_only_models_default_to_english() {
892 let config = WhisperCppConfig {
893 model: WhisperCppModel::BaseEn,
894 language: None,
895 translate: false,
896 threads: None,
897 };
898
899 let language = resolve_language(&config).unwrap().unwrap();
900 assert_eq!(language.to_str().unwrap(), "en");
901 }
902
903 #[test]
904 fn multilingual_models_default_to_auto_detection_without_detect_only_mode() {
905 let config = WhisperCppConfig {
906 model: WhisperCppModel::Base,
907 language: None,
908 translate: false,
909 threads: None,
910 };
911
912 assert_eq!(resolve_language(&config).unwrap(), None);
913 }
914
915 #[test]
916 fn auto_language_uses_english_for_english_only_models() {
917 let config = WhisperCppConfig {
918 model: WhisperCppModel::SmallEn,
919 language: Some("auto".to_string()),
920 translate: false,
921 threads: None,
922 };
923
924 let language = resolve_language(&config).unwrap().unwrap();
925 assert_eq!(language.to_str().unwrap(), "en");
926 }
927}