1use std::collections::HashMap;
11use std::path::Path;
12
13use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float};
14use ferrotorch_nn::module::{Module, StateDict};
15use ferrotorch_serialize::load_safetensors;
16
17use crate::clip_text_encoder::{ClipTextConfig, ClipTextEncoder};
18use crate::config::VaeDecoderConfig;
19use crate::unet::UNet2DConditionModel;
20use crate::unet_config::UNet2DConditionConfig;
21use crate::vae::VaeDecoder;
22
23#[derive(Debug, Default, Clone)]
31pub struct DropReport {
32 pub dropped: Vec<String>,
35}
36
37impl<T: Float> VaeDecoder<T> {
38 pub fn load_hf_state_dict(
58 &mut self,
59 hf_state: &StateDict<T>,
60 strict: bool,
61 ) -> FerrotorchResult<DropReport> {
62 let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
63 let mut dropped: Vec<String> = Vec::new();
64
65 for (k, v) in hf_state {
66 let after_vae = k.strip_prefix("vae.").map_or_else(|| k.clone(), str::to_owned);
69 if after_vae.starts_with("post_quant_conv.") || after_vae.starts_with("decoder.") {
70 remapped.insert(after_vae, v.clone());
71 continue;
72 }
73 if strict {
74 return Err(FerrotorchError::InvalidArgument {
75 message: format!(
76 "VaeDecoder::load_hf_state_dict: key {k:?} is not under \
77 `post_quant_conv.*` / `decoder.*` (with optional `vae.` prefix) \
78 and strict mode is on. Pass strict=false to drop encoder / \
79 quant_conv keys."
80 ),
81 });
82 }
83 dropped.push(k.clone());
84 }
85 dropped.sort();
86 self.load_state_dict(&remapped, strict)?;
87 Ok(DropReport { dropped })
88 }
89}
90
91impl<T: Float> UNet2DConditionModel<T> {
96 pub fn load_hf_state_dict(
111 &mut self,
112 hf_state: &StateDict<T>,
113 strict: bool,
114 ) -> FerrotorchResult<DropReport> {
115 let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
116 let mut dropped: Vec<String> = Vec::new();
117 for (k, v) in hf_state {
118 let after_unet = k.strip_prefix("unet.").map_or_else(|| k.clone(), str::to_owned);
119 let is_unet_key = after_unet.starts_with("time_embedding.")
120 || after_unet.starts_with("conv_in.")
121 || after_unet.starts_with("down_blocks.")
122 || after_unet.starts_with("mid_block.")
123 || after_unet.starts_with("up_blocks.")
124 || after_unet.starts_with("conv_norm_out.")
125 || after_unet.starts_with("conv_out.");
126 if is_unet_key {
127 remapped.insert(after_unet, v.clone());
128 continue;
129 }
130 if strict {
131 return Err(FerrotorchError::InvalidArgument {
132 message: format!(
133 "UNet2DConditionModel::load_hf_state_dict: key {k:?} is not under \
134 a UNet prefix (with optional `unet.`) and strict mode is on."
135 ),
136 });
137 }
138 dropped.push(k.clone());
139 }
140 dropped.sort();
141 self.load_state_dict(&remapped, strict)?;
142 Ok(DropReport { dropped })
143 }
144}
145
146pub fn load_unet<T: Float>(
159 weights_path: &Path,
160 cfg: UNet2DConditionConfig,
161 strict: bool,
162) -> FerrotorchResult<(UNet2DConditionModel<T>, DropReport)> {
163 let state =
164 load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
165 message: format!(
166 "load_unet: failed to decode safetensors {}: {e}",
167 weights_path.display()
168 ),
169 })?;
170 let mut unet = UNet2DConditionModel::<T>::new(cfg)?;
171 let report = unet.load_hf_state_dict(&state, strict)?;
172 Ok((unet, report))
173}
174
175fn load_safetensors_clip_filtered<T: Float>(
187 weights_path: &Path,
188) -> FerrotorchResult<(StateDict<T>, bool)> {
189 use safetensors::SafeTensors;
190
191 let bytes =
192 std::fs::read(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
193 message: format!(
194 "load_safetensors_clip_filtered: failed to read {}: {e}",
195 weights_path.display()
196 ),
197 })?;
198 let st = SafeTensors::deserialize(&bytes).map_err(|e| FerrotorchError::InvalidArgument {
199 message: format!(
200 "load_safetensors_clip_filtered: failed to parse {}: {e}",
201 weights_path.display()
202 ),
203 })?;
204 let mut keep: Vec<String> = Vec::new();
205 let mut had_position_ids = false;
206 for k in st.names() {
207 let s: &str = k.as_str();
208 if s == "embeddings.position_ids" || s == "text_model.embeddings.position_ids" {
211 had_position_ids = true;
212 continue;
213 }
214 keep.push(String::from(s));
215 }
216
217 let mut subset: Vec<(String, safetensors::tensor::TensorView<'_>)> =
221 Vec::with_capacity(keep.len());
222 for k in &keep {
223 let v = st.tensor(k).map_err(|e| FerrotorchError::InvalidArgument {
224 message: format!(
225 "load_safetensors_clip_filtered: missing tensor {k:?} after filter: {e}"
226 ),
227 })?;
228 subset.push((k.clone(), v));
229 }
230 let serialized = safetensors::serialize(subset, &None).map_err(|e| {
231 FerrotorchError::InvalidArgument {
232 message: format!("load_safetensors_clip_filtered: re-serialize failed: {e}"),
233 }
234 })?;
235 let tmp = tempfile::NamedTempFile::new().map_err(|e| FerrotorchError::InvalidArgument {
236 message: format!("load_safetensors_clip_filtered: tempfile: {e}"),
237 })?;
238 std::fs::write(tmp.path(), &serialized).map_err(|e| FerrotorchError::InvalidArgument {
239 message: format!("load_safetensors_clip_filtered: tempfile write: {e}"),
240 })?;
241 let state = load_safetensors::<T>(tmp.path())?;
242 Ok((state, had_position_ids))
243}
244
245pub fn load_clip_text_encoder<T: Float>(
268 weights_path: &Path,
269 cfg: ClipTextConfig,
270 strict: bool,
271) -> FerrotorchResult<(ClipTextEncoder<T>, DropReport)> {
272 let (mut state, had_position_ids) =
273 load_safetensors_clip_filtered::<T>(weights_path).map_err(|e| {
274 FerrotorchError::InvalidArgument {
275 message: format!(
276 "load_clip_text_encoder: failed to decode safetensors {}: {e}",
277 weights_path.display()
278 ),
279 }
280 })?;
281
282 if had_position_ids {
288 let key = if state
289 .keys()
290 .any(|k| k.starts_with("text_model."))
291 {
292 "text_model.embeddings.position_ids".to_string()
293 } else {
294 "embeddings.position_ids".to_string()
295 };
296 state.insert(key, ferrotorch_core::zeros::<T>(&[1])?);
297 }
298
299 let mut enc = ClipTextEncoder::<T>::new(cfg)?;
300 let report = enc.load_hf_state_dict(&state, strict)?;
301 Ok((enc, report))
302}
303
304pub fn load_vae_decoder<T: Float>(
319 weights_path: &Path,
320 cfg: VaeDecoderConfig,
321 strict: bool,
322) -> FerrotorchResult<(VaeDecoder<T>, DropReport)> {
323 let state =
324 load_safetensors::<T>(weights_path).map_err(|e| FerrotorchError::InvalidArgument {
325 message: format!(
326 "load_vae_decoder: failed to decode safetensors {}: {e}",
327 weights_path.display()
328 ),
329 })?;
330 let mut decoder = VaeDecoder::<T>::new(cfg)?;
331 let report = decoder.load_hf_state_dict(&state, strict)?;
332 Ok((decoder, report))
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use ferrotorch_core::{Tensor, TensorStorage};
339 use ferrotorch_serialize::save_safetensors;
340 use std::path::PathBuf;
341
342 fn tiny_cfg() -> VaeDecoderConfig {
343 VaeDecoderConfig {
344 out_channels: 3,
345 latent_channels: 4,
346 block_out_channels: vec![4, 8, 16, 16],
347 layers_per_block: 1,
348 norm_num_groups: 4,
349 sample_size: 8,
350 scaling_factor: 0.18215,
351 }
352 }
353
354 fn tmp_safetensors_from(v: &VaeDecoder<f32>) -> (tempfile::TempDir, PathBuf) {
355 let dir = tempfile::tempdir().unwrap();
356 let path = dir.path().join("model.safetensors");
357 let sd = v.state_dict();
361 save_safetensors(&sd, &path).unwrap();
362 (dir, path)
363 }
364
365 #[test]
366 fn round_trip_safetensors_into_decoder() {
367 let cfg = tiny_cfg();
368 let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
369 let (_d, p) = tmp_safetensors_from(&src);
370 let (dst, report) = load_vae_decoder::<f32>(&p, cfg.clone(), false).unwrap();
371 assert!(
372 report.dropped.is_empty(),
373 "round-trip should have empty drop list, got {:?}",
374 report.dropped
375 );
376 let x = Tensor::from_storage(
377 TensorStorage::cpu(vec![0.01f32; 4]),
378 vec![1, 4, 1, 1],
379 false,
380 )
381 .unwrap();
382 let a = src.forward(&x).unwrap();
383 let b = dst.forward(&x).unwrap();
384 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
385 assert!((x - y).abs() < 1e-5);
386 }
387 }
388
389 #[test]
390 fn load_hf_drops_encoder_keys_nonstrict() {
391 let cfg = tiny_cfg();
392 let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
393 let mut hf_sd: StateDict<f32> = v.state_dict();
394 hf_sd.insert(
396 "encoder.conv_in.weight".into(),
397 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
398 );
399 hf_sd.insert(
401 "quant_conv.weight".into(),
402 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
403 );
404 let rep = v.load_hf_state_dict(&hf_sd, false).unwrap();
405 assert_eq!(
406 rep.dropped,
407 vec![
408 "encoder.conv_in.weight".to_string(),
409 "quant_conv.weight".to_string(),
410 ]
411 );
412 }
413
414 #[test]
415 fn load_hf_strict_rejects_encoder_keys() {
416 let cfg = tiny_cfg();
417 let mut v = VaeDecoder::<f32>::new(cfg).unwrap();
418 let mut hf_sd: StateDict<f32> = HashMap::new();
419 hf_sd.insert(
420 "encoder.conv_in.weight".into(),
421 ferrotorch_core::zeros::<f32>(&[4, 4]).unwrap(),
422 );
423 assert!(v.load_hf_state_dict(&hf_sd, true).is_err());
424 }
425
426 #[test]
427 fn load_hf_strips_vae_prefix() {
428 let cfg = tiny_cfg();
429 let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
430 let bare = src.state_dict();
431 let mut prefixed: StateDict<f32> = HashMap::new();
433 for (k, v) in bare {
434 prefixed.insert(format!("vae.{k}"), v);
435 }
436 let mut dst = VaeDecoder::<f32>::new(cfg).unwrap();
437 let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
438 assert!(rep.dropped.is_empty(), "got {:?}", rep.dropped);
439 let x = Tensor::from_storage(
440 TensorStorage::cpu(vec![0.01f32; 4]),
441 vec![1, 4, 1, 1],
442 false,
443 )
444 .unwrap();
445 let a = src.forward(&x).unwrap();
446 let b = dst.forward(&x).unwrap();
447 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
448 assert!((x - y).abs() < 1e-5);
449 }
450 }
451}