1use std::path::PathBuf;
2
3use hf_hub::{Cache, Repo};
4use strum::{Display, EnumIter, EnumString};
5use tokio::sync::mpsc::UnboundedSender;
6
7use crate::{
8 Error, Event,
9 download::{ProgressType, download_file},
10};
11
12struct HFCoordinates {
13 repo: Repo,
14 model: String,
15}
16
17#[derive(Default, Clone, Debug, EnumIter, EnumString, Display)]
19#[strum(serialize_all = "snake_case")]
20pub enum Model {
21 #[strum(serialize = "tiny", to_string = "Tiny - tiny")]
23 Tiny,
24 #[strum(serialize = "tiny-q5_1", to_string = "Tiny - tiny-q5_1")]
26 TinyQ5_1,
27 #[strum(serialize = "tiny-q8_0", to_string = "Tiny - tiny-q8_0")]
29 TinyQ8_0,
30 #[strum(serialize = "tiny_en", to_string = "TinyEn - tiny_en")]
32 TinyEn,
33 #[strum(serialize = "tiny_en-q5_1", to_string = "TinyEn - tiny_en-q5_1")]
35 TinyEnQ5_1,
36 #[strum(serialize = "tiny_en-q8_0", to_string = "Tiny - tiny_en-q8_0")]
38 TinyEnQ8_0,
39 #[default]
41 #[strum(serialize = "base", to_string = "Base - base")]
42 Base,
43 #[strum(serialize = "base-q5_1", to_string = "Base - base-q5_1")]
45 BaseQ5_1,
46 #[strum(serialize = "base-q8_0", to_string = "Base - base-q8_0")]
48 BaseQ8_0,
49 #[strum(serialize = "base_en", to_string = "BaseEn - base_en")]
51 BaseEn,
52 #[strum(serialize = "base_en-q5_1", to_string = "BaseEn -base_en-q5_1")]
54 BaseEnQ5_1,
55 #[strum(serialize = "base_en-q8_0", to_string = "BaseEn - base_en-q8_0")]
57 BaseEnQ8_0,
58 #[strum(serialize = "small", to_string = "Small - small")]
60 Small,
61 #[strum(serialize = "small-q5_1", to_string = "Small - small-q5_1")]
63 SmallQ5_1,
64 #[strum(serialize = "small-q8_0", to_string = "Small - small-q8_0")]
66 SmallQ8_0,
67 #[strum(serialize = "small_en", to_string = "SmallEn - small_en")]
69 SmallEn,
70 #[strum(serialize = "small_en-q5_1", to_string = "SmallEn - small_en-q5_1")]
72 SmallEnQ5_1,
73 #[strum(serialize = "small_en-q8_0", to_string = "SmallEn - small_en-q8_0")]
75 SmallEnQ8_0,
76 #[strum(serialize = "medium", to_string = "Medium - medium")]
78 Medium,
79 #[strum(serialize = "medium-q5_0", to_string = "Medium - medium-q5_0")]
81 MediumQ5_0,
82 #[strum(serialize = "medium-q8_0", to_string = "Medium - medium-q8_0")]
84 MediumQ8_0,
85 #[strum(serialize = "medium_en", to_string = "MediumEn - medium_en")]
87 MediumEn,
88 #[strum(serialize = "medium_en-q5_0 ", to_string = "MediumEn - medium_en-q5_0")]
90 MediumEnQ5_0,
91 #[strum(serialize = "medium_en-q8_0", to_string = "MediumEn - medium_en-q8_0")]
93 MediumEnQ8_0,
94 #[strum(serialize = "large", to_string = "Large V1 - large")]
96 Large,
97 #[strum(serialize = "large_v2", to_string = "Large V2 - large_v2")]
99 LargeV2,
100 #[strum(serialize = "large_v2-q5_0", to_string = "Large V2 - large_v2-q5_0")]
101 LargeV2Q5_0,
102 #[strum(serialize = "large_v2-q8_0", to_string = "Large V2 - large_v2-q8_0")]
103 LargeV2Q8_0,
104 #[strum(serialize = "large_v3", to_string = "Large V3 - large_v3")]
106 LargeV3,
107 #[strum(serialize = "large_v3-q5_0", to_string = "Large V3 - large_v3-q5_0")]
109 LargeV3Q5_0,
110 #[strum(
112 serialize = "large_v3_turbo",
113 to_string = "Large V3 Turbo - large_v3_turbo"
114 )]
115 LargeV3Turbo,
116 #[strum(
118 serialize = "large_v3_turbo-q5_0",
119 to_string = "Large V3 Turbo - large_v3_turbo-q5_0"
120 )]
121 LargeV3TurboQ5_0,
122 #[strum(
124 serialize = "large_v3_turbo-q8_0",
125 to_string = "Large V3 Turbo - large_v3_turbo-q8_0"
126 )]
127 LargeV3TurboQ8_0,
128}
129
130impl Model {
131 fn hf_coordinates(&self) -> HFCoordinates {
132 let repo = Repo::with_revision(
133 "ggerganov/whisper.cpp".to_owned(),
134 hf_hub::RepoType::Model,
135 "main".to_owned(),
136 );
137 match self {
138 Model::Tiny => HFCoordinates {
139 repo,
140 model: "ggml-tiny.bin".to_owned(),
141 },
142 Model::TinyEn => HFCoordinates {
143 repo,
144 model: "ggml-tiny.en.bin".to_owned(),
145 },
146 Model::Base => HFCoordinates {
147 repo,
148 model: "ggml-base.bin".to_owned(),
149 },
150 Model::BaseEn => HFCoordinates {
151 repo,
152 model: "ggml-base.en.bin".to_owned(),
153 },
154 Model::Small => HFCoordinates {
155 repo,
156 model: "ggml-small.bin".to_owned(),
157 },
158 Model::SmallEn => HFCoordinates {
159 repo,
160 model: "ggml-small.en.bin".to_owned(),
161 },
162 Model::Medium => HFCoordinates {
163 repo,
164 model: "ggml-medium.bin".to_owned(),
165 },
166 Model::MediumEn => HFCoordinates {
167 repo,
168 model: "ggml-medium.en.bin".to_owned(),
169 },
170 Model::Large => HFCoordinates {
171 repo,
172 model: "ggml-large-v1.bin".to_owned(),
173 },
174 Model::LargeV2 => HFCoordinates {
175 repo,
176 model: "ggml-large-v2.bin".to_owned(),
177 },
178 Model::LargeV3 => HFCoordinates {
179 repo,
180 model: "ggml-large-v3.bin".to_owned(),
181 },
182 Model::TinyQ5_1 => HFCoordinates {
183 repo,
184 model: "ggml-tiny-q5_1.bin".to_owned(),
185 },
186 Model::TinyQ8_0 => HFCoordinates {
187 repo,
188 model: "ggml-tiny-q8_0.bin".to_owned(),
189 },
190 Model::TinyEnQ5_1 => HFCoordinates {
191 repo,
192 model: "ggml-tiny.en-q5_1.bin".to_owned(),
193 },
194 Model::TinyEnQ8_0 => HFCoordinates {
195 repo,
196 model: "ggml-tiny.en-q8_0.bin".to_owned(),
197 },
198 Model::BaseQ5_1 => HFCoordinates {
199 repo,
200 model: "ggml-base-q5_1.bin".to_owned(),
201 },
202 Model::BaseQ8_0 => HFCoordinates {
203 repo,
204 model: "ggml-base-q8_0.bin".to_owned(),
205 },
206 Model::BaseEnQ5_1 => HFCoordinates {
207 repo,
208 model: "ggml-base.en-q5_1.bin".to_owned(),
209 },
210 Model::BaseEnQ8_0 => HFCoordinates {
211 repo,
212 model: "ggml-base.en-q8_0.bin".to_owned(),
213 },
214 Model::SmallQ5_1 => HFCoordinates {
215 repo,
216 model: "ggml-small-q5_1.bin".to_owned(),
217 },
218 Model::SmallQ8_0 => HFCoordinates {
219 repo,
220 model: "ggml-small-q8_0.bin".to_owned(),
221 },
222 Model::SmallEnQ5_1 => HFCoordinates {
223 repo,
224 model: "ggml-small.en-q5_1.bin".to_owned(),
225 },
226 Model::SmallEnQ8_0 => HFCoordinates {
227 repo,
228 model: "ggml-small.en-q8_0.bin".to_owned(),
229 },
230 Model::MediumQ5_0 => HFCoordinates {
231 repo,
232 model: "ggml-medium-q5_0.bin".to_owned(),
233 },
234 Model::MediumQ8_0 => HFCoordinates {
235 repo,
236 model: "ggml-medium-q8_0.bin".to_owned(),
237 },
238 Model::MediumEnQ5_0 => HFCoordinates {
239 repo,
240 model: "ggml-medium.en-q5_0.bin".to_owned(),
241 },
242 Model::MediumEnQ8_0 => HFCoordinates {
243 repo,
244 model: "ggml-medium.en-q8_0.bin".to_owned(),
245 },
246 Model::LargeV2Q5_0 => HFCoordinates {
247 repo,
248 model: "ggml-large-v2-q5_0.bin".to_owned(),
249 },
250 Model::LargeV2Q8_0 => HFCoordinates {
251 repo,
252 model: "ggml-large-v2-q8_0.bin".to_owned(),
253 },
254 Model::LargeV3Q5_0 => HFCoordinates {
255 repo,
256 model: "ggml-large-v3-q5_0.bin".to_owned(),
257 },
258 Model::LargeV3Turbo => HFCoordinates {
259 repo,
260 model: "ggml-large-v3-turbo.bin".to_owned(),
261 },
262 Model::LargeV3TurboQ5_0 => HFCoordinates {
263 repo,
264 model: "ggml-large-v3-turbo-q5_0.bin".to_owned(),
265 },
266 Model::LargeV3TurboQ8_0 => HFCoordinates {
267 repo,
268 model: "ggml-large-v3-turbo-q8_0.bin".to_owned(),
269 },
270 }
271 }
272
273 pub fn is_multilingual(&self) -> bool {
275 !self.to_string().contains("en")
276 }
277
278 pub fn cached(&self) -> bool {
280 let coordinates = self.hf_coordinates();
281 let cache = Cache::from_env().repo(coordinates.repo);
282 cache.get(&coordinates.model).is_some()
283 }
284
285 pub(crate) async fn internal_download_model(
286 &self,
287 force_download: bool,
288 progress: ProgressType,
289 ) -> Result<PathBuf, Error> {
290 let coordinates = self.hf_coordinates();
291
292 download_file(
293 &coordinates.model,
294 force_download,
295 progress,
296 coordinates.repo,
297 )
298 .await
299 }
300
301 pub async fn download_model(&self, force_download: bool) -> Result<PathBuf, Error> {
302 self.internal_download_model(force_download, ProgressType::ProgressBar)
303 .await
304 }
305
306 pub async fn download_model_listener(
307 &self,
308 force_download: bool,
309 tx: UnboundedSender<Event>,
310 ) -> Result<PathBuf, Error> {
311 self.internal_download_model(force_download, ProgressType::Callback(tx))
312 .await
313 }
314}