diffusion_rs_common/core/
mod.rs

1#[cfg(feature = "accelerate")]
2mod accelerate;
3pub mod backend;
4pub mod backprop;
5pub mod conv;
6mod convert;
7pub mod cpu;
8pub mod cpu_backend;
9#[cfg(feature = "cuda")]
10pub mod cuda_backend;
11mod custom_op;
12mod device;
13pub mod display;
14mod dtype;
15pub mod dummy_cuda_backend;
16mod dummy_metal_backend;
17pub mod error;
18mod indexer;
19pub mod layout;
20#[cfg(feature = "metal")]
21pub mod metal_backend;
22#[cfg(feature = "mkl")]
23mod mkl;
24pub mod npy;
25pub mod op;
26pub mod pickle;
27pub mod quantized;
28pub mod safetensors;
29pub mod scalar;
30pub mod shape;
31mod sort;
32mod storage;
33pub mod streaming;
34mod strided_index;
35mod tensor;
36mod tensor_cat;
37mod tensor_indexing;
38pub mod test_utils;
39pub mod utils;
40mod variable;
41
42#[cfg(feature = "cudnn")]
43pub use cuda_backend::cudnn;
44
45pub use cpu_backend::{CpuStorage, CpuStorageRef};
46pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
47pub use device::{Device, DeviceLocation, NdArray};
48pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
49pub use error::{Context, Error, Result};
50pub use indexer::{IndexOp, TensorIndexer};
51pub use layout::Layout;
52pub use shape::{Shape, D};
53pub use storage::Storage;
54pub use streaming::{StreamTensor, StreamingBinOp, StreamingModule};
55pub use strided_index::{StridedBlocks, StridedIndex};
56pub use tensor::{from_storage_no_op, Tensor, TensorId};
57pub use variable::Var;
58
59#[cfg(feature = "cuda")]
60pub use cuda_backend as cuda;
61
62#[cfg(not(feature = "cuda"))]
63pub use dummy_cuda_backend as cuda;
64
65pub use cuda::{CudaDevice, CudaStorage};
66
67#[cfg(feature = "metal")]
68pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
69
70#[cfg(not(feature = "metal"))]
71pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
72
73#[cfg(feature = "mkl")]
74extern crate intel_mkl_src;
75
76#[cfg(feature = "accelerate")]
77extern crate accelerate_src;
78
79pub trait ToUsize2 {
80    fn to_usize2(self) -> (usize, usize);
81}
82
83impl ToUsize2 for usize {
84    fn to_usize2(self) -> (usize, usize) {
85        (self, self)
86    }
87}
88
89impl ToUsize2 for (usize, usize) {
90    fn to_usize2(self) -> (usize, usize) {
91        self
92    }
93}
94
95// A simple trait defining a module with forward method using a single argument.
96pub trait Module {
97    fn forward(&self, xs: &Tensor) -> Result<Tensor>;
98}
99
100impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
101    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
102        self(xs)
103    }
104}
105
106impl<M: Module> Module for Option<&M> {
107    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
108        match self {
109            None => Ok(xs.clone()),
110            Some(m) => m.forward(xs),
111        }
112    }
113}
114
115// A trait defining a module with forward method using a single tensor argument and a flag to
116// separate the training and evaluation behaviors.
117pub trait ModuleT {
118    fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
119}
120
121impl<M: Module> ModuleT for M {
122    fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
123        self.forward(xs)
124    }
125}