use crate::memory::MemoryManager;
use crate::profiler::Profiler;
use crate::{Buffer, BufferDescriptor, Device, Kernel, KernelDescriptor};
use torsh_core::{device::DeviceType, dtype::DType, error::TorshError};
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, string::String, vec::Vec};
pub type BackendResult<T> = Result<T, TorshError>;
pub trait BackendCore: Send + Sync + std::fmt::Debug {
fn device_type(&self) -> DeviceType;
fn name(&self) -> &str;
fn is_available(&self) -> BackendResult<bool>;
fn capabilities(&self) -> BackendCapabilities;
fn performance_hints(&self) -> PerformanceHints;
}
#[async_trait::async_trait]
pub trait BackendLifecycle: Send + Sync {
async fn initialize(&mut self) -> BackendResult<()>;
async fn shutdown(&mut self) -> BackendResult<()>;
fn is_initialized(&self) -> bool;
}
pub trait BackendDeviceManager: Send + Sync {
fn devices(&self) -> BackendResult<Vec<Device>>;
fn default_device(&self) -> BackendResult<Device>;
fn create_device(&self, device_id: usize) -> BackendResult<Device>;
fn device_count(&self) -> BackendResult<usize>;
fn is_device_available(&self, device_id: usize) -> bool;
}
pub trait BackendResourceManager: Send + Sync {
fn create_buffer(
&self,
device: &Device,
descriptor: &BufferDescriptor,
) -> BackendResult<Buffer>;
fn create_kernel(
&self,
device: &Device,
descriptor: &KernelDescriptor,
) -> BackendResult<Kernel>;
fn memory_manager(
&self,
device: &Device,
) -> BackendResult<Box<dyn MemoryManager + Send + Sync>>;
fn profiler(&self) -> BackendResult<Box<dyn Profiler + Send + Sync>>;
fn create_scoped_buffer(
&self,
device: &Device,
descriptor: &BufferDescriptor,
) -> BackendResult<Buffer>;
}
pub trait BackendAdvancedResourceManager: Send + Sync {
fn create_resource_with_cleanup<T, F>(
&self,
device: &Device,
factory: F,
cleanup: impl FnOnce(&T) + Send + 'static,
) -> BackendResult<ManagedResource<T>>
where
T: Send + Sync + 'static,
F: FnOnce(&Device) -> BackendResult<T>;
}
#[async_trait::async_trait]
pub trait BackendExecutor: Send + Sync {
async fn synchronize(&self, device: &Device) -> BackendResult<()>;
async fn copy_buffer(
&self,
src: &Buffer,
dst: &Buffer,
src_offset: usize,
dst_offset: usize,
size: usize,
) -> BackendResult<()>;
async fn copy_to_device(
&self,
src: &[u8],
dst: &Buffer,
dst_offset: usize,
) -> BackendResult<()>;
async fn copy_from_device(
&self,
src: &Buffer,
dst: &mut [u8],
src_offset: usize,
) -> BackendResult<()>;
async fn execute_kernel(
&self,
kernel: &Kernel,
buffers: &[&Buffer],
uniform_data: &[u8],
workgroup_size: (u32, u32, u32),
workgroup_count: (u32, u32, u32),
) -> BackendResult<()>;
}
pub trait BackendOperations: Send + Sync {
fn fft_ops(&self) -> Box<dyn crate::fft::FftOps>;
fn convolution_ops(&self) -> Box<dyn crate::convolution::ConvolutionOps>;
fn rnn_ops(&self) -> Box<dyn crate::rnn::RnnOps>;
fn sparse_ops(&self) -> Box<dyn crate::sparse_ops::SparseOps<f32>>;
fn quantization_ops(&self) -> Box<dyn crate::quantization::QuantizationOps>;
fn operations_bundle(&self) -> OperationsBundle;
}
pub trait Backend:
BackendCore
+ BackendLifecycle
+ BackendDeviceManager
+ BackendResourceManager
+ BackendExecutor
+ BackendOperations
+ BackendOps
{
fn as_core(&self) -> &dyn BackendCore;
fn as_lifecycle(&mut self) -> &mut dyn BackendLifecycle;
fn as_device_manager(&self) -> &dyn BackendDeviceManager;
fn as_resource_manager(&self) -> &dyn BackendResourceManager;
fn as_executor(&self) -> &dyn BackendExecutor;
fn as_operations(&self) -> &dyn BackendOperations;
}
pub struct ScopedResource<'a, T> {
resource: Option<T>,
cleanup: Option<Box<dyn FnOnce(T) + Send + 'a>>,
}
impl<'a, T> ScopedResource<'a, T> {
pub fn new(resource: T) -> Self {
Self {
resource: Some(resource),
cleanup: None,
}
}
pub fn new_with_cleanup<F>(resource: T, cleanup: F) -> Self
where
F: FnOnce(T) + Send + 'a,
{
Self {
resource: Some(resource),
cleanup: Some(Box::new(cleanup)),
}
}
pub fn get(&self) -> Option<&T> {
self.resource.as_ref()
}
pub fn get_mut(&mut self) -> Option<&mut T> {
self.resource.as_mut()
}
pub fn take(mut self) -> Option<T> {
self.cleanup = None; self.resource.take()
}
pub fn with_resource<F, R>(&self, f: F) -> Option<R>
where
F: FnOnce(&T) -> R,
{
self.resource.as_ref().map(f)
}
pub fn is_available(&self) -> bool {
self.resource.is_some()
}
}
impl<'a, T> Drop for ScopedResource<'a, T> {
fn drop(&mut self) {
if let (Some(resource), Some(cleanup)) = (self.resource.take(), self.cleanup.take()) {
cleanup(resource);
}
}
}
pub struct ManagedResource<T> {
resource: Option<T>,
cleanup: Option<Box<dyn FnOnce(&T) + Send + 'static>>,
}
impl<T> ManagedResource<T> {
pub fn new(resource: T) -> Self {
Self {
resource: Some(resource),
cleanup: None,
}
}
pub fn new_with_cleanup<F>(resource: T, cleanup: F) -> Self
where
F: FnOnce(&T) + Send + 'static,
{
Self {
resource: Some(resource),
cleanup: Some(Box::new(cleanup)),
}
}
pub fn get(&self) -> Option<&T> {
self.resource.as_ref()
}
pub fn get_mut(&mut self) -> Option<&mut T> {
self.resource.as_mut()
}
pub fn take(mut self) -> Option<T> {
self.cleanup = None; self.resource.take()
}
pub fn with_resource<F, R>(&self, f: F) -> Option<R>
where
F: FnOnce(&T) -> R,
{
self.resource.as_ref().map(f)
}
pub fn is_available(&self) -> bool {
self.resource.is_some()
}
}
impl<T> Drop for ManagedResource<T> {
fn drop(&mut self) {
if let (Some(resource), Some(cleanup)) = (self.resource.as_ref(), self.cleanup.take()) {
cleanup(resource);
}
}
}
unsafe impl<T: Send> Send for ManagedResource<T> {}
unsafe impl<T: Sync> Sync for ManagedResource<T> {}
pub struct OperationsBundle {
pub fft: Box<dyn crate::fft::FftOps>,
pub convolution: Box<dyn crate::convolution::ConvolutionOps>,
pub rnn: Box<dyn crate::rnn::RnnOps>,
pub sparse: Box<dyn crate::sparse_ops::SparseOps<f32>>,
pub quantization: Box<dyn crate::quantization::QuantizationOps>,
}
impl OperationsBundle {
pub fn new(
fft: Box<dyn crate::fft::FftOps>,
convolution: Box<dyn crate::convolution::ConvolutionOps>,
rnn: Box<dyn crate::rnn::RnnOps>,
sparse: Box<dyn crate::sparse_ops::SparseOps<f32>>,
quantization: Box<dyn crate::quantization::QuantizationOps>,
) -> Self {
Self {
fft,
convolution,
rnn,
sparse,
quantization,
}
}
}
#[derive(Debug, Clone)]
pub struct BackendCapabilities {
pub max_buffer_size: usize,
pub max_compute_units: usize,
pub max_workgroup_size: (u32, u32, u32),
pub supported_dtypes: Vec<DType>,
pub supports_async: bool,
pub supports_unified_memory: bool,
pub supports_sub_buffers: bool,
pub supports_kernel_caching: bool,
pub memory_bandwidth_gbps: f32,
pub compute_throughput_gflops: f32,
pub extended_capabilities: ExtendedCapabilities,
}
#[derive(Debug, Clone)]
pub struct ExtendedCapabilities {
pub max_tensor_dims: Option<usize>,
pub precision_modes: Vec<PrecisionMode>,
pub hardware_features: Vec<HardwareFeature>,
pub memory_hierarchy: MemoryHierarchy,
pub execution_model: ExecutionModel,
pub custom_capabilities: std::collections::HashMap<String, CapabilityValue>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum PrecisionMode {
F16,
F32,
F64,
Mixed,
Custom(u8),
}
#[derive(Debug, Clone, PartialEq)]
pub enum HardwareFeature {
TensorCores,
VectorUnits,
SharedMemory,
ConstantMemory,
AtomicOperations,
CooperativeGroups,
DynamicParallelism,
Custom(String),
}
#[derive(Debug, Clone, Default)]
pub struct MemoryHierarchy {
pub l1_cache_size: Option<usize>,
pub l2_cache_size: Option<usize>,
pub l3_cache_size: Option<usize>,
pub shared_memory_size: Option<usize>,
pub memory_latency_cycles: Option<u32>,
pub memory_bandwidth_per_core: Option<f32>,
}
#[derive(Debug, Clone)]
pub struct ExecutionModel {
pub supports_simd: bool,
pub supports_simt: bool,
pub supports_task_parallelism: bool,
pub supports_data_parallelism: bool,
pub max_concurrent_streams: Option<u32>,
pub supports_out_of_order: bool,
}
#[derive(Debug, Clone)]
pub enum CapabilityValue {
Bool(bool),
Int(i64),
Float(f64),
String(String),
List(Vec<CapabilityValue>),
}
impl Default for ExtendedCapabilities {
fn default() -> Self {
Self {
max_tensor_dims: Some(8),
precision_modes: vec![PrecisionMode::F32],
hardware_features: vec![],
memory_hierarchy: MemoryHierarchy::default(),
execution_model: ExecutionModel::default(),
custom_capabilities: std::collections::HashMap::new(),
}
}
}
impl Default for ExecutionModel {
fn default() -> Self {
Self {
supports_simd: false,
supports_simt: false,
supports_task_parallelism: true,
supports_data_parallelism: true,
max_concurrent_streams: Some(1),
supports_out_of_order: false,
}
}
}
impl Default for BackendCapabilities {
fn default() -> Self {
Self {
max_buffer_size: 1024 * 1024 * 1024, max_compute_units: 1,
max_workgroup_size: (256, 1, 1),
supported_dtypes: vec![DType::F32, DType::F64, DType::I32, DType::I64],
supports_async: false,
supports_unified_memory: false,
supports_sub_buffers: false,
supports_kernel_caching: false,
memory_bandwidth_gbps: 10.0,
compute_throughput_gflops: 100.0,
extended_capabilities: ExtendedCapabilities::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct PerformanceHints {
pub preferred_workgroup_size: (u32, u32, u32),
pub memory_alignment: usize,
pub prefer_vectorized: bool,
pub prefer_async: bool,
pub optimal_batch_size: usize,
pub cache_kernels: bool,
}
impl Default for PerformanceHints {
fn default() -> Self {
Self {
preferred_workgroup_size: (64, 1, 1),
memory_alignment: 16,
prefer_vectorized: true,
prefer_async: false,
optimal_batch_size: 32,
cache_kernels: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
pub enum BackendType {
Auto,
Cpu,
Cuda,
Metal,
Rocm,
WebGpu,
}
impl std::fmt::Display for BackendType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BackendType::Auto => write!(f, "Auto"),
BackendType::Cpu => write!(f, "CPU"),
BackendType::Cuda => write!(f, "CUDA"),
BackendType::Metal => write!(f, "Metal"),
BackendType::Rocm => write!(f, "ROCm"),
BackendType::WebGpu => write!(f, "WebGPU"),
}
}
}
pub trait BackendOps: Send + Sync {
fn backend_type(&self) -> BackendType;
fn available_ops(&self) -> Vec<&str>;
fn supports_op(&self, op_name: &str) -> bool;
fn supports_fft(&self) -> bool;
fn supports_convolution(&self) -> bool;
fn supports_rnn(&self) -> bool;
fn supports_sparse(&self) -> bool;
fn supports_quantization(&self) -> bool;
fn operation_capabilities(
&self,
op_name: &str,
) -> Option<std::collections::HashMap<String, CapabilityValue>>;
}
pub trait BackendExtension: Send + Sync {
fn extension_name(&self) -> &str;
fn extension_version(&self) -> &str;
fn is_compatible_with(&self, backend: &dyn BackendCore) -> bool;
fn initialize(&mut self, backend: &dyn Backend) -> BackendResult<()>;
fn shutdown(&mut self) -> BackendResult<()>;
fn capabilities(&self) -> std::collections::HashMap<String, CapabilityValue>;
fn handle_operation(
&self,
op_name: &str,
args: &[CapabilityValue],
) -> BackendResult<CapabilityValue>;
}
pub struct BackendExtensionRegistry {
extensions: std::collections::HashMap<String, Box<dyn BackendExtension>>,
initialized_extensions: std::collections::HashSet<String>,
}
impl BackendExtensionRegistry {
pub fn new() -> Self {
Self {
extensions: std::collections::HashMap::new(),
initialized_extensions: std::collections::HashSet::new(),
}
}
pub fn register_extension(
&mut self,
extension: Box<dyn BackendExtension>,
) -> BackendResult<()> {
let name = extension.extension_name().to_string();
if self.extensions.contains_key(&name) {
return Err(TorshError::BackendError(format!(
"Extension '{}' is already registered",
name
)));
}
self.extensions.insert(name, extension);
Ok(())
}
pub fn get_extension(&self, name: &str) -> Option<&dyn BackendExtension> {
self.extensions.get(name).map(|e| e.as_ref())
}
pub fn get_extension_mut(&mut self, name: &str) -> Option<&mut Box<dyn BackendExtension>> {
self.extensions.get_mut(name)
}
pub fn extensions(&self) -> Vec<&str> {
self.extensions.keys().map(|s| s.as_str()).collect()
}
pub fn initialize_all(&mut self, backend: &dyn Backend) -> BackendResult<Vec<String>> {
let mut failed_extensions = Vec::new();
for (name, extension) in self.extensions.iter_mut() {
if extension.is_compatible_with(backend.as_core()) {
match extension.initialize(backend) {
Ok(()) => {
self.initialized_extensions.insert(name.clone());
}
Err(e) => {
failed_extensions.push(format!("{}: {}", name, e));
}
}
}
}
if failed_extensions.is_empty() {
Ok(vec![])
} else {
Err(TorshError::BackendError(format!(
"Failed to initialize extensions: {}",
failed_extensions.join(", ")
)))
}
}
pub fn shutdown_all(&mut self) -> BackendResult<Vec<String>> {
let mut failed_extensions = Vec::new();
for (name, extension) in self.extensions.iter_mut() {
if self.initialized_extensions.contains(name) {
if let Err(e) = extension.shutdown() {
failed_extensions.push(format!("{}: {}", name, e));
} else {
self.initialized_extensions.remove(name);
}
}
}
if failed_extensions.is_empty() {
Ok(vec![])
} else {
Err(TorshError::BackendError(format!(
"Failed to shutdown extensions: {}",
failed_extensions.join(", ")
)))
}
}
pub fn remove_extension(&mut self, name: &str) -> Option<Box<dyn BackendExtension>> {
if let Some(extension) = self.extensions.get_mut(name) {
if self.initialized_extensions.contains(name) {
let _ = extension.shutdown(); self.initialized_extensions.remove(name);
}
}
self.extensions.remove(name)
}
pub fn has_extension(&self, name: &str) -> bool {
self.extensions.contains_key(name)
}
pub fn len(&self) -> usize {
self.extensions.len()
}
pub fn is_empty(&self) -> bool {
self.extensions.is_empty()
}
}
impl Default for BackendExtensionRegistry {
fn default() -> Self {
Self::new()
}
}
pub trait BackendFactory: Send + Sync {
fn create(&self) -> BackendResult<Box<dyn Backend>>;
fn device_type(&self) -> DeviceType;
fn is_available(&self) -> bool;
fn priority(&self) -> u32;
fn capabilities(&self) -> BackendCapabilities;
}
pub struct DeviceEnumerator;
impl DeviceEnumerator {
pub fn enumerate_all_devices() -> BackendResult<Vec<(DeviceType, Vec<Device>)>> {
let mut all_devices = Vec::new();
#[cfg(feature = "cpu")]
{
if let Ok(cpu_backend) = crate::cpu::CpuBackend::new() {
if let Ok(devices) = cpu_backend.devices() {
all_devices.push((DeviceType::Cpu, devices));
}
}
}
#[cfg(feature = "cuda")]
{
}
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
{
if let Ok(metal_backend) = crate::metal::MetalBackend::new() {
if let Ok(devices) = metal_backend.devices() {
all_devices.push((DeviceType::Metal(0), devices));
}
}
}
#[cfg(feature = "webgpu")]
{
let webgpu_backend = crate::webgpu::WebGpuBackend::with_default_config();
if let Ok(devices) = webgpu_backend.devices() {
all_devices.push((DeviceType::Wgpu(0), devices));
}
}
Ok(all_devices)
}
pub fn find_best_device() -> BackendResult<(DeviceType, Device)> {
let all_devices = Self::enumerate_all_devices()?;
if all_devices.is_empty() {
return Err(TorshError::BackendError("No devices available".to_string()));
}
let backend_priorities = [
DeviceType::Cuda(0),
DeviceType::Metal(0),
DeviceType::Wgpu(0),
DeviceType::Cpu,
];
for preferred_type in &backend_priorities {
for (device_type, devices) in &all_devices {
if Self::device_types_match(device_type, preferred_type) && !devices.is_empty() {
let best_device = devices
.iter()
.max_by(|a, b| {
a.info()
.peak_gflops
.partial_cmp(&b.info().peak_gflops)
.unwrap_or(std::cmp::Ordering::Equal)
})
.cloned()
.expect("devices should not be empty after is_empty check");
return Ok((*device_type, best_device));
}
}
}
let (device_type, devices) = &all_devices[0];
if !devices.is_empty() {
Ok((*device_type, devices[0].clone()))
} else {
Err(TorshError::BackendError(
"No usable devices found".to_string(),
))
}
}
fn device_types_match(a: &DeviceType, b: &DeviceType) -> bool {
matches!(
(a, b),
(DeviceType::Cpu, DeviceType::Cpu)
| (DeviceType::Cuda(_), DeviceType::Cuda(_))
| (DeviceType::Metal(_), DeviceType::Metal(_))
| (DeviceType::Wgpu(_), DeviceType::Wgpu(_))
)
}
pub fn get_devices_by_type(device_type: DeviceType) -> BackendResult<Vec<Device>> {
match device_type {
#[cfg(feature = "cpu")]
DeviceType::Cpu => {
let cpu_backend = crate::cpu::CpuBackend::new()?;
cpu_backend.devices()
}
#[cfg(feature = "cuda")]
DeviceType::Cuda(_device_id) => {
Ok(vec![])
}
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
DeviceType::Metal(_) => {
let metal_backend = crate::metal::MetalBackend::new()?;
metal_backend.devices()
}
#[cfg(feature = "webgpu")]
DeviceType::Wgpu(_) => {
let webgpu_backend = crate::webgpu::WebGpuBackend::with_default_config();
webgpu_backend.devices()
}
#[allow(unreachable_patterns)]
_ => Err(TorshError::BackendError(format!(
"Backend type {device_type:?} not available"
))),
}
}
pub fn is_device_type_available(device_type: DeviceType) -> bool {
match device_type {
#[cfg(feature = "cpu")]
DeviceType::Cpu => true,
#[cfg(cuda_available)]
DeviceType::Cuda(device_id) => {
crate::cuda::CudaBackend::new(crate::cuda::CudaBackendConfig {
device_id: device_id as usize,
..Default::default()
})
.is_ok()
}
#[cfg(all(feature = "cuda", not(cuda_available)))]
DeviceType::Cuda(_) => false, #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
DeviceType::Metal(_) => crate::metal::MetalBackend::new().is_ok(),
#[cfg(feature = "webgpu")]
DeviceType::Wgpu(_) => true, #[allow(unreachable_patterns)]
_ => false,
}
}
}
pub trait BackendPlugin: Send + Sync + std::fmt::Debug {
fn name(&self) -> &str;
fn version(&self) -> &str;
fn create_backend(&self) -> BackendResult<Box<dyn Backend>>;
fn is_compatible(&self) -> bool;
fn supported_device_types(&self) -> Vec<DeviceType>;
fn metadata(&self) -> PluginMetadata;
}
#[derive(Debug, Clone)]
pub struct PluginMetadata {
pub name: String,
pub version: String,
pub description: String,
pub author: String,
pub license: String,
pub supported_architectures: Vec<String>,
pub required_features: Vec<String>,
pub optional_features: Vec<String>,
}
pub trait BackendResourceMonitor: Send + Sync {
fn resource_usage(&self) -> ResourceUsage;
fn set_resource_limits(&mut self, limits: ResourceLimits) -> BackendResult<()>;
fn resource_limits(&self) -> ResourceLimits;
fn cleanup_resources(&mut self) -> BackendResult<()>;
fn resource_statistics(&self) -> ResourceStatistics;
fn enable_monitoring(&mut self) -> BackendResult<()>;
fn disable_monitoring(&mut self) -> BackendResult<()>;
fn is_monitoring_enabled(&self) -> bool;
}
#[derive(Debug, Clone)]
pub struct ResourceUsage {
pub memory_used: usize,
pub buffers_allocated: usize,
pub kernels_cached: usize,
pub active_streams: usize,
pub cpu_usage_percent: f32,
pub gpu_usage_percent: f32,
}
#[derive(Debug, Clone)]
pub struct ResourceLimits {
pub max_memory: Option<usize>,
pub max_buffers: Option<usize>,
pub max_kernels: Option<usize>,
pub max_streams: Option<usize>,
pub memory_pressure_threshold: f32,
}
#[derive(Debug, Clone)]
pub struct ResourceStatistics {
pub peak_memory_usage: usize,
pub total_allocations: u64,
pub total_deallocations: u64,
pub average_buffer_size: f32,
pub cache_hit_rate: f32,
pub allocation_failure_count: u32,
}
pub struct BackendRegistry {
backends: std::collections::HashMap<String, Box<dyn BackendPlugin>>,
default_backend: Option<String>,
}
impl BackendRegistry {
pub fn new() -> Self {
Self {
backends: std::collections::HashMap::new(),
default_backend: None,
}
}
pub fn register_plugin(&mut self, plugin: Box<dyn BackendPlugin>) -> BackendResult<()> {
let name = plugin.name().to_string();
if !plugin.is_compatible() {
return Err(TorshError::BackendError(format!(
"Plugin {name} is not compatible with current system"
)));
}
self.backends.insert(name.clone(), plugin);
if self.default_backend.is_none() {
self.default_backend = Some(name);
}
Ok(())
}
pub fn available_backends(&self) -> Vec<String> {
self.backends.keys().cloned().collect()
}
pub fn create_backend(&self, name: &str) -> BackendResult<Box<dyn Backend>> {
if let Some(plugin) = self.backends.get(name) {
plugin.create_backend()
} else {
Err(TorshError::BackendError(format!(
"Backend {name} not found"
)))
}
}
pub fn create_default_backend(&self) -> BackendResult<Box<dyn Backend>> {
if let Some(default_name) = &self.default_backend {
self.create_backend(default_name)
} else {
Err(TorshError::BackendError(
"No default backend available".to_string(),
))
}
}
pub fn set_default_backend(&mut self, name: &str) -> BackendResult<()> {
if self.backends.contains_key(name) {
self.default_backend = Some(name.to_string());
Ok(())
} else {
Err(TorshError::BackendError(format!(
"Backend {name} not found"
)))
}
}
pub fn get_plugin_metadata(&self, name: &str) -> Option<PluginMetadata> {
self.backends.get(name).map(|plugin| plugin.metadata())
}
}
impl Default for BackendRegistry {
fn default() -> Self {
Self::new()
}
}
pub trait BackendConfig: Send + Sync + Clone {
fn backend_type(&self) -> BackendType;
fn validate(&self) -> BackendResult<()>;
fn as_properties(&self) -> std::collections::HashMap<String, CapabilityValue>;
fn from_properties(
properties: &std::collections::HashMap<String, CapabilityValue>,
) -> BackendResult<Self>
where
Self: Sized;
fn merge(&mut self, other: &Self) -> BackendResult<()>;
fn default_config() -> Self
where
Self: Sized;
}
pub trait BackendBuilder<T: BackendConfig>: Send + Sync {
fn new() -> Self;
fn with_config(self, config: T) -> Self;
fn build(self) -> BackendResult<Box<dyn Backend>>;
fn config(&self) -> &T;
fn config_mut(&mut self) -> &mut T;
}
pub trait BackendErrorHandler: Send + Sync {
fn handle_error(&self, error: TorshError, context: &str) -> TorshError;
fn convert_error(&self, error: Box<dyn std::error::Error + Send + Sync>) -> TorshError;
fn recovery_suggestions(&self, error: &TorshError) -> Vec<String>;
fn log_error(&self, error: &TorshError, context: &str);
}
pub struct DefaultBackendErrorHandler {
backend_name: String,
}
impl DefaultBackendErrorHandler {
pub fn new(backend_name: String) -> Self {
Self { backend_name }
}
}
impl BackendErrorHandler for DefaultBackendErrorHandler {
fn handle_error(&self, error: TorshError, context: &str) -> TorshError {
match error {
TorshError::BackendError(msg) => TorshError::BackendError(format!(
"{}: {} (context: {})",
self.backend_name, msg, context
)),
other => other,
}
}
fn convert_error(&self, error: Box<dyn std::error::Error + Send + Sync>) -> TorshError {
TorshError::BackendError(format!("{}: {}", self.backend_name, error))
}
fn recovery_suggestions(&self, error: &TorshError) -> Vec<String> {
match error {
TorshError::BackendError(msg) if msg.contains("not available") => {
vec![
"Check if the backend is properly installed".to_string(),
"Verify system compatibility".to_string(),
"Try a different backend".to_string(),
]
}
TorshError::BackendError(msg) if msg.contains("memory") => {
vec![
"Reduce batch size or tensor dimensions".to_string(),
"Enable memory optimization".to_string(),
"Check available memory".to_string(),
]
}
_ => vec!["Contact support with error details".to_string()],
}
}
fn log_error(&self, error: &TorshError, context: &str) {
eprintln!("[{}] Error in {}: {}", self.backend_name, context, error);
}
}
impl dyn Backend {
pub fn auto() -> BackendResult<Box<dyn Backend>> {
let (device_type, _device) = DeviceEnumerator::find_best_device()?;
match device_type {
#[cfg(feature = "cpu")]
DeviceType::Cpu => Ok(Box::new(crate::cpu::CpuBackend::new()?)),
#[cfg(cuda_available)]
DeviceType::Cuda(device_id) => Ok(Box::new(crate::cuda::CudaBackend::new(
crate::cuda::CudaBackendConfig {
device_id: device_id as usize,
..Default::default()
},
)?)),
#[cfg(all(feature = "cuda", not(cuda_available)))]
DeviceType::Cuda(_) => Err(TorshError::BackendError(
"CUDA backend not available on this platform".to_string(),
)),
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
DeviceType::Metal(_) => Ok(Box::new(crate::metal::MetalBackend::new()?)),
#[cfg(feature = "webgpu")]
DeviceType::Wgpu(_) => {
Ok(Box::new(crate::webgpu::WebGpuBackend::with_default_config()))
}
#[allow(unreachable_patterns)]
_ => Err(TorshError::BackendError(
"No suitable backend found".to_string(),
)),
}
}
}