1use std::ffi::CString;
13use std::path::Path;
14use std::sync::Arc;
15use std::sync::Mutex;
16use thiserror::Error;
17
18use whisper_cpp_plus_sys as ffi;
19
20#[derive(Debug, Error)]
22pub enum QuantizeError {
23 #[error("Model file not found: {0}")]
24 FileNotFound(String),
25
26 #[error("Failed to open file: {0}")]
27 FileOpenError(String),
28
29 #[error("Failed to write file: {0}")]
30 FileWriteError(String),
31
32 #[error("Invalid model format")]
33 InvalidModel,
34
35 #[error("Invalid quantization type")]
36 InvalidQuantizationType,
37
38 #[error("Quantization failed: {0}")]
39 QuantizationFailed(String),
40}
41
42type Result<T> = std::result::Result<T, QuantizeError>;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46#[repr(i32)]
47#[allow(non_camel_case_types)]
48pub enum QuantizationType {
49 Q4_0 = ffi::GGML_FTYPE_MOSTLY_Q4_0,
51
52 Q4_1 = ffi::GGML_FTYPE_MOSTLY_Q4_1,
54
55 Q5_0 = ffi::GGML_FTYPE_MOSTLY_Q5_0,
57
58 Q5_1 = ffi::GGML_FTYPE_MOSTLY_Q5_1,
60
61 Q8_0 = ffi::GGML_FTYPE_MOSTLY_Q8_0,
63
64 Q2_K = ffi::GGML_FTYPE_MOSTLY_Q2_K,
66
67 Q3_K = ffi::GGML_FTYPE_MOSTLY_Q3_K,
69
70 Q4_K = ffi::GGML_FTYPE_MOSTLY_Q4_K,
72
73 Q5_K = ffi::GGML_FTYPE_MOSTLY_Q5_K,
75
76 Q6_K = ffi::GGML_FTYPE_MOSTLY_Q6_K,
78}
79
80impl QuantizationType {
81 pub fn name(&self) -> &'static str {
83 match self {
84 Self::Q4_0 => "Q4_0",
85 Self::Q4_1 => "Q4_1",
86 Self::Q5_0 => "Q5_0",
87 Self::Q5_1 => "Q5_1",
88 Self::Q8_0 => "Q8_0",
89 Self::Q2_K => "Q2_K",
90 Self::Q3_K => "Q3_K",
91 Self::Q4_K => "Q4_K",
92 Self::Q5_K => "Q5_K",
93 Self::Q6_K => "Q6_K",
94 }
95 }
96
97 pub fn size_factor(&self) -> f32 {
100 match self {
101 Self::Q2_K => 0.19, Self::Q3_K => 0.26, Self::Q4_0 => 0.31, Self::Q4_1 => 0.35, Self::Q4_K => 0.33, Self::Q5_0 => 0.39, Self::Q5_1 => 0.43, Self::Q5_K => 0.41, Self::Q6_K => 0.49, Self::Q8_0 => 0.69, }
112 }
113
114 pub fn all() -> &'static [QuantizationType] {
116 &[
117 Self::Q4_0,
118 Self::Q4_1,
119 Self::Q5_0,
120 Self::Q5_1,
121 Self::Q8_0,
122 Self::Q2_K,
123 Self::Q3_K,
124 Self::Q4_K,
125 Self::Q5_K,
126 Self::Q6_K,
127 ]
128 }
129}
130
131impl std::str::FromStr for QuantizationType {
132 type Err = QuantizeError;
133
134 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
135 match s.to_uppercase().as_str() {
136 "Q4_0" | "Q40" => Ok(Self::Q4_0),
137 "Q4_1" | "Q41" => Ok(Self::Q4_1),
138 "Q5_0" | "Q50" => Ok(Self::Q5_0),
139 "Q5_1" | "Q51" => Ok(Self::Q5_1),
140 "Q8_0" | "Q80" => Ok(Self::Q8_0),
141 "Q2_K" | "Q2K" => Ok(Self::Q2_K),
142 "Q3_K" | "Q3K" => Ok(Self::Q3_K),
143 "Q4_K" | "Q4K" => Ok(Self::Q4_K),
144 "Q5_K" | "Q5K" => Ok(Self::Q5_K),
145 "Q6_K" | "Q6K" => Ok(Self::Q6_K),
146 _ => Err(QuantizeError::InvalidQuantizationType),
147 }
148 }
149}
150
151impl std::fmt::Display for QuantizationType {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 write!(f, "{}", self.name())
154 }
155}
156
157pub type ProgressCallback = Box<dyn Fn(f32) + Send>;
159
160pub struct WhisperQuantize;
162
163impl WhisperQuantize {
164 pub fn quantize_model_file<P: AsRef<Path>>(
182 input_path: P,
183 output_path: P,
184 qtype: QuantizationType,
185 ) -> Result<()> {
186 Self::quantize_model_file_impl(input_path.as_ref(), output_path.as_ref(), qtype, None)
187 }
188
189 pub fn quantize_model_file_with_progress<P, F>(
211 input_path: P,
212 output_path: P,
213 qtype: QuantizationType,
214 callback: F,
215 ) -> Result<()>
216 where
217 P: AsRef<Path>,
218 F: Fn(f32) + Send + 'static,
219 {
220 Self::quantize_model_file_impl(
221 input_path.as_ref(),
222 output_path.as_ref(),
223 qtype,
224 Some(Box::new(callback)),
225 )
226 }
227
228 fn quantize_model_file_impl(
229 input_path: &Path,
230 output_path: &Path,
231 qtype: QuantizationType,
232 callback: Option<ProgressCallback>,
233 ) -> Result<()> {
234 if !input_path.exists() {
236 return Err(QuantizeError::FileNotFound(format!(
237 "{}",
238 input_path.display()
239 )));
240 }
241
242 let input_cstr = path_to_cstring(input_path)?;
244 let output_cstr = path_to_cstring(output_path)?;
245
246 let callback_data = callback.map(|cb| Arc::new(Mutex::new(cb)));
248 let callback_ptr = callback_data
249 .as_ref()
250 .map(|data| Arc::clone(data) as Arc<Mutex<dyn Fn(f32) + Send>>);
251
252 let ffi_callback: ffi::whisper_quantize_progress_callback = if callback_ptr.is_some() {
254 Some(quantize_progress_callback)
255 } else {
256 None
257 };
258
259 if let Some(ptr) = callback_ptr {
261 CALLBACK_DATA.with(|data| {
262 *data.borrow_mut() = Some(ptr);
263 });
264 }
265
266 let result = unsafe {
268 ffi::whisper_model_quantize(
269 input_cstr.as_ptr(),
270 output_cstr.as_ptr(),
271 qtype as i32,
272 ffi_callback,
273 )
274 };
275
276 CALLBACK_DATA.with(|data| {
278 *data.borrow_mut() = None;
279 });
280
281 match result {
283 ffi::WHISPER_QUANTIZE_OK => Ok(()),
284 ffi::WHISPER_QUANTIZE_ERROR_INVALID_MODEL => Err(QuantizeError::QuantizationFailed(
285 "Invalid model file".to_string(),
286 )),
287 ffi::WHISPER_QUANTIZE_ERROR_FILE_OPEN => Err(QuantizeError::QuantizationFailed(
288 format!("Failed to open input file: {}", input_path.display()),
289 )),
290 ffi::WHISPER_QUANTIZE_ERROR_FILE_WRITE => Err(QuantizeError::QuantizationFailed(
291 format!("Failed to write output file: {}", output_path.display()),
292 )),
293 ffi::WHISPER_QUANTIZE_ERROR_INVALID_FTYPE => Err(QuantizeError::QuantizationFailed(
294 format!("Invalid quantization type: {}", qtype),
295 )),
296 ffi::WHISPER_QUANTIZE_ERROR_QUANTIZATION_FAILED => Err(
297 QuantizeError::QuantizationFailed("Quantization failed".to_string()),
298 ),
299 _ => Err(QuantizeError::QuantizationFailed(format!(
300 "Unknown quantization error: {}",
301 result
302 ))),
303 }
304 }
305
306 pub fn get_model_quantization_type<P: AsRef<Path>>(
324 model_path: P,
325 ) -> Result<Option<QuantizationType>> {
326 let path = model_path.as_ref();
327
328 if !path.exists() {
329 return Err(QuantizeError::FileNotFound(format!("{}", path.display())));
330 }
331
332 let path_cstr = path_to_cstring(path)?;
333
334 let ftype = unsafe { ffi::whisper_model_get_ftype(path_cstr.as_ptr()) };
335
336 if ftype < 0 {
337 return Err(QuantizeError::FileOpenError(format!("{}", path.display())));
338 }
339
340 let qtype = match ftype {
342 x if x == ffi::GGML_FTYPE_ALL_F32 => None,
343 x if x == ffi::GGML_FTYPE_MOSTLY_F16 => None,
344 x if x == QuantizationType::Q4_0 as i32 => Some(QuantizationType::Q4_0),
345 x if x == QuantizationType::Q4_1 as i32 => Some(QuantizationType::Q4_1),
346 x if x == QuantizationType::Q5_0 as i32 => Some(QuantizationType::Q5_0),
347 x if x == QuantizationType::Q5_1 as i32 => Some(QuantizationType::Q5_1),
348 x if x == QuantizationType::Q8_0 as i32 => Some(QuantizationType::Q8_0),
349 x if x == QuantizationType::Q2_K as i32 => Some(QuantizationType::Q2_K),
350 x if x == QuantizationType::Q3_K as i32 => Some(QuantizationType::Q3_K),
351 x if x == QuantizationType::Q4_K as i32 => Some(QuantizationType::Q4_K),
352 x if x == QuantizationType::Q5_K as i32 => Some(QuantizationType::Q5_K),
353 x if x == QuantizationType::Q6_K as i32 => Some(QuantizationType::Q6_K),
354 _ => None,
355 };
356
357 Ok(qtype)
358 }
359
360 pub fn estimate_quantized_size<P: AsRef<Path>>(
377 model_path: P,
378 qtype: QuantizationType,
379 ) -> Result<u64> {
380 let path = model_path.as_ref();
381 let metadata = std::fs::metadata(path).map_err(|e| {
382 QuantizeError::QuantizationFailed(format!("Failed to read model file: {}", e))
383 })?;
384
385 let original_size = metadata.len();
386 let estimated_size = (original_size as f64 * qtype.size_factor() as f64) as u64;
387
388 Ok(estimated_size)
389 }
390}
391
392thread_local! {
394 static CALLBACK_DATA: std::cell::RefCell<Option<Arc<Mutex<dyn Fn(f32) + Send>>>> =
395 std::cell::RefCell::new(None);
396}
397
398extern "C" fn quantize_progress_callback(progress: f32) {
400 CALLBACK_DATA.with(|data| {
401 if let Some(callback) = data.borrow().as_ref() {
402 if let Ok(cb) = callback.lock() {
403 cb(progress);
404 }
405 }
406 });
407}
408
409fn path_to_cstring(path: &Path) -> Result<CString> {
411 let path_str = path
412 .to_str()
413 .ok_or_else(|| QuantizeError::QuantizationFailed("Invalid UTF-8 in path".to_string()))?;
414
415 CString::new(path_str)
416 .map_err(|_| QuantizeError::QuantizationFailed("Path contains null byte".to_string()))
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn test_quantization_type_names() {
425 assert_eq!(QuantizationType::Q4_0.name(), "Q4_0");
426 assert_eq!(QuantizationType::Q5_1.name(), "Q5_1");
427 assert_eq!(QuantizationType::Q8_0.name(), "Q8_0");
428 assert_eq!(QuantizationType::Q3_K.name(), "Q3_K");
429 }
430
431 #[test]
432 fn test_quantization_type_from_str() {
433 assert_eq!(
434 "q4_0".parse::<QuantizationType>().unwrap(),
435 QuantizationType::Q4_0
436 );
437 assert_eq!(
438 "Q5_1".parse::<QuantizationType>().unwrap(),
439 QuantizationType::Q5_1
440 );
441 assert_eq!(
442 "q8_0".parse::<QuantizationType>().unwrap(),
443 QuantizationType::Q8_0
444 );
445 assert_eq!(
446 "Q3K".parse::<QuantizationType>().unwrap(),
447 QuantizationType::Q3_K
448 );
449 assert!("invalid".parse::<QuantizationType>().is_err());
450 }
451
452 #[test]
453 fn test_size_factors() {
454 for qtype in QuantizationType::all() {
455 let factor = qtype.size_factor();
456 assert!(
457 factor > 0.0 && factor < 1.0,
458 "{} has invalid size factor: {}",
459 qtype,
460 factor
461 );
462 }
463 }
464}