#![cfg_attr(not(feature = "std"), no_std)]
#![allow(clippy::too_many_arguments)]
#![allow(clippy::uninlined_format_args)]
#![allow(clippy::new_without_default)]
#![allow(clippy::if_same_then_else)]
#![allow(clippy::needless_range_loop)]
#![allow(clippy::implicit_saturating_sub)]
#![allow(clippy::unwrap_or_default)]
#![allow(clippy::manual_div_ceil)]
#![allow(clippy::wrong_self_convention)]
#![allow(clippy::type_complexity)]
#![allow(clippy::not_unsafe_ptr_arg_deref)]
#![allow(clippy::inherent_to_string)]
#![allow(clippy::derivable_impls)]
#![allow(clippy::needless_borrows_for_generic_args)]
#![allow(clippy::field_reassign_with_default)]
#![allow(clippy::mut_from_ref)]
#![allow(clippy::missing_transmute_annotations)]
#![allow(clippy::should_implement_trait)]
#![allow(clippy::redundant_closure)]
#![allow(clippy::manual_flatten)]
#![allow(clippy::useless_conversion)]
#![allow(clippy::identity_op)]
#![allow(clippy::len_without_is_empty)]
#![allow(dead_code)]
#[cfg(not(feature = "std"))]
extern crate alloc;
#[derive(Debug, Clone)]
pub enum BackendError {
InvalidArgument(String),
UnsupportedOperation(String),
QuantizationError(String),
InvalidBuffer { message: String },
Runtime { message: String },
AllocationFailed(String),
SynchronizationFailed(String),
}
impl std::fmt::Display for BackendError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BackendError::InvalidArgument(msg) => write!(f, "Invalid argument: {}", msg),
BackendError::UnsupportedOperation(msg) => write!(f, "Unsupported operation: {}", msg),
BackendError::QuantizationError(msg) => write!(f, "Quantization error: {}", msg),
BackendError::InvalidBuffer { message } => write!(f, "Invalid buffer: {}", message),
BackendError::Runtime { message } => write!(f, "Runtime error: {}", message),
BackendError::AllocationFailed(msg) => write!(f, "Allocation failed: {}", msg),
BackendError::SynchronizationFailed(msg) => {
write!(f, "Synchronization failed: {}", msg)
}
}
}
}
impl std::error::Error for BackendError {}
pub mod adaptive_kernel_selection;
pub mod backend;
pub mod buffer;
pub mod convolution;
pub mod cross_backend_transfer;
pub mod cross_backend_validation;
pub mod deadlock_prevention;
pub mod device;
pub mod error;
pub mod fft;
pub mod hardware_optimization_tests;
pub mod introspection;
pub mod jit_compiler;
pub mod kernel;
pub mod kernel_generation;
pub mod memory;
pub mod memory_defrag;
pub mod memory_profiler;
pub mod performance_modeling;
pub mod performance_tuning;
pub mod profiler;
pub mod property_tests;
pub mod quantization;
pub mod rnn;
pub mod sparse_ops;
pub mod unified_memory_pool;
pub mod version_compat;
pub mod zero_copy;
#[cfg(feature = "cpu")]
pub mod cpu;
#[cfg(feature = "cuda")]
pub mod cuda;
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
pub mod metal;
#[cfg(feature = "rocm")]
pub mod rocm;
#[cfg(feature = "webgpu")]
pub mod webgpu;
pub use adaptive_kernel_selection::{
AdaptiveKernelSelector, AdaptiveSelectionConfig, BenchmarkResult, BenchmarkResults,
CustomKernel, HybridConfig, KernelCharacteristics, KernelConstraints, KernelExecutor,
KernelImplementation, KernelInputs, KernelOutputs, KernelParameter, KernelPerformanceRecord,
KernelRegistry, KernelSelection, KernelUsageStats, KernelVariant, MLBasedConfig, MLModelType,
MLTrainingParams, PerformanceTracker, ResourceRequirements, ScalabilityCharacteristics,
ScalingBehavior, ScoreBasedConfig, SelectionAccuracyTracker, SelectionAlgorithm,
SelectionReason, SelectionStatistics,
};
pub use backend::{
Backend, BackendCapabilities, BackendCore, BackendDeviceManager, BackendExecutor,
BackendExtension, BackendExtensionRegistry, BackendFactory, BackendLifecycle,
BackendOperations, BackendOps, BackendPlugin, BackendRegistry, BackendResourceManager,
BackendType, CapabilityValue, DeviceEnumerator, ExecutionModel, ExtendedCapabilities,
HardwareFeature, MemoryHierarchy, OperationsBundle, PerformanceHints, PluginMetadata,
PrecisionMode, ResourceLimits, ResourceStatistics, ResourceUsage, ScopedResource,
};
pub use buffer::{Buffer, BufferDescriptor, BufferHandle, BufferUsage, BufferView, MemoryLocation};
pub type BufferError = BackendError;
pub use convolution::{
algorithms as conv_algorithms, ConvolutionAlgorithm, ConvolutionConfig, ConvolutionOps,
ConvolutionPerformanceHints, ConvolutionType, DefaultConvolutionOps, PaddingMode,
};
pub use cross_backend_transfer::CrossBackendTransferManager;
pub use cross_backend_validation::{
compare_f32_values, compare_f64_values, run_cross_backend_validation, CrossBackendValidator,
};
pub use device::{
Device, DeviceConfiguration, DeviceDiscovery, DeviceFeature, DeviceInfo, DeviceManager,
DevicePerformanceInfo, DeviceRequirements, DeviceType, DeviceUtils,
};
pub use error::{BackendResult, ErrorCategory, ErrorContext, ErrorSeverity};
pub use fft::{
convenience as fft_convenience, DefaultFftExecutor, DefaultFftOps, FftDirection, FftExecutor,
FftNormalization, FftOps, FftPlan, FftType,
};
pub use hardware_optimization_tests::{
run_hardware_optimization_tests, run_lightweight_hardware_tests, HardwareOptimizationTester,
};
pub use kernel::{Kernel, KernelDescriptor, KernelHandle, KernelLaunchConfig, KernelMetadata};
pub use memory::{
AccessPattern, AllocationHint, AllocationLifetime, AllocationStrategy, CompactionResult,
DefragmentationPolicy, DefragmentationPriority, DefragmentationResult, DefragmentationStrategy,
FragmentationInfo, FragmentationSeverity, FreeListPool, LeakReport, LeakSeverity, LeakType,
MemoryAdvice, MemoryManager, MemoryManagerFactory, MemoryPool, MemoryPoolConfig, MemoryStats,
PoolStats,
};
pub use memory_defrag::{
CompactionPlan, DefragmentationManager, DefragmentationRequest, DefragmentationStats,
DefragmentationTask, MemoryBlock, MemoryLayout, TaskStatus,
};
pub use memory_profiler::{
AccessType, AllocationContext, AllocationUsageStats, HintSeverity, MemoryAllocation,
MemoryPressureEvent, MemoryProfiler, MemoryProfilerConfig, MemorySnapshot, MemoryType,
PerformanceHint, PerformanceHintType, PressureLevel,
};
pub use performance_modeling::{
AnomalyDetector, AnomalySeverity, AnomalyType, ComplexityClass, CorrelationAnalyzer,
CorrelationResult, EnvironmentalFactors, ModelAccuracy, ModelComplexity, ModelTrainingResult,
PatternType, PerformanceAnomaly, PerformanceCharacteristics, PerformanceMeasurement,
PerformanceModel, PerformanceReport, PerformanceSample, PerformanceTrend, RealtimeStatistics,
RuntimeMonitor, RuntimePerformanceModeler, TrendDirection, WorkloadPattern,
};
pub use performance_tuning::{
analyze_workload_optimization_opportunities,
create_default_constraints,
create_default_system_state,
create_energy_budget_constraints,
create_image_processing_workload,
create_ml_inference_workload,
create_ml_training_workload,
create_performance_optimized_system_state,
create_power_efficient_system_state,
create_realtime_constraints,
create_sample_workload,
create_throughput_constraints,
new_coordinator,
recommend_backend,
AccessPattern as PerfAccessPattern,
ActualPerformance,
BackendTuningStrategy,
DataType,
GlobalPerformanceStats,
MemoryAllocationStrategy,
NumaTopologyState,
OperationType,
OptimizationLevel,
PerformanceFeedback,
PerformancePrediction,
PerformanceTuningCoordinator,
PowerEfficiencyMode,
PowerState,
SchedulingStrategy,
StrategyMetrics,
SystemState,
ThermalState,
TuningConstraints,
TuningParameters,
TuningRecommendation,
TuningValue,
WorkloadCharacteristics,
};
pub use profiler::{Profiler, ProfilerEvent, ProfilerStats, SimpleProfiler};
pub use quantization::{
CalibrationMethod, QuantizationCalibrator, QuantizationHardwareFeatures, QuantizationOps,
QuantizationParams, QuantizationScheme, QuantizedDType, QuantizedTensor, SimdQuantizationOps,
};
pub use rnn::{
activations as rnn_activations, cells as rnn_cells, DefaultRnnOps, RnnActivation, RnnCellType,
RnnConfig, RnnDirection, RnnOps, RnnOutput, RnnPerformanceHints,
};
pub use sparse_ops::{
DefaultSparseOps, SparseFormat, SparseFormatConverter, SparseMatrix, SparseOperation,
SparseOps, SparseOptimizationHints,
};
pub use unified_memory_pool::{
CpuMemoryPool, CudaMemoryPool, MetalMemoryPool, RocmMemoryPool, UnifiedMemoryPool,
WebGpuMemoryPool,
};
pub use version_compat::{
BackendDependency, CompatibilityReport, DependencyStatus, Version, VersionCompatibilityChecker,
VersionError, VersionErrorContextExt, VersionRange,
};
pub use zero_copy::{
TransferDirection, TransferMode, ZeroCopyCapabilities, ZeroCopyManager, ZeroCopyStats,
ZeroCopyTransfer,
};
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
pub const VERSION_MAJOR: u32 = 0;
pub const VERSION_MINOR: u32 = 1;
pub const VERSION_PATCH: u32 = 0;
#[cfg(feature = "cuda")]
pub fn is_available() -> bool {
cuda::is_available()
}
#[cfg(not(feature = "cuda"))]
pub fn is_available() -> bool {
false
}
#[cfg(feature = "cpu")]
pub use cpu::{prepare_tensor_data, prepare_tensor_data_mut, SciRS2CpuBackend};
use torsh_core::error::TorshError;
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, vec::Vec};
pub struct BackendBuilder {
backend_type: BackendType,
device_id: usize,
memory_pool_config: Option<MemoryPoolConfig>,
num_threads: Option<usize>,
enable_profiling: bool,
}
impl Default for BackendBuilder {
fn default() -> Self {
Self::new()
}
}
impl BackendBuilder {
pub fn new() -> Self {
Self {
backend_type: BackendType::Auto,
device_id: 0,
memory_pool_config: None,
num_threads: None,
enable_profiling: false,
}
}
pub fn backend_type(mut self, backend_type: BackendType) -> Self {
self.backend_type = backend_type;
self
}
pub fn device_id(mut self, device_id: usize) -> Self {
self.device_id = device_id;
self
}
pub fn memory_pool(mut self, config: MemoryPoolConfig) -> Self {
self.memory_pool_config = Some(config);
self
}
pub fn num_threads(mut self, num_threads: usize) -> Self {
self.num_threads = Some(num_threads);
self
}
pub fn enable_profiling(mut self, enable: bool) -> Self {
self.enable_profiling = enable;
self
}
pub fn build(self) -> BackendResult<Box<dyn Backend>> {
match self.backend_type {
BackendType::Auto => Self::auto_select(self),
BackendType::Cpu => Self::build_cpu(self),
BackendType::Cuda => Self::build_cuda(self),
BackendType::Metal => Self::build_metal(self),
BackendType::Rocm => Self::build_rocm(self),
BackendType::WebGpu => Self::build_webgpu(self),
}
}
fn auto_select(builder: Self) -> BackendResult<Box<dyn Backend>> {
#[cfg(feature = "cuda")]
if let Ok(backend) = Self::build_cuda(builder.clone()) {
return Ok(backend);
}
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
if let Ok(backend) = Self::build_metal(builder.clone()) {
return Ok(backend);
}
#[cfg(feature = "rocm")]
if let Ok(backend) = Self::build_rocm(builder.clone()) {
return Ok(backend);
}
#[cfg(feature = "webgpu")]
if let Ok(backend) = Self::build_webgpu(builder.clone()) {
return Ok(backend);
}
Self::build_cpu(builder)
}
#[cfg(feature = "cpu")]
fn build_cpu(builder: Self) -> BackendResult<Box<dyn Backend>> {
let mut cpu_builder = cpu::CpuBackend::builder();
if let Some(num_threads) = builder.num_threads {
cpu_builder = cpu_builder.num_threads(num_threads);
}
if let Some(pool_config) = builder.memory_pool_config {
cpu_builder = cpu_builder.memory_pool(pool_config);
}
Ok(Box::new(cpu_builder.build()?))
}
#[cfg(not(feature = "cpu"))]
fn build_cpu(_builder: Self) -> BackendResult<Box<dyn Backend>> {
Err(TorshError::BackendError("CPU backend not enabled".into()))
}
#[cfg(feature = "cuda")]
fn build_cuda(builder: Self) -> BackendResult<Box<dyn Backend>> {
let mut cuda_builder = cuda::CudaBackend::builder();
cuda_builder = cuda_builder.device(builder.device_id);
if let Some(pool_config) = builder.memory_pool_config {
cuda_builder = cuda_builder.memory_pool(pool_config);
}
Ok(Box::new(cuda_builder.build()?))
}
#[cfg(not(feature = "cuda"))]
fn build_cuda(_builder: Self) -> BackendResult<Box<dyn Backend>> {
Err(TorshError::BackendError("CUDA backend not enabled".into()))
}
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
fn build_metal(builder: Self) -> BackendResult<Box<dyn Backend>> {
let mut metal_builder = metal::MetalBackend::builder();
if let Some(pool_config) = builder.memory_pool_config {
metal_builder = metal_builder.memory_pool(pool_config);
}
Ok(Box::new(metal_builder.build()?))
}
#[cfg(not(all(feature = "metal", target_os = "macos", target_arch = "aarch64")))]
fn build_metal(_builder: Self) -> BackendResult<Box<dyn Backend>> {
Err(TorshError::BackendError("Metal backend not enabled".into()))
}
#[cfg(feature = "rocm")]
fn build_rocm(_builder: Self) -> BackendResult<Box<dyn Backend>> {
Err(TorshError::BackendError(
"ROCm backend not yet implemented".into(),
))
}
#[cfg(not(feature = "rocm"))]
fn build_rocm(_builder: Self) -> BackendResult<Box<dyn Backend>> {
Err(TorshError::BackendError("ROCm backend not enabled".into()))
}
#[cfg(feature = "webgpu")]
fn build_webgpu(builder: Self) -> BackendResult<Box<dyn Backend>> {
let mut webgpu_builder = webgpu::WebGpuBackendBuilder::new();
webgpu_builder = webgpu_builder.device_id(builder.device_id);
if let Some(pool_config) = builder.memory_pool_config {
if let Some(max_size) = pool_config.max_size {
webgpu_builder = webgpu_builder.max_buffer_size(max_size as u64);
}
}
webgpu_builder = webgpu_builder.enable_pipeline_cache(true);
Ok(Box::new(webgpu_builder.build()))
}
#[cfg(not(feature = "webgpu"))]
fn build_webgpu(_builder: Self) -> BackendResult<Box<dyn Backend>> {
Err(TorshError::BackendError(
"WebGPU backend not enabled".into(),
))
}
}
impl Clone for BackendBuilder {
fn clone(&self) -> Self {
Self {
backend_type: self.backend_type,
device_id: self.device_id,
memory_pool_config: self.memory_pool_config.clone(),
num_threads: self.num_threads,
enable_profiling: self.enable_profiling,
}
}
}
pub fn auto() -> BackendResult<Box<dyn Backend>> {
BackendBuilder::new().build()
}
pub fn cpu() -> BackendResult<Box<dyn Backend>> {
BackendBuilder::new().backend_type(BackendType::Cpu).build()
}
pub fn cuda() -> BackendResult<Box<dyn Backend>> {
BackendBuilder::new()
.backend_type(BackendType::Cuda)
.build()
}
pub fn metal() -> BackendResult<Box<dyn Backend>> {
BackendBuilder::new()
.backend_type(BackendType::Metal)
.build()
}
#[allow(clippy::vec_init_then_push)]
pub fn available_backends() -> Vec<BackendType> {
let mut backends = vec![];
#[cfg(feature = "cpu")]
backends.push(BackendType::Cpu);
#[cfg(feature = "cuda")]
if cuda::is_available() {
backends.push(BackendType::Cuda);
}
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
if metal::is_available() {
backends.push(BackendType::Metal);
}
#[cfg(feature = "rocm")]
if rocm::is_available() {
backends.push(BackendType::Rocm);
}
#[cfg(feature = "webgpu")]
if webgpu::is_available() {
backends.push(BackendType::WebGpu);
}
backends
}
pub fn enumerate_all_devices() -> BackendResult<Vec<(BackendType, Vec<Device>)>> {
let mut all_devices = Vec::new();
#[cfg(feature = "cpu")]
{
match cpu() {
Ok(backend) => {
if let Ok(devices) = backend.devices() {
all_devices.push((BackendType::Cpu, devices));
}
}
Err(_) => {
}
}
}
#[cfg(feature = "cuda")]
if cuda::is_available() {
for device_id in 0..cuda::device_count().unwrap_or(0) {
match BackendBuilder::new()
.backend_type(BackendType::Cuda)
.device_id(device_id as usize)
.build()
{
Ok(backend) => {
if let Ok(devices) = backend.devices() {
all_devices.push((BackendType::Cuda, devices));
break; }
}
Err(_) => continue,
}
}
}
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
if metal::is_available() {
match BackendBuilder::new()
.backend_type(BackendType::Metal)
.build()
{
Ok(backend) => {
if let Ok(devices) = backend.devices() {
if !devices.is_empty() {
all_devices.push((BackendType::Metal, devices));
}
}
}
Err(_) => {
}
}
}
#[cfg(feature = "webgpu")]
if webgpu::is_available() {
match BackendBuilder::new()
.backend_type(BackendType::WebGpu)
.build()
{
Ok(backend) => {
if let Ok(devices) = backend.devices() {
if !devices.is_empty() {
all_devices.push((BackendType::WebGpu, devices));
}
}
}
Err(_) => {
}
}
}
Ok(all_devices)
}
pub fn find_best_device(
selector: Option<device::DeviceSelector>,
) -> BackendResult<(BackendType, Device)> {
let all_devices = enumerate_all_devices()?;
if all_devices.is_empty() {
return Err(TorshError::BackendError("No devices available".into()));
}
let selector = selector.unwrap_or_default();
for (backend_type, devices) in &all_devices {
for device in devices {
if selector.matches(device) {
return Ok((*backend_type, device.clone()));
}
}
}
let preference_order = [
BackendType::Cuda,
BackendType::Metal,
BackendType::WebGpu,
BackendType::Cpu,
];
for preferred_backend in &preference_order {
for (backend_type, devices) in &all_devices {
if backend_type == preferred_backend && !devices.is_empty() {
return Ok((*backend_type, devices[0].clone()));
}
}
}
let (backend_type, devices) = &all_devices[0];
Ok((*backend_type, devices[0].clone()))
}
pub fn device_count(backend_type: BackendType) -> BackendResult<usize> {
match backend_type {
BackendType::Cpu => Ok(1),
#[cfg(feature = "cuda")]
BackendType::Cuda => {
if cuda::is_available() {
Ok(cuda::device_count().unwrap_or(0) as usize)
} else {
Ok(0)
}
}
#[cfg(not(feature = "cuda"))]
BackendType::Cuda => Ok(0),
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
BackendType::Metal => {
if metal::is_available() {
Ok(metal::device_count().unwrap_or(0))
} else {
Ok(0)
}
}
#[cfg(not(all(feature = "metal", target_os = "macos", target_arch = "aarch64")))]
BackendType::Metal => Ok(0),
#[cfg(feature = "webgpu")]
BackendType::WebGpu => {
if webgpu::is_available() {
Ok(webgpu::device_count().unwrap_or(0))
} else {
Ok(0)
}
}
#[cfg(not(feature = "webgpu"))]
BackendType::WebGpu => Ok(0),
BackendType::Rocm => Ok(0), BackendType::Auto => {
let mut total = 0;
for backend in available_backends() {
if backend != BackendType::Auto {
total += device_count(backend)?;
}
}
Ok(total)
}
}
}
pub mod prelude {
pub use crate::{
auto,
available_backends,
compare_f32_values,
compare_f64_values,
cpu,
cuda,
device_count,
enumerate_all_devices,
find_best_device,
metal,
run_cross_backend_validation,
run_hardware_optimization_tests,
run_lightweight_hardware_tests,
AdaptiveKernelSelector,
Backend,
BackendBuilder,
BackendCapabilities,
BackendOps,
BackendPlugin,
BackendRegistry,
BackendResourceManager,
BackendResult,
BackendType,
BenchmarkResult,
Buffer,
CompactionPlan,
CrossBackendValidator,
DefragmentationManager,
DefragmentationStats,
Device,
ExecutionModel,
ExtendedCapabilities,
HardwareFeature,
HardwareOptimizationTester,
KernelImplementation,
KernelSelection,
KernelVariant,
MemoryHierarchy,
MemoryPool,
OperationType,
PerformanceMeasurement,
PerformancePrediction,
PerformanceReport,
PerformanceTrend,
PerformanceTuningCoordinator,
PluginMetadata,
PrecisionMode,
ResourceLimits,
ResourceStatistics,
ResourceUsage,
RuntimePerformanceModeler,
SelectionAlgorithm,
TransferDirection,
TransferMode,
TuningParameters,
TuningRecommendation,
WorkloadCharacteristics,
ZeroCopyCapabilities,
ZeroCopyManager,
ZeroCopyStats,
ZeroCopyTransfer,
VERSION,
VERSION_MAJOR,
VERSION_MINOR,
VERSION_PATCH,
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_builder() {
let builder = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.device_id(0);
let result = builder.build();
if let Err(e) = &result {
eprintln!("Backend build failed: {:?}", e);
}
assert!(result.is_ok());
}
#[test]
fn test_available_backends() {
let backends = available_backends();
assert!(!backends.is_empty());
assert!(backends.contains(&BackendType::Cpu));
}
#[test]
fn test_device_count() {
assert_eq!(device_count(BackendType::Cpu).unwrap(), 1);
let auto_count = device_count(BackendType::Auto).unwrap();
assert!(auto_count >= 1);
for backend_type in available_backends() {
if backend_type != BackendType::Auto {
let count = device_count(backend_type).unwrap();
assert!(count < usize::MAX); }
}
}
#[test]
fn test_enumerate_all_devices() {
let devices = enumerate_all_devices().unwrap();
assert!(!devices.is_empty());
let has_cpu = devices
.iter()
.any(|(backend_type, _)| *backend_type == BackendType::Cpu);
assert!(has_cpu);
for (backend_type, device_list) in &devices {
assert!(
!device_list.is_empty(),
"Backend {:?} should have at least one device",
backend_type
);
}
}
#[test]
fn test_find_best_device() {
let (backend_type, device) = find_best_device(None).unwrap();
assert!(matches!(
backend_type,
BackendType::Cpu | BackendType::Cuda | BackendType::Metal | BackendType::WebGpu
));
assert!(!device.name().is_empty());
}
#[test]
fn test_find_best_device_with_selector() {
use crate::device::{DeviceSelector, DeviceType};
let selector = DeviceSelector::new().with_device_type(DeviceType::Cpu);
let result = find_best_device(Some(selector));
assert!(result.is_ok());
let (backend_type, device) = result.unwrap();
assert_eq!(backend_type, BackendType::Cpu);
assert_eq!(device.device_type(), torsh_core::device::DeviceType::Cpu);
}
#[test]
fn test_unified_error_handling() {
use crate::error::{conversion, ErrorContext};
let context = ErrorContext::new("test_operation")
.with_backend("TestBackend")
.with_device("test:0")
.with_details("test details");
let formatted = context.format();
assert!(formatted.contains("test_operation"));
assert!(formatted.contains("backend: TestBackend"));
assert!(formatted.contains("device: test:0"));
assert!(formatted.contains("details: test details"));
let cuda_error =
conversion::cuda_error_with_context("Test CUDA error", "test_kernel", Some(0));
let error_str = cuda_error.to_string();
assert!(error_str.contains("CUDA"));
assert!(error_str.contains("test_kernel"));
assert!(error_str.contains("cuda:0"));
let cpu_error = conversion::cpu_error_with_context("Test CPU error", "test_operation");
let error_str = cpu_error.to_string();
assert!(error_str.contains("CPU"));
assert!(error_str.contains("test_operation"));
let memory_error =
conversion::memory_error_with_context("Out of memory", 1024, "CUDA", Some("cuda:0"));
let error_str = memory_error.to_string();
assert!(error_str.contains("memory_allocation"));
assert!(error_str.contains("1024 bytes"));
assert!(error_str.contains("CUDA"));
assert!(error_str.contains("cuda:0"));
}
#[test]
fn test_error_context_extension() {
use torsh_core::error::TorshError;
let result: Result<(), TorshError> =
Err(TorshError::ComputeError("Test error".to_string()));
let with_context = crate::error::ErrorContextExt::with_operation(result, "test_operation");
assert!(with_context.is_err());
let error_str = with_context.unwrap_err().to_string();
assert!(error_str.contains("test_operation"));
assert!(error_str.contains("Test error"));
}
#[test]
fn test_invalid_device_id_error() {
let builder = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.device_id(999);
let backend = builder.build().unwrap();
let result = backend.create_device(999);
assert!(result.is_err());
let error_str = result.unwrap_err().to_string();
assert!(error_str.contains("999"));
assert!(error_str.contains("not found"));
}
#[test]
fn test_backend_builder_invalid_thread_count() {
let builder = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.num_threads(0);
let result = builder.build();
assert!(result.is_ok());
}
#[test]
fn test_backend_builder_extreme_thread_count() {
let builder = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.num_threads(10000);
let result = builder.build();
if let Err(ref e) = result {
eprintln!("Backend build failed with extreme thread count: {:?}", e);
}
assert!(result.is_ok());
}
#[test]
fn test_unavailable_backend_selection() {
#[cfg(not(feature = "cuda"))]
{
let builder = BackendBuilder::new().backend_type(BackendType::Cuda);
let result = builder.build();
assert!(result.is_err());
let error_str = result.unwrap_err().to_string();
assert!(error_str.contains("not enabled"));
}
#[cfg(not(feature = "metal"))]
{
let builder = BackendBuilder::new().backend_type(BackendType::Metal);
let result = builder.build();
assert!(result.is_err());
let error_str = result.unwrap_err().to_string();
assert!(error_str.contains("not enabled"));
}
}
#[test]
fn test_device_count_edge_cases() {
#[cfg(not(feature = "cuda"))]
{
let count = device_count(BackendType::Cuda).unwrap();
assert_eq!(count, 0);
}
#[cfg(not(feature = "metal"))]
{
let count = device_count(BackendType::Metal).unwrap();
assert_eq!(count, 0);
}
let count = device_count(BackendType::Rocm).unwrap();
assert_eq!(count, 0);
}
#[test]
fn test_find_best_device_no_match() {
use crate::device::{DeviceSelector, DeviceType};
let selector = DeviceSelector::new().with_device_type(DeviceType::Cuda);
let result = find_best_device(Some(selector));
assert!(result.is_ok());
}
#[test]
fn test_memory_pool_config_edge_cases() {
let config = MemoryPoolConfig::new(0); assert_eq!(config.initial_size, 0);
let config = MemoryPoolConfig::new(usize::MAX); assert_eq!(config.initial_size, usize::MAX);
let config = MemoryPoolConfig::new(1024).with_growth_factor(0.0);
assert_eq!(config.growth_factor, 0.0);
let config = MemoryPoolConfig::new(1024).with_growth_factor(-1.0);
assert_eq!(config.growth_factor, -1.0); }
#[test]
fn test_memory_pool_config_alignment_edge_cases() {
let config = MemoryPoolConfig::new(1024).with_alignment(0);
assert_eq!(config.alignment, 0);
let config = MemoryPoolConfig::new(1024).with_alignment(1);
assert_eq!(config.alignment, 1);
let config = MemoryPoolConfig::new(1024).with_alignment(4096);
assert_eq!(config.alignment, 4096); }
#[test]
fn test_error_handling_with_long_messages() {
use crate::error::conversion;
let long_message = "x".repeat(10000);
let error = conversion::cpu_error_with_context(long_message.clone(), "test_operation");
let error_str = error.to_string();
assert!(error_str.contains(&long_message));
assert!(error_str.len() > 10000);
}
#[test]
fn test_error_handling_with_special_characters() {
use crate::error::conversion;
let special_message = "Error: 測試 ñoño 🚀 \n\t\r";
let error = conversion::cpu_error_with_context(special_message, "test_unicode_operation");
let error_str = error.to_string();
assert!(error_str.contains("測試"));
assert!(error_str.contains("🚀"));
}
#[test]
fn test_concurrent_backend_creation() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
let success_count = Arc::new(AtomicUsize::new(0));
let error_count = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..10 {
let success_count = Arc::clone(&success_count);
let error_count = Arc::clone(&error_count);
let handle = thread::spawn(move || {
let builder = BackendBuilder::new().backend_type(BackendType::Cpu);
match builder.build() {
Ok(_) => success_count.fetch_add(1, Ordering::Relaxed),
Err(_) => error_count.fetch_add(1, Ordering::Relaxed),
};
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let successes = success_count.load(Ordering::Relaxed);
assert!(
successes > 0,
"No backend creation succeeded in concurrent test"
);
}
#[test]
fn test_backend_memory_pressure_simulation() {
let backend = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.memory_pool(MemoryPoolConfig::new(1024)) .build()
.unwrap();
let device = backend.default_device().unwrap();
assert!(!device.name().is_empty());
}
#[test]
fn test_enumerate_devices_consistency() {
let devices1 = enumerate_all_devices().unwrap();
let devices2 = enumerate_all_devices().unwrap();
assert_eq!(devices1.len(), devices2.len());
let backend_types1: std::collections::HashSet<_> =
devices1.iter().map(|(bt, _)| *bt).collect();
let backend_types2: std::collections::HashSet<_> =
devices2.iter().map(|(bt, _)| *bt).collect();
assert_eq!(backend_types1, backend_types2);
}
#[test]
fn test_device_selector_empty_criteria() {
use crate::device::DeviceSelector;
let selector = DeviceSelector::new();
let result = find_best_device(Some(selector));
assert!(result.is_ok());
}
#[test]
fn test_backend_builder_chain_operations() {
let builder = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.device_id(0)
.num_threads(4)
.memory_pool(MemoryPoolConfig::new(1024 * 1024))
.enable_profiling(true);
let result = builder.build();
assert!(result.is_ok());
}
#[test]
fn test_auto_backend_selection_fallback() {
let builder = BackendBuilder::new().backend_type(BackendType::Auto);
let result = builder.build();
assert!(result.is_ok());
let backend = result.unwrap();
let devices = backend.devices().unwrap();
assert!(!devices.is_empty());
}
#[test]
fn test_memory_pool_zero_max_size() {
let config = MemoryPoolConfig::new(1024).with_max_size(0);
let builder = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.memory_pool(config);
let result = builder.build();
match result {
Ok(_) => {
}
Err(_) => {
}
}
}
#[test]
fn test_memory_pool_negative_growth_factor() {
let config = MemoryPoolConfig::new(1024).with_growth_factor(-0.5);
let builder = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.memory_pool(config);
let result = builder.build();
match result {
Ok(_) => {
}
Err(_) => {
}
}
}
#[test]
fn test_device_selector_with_conflicting_criteria() {
use crate::device::{DeviceSelector, DeviceType};
let selector = DeviceSelector::new()
.with_device_type(DeviceType::Cpu)
.with_device_type(DeviceType::Cuda);
let result = find_best_device(Some(selector));
assert!(result.is_ok());
}
#[test]
fn test_backend_builder_cloning_with_modifications() {
let original_builder = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.num_threads(2);
let mut cloned_builder = original_builder.clone();
cloned_builder = cloned_builder.num_threads(4);
let original_result = original_builder.build();
let cloned_result = cloned_builder.build();
assert!(original_result.is_ok());
assert!(cloned_result.is_ok());
}
#[test]
fn test_error_context_with_empty_strings() {
use crate::error::ErrorContext;
let context = ErrorContext::new("")
.with_backend("")
.with_device("")
.with_details("");
let formatted = context.format();
assert!(!formatted.is_empty());
}
#[test]
fn test_error_context_with_null_characters() {
use crate::error::ErrorContext;
let context = ErrorContext::new("op\0eration")
.with_backend("back\0end")
.with_device("dev\0ice")
.with_details("deta\0ils");
let formatted = context.format();
assert!(!formatted.is_empty());
}
#[test]
fn test_memory_manager_extreme_alignment() {
let config = MemoryPoolConfig::new(1024).with_alignment(usize::MAX);
let builder = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.memory_pool(config);
let result = builder.build();
match result {
Ok(_) => {
}
Err(_) => {
}
}
}
#[test]
fn test_backend_resource_cleanup() {
let backend = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.build()
.unwrap();
let _device = backend.default_device().unwrap();
let _devices = backend.devices().unwrap();
drop(backend);
}
#[test]
fn test_available_backends_consistency() {
let backends1 = available_backends();
let backends2 = available_backends();
assert_eq!(backends1, backends2);
assert!(backends1.contains(&BackendType::Cpu));
assert!(!backends1.contains(&BackendType::Auto));
}
#[test]
fn test_device_count_consistency() {
for backend_type in available_backends() {
let count1 = device_count(backend_type).unwrap();
let count2 = device_count(backend_type).unwrap();
assert_eq!(
count1, count2,
"Device count should be consistent for {:?}",
backend_type
);
}
}
#[test]
fn test_enumerate_devices_with_no_backends() {
let devices = enumerate_all_devices().unwrap();
assert!(!devices.is_empty());
}
#[test]
fn test_backend_capability_reporting() {
let backend = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.build()
.unwrap();
let capabilities = backend.capabilities();
assert!(!capabilities.supported_dtypes.is_empty());
assert!(capabilities
.supported_dtypes
.contains(&torsh_core::DType::F32));
assert!(capabilities
.supported_dtypes
.contains(&torsh_core::DType::F64));
}
#[test]
fn test_error_recovery_and_retry_logic() {
let mut retry_count = 0;
let max_retries = 3;
loop {
let result = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.num_threads(1) .build();
match result {
Ok(_) => {
break;
}
Err(e) => {
retry_count += 1;
if retry_count >= max_retries {
let error_msg = e.to_string();
assert!(!error_msg.is_empty());
break;
}
std::thread::sleep(std::time::Duration::from_millis(10));
}
}
}
}
#[test]
fn test_backend_performance_hints() {
let backend = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.build()
.unwrap();
let hints = backend.performance_hints();
assert!(hints.optimal_batch_size > 0);
assert!(hints.optimal_batch_size <= 1024 * 1024); }
#[test]
fn test_cross_backend_type_compatibility() {
let cpu_result = BackendBuilder::new().backend_type(BackendType::Cpu).build();
assert!(cpu_result.is_ok());
#[cfg(feature = "cuda")]
{
let cuda_result = BackendBuilder::new()
.backend_type(BackendType::Cuda)
.build();
match cuda_result {
Ok(_) => {
}
Err(_) => {
}
}
}
}
#[test]
fn test_backend_state_isolation() {
let backend1 = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.num_threads(2)
.build()
.unwrap();
let backend2 = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.num_threads(4)
.build()
.unwrap();
let device1 = backend1.default_device().unwrap();
let device2 = backend2.default_device().unwrap();
assert!(!device1.name().is_empty());
assert!(!device2.name().is_empty());
}
#[test]
fn test_profiling_enablement() {
let backend = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.enable_profiling(true)
.build()
.unwrap();
let device = backend.default_device().unwrap();
assert!(!device.name().is_empty());
let backend_no_prof = BackendBuilder::new()
.backend_type(BackendType::Cpu)
.enable_profiling(false)
.build()
.unwrap();
let device_no_prof = backend_no_prof.default_device().unwrap();
assert!(!device_no_prof.name().is_empty());
}
#[test]
#[ignore = "Requires CUDA hardware - run with --ignored flag"]
fn test_cross_backend_validation_integration() {
use crate::cross_backend_validation::{
run_cross_backend_validation, CrossBackendValidator,
};
let validator = CrossBackendValidator::new();
assert!(!validator.available_backends().is_empty());
match validator.validate_device_creation() {
Ok(()) => {} Err(e) => eprintln!("Device creation validation warning: {}", e),
}
match validator.validate_capabilities_consistency() {
Ok(()) => {} Err(e) => eprintln!("Capabilities consistency validation warning: {}", e),
}
match run_cross_backend_validation() {
Ok(()) => {
}
Err(e) => {
eprintln!("Cross-backend validation warning: {}", e);
}
}
}
#[test]
fn test_floating_point_comparison_utilities() {
use crate::cross_backend_validation::{compare_f32_values, compare_f64_values};
assert!(compare_f32_values(1.0, 1.0, 1e-6));
assert!(compare_f32_values(1.0, 1.0000005, 1e-6));
assert!(!compare_f32_values(1.0, 1.1, 1e-6));
assert!(compare_f64_values(1.0, 1.0, 1e-11));
assert!(compare_f64_values(1.0, 1.00000000001, 1.1e-11));
assert!(!compare_f64_values(1.0, 1.1, 1e-11));
assert!(compare_f32_values(f32::NAN, f32::NAN, 1e-6));
assert!(compare_f32_values(f32::INFINITY, f32::INFINITY, 1e-6));
assert!(!compare_f32_values(f32::INFINITY, f32::NEG_INFINITY, 1e-6));
assert!(compare_f64_values(f64::NAN, f64::NAN, 1e-12));
assert!(compare_f64_values(f64::INFINITY, f64::INFINITY, 1e-12));
assert!(!compare_f64_values(f64::INFINITY, f64::NEG_INFINITY, 1e-12));
}
#[test]
fn test_hardware_optimization_integration() {
use crate::hardware_optimization_tests::{
run_lightweight_hardware_tests, HardwareOptimizationTester,
};
let tester = HardwareOptimizationTester::new();
assert!(tester.simd_tests_enabled);
assert!(tester.platform_tests_enabled);
assert!(!tester.performance_tests_enabled);
match run_lightweight_hardware_tests() {
Ok(()) => {
}
Err(e) => {
eprintln!("Hardware optimization tests warning: {}", e);
}
}
}
#[test]
fn test_hardware_optimization_tester_configuration() {
use crate::hardware_optimization_tests::HardwareOptimizationTester;
let mut tester = HardwareOptimizationTester::new();
tester.simd_tests_enabled = false;
tester.platform_tests_enabled = true;
tester.performance_tests_enabled = false;
assert!(!tester.simd_tests_enabled);
assert!(tester.platform_tests_enabled);
assert!(!tester.performance_tests_enabled);
}
}