hanzo_engine/pipeline/loaders/
diffusion_loaders.rs1use std::{
2 fmt::Debug,
3 path::{Path, PathBuf},
4 str::FromStr,
5};
6
7use anyhow::{Context, Result};
8use hanzo_ml::{Device, Tensor};
9
10use hanzo_quant::ShardedVarBuilder;
11use hf_hub::api::sync::ApiRepo;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::pyclass;
14
15use regex::Regex;
16use serde::Deserialize;
17
18use super::{ModelPaths, NormalLoadingMetadata};
19use crate::{
20 api_dir_list, api_get_file,
21 diffusion_models::{
22 flux::{
23 self,
24 stepper::{FluxStepper, FluxStepperConfig},
25 },
26 DiffusionGenerationParams,
27 },
28 paged_attention::AttentionImplementation,
29 pipeline::{paths::AdapterPaths, EmbeddingModulePaths},
30};
31
32pub trait DiffusionModel {
33 fn forward(
35 &mut self,
36 prompts: Vec<String>,
37 params: DiffusionGenerationParams,
38 ) -> hanzo_ml::Result<Tensor>;
39 fn device(&self) -> &Device;
40 fn max_seq_len(&self) -> usize;
41}
42
43pub trait DiffusionModelLoader: Send + Sync {
44 fn get_model_paths(
46 &self,
47 api: &ApiRepo,
48 model_id: &Path,
49 revision: &str,
50 ) -> Result<Vec<PathBuf>>;
51 fn get_config_filenames(
53 &self,
54 api: &ApiRepo,
55 model_id: &Path,
56 revision: &str,
57 ) -> Result<Vec<PathBuf>>;
58 fn force_cpu_vb(&self) -> Vec<bool>;
59 fn load(
61 &self,
62 configs: Vec<String>,
63 vbs: Vec<ShardedVarBuilder>,
64 normal_loading_metadata: NormalLoadingMetadata,
65 attention_mechanism: AttentionImplementation,
66 silent: bool,
67 ) -> Result<Box<dyn DiffusionModel + Send + Sync>>;
68}
69
70#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
71#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
72pub enum DiffusionLoaderType {
74 #[serde(rename = "flux")]
75 Flux,
76 #[serde(rename = "flux-offloaded")]
77 FluxOffloaded,
78}
79
80impl FromStr for DiffusionLoaderType {
81 type Err = String;
82 fn from_str(s: &str) -> Result<Self, Self::Err> {
83 match s {
84 "flux" => Ok(Self::Flux),
85 "flux-offloaded" => Ok(Self::FluxOffloaded),
86 a => Err(format!(
87 "Unknown architecture `{a}`. Possible architectures: `flux`."
88 )),
89 }
90 }
91}
92
93impl DiffusionLoaderType {
94 pub fn auto_detect_from_files(files: &[String]) -> Option<Self> {
97 if Self::matches_flux(files) {
98 return Some(Self::Flux);
99 }
100 None
101 }
102
103 fn matches_flux(files: &[String]) -> bool {
104 let flux_regex = Regex::new(r"^flux\\d+-(schnell|dev)\\.safetensors$");
105 let Ok(flux_regex) = flux_regex else {
106 return false;
107 };
108 let has_transformer = files.iter().any(|f| f == "transformer/config.json");
109 let has_vae = files.iter().any(|f| f == "vae/config.json");
110 let has_ae = files.iter().any(|f| f == "ae.safetensors");
111 let has_flux = files.iter().any(|f| {
112 let name = f.rsplit('/').next().unwrap_or(f);
113 flux_regex.is_match(name)
114 });
115
116 has_transformer && has_vae && has_ae && has_flux
117 }
118}
119
120#[derive(Clone, Debug)]
121pub struct DiffusionModelPathsInner {
122 pub config_filenames: Vec<PathBuf>,
123 pub filenames: Vec<PathBuf>,
124}
125
126#[derive(Clone, Debug)]
127pub struct DiffusionModelPaths(pub DiffusionModelPathsInner);
128
129impl ModelPaths for DiffusionModelPaths {
130 fn get_config_filename(&self) -> &PathBuf {
131 unreachable!("Use `std::any::Any`.")
132 }
133 fn get_tokenizer_filename(&self) -> &PathBuf {
134 unreachable!("Use `std::any::Any`.")
135 }
136 fn get_weight_filenames(&self) -> &[PathBuf] {
137 unreachable!("Use `std::any::Any`.")
138 }
139 fn get_template_filename(&self) -> &Option<PathBuf> {
140 unreachable!("Use `std::any::Any`.")
141 }
142 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
143 unreachable!("Use `std::any::Any`.")
144 }
145 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
146 unreachable!("Use `std::any::Any`.")
147 }
148 fn get_processor_config(&self) -> &Option<PathBuf> {
149 unreachable!("Use `std::any::Any`.")
150 }
151 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
152 unreachable!("Use `std::any::Any`.")
153 }
154 fn get_adapter_paths(&self) -> &AdapterPaths {
155 unreachable!("Use `std::any::Any`.")
156 }
157 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
158 unreachable!("Use `std::any::Any`.")
159 }
160}
161
162pub struct FluxLoader {
168 pub(crate) offload: bool,
169}
170
171impl DiffusionModelLoader for FluxLoader {
172 fn get_model_paths(
173 &self,
174 api: &ApiRepo,
175 model_id: &Path,
176 revision: &str,
177 ) -> Result<Vec<PathBuf>> {
178 let regex = Regex::new(r"^flux\d+-(schnell|dev)\.safetensors$")?;
179 let flux_name = api_dir_list!(api, model_id, true, revision)
180 .filter(|x| regex.is_match(x))
181 .nth(0)
182 .with_context(|| "Expected at least 1 .safetensors file matching the FLUX regex, please raise an issue.")?;
183 let flux_file = api_get_file!(api, &flux_name, model_id, revision);
184 let ae_file = api_get_file!(api, "ae.safetensors", model_id, revision);
185
186 Ok(vec![flux_file, ae_file])
188 }
189 fn get_config_filenames(
190 &self,
191 api: &ApiRepo,
192 model_id: &Path,
193 revision: &str,
194 ) -> Result<Vec<PathBuf>> {
195 let flux_file = api_get_file!(api, "transformer/config.json", model_id, revision);
196 let ae_file = api_get_file!(api, "vae/config.json", model_id, revision);
197
198 Ok(vec![flux_file, ae_file])
200 }
201 fn force_cpu_vb(&self) -> Vec<bool> {
202 vec![self.offload, false]
203 }
204 fn load(
205 &self,
206 mut configs: Vec<String>,
207 mut vbs: Vec<ShardedVarBuilder>,
208 normal_loading_metadata: NormalLoadingMetadata,
209 _attention_mechanism: AttentionImplementation,
210 silent: bool,
211 ) -> Result<Box<dyn DiffusionModel + Send + Sync>> {
212 let (vae_cfg, vae_vb) = (configs.remove(1), vbs.remove(1));
213 let (flux_cfg, flux_vb) = (configs.remove(0), vbs.remove(0));
214
215 let vae_cfg: flux::autoencoder::Config = serde_json::from_str(&vae_cfg)?;
216 let flux_cfg: flux::model::Config = serde_json::from_str(&flux_cfg)?;
217
218 let flux_dtype = flux_vb.dtype();
219 if flux_dtype != vae_vb.dtype() {
220 anyhow::bail!(
221 "Expected VAE and FLUX model VBs to be the same dtype, got {:?} and {flux_dtype:?}",
222 vae_vb.dtype()
223 );
224 }
225
226 Ok(Box::new(FluxStepper::new(
227 FluxStepperConfig::default_for_guidance(flux_cfg.guidance_embeds),
228 (flux_vb, &flux_cfg),
229 (vae_vb, &vae_cfg),
230 flux_dtype,
231 &normal_loading_metadata.real_device,
232 silent,
233 self.offload,
234 )?))
235 }
236}