use ferrotorch_core::{Tensor, TensorStorage};
use ferrotorch_diffusion::{VaeDecoderConfig, VaeEncoder, VaeEncoderConfig};
use ferrotorch_nn::module::Module;
fn sd_v1_5_config() -> VaeEncoderConfig {
VaeDecoderConfig::sd_v1_5()
}
fn striped_image(b: usize, h: usize, w: usize) -> Tensor<f32> {
let mut data = Vec::with_capacity(b * 3 * h * w);
for _ in 0..b {
for c in 0..3 {
for y in 0..h {
for _ in 0..w {
let base = (y as f32 / h as f32) * 2.0 - 1.0;
let v = base + (c as f32) * 0.05;
data.push(v.clamp(-1.0, 1.0));
}
}
}
}
Tensor::from_storage(TensorStorage::cpu(data), vec![b, 3, h, w], false)
.expect("striped_image: tensor construction must succeed")
}
#[test]
#[ignore = "Allocates ~600 MB / 10 s CPU — enable with --ignored"]
fn vae_encoder_sd_scale_shape_sanity() {
let cfg = sd_v1_5_config();
let enc = VaeEncoder::<f32>::new(cfg.clone())
.expect("VaeEncoder::new must succeed for the canonical SD-1.5 config");
let x = striped_image(1, 512, 512);
let params = enc
.forward(&x)
.expect("VaeEncoder forward must succeed at SD-1.5 scale");
assert_eq!(
params.shape(),
&[1, 2 * cfg.latent_channels, 64, 64],
"SD-1.5 VAE encoder must produce [1, 8, 64, 64], got {:?}",
params.shape()
);
let mut count_nonfinite = 0usize;
for &v in params.data().expect("params must have data") {
if !v.is_finite() {
count_nonfinite += 1;
}
}
assert_eq!(
count_nonfinite, 0,
"SD-1.5 VAE encoder produced {count_nonfinite} non-finite values"
);
let dist = enc
.encode(&x)
.expect("VaeEncoder::encode must succeed at SD-1.5 scale");
assert_eq!(dist.mean.shape(), &[1, 4, 64, 64]);
assert_eq!(dist.logvar.shape(), &[1, 4, 64, 64]);
for &v in dist.logvar.data().expect("logvar must have data") {
assert!(
v.is_finite() && (-30.0..=20.0).contains(&v),
"logvar value {v} outside the [-30, 20] clamp range"
);
}
}
#[test]
#[ignore = "Awaits the `ferrotorch/sd-v1-5-vae-encoder` HF mirror — enable with --ignored"]
fn vae_encoder_diffusers_parity_smoke() {
let mirror_present = ferrotorch_hub::registry::get_model_info("sd-v1-5-vae-encoder").is_some();
if !mirror_present {
eprintln!(
"vae_encoder_diffusers_parity_smoke: skipping — \
`ferrotorch/sd-v1-5-vae-encoder` mirror is not yet registered. \
To enable, (1) run `scripts/verify_vae_encoder.py --pin` to \
produce the encoder safetensors + `_value_parity_*.bin` \
fixtures, (2) push the mirror to HuggingFace, (3) add a \
`ModelInfo` entry to ferrotorch-hub/src/registry.rs."
);
return;
}
panic!(
"vae_encoder_diffusers_parity_smoke: mirror present but parity harness \
not implemented. Wire up via the same pattern as \
conformance_pretrained_diffusion.rs."
);
}