#![allow(dead_code)]
#![cfg_attr(not(feature = "std"), no_std)]
#[cfg(not(feature = "std"))]
extern crate alloc;
pub use torsh_autograd as autograd;
pub use torsh_core as core;
pub use torsh_tensor as tensor;
#[cfg(feature = "nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "nn")))]
pub use torsh_nn as nn;
#[cfg(feature = "optim")]
#[cfg_attr(docsrs, doc(cfg(feature = "optim")))]
pub use torsh_optim as optim;
#[cfg(feature = "data")]
#[cfg_attr(docsrs, doc(cfg(feature = "data")))]
pub use torsh_data as data;
#[cfg(feature = "functional")]
#[cfg_attr(docsrs, doc(cfg(feature = "functional")))]
pub use torsh_functional as functional;
#[cfg(feature = "text")]
#[cfg_attr(docsrs, doc(cfg(feature = "text")))]
pub use torsh_text as text;
#[cfg(feature = "vision")]
#[cfg_attr(docsrs, doc(cfg(feature = "vision")))]
pub use torsh_vision as vision;
#[cfg(feature = "sparse")]
#[cfg_attr(docsrs, doc(cfg(feature = "sparse")))]
pub use torsh_sparse as sparse;
#[cfg(feature = "quantization")]
#[cfg_attr(docsrs, doc(cfg(feature = "quantization")))]
pub use torsh_quantization as quantization;
#[cfg(feature = "special")]
#[cfg_attr(docsrs, doc(cfg(feature = "special")))]
pub use torsh_special as special;
#[cfg(feature = "linalg")]
#[cfg_attr(docsrs, doc(cfg(feature = "linalg")))]
pub use torsh_linalg as linalg;
#[cfg(feature = "profiler")]
#[cfg_attr(docsrs, doc(cfg(feature = "profiler")))]
pub use torsh_profiler as profiler;
#[cfg(feature = "distributed")]
#[cfg_attr(docsrs, doc(cfg(feature = "distributed")))]
pub use torsh_distributed as distributed;
#[cfg(feature = "jit")]
#[cfg_attr(docsrs, doc(cfg(feature = "jit")))]
pub use torsh_jit as jit;
#[cfg(feature = "fx")]
#[cfg_attr(docsrs, doc(cfg(feature = "fx")))]
pub use torsh_fx as fx;
#[cfg(feature = "hub")]
#[cfg_attr(docsrs, doc(cfg(feature = "hub")))]
pub use torsh_hub as hub;
#[cfg(feature = "backend")]
#[cfg_attr(docsrs, doc(cfg(feature = "backend")))]
pub use torsh_backend as backend;
pub use core::{
device::{Device, DeviceType},
dtype::{DType, FloatElement, TensorElement},
error::{Result, TorshError},
shape::Shape,
storage::Storage,
};
pub use tensor::{tensor, Tensor};
pub use autograd::{backward, enable_grad, is_grad_enabled, no_grad};
pub use tensor::creation::{
arange, eye, linspace, ones, ones_like, rand, rand_like, randint, randn, randn_like, zeros,
zeros_like,
};
pub use tensor::indexing::TensorIndex;
#[cfg(feature = "functional")]
pub use functional::{
adaptive_avg_pool2d, avg_pool2d, batch_norm, binary_cross_entropy, conv1d, conv2d, conv3d,
cross_entropy, gelu, layer_norm, linear, log_softmax, max_pool2d, mse_loss, relu, sigmoid,
silu, softmax, tanh,
};
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
pub const VERSION_MAJOR: u32 = 0;
pub const VERSION_MINOR: u32 = 1;
pub const VERSION_PATCH: u32 = 1;
pub mod version {
use crate::{Result, TorshError};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CrateVersion {
pub name: &'static str,
pub version: &'static str,
pub major: u32,
pub minor: u32,
pub patch: u32,
}
impl CrateVersion {
pub const fn new(
name: &'static str,
version: &'static str,
major: u32,
minor: u32,
patch: u32,
) -> Self {
Self {
name,
version,
major,
minor,
patch,
}
}
pub fn is_compatible_with(&self, other: &CrateVersion) -> bool {
if self.major == 0 && other.major == 0 {
return self.minor == other.minor && self.patch == other.patch;
}
self.major == other.major && self.minor >= other.minor
}
}
pub fn get_crate_versions() -> Vec<CrateVersion> {
let mut versions = vec![
CrateVersion::new(
"torsh",
crate::VERSION,
crate::VERSION_MAJOR,
crate::VERSION_MINOR,
crate::VERSION_PATCH,
),
CrateVersion::new(
"torsh-core",
crate::core::VERSION,
crate::core::VERSION_MAJOR,
crate::core::VERSION_MINOR,
crate::core::VERSION_PATCH,
),
CrateVersion::new(
"torsh-tensor",
crate::tensor::VERSION,
crate::tensor::VERSION_MAJOR,
crate::tensor::VERSION_MINOR,
crate::tensor::VERSION_PATCH,
),
CrateVersion::new(
"torsh-autograd",
crate::autograd::VERSION,
crate::autograd::VERSION_MAJOR,
crate::autograd::VERSION_MINOR,
crate::autograd::VERSION_PATCH,
),
];
#[cfg(feature = "nn")]
versions.push(CrateVersion::new(
"torsh-nn",
crate::nn::VERSION,
crate::nn::VERSION_MAJOR,
crate::nn::VERSION_MINOR,
crate::nn::VERSION_PATCH,
));
#[cfg(feature = "optim")]
versions.push(CrateVersion::new(
"torsh-optim",
crate::optim::VERSION,
crate::optim::VERSION_MAJOR,
crate::optim::VERSION_MINOR,
crate::optim::VERSION_PATCH,
));
#[cfg(feature = "data")]
versions.push(CrateVersion::new(
"torsh-data",
crate::data::VERSION,
crate::data::VERSION_MAJOR,
crate::data::VERSION_MINOR,
crate::data::VERSION_PATCH,
));
#[cfg(feature = "functional")]
versions.push(CrateVersion::new(
"torsh-functional",
crate::functional::VERSION,
crate::functional::VERSION_MAJOR,
crate::functional::VERSION_MINOR,
crate::functional::VERSION_PATCH,
));
#[cfg(feature = "text")]
versions.push(CrateVersion::new(
"torsh-text",
crate::text::VERSION,
crate::text::VERSION_MAJOR,
crate::text::VERSION_MINOR,
crate::text::VERSION_PATCH,
));
#[cfg(feature = "vision")]
versions.push(CrateVersion::new(
"torsh-vision",
crate::vision::VERSION,
crate::vision::VERSION_MAJOR,
crate::vision::VERSION_MINOR,
crate::vision::VERSION_PATCH,
));
#[cfg(feature = "sparse")]
versions.push(CrateVersion::new(
"torsh-sparse",
crate::sparse::VERSION,
crate::sparse::VERSION_MAJOR,
crate::sparse::VERSION_MINOR,
crate::sparse::VERSION_PATCH,
));
#[cfg(feature = "quantization")]
versions.push(CrateVersion::new(
"torsh-quantization",
crate::quantization::VERSION,
crate::quantization::VERSION_MAJOR,
crate::quantization::VERSION_MINOR,
crate::quantization::VERSION_PATCH,
));
#[cfg(feature = "special")]
versions.push(CrateVersion::new(
"torsh-special",
crate::special::VERSION,
crate::special::VERSION_MAJOR,
crate::special::VERSION_MINOR,
crate::special::VERSION_PATCH,
));
#[cfg(feature = "linalg")]
versions.push(CrateVersion::new(
"torsh-linalg",
crate::linalg::VERSION,
crate::linalg::VERSION_MAJOR,
crate::linalg::VERSION_MINOR,
crate::linalg::VERSION_PATCH,
));
#[cfg(feature = "profiler")]
versions.push(CrateVersion::new(
"torsh-profiler",
crate::profiler::VERSION,
crate::profiler::VERSION_MAJOR,
crate::profiler::VERSION_MINOR,
crate::profiler::VERSION_PATCH,
));
#[cfg(feature = "distributed")]
versions.push(CrateVersion::new(
"torsh-distributed",
crate::distributed::VERSION,
crate::distributed::VERSION_MAJOR,
crate::distributed::VERSION_MINOR,
crate::distributed::VERSION_PATCH,
));
#[cfg(feature = "jit")]
versions.push(CrateVersion::new(
"torsh-jit",
crate::jit::VERSION,
crate::jit::VERSION_MAJOR,
crate::jit::VERSION_MINOR,
crate::jit::VERSION_PATCH,
));
#[cfg(feature = "fx")]
versions.push(CrateVersion::new(
"torsh-fx",
crate::fx::VERSION,
crate::fx::VERSION_MAJOR,
crate::fx::VERSION_MINOR,
crate::fx::VERSION_PATCH,
));
#[cfg(feature = "hub")]
versions.push(CrateVersion::new(
"torsh-hub",
crate::hub::VERSION,
crate::hub::VERSION_MAJOR,
crate::hub::VERSION_MINOR,
crate::hub::VERSION_PATCH,
));
#[cfg(feature = "backend")]
versions.push(CrateVersion::new(
"torsh-backend",
crate::backend::VERSION,
crate::backend::VERSION_MAJOR,
crate::backend::VERSION_MINOR,
crate::backend::VERSION_PATCH,
));
versions
}
pub fn check_version_compatibility() -> Result<()> {
let versions = get_crate_versions();
let main_version = &versions[0];
for version in &versions[1..] {
if !main_version.is_compatible_with(version) {
return Err(TorshError::Other(format!(
"Version mismatch: {} {} is not compatible with {} {}",
main_version.name, main_version.version, version.name, version.version
)));
}
}
Ok(())
}
pub fn print_version_info() {
println!("ToRSh Crate Versions:");
for version in get_crate_versions() {
println!(" {}: {}", version.name, version.version);
}
}
}
pub mod prelude {
#[allow(ambiguous_glob_reexports)]
pub use crate::autograd::prelude::*;
#[allow(ambiguous_glob_reexports)]
pub use crate::core::prelude::*;
#[allow(ambiguous_glob_reexports)]
pub use crate::tensor::prelude::*;
#[cfg(feature = "nn")]
#[allow(ambiguous_glob_reexports)]
pub use crate::nn::prelude::*;
#[cfg(feature = "optim")]
#[allow(ambiguous_glob_reexports)]
pub use crate::optim::prelude::*;
#[cfg(feature = "data")]
#[allow(ambiguous_glob_reexports)]
pub use crate::data::prelude::*;
#[cfg(feature = "functional")]
pub use crate::functional::{conv2d, gelu, max_pool2d, relu, sigmoid, silu, softmax, tanh};
#[cfg(feature = "text")]
#[allow(ambiguous_glob_reexports)]
pub use crate::text::prelude::*;
#[cfg(feature = "vision")]
#[allow(ambiguous_glob_reexports)]
pub use crate::vision::prelude::*;
#[cfg(feature = "sparse")]
#[allow(ambiguous_glob_reexports)]
pub use crate::sparse::prelude::*;
#[cfg(feature = "quantization")]
#[allow(ambiguous_glob_reexports)]
pub use crate::quantization::prelude::*;
#[cfg(feature = "special")]
#[allow(ambiguous_glob_reexports)]
pub use crate::special::prelude::*;
#[cfg(feature = "linalg")]
#[allow(ambiguous_glob_reexports)]
pub use crate::linalg::prelude::*;
#[cfg(feature = "profiler")]
#[allow(ambiguous_glob_reexports)]
pub use crate::profiler::prelude::*;
#[cfg(feature = "distributed")]
#[allow(ambiguous_glob_reexports)]
pub use crate::distributed::prelude::*;
#[cfg(feature = "jit")]
#[allow(ambiguous_glob_reexports)]
pub use crate::jit::prelude::*;
#[cfg(feature = "fx")]
#[allow(ambiguous_glob_reexports)]
pub use crate::fx::prelude::*;
#[cfg(feature = "hub")]
#[allow(ambiguous_glob_reexports)]
pub use crate::hub::prelude::*;
#[cfg(feature = "backend")]
pub use crate::backend::prelude::*;
pub use crate::{DType, Device, DeviceType, Result, Shape, Tensor, TorshError};
pub use crate::{backward, enable_grad, is_grad_enabled, no_grad};
pub use crate::{
arange, eye, linspace, ones, ones_like, rand, rand_like, randint, randn, randn_like, zeros,
zeros_like,
};
pub use crate::tensor;
pub use crate::F;
pub use crate::version::{check_version_compatibility, print_version_info, CrateVersion};
pub use crate::{VERSION, VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH};
pub use crate::features::{
check_feature_requirements, get_enabled_features, print_feature_info,
};
pub use crate::macros::{device, shape, tensor_1d, tensor_2d};
}
#[allow(non_snake_case)]
pub mod F {
#[cfg(feature = "nn")]
#[allow(ambiguous_glob_reexports)]
pub use crate::nn::functional::*;
#[cfg(feature = "functional")]
pub use crate::functional::{
avg_pool1d as AvgPool1d,
avg_pool2d as AvgPool2d,
avg_pool3d as AvgPool3d,
batch_norm as BatchNorm,
binary_cross_entropy as BCELoss,
conv1d as Conv1d,
conv2d as Conv2d,
conv3d as Conv3d,
cross_entropy as CrossEntropyLoss,
dropout as Dropout,
dropout2d as Dropout2d,
dropout3d as Dropout3d,
elu as ELU,
gelu as GELU,
group_norm as GroupNorm,
instance_norm as InstanceNorm,
l1_loss as L1Loss,
layer_norm as LayerNorm,
log_softmax as LogSoftmax,
max_pool1d as MaxPool1d,
max_pool2d as MaxPool2d,
max_pool3d as MaxPool3d,
mse_loss as MSELoss,
relu as ReLU,
sigmoid as Sigmoid,
silu as SiLU,
softmax as Softmax,
tanh as Tanh,
};
}
pub mod macros {
#[macro_export]
macro_rules! tensor_1d {
[$($x:expr),* $(,)?] => {
$crate::tensor::creation::tensor_1d(&[$($x),*])
};
}
#[macro_export]
macro_rules! tensor_2d {
[$([$($x:expr),* $(,)?]),* $(,)?] => {{
let rows: &[&[_]] = &[$(
&[$($x),*]
),*];
$crate::tensor::creation::tensor_2d(rows)
}};
}
#[macro_export]
macro_rules! device {
("cpu") => {
$crate::DeviceType::Cpu
};
("cuda") => {
$crate::DeviceType::Cuda(0)
};
("cuda:0") => {
$crate::DeviceType::Cuda(0)
};
($device_str:expr) => {
compile_error!("Use DeviceType directly for complex device specifications")
};
}
#[macro_export]
macro_rules! shape {
[$($dim:expr),* $(,)?] => {
$crate::Shape::new(vec![$($dim),*])
};
}
pub use {device, shape, tensor_1d, tensor_2d};
}
pub mod features {
#[derive(Debug, Clone)]
pub struct FeatureInfo {
pub name: &'static str,
pub enabled: bool,
pub description: &'static str,
pub category: FeatureCategory,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum FeatureCategory {
Core,
Module,
Advanced,
Backend,
Serialization,
Performance,
Development,
Compatibility,
}
pub fn get_enabled_features() -> Vec<FeatureInfo> {
vec![
FeatureInfo {
name: "std",
enabled: cfg!(feature = "std"),
description: "Standard library support",
category: FeatureCategory::Core,
},
FeatureInfo {
name: "no_std",
enabled: cfg!(feature = "no_std"),
description: "No standard library support",
category: FeatureCategory::Core,
},
FeatureInfo {
name: "nn",
enabled: cfg!(feature = "nn"),
description: "Neural network modules",
category: FeatureCategory::Module,
},
FeatureInfo {
name: "optim",
enabled: cfg!(feature = "optim"),
description: "Optimization algorithms",
category: FeatureCategory::Module,
},
FeatureInfo {
name: "data",
enabled: cfg!(feature = "data"),
description: "Data loading and preprocessing",
category: FeatureCategory::Module,
},
FeatureInfo {
name: "text",
enabled: cfg!(feature = "text"),
description: "Text processing capabilities",
category: FeatureCategory::Module,
},
FeatureInfo {
name: "vision",
enabled: cfg!(feature = "vision"),
description: "Computer vision utilities",
category: FeatureCategory::Module,
},
FeatureInfo {
name: "functional",
enabled: cfg!(feature = "functional"),
description: "Functional operations API",
category: FeatureCategory::Module,
},
FeatureInfo {
name: "sparse",
enabled: cfg!(feature = "sparse"),
description: "Sparse tensor operations",
category: FeatureCategory::Advanced,
},
FeatureInfo {
name: "quantization",
enabled: cfg!(feature = "quantization"),
description: "Model quantization",
category: FeatureCategory::Advanced,
},
FeatureInfo {
name: "special",
enabled: cfg!(feature = "special"),
description: "Special mathematical functions",
category: FeatureCategory::Advanced,
},
FeatureInfo {
name: "linalg",
enabled: cfg!(feature = "linalg"),
description: "Linear algebra operations",
category: FeatureCategory::Advanced,
},
FeatureInfo {
name: "profiler",
enabled: cfg!(feature = "profiler"),
description: "Performance profiling",
category: FeatureCategory::Advanced,
},
FeatureInfo {
name: "distributed",
enabled: cfg!(feature = "distributed"),
description: "Distributed computing",
category: FeatureCategory::Advanced,
},
FeatureInfo {
name: "jit",
enabled: cfg!(feature = "jit"),
description: "Just-in-time compilation",
category: FeatureCategory::Advanced,
},
FeatureInfo {
name: "fx",
enabled: cfg!(feature = "fx"),
description: "Graph transformations",
category: FeatureCategory::Advanced,
},
FeatureInfo {
name: "hub",
enabled: cfg!(feature = "hub"),
description: "Model hub integration",
category: FeatureCategory::Advanced,
},
FeatureInfo {
name: "backend",
enabled: cfg!(feature = "backend"),
description: "Unified backend system",
category: FeatureCategory::Backend,
},
FeatureInfo {
name: "cuda",
enabled: cfg!(feature = "cuda"),
description: "CUDA GPU backend with execution engine",
category: FeatureCategory::Backend,
},
FeatureInfo {
name: "metal",
enabled: cfg!(feature = "metal"),
description: "Apple Metal GPU backend",
category: FeatureCategory::Backend,
},
FeatureInfo {
name: "wgpu",
enabled: cfg!(feature = "wgpu"),
description: "WebGPU backend",
category: FeatureCategory::Backend,
},
FeatureInfo {
name: "rocm",
enabled: cfg!(feature = "rocm"),
description: "AMD ROCm GPU backend",
category: FeatureCategory::Backend,
},
FeatureInfo {
name: "serialize",
enabled: cfg!(feature = "serialize"),
description: "Basic serialization support",
category: FeatureCategory::Serialization,
},
FeatureInfo {
name: "serialize-hdf5",
enabled: cfg!(feature = "serialize-hdf5"),
description: "HDF5 serialization",
category: FeatureCategory::Serialization,
},
FeatureInfo {
name: "serialize-arrow",
enabled: cfg!(feature = "serialize-arrow"),
description: "Apache Arrow serialization",
category: FeatureCategory::Serialization,
},
FeatureInfo {
name: "serialize-onnx",
enabled: cfg!(feature = "serialize-onnx"),
description: "ONNX serialization",
category: FeatureCategory::Serialization,
},
FeatureInfo {
name: "simd",
enabled: cfg!(feature = "simd"),
description: "SIMD optimizations",
category: FeatureCategory::Performance,
},
FeatureInfo {
name: "parallel",
enabled: cfg!(feature = "parallel"),
description: "Parallel execution",
category: FeatureCategory::Performance,
},
FeatureInfo {
name: "fast-math",
enabled: cfg!(feature = "fast-math"),
description: "Fast math optimizations",
category: FeatureCategory::Performance,
},
FeatureInfo {
name: "debug",
enabled: cfg!(feature = "debug"),
description: "Debug information",
category: FeatureCategory::Development,
},
FeatureInfo {
name: "trace",
enabled: cfg!(feature = "trace"),
description: "Tracing support",
category: FeatureCategory::Development,
},
FeatureInfo {
name: "bench",
enabled: cfg!(feature = "bench"),
description: "Benchmarking utilities",
category: FeatureCategory::Development,
},
FeatureInfo {
name: "python",
enabled: cfg!(feature = "python"),
description: "Python interoperability",
category: FeatureCategory::Compatibility,
},
FeatureInfo {
name: "numpy",
enabled: cfg!(feature = "numpy"),
description: "NumPy compatibility",
category: FeatureCategory::Compatibility,
},
FeatureInfo {
name: "pytorch-compat",
enabled: cfg!(feature = "pytorch-compat"),
description: "PyTorch compatibility",
category: FeatureCategory::Compatibility,
},
]
}
pub fn check_feature_requirements(required_features: &[&str]) -> crate::Result<()> {
let enabled_features = get_enabled_features();
let enabled_names: std::collections::HashSet<&str> = enabled_features
.iter()
.filter(|f| f.enabled)
.map(|f| f.name)
.collect();
for required in required_features {
if !enabled_names.contains(required) {
return Err(crate::TorshError::Other(format!(
"Required feature '{}' is not enabled",
required
)));
}
}
Ok(())
}
pub fn print_feature_info() {
use std::collections::HashMap;
let features = get_enabled_features();
let mut by_category: HashMap<FeatureCategory, Vec<&FeatureInfo>> = HashMap::new();
for feature in &features {
by_category
.entry(feature.category.clone())
.or_default()
.push(feature);
}
println!("ToRSh Enabled Features:");
for (category, features) in by_category {
let category_name = match category {
FeatureCategory::Core => "Core",
FeatureCategory::Module => "Modules",
FeatureCategory::Advanced => "Advanced",
FeatureCategory::Backend => "Backends",
FeatureCategory::Serialization => "Serialization",
FeatureCategory::Performance => "Performance",
FeatureCategory::Development => "Development",
FeatureCategory::Compatibility => "Compatibility",
};
println!("\n {}:", category_name);
for feature in features {
let status = if feature.enabled { "✓" } else { "✗" };
println!(" {} {}: {}", status, feature.name, feature.description);
}
}
}
pub fn get_feature_stats() -> std::collections::HashMap<FeatureCategory, (usize, usize)> {
use std::collections::HashMap;
let features = get_enabled_features();
let mut stats: HashMap<FeatureCategory, (usize, usize)> = HashMap::new();
for feature in features {
let (enabled, total) = stats.entry(feature.category.clone()).or_insert((0, 0));
*total += 1;
if feature.enabled {
*enabled += 1;
}
}
stats
}
}
pub fn check_version(required_major: u32, required_minor: u32) -> Result<()> {
if VERSION_MAJOR < required_major
|| (VERSION_MAJOR == required_major && VERSION_MINOR < required_minor)
{
return Err(TorshError::Other(format!(
"ToRSh version {}.{} or higher required, but got {}.{}",
required_major, required_minor, VERSION_MAJOR, VERSION_MINOR
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_operations() {
let a = tensor![1.0, 2.0, 3.0].unwrap();
let b = tensor![4.0, 5.0, 6.0].unwrap();
let c = a.add(&b).unwrap();
assert_eq!(c.shape().dims(), &[3]);
}
#[test]
fn test_version() {
assert_eq!(VERSION, "0.1.2");
check_version(0, 1).unwrap();
}
}