use crate::{FxGraph, Node, TorshResult};
use petgraph::graph::NodeIndex;
use petgraph::Direction;
use std::any::TypeId;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use torsh_core::{
dtype::{CustomDTypeRegistry, CustomTensorElement, DType, ExtendedDType},
error::TorshError,
shape::Shape,
};
use torsh_tensor::Tensor;
#[derive(Debug, Clone)]
pub struct ExtendedShapeInfo {
pub shape: Shape,
pub dtype: ExtendedDType,
}
impl ExtendedShapeInfo {
pub fn new(shape: Shape, dtype: ExtendedDType) -> Self {
Self { shape, dtype }
}
pub fn from_standard(shape: Shape, dtype: DType) -> Self {
Self {
shape,
dtype: ExtendedDType::Standard(dtype),
}
}
pub fn from_tensor(tensor: &Tensor) -> Self {
Self {
shape: tensor.shape().clone(),
dtype: ExtendedDType::Standard(tensor.dtype()),
}
}
pub fn byte_size(&self) -> usize {
let element_count: usize = self.shape.dims().iter().product();
element_count * self.dtype.size()
}
pub fn can_promote_to(&self, target: &ExtendedDType) -> bool {
match (&self.dtype, target) {
(ExtendedDType::Standard(src), ExtendedDType::Standard(dst)) => {
use torsh_core::dtype::TypePromotion;
DType::can_promote_to(*src, *dst)
}
(ExtendedDType::Custom(src_id), ExtendedDType::Custom(dst_id)) => {
src_id == dst_id
}
_ => false, }
}
}
pub struct ExtendedShapeInferenceContext {
shapes: HashMap<NodeIndex, ExtendedShapeInfo>,
}
impl ExtendedShapeInferenceContext {
pub fn new() -> Self {
Self {
shapes: HashMap::new(),
}
}
pub fn set_shape(&mut self, node: NodeIndex, shape_info: ExtendedShapeInfo) {
self.shapes.insert(node, shape_info);
}
pub fn get_shape(&self, node: NodeIndex) -> Option<&ExtendedShapeInfo> {
self.shapes.get(&node)
}
pub fn infer_shape_extended(
&mut self,
graph: &FxGraph,
node_idx: NodeIndex,
) -> TorshResult<ExtendedShapeInfo> {
let node = graph
.get_node(node_idx)
.ok_or_else(|| TorshError::InvalidArgument("Node not found".to_string()))?;
match node {
Node::Input(_) => {
self.get_shape(node_idx).cloned().ok_or_else(|| {
TorshError::InvalidArgument("Input node shape not set".to_string())
})
}
Node::Call(op_name, _) => self.infer_operation_shape_extended(graph, node_idx, op_name),
Node::Conditional { .. } => {
let input_shapes = self.get_input_shapes_extended(graph, node_idx)?;
if let Some(first_shape) = input_shapes.first() {
Ok(first_shape.clone())
} else {
Ok(ExtendedShapeInfo::from_standard(
Shape::new(vec![1]),
DType::F32,
))
}
}
Node::Loop { .. } => {
let input_shapes = self.get_input_shapes_extended(graph, node_idx)?;
if let Some(first_shape) = input_shapes.first() {
Ok(first_shape.clone())
} else {
Ok(ExtendedShapeInfo::from_standard(
Shape::new(vec![1]),
DType::F32,
))
}
}
Node::Merge { .. } => {
let input_shapes = self.get_input_shapes_extended(graph, node_idx)?;
self.validate_merge_shapes_extended(&input_shapes)?;
if let Some(first_shape) = input_shapes.first() {
Ok(first_shape.clone())
} else {
Ok(ExtendedShapeInfo::from_standard(
Shape::new(vec![1]),
DType::F32,
))
}
}
Node::GetAttr { .. } => {
let input_shapes = self.get_input_shapes_extended(graph, node_idx)?;
if let Some(first_shape) = input_shapes.first() {
Ok(first_shape.clone())
} else {
Ok(ExtendedShapeInfo::from_standard(
Shape::new(vec![1]),
DType::F32,
))
}
}
Node::Output => {
let input_shapes = self.get_input_shapes_extended(graph, node_idx)?;
if let Some(first_shape) = input_shapes.first() {
Ok(first_shape.clone())
} else {
Ok(ExtendedShapeInfo::from_standard(
Shape::new(vec![1]),
DType::F32,
))
}
}
}
}
fn get_input_shapes_extended(
&self,
graph: &FxGraph,
node_idx: NodeIndex,
) -> TorshResult<Vec<ExtendedShapeInfo>> {
let predecessors: Vec<_> = graph
.graph
.neighbors_directed(node_idx, Direction::Incoming)
.collect();
let mut input_shapes = Vec::new();
for pred_idx in predecessors {
if let Some(shape_info) = self.get_shape(pred_idx) {
input_shapes.push(shape_info.clone());
}
}
Ok(input_shapes)
}
fn infer_operation_shape_extended(
&mut self,
graph: &FxGraph,
node_idx: NodeIndex,
_op_name: &str,
) -> TorshResult<ExtendedShapeInfo> {
let input_shapes = self.get_input_shapes_extended(graph, node_idx)?;
if crate::interpreter::is_operation_registered(_op_name) {
return self.infer_custom_operation_shape(_op_name, &input_shapes);
}
match _op_name {
"add" | "sub" | "mul" | "div" => {
if input_shapes.len() >= 2 {
self.broadcast_shapes_extended(&input_shapes[0], &input_shapes[1])
} else {
Err(TorshError::InvalidArgument(
"Binary operations require 2 inputs".to_string(),
))
}
}
"relu" | "sigmoid" | "tanh" | "gelu" => {
Ok(input_shapes.first().cloned().unwrap_or_else(|| {
ExtendedShapeInfo::from_standard(Shape::new(vec![1]), DType::F32)
}))
}
"matmul" => {
if input_shapes.len() >= 2 {
self.matmul_shape_extended(&input_shapes[0], &input_shapes[1])
} else {
Err(TorshError::InvalidArgument(
"Matmul requires 2 inputs".to_string(),
))
}
}
"linear" => {
if input_shapes.len() >= 2 {
let input_dims = input_shapes[0].shape.dims();
let weight_dims = input_shapes[1].shape.dims();
if input_dims.len() >= 2 && weight_dims.len() >= 2 {
let batch_size = input_dims[0];
let out_features = weight_dims[0];
Ok(ExtendedShapeInfo::new(
Shape::new(vec![batch_size, out_features]),
input_shapes[0].dtype.clone(),
))
} else {
Err(TorshError::InvalidArgument(
"Invalid shapes for linear operation".to_string(),
))
}
} else {
Err(TorshError::InvalidArgument(
"Linear requires at least 2 inputs".to_string(),
))
}
}
"conv2d" => {
if !input_shapes.is_empty() {
let input_dims = input_shapes[0].shape.dims();
if input_dims.len() >= 4 {
Ok(input_shapes[0].clone())
} else {
Err(TorshError::InvalidArgument(
"Conv2d requires 4D input".to_string(),
))
}
} else {
Err(TorshError::InvalidArgument(
"Conv2d requires input".to_string(),
))
}
}
"cast" => {
if input_shapes.len() >= 2 {
Ok(ExtendedShapeInfo::new(
input_shapes[0].shape.clone(),
input_shapes[1].dtype.clone(),
))
} else {
Ok(input_shapes.first().cloned().unwrap_or_else(|| {
ExtendedShapeInfo::from_standard(Shape::new(vec![1]), DType::F32)
}))
}
}
"custom_promote" => {
if input_shapes.len() >= 2 {
self.promote_custom_types(&input_shapes[0], &input_shapes[1])
} else {
Ok(input_shapes.first().cloned().unwrap_or_else(|| {
ExtendedShapeInfo::from_standard(Shape::new(vec![1]), DType::F32)
}))
}
}
_ => {
Ok(input_shapes.first().cloned().unwrap_or_else(|| {
ExtendedShapeInfo::from_standard(Shape::new(vec![1]), DType::F32)
}))
}
}
}
fn infer_custom_operation_shape(
&self,
_op_name: &str,
input_shapes: &[ExtendedShapeInfo],
) -> TorshResult<ExtendedShapeInfo> {
if let Some(first_shape) = input_shapes.first() {
Ok(first_shape.clone())
} else {
Ok(ExtendedShapeInfo::from_standard(
Shape::new(vec![1]),
DType::F32,
))
}
}
fn broadcast_shapes_extended(
&self,
shape1: &ExtendedShapeInfo,
shape2: &ExtendedShapeInfo,
) -> TorshResult<ExtendedShapeInfo> {
let result_dtype = self.promote_types(&shape1.dtype, &shape2.dtype)?;
let dims1 = shape1.shape.dims();
let dims2 = shape2.shape.dims();
let max_len = dims1.len().max(dims2.len());
let mut result_dims = vec![1; max_len];
for i in 0..max_len {
let dim1 = if i < dims1.len() {
dims1[dims1.len() - 1 - i]
} else {
1
};
let dim2 = if i < dims2.len() {
dims2[dims2.len() - 1 - i]
} else {
1
};
if dim1 == dim2 || dim1 == 1 || dim2 == 1 {
result_dims[max_len - 1 - i] = dim1.max(dim2);
} else {
return Err(TorshError::InvalidArgument(format!(
"Cannot broadcast shapes {dims1:?} and {dims2:?}"
)));
}
}
Ok(ExtendedShapeInfo::new(
Shape::new(result_dims),
result_dtype,
))
}
fn matmul_shape_extended(
&self,
shape1: &ExtendedShapeInfo,
shape2: &ExtendedShapeInfo,
) -> TorshResult<ExtendedShapeInfo> {
let dims1 = shape1.shape.dims();
let dims2 = shape2.shape.dims();
if dims1.len() < 2 || dims2.len() < 2 {
return Err(TorshError::InvalidArgument(
"Matmul requires at least 2D tensors".to_string(),
));
}
let inner1 = dims1[dims1.len() - 1];
let inner2 = dims2[dims2.len() - 2];
if inner1 != inner2 {
return Err(TorshError::InvalidArgument(format!(
"Incompatible dimensions for matmul: {inner1} vs {inner2}"
)));
}
let mut result_dims = Vec::new();
let batch_dims1 = &dims1[..dims1.len() - 2];
let batch_dims2 = &dims2[..dims2.len() - 2];
let max_batch_len = batch_dims1.len().max(batch_dims2.len());
for i in 0..max_batch_len {
let dim1 = if i < batch_dims1.len() {
batch_dims1[batch_dims1.len() - 1 - i]
} else {
1
};
let dim2 = if i < batch_dims2.len() {
batch_dims2[batch_dims2.len() - 1 - i]
} else {
1
};
result_dims.push(dim1.max(dim2));
}
result_dims.reverse();
result_dims.push(dims1[dims1.len() - 2]); result_dims.push(dims2[dims2.len() - 1]);
let result_dtype = self.promote_types(&shape1.dtype, &shape2.dtype)?;
Ok(ExtendedShapeInfo::new(
Shape::new(result_dims),
result_dtype,
))
}
fn promote_types(
&self,
dtype1: &ExtendedDType,
dtype2: &ExtendedDType,
) -> TorshResult<ExtendedDType> {
match (dtype1, dtype2) {
(ExtendedDType::Standard(dt1), ExtendedDType::Standard(dt2)) => {
use torsh_core::dtype::TypePromotion;
Ok(ExtendedDType::Standard(DType::promote_types(*dt1, *dt2)))
}
(ExtendedDType::Custom(id1), ExtendedDType::Custom(id2)) => {
if id1 == id2 {
Ok(dtype1.clone())
} else {
Err(TorshError::InvalidArgument(
"Cannot promote between different custom types".to_string(),
))
}
}
_ => {
Err(TorshError::InvalidArgument(
"Cannot promote between standard and custom types".to_string(),
))
}
}
}
fn promote_custom_types(
&self,
shape1: &ExtendedShapeInfo,
shape2: &ExtendedShapeInfo,
) -> TorshResult<ExtendedShapeInfo> {
if std::mem::discriminant(&shape1.dtype) == std::mem::discriminant(&shape2.dtype) {
Ok(shape1.clone())
} else {
Err(TorshError::InvalidArgument(
"Custom type promotion not supported for these types".to_string(),
))
}
}
fn validate_merge_shapes_extended(
&self,
input_shapes: &[ExtendedShapeInfo],
) -> TorshResult<()> {
if input_shapes.is_empty() {
return Ok(());
}
let first_shape = &input_shapes[0];
for shape_info in &input_shapes[1..] {
if shape_info.shape.dims() != first_shape.shape.dims() {
return Err(TorshError::InvalidArgument(
"All inputs to merge operation must have the same shape".to_string(),
));
}
if std::mem::discriminant(&shape_info.dtype)
!= std::mem::discriminant(&first_shape.dtype)
{
return Err(TorshError::InvalidArgument(
"All inputs to merge operation must have compatible types".to_string(),
));
}
}
Ok(())
}
}
pub trait ExtendedCustomOperation: Send + Sync {
fn execute_extended(&self, inputs: Vec<Tensor>) -> TorshResult<Tensor>;
fn name(&self) -> &str;
fn infer_shape_extended(
&self,
input_shapes: &[ExtendedShapeInfo],
) -> TorshResult<ExtendedShapeInfo>;
fn validate_types_extended(&self, input_types: &[ExtendedDType]) -> TorshResult<()>;
fn infer_type_extended(&self, input_types: &[ExtendedDType]) -> TorshResult<ExtendedDType>;
fn metadata(&self) -> Option<HashMap<String, String>> {
None
}
fn supports_custom_type(&self, _type_id: TypeId) -> bool {
false }
fn supported_custom_types(&self) -> Vec<TypeId> {
Vec::new()
}
}
#[derive(Default)]
pub struct ExtendedOperationRegistry {
operations: Arc<RwLock<HashMap<String, Box<dyn ExtendedCustomOperation>>>>,
}
impl ExtendedOperationRegistry {
pub fn new() -> Self {
Self {
operations: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register<T: ExtendedCustomOperation + 'static>(&self, operation: T) -> TorshResult<()> {
let name = operation.name().to_string();
let mut ops = self.operations.write().map_err(|_| {
TorshError::InvalidArgument(
"Failed to acquire write lock on extended operations".to_string(),
)
})?;
if ops.contains_key(&name) {
return Err(TorshError::InvalidArgument(format!(
"Extended operation '{name}' already registered"
)));
}
ops.insert(name, Box::new(operation));
Ok(())
}
pub fn is_registered(&self, name: &str) -> bool {
if let Ok(ops) = self.operations.read() {
ops.contains_key(name)
} else {
false
}
}
pub fn list_operations(&self) -> Vec<String> {
if let Ok(ops) = self.operations.read() {
ops.keys().cloned().collect()
} else {
Vec::new()
}
}
pub fn execute(&self, name: &str, inputs: Vec<Tensor>) -> TorshResult<Tensor> {
let ops = self.operations.read().map_err(|_| {
TorshError::InvalidArgument(
"Failed to acquire read lock on extended operations".to_string(),
)
})?;
if let Some(operation) = ops.get(name) {
operation.execute_extended(inputs)
} else {
Err(TorshError::InvalidArgument(format!(
"Extended operation '{name}' not found in registry"
)))
}
}
pub fn get_operations_for_type(&self, type_id: TypeId) -> Vec<String> {
if let Ok(ops) = self.operations.read() {
ops.iter()
.filter(|(_, op)| op.supports_custom_type(type_id))
.map(|(name, _)| name.clone())
.collect()
} else {
Vec::new()
}
}
}
static GLOBAL_EXTENDED_REGISTRY: std::sync::OnceLock<ExtendedOperationRegistry> =
std::sync::OnceLock::new();
pub fn global_extended_registry() -> &'static ExtendedOperationRegistry {
GLOBAL_EXTENDED_REGISTRY.get_or_init(|| ExtendedOperationRegistry::new())
}
pub fn register_extended_operation<T: ExtendedCustomOperation + 'static>(
operation: T,
) -> TorshResult<()> {
global_extended_registry().register(operation)
}
pub struct CustomTypeUtils;
impl CustomTypeUtils {
pub fn create_custom_tensor<T: CustomTensorElement>(
_data: Vec<T>,
_shape: Shape,
) -> TorshResult<Tensor> {
Err(TorshError::InvalidArgument(
"Custom tensor creation requires tensor crate support".to_string(),
))
}
pub fn register_custom_type<T: CustomTensorElement>() -> TorshResult<()> {
CustomDTypeRegistry::register::<T>()
}
pub fn is_custom_type_supported(type_id: TypeId) -> bool {
CustomDTypeRegistry::is_registered(type_id)
}
pub fn get_compatible_operations(type_id: TypeId) -> Vec<String> {
global_extended_registry().get_operations_for_type(type_id)
}
pub fn validate_graph_custom_types(graph: &FxGraph) -> TorshResult<()> {
for (_node_idx, node) in graph.nodes() {
match node {
Node::Call(op_name, _) => {
if !crate::interpreter::is_operation_registered(op_name)
&& !global_extended_registry().is_registered(op_name)
{
}
}
_ => {} }
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::any::TypeId;
use torsh_core::dtype::CustomInt16;
#[test]
fn test_extended_shape_info() {
let shape = Shape::new(vec![2, 3, 4]);
let dtype = ExtendedDType::Standard(DType::F32);
let shape_info = ExtendedShapeInfo::new(shape.clone(), dtype);
assert_eq!(shape_info.shape.dims(), &[2, 3, 4]);
assert_eq!(shape_info.byte_size(), 2 * 3 * 4 * 4); assert!(!shape_info.dtype.is_custom());
}
#[test]
fn test_custom_type_registration() {
let type_id = TypeId::of::<CustomInt16>();
if !CustomDTypeRegistry::is_registered(type_id) {
let result = CustomTypeUtils::register_custom_type::<CustomInt16>();
if result.is_err() {
eprintln!(
"WARNING: CustomInt16 registration failed (concurrent test): {:?}",
result.err()
);
return;
}
}
assert!(CustomTypeUtils::is_custom_type_supported(type_id));
}
#[test]
fn test_extended_shape_inference_context() {
let mut context = ExtendedShapeInferenceContext::new();
let shape_info = ExtendedShapeInfo::from_standard(Shape::new(vec![2, 3]), DType::F32);
let node_idx = NodeIndex::new(0);
context.set_shape(node_idx, shape_info.clone());
let retrieved = context.get_shape(node_idx).unwrap();
assert_eq!(retrieved.shape.dims(), shape_info.shape.dims());
}
#[test]
fn test_type_promotion() {
let ctx = ExtendedShapeInferenceContext::new();
let dtype1 = ExtendedDType::Standard(DType::I32);
let dtype2 = ExtendedDType::Standard(DType::F32);
let promoted = ctx.promote_types(&dtype1, &dtype2).unwrap();
match promoted {
ExtendedDType::Standard(DType::F32) => {} _ => panic!("Expected F32 promotion"),
}
}
#[test]
fn test_broadcasting_extended() {
let ctx = ExtendedShapeInferenceContext::new();
let shape1 = ExtendedShapeInfo::from_standard(Shape::new(vec![1, 3]), DType::F32);
let shape2 = ExtendedShapeInfo::from_standard(Shape::new(vec![2, 1]), DType::F32);
let result = ctx.broadcast_shapes_extended(&shape1, &shape2).unwrap();
assert_eq!(result.shape.dims(), &[2, 3]);
}
#[test]
fn test_matmul_shape_extended() {
let ctx = ExtendedShapeInferenceContext::new();
let shape1 = ExtendedShapeInfo::from_standard(Shape::new(vec![2, 3]), DType::F32);
let shape2 = ExtendedShapeInfo::from_standard(Shape::new(vec![3, 4]), DType::F32);
let result = ctx.matmul_shape_extended(&shape1, &shape2).unwrap();
assert_eq!(result.shape.dims(), &[2, 4]);
}
#[test]
fn test_custom_type_utils() {
let type_id = TypeId::of::<CustomInt16>();
if !CustomDTypeRegistry::is_registered(type_id) {
let result = CustomTypeUtils::register_custom_type::<CustomInt16>();
if result.is_err() {
eprintln!(
"WARNING: CustomInt16 registration failed (concurrent test): {:?}",
result.err()
);
return;
}
}
assert!(CustomTypeUtils::is_custom_type_supported(type_id));
let _ops = CustomTypeUtils::get_compatible_operations(type_id);
}
#[test]
fn test_extended_operation_registry() {
let registry = ExtendedOperationRegistry::new();
assert!(registry.list_operations().is_empty());
assert!(!registry.is_registered("test_op"));
}
}