Skip to main content

hanzo_engine/pipeline/loaders/
diffusion_loaders.rs

1use std::{
2    fmt::Debug,
3    path::{Path, PathBuf},
4    str::FromStr,
5};
6
7use anyhow::{Context, Result};
8use hanzo_ml::{Device, Tensor};
9
10use hanzo_quant::ShardedVarBuilder;
11use hf_hub::api::sync::ApiRepo;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::pyclass;
14
15use regex::Regex;
16use serde::Deserialize;
17
18use super::{ModelPaths, NormalLoadingMetadata};
19use crate::{
20    api_dir_list, api_get_file,
21    diffusion_models::{
22        flux::{
23            self,
24            stepper::{FluxStepper, FluxStepperConfig},
25        },
26        DiffusionGenerationParams,
27    },
28    paged_attention::AttentionImplementation,
29    pipeline::{paths::AdapterPaths, EmbeddingModulePaths},
30};
31
32pub trait DiffusionModel {
33    /// This returns a tensor of shape (bs, c, h, w), with values in [0, 255].
34    fn forward(
35        &mut self,
36        prompts: Vec<String>,
37        params: DiffusionGenerationParams,
38    ) -> hanzo_ml::Result<Tensor>;
39    fn device(&self) -> &Device;
40    fn max_seq_len(&self) -> usize;
41}
42
43pub trait DiffusionModelLoader: Send + Sync {
44    /// If the model is being loaded with `load_model_from_hf` (so manual paths not provided), this will be called.
45    fn get_model_paths(
46        &self,
47        api: &ApiRepo,
48        model_id: &Path,
49        revision: &str,
50    ) -> Result<Vec<PathBuf>>;
51    /// If the model is being loaded with `load_model_from_hf` (so manual paths not provided), this will be called.
52    fn get_config_filenames(
53        &self,
54        api: &ApiRepo,
55        model_id: &Path,
56        revision: &str,
57    ) -> Result<Vec<PathBuf>>;
58    fn force_cpu_vb(&self) -> Vec<bool>;
59    // `configs` and `vbs` should be corresponding. It is up to the implementer to maintain this invaraint.
60    fn load(
61        &self,
62        configs: Vec<String>,
63        vbs: Vec<ShardedVarBuilder>,
64        normal_loading_metadata: NormalLoadingMetadata,
65        attention_mechanism: AttentionImplementation,
66        silent: bool,
67    ) -> Result<Box<dyn DiffusionModel + Send + Sync>>;
68}
69
70#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
71#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
72/// The architecture to load the diffusion model as.
73pub enum DiffusionLoaderType {
74    #[serde(rename = "flux")]
75    Flux,
76    #[serde(rename = "flux-offloaded")]
77    FluxOffloaded,
78}
79
80impl FromStr for DiffusionLoaderType {
81    type Err = String;
82    fn from_str(s: &str) -> Result<Self, Self::Err> {
83        match s {
84            "flux" => Ok(Self::Flux),
85            "flux-offloaded" => Ok(Self::FluxOffloaded),
86            a => Err(format!(
87                "Unknown architecture `{a}`. Possible architectures: `flux`."
88            )),
89        }
90    }
91}
92
93impl DiffusionLoaderType {
94    /// Auto-detect diffusion loader type from a repo file listing.
95    /// Extend this when adding new diffusion pipelines.
96    pub fn auto_detect_from_files(files: &[String]) -> Option<Self> {
97        if Self::matches_flux(files) {
98            return Some(Self::Flux);
99        }
100        None
101    }
102
103    fn matches_flux(files: &[String]) -> bool {
104        let flux_regex = Regex::new(r"^flux\\d+-(schnell|dev)\\.safetensors$");
105        let Ok(flux_regex) = flux_regex else {
106            return false;
107        };
108        let has_transformer = files.iter().any(|f| f == "transformer/config.json");
109        let has_vae = files.iter().any(|f| f == "vae/config.json");
110        let has_ae = files.iter().any(|f| f == "ae.safetensors");
111        let has_flux = files.iter().any(|f| {
112            let name = f.rsplit('/').next().unwrap_or(f);
113            flux_regex.is_match(name)
114        });
115
116        has_transformer && has_vae && has_ae && has_flux
117    }
118}
119
120#[derive(Clone, Debug)]
121pub struct DiffusionModelPathsInner {
122    pub config_filenames: Vec<PathBuf>,
123    pub filenames: Vec<PathBuf>,
124}
125
126#[derive(Clone, Debug)]
127pub struct DiffusionModelPaths(pub DiffusionModelPathsInner);
128
129impl ModelPaths for DiffusionModelPaths {
130    fn get_config_filename(&self) -> &PathBuf {
131        unreachable!("Use `std::any::Any`.")
132    }
133    fn get_tokenizer_filename(&self) -> &PathBuf {
134        unreachable!("Use `std::any::Any`.")
135    }
136    fn get_weight_filenames(&self) -> &[PathBuf] {
137        unreachable!("Use `std::any::Any`.")
138    }
139    fn get_template_filename(&self) -> &Option<PathBuf> {
140        unreachable!("Use `std::any::Any`.")
141    }
142    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
143        unreachable!("Use `std::any::Any`.")
144    }
145    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
146        unreachable!("Use `std::any::Any`.")
147    }
148    fn get_processor_config(&self) -> &Option<PathBuf> {
149        unreachable!("Use `std::any::Any`.")
150    }
151    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
152        unreachable!("Use `std::any::Any`.")
153    }
154    fn get_adapter_paths(&self) -> &AdapterPaths {
155        unreachable!("Use `std::any::Any`.")
156    }
157    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
158        unreachable!("Use `std::any::Any`.")
159    }
160}
161
162// ======================== Flux loader
163
164/// [`DiffusionLoader`] for a Flux Diffusion model.
165///
166/// [`DiffusionLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.DiffusionLoader.html
167pub struct FluxLoader {
168    pub(crate) offload: bool,
169}
170
171impl DiffusionModelLoader for FluxLoader {
172    fn get_model_paths(
173        &self,
174        api: &ApiRepo,
175        model_id: &Path,
176        revision: &str,
177    ) -> Result<Vec<PathBuf>> {
178        let regex = Regex::new(r"^flux\d+-(schnell|dev)\.safetensors$")?;
179        let flux_name = api_dir_list!(api, model_id, true, revision)
180            .filter(|x| regex.is_match(x))
181            .nth(0)
182            .with_context(|| "Expected at least 1 .safetensors file matching the FLUX regex, please raise an issue.")?;
183        let flux_file = api_get_file!(api, &flux_name, model_id, revision);
184        let ae_file = api_get_file!(api, "ae.safetensors", model_id, revision);
185
186        // NOTE(hanzoai): disgusting way of doing this but the 0th path is the flux, 1 is ae
187        Ok(vec![flux_file, ae_file])
188    }
189    fn get_config_filenames(
190        &self,
191        api: &ApiRepo,
192        model_id: &Path,
193        revision: &str,
194    ) -> Result<Vec<PathBuf>> {
195        let flux_file = api_get_file!(api, "transformer/config.json", model_id, revision);
196        let ae_file = api_get_file!(api, "vae/config.json", model_id, revision);
197
198        // NOTE(hanzoai): disgusting way of doing this but the 0th path is the flux, 1 is ae
199        Ok(vec![flux_file, ae_file])
200    }
201    fn force_cpu_vb(&self) -> Vec<bool> {
202        vec![self.offload, false]
203    }
204    fn load(
205        &self,
206        mut configs: Vec<String>,
207        mut vbs: Vec<ShardedVarBuilder>,
208        normal_loading_metadata: NormalLoadingMetadata,
209        _attention_mechanism: AttentionImplementation,
210        silent: bool,
211    ) -> Result<Box<dyn DiffusionModel + Send + Sync>> {
212        let (vae_cfg, vae_vb) = (configs.remove(1), vbs.remove(1));
213        let (flux_cfg, flux_vb) = (configs.remove(0), vbs.remove(0));
214
215        let vae_cfg: flux::autoencoder::Config = serde_json::from_str(&vae_cfg)?;
216        let flux_cfg: flux::model::Config = serde_json::from_str(&flux_cfg)?;
217
218        let flux_dtype = flux_vb.dtype();
219        if flux_dtype != vae_vb.dtype() {
220            anyhow::bail!(
221                "Expected VAE and FLUX model VBs to be the same dtype, got {:?} and {flux_dtype:?}",
222                vae_vb.dtype()
223            );
224        }
225
226        Ok(Box::new(FluxStepper::new(
227            FluxStepperConfig::default_for_guidance(flux_cfg.guidance_embeds),
228            (flux_vb, &flux_cfg),
229            (vae_vb, &vae_cfg),
230            flux_dtype,
231            &normal_loading_metadata.real_device,
232            silent,
233            self.offload,
234        )?))
235    }
236}