Skip to main content

oxigaf_diffusion/
lib.rs

1//! # oxigaf-diffusion
2//!
3//! Multi-view diffusion model inference for GAF.
4//!
5//! Implements the full pipeline: CLIP image encoding → multi-view U-Net
6//! denoising with camera-conditioned cross-view attention → VAE decoding.
7//!
8//! ## Cargo Features
9//!
10//! This crate supports the following feature flags:
11//!
12//! - **`default`** = `["accelerate", "flash_attention"]`:
13//!   Default features for CPU-only inference with optimizations
14//!
15//! - **`accelerate`**:
16//!   Uses platform-native BLAS/LAPACK for CPU tensor operations
17//!   - macOS: Accelerate framework
18//!   - Linux: OpenBLAS or Intel MKL
19//!
20//! - **`cuda`** (platform-specific):
21//!   Enables NVIDIA GPU acceleration via CUDA
22//!   - Requires CUDA toolkit installed
23//!   - Not available on macOS
24//!
25//! - **`metal`** (platform-specific):
26//!   Enables Apple Silicon GPU acceleration via Metal
27//!   - macOS only
28//!   - Optimized for M1/M2/M3 chips
29//!
30//! - **`flash_attention`** (enabled by default):
31//!   Memory-efficient attention with O(N) complexity instead of O(N²)
32//!   - Reduces memory usage by 2-4× for large images
33//!   - Tiled computation for better cache locality
34//!
35//! - **`mixed_precision`** (planned, not yet implemented):
36//!   FP16/BF16 inference for reduced memory usage
37//!   - Faster on GPUs with Tensor Cores
38//!   - Lower memory footprint
39//!
40//! Example usage:
41//! ```toml
42//! # In Cargo.toml
43//! # For CPU-only with flash attention
44//! oxigaf-diffusion = { version = "0.1", default-features = true }
45//!
46//! # For Apple Silicon with Metal acceleration
47//! oxigaf-diffusion = { version = "0.1", features = ["metal", "flash_attention"] }
48//!
49//! # For NVIDIA GPU with CUDA
50//! oxigaf-diffusion = { version = "0.1", features = ["cuda", "flash_attention"] }
51//! ```
52
53#![deny(clippy::unwrap_used)]
54#![deny(clippy::expect_used)]
55
56pub mod attention;
57pub mod camera;
58pub mod clip;
59pub mod config;
60#[cfg(feature = "flash_attention")]
61pub mod flash_attention;
62pub mod pipeline;
63pub mod scheduler;
64pub mod unet;
65pub mod upsampler;
66pub mod vae;
67
68use std::path::PathBuf;
69use thiserror::Error;
70
71/// Errors that can occur during diffusion model operations.
72#[derive(Debug, Error)]
73pub enum DiffusionError {
74    // -------------------------------------------------------------------------
75    // Model Loading Errors
76    // -------------------------------------------------------------------------
77    /// Generic model loading error with context message.
78    #[error("Model loading error: {0}")]
79    ModelLoad(String),
80
81    /// Weight file not found at expected path.
82    #[error("Weight not found: layer '{layer}', expected shape {expected_shape:?}")]
83    WeightNotFound {
84        layer: String,
85        expected_shape: Vec<usize>,
86    },
87
88    /// Weight shape does not match expected dimensions.
89    #[error("Weight shape mismatch: layer '{layer}', expected {expected:?}, got {got:?}")]
90    WeightShapeMismatch {
91        layer: String,
92        expected: Vec<usize>,
93        got: Vec<usize>,
94    },
95
96    /// Safetensors file is corrupted or invalid.
97    #[error("Safetensors corrupt: {path:?}, reason: {reason}")]
98    SafetensorsCorrupt { path: PathBuf, reason: String },
99
100    // -------------------------------------------------------------------------
101    // Tensor Operation Errors
102    // -------------------------------------------------------------------------
103    /// Tensor shape mismatch during operation.
104    #[error("Shape mismatch in '{op}': expected {expected:?}, got {got:?}")]
105    ShapeMismatch {
106        op: String,
107        expected: Vec<usize>,
108        got: Vec<usize>,
109    },
110
111    /// Data type mismatch between tensors.
112    #[error("Dtype mismatch: expected {expected}, got {got}")]
113    DtypeMismatch { expected: String, got: String },
114
115    /// Device mismatch between tensors.
116    #[error("Device mismatch: expected {expected}, got {got}")]
117    DeviceMismatch { expected: String, got: String },
118
119    // -------------------------------------------------------------------------
120    // Numerical Errors
121    // -------------------------------------------------------------------------
122    /// NaN detected in tensor during computation.
123    #[error("NaN detected in layer '{layer}' at timestep {timestep:?}")]
124    NanDetected {
125        layer: String,
126        timestep: Option<usize>,
127    },
128
129    /// Infinity detected in tensor during computation.
130    #[error("Inf detected in layer '{layer}' at timestep {timestep:?}")]
131    InfDetected {
132        layer: String,
133        timestep: Option<usize>,
134    },
135
136    /// General numerical instability.
137    #[error("Numerical instability: {context}")]
138    NumericalInstability { context: String },
139
140    // -------------------------------------------------------------------------
141    // Inference Errors
142    // -------------------------------------------------------------------------
143    /// Generic inference error with context.
144    #[error("Inference error: {0}")]
145    Inference(String),
146
147    /// Invalid timestep value.
148    #[error("Invalid timestep: {value}, max allowed: {max}")]
149    InvalidTimestep { value: usize, max: usize },
150
151    /// Invalid number of views provided.
152    #[error("Invalid view count: expected {expected}, got {got}")]
153    InvalidViewCount { expected: usize, got: usize },
154
155    /// Invalid latent tensor shape.
156    #[error("Invalid latent shape: expected {expected:?}, got {got:?}")]
157    InvalidLatentShape {
158        expected: Vec<usize>,
159        got: Vec<usize>,
160    },
161
162    /// Skip connection underflow during U-Net forward pass.
163    #[error(
164        "Skip connection underflow: expected {expected} connections, only {available} available"
165    )]
166    SkipConnectionUnderflow { expected: usize, available: usize },
167
168    // -------------------------------------------------------------------------
169    // Pipeline Errors
170    // -------------------------------------------------------------------------
171    /// Scheduler not initialized before use.
172    #[error("Scheduler not initialized: call set_timesteps() first")]
173    SchedulerNotInitialized,
174
175    /// CLIP encoding failed.
176    #[error("CLIP encoding failed: {0}")]
177    ClipEncodingFailed(String),
178
179    /// VAE encoding failed.
180    #[error("VAE encoding failed: {0}")]
181    VaeEncodeFailed(String),
182
183    /// VAE decoding failed.
184    #[error("VAE decoding failed: {0}")]
185    VaeDecodeFailed(String),
186
187    /// U-Net forward pass failed.
188    #[error("U-Net forward failed at timestep {timestep}: {reason}")]
189    UnetForwardFailed { timestep: usize, reason: String },
190
191    // -------------------------------------------------------------------------
192    // I/O Errors
193    // -------------------------------------------------------------------------
194    /// I/O error during file operations.
195    #[error("I/O error: {0}")]
196    IoError(#[from] std::io::Error),
197
198    /// Image processing error.
199    #[error("Image processing error: {0}")]
200    ImageProcessingError(String),
201
202    // -------------------------------------------------------------------------
203    // Candle Backend Errors
204    // -------------------------------------------------------------------------
205    /// Error from candle tensor operations.
206    #[error("Candle error: {0}")]
207    Candle(#[from] candle_core::Error),
208}
209
210/// Result type for diffusion operations.
211pub type DiffusionResult<T> = std::result::Result<T, DiffusionError>;
212
213// Re-exports
214pub use clip::ClipImageEncoder;
215pub use config::DiffusionConfig;
216pub use pipeline::{MultiViewDiffusionPipeline, MultiViewOutput};
217pub use scheduler::{DdimScheduler, PredictionType};
218pub use unet::MultiViewUNet;
219pub use upsampler::{LatentUpsampler, UpsamplerMode};
220pub use vae::Vae;
221
222// Flash attention exports (only when feature is enabled)
223#[cfg(feature = "flash_attention")]
224pub use flash_attention::{
225    flash_attention, flash_attention_with_config, FlashAttention, FlashAttentionConfig,
226};