#![doc(html_logo_url = "https://parcel.pyke.io/v2/cdn/assetdelivery/diffusers/doc/diffusers-square.png")]
#![warn(missing_docs)]
#![warn(rustdoc::all)]
#![warn(clippy::correctness, clippy::suspicious, clippy::complexity, clippy::perf, clippy::style)]
#![allow(clippy::tabs_in_doc_comments)]
#[cfg(feature = "tokenizers")]
#[doc(hidden)]
pub mod clip;
pub(crate) mod config;
pub mod pipelines;
pub mod schedulers;
pub(crate) mod util;
pub use ort::Environment as OrtEnvironment;
use ort::ExecutionProvider;
pub use self::pipelines::*;
pub use self::schedulers::*;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ArenaExtendStrategy {
PowerOfTwo,
SameAsRequested
}
impl Default for ArenaExtendStrategy {
fn default() -> Self {
Self::PowerOfTwo
}
}
impl From<ArenaExtendStrategy> for String {
fn from(val: ArenaExtendStrategy) -> Self {
match val {
ArenaExtendStrategy::PowerOfTwo => "kNextPowerOfTwo".to_string(),
ArenaExtendStrategy::SameAsRequested => "kSameAsRequested".to_string()
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CuDNNConvolutionAlgorithmSearch {
Exhaustive,
Heuristic,
Default
}
impl Default for CuDNNConvolutionAlgorithmSearch {
fn default() -> Self {
Self::Exhaustive
}
}
impl From<CuDNNConvolutionAlgorithmSearch> for String {
fn from(val: CuDNNConvolutionAlgorithmSearch) -> Self {
match val {
CuDNNConvolutionAlgorithmSearch::Exhaustive => "EXHAUSTIVE".to_string(),
CuDNNConvolutionAlgorithmSearch::Heuristic => "HEURISTIC".to_string(),
CuDNNConvolutionAlgorithmSearch::Default => "DEFAULT".to_string()
}
}
}
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct CUDADeviceOptions {
pub arena_extend_strategy: Option<ArenaExtendStrategy>,
pub memory_limit: Option<usize>,
pub cudnn_conv_algorithm_search: Option<CuDNNConvolutionAlgorithmSearch>
}
impl From<CUDADeviceOptions> for ExecutionProvider {
fn from(val: CUDADeviceOptions) -> Self {
let mut ep = ExecutionProvider::cuda();
if let Some(arena_extend_strategy) = val.arena_extend_strategy {
ep = ep.with("arena_extend_strategy", arena_extend_strategy);
}
if let Some(memory_limit) = val.memory_limit {
ep = ep.with("gpu_mem_limit", memory_limit.to_string());
}
if let Some(cudnn_conv_algorithm_search) = val.cudnn_conv_algorithm_search {
ep = ep.with("cudnn_conv_algo_search", cudnn_conv_algorithm_search);
}
ep
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum DiffusionDevice {
CPU,
CUDA(usize, Option<CUDADeviceOptions>),
TensorRT,
DirectML(usize),
OneDNN,
Custom(ExecutionProvider)
}
impl From<DiffusionDevice> for ExecutionProvider {
fn from(value: DiffusionDevice) -> Self {
match value {
DiffusionDevice::CPU => ExecutionProvider::cpu(),
DiffusionDevice::CUDA(device, options) => {
let options = options.unwrap_or_default();
let mut ep: ExecutionProvider = options.into();
ep = ep.with("device_id", device.to_string());
ep
}
DiffusionDevice::TensorRT => ExecutionProvider::tensorrt(),
DiffusionDevice::DirectML(_) => todo!("sorry, not implemented yet, please open an issue"),
DiffusionDevice::OneDNN => ExecutionProvider::onednn(),
DiffusionDevice::Custom(ep) => ep
}
}
}
#[derive(Debug, Clone)]
pub struct DiffusionDeviceControl {
pub vae_encoder: DiffusionDevice,
pub vae_decoder: DiffusionDevice,
pub text_encoder: DiffusionDevice,
pub unet: DiffusionDevice,
pub safety_checker: DiffusionDevice
}
impl DiffusionDeviceControl {
pub fn all(device: DiffusionDevice) -> Self {
Self {
vae_encoder: device.clone(),
vae_decoder: device.clone(),
text_encoder: device.clone(),
unet: device.clone(),
safety_checker: device
}
}
}
impl Default for DiffusionDeviceControl {
fn default() -> Self {
DiffusionDeviceControl::all(DiffusionDevice::CPU)
}
}