oxigaf_diffusion/lib.rs
1//! # oxigaf-diffusion
2//!
3//! Multi-view diffusion model inference for GAF.
4//!
5//! Implements the full pipeline: CLIP image encoding → multi-view U-Net
6//! denoising with camera-conditioned cross-view attention → VAE decoding.
7//!
8//! ## Cargo Features
9//!
10//! This crate supports the following feature flags:
11//!
12//! - **`default`** = `["accelerate", "flash_attention"]`:
13//! Default features for CPU-only inference with optimizations
14//!
15//! - **`accelerate`**:
16//! Uses platform-native BLAS/LAPACK for CPU tensor operations
17//! - macOS: Accelerate framework
18//! - Linux: OpenBLAS or Intel MKL
19//!
20//! - **`cuda`** (platform-specific):
21//! Enables NVIDIA GPU acceleration via CUDA
22//! - Requires CUDA toolkit installed
23//! - Not available on macOS
24//!
25//! - **`metal`** (platform-specific):
26//! Enables Apple Silicon GPU acceleration via Metal
27//! - macOS only
28//! - Optimized for M1/M2/M3 chips
29//!
30//! - **`flash_attention`** (enabled by default):
31//! Memory-efficient attention with O(N) complexity instead of O(N²)
32//! - Reduces memory usage by 2-4× for large images
33//! - Tiled computation for better cache locality
34//!
35//! - **`mixed_precision`** (planned, not yet implemented):
36//! FP16/BF16 inference for reduced memory usage
37//! - Faster on GPUs with Tensor Cores
38//! - Lower memory footprint
39//!
40//! Example usage:
41//! ```toml
42//! # In Cargo.toml
43//! # For CPU-only with flash attention
44//! oxigaf-diffusion = { version = "0.1", default-features = true }
45//!
46//! # For Apple Silicon with Metal acceleration
47//! oxigaf-diffusion = { version = "0.1", features = ["metal", "flash_attention"] }
48//!
49//! # For NVIDIA GPU with CUDA
50//! oxigaf-diffusion = { version = "0.1", features = ["cuda", "flash_attention"] }
51//! ```
52
53#![deny(clippy::unwrap_used)]
54#![deny(clippy::expect_used)]
55
56pub mod attention;
57pub mod camera;
58pub mod clip;
59pub mod config;
60#[cfg(feature = "flash_attention")]
61pub mod flash_attention;
62pub mod pipeline;
63pub mod scheduler;
64pub mod unet;
65pub mod upsampler;
66pub mod vae;
67
68use std::path::PathBuf;
69use thiserror::Error;
70
71/// Errors that can occur during diffusion model operations.
72#[derive(Debug, Error)]
73pub enum DiffusionError {
74 // -------------------------------------------------------------------------
75 // Model Loading Errors
76 // -------------------------------------------------------------------------
77 /// Generic model loading error with context message.
78 #[error("Model loading error: {0}")]
79 ModelLoad(String),
80
81 /// Weight file not found at expected path.
82 #[error("Weight not found: layer '{layer}', expected shape {expected_shape:?}")]
83 WeightNotFound {
84 layer: String,
85 expected_shape: Vec<usize>,
86 },
87
88 /// Weight shape does not match expected dimensions.
89 #[error("Weight shape mismatch: layer '{layer}', expected {expected:?}, got {got:?}")]
90 WeightShapeMismatch {
91 layer: String,
92 expected: Vec<usize>,
93 got: Vec<usize>,
94 },
95
96 /// Safetensors file is corrupted or invalid.
97 #[error("Safetensors corrupt: {path:?}, reason: {reason}")]
98 SafetensorsCorrupt { path: PathBuf, reason: String },
99
100 // -------------------------------------------------------------------------
101 // Tensor Operation Errors
102 // -------------------------------------------------------------------------
103 /// Tensor shape mismatch during operation.
104 #[error("Shape mismatch in '{op}': expected {expected:?}, got {got:?}")]
105 ShapeMismatch {
106 op: String,
107 expected: Vec<usize>,
108 got: Vec<usize>,
109 },
110
111 /// Data type mismatch between tensors.
112 #[error("Dtype mismatch: expected {expected}, got {got}")]
113 DtypeMismatch { expected: String, got: String },
114
115 /// Device mismatch between tensors.
116 #[error("Device mismatch: expected {expected}, got {got}")]
117 DeviceMismatch { expected: String, got: String },
118
119 // -------------------------------------------------------------------------
120 // Numerical Errors
121 // -------------------------------------------------------------------------
122 /// NaN detected in tensor during computation.
123 #[error("NaN detected in layer '{layer}' at timestep {timestep:?}")]
124 NanDetected {
125 layer: String,
126 timestep: Option<usize>,
127 },
128
129 /// Infinity detected in tensor during computation.
130 #[error("Inf detected in layer '{layer}' at timestep {timestep:?}")]
131 InfDetected {
132 layer: String,
133 timestep: Option<usize>,
134 },
135
136 /// General numerical instability.
137 #[error("Numerical instability: {context}")]
138 NumericalInstability { context: String },
139
140 // -------------------------------------------------------------------------
141 // Inference Errors
142 // -------------------------------------------------------------------------
143 /// Generic inference error with context.
144 #[error("Inference error: {0}")]
145 Inference(String),
146
147 /// Invalid timestep value.
148 #[error("Invalid timestep: {value}, max allowed: {max}")]
149 InvalidTimestep { value: usize, max: usize },
150
151 /// Invalid number of views provided.
152 #[error("Invalid view count: expected {expected}, got {got}")]
153 InvalidViewCount { expected: usize, got: usize },
154
155 /// Invalid latent tensor shape.
156 #[error("Invalid latent shape: expected {expected:?}, got {got:?}")]
157 InvalidLatentShape {
158 expected: Vec<usize>,
159 got: Vec<usize>,
160 },
161
162 /// Skip connection underflow during U-Net forward pass.
163 #[error(
164 "Skip connection underflow: expected {expected} connections, only {available} available"
165 )]
166 SkipConnectionUnderflow { expected: usize, available: usize },
167
168 // -------------------------------------------------------------------------
169 // Pipeline Errors
170 // -------------------------------------------------------------------------
171 /// Scheduler not initialized before use.
172 #[error("Scheduler not initialized: call set_timesteps() first")]
173 SchedulerNotInitialized,
174
175 /// CLIP encoding failed.
176 #[error("CLIP encoding failed: {0}")]
177 ClipEncodingFailed(String),
178
179 /// VAE encoding failed.
180 #[error("VAE encoding failed: {0}")]
181 VaeEncodeFailed(String),
182
183 /// VAE decoding failed.
184 #[error("VAE decoding failed: {0}")]
185 VaeDecodeFailed(String),
186
187 /// U-Net forward pass failed.
188 #[error("U-Net forward failed at timestep {timestep}: {reason}")]
189 UnetForwardFailed { timestep: usize, reason: String },
190
191 // -------------------------------------------------------------------------
192 // I/O Errors
193 // -------------------------------------------------------------------------
194 /// I/O error during file operations.
195 #[error("I/O error: {0}")]
196 IoError(#[from] std::io::Error),
197
198 /// Image processing error.
199 #[error("Image processing error: {0}")]
200 ImageProcessingError(String),
201
202 // -------------------------------------------------------------------------
203 // Candle Backend Errors
204 // -------------------------------------------------------------------------
205 /// Error from candle tensor operations.
206 #[error("Candle error: {0}")]
207 Candle(#[from] candle_core::Error),
208}
209
210/// Result type for diffusion operations.
211pub type DiffusionResult<T> = std::result::Result<T, DiffusionError>;
212
213// Re-exports
214pub use clip::ClipImageEncoder;
215pub use config::DiffusionConfig;
216pub use pipeline::{MultiViewDiffusionPipeline, MultiViewOutput};
217pub use scheduler::{DdimScheduler, PredictionType};
218pub use unet::MultiViewUNet;
219pub use upsampler::{LatentUpsampler, UpsamplerMode};
220pub use vae::Vae;
221
222// Flash attention exports (only when feature is enabled)
223#[cfg(feature = "flash_attention")]
224pub use flash_attention::{
225 flash_attention, flash_attention_with_config, FlashAttention, FlashAttentionConfig,
226};