1use std::{
2 ffi::OsStr,
3 fmt::{Debug, Display},
4 fs::{self, File},
5 io::Cursor,
6 path::PathBuf,
7};
8
9use crate::{get_token, TokenSource};
10use hf_hub::{
11 api::sync::{ApiBuilder, ApiRepo},
12 Repo, RepoType,
13};
14use memmap2::Mmap;
15use zip::ZipArchive;
16
17pub enum ModelSource {
19 ModelId(String),
20 ModelIdWithTransformer {
21 model_id: String,
22 transformer_model_id: String,
23 },
24 Dduf {
25 file: Cursor<Mmap>,
26 name: String,
27 },
28}
29
30impl Display for ModelSource {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Self::Dduf { file: _, name } => write!(f, "dduf file: {name}"),
34 Self::ModelId(model_id) => write!(f, "model id: {model_id}"),
35 Self::ModelIdWithTransformer {
36 model_id,
37 transformer_model_id,
38 } => write!(
39 f,
40 "model id: {model_id}, transformer override: {transformer_model_id}"
41 ),
42 }
43 }
44}
45
46impl ModelSource {
47 pub fn from_model_id<S: ToString>(model_id: S) -> Self {
49 Self::ModelId(model_id.to_string())
50 }
51
52 pub fn override_transformer_model_id<S: ToString>(self, model_id: S) -> anyhow::Result<Self> {
66 let Self::ModelId(base_id) = self else {
67 anyhow::bail!("Expected model ID for the model source")
68 };
69 Ok(Self::ModelIdWithTransformer {
70 model_id: base_id,
71 transformer_model_id: model_id.to_string(),
72 })
73 }
74
75 pub fn dduf<S: ToString>(filename: S) -> anyhow::Result<Self> {
77 let file = File::open(filename.to_string())?;
78 let mmap = unsafe { Mmap::map(&file)? };
79 let cursor = Cursor::new(mmap);
80 Ok(Self::Dduf {
81 file: cursor,
82 name: filename.to_string(),
83 })
84 }
85}
86
87pub enum FileLoader<'a> {
88 Api(Box<ApiRepo>),
89 ApiWithTransformer {
90 base: Box<ApiRepo>,
91 transformer: Box<ApiRepo>,
92 },
93 Dduf(ZipArchive<&'a mut Cursor<Mmap>>),
94}
95
96impl<'a> FileLoader<'a> {
97 pub fn from_model_source(
98 source: &'a mut ModelSource,
99 silent: bool,
100 token: TokenSource,
101 revision: Option<String>,
102 ) -> anyhow::Result<Self> {
103 match source {
104 ModelSource::ModelId(model_id) => {
105 let api_builder = ApiBuilder::new()
106 .with_progress(!silent)
107 .with_token(get_token(&token)?)
108 .build()?;
109 let revision = revision.unwrap_or("main".to_string());
110 let api = api_builder.repo(Repo::with_revision(
111 model_id.clone(),
112 RepoType::Model,
113 revision.clone(),
114 ));
115
116 Ok(Self::Api(Box::new(api)))
117 }
118 ModelSource::Dduf { file, name: _ } => Ok(Self::Dduf(ZipArchive::new(file)?)),
119 ModelSource::ModelIdWithTransformer {
120 model_id,
121 transformer_model_id,
122 } => {
123 let api_builder = ApiBuilder::new()
124 .with_progress(!silent)
125 .with_token(get_token(&token)?)
126 .build()?;
127 let revision = revision.unwrap_or("main".to_string());
128 let api = api_builder.repo(Repo::with_revision(
129 model_id.clone(),
130 RepoType::Model,
131 revision.clone(),
132 ));
133 let transformer_api = api_builder.repo(Repo::with_revision(
134 transformer_model_id.clone(),
135 RepoType::Model,
136 revision.clone(),
137 ));
138
139 Ok(Self::ApiWithTransformer {
140 base: Box::new(api),
141 transformer: Box::new(transformer_api),
142 })
143 }
144 }
145 }
146
147 pub fn list_files(&mut self) -> anyhow::Result<Vec<String>> {
148 match self {
149 Self::Api(api)
150 | Self::ApiWithTransformer {
151 base: api,
152 transformer: _,
153 } => api
154 .info()
155 .map(|repo| {
156 repo.siblings
157 .iter()
158 .map(|x| x.rfilename.clone())
159 .collect::<Vec<String>>()
160 })
161 .map_err(|e| anyhow::Error::msg(e.to_string())),
162 Self::Dduf(dduf) => (0..dduf.len())
163 .map(|i| {
164 dduf.by_index(i)
165 .map(|x| x.name().to_string())
166 .map_err(|e| anyhow::Error::msg(e.to_string()))
167 })
168 .collect::<anyhow::Result<Vec<_>>>(),
169 }
170 }
171
172 pub fn list_transformer_files(&self) -> anyhow::Result<Option<Vec<String>>> {
173 match self {
174 Self::Api(_) | Self::Dduf(_) => Ok(None),
175
176 Self::ApiWithTransformer {
177 base: _,
178 transformer: api,
179 } => api
180 .info()
181 .map(|repo| {
182 repo.siblings
183 .iter()
184 .map(|x| x.rfilename.clone())
185 .collect::<Vec<String>>()
186 })
187 .map(Some)
188 .map_err(|e| anyhow::Error::msg(e.to_string())),
189 }
190 }
191
192 pub fn read_file(&mut self, name: &str, from_transformer: bool) -> anyhow::Result<FileData> {
198 if from_transformer && !matches!(self, Self::ApiWithTransformer { .. }) {
199 anyhow::bail!("This model source has no transformer files.")
200 }
201
202 match (self, from_transformer) {
203 (Self::Api(api), false)
204 | (
205 Self::ApiWithTransformer {
206 base: api,
207 transformer: _,
208 },
209 false,
210 ) => Ok(FileData::Path(
211 api.get(name)
212 .map_err(|e| anyhow::Error::msg(e.to_string()))?,
213 )),
214 (
215 Self::ApiWithTransformer {
216 base: api,
217 transformer: _,
218 },
219 true,
220 ) => Ok(FileData::Path(
221 api.get(name)
222 .map_err(|e| anyhow::Error::msg(e.to_string()))?,
223 )),
224 (Self::Api(_), true) => anyhow::bail!("This model source has no transformer files."),
225 (Self::Dduf(dduf), _) => {
226 let file = dduf.by_name(name)?;
227 let start = file.data_start() as usize;
228 let len = file.size() as usize;
229 let end = start + len;
230 let name = file.name().into();
231 Ok(FileData::Dduf { name, start, end })
232 }
233 }
234 }
235
236 pub fn read_file_copied(
242 &mut self,
243 name: &str,
244 from_transformer: bool,
245 ) -> anyhow::Result<FileData> {
246 if matches!(self, Self::Api(_) | Self::ApiWithTransformer { .. }) {
247 return self.read_file(name, from_transformer);
248 }
249
250 let Self::Dduf(dduf) = self else {
251 anyhow::bail!("expected dduf model source!");
252 };
253 let mut file = dduf.by_name(name)?;
254 let mut data = Vec::new();
255 std::io::copy(&mut file, &mut data)?;
256 let name = PathBuf::from(file.name().to_string());
257 Ok(FileData::DdufOwned { name, data })
258 }
259}
260
261pub enum FileData {
262 Path(PathBuf),
263 Dduf {
264 name: PathBuf,
265 start: usize,
266 end: usize,
267 },
268 DdufOwned {
269 name: PathBuf,
270 data: Vec<u8>,
271 },
272}
273
274impl Debug for FileData {
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 match self {
277 Self::Path(p) => write!(f, "path: {}", p.display()),
278 Self::Dduf {
279 name,
280 start: _,
281 end: _,
282 } => write!(f, "dduf: {}", name.display()),
283 Self::DdufOwned { name, data: _ } => write!(f, "dduf owned: {}", name.display()),
284 }
285 }
286}
287
288impl FileData {
289 pub fn read_to_string(&self, src: &ModelSource) -> anyhow::Result<String> {
290 match self {
291 Self::Path(p) => Ok(fs::read_to_string(p)?),
292 Self::Dduf {
293 name: _,
294 start,
295 end,
296 } => {
297 let ModelSource::Dduf { file, name: _ } = src else {
298 anyhow::bail!("expected dduf model source!");
299 };
300 Ok(String::from_utf8(file.get_ref()[*start..*end].to_vec())?)
301 }
302 Self::DdufOwned { name: _, data } => Ok(String::from_utf8(data.to_vec())?),
303 }
304 }
305
306 pub fn read_to_string_owned(&self) -> anyhow::Result<String> {
307 match self {
308 Self::Path(p) => Ok(fs::read_to_string(p)?),
309 Self::Dduf { .. } => {
310 anyhow::bail!("dduf file data is not owned !");
311 }
312 Self::DdufOwned { name: _, data } => Ok(String::from_utf8(data.to_vec())?),
313 }
314 }
315
316 pub fn extension(&self) -> Option<&OsStr> {
317 match self {
318 Self::Path(p) => p.extension(),
319 Self::Dduf {
320 name,
321 start: _,
322 end: _,
323 } => name.extension(),
324 Self::DdufOwned { name, data: _ } => name.extension(),
325 }
326 }
327}