use crate::{core_ops::Tensor, TensorElement};
use scirs2_core::numeric::FromPrimitive;
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use torsh_core::error::{Result, TorshError};
pub trait CustomOperation<T: TensorElement>: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn forward(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<Vec<Tensor<T>>>;
fn backward(
&self,
grad_outputs: &[Tensor<T>],
inputs: &[Tensor<T>],
_outputs: &[Tensor<T>],
_params: &OperationParams,
) -> Result<Vec<Option<Tensor<T>>>> {
let _ = grad_outputs.is_empty();
Ok(vec![None; inputs.len()])
}
fn validate_inputs(&self, inputs: &[Tensor<T>], _params: &OperationParams) -> Result<()> {
if inputs.is_empty() {
return Err(torsh_core::error::TorshError::InvalidShape(
"Operation requires at least one input tensor".to_string(),
));
}
for (idx, input) in inputs.iter().enumerate() {
let _ = (idx, input.shape.is_empty()); }
Ok(())
}
fn output_shapes(
&self,
input_shapes: &[Vec<usize>],
params: &OperationParams,
) -> Result<Vec<Vec<usize>>>;
fn supports_autograd(&self) -> bool {
true }
fn num_inputs(&self) -> usize;
fn num_outputs(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct OperationParams {
pub strings: HashMap<String, String>,
pub integers: HashMap<String, i64>,
pub floats: HashMap<String, f64>,
pub booleans: HashMap<String, bool>,
pub vectors: HashMap<String, Vec<f64>>,
pub shapes: HashMap<String, Vec<usize>>,
}
impl OperationParams {
pub fn new() -> Self {
Self {
strings: HashMap::new(),
integers: HashMap::new(),
floats: HashMap::new(),
booleans: HashMap::new(),
vectors: HashMap::new(),
shapes: HashMap::new(),
}
}
pub fn with_string(mut self, key: &str, value: &str) -> Self {
self.strings.insert(key.to_string(), value.to_string());
self
}
pub fn with_int(mut self, key: &str, value: i64) -> Self {
self.integers.insert(key.to_string(), value);
self
}
pub fn with_float(mut self, key: &str, value: f64) -> Self {
self.floats.insert(key.to_string(), value);
self
}
pub fn with_bool(mut self, key: &str, value: bool) -> Self {
self.booleans.insert(key.to_string(), value);
self
}
pub fn with_vector(mut self, key: &str, value: Vec<f64>) -> Self {
self.vectors.insert(key.to_string(), value);
self
}
pub fn with_shape(mut self, key: &str, value: Vec<usize>) -> Self {
self.shapes.insert(key.to_string(), value);
self
}
pub fn get_string(&self, key: &str) -> Option<&String> {
self.strings.get(key)
}
pub fn get_int(&self, key: &str) -> Option<i64> {
self.integers.get(key).copied()
}
pub fn get_float(&self, key: &str) -> Option<f64> {
self.floats.get(key).copied()
}
pub fn get_bool(&self, key: &str) -> Option<bool> {
self.booleans.get(key).copied()
}
pub fn get_vector(&self, key: &str) -> Option<&Vec<f64>> {
self.vectors.get(key)
}
pub fn get_shape(&self, key: &str) -> Option<&Vec<usize>> {
self.shapes.get(key)
}
}
impl Default for OperationParams {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct OperationMetadata {
pub name: String,
pub description: String,
pub num_inputs: usize,
pub num_outputs: usize,
pub supports_autograd: bool,
pub data_type: TypeId,
pub version: String,
pub author: Option<String>,
pub tags: Vec<String>,
}
pub struct CustomOperationRegistry {
operations: RwLock<HashMap<(TypeId, String), Arc<dyn Any + Send + Sync>>>,
metadata: RwLock<HashMap<(TypeId, String), OperationMetadata>>,
}
impl CustomOperationRegistry {
pub fn new() -> Self {
Self {
operations: RwLock::new(HashMap::new()),
metadata: RwLock::new(HashMap::new()),
}
}
pub fn register<T: TensorElement + 'static>(
&self,
operation: Box<dyn CustomOperation<T>>,
version: &str,
author: Option<String>,
tags: Vec<String>,
) -> Result<()> {
let type_id = TypeId::of::<T>();
let name = operation.name().to_string();
let key = (type_id, name.clone());
let metadata = OperationMetadata {
name: name.clone(),
description: operation.description().to_string(),
num_inputs: operation.num_inputs(),
num_outputs: operation.num_outputs(),
supports_autograd: operation.supports_autograd(),
data_type: type_id,
version: version.to_string(),
author,
tags,
};
{
let mut ops = self
.operations
.write()
.expect("lock should not be poisoned");
let mut meta = self.metadata.write().expect("lock should not be poisoned");
if ops.contains_key(&key) {
return Err(TorshError::InvalidArgument(format!(
"Operation '{}' for type {:?} is already registered",
name, type_id
)));
}
let arc_op: Arc<dyn CustomOperation<T>> = Arc::from(operation);
let boxed_any: Arc<dyn Any + Send + Sync> = Arc::new(arc_op);
ops.insert(key.clone(), boxed_any);
meta.insert(key, metadata);
}
Ok(())
}
pub fn get<T: TensorElement + 'static>(
&self,
name: &str,
) -> Option<Arc<dyn CustomOperation<T>>> {
let type_id = TypeId::of::<T>();
let key = (type_id, name.to_string());
let ops = self.operations.read().expect("lock should not be poisoned");
ops.get(&key).and_then(|arc_any| {
arc_any
.downcast_ref::<Arc<dyn CustomOperation<T>>>()
.map(|arc_op| Arc::clone(arc_op))
})
}
pub fn get_metadata<T: TensorElement + 'static>(
&self,
name: &str,
) -> Option<OperationMetadata> {
let type_id = TypeId::of::<T>();
let key = (type_id, name.to_string());
let meta = self.metadata.read().expect("lock should not be poisoned");
meta.get(&key).cloned()
}
pub fn list_operations<T: TensorElement + 'static>(&self) -> Vec<String> {
let type_id = TypeId::of::<T>();
let meta = self.metadata.read().expect("lock should not be poisoned");
meta.keys()
.filter(|(tid, _)| *tid == type_id)
.map(|(_, name)| name.clone())
.collect()
}
pub fn unregister<T: TensorElement + 'static>(&self, name: &str) -> Result<()> {
let type_id = TypeId::of::<T>();
let key = (type_id, name.to_string());
let mut ops = self
.operations
.write()
.expect("lock should not be poisoned");
let mut meta = self.metadata.write().expect("lock should not be poisoned");
if ops.remove(&key).is_none() {
return Err(TorshError::InvalidArgument(format!(
"Operation '{}' for type {:?} is not registered",
name, type_id
)));
}
meta.remove(&key);
Ok(())
}
pub fn is_registered<T: TensorElement + 'static>(&self, name: &str) -> bool {
let type_id = TypeId::of::<T>();
let key = (type_id, name.to_string());
let ops = self.operations.read().expect("lock should not be poisoned");
ops.contains_key(&key)
}
pub fn count(&self) -> usize {
let ops = self.operations.read().expect("lock should not be poisoned");
ops.len()
}
pub fn clear(&self) {
let mut ops = self
.operations
.write()
.expect("lock should not be poisoned");
let mut meta = self.metadata.write().expect("lock should not be poisoned");
ops.clear();
meta.clear();
}
}
impl Default for CustomOperationRegistry {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_REGISTRY: std::sync::LazyLock<CustomOperationRegistry> =
std::sync::LazyLock::new(CustomOperationRegistry::new);
pub fn global_registry() -> &'static CustomOperationRegistry {
&GLOBAL_REGISTRY
}
pub trait TensorCustomOps<T: TensorElement> {
fn apply_custom_op(
&self,
op_name: &str,
other_inputs: &[&Tensor<T>],
params: &OperationParams,
) -> Result<Vec<Tensor<T>>>;
fn apply_custom_op_with_registry(
&self,
registry: &CustomOperationRegistry,
op_name: &str,
other_inputs: &[&Tensor<T>],
params: &OperationParams,
) -> Result<Vec<Tensor<T>>>;
}
impl<T: TensorElement + 'static> TensorCustomOps<T> for Tensor<T> {
fn apply_custom_op(
&self,
op_name: &str,
other_inputs: &[&Tensor<T>],
params: &OperationParams,
) -> Result<Vec<Tensor<T>>> {
self.apply_custom_op_with_registry(global_registry(), op_name, other_inputs, params)
}
fn apply_custom_op_with_registry(
&self,
registry: &CustomOperationRegistry,
op_name: &str,
other_inputs: &[&Tensor<T>],
params: &OperationParams,
) -> Result<Vec<Tensor<T>>> {
let operation = registry.get::<T>(op_name).ok_or_else(|| {
TorshError::InvalidArgument(format!(
"Custom operation '{}' not found for type",
op_name
))
})?;
let mut inputs = vec![self.clone()];
inputs.extend(other_inputs.iter().map(|&t| t.clone()));
operation.validate_inputs(&inputs, params)?;
if inputs.len() != operation.num_inputs() {
return Err(TorshError::InvalidArgument(format!(
"Operation '{}' expects {} inputs, got {}",
op_name,
operation.num_inputs(),
inputs.len()
)));
}
let outputs = operation.forward(&inputs, params)?;
if outputs.len() != operation.num_outputs() {
return Err(TorshError::InvalidArgument(format!(
"Operation '{}' produced {} outputs, expected {}",
op_name,
outputs.len(),
operation.num_outputs()
)));
}
Ok(outputs)
}
}
pub struct ScaleOperation;
impl<T: TensorElement + Copy + std::ops::Mul<Output = T> + num_traits::FromPrimitive>
CustomOperation<T> for ScaleOperation
{
fn name(&self) -> &str {
"scale"
}
fn description(&self) -> &str {
"Scales tensor elements by a constant factor"
}
fn forward(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<Vec<Tensor<T>>> {
if inputs.len() != 1 {
return Err(TorshError::InvalidArgument(
"Scale operation requires exactly 1 input".to_string(),
));
}
let scale = params.get_float("scale").unwrap_or(1.0);
let scale_val = <T as FromPrimitive>::from_f64(scale).ok_or_else(|| {
TorshError::InvalidArgument("Cannot convert scale factor to tensor type".to_string())
})?;
let result = inputs[0].mul_scalar(scale_val)?;
Ok(vec![result])
}
fn backward(
&self,
grad_outputs: &[Tensor<T>],
_inputs: &[Tensor<T>],
_outputs: &[Tensor<T>],
params: &OperationParams,
) -> Result<Vec<Option<Tensor<T>>>> {
let scale = params.get_float("scale").unwrap_or(1.0);
let scale_val = <T as FromPrimitive>::from_f64(scale).ok_or_else(|| {
TorshError::InvalidArgument("Cannot convert scale factor to tensor type".to_string())
})?;
let grad_input = grad_outputs[0].mul_scalar(scale_val)?;
Ok(vec![Some(grad_input)])
}
fn output_shapes(
&self,
input_shapes: &[Vec<usize>],
_params: &OperationParams,
) -> Result<Vec<Vec<usize>>> {
if input_shapes.len() != 1 {
return Err(TorshError::InvalidArgument(
"Scale operation requires exactly 1 input".to_string(),
));
}
Ok(vec![input_shapes[0].clone()])
}
fn num_inputs(&self) -> usize {
1
}
fn num_outputs(&self) -> usize {
1
}
}
pub struct ConcatOperation;
impl<T: TensorElement + Copy> CustomOperation<T> for ConcatOperation {
fn name(&self) -> &str {
"concat"
}
fn description(&self) -> &str {
"Concatenates tensors along a specified axis"
}
fn forward(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<Vec<Tensor<T>>> {
if inputs.len() < 2 {
return Err(TorshError::InvalidArgument(
"Concat operation requires at least 2 inputs".to_string(),
));
}
let axis = params.get_int("axis").unwrap_or(0) as usize;
let input_refs: Vec<&Tensor<T>> = inputs.iter().collect();
let result = Tensor::cat(&input_refs, axis as i32)?;
Ok(vec![result])
}
fn backward(
&self,
grad_outputs: &[Tensor<T>],
inputs: &[Tensor<T>],
_outputs: &[Tensor<T>],
params: &OperationParams,
) -> Result<Vec<Option<Tensor<T>>>> {
let axis = params.get_int("axis").unwrap_or(0) as usize;
let grad_output = &grad_outputs[0];
let mut split_sizes = Vec::new();
for input in inputs {
split_sizes.push(input.shape().dims()[axis]);
}
let mut grad_inputs = Vec::new();
let mut start = 0;
for &size in &split_sizes {
let end = start + size;
let slice = grad_output.slice_tensor(axis, start, end)?;
grad_inputs.push(Some(slice));
start = end;
}
Ok(grad_inputs)
}
fn output_shapes(
&self,
input_shapes: &[Vec<usize>],
params: &OperationParams,
) -> Result<Vec<Vec<usize>>> {
if input_shapes.len() < 2 {
return Err(TorshError::InvalidArgument(
"Concat operation requires at least 2 inputs".to_string(),
));
}
let axis = params.get_int("axis").unwrap_or(0) as usize;
let mut output_shape = input_shapes[0].clone();
if axis >= output_shape.len() {
return Err(TorshError::InvalidArgument(format!(
"Concat axis {} out of bounds for {} dimensions",
axis,
output_shape.len()
)));
}
let mut total_size = output_shape[axis];
for shape in &input_shapes[1..] {
if shape.len() != output_shape.len() {
return Err(TorshError::InvalidArgument(
"All tensors must have the same number of dimensions".to_string(),
));
}
for (i, (&dim1, &dim2)) in output_shape.iter().zip(shape.iter()).enumerate() {
if i != axis && dim1 != dim2 {
return Err(TorshError::InvalidArgument(format!(
"Dimension {} mismatch: {} vs {}",
i, dim1, dim2
)));
}
}
total_size += shape[axis];
}
output_shape[axis] = total_size;
Ok(vec![output_shape])
}
fn num_inputs(&self) -> usize {
2 }
fn num_outputs(&self) -> usize {
1
}
fn validate_inputs(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<()> {
if inputs.len() < 2 {
return Err(TorshError::InvalidArgument(
"Concat operation requires at least 2 inputs".to_string(),
));
}
let axis = params.get_int("axis").unwrap_or(0) as usize;
let first_tensor_shape = inputs[0].shape();
let first_shape = first_tensor_shape.dims();
if axis >= first_shape.len() {
return Err(TorshError::InvalidArgument(format!(
"Concat axis {} out of bounds for {} dimensions",
axis,
first_shape.len()
)));
}
for (i, tensor) in inputs.iter().enumerate().skip(1) {
let tensor_shape = tensor.shape();
let shape = tensor_shape.dims();
if shape.len() != first_shape.len() {
return Err(TorshError::InvalidArgument(format!(
"Tensor {} has {} dimensions, expected {}",
i,
shape.len(),
first_shape.len()
)));
}
for (dim_idx, (&dim1, &dim2)) in first_shape.iter().zip(shape.iter()).enumerate() {
if dim_idx != axis && dim1 != dim2 {
return Err(TorshError::InvalidArgument(format!(
"Tensor {} dimension {} mismatch: {} vs {}",
i, dim_idx, dim1, dim2
)));
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
#[test]
fn test_operation_params() {
let params = OperationParams::new()
.with_string("mode", "linear")
.with_int("axis", 1)
.with_float("scale", 2.5)
.with_bool("inplace", false)
.with_vector("weights", vec![1.0, 2.0, 3.0])
.with_shape("target_shape", vec![10, 20]);
assert_eq!(params.get_string("mode"), Some(&"linear".to_string()));
assert_eq!(params.get_int("axis"), Some(1));
assert_eq!(params.get_float("scale"), Some(2.5));
assert_eq!(params.get_bool("inplace"), Some(false));
assert_eq!(params.get_vector("weights"), Some(&vec![1.0, 2.0, 3.0]));
assert_eq!(params.get_shape("target_shape"), Some(&vec![10, 20]));
assert_eq!(params.get_string("nonexistent"), None);
}
#[test]
fn test_registry_operations() {
let registry = CustomOperationRegistry::new();
let scale_op = Box::new(ScaleOperation);
registry
.register::<f32>(
scale_op,
"1.0.0",
Some("Test".to_string()),
vec!["math".to_string()],
)
.expect("registration should succeed");
assert!(registry.is_registered::<f32>("scale"));
assert!(!registry.is_registered::<f32>("nonexistent"));
let metadata = registry
.get_metadata::<f32>("scale")
.expect("metadata retrieval should succeed");
assert_eq!(metadata.name, "scale");
assert_eq!(
metadata.description,
"Scales tensor elements by a constant factor"
);
assert_eq!(metadata.num_inputs, 1);
assert_eq!(metadata.num_outputs, 1);
assert_eq!(metadata.version, "1.0.0");
assert_eq!(metadata.author, Some("Test".to_string()));
assert_eq!(metadata.tags, vec!["math".to_string()]);
let ops = registry.list_operations::<f32>();
assert_eq!(ops, vec!["scale".to_string()]);
registry
.unregister::<f32>("scale")
.expect("unregister should succeed");
assert!(!registry.is_registered::<f32>("scale"));
}
#[test]
fn test_scale_operation() {
let registry = CustomOperationRegistry::new();
let scale_op = Box::new(ScaleOperation);
registry
.register::<f32>(scale_op, "1.0.0", None, vec![])
.expect("unregister should succeed");
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let tensor = Tensor::from_data(data, vec![2, 2], DeviceType::Cpu)
.expect("tensor creation should succeed");
let params = OperationParams::new().with_float("scale", 2.0);
let results = tensor
.apply_custom_op_with_registry(®istry, "scale", &[], ¶ms)
.expect("tensor creation should succeed");
assert_eq!(results.len(), 1);
let result = &results[0];
let expected_data = vec![2.0f32, 4.0, 6.0, 8.0];
assert_eq!(
result.data().expect("data retrieval should succeed"),
expected_data
);
}
#[test]
fn test_concat_operation() {
let registry = CustomOperationRegistry::new();
let concat_op = Box::new(ConcatOperation);
registry
.register::<f32>(concat_op, "1.0.0", None, vec![])
.expect("registration should succeed");
let data1 = vec![1.0f32, 2.0];
let tensor1 = Tensor::from_data(data1, vec![2], DeviceType::Cpu)
.expect("tensor creation should succeed");
let data2 = vec![3.0f32, 4.0];
let tensor2 = Tensor::from_data(data2, vec![2], DeviceType::Cpu)
.expect("tensor creation should succeed");
let params = OperationParams::new().with_int("axis", 0);
let results = tensor1
.apply_custom_op_with_registry(®istry, "concat", &[&tensor2], ¶ms)
.expect("tensor creation should succeed");
assert_eq!(results.len(), 1);
let result = &results[0];
assert_eq!(result.shape().dims(), &[4]); let expected_data = vec![1.0f32, 2.0, 3.0, 4.0];
assert_eq!(
result.data().expect("data retrieval should succeed"),
expected_data
);
}
#[test]
fn test_operation_validation() {
let registry = CustomOperationRegistry::new();
let concat_op = Box::new(ConcatOperation);
registry
.register::<f32>(concat_op, "1.0.0", None, vec![])
.expect("registration should succeed");
let data1 = vec![1.0f32, 2.0];
let tensor1 = Tensor::from_data(data1, vec![2], DeviceType::Cpu)
.expect("tensor creation should succeed");
let data2 = vec![3.0f32, 4.0, 5.0, 6.0];
let tensor2 = Tensor::from_data(data2, vec![2, 2], DeviceType::Cpu)
.expect("tensor creation should succeed");
let params = OperationParams::new().with_int("axis", 0);
let result =
tensor1.apply_custom_op_with_registry(®istry, "concat", &[&tensor2], ¶ms);
assert!(result.is_err());
}
#[test]
fn test_output_shape_inference() {
let concat_op = ConcatOperation;
let input_shapes = vec![vec![3], vec![4]];
let params = OperationParams::new().with_int("axis", 0);
let output_shapes = <ConcatOperation as CustomOperation<f32>>::output_shapes(
&concat_op,
&input_shapes,
¶ms,
)
.expect("custom dtype operation should succeed");
assert_eq!(output_shapes, vec![vec![7]]); }
#[test]
fn test_error_cases() {
let registry = CustomOperationRegistry::new();
let scale_op1 = Box::new(ScaleOperation);
let scale_op2 = Box::new(ScaleOperation);
registry
.register::<f32>(scale_op1, "1.0.0", None, vec![])
.expect("registration should succeed");
let result = registry.register::<f32>(scale_op2, "1.0.0", None, vec![]);
assert!(result.is_err());
let result = registry.unregister::<f32>("nonexistent");
assert!(result.is_err());
let data = vec![1.0f32, 2.0];
let tensor = Tensor::from_data(data, vec![1, 2], DeviceType::Cpu)
.expect("tensor creation should succeed");
let params = OperationParams::new();
let result = tensor.apply_custom_op_with_registry(®istry, "nonexistent", &[], ¶ms);
assert!(result.is_err());
}
#[test]
fn test_global_registry() {
let registry = global_registry();
let scale_op = Box::new(ScaleOperation);
registry
.register::<f32>(scale_op, "1.0.0", None, vec![])
.expect("registration should succeed");
let data = vec![1.0f32, 2.0, 3.0];
let tensor = Tensor::from_data(data, vec![3], DeviceType::Cpu)
.expect("tensor creation should succeed");
let params = OperationParams::new().with_float("scale", 3.0);
let results = tensor
.apply_custom_op("scale", &[], ¶ms)
.expect("custom_op should succeed");
assert_eq!(results.len(), 1);
let expected_data = vec![3.0f32, 6.0, 9.0];
assert_eq!(
results[0].data().expect("data retrieval should succeed"),
expected_data
);
registry
.unregister::<f32>("scale")
.expect("unregister should succeed");
}
}