#![cfg(feature = "cuda")]
use ferrotorch_core::grad_fns::arithmetic::{add, mul, sub};
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Tensor, TensorStorage};
use ferrotorch_gpu::GpuDevice;
use crate::gpu::clip::GpuClipTextEncoder;
use crate::gpu::unet::GpuUNet2DConditional;
use crate::gpu::vae::GpuVaeDecoder;
use crate::pipeline::PipelineStepDump;
use crate::scheduler::DDIMScheduler;
#[derive(Debug)]
pub struct GpuStableDiffusionPipeline {
pub text_encoder: GpuClipTextEncoder,
pub unet: GpuUNet2DConditional,
pub vae: GpuVaeDecoder,
pub scheduler: DDIMScheduler,
_device: GpuDevice,
}
impl GpuStableDiffusionPipeline {
pub fn new(
text_encoder: GpuClipTextEncoder,
unet: GpuUNet2DConditional,
vae: GpuVaeDecoder,
scheduler: DDIMScheduler,
device: GpuDevice,
) -> FerrotorchResult<Self> {
Ok(Self {
text_encoder,
unet,
vae,
scheduler,
_device: device,
})
}
pub fn encode_prompt(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<f32>> {
self.text_encoder.encode(input_ids)
}
fn timestep_tensor(timestep: usize, batch: usize) -> FerrotorchResult<Tensor<f32>> {
Tensor::<f32>::from_storage(
TensorStorage::cpu(vec![timestep as f32; batch]),
vec![batch],
false,
)
}
fn cfg_eval(
&self,
latent: &Tensor<f32>,
timestep: usize,
cond_embeds: &Tensor<f32>,
uncond_embeds: &Tensor<f32>,
guidance_scale: f32,
) -> FerrotorchResult<(Tensor<f32>, Tensor<f32>, Tensor<f32>)> {
let batch = latent.shape()[0];
let t = Self::timestep_tensor(timestep, batch)?;
let model_input = self.scheduler.scale_model_input(latent, timestep)?;
let noise_uncond = self.unet.forward(&model_input, &t, uncond_embeds)?;
let noise_cond = self.unet.forward(&model_input, &t, cond_embeds)?;
let gs_scalar = ferrotorch_core::scalar::<f32>(guidance_scale)?;
let diff = sub(&noise_cond, &noise_uncond)?;
let scaled = mul(&diff, &gs_scalar)?;
let guided = add(&noise_uncond, &scaled)?;
Ok((noise_uncond, noise_cond, guided))
}
pub fn generate(
&mut self,
cond_embeds: &Tensor<f32>,
uncond_embeds: &Tensor<f32>,
init_latent: &Tensor<f32>,
num_inference_steps: usize,
guidance_scale: f32,
) -> FerrotorchResult<(Tensor<f32>, Vec<PipelineStepDump<f32>>)> {
if init_latent.ndim() != 4 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuStableDiffusionPipeline::generate: expected init_latent \
[B, 4, H, W], got {:?}",
init_latent.shape()
),
});
}
if cond_embeds.shape() != uncond_embeds.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"GpuStableDiffusionPipeline::generate: cond_embeds shape {:?} \
!= uncond_embeds {:?}",
cond_embeds.shape(),
uncond_embeds.shape()
),
});
}
let timesteps: Vec<usize> = self.scheduler.set_timesteps(num_inference_steps)?.to_vec();
let sigma = self.scheduler.init_noise_sigma() as f32;
let sigma_scalar = ferrotorch_core::scalar::<f32>(sigma)?;
let mut latent = mul(init_latent, &sigma_scalar)?;
let mut dumps: Vec<PipelineStepDump<f32>> = Vec::with_capacity(num_inference_steps);
for (i, &t) in timesteps.iter().enumerate() {
let (noise_uncond, noise_cond, guided) =
self.cfg_eval(&latent, t, cond_embeds, uncond_embeds, guidance_scale)?;
let latent_after = self.scheduler.step(&guided, t, &latent)?;
dumps.push(PipelineStepDump {
step: i,
timestep: t,
noise_pred_uncond: noise_uncond,
noise_pred_cond: noise_cond,
guided_noise: guided,
latent_after_step: latent_after.clone(),
});
latent = latent_after;
}
let image = self.vae.decode(&latent)?;
Ok((image, dumps))
}
}