#![cfg(feature = "cuda")]
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Tensor, TensorStorage};
use ferrotorch_gpu::{
CudaBuffer, GpuDevice, GpuError, gpu_bmm_f32, gpu_conv2d_f32, gpu_group_norm_f32,
gpu_layernorm, gpu_matmul_f32, gpu_nearest_upsample2x_f32, gpu_softmax,
kernels::{gpu_add, gpu_broadcast_add, gpu_gelu_erf, gpu_scale, gpu_silu},
transfer::{cpu_to_gpu, gpu_to_cpu},
};
use ferrotorch_nn::module::{Module, StateDict};
use crate::safetensors_loader::DropReport;
use crate::time_embedding::Timesteps;
use crate::unet::UNet2DConditionModel;
use crate::unet_config::UNet2DConditionConfig;
#[derive(Debug)]
struct GpuConv2d {
weight: CudaBuffer<f32>,
bias: CudaBuffer<f32>,
in_channels: usize,
out_channels: usize,
kernel: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
}
#[derive(Debug)]
struct GpuGroupNorm {
weight: CudaBuffer<f32>,
bias: CudaBuffer<f32>,
num_groups: usize,
num_channels: usize,
eps: f32,
}
#[derive(Debug)]
struct GpuLayerNorm {
weight: CudaBuffer<f32>,
bias: CudaBuffer<f32>,
normalized_shape: usize,
eps: f32,
}
#[derive(Debug)]
struct GpuLinearT {
weight_t: CudaBuffer<f32>,
bias: Option<CudaBuffer<f32>>,
in_features: usize,
out_features: usize,
}
#[derive(Debug)]
struct GpuResnetTime {
norm1: GpuGroupNorm,
conv1: GpuConv2d,
time_emb_proj: GpuLinearT,
norm2: GpuGroupNorm,
conv2: GpuConv2d,
conv_shortcut: Option<GpuConv2d>,
in_channels: usize,
out_channels: usize,
}
#[derive(Debug)]
struct GpuAttention {
to_q: GpuLinearT,
to_k: GpuLinearT,
to_v: GpuLinearT,
to_out_0: GpuLinearT,
heads: usize,
dim_head: usize,
inner_dim: usize,
}
#[derive(Debug)]
struct GpuFeedForwardGEGLU {
net_0_proj: GpuLinearT,
net_2: GpuLinearT,
dim: usize,
dim_ff: usize,
}
#[derive(Debug)]
struct GpuBasicTransformerBlock {
norm1: GpuLayerNorm,
attn1: GpuAttention,
norm2: GpuLayerNorm,
attn2: GpuAttention,
norm3: GpuLayerNorm,
ff: GpuFeedForwardGEGLU,
dim: usize,
}
#[derive(Debug)]
struct GpuTransformer2D {
norm: GpuGroupNorm,
proj_in: GpuConv2d,
blocks: Vec<GpuBasicTransformerBlock>,
proj_out: GpuConv2d,
channels: usize,
inner_dim: usize,
}
#[derive(Debug)]
struct GpuUpsample2D {
conv: GpuConv2d,
channels: usize,
}
#[derive(Debug)]
struct GpuDownsample2D {
conv: GpuConv2d,
channels: usize,
}
#[derive(Debug)]
struct GpuCrossAttnDownBlock {
resnets: Vec<GpuResnetTime>,
attentions: Vec<GpuTransformer2D>,
downsampler: Option<GpuDownsample2D>,
}
#[derive(Debug)]
struct GpuDownBlock {
resnets: Vec<GpuResnetTime>,
downsampler: Option<GpuDownsample2D>,
}
#[derive(Debug)]
enum AnyGpuDown {
CrossAttn(GpuCrossAttnDownBlock),
Plain(GpuDownBlock),
}
#[derive(Debug)]
struct GpuMidBlock {
resnet0: GpuResnetTime,
attn0: GpuTransformer2D,
resnet1: GpuResnetTime,
}
#[derive(Debug)]
struct GpuCrossAttnUpBlock {
resnets: Vec<GpuResnetTime>,
attentions: Vec<GpuTransformer2D>,
upsampler: Option<GpuUpsample2D>,
}
#[derive(Debug)]
struct GpuUpBlock {
resnets: Vec<GpuResnetTime>,
upsampler: Option<GpuUpsample2D>,
}
#[derive(Debug)]
enum AnyGpuUp {
CrossAttn(GpuCrossAttnUpBlock),
Plain(GpuUpBlock),
}
#[derive(Debug)]
pub struct GpuUNet2DConditional {
time_proj: Timesteps,
time_emb_lin1: GpuLinearT,
time_emb_lin2: GpuLinearT,
conv_in: GpuConv2d,
down_blocks: Vec<AnyGpuDown>,
mid_block: GpuMidBlock,
up_blocks: Vec<AnyGpuUp>,
conv_norm_out: GpuGroupNorm,
conv_out: GpuConv2d,
config: UNet2DConditionConfig,
device: GpuDevice,
}
impl GpuUNet2DConditional {
pub fn new(
config: UNet2DConditionConfig,
mut state: StateDict<f32>,
device: GpuDevice,
) -> FerrotorchResult<(Self, DropReport)> {
config.validate()?;
let groups = config.norm_num_groups;
let temb_channels = config.time_embed_dim();
let bocs = &config.block_out_channels;
let num_blocks = bocs.len();
let cross_dim = config.cross_attention_dim;
let heads = config.attention_head_dim; let transformer_layers = config.transformer_layers_per_block;
let resnet_eps = 1e-5_f32;
let transformer_eps = 1e-6_f32;
let time_proj = Timesteps::new(bocs[0], config.flip_sin_to_cos, config.freq_shift)?;
let time_emb_lin1 = pop_linear(
&mut state,
"time_embedding.linear_1",
bocs[0],
temb_channels,
true,
&device,
)?;
let time_emb_lin2 = pop_linear(
&mut state,
"time_embedding.linear_2",
temb_channels,
temb_channels,
true,
&device,
)?;
let conv_in = pop_conv(
&mut state,
"conv_in",
config.in_channels,
bocs[0],
(3, 3),
(1, 1),
(1, 1),
true,
&device,
)?;
let mut down_blocks: Vec<AnyGpuDown> = Vec::with_capacity(num_blocks);
let mut prev = bocs[0];
for i in 0..num_blocks {
let out_c = bocs[i];
let is_final = i == num_blocks - 1;
let add_downsample = !is_final;
let dim_head = out_c / heads;
let block_prefix = format!("down_blocks.{i}");
if config.down_block_has_attn[i] {
let mut resnets = Vec::with_capacity(config.layers_per_block);
let mut attentions = Vec::with_capacity(config.layers_per_block);
for j in 0..config.layers_per_block {
let in_c = if j == 0 { prev } else { out_c };
resnets.push(pop_resnet_time(
&mut state,
&format!("{block_prefix}.resnets.{j}"),
in_c,
out_c,
temb_channels,
groups,
resnet_eps,
&device,
)?);
attentions.push(pop_transformer_2d(
&mut state,
&format!("{block_prefix}.attentions.{j}"),
out_c,
heads,
dim_head,
transformer_layers,
cross_dim,
groups,
transformer_eps,
&device,
)?);
}
let downsampler = if add_downsample {
Some(pop_downsample(
&mut state,
&format!("{block_prefix}.downsamplers.0"),
out_c,
&device,
)?)
} else {
None
};
down_blocks.push(AnyGpuDown::CrossAttn(GpuCrossAttnDownBlock {
resnets,
attentions,
downsampler,
}));
} else {
let mut resnets = Vec::with_capacity(config.layers_per_block);
for j in 0..config.layers_per_block {
let in_c = if j == 0 { prev } else { out_c };
resnets.push(pop_resnet_time(
&mut state,
&format!("{block_prefix}.resnets.{j}"),
in_c,
out_c,
temb_channels,
groups,
resnet_eps,
&device,
)?);
}
let downsampler = if add_downsample {
Some(pop_downsample(
&mut state,
&format!("{block_prefix}.downsamplers.0"),
out_c,
&device,
)?)
} else {
None
};
down_blocks.push(AnyGpuDown::Plain(GpuDownBlock {
resnets,
downsampler,
}));
}
prev = out_c;
}
let mid_channels = bocs[num_blocks - 1];
let mid_dim_head = mid_channels / heads;
let mid_resnet0 = pop_resnet_time(
&mut state,
"mid_block.resnets.0",
mid_channels,
mid_channels,
temb_channels,
groups,
resnet_eps,
&device,
)?;
let mid_attn0 = pop_transformer_2d(
&mut state,
"mid_block.attentions.0",
mid_channels,
heads,
mid_dim_head,
transformer_layers,
cross_dim,
groups,
transformer_eps,
&device,
)?;
let mid_resnet1 = pop_resnet_time(
&mut state,
"mid_block.resnets.1",
mid_channels,
mid_channels,
temb_channels,
groups,
resnet_eps,
&device,
)?;
let mid_block = GpuMidBlock {
resnet0: mid_resnet0,
attn0: mid_attn0,
resnet1: mid_resnet1,
};
let mut up_blocks: Vec<AnyGpuUp> = Vec::with_capacity(num_blocks);
let reversed: Vec<usize> = bocs.iter().rev().copied().collect();
let mut prev_up = mid_channels;
let up_layers = config.layers_per_block + 1;
for i in 0..num_blocks {
let out_c = reversed[i];
let in_c = reversed[(i + 1).min(num_blocks - 1)];
let is_final = i == num_blocks - 1;
let add_upsample = !is_final;
let dim_head = out_c / heads;
let block_prefix = format!("up_blocks.{i}");
if config.up_block_has_attn[i] {
let mut resnets = Vec::with_capacity(up_layers);
let mut attentions = Vec::with_capacity(up_layers);
for j in 0..up_layers {
let res_skip = if j == up_layers - 1 { in_c } else { out_c };
let resnet_in = if j == 0 {
prev_up + res_skip
} else {
out_c + res_skip
};
resnets.push(pop_resnet_time(
&mut state,
&format!("{block_prefix}.resnets.{j}"),
resnet_in,
out_c,
temb_channels,
groups,
resnet_eps,
&device,
)?);
attentions.push(pop_transformer_2d(
&mut state,
&format!("{block_prefix}.attentions.{j}"),
out_c,
heads,
dim_head,
transformer_layers,
cross_dim,
groups,
transformer_eps,
&device,
)?);
}
let upsampler = if add_upsample {
Some(pop_upsample(
&mut state,
&format!("{block_prefix}.upsamplers.0"),
out_c,
&device,
)?)
} else {
None
};
up_blocks.push(AnyGpuUp::CrossAttn(GpuCrossAttnUpBlock {
resnets,
attentions,
upsampler,
}));
} else {
let mut resnets = Vec::with_capacity(up_layers);
for j in 0..up_layers {
let res_skip = if j == up_layers - 1 { in_c } else { out_c };
let resnet_in = if j == 0 {
prev_up + res_skip
} else {
out_c + res_skip
};
resnets.push(pop_resnet_time(
&mut state,
&format!("{block_prefix}.resnets.{j}"),
resnet_in,
out_c,
temb_channels,
groups,
resnet_eps,
&device,
)?);
}
let upsampler = if add_upsample {
Some(pop_upsample(
&mut state,
&format!("{block_prefix}.upsamplers.0"),
out_c,
&device,
)?)
} else {
None
};
up_blocks.push(AnyGpuUp::Plain(GpuUpBlock { resnets, upsampler }));
}
prev_up = out_c;
}
let conv_norm_out =
pop_groupnorm(&mut state, "conv_norm_out", groups, bocs[0], resnet_eps, &device)?;
let conv_out = pop_conv(
&mut state,
"conv_out",
bocs[0],
config.out_channels,
(3, 3),
(1, 1),
(1, 1),
true,
&device,
)?;
let mut dropped: Vec<String> = state.keys().cloned().collect();
dropped.sort();
let report = DropReport { dropped };
Ok((
Self {
time_proj,
time_emb_lin1,
time_emb_lin2,
conv_in,
down_blocks,
mid_block,
up_blocks,
conv_norm_out,
conv_out,
config,
device,
},
report,
))
}
pub fn from_module(
cpu: &UNet2DConditionModel<f32>,
device: &GpuDevice,
) -> FerrotorchResult<(Self, DropReport)> {
let state: StateDict<f32> = cpu.state_dict();
Self::new(cpu.config.clone(), state, device.clone())
}
pub fn forward(
&self,
sample: &Tensor<f32>,
timesteps: &Tensor<f32>,
encoder_hidden_states: &Tensor<f32>,
) -> FerrotorchResult<Tensor<f32>> {
let cfg = &self.config;
let s_shape = sample.shape();
if s_shape.len() != 4 || s_shape[1] != cfg.in_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuUNet2DConditional::forward: expected sample [B, {}, H, W], got {:?}",
cfg.in_channels, s_shape
),
});
}
if timesteps.ndim() != 1 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuUNet2DConditional::forward: expected timesteps [B], got {:?}",
timesteps.shape()
),
});
}
let eh_shape = encoder_hidden_states.shape();
if eh_shape.len() != 3 || eh_shape[2] != cfg.cross_attention_dim {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuUNet2DConditional::forward: expected encoder_hidden_states \
[B, S, {}], got {:?}",
cfg.cross_attention_dim, eh_shape
),
});
}
let b = s_shape[0];
let h_in = s_shape[2];
let w_in = s_shape[3];
let s_text = eh_shape[1];
let t_enc = self.time_proj.forward_t(timesteps)?;
let t_enc_data = t_enc.data()?;
let t_enc_gpu = cpu_to_gpu(t_enc_data, &self.device).map_err(gpu_err)?;
let t1 = linear_forward(&self.time_emb_lin1, &t_enc_gpu, b, &self.device)?;
let t1_act = gpu_silu(&t1, &self.device).map_err(gpu_err)?;
let temb = linear_forward(&self.time_emb_lin2, &t1_act, b, &self.device)?;
let sample_data = sample.data()?;
let x_in = cpu_to_gpu(sample_data, &self.device).map_err(gpu_err)?;
let (h0_buf, h0_shape) = conv_forward(
&self.conv_in,
&x_in,
[b, cfg.in_channels, h_in, w_in],
&self.device,
)?;
let ehs_data = encoder_hidden_states.data()?;
let ehs_gpu = cpu_to_gpu(ehs_data, &self.device).map_err(gpu_err)?;
let mut skips: Vec<(CudaBuffer<f32>, [usize; 4])> = Vec::new();
skips.push((clone_buf(&h0_buf, &self.device)?, h0_shape));
let mut h_buf = h0_buf;
let mut h_shape = h0_shape;
for db in &self.down_blocks {
match db {
AnyGpuDown::CrossAttn(blk) => {
for (r, a) in blk.resnets.iter().zip(blk.attentions.iter()) {
let (rb, rs) =
resnet_time_forward(r, &h_buf, h_shape, &temb, b, &self.device)?;
h_buf = rb;
h_shape = rs;
let (ab, asz) = transformer_2d_forward(
a,
&h_buf,
h_shape,
&ehs_gpu,
b,
s_text,
cfg.cross_attention_dim,
&self.device,
)?;
h_buf = ab;
h_shape = asz;
skips.push((clone_buf(&h_buf, &self.device)?, h_shape));
}
if let Some(ds) = &blk.downsampler {
let (db_out, ds_shape) =
downsample_forward(ds, &h_buf, h_shape, &self.device)?;
h_buf = db_out;
h_shape = ds_shape;
skips.push((clone_buf(&h_buf, &self.device)?, h_shape));
}
}
AnyGpuDown::Plain(blk) => {
for r in &blk.resnets {
let (rb, rs) =
resnet_time_forward(r, &h_buf, h_shape, &temb, b, &self.device)?;
h_buf = rb;
h_shape = rs;
skips.push((clone_buf(&h_buf, &self.device)?, h_shape));
}
if let Some(ds) = &blk.downsampler {
let (db_out, ds_shape) =
downsample_forward(ds, &h_buf, h_shape, &self.device)?;
h_buf = db_out;
h_shape = ds_shape;
skips.push((clone_buf(&h_buf, &self.device)?, h_shape));
}
}
}
}
let (mr0, mr0_shape) = resnet_time_forward(
&self.mid_block.resnet0,
&h_buf,
h_shape,
&temb,
b,
&self.device,
)?;
let (ma0, ma0_shape) = transformer_2d_forward(
&self.mid_block.attn0,
&mr0,
mr0_shape,
&ehs_gpu,
b,
s_text,
cfg.cross_attention_dim,
&self.device,
)?;
let (mr1, mr1_shape) = resnet_time_forward(
&self.mid_block.resnet1,
&ma0,
ma0_shape,
&temb,
b,
&self.device,
)?;
h_buf = mr1;
h_shape = mr1_shape;
for ub in &self.up_blocks {
let n = match ub {
AnyGpuUp::CrossAttn(b) => b.resnets.len(),
AnyGpuUp::Plain(b) => b.resnets.len(),
};
if skips.len() < n {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"GpuUNet2DConditional: up-block needs {n} skips, only {} left",
skips.len()
),
});
}
let split_at = skips.len() - n;
let popped: Vec<(CudaBuffer<f32>, [usize; 4])> = skips.split_off(split_at);
let popped_rev: Vec<(CudaBuffer<f32>, [usize; 4])> =
popped.into_iter().rev().collect();
match ub {
AnyGpuUp::CrossAttn(blk) => {
for ((r, a), (skip_buf, skip_shape)) in blk
.resnets
.iter()
.zip(blk.attentions.iter())
.zip(popped_rev.iter())
{
let (cat_buf, cat_shape) =
cat_channels(&h_buf, h_shape, skip_buf, *skip_shape, &self.device)?;
let (rb, rs) = resnet_time_forward(
r,
&cat_buf,
cat_shape,
&temb,
b,
&self.device,
)?;
let (ab, asz) = transformer_2d_forward(
a,
&rb,
rs,
&ehs_gpu,
b,
s_text,
cfg.cross_attention_dim,
&self.device,
)?;
h_buf = ab;
h_shape = asz;
}
if let Some(up) = &blk.upsampler {
let (ub_buf, ub_shape) =
upsample_forward(up, &h_buf, h_shape, &self.device)?;
h_buf = ub_buf;
h_shape = ub_shape;
}
}
AnyGpuUp::Plain(blk) => {
for (r, (skip_buf, skip_shape)) in
blk.resnets.iter().zip(popped_rev.iter())
{
let (cat_buf, cat_shape) =
cat_channels(&h_buf, h_shape, skip_buf, *skip_shape, &self.device)?;
let (rb, rs) = resnet_time_forward(
r,
&cat_buf,
cat_shape,
&temb,
b,
&self.device,
)?;
h_buf = rb;
h_shape = rs;
}
if let Some(up) = &blk.upsampler {
let (ub_buf, ub_shape) =
upsample_forward(up, &h_buf, h_shape, &self.device)?;
h_buf = ub_buf;
h_shape = ub_shape;
}
}
}
}
h_buf = group_norm_forward(&self.conv_norm_out, &h_buf, h_shape, &self.device)?;
h_buf = gpu_silu(&h_buf, &self.device).map_err(gpu_err)?;
let (out_buf, out_shape) = conv_forward(&self.conv_out, &h_buf, h_shape, &self.device)?;
let out_data = gpu_to_cpu(&out_buf, &self.device).map_err(gpu_err)?;
Tensor::from_storage(TensorStorage::cpu(out_data), out_shape.to_vec(), false)
}
}
fn gpu_err(e: GpuError) -> FerrotorchError {
FerrotorchError::InvalidArgument {
message: format!("GpuUNet2DConditional GPU op failed: {e}"),
}
}
fn clone_buf(
buf: &CudaBuffer<f32>,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let host = gpu_to_cpu(buf, device).map_err(gpu_err)?;
cpu_to_gpu(&host, device).map_err(gpu_err)
}
fn pop_tensor(
state: &mut StateDict<f32>,
key: &str,
expected_len: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let t = state.remove(key).ok_or_else(|| FerrotorchError::InvalidArgument {
message: format!("GpuUNet2DConditional: missing tensor {key:?}"),
})?;
let data = t.data()?;
if data.len() != expected_len {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuUNet2DConditional: tensor {key:?} length {} != expected {expected_len}",
data.len()
),
});
}
cpu_to_gpu(data, device).map_err(gpu_err)
}
#[allow(clippy::too_many_arguments)]
fn pop_conv(
state: &mut StateDict<f32>,
prefix: &str,
in_c: usize,
out_c: usize,
kernel: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
bias: bool,
device: &GpuDevice,
) -> FerrotorchResult<GpuConv2d> {
let w_len = out_c * in_c * kernel.0 * kernel.1;
let weight = pop_tensor(state, &format!("{prefix}.weight"), w_len, device)?;
let bias_buf = if bias {
pop_tensor(state, &format!("{prefix}.bias"), out_c, device)?
} else {
return Err(FerrotorchError::InvalidArgument {
message: format!("GpuUNet2DConditional: conv {prefix:?} expected bias=true"),
});
};
Ok(GpuConv2d {
weight,
bias: bias_buf,
in_channels: in_c,
out_channels: out_c,
kernel,
stride,
padding,
})
}
fn pop_groupnorm(
state: &mut StateDict<f32>,
prefix: &str,
groups: usize,
channels: usize,
eps: f32,
device: &GpuDevice,
) -> FerrotorchResult<GpuGroupNorm> {
let weight = pop_tensor(state, &format!("{prefix}.weight"), channels, device)?;
let bias = pop_tensor(state, &format!("{prefix}.bias"), channels, device)?;
Ok(GpuGroupNorm {
weight,
bias,
num_groups: groups,
num_channels: channels,
eps,
})
}
fn pop_layernorm(
state: &mut StateDict<f32>,
prefix: &str,
features: usize,
eps: f32,
device: &GpuDevice,
) -> FerrotorchResult<GpuLayerNorm> {
let weight = pop_tensor(state, &format!("{prefix}.weight"), features, device)?;
let bias = pop_tensor(state, &format!("{prefix}.bias"), features, device)?;
Ok(GpuLayerNorm {
weight,
bias,
normalized_shape: features,
eps,
})
}
fn pop_linear(
state: &mut StateDict<f32>,
prefix: &str,
in_f: usize,
out_f: usize,
bias: bool,
device: &GpuDevice,
) -> FerrotorchResult<GpuLinearT> {
let w_key = format!("{prefix}.weight");
let w = state.remove(&w_key).ok_or_else(|| FerrotorchError::InvalidArgument {
message: format!("GpuUNet2DConditional: missing tensor {w_key:?}"),
})?;
let w_data = w.data()?;
if w_data.len() != out_f * in_f {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuUNet2DConditional: tensor {w_key:?} length {} != expected {}",
w_data.len(),
out_f * in_f
),
});
}
let mut wt = vec![0.0_f32; in_f * out_f];
for o in 0..out_f {
for i in 0..in_f {
wt[i * out_f + o] = w_data[o * in_f + i];
}
}
let weight_t = cpu_to_gpu(&wt, device).map_err(gpu_err)?;
let bias_buf = if bias {
Some(pop_tensor(state, &format!("{prefix}.bias"), out_f, device)?)
} else {
None
};
Ok(GpuLinearT {
weight_t,
bias: bias_buf,
in_features: in_f,
out_features: out_f,
})
}
#[allow(clippy::too_many_arguments)]
fn pop_resnet_time(
state: &mut StateDict<f32>,
prefix: &str,
in_c: usize,
out_c: usize,
temb_channels: usize,
groups: usize,
eps: f32,
device: &GpuDevice,
) -> FerrotorchResult<GpuResnetTime> {
let norm1 = pop_groupnorm(state, &format!("{prefix}.norm1"), groups, in_c, eps, device)?;
let conv1 = pop_conv(
state,
&format!("{prefix}.conv1"),
in_c,
out_c,
(3, 3),
(1, 1),
(1, 1),
true,
device,
)?;
let time_emb_proj = pop_linear(
state,
&format!("{prefix}.time_emb_proj"),
temb_channels,
out_c,
true,
device,
)?;
let norm2 = pop_groupnorm(state, &format!("{prefix}.norm2"), groups, out_c, eps, device)?;
let conv2 = pop_conv(
state,
&format!("{prefix}.conv2"),
out_c,
out_c,
(3, 3),
(1, 1),
(1, 1),
true,
device,
)?;
let conv_shortcut = if in_c == out_c {
None
} else {
Some(pop_conv(
state,
&format!("{prefix}.conv_shortcut"),
in_c,
out_c,
(1, 1),
(1, 1),
(0, 0),
true,
device,
)?)
};
Ok(GpuResnetTime {
norm1,
conv1,
time_emb_proj,
norm2,
conv2,
conv_shortcut,
in_channels: in_c,
out_channels: out_c,
})
}
fn pop_attention(
state: &mut StateDict<f32>,
prefix: &str,
query_dim: usize,
cross_attention_dim: Option<usize>,
heads: usize,
dim_head: usize,
device: &GpuDevice,
) -> FerrotorchResult<GpuAttention> {
let inner_dim = heads * dim_head;
let kv_dim = cross_attention_dim.unwrap_or(query_dim);
let to_q = pop_linear(state, &format!("{prefix}.to_q"), query_dim, inner_dim, false, device)?;
let to_k = pop_linear(state, &format!("{prefix}.to_k"), kv_dim, inner_dim, false, device)?;
let to_v = pop_linear(state, &format!("{prefix}.to_v"), kv_dim, inner_dim, false, device)?;
let to_out_0 = pop_linear(
state,
&format!("{prefix}.to_out.0"),
inner_dim,
query_dim,
true,
device,
)?;
let _ = query_dim;
let _ = kv_dim;
Ok(GpuAttention {
to_q,
to_k,
to_v,
to_out_0,
heads,
dim_head,
inner_dim,
})
}
fn pop_feedforward_geglu(
state: &mut StateDict<f32>,
prefix: &str,
dim: usize,
mult: usize,
device: &GpuDevice,
) -> FerrotorchResult<GpuFeedForwardGEGLU> {
let dim_ff = dim * mult;
let net_0_proj = pop_linear(
state,
&format!("{prefix}.net.0.proj"),
dim,
2 * dim_ff,
true,
device,
)?;
let net_2 = pop_linear(state, &format!("{prefix}.net.2"), dim_ff, dim, true, device)?;
Ok(GpuFeedForwardGEGLU {
net_0_proj,
net_2,
dim,
dim_ff,
})
}
fn pop_basic_transformer_block(
state: &mut StateDict<f32>,
prefix: &str,
dim: usize,
heads: usize,
dim_head: usize,
cross_dim: usize,
device: &GpuDevice,
) -> FerrotorchResult<GpuBasicTransformerBlock> {
let norm1 = pop_layernorm(state, &format!("{prefix}.norm1"), dim, 1e-5_f32, device)?;
let attn1 = pop_attention(
state,
&format!("{prefix}.attn1"),
dim,
None,
heads,
dim_head,
device,
)?;
let norm2 = pop_layernorm(state, &format!("{prefix}.norm2"), dim, 1e-5_f32, device)?;
let attn2 = pop_attention(
state,
&format!("{prefix}.attn2"),
dim,
Some(cross_dim),
heads,
dim_head,
device,
)?;
let norm3 = pop_layernorm(state, &format!("{prefix}.norm3"), dim, 1e-5_f32, device)?;
let ff = pop_feedforward_geglu(state, &format!("{prefix}.ff"), dim, 4, device)?;
Ok(GpuBasicTransformerBlock {
norm1,
attn1,
norm2,
attn2,
norm3,
ff,
dim,
})
}
#[allow(clippy::too_many_arguments)]
fn pop_transformer_2d(
state: &mut StateDict<f32>,
prefix: &str,
in_channels: usize,
heads: usize,
dim_head: usize,
num_layers: usize,
cross_dim: usize,
groups: usize,
eps: f32,
device: &GpuDevice,
) -> FerrotorchResult<GpuTransformer2D> {
let inner_dim = heads * dim_head;
let norm =
pop_groupnorm(state, &format!("{prefix}.norm"), groups, in_channels, eps, device)?;
let proj_in = pop_conv(
state,
&format!("{prefix}.proj_in"),
in_channels,
inner_dim,
(1, 1),
(1, 1),
(0, 0),
true,
device,
)?;
let proj_out = pop_conv(
state,
&format!("{prefix}.proj_out"),
inner_dim,
in_channels,
(1, 1),
(1, 1),
(0, 0),
true,
device,
)?;
let mut blocks = Vec::with_capacity(num_layers);
for j in 0..num_layers {
blocks.push(pop_basic_transformer_block(
state,
&format!("{prefix}.transformer_blocks.{j}"),
inner_dim,
heads,
dim_head,
cross_dim,
device,
)?);
}
Ok(GpuTransformer2D {
norm,
proj_in,
blocks,
proj_out,
channels: in_channels,
inner_dim,
})
}
fn pop_upsample(
state: &mut StateDict<f32>,
prefix: &str,
channels: usize,
device: &GpuDevice,
) -> FerrotorchResult<GpuUpsample2D> {
let conv = pop_conv(
state,
&format!("{prefix}.conv"),
channels,
channels,
(3, 3),
(1, 1),
(1, 1),
true,
device,
)?;
Ok(GpuUpsample2D { conv, channels })
}
fn pop_downsample(
state: &mut StateDict<f32>,
prefix: &str,
channels: usize,
device: &GpuDevice,
) -> FerrotorchResult<GpuDownsample2D> {
let conv = pop_conv(
state,
&format!("{prefix}.conv"),
channels,
channels,
(3, 3),
(2, 2),
(1, 1),
true,
device,
)?;
Ok(GpuDownsample2D { conv, channels })
}
fn conv_forward(
c: &GpuConv2d,
x: &CudaBuffer<f32>,
shape: [usize; 4],
device: &GpuDevice,
) -> FerrotorchResult<(CudaBuffer<f32>, [usize; 4])> {
let (out, out_shape) = gpu_conv2d_f32(
x,
&c.weight,
Some(&c.bias),
shape,
[c.out_channels, c.in_channels, c.kernel.0, c.kernel.1],
c.stride,
c.padding,
(1, 1),
1,
device,
)
.map_err(gpu_err)?;
Ok((out, out_shape))
}
fn group_norm_forward(
g: &GpuGroupNorm,
x: &CudaBuffer<f32>,
shape: [usize; 4],
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let [b, c, h, w] = shape;
if c != g.num_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuUNet2DConditional::group_norm: expected C={}, got {}",
g.num_channels, c
),
});
}
gpu_group_norm_f32(x, &g.weight, &g.bias, b, c, g.num_groups, h * w, g.eps, device)
.map_err(gpu_err)
}
fn layer_norm_forward(
ln: &GpuLayerNorm,
x: &CudaBuffer<f32>,
rows: usize,
cols: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
if cols != ln.normalized_shape {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuUNet2DConditional::layer_norm: expected cols={}, got {}",
ln.normalized_shape, cols
),
});
}
gpu_layernorm(x, &ln.weight, &ln.bias, rows, cols, ln.eps, device).map_err(gpu_err)
}
fn linear_forward(
lin: &GpuLinearT,
x: &CudaBuffer<f32>,
m: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let y = gpu_matmul_f32(x, &lin.weight_t, m, lin.in_features, lin.out_features, device)
.map_err(gpu_err)?;
if let Some(bias) = &lin.bias {
gpu_broadcast_add(
&y,
bias,
&[m, lin.out_features],
&[1, lin.out_features],
&[m, lin.out_features],
device,
)
.map_err(gpu_err)
} else {
Ok(y)
}
}
fn resnet_time_forward(
r: &GpuResnetTime,
x: &CudaBuffer<f32>,
shape: [usize; 4],
temb: &CudaBuffer<f32>,
b: usize,
device: &GpuDevice,
) -> FerrotorchResult<(CudaBuffer<f32>, [usize; 4])> {
let c_in = shape[1];
if c_in != r.in_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuUNet2DConditional::resnet_time: expected C_in={}, got {}",
r.in_channels, c_in
),
});
}
let mut h_buf = group_norm_forward(&r.norm1, x, shape, device)?;
h_buf = gpu_silu(&h_buf, device).map_err(gpu_err)?;
let (mut hb, mut hs) = conv_forward(&r.conv1, &h_buf, shape, device)?;
let t_act = gpu_silu(temb, device).map_err(gpu_err)?;
let t_proj = linear_forward(&r.time_emb_proj, &t_act, b, device)?;
let hw = hs[2] * hs[3];
hb = gpu_broadcast_add(
&hb,
&t_proj,
&[b, r.out_channels, hw],
&[b, r.out_channels, 1],
&[b, r.out_channels, hw],
device,
)
.map_err(gpu_err)?;
hb = group_norm_forward(&r.norm2, &hb, hs, device)?;
hb = gpu_silu(&hb, device).map_err(gpu_err)?;
(hb, hs) = conv_forward(&r.conv2, &hb, hs, device)?;
if let Some(sc) = &r.conv_shortcut {
let (sb, _) = conv_forward(sc, x, shape, device)?;
hb = gpu_add(&hb, &sb, device).map_err(gpu_err)?;
} else {
hb = gpu_add(&hb, x, device).map_err(gpu_err)?;
}
let _ = r.out_channels;
Ok((hb, hs))
}
fn attention_forward(
a: &GpuAttention,
query_buf: &CudaBuffer<f32>,
b: usize,
n: usize,
kv_buf: Option<&CudaBuffer<f32>>,
s: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let h = a.heads;
let d = a.dim_head;
let inner = a.inner_dim;
let kv_src = kv_buf.unwrap_or(query_buf);
let s_eff = if kv_buf.is_some() { s } else { n };
let q = linear_forward(&a.to_q, query_buf, b * n, device)?;
let k = linear_forward(&a.to_k, kv_src, b * s_eff, device)?;
let v = linear_forward(&a.to_v, kv_src, b * s_eff, device)?;
let q_h = reshape_bnhd_to_bhnd(&q, b, n, h, d, device)?;
let k_h = reshape_bnhd_to_bhnd(&k, b, s_eff, h, d, device)?;
let v_h = reshape_bnhd_to_bhnd(&v, b, s_eff, h, d, device)?;
let k_h_t = transpose_last_two(&k_h, b * h, s_eff, d, device)?;
let scores = gpu_bmm_f32(&q_h, &k_h_t, b * h, n, d, s_eff, device).map_err(gpu_err)?;
let scale = (d as f64).sqrt().recip() as f32;
let scaled = gpu_scale(&scores, scale, device).map_err(gpu_err)?;
let probs = gpu_softmax(&scaled, b * h * n, s_eff, device).map_err(gpu_err)?;
let attended = gpu_bmm_f32(&probs, &v_h, b * h, n, s_eff, d, device).map_err(gpu_err)?;
let merged = reshape_bhnd_to_bnhd(&attended, b, n, h, d, device)?;
let _ = inner;
linear_forward(&a.to_out_0, &merged, b * n, device)
}
fn basic_transformer_block_forward(
blk: &GpuBasicTransformerBlock,
x: &CudaBuffer<f32>,
b: usize,
n: usize,
ehs: &CudaBuffer<f32>,
s_text: usize,
cross_dim: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let _ = blk.dim; let dim = blk.dim;
let normed1 = layer_norm_forward(&blk.norm1, x, b * n, dim, device)?;
let attn1_out = attention_forward(&blk.attn1, &normed1, b, n, None, n, device)?;
let x1 = gpu_add(x, &attn1_out, device).map_err(gpu_err)?;
let normed2 = layer_norm_forward(&blk.norm2, &x1, b * n, dim, device)?;
let _ = cross_dim;
let attn2_out = attention_forward(&blk.attn2, &normed2, b, n, Some(ehs), s_text, device)?;
let x2 = gpu_add(&x1, &attn2_out, device).map_err(gpu_err)?;
let normed3 = layer_norm_forward(&blk.norm3, &x2, b * n, dim, device)?;
let ff_out = ff_geglu_forward(&blk.ff, &normed3, b * n, device)?;
gpu_add(&x2, &ff_out, device).map_err(gpu_err)
}
fn ff_geglu_forward(
ff: &GpuFeedForwardGEGLU,
x: &CudaBuffer<f32>,
m: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let proj = linear_forward(&ff.net_0_proj, x, m, device)?;
let host = gpu_to_cpu(&proj, device).map_err(gpu_err)?;
let dim_ff = ff.dim_ff;
let mut x_part = vec![0.0_f32; m * dim_ff];
let mut gate_part = vec![0.0_f32; m * dim_ff];
for i in 0..m {
let row = i * 2 * dim_ff;
x_part[i * dim_ff..(i + 1) * dim_ff].copy_from_slice(&host[row..row + dim_ff]);
gate_part[i * dim_ff..(i + 1) * dim_ff]
.copy_from_slice(&host[row + dim_ff..row + 2 * dim_ff]);
}
let x_gpu = cpu_to_gpu(&x_part, device).map_err(gpu_err)?;
let gate_gpu = cpu_to_gpu(&gate_part, device).map_err(gpu_err)?;
let gate_act = gpu_gelu_erf(&gate_gpu, device).map_err(gpu_err)?;
let activated = ferrotorch_gpu::kernels::gpu_mul(&x_gpu, &gate_act, device).map_err(gpu_err)?;
let _ = ff.dim;
linear_forward(&ff.net_2, &activated, m, device)
}
#[allow(clippy::too_many_arguments)]
fn transformer_2d_forward(
t: &GpuTransformer2D,
x: &CudaBuffer<f32>,
shape: [usize; 4],
ehs: &CudaBuffer<f32>,
b_ehs: usize,
s_text: usize,
cross_dim: usize,
device: &GpuDevice,
) -> FerrotorchResult<(CudaBuffer<f32>, [usize; 4])> {
let [b, c, h, w] = shape;
if c != t.channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuUNet2DConditional::transformer_2d: expected C={}, got {}",
t.channels, c
),
});
}
if b != b_ehs {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuUNet2DConditional::transformer_2d: batch mismatch sample B={b} vs ehs B={b_ehs}"
),
});
}
let hw = h * w;
let inner = t.inner_dim;
let normed = group_norm_forward(&t.norm, x, shape, device)?;
let (proj_in_buf, proj_in_shape) = conv_forward(&t.proj_in, &normed, shape, device)?;
let mut hidden_seq = transpose_bchw_to_bnc(&proj_in_buf, b, inner, hw, device)?;
for block in &t.blocks {
hidden_seq =
basic_transformer_block_forward(block, &hidden_seq, b, hw, ehs, s_text, cross_dim, device)?;
}
let hidden_back = transpose_bnc_to_bchw(&hidden_seq, b, inner, hw, device)?;
let (proj_out_buf, _) =
conv_forward(&t.proj_out, &hidden_back, [b, inner, h, w], device)?;
let summed = gpu_add(&proj_out_buf, x, device).map_err(gpu_err)?;
Ok((summed, proj_in_shape))
}
fn upsample_forward(
u: &GpuUpsample2D,
x: &CudaBuffer<f32>,
shape: [usize; 4],
device: &GpuDevice,
) -> FerrotorchResult<(CudaBuffer<f32>, [usize; 4])> {
let [b, c, h, w] = shape;
if c != u.channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuUNet2DConditional::upsample: expected C={}, got {}",
u.channels, c
),
});
}
let upsampled = gpu_nearest_upsample2x_f32(x, b, c, h, w, device).map_err(gpu_err)?;
let new_shape = [b, c, h * 2, w * 2];
conv_forward(&u.conv, &upsampled, new_shape, device)
}
fn downsample_forward(
d: &GpuDownsample2D,
x: &CudaBuffer<f32>,
shape: [usize; 4],
device: &GpuDevice,
) -> FerrotorchResult<(CudaBuffer<f32>, [usize; 4])> {
let [_, c, _, _] = shape;
if c != d.channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuUNet2DConditional::downsample: expected C={}, got {}",
d.channels, c
),
});
}
conv_forward(&d.conv, x, shape, device)
}
fn reshape_bnhd_to_bhnd(
x: &CudaBuffer<f32>,
b: usize,
n: usize,
h: usize,
d: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let host = gpu_to_cpu(x, device).map_err(gpu_err)?;
let mut out = vec![0.0_f32; b * h * n * d];
for bi in 0..b {
for ni in 0..n {
for hi in 0..h {
for di in 0..d {
let src = ((bi * n + ni) * h + hi) * d + di;
let dst = ((bi * h + hi) * n + ni) * d + di;
out[dst] = host[src];
}
}
}
}
cpu_to_gpu(&out, device).map_err(gpu_err)
}
fn reshape_bhnd_to_bnhd(
x: &CudaBuffer<f32>,
b: usize,
n: usize,
h: usize,
d: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let host = gpu_to_cpu(x, device).map_err(gpu_err)?;
let mut out = vec![0.0_f32; b * n * h * d];
for bi in 0..b {
for hi in 0..h {
for ni in 0..n {
for di in 0..d {
let src = ((bi * h + hi) * n + ni) * d + di;
let dst = ((bi * n + ni) * h + hi) * d + di;
out[dst] = host[src];
}
}
}
}
cpu_to_gpu(&out, device).map_err(gpu_err)
}
fn transpose_last_two(
x: &CudaBuffer<f32>,
batch: usize,
m: usize,
n: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let host = gpu_to_cpu(x, device).map_err(gpu_err)?;
let mut out = vec![0.0_f32; batch * n * m];
for bi in 0..batch {
for mi in 0..m {
for ni in 0..n {
let src = bi * m * n + mi * n + ni;
let dst = bi * n * m + ni * m + mi;
out[dst] = host[src];
}
}
}
cpu_to_gpu(&out, device).map_err(gpu_err)
}
fn transpose_bchw_to_bnc(
x: &CudaBuffer<f32>,
b: usize,
c: usize,
hw: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let host = gpu_to_cpu(x, device).map_err(gpu_err)?;
let mut out = vec![0.0_f32; b * hw * c];
for bi in 0..b {
for ci in 0..c {
for hwi in 0..hw {
let src = (bi * c + ci) * hw + hwi;
let dst = (bi * hw + hwi) * c + ci;
out[dst] = host[src];
}
}
}
cpu_to_gpu(&out, device).map_err(gpu_err)
}
fn transpose_bnc_to_bchw(
x: &CudaBuffer<f32>,
b: usize,
c: usize,
hw: usize,
device: &GpuDevice,
) -> FerrotorchResult<CudaBuffer<f32>> {
let host = gpu_to_cpu(x, device).map_err(gpu_err)?;
let mut out = vec![0.0_f32; b * c * hw];
for bi in 0..b {
for hwi in 0..hw {
for ci in 0..c {
let src = (bi * hw + hwi) * c + ci;
let dst = (bi * c + ci) * hw + hwi;
out[dst] = host[src];
}
}
}
cpu_to_gpu(&out, device).map_err(gpu_err)
}
fn cat_channels(
a: &CudaBuffer<f32>,
a_shape: [usize; 4],
b: &CudaBuffer<f32>,
b_shape: [usize; 4],
device: &GpuDevice,
) -> FerrotorchResult<(CudaBuffer<f32>, [usize; 4])> {
let [ba, ca, ha, wa] = a_shape;
let [bb, cb, hb, wb] = b_shape;
if ba != bb || ha != hb || wa != wb {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuUNet2DConditional::cat_channels: shape disagree {a_shape:?} vs {b_shape:?}"
),
});
}
let a_host = gpu_to_cpu(a, device).map_err(gpu_err)?;
let b_host = gpu_to_cpu(b, device).map_err(gpu_err)?;
let c_out = ca + cb;
let hw = ha * wa;
let mut out = vec![0.0_f32; ba * c_out * hw];
for bi in 0..ba {
for ci in 0..ca {
let src = (bi * ca + ci) * hw;
let dst = (bi * c_out + ci) * hw;
out[dst..dst + hw].copy_from_slice(&a_host[src..src + hw]);
}
for ci in 0..cb {
let src = (bi * cb + ci) * hw;
let dst = (bi * c_out + ca + ci) * hw;
out[dst..dst + hw].copy_from_slice(&b_host[src..src + hw]);
}
}
let out_gpu = cpu_to_gpu(&out, device).map_err(gpu_err)?;
Ok((out_gpu, [ba, c_out, ha, wa]))
}
#[cfg(all(test, feature = "cuda"))]
mod tests {
use super::*;
use crate::unet::UNet2DConditionModel;
fn tiny_cfg() -> UNet2DConditionConfig {
UNet2DConditionConfig {
in_channels: 4,
out_channels: 4,
block_out_channels: vec![16, 32, 64, 64],
layers_per_block: 1,
attention_head_dim: 8,
cross_attention_dim: 24,
norm_num_groups: 4,
sample_size: 8,
flip_sin_to_cos: true,
freq_shift: 0.0,
transformer_layers_per_block: 1,
down_block_has_attn: vec![true, true, true, false],
up_block_has_attn: vec![false, true, true, true],
}
}
#[test]
fn gpu_unet_matches_cpu_tiny() {
let Ok(device) = GpuDevice::new(0) else {
return;
};
let cfg = tiny_cfg();
let cpu = UNet2DConditionModel::<f32>::new(cfg.clone()).unwrap();
let (gpu, report) = GpuUNet2DConditional::from_module(&cpu, &device).unwrap();
assert!(
report.dropped.is_empty(),
"unexpected dropped keys: {:?}",
report.dropped
);
let b = 1usize;
let h_in = 8usize;
let w_in = 8usize;
let sample_data: Vec<f32> = (0..b * cfg.in_channels * h_in * w_in)
.map(|i| ((i % 7) as f32) * 0.03 - 0.05)
.collect();
let sample = Tensor::from_storage(
TensorStorage::cpu(sample_data),
vec![b, cfg.in_channels, h_in, w_in],
false,
)
.unwrap();
let timesteps =
Tensor::from_storage(TensorStorage::cpu(vec![5.0f32]), vec![b], false).unwrap();
let s = 7usize;
let ehs_data: Vec<f32> = (0..b * s * cfg.cross_attention_dim)
.map(|i| ((i % 11) as f32) * 0.02 - 0.07)
.collect();
let ehs = Tensor::from_storage(
TensorStorage::cpu(ehs_data),
vec![b, s, cfg.cross_attention_dim],
false,
)
.unwrap();
let cpu_out = cpu.forward_t(&sample, ×teps, &ehs).unwrap();
let gpu_out = gpu.forward(&sample, ×teps, &ehs).unwrap();
assert_eq!(cpu_out.shape(), gpu_out.shape());
let cpu_data = cpu_out.data().unwrap();
let gpu_data = gpu_out.data().unwrap();
let mut max_abs = 0.0_f32;
for (a, c) in cpu_data.iter().zip(gpu_data.iter()) {
let d = (a - c).abs();
if d > max_abs {
max_abs = d;
}
}
assert!(max_abs < 1e-3, "gpu vs cpu tiny UNet max_abs = {max_abs}");
}
}