use crate::{FxGraph, TorshResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use torsh_core::{device::DeviceType, dtype::DType, error::TorshError};
use torsh_tensor::Tensor;
type BackendInstanceCache = Arc<RwLock<HashMap<String, Arc<RwLock<Box<dyn CustomBackend>>>>>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BackendCapability {
BasicOps,
LinearAlgebra,
Convolution,
Recurrent,
Attention,
CustomOps,
Distributed,
Quantized,
GraphOptimization,
MemoryOptimization,
AutoGrad,
JitCompilation,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendInfo {
pub name: String,
pub version: String,
pub description: String,
pub supported_devices: Vec<String>,
pub capabilities: Vec<BackendCapability>,
pub supported_dtypes: Vec<String>,
pub vendor: String,
pub metadata: HashMap<String, String>,
}
impl BackendInfo {
pub fn new(name: String, version: String, description: String) -> Self {
Self {
name,
version,
description,
supported_devices: vec![],
capabilities: vec![],
supported_dtypes: vec![],
vendor: "Unknown".to_string(),
metadata: HashMap::new(),
}
}
pub fn with_device(mut self, device: DeviceType) -> Self {
self.supported_devices.push(format!("{device:?}"));
self
}
pub fn with_capability(mut self, capability: BackendCapability) -> Self {
self.capabilities.push(capability);
self
}
pub fn with_dtype(mut self, dtype: DType) -> Self {
self.supported_dtypes.push(format!("{dtype:?}"));
self
}
pub fn with_vendor(mut self, vendor: String) -> Self {
self.vendor = vendor;
self
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
pub fn supports_device(&self, device: DeviceType) -> bool {
let device_str = format!("{device:?}");
self.supported_devices.contains(&device_str)
}
pub fn has_capability(&self, capability: BackendCapability) -> bool {
self.capabilities.contains(&capability)
}
pub fn supports_dtype(&self, dtype: DType) -> bool {
let dtype_str = format!("{dtype:?}");
self.supported_dtypes.contains(&dtype_str)
}
}
#[derive(Debug, Clone)]
pub struct BackendContext {
pub device: DeviceType,
pub parameters: HashMap<String, String>,
pub options: HashMap<String, serde_json::Value>,
pub memory_limit: Option<usize>,
pub optimization_level: u8,
}
impl Default for BackendContext {
fn default() -> Self {
Self {
device: DeviceType::Cpu,
parameters: HashMap::new(),
options: HashMap::new(),
memory_limit: None,
optimization_level: 1,
}
}
}
impl BackendContext {
pub fn new(device: DeviceType) -> Self {
Self {
device,
..Default::default()
}
}
pub fn with_parameter(mut self, key: String, value: String) -> Self {
self.parameters.insert(key, value);
self
}
pub fn with_option(mut self, key: String, value: serde_json::Value) -> Self {
self.options.insert(key, value);
self
}
pub fn with_memory_limit(mut self, limit: usize) -> Self {
self.memory_limit = Some(limit);
self
}
pub fn with_optimization_level(mut self, level: u8) -> Self {
self.optimization_level = level.min(3);
self
}
}
#[derive(Debug)]
pub struct BackendResult {
pub outputs: Vec<Tensor>,
pub execution_time: Option<u64>,
pub memory_usage: Option<usize>,
pub metadata: HashMap<String, String>,
}
impl BackendResult {
pub fn new(outputs: Vec<Tensor>) -> Self {
Self {
outputs,
execution_time: None,
memory_usage: None,
metadata: HashMap::new(),
}
}
pub fn with_execution_time(mut self, time: u64) -> Self {
self.execution_time = Some(time);
self
}
pub fn with_memory_usage(mut self, memory: usize) -> Self {
self.memory_usage = Some(memory);
self
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
}
pub trait CustomBackend: Send + Sync {
fn info(&self) -> &BackendInfo;
fn initialize(&mut self, context: &BackendContext) -> TorshResult<()>;
fn finalize(&mut self) -> TorshResult<()>;
fn can_execute(&self, operation: &str, inputs: &[&Tensor], context: &BackendContext) -> bool;
fn execute_operation(
&self,
operation: &str,
inputs: Vec<Tensor>,
context: &BackendContext,
) -> TorshResult<BackendResult>;
fn execute_graph(
&self,
graph: &FxGraph,
inputs: HashMap<String, Tensor>,
context: &BackendContext,
) -> TorshResult<BackendResult> {
self.execute_graph_sequential(graph, inputs, context)
}
fn optimize_graph(&self, _graph: &FxGraph, _context: &BackendContext) -> TorshResult<FxGraph> {
Ok(_graph.clone())
}
fn compile_info(
&self,
_graph: &FxGraph,
_context: &BackendContext,
) -> TorshResult<HashMap<String, String>> {
Ok(HashMap::new())
}
fn execute_graph_sequential(
&self,
graph: &FxGraph,
inputs: HashMap<String, Tensor>,
context: &BackendContext,
) -> TorshResult<BackendResult> {
let mut interpreter = crate::interpreter::GraphInterpreter::new(context.device);
let outputs = interpreter.run(graph, inputs)?;
Ok(BackendResult::new(outputs))
}
}
pub trait BackendFactory: Send + Sync {
fn create_backend(&self) -> TorshResult<Box<dyn CustomBackend>>;
fn factory_info(&self) -> BackendInfo;
}
#[derive(Default)]
pub struct BackendRegistry {
factories: Arc<RwLock<HashMap<String, Box<dyn BackendFactory>>>>,
instances: BackendInstanceCache,
}
impl BackendRegistry {
pub fn new() -> Self {
Self {
factories: Arc::new(RwLock::new(HashMap::new())),
instances: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register_factory<F: BackendFactory + 'static>(
&self,
name: String,
factory: F,
) -> TorshResult<()> {
let mut factories = self.factories.write().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire write lock on factories".to_string())
})?;
if factories.contains_key(&name) {
return Err(TorshError::InvalidArgument(format!(
"Backend factory '{name}' already registered"
)));
}
factories.insert(name, Box::new(factory));
Ok(())
}
pub fn get_backend(&self, name: &str) -> TorshResult<Arc<RwLock<Box<dyn CustomBackend>>>> {
{
let instances = self.instances.read().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire read lock on instances".to_string())
})?;
if let Some(instance) = instances.get(name) {
return Ok(instance.clone());
}
}
let factories = self.factories.read().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire read lock on factories".to_string())
})?;
let factory = factories.get(name).ok_or_else(|| {
TorshError::InvalidArgument(format!("Backend factory '{name}' not found"))
})?;
let backend = factory.create_backend()?;
let instance = Arc::new(RwLock::new(backend));
drop(factories);
let mut instances = self.instances.write().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire write lock on instances".to_string())
})?;
instances.insert(name.to_string(), instance.clone());
Ok(instance)
}
pub fn list_backends(&self) -> TorshResult<Vec<BackendInfo>> {
let factories = self.factories.read().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire read lock on factories".to_string())
})?;
let mut backends = Vec::new();
for factory in factories.values() {
backends.push(factory.factory_info());
}
Ok(backends)
}
pub fn find_backends_with_capability(
&self,
capability: BackendCapability,
) -> TorshResult<Vec<String>> {
let backends = self.list_backends()?;
let mut matching = Vec::new();
for backend in backends {
if backend.has_capability(capability) {
matching.push(backend.name);
}
}
Ok(matching)
}
pub fn find_backends_for_device(&self, device: DeviceType) -> TorshResult<Vec<String>> {
let backends = self.list_backends()?;
let mut matching = Vec::new();
for backend in backends {
if backend.supports_device(device) {
matching.push(backend.name);
}
}
Ok(matching)
}
pub fn unregister_backend(&self, name: &str) -> TorshResult<()> {
let mut factories = self.factories.write().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire write lock on factories".to_string())
})?;
let mut instances = self.instances.write().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire write lock on instances".to_string())
})?;
factories.remove(name);
instances.remove(name);
Ok(())
}
}
static GLOBAL_REGISTRY: std::sync::OnceLock<BackendRegistry> = std::sync::OnceLock::new();
pub fn global_registry() -> &'static BackendRegistry {
GLOBAL_REGISTRY.get_or_init(BackendRegistry::new)
}
pub fn register_backend_factory<F: BackendFactory + 'static>(
name: String,
factory: F,
) -> TorshResult<()> {
global_registry().register_factory(name, factory)
}
pub fn get_backend(name: &str) -> TorshResult<Arc<RwLock<Box<dyn CustomBackend>>>> {
global_registry().get_backend(name)
}
pub fn list_available_backends() -> TorshResult<Vec<BackendInfo>> {
global_registry().list_backends()
}
pub struct BackendExecutor {
strategy: BackendSelectionStrategy,
fallback_backend: Option<String>,
context: BackendContext,
}
#[derive(Debug, Clone)]
pub enum BackendSelectionStrategy {
Specific(String),
Auto,
FirstAvailable(Vec<BackendCapability>),
DeviceSpecific(DeviceType),
Custom(fn(&[BackendInfo], &BackendContext) -> Option<String>),
}
impl BackendExecutor {
pub fn new(strategy: BackendSelectionStrategy, context: BackendContext) -> Self {
Self {
strategy,
fallback_backend: None,
context,
}
}
pub fn with_fallback(mut self, backend_name: String) -> Self {
self.fallback_backend = Some(backend_name);
self
}
pub fn execute(
&self,
graph: &FxGraph,
inputs: HashMap<String, Tensor>,
) -> TorshResult<BackendResult> {
let backend_name = self.select_backend(graph)?;
let backend = get_backend(&backend_name)?;
let backend_guard = backend.read().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire read lock on backend".to_string())
})?;
let optimized_graph = backend_guard.optimize_graph(graph, &self.context)?;
backend_guard.execute_graph(&optimized_graph, inputs, &self.context)
}
fn select_backend(&self, graph: &FxGraph) -> TorshResult<String> {
match &self.strategy {
BackendSelectionStrategy::Specific(name) => Ok(name.clone()),
BackendSelectionStrategy::Auto => self.auto_select_backend(graph),
BackendSelectionStrategy::FirstAvailable(capabilities) => {
self.select_first_available(capabilities)
}
BackendSelectionStrategy::DeviceSpecific(device) => {
self.select_device_specific(*device)
}
BackendSelectionStrategy::Custom(selector) => {
let backends = list_available_backends()?;
if let Some(name) = selector(&backends, &self.context) {
Ok(name)
} else {
self.get_fallback_backend()
}
}
}
}
fn auto_select_backend(&self, _graph: &FxGraph) -> TorshResult<String> {
let device_backends = global_registry().find_backends_for_device(self.context.device)?;
if !device_backends.is_empty() {
Ok(device_backends[0].clone())
} else {
self.get_fallback_backend()
}
}
fn select_first_available(&self, capabilities: &[BackendCapability]) -> TorshResult<String> {
for capability in capabilities {
let backends = global_registry().find_backends_with_capability(*capability)?;
if !backends.is_empty() {
return Ok(backends[0].clone());
}
}
self.get_fallback_backend()
}
fn select_device_specific(&self, device: DeviceType) -> TorshResult<String> {
let backends = global_registry().find_backends_for_device(device)?;
if !backends.is_empty() {
Ok(backends[0].clone())
} else {
self.get_fallback_backend()
}
}
fn get_fallback_backend(&self) -> TorshResult<String> {
if let Some(ref fallback) = self.fallback_backend {
Ok(fallback.clone())
} else {
Err(TorshError::InvalidArgument(
"No suitable backend found".to_string(),
))
}
}
}
pub struct SimpleCpuBackend {
info: BackendInfo,
initialized: bool,
}
impl SimpleCpuBackend {
pub fn new() -> Self {
let info = BackendInfo::new(
"simple_cpu".to_string(),
"1.0.0".to_string(),
"Simple CPU backend for basic operations".to_string(),
)
.with_device(DeviceType::Cpu)
.with_capability(BackendCapability::BasicOps)
.with_capability(BackendCapability::LinearAlgebra)
.with_dtype(DType::F32)
.with_dtype(DType::F64)
.with_vendor("ToRSh".to_string());
Self {
info,
initialized: false,
}
}
}
impl Default for SimpleCpuBackend {
fn default() -> Self {
Self::new()
}
}
impl CustomBackend for SimpleCpuBackend {
fn info(&self) -> &BackendInfo {
&self.info
}
fn initialize(&mut self, _context: &BackendContext) -> TorshResult<()> {
self.initialized = true;
Ok(())
}
fn finalize(&mut self) -> TorshResult<()> {
self.initialized = false;
Ok(())
}
fn can_execute(&self, operation: &str, _inputs: &[&Tensor], context: &BackendContext) -> bool {
if !self.initialized || context.device != DeviceType::Cpu {
return false;
}
matches!(operation, "add" | "mul" | "matmul" | "relu" | "sigmoid")
}
fn execute_operation(
&self,
operation: &str,
inputs: Vec<Tensor>,
_context: &BackendContext,
) -> TorshResult<BackendResult> {
if !self.initialized {
return Err(TorshError::InvalidArgument(
"Backend not initialized".to_string(),
));
}
let start_time = std::time::Instant::now();
let result = match operation {
"add" => {
if inputs.len() != 2 {
return Err(TorshError::InvalidArgument(
"Add requires 2 inputs".to_string(),
));
}
inputs[0].add_op(&inputs[1])?
}
"mul" => {
if inputs.len() != 2 {
return Err(TorshError::InvalidArgument(
"Mul requires 2 inputs".to_string(),
));
}
inputs[0].mul_op(&inputs[1])?
}
"matmul" => {
if inputs.len() != 2 {
return Err(TorshError::InvalidArgument(
"Matmul requires 2 inputs".to_string(),
));
}
inputs[0].matmul(&inputs[1])?
}
"relu" => {
if inputs.len() != 1 {
return Err(TorshError::InvalidArgument(
"ReLU requires 1 input".to_string(),
));
}
inputs[0].relu()?
}
"sigmoid" => {
if inputs.len() != 1 {
return Err(TorshError::InvalidArgument(
"Sigmoid requires 1 input".to_string(),
));
}
inputs[0].sigmoid()?
}
_ => {
return Err(TorshError::InvalidArgument(format!(
"Unsupported operation: {operation}"
)));
}
};
let execution_time = start_time.elapsed().as_micros() as u64;
Ok(BackendResult::new(vec![result])
.with_execution_time(execution_time)
.with_metadata("backend".to_string(), "simple_cpu".to_string()))
}
}
pub struct SimpleCpuBackendFactory;
impl BackendFactory for SimpleCpuBackendFactory {
fn create_backend(&self) -> TorshResult<Box<dyn CustomBackend>> {
Ok(Box::new(SimpleCpuBackend::new()))
}
fn factory_info(&self) -> BackendInfo {
SimpleCpuBackend::new().info().clone()
}
}
pub fn execute_with_auto_backend(
graph: &FxGraph,
inputs: HashMap<String, Tensor>,
device: DeviceType,
) -> TorshResult<BackendResult> {
let context = BackendContext::new(device);
let executor = BackendExecutor::new(BackendSelectionStrategy::Auto, context);
executor.execute(graph, inputs)
}
pub fn execute_with_backend(
graph: &FxGraph,
inputs: HashMap<String, Tensor>,
backend_name: &str,
context: BackendContext,
) -> TorshResult<BackendResult> {
let executor = BackendExecutor::new(
BackendSelectionStrategy::Specific(backend_name.to_string()),
context,
);
executor.execute(graph, inputs)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tracer::ModuleTracer;
use torsh_tensor::creation::ones;
#[test]
fn test_backend_info() {
let info = BackendInfo::new(
"test_backend".to_string(),
"1.0.0".to_string(),
"Test backend".to_string(),
)
.with_device(DeviceType::Cpu)
.with_capability(BackendCapability::BasicOps)
.with_dtype(DType::F32);
assert_eq!(info.name, "test_backend");
assert!(info.supports_device(DeviceType::Cpu));
assert!(info.has_capability(BackendCapability::BasicOps));
assert!(info.supports_dtype(DType::F32));
}
#[test]
fn test_backend_context() {
let context = BackendContext::new(DeviceType::Cpu)
.with_parameter("threads".to_string(), "4".to_string())
.with_memory_limit(1024 * 1024 * 1024)
.with_optimization_level(2);
assert_eq!(context.device, DeviceType::Cpu);
assert_eq!(context.parameters.get("threads"), Some(&"4".to_string()));
assert_eq!(context.memory_limit, Some(1024 * 1024 * 1024));
assert_eq!(context.optimization_level, 2);
}
#[test]
fn test_backend_result() {
let tensor = ones(&[2, 3]).unwrap();
let result = BackendResult::new(vec![tensor])
.with_execution_time(1000)
.with_memory_usage(1024)
.with_metadata("test".to_string(), "value".to_string());
assert_eq!(result.outputs.len(), 1);
assert_eq!(result.execution_time, Some(1000));
assert_eq!(result.memory_usage, Some(1024));
assert_eq!(result.metadata.get("test"), Some(&"value".to_string()));
}
#[test]
fn test_simple_cpu_backend() {
let mut backend = SimpleCpuBackend::new();
let context = BackendContext::new(DeviceType::Cpu);
assert!(backend.initialize(&context).is_ok());
let tensor1 = ones(&[2, 3]).unwrap();
let tensor2 = ones(&[2, 3]).unwrap();
assert!(backend.can_execute("add", &[&tensor1, &tensor2], &context));
let result = backend.execute_operation("add", vec![tensor1, tensor2], &context);
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.outputs.len(), 1);
assert!(result.execution_time.is_some());
}
#[test]
fn test_backend_registry() {
let registry = BackendRegistry::new();
let factory = SimpleCpuBackendFactory;
assert!(registry
.register_factory("simple_cpu".to_string(), factory)
.is_ok());
let backends = registry.list_backends().unwrap();
assert!(!backends.is_empty());
assert_eq!(backends[0].name, "simple_cpu");
let backend = registry.get_backend("simple_cpu");
assert!(backend.is_ok());
}
#[test]
fn test_global_registry() {
let factory = SimpleCpuBackendFactory;
let register_result = register_backend_factory("test_global".to_string(), factory);
if register_result.is_ok() {
let backends = list_available_backends();
if let Ok(backends) = backends {
let found = backends.iter().any(|b| b.name == "test_global");
if found {
let backend = get_backend("test_global");
assert!(backend.is_ok());
} else {
}
} else {
}
} else {
}
}
#[test]
fn test_backend_executor() {
let factory = SimpleCpuBackendFactory;
let _ = register_backend_factory("cpu_executor_test".to_string(), factory);
let context = BackendContext::new(DeviceType::Cpu);
let executor = BackendExecutor::new(
BackendSelectionStrategy::Specific("cpu_executor_test".to_string()),
context,
);
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let mut inputs = HashMap::new();
inputs.insert("x".to_string(), ones(&[2, 3]).unwrap());
let result = executor.execute(&graph, inputs);
match result {
Ok(_) => {
}
Err(_) => {
}
}
}
#[test]
fn test_backend_capabilities_search() {
let registry = BackendRegistry::new();
let factory = SimpleCpuBackendFactory;
let _ = registry.register_factory("capability_test".to_string(), factory);
let backends = registry
.find_backends_with_capability(BackendCapability::BasicOps)
.unwrap();
assert!(!backends.is_empty());
let device_backends = registry.find_backends_for_device(DeviceType::Cpu).unwrap();
assert!(!device_backends.is_empty());
}
#[test]
fn test_backend_selection_strategies() {
let context = BackendContext::new(DeviceType::Cpu);
let strategy = BackendSelectionStrategy::Specific("test".to_string());
let _executor = BackendExecutor::new(strategy, context.clone());
let strategy = BackendSelectionStrategy::Auto;
let _executor = BackendExecutor::new(strategy, context.clone());
let strategy = BackendSelectionStrategy::FirstAvailable(vec![BackendCapability::BasicOps]);
let _executor = BackendExecutor::new(strategy, context.clone());
let strategy = BackendSelectionStrategy::DeviceSpecific(DeviceType::Cpu);
let _executor = BackendExecutor::new(strategy, context);
}
#[test]
fn test_convenience_functions() {
let factory = SimpleCpuBackendFactory;
let _ = register_backend_factory("convenience_test".to_string(), factory);
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
let graph = tracer.finalize();
let mut inputs = HashMap::new();
inputs.insert("x".to_string(), ones(&[2, 3]).unwrap());
let _result = execute_with_auto_backend(&graph, inputs.clone(), DeviceType::Cpu);
let context = BackendContext::new(DeviceType::Cpu);
let _result = execute_with_backend(&graph, inputs, "convenience_test", context);
}
#[test]
fn test_backend_info_serialization() {
let info = BackendInfo::new("test".to_string(), "1.0".to_string(), "Test".to_string())
.with_device(DeviceType::Cpu)
.with_capability(BackendCapability::BasicOps);
let serialized = serde_json::to_string(&info).unwrap();
let deserialized: BackendInfo = serde_json::from_str(&serialized).unwrap();
assert_eq!(info.name, deserialized.name);
assert_eq!(info.version, deserialized.version);
}
}