use ferrotorch_core::grad_fns::arithmetic::{add, mul, sub};
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
use crate::clip_text_encoder::ClipTextEncoder;
use crate::scheduler::DDIMScheduler;
use crate::unet::UNet2DConditionModel;
use crate::vae::VaeDecoder;
#[derive(Debug)]
pub struct PipelineStepDump<T: Float> {
pub step: usize,
pub timestep: usize,
pub noise_pred_uncond: Tensor<T>,
pub noise_pred_cond: Tensor<T>,
pub guided_noise: Tensor<T>,
pub latent_after_step: Tensor<T>,
}
#[derive(Debug)]
pub struct StableDiffusionPipeline<T: Float> {
pub text_encoder: ClipTextEncoder<T>,
pub unet: UNet2DConditionModel<T>,
pub vae: VaeDecoder<T>,
pub scheduler: DDIMScheduler,
}
impl<T: Float> StableDiffusionPipeline<T> {
pub fn new(
text_encoder: ClipTextEncoder<T>,
unet: UNet2DConditionModel<T>,
vae: VaeDecoder<T>,
scheduler: DDIMScheduler,
) -> FerrotorchResult<Self> {
Ok(Self {
text_encoder,
unet,
vae,
scheduler,
})
}
pub fn encode_prompt(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
self.text_encoder.forward_from_ids(input_ids)
}
fn timestep_tensor(timestep: usize, batch: usize) -> FerrotorchResult<Tensor<T>> {
let v = T::from(timestep as f64).ok_or_else(|| FerrotorchError::InvalidArgument {
message: format!(
"StableDiffusionPipeline: cannot represent timestep {timestep} as the active Float"
),
})?;
Tensor::<T>::from_storage(TensorStorage::cpu(vec![v; batch]), vec![batch], false)
}
fn cfg_eval(
&self,
latent: &Tensor<T>,
timestep: usize,
cond_embeds: &Tensor<T>,
uncond_embeds: &Tensor<T>,
guidance_scale: f32,
) -> FerrotorchResult<(Tensor<T>, Tensor<T>, Tensor<T>)> {
let batch = latent.shape()[0];
let t = Self::timestep_tensor(timestep, batch)?;
let model_input = self.scheduler.scale_model_input(latent, timestep)?;
let noise_uncond = self.unet.forward_t(&model_input, &t, uncond_embeds)?;
let noise_cond = self.unet.forward_t(&model_input, &t, cond_embeds)?;
let gs = T::from(guidance_scale as f64).ok_or_else(|| FerrotorchError::InvalidArgument {
message: format!(
"StableDiffusionPipeline: cannot represent guidance_scale {guidance_scale} as the active Float"
),
})?;
let gs_t = ferrotorch_core::scalar::<T>(gs)?;
let diff = sub(&noise_cond, &noise_uncond)?;
let scaled = mul(&diff, &gs_t)?;
let guided = add(&noise_uncond, &scaled)?;
Ok((noise_uncond, noise_cond, guided))
}
pub fn generate(
&mut self,
cond_embeds: &Tensor<T>,
uncond_embeds: &Tensor<T>,
init_latent: &Tensor<T>,
num_inference_steps: usize,
guidance_scale: f32,
) -> FerrotorchResult<(Tensor<T>, Vec<PipelineStepDump<T>>)> {
if init_latent.ndim() != 4 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"StableDiffusionPipeline::generate: expected init_latent [B, 4, H, W], got {:?}",
init_latent.shape()
),
});
}
if cond_embeds.shape() != uncond_embeds.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"StableDiffusionPipeline::generate: cond_embeds shape {:?} != uncond_embeds {:?}",
cond_embeds.shape(),
uncond_embeds.shape()
),
});
}
let timesteps: Vec<usize> = self.scheduler.set_timesteps(num_inference_steps)?.to_vec();
let sigma = self.scheduler.init_noise_sigma();
let sigma_t = T::from(sigma).ok_or_else(|| FerrotorchError::InvalidArgument {
message: format!(
"StableDiffusionPipeline: cannot represent init_noise_sigma {sigma} as the active Float"
),
})?;
let sigma_scalar = ferrotorch_core::scalar::<T>(sigma_t)?;
let mut latent = mul(init_latent, &sigma_scalar)?;
let mut dumps: Vec<PipelineStepDump<T>> = Vec::with_capacity(num_inference_steps);
for (i, &t) in timesteps.iter().enumerate() {
let (noise_uncond, noise_cond, guided) =
self.cfg_eval(&latent, t, cond_embeds, uncond_embeds, guidance_scale)?;
let latent_after = self.scheduler.step(&guided, t, &latent)?;
dumps.push(PipelineStepDump {
step: i,
timestep: t,
noise_pred_uncond: noise_uncond,
noise_pred_cond: noise_cond,
guided_noise: guided,
latent_after_step: latent_after.clone(),
});
latent = latent_after;
}
let image = self.vae.decode_with_scaling(&latent)?;
Ok((image, dumps))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::clip_text_encoder::ClipTextConfig;
use crate::config::VaeDecoderConfig;
use crate::scheduler::DDIMConfig;
use crate::unet::UNet2DConditionModel;
use crate::unet_config::UNet2DConditionConfig;
use crate::vae::VaeDecoder;
fn build_tiny_pipeline() -> FerrotorchResult<StableDiffusionPipeline<f32>> {
let clip_cfg = ClipTextConfig::sd_v1_5();
let text_encoder = ClipTextEncoder::<f32>::new(clip_cfg)?;
let mut unet_cfg = UNet2DConditionConfig::sd_v1_5();
unet_cfg.sample_size = 8;
let unet = UNet2DConditionModel::<f32>::new(unet_cfg)?;
let mut vae_cfg = VaeDecoderConfig::sd_v1_5();
vae_cfg.sample_size = 8;
let vae = VaeDecoder::<f32>::new(vae_cfg)?;
let sched = DDIMScheduler::new(DDIMConfig::sd_v1_5())?;
StableDiffusionPipeline::new(text_encoder, unet, vae, sched)
}
#[test]
fn pipeline_constructs() {
let p = build_tiny_pipeline();
match p {
Ok(_) => {}
Err(e) => panic!("pipeline construction unexpectedly failed: {e}"),
}
}
}