1#[cfg(feature = "accelerate")]
51mod accelerate;
52pub mod backend;
53pub mod backprop;
54pub mod conv;
55mod convert;
56pub mod cpu;
57pub mod cpu_backend;
58#[cfg(feature = "cuda")]
59pub mod cuda_backend;
60mod custom_op;
61mod device;
62pub mod display;
63mod dtype;
64pub mod dummy_cuda_backend;
65pub mod dummy_dtype;
66mod dummy_metal_backend;
67pub mod error;
68mod indexer;
69pub mod layout;
70#[cfg(feature = "metal")]
71pub mod metal_backend;
72#[cfg(feature = "mkl")]
73mod mkl;
74pub mod npy;
75pub mod op;
76pub mod pickle;
77pub mod quantized;
78pub mod safetensors;
79pub mod scalar;
80pub mod shape;
81mod sort;
82mod storage;
83pub mod streaming;
84mod strided_index;
85mod tensor;
86mod tensor_cat;
87pub mod test_utils;
88pub mod utils;
89mod variable;
90
91#[cfg(feature = "cudnn")]
92pub use cuda_backend::cudnn;
93
94pub use cpu_backend::{CpuStorage, CpuStorageRef};
95#[cfg(feature = "ug")]
96pub use custom_op::UgIOp1;
97pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
98pub use device::{Device, DeviceLocation, NdArray};
99pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
100pub use dummy_dtype::{F4, F6E2M3, F6E3M2, F8E8M0};
101pub use error::{Context, Error, Result};
102pub use indexer::{IndexOp, TensorIndexer};
103pub use layout::Layout;
104pub use shape::{Shape, D};
105pub use storage::Storage;
106pub use streaming::{StreamTensor, StreamingBinOp, StreamingModule};
107pub use strided_index::{StridedBlocks, StridedIndex};
108pub use tensor::{Tensor, TensorId};
109pub use variable::Var;
110
111#[cfg(feature = "cuda")]
112pub use cuda_backend as cuda;
113
114#[cfg(not(feature = "cuda"))]
115pub use dummy_cuda_backend as cuda;
116
117pub use cuda::{CudaDevice, CudaStorage};
118
119#[cfg(feature = "metal")]
120pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
121
122#[cfg(not(feature = "metal"))]
123pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
124
125#[cfg(feature = "mkl")]
126extern crate intel_mkl_src;
127
128#[cfg(feature = "accelerate")]
129extern crate accelerate_src;
130
131pub trait ToUsize2 {
132 fn to_usize2(self) -> (usize, usize);
133}
134
135impl ToUsize2 for usize {
136 fn to_usize2(self) -> (usize, usize) {
137 (self, self)
138 }
139}
140
141impl ToUsize2 for (usize, usize) {
142 fn to_usize2(self) -> (usize, usize) {
143 self
144 }
145}
146
147pub trait Module {
149 fn forward(&self, xs: &Tensor) -> Result<Tensor>;
150}
151
152impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
153 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
154 self(xs)
155 }
156}
157
158impl<M: Module> Module for Option<&M> {
159 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
160 match self {
161 None => Ok(xs.clone()),
162 Some(m) => m.forward(xs),
163 }
164 }
165}
166
167pub trait ModuleT {
170 fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
171}
172
173impl<M: Module> ModuleT for M {
174 fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
175 self.forward(xs)
176 }
177}