use pyo3::prelude::*;
pub mod device;
pub mod dtype;
pub mod error;
pub mod nn;
pub mod optim;
pub mod tensor;
pub mod utils;
pub use device::PyDevice;
pub use dtype::PyDType;
pub use error::TorshPyError;
pub use tensor::PyTensor;
#[pymodule]
fn rstorch(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyTensor>()?;
m.add_class::<PyDevice>()?;
m.add_class::<PyDType>()?;
nn::register_nn_module(m.py(), m)?;
optim::register_optim_module(m.py(), m)?;
tensor::register_creation_functions(m)?;
device::register_device_constants(m)?;
dtype::register_dtype_constants(m)?;
error::register_error_types(m)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())
}
#[pymodule]
fn rstorch_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
rstorch(m)
}