use ferrotorch_core::{FerrotorchError, FerrotorchResult, Tensor, TensorStorage};
use ferrotorch_gpu::{
CudaBuffer, GpuDevice, GpuError,
kernels::{gpu_add, gpu_clamp, gpu_exp, gpu_mul, gpu_scale, gpu_silu},
rng::gpu_philox_normal,
transfer::{alloc_zeros_f32, cpu_to_gpu, gpu_to_cpu},
};
use ferrotorch_nn::module::StateDict;
use crate::config::VaeDecoderConfig;
use crate::safetensors_loader::DropReport;
use crate::vae_encoder::VaeEncoder;
use super::vae::{
GpuConv2d, GpuGroupNorm, GpuMidBlock, GpuResnet, attn_forward, conv_forward, gpu_err,
group_norm_forward, pop_attn, pop_conv, pop_groupnorm, pop_resnet, resnet_forward,
};
const LOGVAR_CLAMP_MIN: f32 = -30.0;
const LOGVAR_CLAMP_MAX: f32 = 20.0;
#[derive(Debug)]
struct GpuDownsample {
conv: GpuConv2d,
channels: usize,
}
#[derive(Debug)]
struct GpuDownEncoderBlock {
resnets: Vec<GpuResnet>,
downsample: Option<GpuDownsample>,
}
#[derive(Debug)]
pub struct GpuVaeEncoder {
conv_in: GpuConv2d,
down_blocks: Vec<GpuDownEncoderBlock>,
mid_block: GpuMidBlock,
conv_norm_out: GpuGroupNorm,
conv_out: GpuConv2d,
quant_conv: GpuConv2d,
config: VaeDecoderConfig,
device: GpuDevice,
}
impl GpuVaeEncoder {
pub fn new(
config: VaeDecoderConfig,
mut state: StateDict<f32>,
device: GpuDevice,
) -> FerrotorchResult<(Self, DropReport)> {
config.validate()?;
let eps = 1e-6_f32;
let groups = config.norm_num_groups;
let latent_c = config.latent_channels;
let top_c =
*config
.block_out_channels
.last()
.ok_or_else(|| FerrotorchError::InvalidArgument {
message: "GpuVaeEncoder: block_out_channels empty".into(),
})?;
let bottom_c = config.block_out_channels[0];
let resnets_per_block = config.layers_per_block;
let conv_in = pop_conv(
&mut state,
"encoder.conv_in",
config.out_channels,
bottom_c,
(3, 3),
(1, 1),
(1, 1),
&device,
)?;
let num_blocks = config.block_out_channels.len();
let mut down_blocks: Vec<GpuDownEncoderBlock> = Vec::with_capacity(num_blocks);
let mut prev_out = bottom_c;
for (i, &c) in config.block_out_channels.iter().enumerate() {
let is_final = i == num_blocks - 1;
let mut resnets = Vec::with_capacity(resnets_per_block);
for r in 0..resnets_per_block {
let in_c = if r == 0 { prev_out } else { c };
resnets.push(pop_resnet(
&mut state,
&format!("encoder.down_blocks.{i}.resnets.{r}"),
in_c,
c,
groups,
eps,
&device,
)?);
}
let downsample = if is_final {
None
} else {
let conv = pop_conv(
&mut state,
&format!("encoder.down_blocks.{i}.downsamplers.0.conv"),
c,
c,
(3, 3),
(2, 2),
(1, 1),
&device,
)?;
Some(GpuDownsample { conv, channels: c })
};
down_blocks.push(GpuDownEncoderBlock {
resnets,
downsample,
});
prev_out = c;
}
let mid_resnet0 = pop_resnet(
&mut state,
"encoder.mid_block.resnets.0",
top_c,
top_c,
groups,
eps,
&device,
)?;
let mid_attn0 = pop_attn(
&mut state,
"encoder.mid_block.attentions.0",
top_c,
groups,
eps,
&device,
)?;
let mid_resnet1 = pop_resnet(
&mut state,
"encoder.mid_block.resnets.1",
top_c,
top_c,
groups,
eps,
&device,
)?;
let mid_block = GpuMidBlock {
resnets: vec![mid_resnet0, mid_resnet1],
attentions: vec![mid_attn0],
};
let conv_norm_out = pop_groupnorm(
&mut state,
"encoder.conv_norm_out",
groups,
top_c,
eps,
&device,
)?;
let conv_out = pop_conv(
&mut state,
"encoder.conv_out",
top_c,
2 * latent_c,
(3, 3),
(1, 1),
(1, 1),
&device,
)?;
let quant_conv = pop_conv(
&mut state,
"quant_conv",
2 * latent_c,
2 * latent_c,
(1, 1),
(1, 1),
(0, 0),
&device,
)?;
let mut dropped: Vec<String> = state.keys().cloned().collect();
dropped.sort();
let report = DropReport { dropped };
Ok((
Self {
conv_in,
down_blocks,
mid_block,
conv_norm_out,
conv_out,
quant_conv,
config,
device,
},
report,
))
}
pub fn from_module(
cpu: &VaeEncoder<f32>,
device: &GpuDevice,
) -> FerrotorchResult<(Self, DropReport)> {
use ferrotorch_nn::module::Module;
let state: StateDict<f32> = cpu.state_dict();
Self::new(cpu.config.clone(), state, device.clone())
}
fn forward_to_params(
&self,
x: &CudaBuffer<f32>,
shape: [usize; 4],
) -> FerrotorchResult<(CudaBuffer<f32>, [usize; 4])> {
let [b, c_in, h, w] = shape;
if c_in != self.config.out_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuVaeEncoder: expected input channels={}, got {}",
self.config.out_channels, c_in
),
});
}
let (mut hbuf, mut hshape) = conv_forward(&self.conv_in, x, [b, c_in, h, w], &self.device)?;
for block in &self.down_blocks {
for r in &block.resnets {
(hbuf, hshape) = resnet_forward(r, &hbuf, hshape, &self.device)?;
}
if let Some(ds) = &block.downsample {
(hbuf, hshape) = downsample_forward(ds, &hbuf, hshape, &self.device)?;
}
}
(hbuf, hshape) = resnet_forward(&self.mid_block.resnets[0], &hbuf, hshape, &self.device)?;
(hbuf, hshape) = attn_forward(&self.mid_block.attentions[0], &hbuf, hshape, &self.device)?;
(hbuf, hshape) = resnet_forward(&self.mid_block.resnets[1], &hbuf, hshape, &self.device)?;
hbuf = group_norm_forward(&self.conv_norm_out, &hbuf, hshape, &self.device)?;
hbuf = gpu_silu(&hbuf, &self.device).map_err(gpu_err)?;
(hbuf, hshape) = conv_forward(&self.conv_out, &hbuf, hshape, &self.device)?;
let (params_buf, params_shape) =
conv_forward(&self.quant_conv, &hbuf, hshape, &self.device)?;
Ok((params_buf, params_shape))
}
pub fn encode(&self, image: &Tensor<f32>) -> FerrotorchResult<Tensor<f32>> {
let (out_buf, out_shape) = self.encode_to_gpu_buf(image, false)?;
let out_data = gpu_to_cpu(&out_buf, &self.device).map_err(gpu_err)?;
Tensor::from_storage(TensorStorage::cpu(out_data), out_shape.to_vec(), false)
}
pub fn encode_mode(&self, image: &Tensor<f32>) -> FerrotorchResult<Tensor<f32>> {
let (out_buf, out_shape) = self.encode_to_gpu_buf(image, true)?;
let out_data = gpu_to_cpu(&out_buf, &self.device).map_err(gpu_err)?;
Tensor::from_storage(TensorStorage::cpu(out_data), out_shape.to_vec(), false)
}
pub fn encode_with_gpu_params_probe<F>(
&self,
image: &Tensor<f32>,
probe: F,
) -> FerrotorchResult<Tensor<f32>>
where
F: FnOnce(&CudaBuffer<f32>, [usize; 4]) -> FerrotorchResult<()>,
{
let shape = image.shape();
if shape.len() != 4 || shape[1] != self.config.out_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuVaeEncoder::encode: expected [B, {}, H, W], got {:?}",
self.config.out_channels, shape
),
});
}
let data = image.data()?;
let x = cpu_to_gpu(data, &self.device).map_err(gpu_err)?;
let (params_buf, params_shape) =
self.forward_to_params(&x, [shape[0], shape[1], shape[2], shape[3]])?;
probe(¶ms_buf, params_shape)?;
let (out_buf, out_shape) = diag_gauss_sample_with_scale_gpu(
¶ms_buf,
params_shape,
self.config.latent_channels,
self.config.scaling_factor as f32,
false,
&self.device,
)?;
let out_data = gpu_to_cpu(&out_buf, &self.device).map_err(gpu_err)?;
Tensor::from_storage(TensorStorage::cpu(out_data), out_shape.to_vec(), false)
}
fn encode_to_gpu_buf(
&self,
image: &Tensor<f32>,
deterministic: bool,
) -> FerrotorchResult<(CudaBuffer<f32>, [usize; 4])> {
let shape = image.shape();
if shape.len() != 4 || shape[1] != self.config.out_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuVaeEncoder::encode: expected [B, {}, H, W], got {:?}",
self.config.out_channels, shape
),
});
}
let data = image.data()?;
let x = cpu_to_gpu(data, &self.device).map_err(gpu_err)?;
let (params_buf, params_shape) =
self.forward_to_params(&x, [shape[0], shape[1], shape[2], shape[3]])?;
diag_gauss_sample_with_scale_gpu(
¶ms_buf,
params_shape,
self.config.latent_channels,
self.config.scaling_factor as f32,
deterministic,
&self.device,
)
}
}
fn downsample_forward(
d: &GpuDownsample,
x: &CudaBuffer<f32>,
shape: [usize; 4],
device: &GpuDevice,
) -> FerrotorchResult<(CudaBuffer<f32>, [usize; 4])> {
if shape[1] != d.channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"downsample_forward: expected {} channels, got {}",
d.channels, shape[1]
),
});
}
conv_forward(&d.conv, x, shape, device)
}
fn diag_gauss_sample_with_scale_gpu(
params: &CudaBuffer<f32>,
params_shape: [usize; 4],
latent_channels: usize,
scaling_factor: f32,
deterministic: bool,
device: &GpuDevice,
) -> FerrotorchResult<(CudaBuffer<f32>, [usize; 4])> {
let [b, c2, h, w] = params_shape;
if c2 != 2 * latent_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"diag_gauss_sample_with_scale_gpu: expected 2*{} channels, got {}",
latent_channels, c2
),
});
}
if b != 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"diag_gauss_sample_with_scale_gpu: only B=1 is supported, got B={b}. \
The channel-split assumes a contiguous layout; B>1 requires a strided \
gather kernel (follow-on)."
),
});
}
let latent_numel = latent_channels * h * w;
let stream = device.stream();
let mut mean_buf = alloc_zeros_f32(latent_numel, device).map_err(gpu_err)?;
let mut logvar_buf = alloc_zeros_f32(latent_numel, device).map_err(gpu_err)?;
{
let src = params.inner();
let mean_view = src.slice(0..latent_numel);
let logvar_view = src.slice(latent_numel..(2 * latent_numel));
stream
.memcpy_dtod(&mean_view, mean_buf.inner_mut())
.map_err(|e| gpu_err(GpuError::from(e)))?;
stream
.memcpy_dtod(&logvar_view, logvar_buf.inner_mut())
.map_err(|e| gpu_err(GpuError::from(e)))?;
}
let out_shape = [b, latent_channels, h, w];
if deterministic {
let scaled = gpu_scale(&mean_buf, scaling_factor, device).map_err(gpu_err)?;
return Ok((scaled, out_shape));
}
let logvar_clamped =
gpu_clamp(&logvar_buf, LOGVAR_CLAMP_MIN, LOGVAR_CLAMP_MAX, device).map_err(gpu_err)?;
let eps = gpu_philox_normal(latent_numel, device).map_err(gpu_err)?;
let half_logvar = gpu_scale(&logvar_clamped, 0.5_f32, device).map_err(gpu_err)?;
let std = gpu_exp(&half_logvar, device).map_err(gpu_err)?;
let noise = gpu_mul(&std, &eps, device).map_err(gpu_err)?;
let sample = gpu_add(&mean_buf, &noise, device).map_err(gpu_err)?;
let scaled = gpu_scale(&sample, scaling_factor, device).map_err(gpu_err)?;
Ok((scaled, out_shape))
}
#[cfg(all(test, feature = "cuda"))]
mod tests {
use super::*;
use crate::vae_encoder::VaeEncoder;
use ferrotorch_nn::module::Module;
fn tiny_cfg() -> VaeDecoderConfig {
VaeDecoderConfig {
out_channels: 3,
latent_channels: 4,
block_out_channels: vec![4, 8, 16, 16],
layers_per_block: 1,
norm_num_groups: 4,
sample_size: 8,
scaling_factor: 0.18215,
}
}
fn striped_image_tiny() -> Tensor<f32> {
let mut data = Vec::with_capacity(3 * 8 * 8);
for c in 0..3 {
for y in 0..8 {
for _ in 0..8 {
let base = (y as f32 / 8.0) * 2.0 - 1.0;
data.push((base + c as f32 * 0.05).clamp(-1.0, 1.0));
}
}
}
Tensor::from_storage(TensorStorage::cpu(data), vec![1, 3, 8, 8], false).unwrap()
}
#[test]
fn gpu_encoder_mode_matches_cpu_mean_scaled_tiny() {
let Ok(device) = GpuDevice::new(0) else {
return;
};
let cfg = tiny_cfg();
let cpu = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
let (gpu, report) = GpuVaeEncoder::from_module(&cpu, &device).unwrap();
assert!(
report.dropped.is_empty(),
"unexpected dropped keys: {:?}",
report.dropped
);
let img = striped_image_tiny();
let gpu_latent = gpu.encode_mode(&img).unwrap();
let cpu_params = cpu.forward(&img).unwrap();
let cpu_chunks = cpu_params.chunk(2, 1).unwrap();
let cpu_mean = &cpu_chunks[0];
assert_eq!(gpu_latent.shape(), cpu_mean.shape());
let gpu_data = gpu_latent.data().unwrap();
let cpu_data = cpu_mean.data().unwrap();
let sf = cfg.scaling_factor as f32;
let mut max_abs = 0.0_f32;
for (g, c) in gpu_data.iter().zip(cpu_data.iter()) {
let expected = c * sf;
let d = (g - expected).abs();
if d > max_abs {
max_abs = d;
}
}
assert!(
max_abs < 1e-3,
"encode_mode: gpu vs (cpu_mean * scaling_factor) max_abs = {max_abs}"
);
}
#[test]
fn gpu_encoder_sample_shape_and_finite_tiny() {
let Ok(device) = GpuDevice::new(0) else {
return;
};
let cfg = tiny_cfg();
let cpu = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
let (gpu, _) = GpuVaeEncoder::from_module(&cpu, &device).unwrap();
let img = striped_image_tiny();
let latent = gpu.encode(&img).unwrap();
assert_eq!(latent.shape(), &[1, cfg.latent_channels, 1, 1]);
for &v in latent.data().unwrap() {
assert!(v.is_finite(), "GPU encode produced non-finite value: {v}");
}
}
#[test]
fn gpu_encoder_params_probe_proves_gpu_residency() {
let Ok(device) = GpuDevice::new(0) else {
return;
};
let cfg = tiny_cfg();
let cpu = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
let (gpu, _) = GpuVaeEncoder::from_module(&cpu, &device).unwrap();
let img = striped_image_tiny();
let probe_called = std::cell::Cell::new(false);
let latent = gpu
.encode_with_gpu_params_probe(&img, |params_buf: &CudaBuffer<f32>, shape| {
probe_called.set(true);
let [b, c2, h, w] = shape;
let expected = b * c2 * h * w;
assert_eq!(
params_buf.len(),
expected,
"params CudaBuffer len {} != expected {expected}",
params_buf.len()
);
assert_eq!(c2, 2 * cfg.latent_channels);
let host = gpu_to_cpu(params_buf, &device).map_err(|e| {
FerrotorchError::InvalidArgument {
message: format!("probe readback failed: {e}"),
}
})?;
let half = expected / 2;
let mean = &host[..half];
let logvar = &host[half..];
let mut any_diff = false;
for (m, lv) in mean.iter().zip(logvar.iter()) {
if (m - lv).abs() > 1e-6 {
any_diff = true;
break;
}
}
assert!(
any_diff,
"mean and logvar halves are identical — channel-split bug?"
);
Ok(())
})
.unwrap();
assert!(probe_called.get(), "probe callback was never invoked");
assert_eq!(latent.shape(), &[1, cfg.latent_channels, 1, 1]);
}
}