diffusion_rs_common/
model_source.rs

1use std::{
2    ffi::OsStr,
3    fmt::{Debug, Display},
4    fs::{self, File},
5    io::Cursor,
6    path::PathBuf,
7};
8
9use crate::{get_token, TokenSource};
10use hf_hub::{
11    api::sync::{ApiBuilder, ApiRepo},
12    Repo, RepoType,
13};
14use memmap2::Mmap;
15use zip::ZipArchive;
16
17/// Source from which to load the model. This is easiest to create with the various constructor functions.
18pub enum ModelSource {
19    ModelId(String),
20    ModelIdWithTransformer {
21        model_id: String,
22        transformer_model_id: String,
23    },
24    Dduf {
25        file: Cursor<Mmap>,
26        name: String,
27    },
28}
29
30impl Display for ModelSource {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Self::Dduf { file: _, name } => write!(f, "dduf file: {name}"),
34            Self::ModelId(model_id) => write!(f, "model id: {model_id}"),
35            Self::ModelIdWithTransformer {
36                model_id,
37                transformer_model_id,
38            } => write!(
39                f,
40                "model id: {model_id}, transformer override: {transformer_model_id}"
41            ),
42        }
43    }
44}
45
46impl ModelSource {
47    /// Load the model from a Hugging Face model ID or a local path.
48    pub fn from_model_id<S: ToString>(model_id: S) -> Self {
49        Self::ModelId(model_id.to_string())
50    }
51
52    /// Load the transformer part of this model from a Hugging Face model ID or a local path.
53    ///
54    /// For example, this enables loading a quantized transformer model (for instance, [this](https://huggingface.co/sayakpaul/flux.1-dev-nf4-with-bnb-integration))
55    /// with the same [base model](https://huggingface.co/black-forest-labs/FLUX.1-dev) as the original model ID.
56    ///
57    /// ```rust
58    /// use diffusion_rs_common::ModelSource;
59    ///
60    /// let _ = ModelSource::from_model_id("black-forest-labs/FLUX.1-dev")
61    ///     .override_transformer_model_id("sayakpaul/flux.1-dev-nf4-with-bnb-integration")?;
62    ///
63    /// # Ok::<(), anyhow::Error>(())
64    /// ```
65    pub fn override_transformer_model_id<S: ToString>(self, model_id: S) -> anyhow::Result<Self> {
66        let Self::ModelId(base_id) = self else {
67            anyhow::bail!("Expected model ID for the model source")
68        };
69        Ok(Self::ModelIdWithTransformer {
70            model_id: base_id,
71            transformer_model_id: model_id.to_string(),
72        })
73    }
74
75    /// Load a DDUF model from a .dduf file.
76    pub fn dduf<S: ToString>(filename: S) -> anyhow::Result<Self> {
77        let file = File::open(filename.to_string())?;
78        let mmap = unsafe { Mmap::map(&file)? };
79        let cursor = Cursor::new(mmap);
80        Ok(Self::Dduf {
81            file: cursor,
82            name: filename.to_string(),
83        })
84    }
85}
86
87pub enum FileLoader<'a> {
88    Api(Box<ApiRepo>),
89    ApiWithTransformer {
90        base: Box<ApiRepo>,
91        transformer: Box<ApiRepo>,
92    },
93    Dduf(ZipArchive<&'a mut Cursor<Mmap>>),
94}
95
96impl<'a> FileLoader<'a> {
97    pub fn from_model_source(
98        source: &'a mut ModelSource,
99        silent: bool,
100        token: TokenSource,
101        revision: Option<String>,
102    ) -> anyhow::Result<Self> {
103        match source {
104            ModelSource::ModelId(model_id) => {
105                let api_builder = ApiBuilder::new()
106                    .with_progress(!silent)
107                    .with_token(get_token(&token)?)
108                    .build()?;
109                let revision = revision.unwrap_or("main".to_string());
110                let api = api_builder.repo(Repo::with_revision(
111                    model_id.clone(),
112                    RepoType::Model,
113                    revision.clone(),
114                ));
115
116                Ok(Self::Api(Box::new(api)))
117            }
118            ModelSource::Dduf { file, name: _ } => Ok(Self::Dduf(ZipArchive::new(file)?)),
119            ModelSource::ModelIdWithTransformer {
120                model_id,
121                transformer_model_id,
122            } => {
123                let api_builder = ApiBuilder::new()
124                    .with_progress(!silent)
125                    .with_token(get_token(&token)?)
126                    .build()?;
127                let revision = revision.unwrap_or("main".to_string());
128                let api = api_builder.repo(Repo::with_revision(
129                    model_id.clone(),
130                    RepoType::Model,
131                    revision.clone(),
132                ));
133                let transformer_api = api_builder.repo(Repo::with_revision(
134                    transformer_model_id.clone(),
135                    RepoType::Model,
136                    revision.clone(),
137                ));
138
139                Ok(Self::ApiWithTransformer {
140                    base: Box::new(api),
141                    transformer: Box::new(transformer_api),
142                })
143            }
144        }
145    }
146
147    pub fn list_files(&mut self) -> anyhow::Result<Vec<String>> {
148        match self {
149            Self::Api(api)
150            | Self::ApiWithTransformer {
151                base: api,
152                transformer: _,
153            } => api
154                .info()
155                .map(|repo| {
156                    repo.siblings
157                        .iter()
158                        .map(|x| x.rfilename.clone())
159                        .collect::<Vec<String>>()
160                })
161                .map_err(|e| anyhow::Error::msg(e.to_string())),
162            Self::Dduf(dduf) => (0..dduf.len())
163                .map(|i| {
164                    dduf.by_index(i)
165                        .map(|x| x.name().to_string())
166                        .map_err(|e| anyhow::Error::msg(e.to_string()))
167                })
168                .collect::<anyhow::Result<Vec<_>>>(),
169        }
170    }
171
172    pub fn list_transformer_files(&self) -> anyhow::Result<Option<Vec<String>>> {
173        match self {
174            Self::Api(_) | Self::Dduf(_) => Ok(None),
175
176            Self::ApiWithTransformer {
177                base: _,
178                transformer: api,
179            } => api
180                .info()
181                .map(|repo| {
182                    repo.siblings
183                        .iter()
184                        .map(|x| x.rfilename.clone())
185                        .collect::<Vec<String>>()
186                })
187                .map(Some)
188                .map_err(|e| anyhow::Error::msg(e.to_string())),
189        }
190    }
191
192    /// Read a file.
193    ///
194    /// - If loading from a DDUF file, this returns indices to the file data instead of owned data.
195    /// - For non-DDUF model sources, a path is returned
196    /// - File data should be read with `read_to_string`
197    pub fn read_file(&mut self, name: &str, from_transformer: bool) -> anyhow::Result<FileData> {
198        if from_transformer && !matches!(self, Self::ApiWithTransformer { .. }) {
199            anyhow::bail!("This model source has no transformer files.")
200        }
201
202        match (self, from_transformer) {
203            (Self::Api(api), false)
204            | (
205                Self::ApiWithTransformer {
206                    base: api,
207                    transformer: _,
208                },
209                false,
210            ) => Ok(FileData::Path(
211                api.get(name)
212                    .map_err(|e| anyhow::Error::msg(e.to_string()))?,
213            )),
214            (
215                Self::ApiWithTransformer {
216                    base: api,
217                    transformer: _,
218                },
219                true,
220            ) => Ok(FileData::Path(
221                api.get(name)
222                    .map_err(|e| anyhow::Error::msg(e.to_string()))?,
223            )),
224            (Self::Api(_), true) => anyhow::bail!("This model source has no transformer files."),
225            (Self::Dduf(dduf), _) => {
226                let file = dduf.by_name(name)?;
227                let start = file.data_start() as usize;
228                let len = file.size() as usize;
229                let end = start + len;
230                let name = file.name().into();
231                Ok(FileData::Dduf { name, start, end })
232            }
233        }
234    }
235
236    /// Read a file, always returning owned data.
237    ///
238    /// - If loading from a DDUF file, this copies the file data.
239    /// - For non-DDUF model sources, this is equivalent to `read_file`
240    /// - File data can always be read with `read_to_string_owned`, unlike from `read_file`
241    pub fn read_file_copied(
242        &mut self,
243        name: &str,
244        from_transformer: bool,
245    ) -> anyhow::Result<FileData> {
246        if matches!(self, Self::Api(_) | Self::ApiWithTransformer { .. }) {
247            return self.read_file(name, from_transformer);
248        }
249
250        let Self::Dduf(dduf) = self else {
251            anyhow::bail!("expected dduf model source!");
252        };
253        let mut file = dduf.by_name(name)?;
254        let mut data = Vec::new();
255        std::io::copy(&mut file, &mut data)?;
256        let name = PathBuf::from(file.name().to_string());
257        Ok(FileData::DdufOwned { name, data })
258    }
259}
260
261pub enum FileData {
262    Path(PathBuf),
263    Dduf {
264        name: PathBuf,
265        start: usize,
266        end: usize,
267    },
268    DdufOwned {
269        name: PathBuf,
270        data: Vec<u8>,
271    },
272}
273
274impl Debug for FileData {
275    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        match self {
277            Self::Path(p) => write!(f, "path: {}", p.display()),
278            Self::Dduf {
279                name,
280                start: _,
281                end: _,
282            } => write!(f, "dduf: {}", name.display()),
283            Self::DdufOwned { name, data: _ } => write!(f, "dduf owned: {}", name.display()),
284        }
285    }
286}
287
288impl FileData {
289    pub fn read_to_string(&self, src: &ModelSource) -> anyhow::Result<String> {
290        match self {
291            Self::Path(p) => Ok(fs::read_to_string(p)?),
292            Self::Dduf {
293                name: _,
294                start,
295                end,
296            } => {
297                let ModelSource::Dduf { file, name: _ } = src else {
298                    anyhow::bail!("expected dduf model source!");
299                };
300                Ok(String::from_utf8(file.get_ref()[*start..*end].to_vec())?)
301            }
302            Self::DdufOwned { name: _, data } => Ok(String::from_utf8(data.to_vec())?),
303        }
304    }
305
306    pub fn read_to_string_owned(&self) -> anyhow::Result<String> {
307        match self {
308            Self::Path(p) => Ok(fs::read_to_string(p)?),
309            Self::Dduf { .. } => {
310                anyhow::bail!("dduf file data is not owned !");
311            }
312            Self::DdufOwned { name: _, data } => Ok(String::from_utf8(data.to_vec())?),
313        }
314    }
315
316    pub fn extension(&self) -> Option<&OsStr> {
317        match self {
318            Self::Path(p) => p.extension(),
319            Self::Dduf {
320                name,
321                start: _,
322                end: _,
323            } => name.extension(),
324            Self::DdufOwned { name, data: _ } => name.extension(),
325        }
326    }
327}