ferrotorch_diffusion/lib.rs
1// Crate-level lint baseline. Mirrors the ferrotorch-whisper / ferrotorch-bert
2// posture: deny correctness / idiom / Debug / docs problems; warn pedantic
3// stylistic issues. Specific pedantic lints are allowed crate-wide where
4// the lint is consistently wrong for ML/numeric kernel code.
5
6#![deny(unsafe_code)]
7#![deny(rust_2018_idioms)]
8#![deny(missing_debug_implementations)]
9#![deny(missing_docs)]
10#![warn(clippy::all)]
11#![warn(clippy::pedantic)]
12// Casts: dimension math (`as usize`, `as f32`, `as u32`) is intrinsic
13// to tensor indexing — every kernel call would otherwise need a
14// per-call allow.
15#![allow(clippy::cast_possible_truncation)]
16#![allow(clippy::cast_precision_loss)]
17#![allow(clippy::cast_sign_loss)]
18#![allow(clippy::cast_possible_wrap)]
19#![allow(clippy::cast_lossless)]
20// Builder-style accessors don't all need `#[must_use]`.
21#![allow(clippy::must_use_candidate)]
22// Identifiers like `bf16`, `f32`, `VAE`, `SD`, `SiLU` are flagged as
23// missing backticks even when they appear in code-fenced text.
24#![allow(clippy::doc_markdown)]
25// `needless_pass_by_value` would force `&VaeDecoderConfig` signatures
26// throughout, hiding intent in the API.
27#![allow(clippy::needless_pass_by_value)]
28// `unnecessary_wraps` flags `Result`-returning helpers that today
29// always succeed but are part of an extensible API surface.
30#![allow(clippy::unnecessary_wraps)]
31// `uninlined_format_args` flags `format!("x={}", x)` vs
32// `format!("x={x}")`. Both are equally clear; the fixup churn is high.
33#![allow(clippy::uninlined_format_args)]
34// `many_single_char_names` flags conventional ML kernel locals
35// (`q`, `k`, `v`, `h`).
36#![allow(clippy::many_single_char_names)]
37// `similar_names` flags variable pairs that are intentionally similar
38// (e.g. `q2` / `q_h`).
39#![allow(clippy::similar_names)]
40// `module_name_repetitions`: every type starts with `Vae` / `UNet`
41// (matching the HF / diffusers naming) — the lint would force renames
42// that lose the upstream-1:1 mapping.
43#![allow(clippy::module_name_repetitions)]
44// `too_many_lines`: the decoder / UNet forward is one cohesive sequence
45// of ops mirroring the diffusers reference; splitting it hurts
46// cross-reading.
47#![allow(clippy::too_many_lines)]
48// UNet builders take a handful of (in_c, out_c, temb, layers, heads,
49// dim_head, cross_dim, groups, …) parameters — the explicit list is
50// shorter than the struct-of-args alternative for an internal builder.
51#![allow(clippy::too_many_arguments)]
52// `items_after_statements` flags the in-test helper layout used widely.
53#![allow(clippy::items_after_statements)]
54// `redundant_else` flags `if x { return …; } else { … }`; the
55// alternative (`if x { return …; } …`) loses the structural shape.
56#![allow(clippy::redundant_else)]
57// Tensor ops naturally use `for i in 0..n { … }` over `.iter()` when
58// the index itself is used; clippy's preferred form hurts readability.
59#![allow(clippy::needless_range_loop)]
60
61//! Stable-Diffusion model composition for ferrotorch.
62//!
63//! Phase B.3 of real-artifact-driven development. This crate implements
64//! the **VAE decoder** (Phase B.3a) and the **UNet2DConditionModel**
65//! (Phase B.3b) of `runwayml/stable-diffusion-v1-5`. The encoder, the
66//! CLIP text encoder, and the scheduler are out of scope and tracked
67//! under follow-up dispatches.
68//!
69//! ## VAE decoder
70//!
71//! Mirrors `vae/config.json` — `VaeDecoder` inverts a latent
72//! `[B, 4, 64, 64]` into an image `[B, 3, 512, 512]`. See [`vae`].
73//!
74//! ## UNet2DConditionModel
75//!
76//! Mirrors `unet/config.json` — `UNet2DConditionModel` consumes
77//! `(noisy_latent [B, 4, 64, 64], timestep [B], text_embed [B, S, 768])`
78//! and returns predicted noise `[B, 4, 64, 64]`. See [`unet`].
79//!
80//! ResnetBlock2DTime (UNet flavour with time bias):
81//!
82//! ```text
83//! h = silu(norm1(x)); h = conv1(h)
84//! t = silu(temb); h = h + Linear(t).view(B, out, 1, 1)
85//! h = silu(norm2(h)); h = conv2(h)
86//! out = h + (x if in==out else conv_shortcut(x))
87//! ```
88//!
89//! Transformer2DModel (SD UNet flavour):
90//!
91//! ```text
92//! h = GroupNorm(x); h = proj_in (Conv2d k=1, [B, inner, H, W])
93//! h = flatten to [B, HW, inner]; for block in blocks: h = block(h, ehs)
94//! h = reshape back; h = proj_out (Conv2d k=1); out = h + residual
95//! ```
96//!
97//! Each `BasicTransformerBlock` is the canonical pre-LN
98//! (self-attn → cross-attn → GEGLU FF) stack.
99
100pub mod attention;
101pub mod blocks;
102pub mod clip_text_encoder;
103pub mod config;
104#[cfg(feature = "cuda")]
105pub mod gpu;
106pub mod pipeline;
107pub mod resnet_block_time;
108pub mod safetensors_loader;
109pub mod scheduler;
110pub mod time_embedding;
111pub mod unet;
112pub mod unet_config;
113pub mod vae;
114pub mod vae_encoder;
115
116pub use attention::{Attention, BasicTransformerBlock, FeedForward, Transformer2DModel};
117pub use blocks::{
118 AttnBlock2D, DownEncoderBlock2D, Downsample2D, ResnetBlock2D, UNetMidBlock2D,
119 UpDecoderBlock2D, Upsample2D,
120};
121pub use clip_text_encoder::{
122 ClipEncoder, ClipEncoderLayer, ClipMlp, ClipSelfAttention, ClipTextConfig, ClipTextEmbeddings,
123 ClipTextEncoder,
124};
125pub use config::VaeDecoderConfig;
126pub use pipeline::{PipelineStepDump, StableDiffusionPipeline};
127pub use resnet_block_time::ResnetBlock2DTime;
128pub use safetensors_loader::{
129 load_clip_text_encoder, load_unet, load_vae_decoder, load_vae_encoder, DropReport,
130};
131pub use scheduler::{BetaSchedule, DDIMConfig, DDIMScheduler, PredictionType, TimestepSpacing};
132pub use time_embedding::{TimestepEmbedding, Timesteps};
133pub use unet::{
134 AnyDownBlock, AnyUpBlock, CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UNet2DConditionModel,
135 UNetMidBlock2DCrossAttn, UpBlock2D,
136};
137pub use unet_config::UNet2DConditionConfig;
138pub use vae::{Decoder, VaeDecoder};
139pub use vae_encoder::{DiagonalGaussianDistribution, Encoder, VaeEncoder, VaeEncoderConfig};