1#[cfg(feature = "accelerate")]
51mod accelerate;
52pub mod backend;
53pub mod backprop;
54pub mod conv;
55mod convert;
56pub mod cpu;
57pub mod cpu_backend;
58#[cfg(feature = "cuda")]
59pub mod cuda_backend;
60mod custom_op;
61mod device;
62pub mod display;
63mod dtype;
64pub mod dummy_cuda_backend;
65pub mod dummy_dtype;
66mod dummy_metal_backend;
67pub mod error;
68mod indexer;
69pub mod layout;
70#[cfg(feature = "metal")]
71pub mod metal_backend;
72#[cfg(feature = "mkl")]
73mod mkl;
74pub mod nditer;
75pub mod npy;
76pub mod op;
77pub mod pickle;
78pub mod quantized;
79pub mod safetensors;
80pub mod scalar;
81pub mod shape;
82mod sort;
83mod storage;
84pub mod streaming;
85mod strided_index;
86mod tensor;
87mod tensor_cat;
88pub mod test_utils;
89pub mod utils;
90mod variable;
91
92#[cfg(feature = "cudnn")]
93pub use cuda_backend::cudnn;
94
95pub use cpu_backend::{CpuStorage, CpuStorageRef};
96#[cfg(feature = "ug")]
97pub use custom_op::UgIOp1;
98pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
99pub use device::{Device, DeviceLocation, NdArray};
100pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
101pub use dummy_dtype::{F4, F6E2M3, F6E3M2, F8E8M0};
102pub use error::{Context, Error, Result};
103pub use indexer::{IndexOp, TensorIndexer};
104pub use layout::Layout;
105pub use nditer::NdIter;
106pub use shape::{Shape, D};
107pub use storage::Storage;
108pub use streaming::{StreamTensor, StreamingBinOp, StreamingModule};
109pub use strided_index::{StridedBlocks, StridedIndex};
110pub use tensor::{Tensor, TensorId};
111pub use variable::Var;
112
113#[cfg(feature = "cuda")]
114pub use cuda_backend as cuda;
115
116#[cfg(not(feature = "cuda"))]
117pub use dummy_cuda_backend as cuda;
118
119pub use cuda::{CudaDevice, CudaStorage};
120
121#[cfg(feature = "metal")]
122pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
123
124#[cfg(not(feature = "metal"))]
125pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
126
127#[cfg(feature = "mkl")]
128extern crate intel_mkl_src;
129
130#[cfg(feature = "accelerate")]
131extern crate accelerate_src;
132
133pub trait ToUsize2 {
134 fn to_usize2(self) -> (usize, usize);
135}
136
137impl ToUsize2 for usize {
138 fn to_usize2(self) -> (usize, usize) {
139 (self, self)
140 }
141}
142
143impl ToUsize2 for (usize, usize) {
144 fn to_usize2(self) -> (usize, usize) {
145 self
146 }
147}
148
149pub trait Module {
151 fn forward(&self, xs: &Tensor) -> Result<Tensor>;
152}
153
154impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
155 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
156 self(xs)
157 }
158}
159
160impl<M: Module> Module for Option<&M> {
161 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
162 match self {
163 None => Ok(xs.clone()),
164 Some(m) => m.forward(xs),
165 }
166 }
167}
168
169pub trait ModuleT {
172 fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
173}
174
175impl<M: Module> ModuleT for M {
176 fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
177 self.forward(xs)
178 }
179}