use crate::TorshResult;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use torsh_core::{dtype::DType, error::TorshError, shape::Shape};
use torsh_tensor::Tensor;
pub trait CustomOperation: Send + Sync {
fn execute(&self, inputs: Vec<Tensor>) -> TorshResult<Tensor>;
fn name(&self) -> &str;
fn clone_operation(&self) -> Box<dyn CustomOperation>;
fn metadata(&self) -> Option<HashMap<String, String>> {
None
}
fn validate_inputs(&self, _inputs: &[Tensor]) -> TorshResult<()> {
Ok(())
}
fn infer_shape(&self, input_shapes: &[Shape]) -> TorshResult<Shape> {
input_shapes
.first()
.cloned()
.ok_or_else(|| TorshError::InvalidArgument("No input shapes provided".to_string()))
}
fn validate_types(&self, _input_types: &[DType]) -> TorshResult<()> {
Ok(())
}
fn infer_type(&self, input_types: &[DType]) -> TorshResult<DType> {
input_types
.first()
.copied()
.ok_or_else(|| TorshError::InvalidArgument("No input types provided".to_string()))
}
}
#[derive(Default)]
pub struct OperationRegistry {
operations: Arc<RwLock<HashMap<String, Box<dyn CustomOperation>>>>,
}
impl OperationRegistry {
pub fn new() -> Self {
Self {
operations: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register<T: CustomOperation + 'static>(&self, operation: T) -> TorshResult<()> {
let name = operation.name().to_string();
let mut ops = self.operations.write().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire write lock on operations".to_string())
})?;
if ops.contains_key(&name) {
return Err(TorshError::InvalidArgument(format!(
"Operation '{}' already registered",
name
)));
}
ops.insert(name, Box::new(operation));
Ok(())
}
pub fn get(&self, name: &str) -> TorshResult<Box<dyn CustomOperation>> {
let ops = self.operations.read().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire read lock on operations".to_string())
})?;
if let Some(operation) = ops.get(name) {
Ok(operation.clone_operation())
} else {
Err(TorshError::InvalidArgument(format!(
"Operation '{}' not found in registry",
name
)))
}
}
pub fn is_registered(&self, name: &str) -> bool {
if let Ok(ops) = self.operations.read() {
ops.contains_key(name)
} else {
false
}
}
pub fn list_operations(&self) -> Vec<String> {
if let Ok(ops) = self.operations.read() {
ops.keys().cloned().collect()
} else {
Vec::new()
}
}
pub fn execute(&self, name: &str, inputs: Vec<Tensor>) -> TorshResult<Tensor> {
let ops = self.operations.read().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire read lock on operations".to_string())
})?;
if let Some(operation) = ops.get(name) {
operation.validate_inputs(&inputs)?;
operation.execute(inputs)
} else {
Err(TorshError::InvalidArgument(format!(
"Operation '{}' not found in registry",
name
)))
}
}
pub fn get_operation_metadata(
&self,
name: &str,
) -> TorshResult<Option<HashMap<String, String>>> {
let ops = self.operations.read().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire read lock on operations".to_string())
})?;
if let Some(operation) = ops.get(name) {
Ok(operation.metadata())
} else {
Err(TorshError::InvalidArgument(format!(
"Operation '{}' not found in registry",
name
)))
}
}
pub fn clear(&self) -> TorshResult<()> {
let mut ops = self.operations.write().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire write lock on operations".to_string())
})?;
ops.clear();
Ok(())
}
pub fn operation_count(&self) -> usize {
if let Ok(ops) = self.operations.read() {
ops.len()
} else {
0
}
}
pub fn validate_operation(&self, name: &str, inputs: &[Tensor]) -> TorshResult<()> {
let ops = self.operations.read().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire read lock on operations".to_string())
})?;
if let Some(operation) = ops.get(name) {
operation.validate_inputs(inputs)
} else {
Err(TorshError::InvalidArgument(format!(
"Operation '{}' not found in registry",
name
)))
}
}
}
static GLOBAL_REGISTRY: std::sync::OnceLock<OperationRegistry> = std::sync::OnceLock::new();
pub fn global_registry() -> &'static OperationRegistry {
GLOBAL_REGISTRY.get_or_init(|| OperationRegistry::new())
}
pub fn register_operation<T: CustomOperation + 'static>(operation: T) -> TorshResult<()> {
global_registry().register(operation)
}
pub fn is_operation_registered(name: &str) -> bool {
global_registry().is_registered(name)
}
pub fn execute_registered_operation(name: &str, inputs: Vec<Tensor>) -> TorshResult<Tensor> {
global_registry().execute(name, inputs)
}