Skip to main content

text_transcripts/
whisper_cpp.rs

1//! Native whisper.cpp transcription support for text transcripts.
2
3#[cfg(feature = "native")]
4#[path = "ffi.rs"]
5mod ffi;
6
7#[cfg(feature = "native")]
8use std::ffi::CStr;
9#[cfg(any(feature = "native", test))]
10use std::ffi::CString;
11use std::fmt::{Display, Formatter};
12#[cfg(feature = "native")]
13use std::fs::File;
14#[cfg(any(feature = "native", test))]
15use std::fs::{self, OpenOptions};
16#[cfg(any(feature = "native", test))]
17use std::io::Write;
18#[cfg(feature = "native")]
19use std::io::{BufWriter, Read};
20use std::path::{Path, PathBuf};
21#[cfg(any(feature = "native", test))]
22use std::thread;
23#[cfg(any(feature = "native", test))]
24use std::time::{Duration, Instant};
25
26use serde::{Deserialize, Serialize};
27#[cfg(feature = "native")]
28use sha2::{Digest, Sha256};
29
30#[derive(
31    Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash, Default,
32)]
33/// Variants describing whisper cpp model.
34pub enum WhisperCppModel {
35    #[serde(rename = "tiny.en")]
36    /// The tiny en variant.
37    TinyEn,
38    #[serde(rename = "tiny")]
39    /// The tiny variant.
40    Tiny,
41    #[serde(rename = "base.en")]
42    #[default]
43    /// The base en variant.
44    BaseEn,
45    #[serde(rename = "base")]
46    /// The base variant.
47    Base,
48    #[serde(rename = "small.en")]
49    /// The small en variant.
50    SmallEn,
51    #[serde(rename = "small")]
52    /// The small variant.
53    Small,
54    #[serde(rename = "medium.en")]
55    /// The medium en variant.
56    MediumEn,
57    #[serde(rename = "medium")]
58    /// The medium variant.
59    Medium,
60    #[serde(rename = "large-v1")]
61    /// The large v1 variant.
62    LargeV1,
63    #[serde(rename = "large-v2")]
64    /// The large v2 variant.
65    LargeV2,
66    #[serde(rename = "large-v3")]
67    /// The large v3 variant.
68    LargeV3,
69    #[serde(rename = "large-v3-turbo")]
70    /// The large v3 turbo variant.
71    LargeV3Turbo,
72}
73
74impl WhisperCppModel {
75    /// Constant for all.
76    pub const ALL: [Self; 12] = [
77        Self::TinyEn,
78        Self::Tiny,
79        Self::BaseEn,
80        Self::Base,
81        Self::SmallEn,
82        Self::Small,
83        Self::MediumEn,
84        Self::Medium,
85        Self::LargeV1,
86        Self::LargeV2,
87        Self::LargeV3,
88        Self::LargeV3Turbo,
89    ];
90
91    /// Returns identifier.
92    pub fn id(self) -> &'static str {
93        match self {
94            Self::TinyEn => "tiny.en",
95            Self::Tiny => "tiny",
96            Self::BaseEn => "base.en",
97            Self::Base => "base",
98            Self::SmallEn => "small.en",
99            Self::Small => "small",
100            Self::MediumEn => "medium.en",
101            Self::Medium => "medium",
102            Self::LargeV1 => "large-v1",
103            Self::LargeV2 => "large-v2",
104            Self::LargeV3 => "large-v3",
105            Self::LargeV3Turbo => "large-v3-turbo",
106        }
107    }
108
109    /// Returns file name.
110    pub fn file_name(self) -> String {
111        format!("ggml-{}.bin", self.id())
112    }
113
114    /// Returns download URL.
115    pub fn download_url(self) -> String {
116        format!(
117            "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/{}",
118            self.file_name()
119        )
120    }
121
122    /// Returns checksum sha256.
123    pub fn checksum_sha256(self) -> &'static str {
124        match self {
125            Self::TinyEn => "0d686a2a6a22b02da2ef3101d4c86e68461363a623c58f27f81b1b2d36b42317",
126            Self::Tiny => "518970a29bedb265f23ac48d486ddbc63bedffd90967b10140ae5ac61243acf3",
127            Self::BaseEn => "a03779c86df3323075f5e796cb2ce5029f00ec8869eee3fdfb897afe36c6d002",
128            Self::Base => "2f62d18b50c3f3feafbf990eec23a93d319660b1efbdd3fff55e52b7cde2e374",
129            Self::SmallEn => "0d57184d34ae7d736e5bb2db5bf83debe730bd53dcefa235a0979b9dcfd33fb3",
130            Self::Small => "edd29d67e70b000132af65205b99bb774b77abc13d10103e14f80ce2242913e1",
131            Self::MediumEn => "a163589aa264d5188df3b05ed4eac56bfd97e26910f207809d869f7e99886fd2",
132            Self::Medium => "d3d5696e6a3e0ca2aa08eb31cad208ffa1e87b3cc341f59e628fbdcf8122de9b",
133            Self::LargeV1 => "cbcb187d1e1abe979d33636cdc63381de20738eeda0885c39440b086e184248a",
134            Self::LargeV2 => "c6d6d3dcebc5e0074175386e17eba305fc5cc7d3d5dff3ecfd11e8f2bd4222d7",
135            Self::LargeV3 => "766d11cebbdf5a67c179c5774e2642b609e35e1a30240e7b559d5647c655b0a4",
136            Self::LargeV3Turbo => {
137                "5a4b65b05933d70ce9d5aa6265eb128fa5eba38f6fee40836fdedc4d2fde42ad"
138            }
139        }
140    }
141
142    /// Returns multilingual.
143    pub fn multilingual(self) -> bool {
144        !matches!(
145            self,
146            Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn
147        )
148    }
149}
150
151impl Display for WhisperCppModel {
152    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
153        f.write_str(self.id())
154    }
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
158/// Data type for whisper cpp config.
159pub struct WhisperCppConfig {
160    #[serde(default)]
161    /// The model value.
162    pub model: WhisperCppModel,
163    /// Language tag for this value.
164    pub language: Option<String>,
165    #[serde(default)]
166    /// The translate value.
167    pub translate: bool,
168    /// The threads value.
169    pub threads: Option<usize>,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
173/// Data type for whisper cpp segment.
174pub struct WhisperCppSegment {
175    /// The index value.
176    pub index: u64,
177    /// The start seconds value.
178    pub start_seconds: Option<f64>,
179    /// The end seconds value.
180    pub end_seconds: Option<f64>,
181    /// Text content for this value.
182    pub text: String,
183    /// Confidence score for this value.
184    pub confidence: Option<f32>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
188/// Data type for whisper cpp transcription.
189pub struct WhisperCppTranscription {
190    /// Text content for this value.
191    pub text: Option<String>,
192    /// Language tag for this value.
193    pub language: Option<String>,
194    /// The segments value.
195    pub segments: Vec<WhisperCppSegment>,
196    /// The source value.
197    pub source: Option<String>,
198}
199
200#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
201#[serde(rename_all = "snake_case")]
202/// Variants describing whisper cpp phase.
203pub enum WhisperCppPhase {
204    /// The preparing variant.
205    Preparing,
206    /// The downloading model variant.
207    DownloadingModel,
208    /// The loading model variant.
209    LoadingModel,
210    /// The transcribing variant.
211    Transcribing,
212}
213
214impl WhisperCppPhase {
215    /// Borrows this value as a str.
216    pub fn as_str(self) -> &'static str {
217        match self {
218            Self::Preparing => "preparing",
219            Self::DownloadingModel => "downloading_model",
220            Self::LoadingModel => "loading_model",
221            Self::Transcribing => "transcribing",
222        }
223    }
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
227/// Data type for whisper cpp progress event.
228pub struct WhisperCppProgressEvent {
229    /// The phase value.
230    pub phase: WhisperCppPhase,
231    /// The message value.
232    pub message: String,
233    /// The progress value.
234    pub progress: Option<f32>,
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
238/// Data type for whisper cpp model status.
239pub struct WhisperCppModelStatus {
240    /// The model value.
241    pub model: WhisperCppModel,
242    /// The cached value.
243    pub cached: bool,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
247/// Data type for whisper cpp catalog.
248pub struct WhisperCppCatalog {
249    /// The default model value.
250    pub default_model: WhisperCppModel,
251    /// The models value.
252    pub models: Vec<WhisperCppModelStatus>,
253}
254
255#[derive(Debug, thiserror::Error)]
256/// Variants describing whisper cpp error.
257pub enum WhisperCppError {
258    #[error("I/O error: {0}")]
259    /// The I/O variant.
260    Io(#[from] std::io::Error),
261    #[error("wave input error: {0}")]
262    /// The wav variant.
263    Wav(#[from] hound::Error),
264    #[error("network error: {0}")]
265    /// The http variant.
266    Http(String),
267    #[error("invalid input: {0}")]
268    /// The invalid input variant.
269    InvalidInput(String),
270    #[error("unsupported language `{0}`")]
271    /// The unsupported language variant.
272    UnsupportedLanguage(String),
273    #[error("downloaded model `{model}` failed checksum verification")]
274    /// The invalid checksum variant.
275    InvalidChecksum {
276        /// Model associated with this variant.
277        model: WhisperCppModel,
278    },
279    #[error("failed to initialize whisper.cpp from `{0}`")]
280    /// The initialization variant.
281    Initialization(String),
282    #[error("whisper.cpp inference failed for `{0}`")]
283    /// The inference variant.
284    Inference(String),
285    #[error("invalid utf-8 returned by whisper.cpp")]
286    /// The invalid utf8 variant.
287    InvalidUtf8,
288}
289
290/// Type alias for result.
291pub type Result<T> = std::result::Result<T, WhisperCppError>;
292
293type OwnedProgressCallback = dyn FnMut(WhisperCppProgressEvent) + 'static;
294
295#[derive(Clone)]
296/// Data type for model store.
297pub struct ModelStore {
298    root: PathBuf,
299}
300
301impl Default for ModelStore {
302    fn default() -> Self {
303        Self {
304            root: cache_root().join("whisper-cpp"),
305        }
306    }
307}
308
309impl ModelStore {
310    /// Creates a new value.
311    pub fn new(root: PathBuf) -> Self {
312        Self { root }
313    }
314
315    /// Returns models dir.
316    pub fn models_dir(&self) -> PathBuf {
317        self.root.join("models")
318    }
319
320    /// Returns model path.
321    pub fn model_path(&self, model: WhisperCppModel) -> PathBuf {
322        self.models_dir().join(model.file_name())
323    }
324
325    /// Returns lock path.
326    pub fn lock_path(&self, model: WhisperCppModel) -> PathBuf {
327        self.models_dir()
328            .join(format!("{}.lock", model.file_name()))
329    }
330
331    /// Returns catalog.
332    pub fn catalog(&self) -> WhisperCppCatalog {
333        WhisperCppCatalog {
334            default_model: WhisperCppModel::default(),
335            models: WhisperCppModel::ALL
336                .into_iter()
337                .map(|model| WhisperCppModelStatus {
338                    model,
339                    cached: self.model_path(model).is_file(),
340                })
341                .collect(),
342        }
343    }
344
345    #[cfg(feature = "native")]
346    fn ensure_model(
347        &self,
348        model: WhisperCppModel,
349        progress: &mut ProgressSink<'_>,
350    ) -> Result<PathBuf> {
351        fs::create_dir_all(self.models_dir())?;
352        let model_path = self.model_path(model);
353        if model_path.is_file() {
354            return Ok(model_path);
355        }
356
357        let _lock = FileLock::acquire(self.lock_path(model))?;
358        if model_path.is_file() {
359            return Ok(model_path);
360        }
361
362        progress.emit(
363            WhisperCppPhase::DownloadingModel,
364            format!("downloading whisper.cpp model `{model}`"),
365            Some(0.0),
366        );
367
368        let temp_path = model_path.with_extension("bin.part");
369        if temp_path.exists() {
370            let _ = fs::remove_file(&temp_path);
371        }
372
373        let response = ureq::get(&model.download_url())
374            .call()
375            .map_err(|error| WhisperCppError::Http(error.to_string()))?;
376        let total_bytes = response
377            .header("Content-Length")
378            .and_then(|value| value.parse::<u64>().ok());
379        let mut reader = response.into_reader();
380        let mut file = BufWriter::new(File::create(&temp_path)?);
381        let mut hasher = Sha256::new();
382        let mut downloaded = 0_u64;
383        let mut buffer = [0_u8; 64 * 1024];
384
385        loop {
386            let read = reader
387                .read(&mut buffer)
388                .map_err(|error| WhisperCppError::Http(error.to_string()))?;
389            if read == 0 {
390                break;
391            }
392            file.write_all(&buffer[..read])?;
393            hasher.update(&buffer[..read]);
394            downloaded += read as u64;
395            let fraction =
396                total_bytes.map(|total| (downloaded as f32 / total as f32).clamp(0.0, 1.0));
397            progress.emit(
398                WhisperCppPhase::DownloadingModel,
399                format!("downloading whisper.cpp model `{model}`"),
400                fraction,
401            );
402        }
403        file.flush()?;
404
405        let checksum = format!("{:x}", hasher.finalize());
406        if checksum != model.checksum_sha256() {
407            let _ = fs::remove_file(&temp_path);
408            return Err(WhisperCppError::InvalidChecksum { model });
409        }
410
411        fs::rename(temp_path, &model_path)?;
412        Ok(model_path)
413    }
414}
415
416/// Data type for whisper cpp transcriber.
417pub struct WhisperCppTranscriber {
418    config: WhisperCppConfig,
419    store: ModelStore,
420    progress: Option<Box<OwnedProgressCallback>>,
421}
422
423impl WhisperCppTranscriber {
424    /// Creates a new value.
425    pub fn new(config: WhisperCppConfig) -> Self {
426        Self {
427            config,
428            store: ModelStore::default(),
429            progress: None,
430        }
431    }
432
433    /// Returns this value with model store.
434    pub fn with_model_store(mut self, store: ModelStore) -> Self {
435        self.store = store;
436        self
437    }
438
439    /// Returns on progress.
440    pub fn on_progress<F>(mut self, callback: F) -> Self
441    where
442        F: FnMut(WhisperCppProgressEvent) + 'static,
443    {
444        self.progress = Some(Box::new(callback));
445        self
446    }
447
448    /// Returns transcribe file.
449    pub fn transcribe_file(&mut self, input: &Path) -> Result<WhisperCppTranscription> {
450        let store = self.store.clone();
451        let config = self.config.clone();
452        let mut progress = ProgressSink::new(self.progress_deref_mut());
453        transcribe_impl(&store, &config, input, &mut progress)
454    }
455
456    /// Returns transcribe file with progress.
457    pub fn transcribe_file_with_progress(
458        &mut self,
459        input: &Path,
460        progress: &mut dyn FnMut(WhisperCppProgressEvent),
461    ) -> Result<WhisperCppTranscription> {
462        let mut progress = ProgressSink::new(Some(progress));
463        transcribe_impl(&self.store, &self.config, input, &mut progress)
464    }
465
466    fn progress_deref_mut(&mut self) -> Option<&mut dyn FnMut(WhisperCppProgressEvent)> {
467        self.progress
468            .as_mut()
469            .map(|callback| callback.as_mut() as &mut dyn FnMut(WhisperCppProgressEvent))
470    }
471}
472
473/// Returns transcription catalog.
474pub fn transcription_catalog() -> WhisperCppCatalog {
475    ModelStore::default().catalog()
476}
477
478/// Returns whisper cpp system info.
479pub fn whisper_cpp_system_info() -> Option<String> {
480    #[cfg(not(feature = "native"))]
481    {
482        None
483    }
484
485    #[cfg(feature = "native")]
486    {
487        let value = unsafe { ffi::whisper_print_system_info() };
488        if value.is_null() {
489            return None;
490        }
491        unsafe { CStr::from_ptr(value) }
492            .to_str()
493            .ok()
494            .map(|value| value.to_string())
495    }
496}
497
498#[cfg(feature = "native")]
499fn transcribe_impl(
500    store: &ModelStore,
501    config: &WhisperCppConfig,
502    input: &Path,
503    progress: &mut ProgressSink<'_>,
504) -> Result<WhisperCppTranscription> {
505    let model = config.model;
506    progress.emit(
507        WhisperCppPhase::Preparing,
508        format!(
509            "preparing native whisper.cpp transcription for {}",
510            input.display()
511        ),
512        None,
513    );
514
515    let model_path = store.ensure_model(model, progress)?;
516    progress.emit(
517        WhisperCppPhase::LoadingModel,
518        format!("loading whisper.cpp model `{model}`"),
519        None,
520    );
521
522    let audio = read_wav_mono_f32(input)?;
523    progress.emit(
524        WhisperCppPhase::Transcribing,
525        format!("transcribing audio with whisper.cpp model `{model}`"),
526        None,
527    );
528
529    let context = WhisperContext::from_model(&model_path)?;
530    let mut params = unsafe {
531        ffi::whisper_full_default_params(ffi::whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY)
532    };
533    params.n_threads = resolve_threads(config.threads);
534    params.translate = config.translate;
535    params.print_progress = false;
536    params.print_realtime = false;
537    params.print_special = false;
538    params.print_timestamps = false;
539    params.no_timestamps = false;
540
541    let language = resolve_language(config)?;
542    if let Some(language) = language.as_ref() {
543        let lang_id = unsafe { ffi::whisper_lang_id(language.as_ptr()) };
544        if lang_id < 0 {
545            return Err(WhisperCppError::UnsupportedLanguage(
546                language.to_string_lossy().into_owned(),
547            ));
548        }
549        params.language = language.as_ptr();
550    } else {
551        params.language = std::ptr::null();
552    }
553    params.detect_language = false;
554
555    let status = unsafe {
556        ffi::whisper_full(
557            context.raw,
558            params,
559            audio.samples.as_ptr(),
560            audio.samples.len() as i32,
561        )
562    };
563    if status != 0 {
564        return Err(WhisperCppError::Inference(model_path.display().to_string()));
565    }
566
567    let segment_count = unsafe { ffi::whisper_full_n_segments(context.raw) };
568    let mut segments = Vec::with_capacity(segment_count.max(0) as usize);
569    for index in 0..segment_count {
570        let text_ptr = unsafe { ffi::whisper_full_get_segment_text(context.raw, index) };
571        let text = c_string(text_ptr)?.trim().to_string();
572        let start = unsafe { ffi::whisper_full_get_segment_t0(context.raw, index) };
573        let end = unsafe { ffi::whisper_full_get_segment_t1(context.raw, index) };
574        let token_count = unsafe { ffi::whisper_full_n_tokens(context.raw, index) };
575        let confidence = if token_count > 0 {
576            let mut total = 0.0_f32;
577            for token_index in 0..token_count {
578                total += unsafe { ffi::whisper_full_get_token_p(context.raw, index, token_index) };
579            }
580            Some(total / token_count as f32)
581        } else {
582            None
583        };
584        segments.push(WhisperCppSegment {
585            index: index as u64,
586            start_seconds: Some(timestamp_to_seconds(start)),
587            end_seconds: Some(timestamp_to_seconds(end)),
588            text,
589            confidence,
590        });
591    }
592
593    let language = unsafe { ffi::whisper_full_lang_id(context.raw) };
594    let language = if language >= 0 {
595        Some(c_string(unsafe { ffi::whisper_lang_str(language) })?)
596    } else {
597        None
598    };
599    let text = join_segments(&segments);
600
601    Ok(WhisperCppTranscription {
602        text,
603        language,
604        segments,
605        source: Some(model_path.to_string_lossy().into_owned()),
606    })
607}
608
609#[cfg(not(feature = "native"))]
610fn transcribe_impl(
611    _store: &ModelStore,
612    _config: &WhisperCppConfig,
613    _input: &Path,
614    _progress: &mut ProgressSink<'_>,
615) -> Result<WhisperCppTranscription> {
616    Err(WhisperCppError::Initialization(
617        "text-transcripts was built without the `native` feature".to_string(),
618    ))
619}
620
621#[cfg(any(feature = "native", test))]
622fn resolve_language(config: &WhisperCppConfig) -> Result<Option<CString>> {
623    match config.language.as_deref().map(str::trim) {
624        Some("") => resolve_default_language(config.model),
625        Some(value) if value.eq_ignore_ascii_case("auto") => resolve_default_language(config.model),
626        Some(value) => CString::new(value)
627            .map(Some)
628            .map_err(|_| WhisperCppError::UnsupportedLanguage(value.to_string())),
629        None => resolve_default_language(config.model),
630    }
631}
632
633#[cfg(any(feature = "native", test))]
634fn resolve_default_language(model: WhisperCppModel) -> Result<Option<CString>> {
635    if model.multilingual() {
636        Ok(None)
637    } else {
638        CString::new("en")
639            .map(Some)
640            .map_err(|_| WhisperCppError::UnsupportedLanguage("en".to_string()))
641    }
642}
643
644#[cfg_attr(not(feature = "native"), allow(dead_code))]
645struct ProgressSink<'a> {
646    callback: Option<&'a mut dyn FnMut(WhisperCppProgressEvent)>,
647}
648
649impl<'a> ProgressSink<'a> {
650    fn new(callback: Option<&'a mut dyn FnMut(WhisperCppProgressEvent)>) -> Self {
651        Self { callback }
652    }
653
654    #[cfg(feature = "native")]
655    fn emit(&mut self, phase: WhisperCppPhase, message: String, progress: Option<f32>) {
656        if let Some(callback) = self.callback.as_mut() {
657            callback(WhisperCppProgressEvent {
658                phase,
659                message,
660                progress,
661            });
662        }
663    }
664}
665
666#[cfg(feature = "native")]
667fn read_wav_mono_f32(path: &Path) -> Result<AudioSamples> {
668    let mut reader = hound::WavReader::open(path)?;
669    let spec = reader.spec();
670    if spec.channels == 0 {
671        return Err(WhisperCppError::InvalidInput(
672            "wav file has no channels".to_string(),
673        ));
674    }
675    if spec.sample_rate != 16_000 {
676        return Err(WhisperCppError::InvalidInput(format!(
677            "expected 16 kHz wav input, got {} Hz",
678            spec.sample_rate
679        )));
680    }
681
682    let interleaved = match spec.sample_format {
683        hound::SampleFormat::Int => read_int_samples(&mut reader, spec.bits_per_sample)?,
684        hound::SampleFormat::Float => reader
685            .samples::<f32>()
686            .collect::<std::result::Result<Vec<_>, _>>()?,
687    };
688
689    let channels = spec.channels as usize;
690    let samples = if channels == 1 {
691        interleaved
692    } else {
693        interleaved
694            .chunks(channels)
695            .map(|frame| frame.iter().copied().sum::<f32>() / frame.len() as f32)
696            .collect()
697    };
698
699    Ok(AudioSamples { samples })
700}
701
702#[cfg(feature = "native")]
703fn read_int_samples(
704    reader: &mut hound::WavReader<std::io::BufReader<File>>,
705    bits_per_sample: u16,
706) -> Result<Vec<f32>> {
707    let scale = ((1_i64 << (bits_per_sample.saturating_sub(1) as u32)) - 1) as f32;
708    if bits_per_sample <= 16 {
709        Ok(reader
710            .samples::<i16>()
711            .map(|sample| sample.map(|sample| sample as f32 / scale))
712            .collect::<std::result::Result<Vec<_>, _>>()?)
713    } else {
714        Ok(reader
715            .samples::<i32>()
716            .map(|sample| sample.map(|sample| sample as f32 / scale))
717            .collect::<std::result::Result<Vec<_>, _>>()?)
718    }
719}
720
721#[cfg(feature = "native")]
722fn resolve_threads(value: Option<usize>) -> i32 {
723    value
724        .or_else(|| thread::available_parallelism().ok().map(usize::from))
725        .unwrap_or(4)
726        .min(i32::MAX as usize) as i32
727}
728
729#[cfg(feature = "native")]
730fn timestamp_to_seconds(value: i64) -> f64 {
731    value as f64 / 100.0
732}
733
734#[cfg(feature = "native")]
735fn join_segments(segments: &[WhisperCppSegment]) -> Option<String> {
736    let text = segments
737        .iter()
738        .map(|segment| segment.text.trim())
739        .filter(|text| !text.is_empty())
740        .collect::<Vec<_>>()
741        .join(" ");
742    (!text.is_empty()).then_some(text)
743}
744
745#[cfg(feature = "native")]
746fn c_string(value: *const std::ffi::c_char) -> Result<String> {
747    if value.is_null() {
748        return Ok(String::new());
749    }
750    unsafe { CStr::from_ptr(value) }
751        .to_str()
752        .map(|value| value.to_string())
753        .map_err(|_| WhisperCppError::InvalidUtf8)
754}
755
756fn cache_root() -> PathBuf {
757    if let Some(dir) = std::env::var_os("VIDEO_ANALYSIS_STUDIO_CACHE_DIR") {
758        return PathBuf::from(dir);
759    }
760    if let Some(dir) = std::env::var_os("XDG_CACHE_HOME") {
761        return PathBuf::from(dir).join("video-analysis-studio");
762    }
763    if cfg!(target_os = "windows") {
764        if let Some(dir) = std::env::var_os("LOCALAPPDATA") {
765            return PathBuf::from(dir).join("video-analysis-studio");
766        }
767    }
768    if let Some(home) = std::env::var_os("HOME") {
769        return PathBuf::from(home)
770            .join(".cache")
771            .join("video-analysis-studio");
772    }
773    PathBuf::from(".cache/video-analysis-studio")
774}
775
776#[cfg(feature = "native")]
777struct AudioSamples {
778    samples: Vec<f32>,
779}
780
781#[cfg(feature = "native")]
782struct WhisperContext {
783    raw: *mut ffi::whisper_context,
784}
785
786#[cfg(feature = "native")]
787impl WhisperContext {
788    fn from_model(path: &Path) -> Result<Self> {
789        let model_path = CString::new(path.to_string_lossy().into_owned())
790            .map_err(|_| WhisperCppError::Initialization(path.display().to_string()))?;
791        let mut params = unsafe { ffi::whisper_context_default_params() };
792        params.use_gpu = cfg!(target_os = "macos");
793        params.flash_attn = false;
794        let raw = unsafe { ffi::whisper_init_from_file_with_params(model_path.as_ptr(), params) };
795        if raw.is_null() {
796            return Err(WhisperCppError::Initialization(path.display().to_string()));
797        }
798        Ok(Self { raw })
799    }
800}
801
802#[cfg(feature = "native")]
803impl Drop for WhisperContext {
804    fn drop(&mut self) {
805        if !self.raw.is_null() {
806            unsafe { ffi::whisper_free(self.raw) };
807        }
808    }
809}
810
811#[cfg(any(feature = "native", test))]
812struct FileLock {
813    path: PathBuf,
814}
815
816#[cfg(any(feature = "native", test))]
817impl FileLock {
818    fn acquire(path: PathBuf) -> Result<Self> {
819        let deadline = Instant::now() + Duration::from_secs(120);
820        loop {
821            match OpenOptions::new().create_new(true).write(true).open(&path) {
822                Ok(mut file) => {
823                    let _ = writeln!(file, "{}", std::process::id());
824                    return Ok(Self { path });
825                }
826                Err(error) if error.kind() == std::io::ErrorKind::AlreadyExists => {
827                    if Instant::now() >= deadline {
828                        return Err(WhisperCppError::Io(error));
829                    }
830                    thread::sleep(Duration::from_millis(250));
831                }
832                Err(error) => return Err(WhisperCppError::Io(error)),
833            }
834        }
835    }
836}
837
838#[cfg(any(feature = "native", test))]
839impl Drop for FileLock {
840    fn drop(&mut self) {
841        let _ = fs::remove_file(&self.path);
842    }
843}
844
845#[cfg(test)]
846mod tests {
847    use super::*;
848    use tempfile::tempdir;
849
850    #[test]
851    fn model_metadata_matches_expected_file_names() {
852        assert_eq!(WhisperCppModel::BaseEn.file_name(), "ggml-base.en.bin");
853        assert_eq!(
854            WhisperCppModel::LargeV3Turbo.download_url(),
855            "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-turbo.bin"
856        );
857    }
858
859    #[test]
860    fn catalog_uses_base_en_by_default() {
861        let catalog = ModelStore::new(PathBuf::from("/tmp/video-analysis-studio-test")).catalog();
862        assert_eq!(catalog.default_model, WhisperCppModel::BaseEn);
863        assert_eq!(catalog.models.len(), WhisperCppModel::ALL.len());
864    }
865
866    #[test]
867    fn cache_paths_are_stable() {
868        let store = ModelStore::new(PathBuf::from("/tmp/video-analysis-studio-test"));
869        assert_eq!(
870            store.model_path(WhisperCppModel::SmallEn),
871            PathBuf::from("/tmp/video-analysis-studio-test/models/ggml-small.en.bin")
872        );
873        assert_eq!(
874            store.lock_path(WhisperCppModel::SmallEn),
875            PathBuf::from("/tmp/video-analysis-studio-test/models/ggml-small.en.bin.lock")
876        );
877    }
878
879    #[test]
880    fn file_lock_creates_and_releases_lock_path() {
881        let dir = tempdir().unwrap();
882        let path = dir.path().join("model.lock");
883        {
884            let _lock = FileLock::acquire(path.clone()).unwrap();
885            assert!(path.is_file());
886        }
887        assert!(!path.exists());
888    }
889
890    #[test]
891    fn english_only_models_default_to_english() {
892        let config = WhisperCppConfig {
893            model: WhisperCppModel::BaseEn,
894            language: None,
895            translate: false,
896            threads: None,
897        };
898
899        let language = resolve_language(&config).unwrap().unwrap();
900        assert_eq!(language.to_str().unwrap(), "en");
901    }
902
903    #[test]
904    fn multilingual_models_default_to_auto_detection_without_detect_only_mode() {
905        let config = WhisperCppConfig {
906            model: WhisperCppModel::Base,
907            language: None,
908            translate: false,
909            threads: None,
910        };
911
912        assert_eq!(resolve_language(&config).unwrap(), None);
913    }
914
915    #[test]
916    fn auto_language_uses_english_for_english_only_models() {
917        let config = WhisperCppConfig {
918            model: WhisperCppModel::SmallEn,
919            language: Some("auto".to_string()),
920            translate: false,
921            threads: None,
922        };
923
924        let language = resolve_language(&config).unwrap().unwrap();
925        assert_eq!(language.to_str().unwrap(), "en");
926    }
927}