Skip to main content

whisper_cpp_plus/
params.rs

1use std::ffi::CString;
2use whisper_cpp_plus_sys as ffi;
3
4#[derive(Clone, Copy, Debug)]
5pub enum SamplingStrategy {
6    Greedy { best_of: i32 },
7    BeamSearch { beam_size: i32 },
8}
9
10#[derive(Clone)]
11pub struct FullParams {
12    pub(crate) inner: ffi::whisper_full_params,
13    language: Option<CString>,
14    initial_prompt: Option<CString>,
15}
16
17// FullParams is Send and Sync because we only use it in controlled contexts
18unsafe impl Send for FullParams {}
19unsafe impl Sync for FullParams {}
20
21impl FullParams {
22    pub fn new(strategy: SamplingStrategy) -> Self {
23        let inner = unsafe {
24            match strategy {
25                SamplingStrategy::Greedy { best_of } => {
26                    let mut params = ffi::whisper_full_default_params(
27                        ffi::whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY,
28                    );
29                    params.greedy.best_of = best_of;
30                    params
31                }
32                SamplingStrategy::BeamSearch { beam_size } => {
33                    let mut params = ffi::whisper_full_default_params(
34                        ffi::whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH,
35                    );
36                    params.beam_search.beam_size = beam_size;
37                    params
38                }
39            }
40        };
41
42        let mut params = Self {
43            inner,
44            language: None,
45            initial_prompt: None,
46        };
47
48        params.inner.n_threads = (num_cpus::get() / 2).max(1) as i32;
49        params.inner.suppress_blank = true;
50        params.inner.suppress_nst = true;
51        params.inner.temperature = 0.0;
52        params.inner.max_initial_ts = 1.0;
53        params.inner.length_penalty = -1.0;
54
55        params
56    }
57
58    pub(crate) fn as_raw(&self) -> ffi::whisper_full_params {
59        let mut params = self.inner;
60
61        if let Some(ref lang) = self.language {
62            params.language = lang.as_ptr();
63        }
64
65        if let Some(ref prompt) = self.initial_prompt {
66            params.initial_prompt = prompt.as_ptr();
67        }
68
69        params
70    }
71
72    pub fn language(mut self, lang: &str) -> Self {
73        self.language = CString::new(lang).ok();
74        if let Some(ref lang_cstr) = self.language {
75            self.inner.language = lang_cstr.as_ptr();
76        }
77        self
78    }
79
80    pub fn translate(mut self, translate: bool) -> Self {
81        self.inner.translate = translate;
82        self
83    }
84
85    pub fn no_context(mut self, no_context: bool) -> Self {
86        self.inner.no_context = no_context;
87        self
88    }
89
90    pub fn no_timestamps(mut self, no_timestamps: bool) -> Self {
91        self.inner.no_timestamps = no_timestamps;
92        self
93    }
94
95    pub fn single_segment(mut self, single_segment: bool) -> Self {
96        self.inner.single_segment = single_segment;
97        self
98    }
99
100    pub fn print_special(mut self, print_special: bool) -> Self {
101        self.inner.print_special = print_special;
102        self
103    }
104
105    pub fn print_progress(mut self, print_progress: bool) -> Self {
106        self.inner.print_progress = print_progress;
107        self
108    }
109
110    pub fn print_realtime(mut self, print_realtime: bool) -> Self {
111        self.inner.print_realtime = print_realtime;
112        self
113    }
114
115    pub fn print_timestamps(mut self, print_timestamps: bool) -> Self {
116        self.inner.print_timestamps = print_timestamps;
117        self
118    }
119
120    pub fn token_timestamps(mut self, token_timestamps: bool) -> Self {
121        self.inner.token_timestamps = token_timestamps;
122        self
123    }
124
125    pub fn thold_pt(mut self, thold_pt: f32) -> Self {
126        self.inner.thold_pt = thold_pt;
127        self
128    }
129
130    pub fn thold_ptsum(mut self, thold_ptsum: f32) -> Self {
131        self.inner.thold_ptsum = thold_ptsum;
132        self
133    }
134
135    pub fn max_len(mut self, max_len: i32) -> Self {
136        self.inner.max_len = max_len;
137        self
138    }
139
140    pub fn split_on_word(mut self, split_on_word: bool) -> Self {
141        self.inner.split_on_word = split_on_word;
142        self
143    }
144
145    pub fn max_tokens(mut self, max_tokens: i32) -> Self {
146        self.inner.max_tokens = max_tokens;
147        self
148    }
149
150
151    pub fn debug_mode(mut self, debug_mode: bool) -> Self {
152        self.inner.debug_mode = debug_mode;
153        self
154    }
155
156    pub fn audio_ctx(mut self, audio_ctx: i32) -> Self {
157        self.inner.audio_ctx = audio_ctx;
158        self
159    }
160
161    pub fn tdrz_enable(mut self, tdrz_enable: bool) -> Self {
162        self.inner.tdrz_enable = tdrz_enable;
163        self
164    }
165
166    pub fn suppress_regex(mut self, suppress_regex: Option<&str>) -> Self {
167        if let Some(regex) = suppress_regex {
168            if let Ok(c_regex) = CString::new(regex) {
169                self.inner.suppress_regex = c_regex.as_ptr();
170            }
171        } else {
172            self.inner.suppress_regex = std::ptr::null();
173        }
174        self
175    }
176
177    pub fn initial_prompt(mut self, prompt: &str) -> Self {
178        self.initial_prompt = CString::new(prompt).ok();
179        if let Some(ref prompt_cstr) = self.initial_prompt {
180            self.inner.initial_prompt = prompt_cstr.as_ptr();
181        }
182        self
183    }
184
185    pub fn prompt_tokens(mut self, tokens: &[i32]) -> Self {
186        self.inner.prompt_tokens = tokens.as_ptr();
187        self.inner.prompt_n_tokens = tokens.len() as i32;
188        self
189    }
190
191    pub fn temperature(mut self, temperature: f32) -> Self {
192        self.inner.temperature = temperature;
193        self
194    }
195
196    pub fn temperature_inc(mut self, temperature_inc: f32) -> Self {
197        self.inner.temperature_inc = temperature_inc;
198        self
199    }
200
201    pub fn entropy_thold(mut self, entropy_thold: f32) -> Self {
202        self.inner.entropy_thold = entropy_thold;
203        self
204    }
205
206    pub fn logprob_thold(mut self, logprob_thold: f32) -> Self {
207        self.inner.logprob_thold = logprob_thold;
208        self
209    }
210
211    pub fn n_threads(mut self, n_threads: i32) -> Self {
212        self.inner.n_threads = n_threads;
213        self
214    }
215
216    pub fn offset_ms(mut self, offset_ms: i32) -> Self {
217        self.inner.offset_ms = offset_ms;
218        self
219    }
220
221    pub fn duration_ms(mut self, duration_ms: i32) -> Self {
222        self.inner.duration_ms = duration_ms;
223        self
224    }
225}
226
227impl Default for FullParams {
228    fn default() -> Self {
229        Self::new(SamplingStrategy::Greedy { best_of: 1 })
230    }
231}
232
233#[derive(Clone)]
234pub struct TranscriptionParams {
235    params: FullParams,
236}
237
238impl TranscriptionParams {
239    pub fn builder() -> TranscriptionParamsBuilder {
240        TranscriptionParamsBuilder::new()
241    }
242
243    pub(crate) fn into_full_params(self) -> FullParams {
244        self.params
245    }
246}
247
248#[derive(Clone)]
249pub struct TranscriptionParamsBuilder {
250    params: FullParams,
251}
252
253impl TranscriptionParamsBuilder {
254    pub fn new() -> Self {
255        Self {
256            params: FullParams::default(),
257        }
258    }
259
260    pub fn language(mut self, lang: &str) -> Self {
261        self.params = self.params.language(lang);
262        self
263    }
264
265    pub fn translate(mut self, translate: bool) -> Self {
266        self.params = self.params.translate(translate);
267        self
268    }
269
270    pub fn temperature(mut self, temperature: f32) -> Self {
271        self.params = self.params.temperature(temperature);
272        self
273    }
274
275    pub fn enable_timestamps(mut self) -> Self {
276        self.params = self.params.no_timestamps(false);
277        self
278    }
279
280    pub fn disable_timestamps(mut self) -> Self {
281        self.params = self.params.no_timestamps(true);
282        self
283    }
284
285    pub fn single_segment(mut self, single: bool) -> Self {
286        self.params = self.params.single_segment(single);
287        self
288    }
289
290    pub fn max_tokens(mut self, max_tokens: i32) -> Self {
291        self.params = self.params.max_tokens(max_tokens);
292        self
293    }
294
295    pub fn initial_prompt(mut self, prompt: &str) -> Self {
296        self.params = self.params.initial_prompt(prompt);
297        self
298    }
299
300    pub fn n_threads(mut self, n_threads: i32) -> Self {
301        self.params = self.params.n_threads(n_threads);
302        self
303    }
304
305    pub fn build(self) -> TranscriptionParams {
306        TranscriptionParams {
307            params: self.params,
308        }
309    }
310}
311
312impl Default for TranscriptionParamsBuilder {
313    fn default() -> Self {
314        Self::new()
315    }
316}