use crate::error::{CoreError, ErrorContext, ErrorLocation};
use once_cell::sync::Lazy;
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::fmt;
use std::sync::{Arc, Mutex};
type FusedOpArc = Arc<dyn FusedOp>;
type FusionRegistryMap = HashMap<TypeId, Vec<FusedOpArc>>;
static FUSION_REGISTRY: Lazy<Mutex<FusionRegistryMap>> = Lazy::new(|| Mutex::new(HashMap::new()));
pub trait FusedOp: Send + Sync {
fn name(&self) -> &str;
fn input_type(&self) -> TypeId;
fn output_type(&self) -> TypeId;
fn can_fuse_with(&self, other: &dyn FusedOp) -> bool;
fn fuse_with(&self, other: &dyn FusedOp) -> Arc<dyn FusedOp>;
fn apply(&self, input: &dyn Any) -> Result<Box<dyn Any>, CoreError>;
fn clone_op(&self) -> Arc<dyn FusedOp>;
}
#[derive(Clone)]
pub struct OpFusion {
ops: Vec<Arc<dyn FusedOp>>,
input_type: TypeId,
output_type: TypeId,
}
impl fmt::Debug for OpFusion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OpFusion")
.field("num_ops", &self.ops.len())
.finish()
}
}
impl OpFusion {
pub fn new() -> Self {
Self {
ops: Vec::new(),
input_type: TypeId::of::<()>(),
output_type: TypeId::of::<()>(),
}
}
pub fn add_op(&mut self, op: Arc<dyn FusedOp>) -> Result<&mut Self, CoreError> {
if self.ops.is_empty() {
self.input_type = op.input_type();
self.output_type = op.output_type();
} else if op.input_type() != self.output_type {
return Err(CoreError::ValidationError(
ErrorContext::new("Operation input type does not match previous output type")
.with_location(ErrorLocation::new(file!(), line!())),
));
}
let output_type = op.output_type();
self.ops.push(op);
self.output_type = output_type;
Ok(self)
}
pub fn optimize(&mut self) -> Result<&mut Self, CoreError> {
if self.ops.len() <= 1 {
return Ok(self);
}
let mut optimized = Vec::new();
let mut current_op = self.ops[0].clone_op();
for i in 1..self.ops.len() {
let next_op = &self.ops[i];
if current_op.can_fuse_with(next_op.as_ref()) {
current_op = current_op.fuse_with(next_op.as_ref());
} else {
optimized.push(current_op);
current_op = next_op.clone_op();
}
}
optimized.push(current_op);
self.ops = optimized;
Ok(self)
}
pub fn apply<A: 'static>(&self, input: A) -> Result<Box<dyn Any>, CoreError> {
if TypeId::of::<A>() != self.input_type {
return Err(CoreError::ValidationError(
ErrorContext::new("Input type does not match expected type")
.with_location(ErrorLocation::new(file!(), line!())),
));
}
let mut result: Box<dyn Any> = Box::new(input);
for op in &self.ops {
result = op.apply(result.as_ref())?;
}
Ok(result)
}
pub fn num_ops(&self) -> usize {
self.ops.len()
}
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
}
impl Default for OpFusion {
fn default() -> Self {
Self::new()
}
}
#[allow(dead_code)]
pub fn register_fusion<T: 'static>(op: Arc<dyn FusedOp>) -> Result<(), CoreError> {
let type_id = TypeId::of::<T>();
let mut registry = FUSION_REGISTRY.lock().expect("Operation failed");
let ops = registry.entry(type_id).or_default();
ops.push(op);
Ok(())
}
#[allow(dead_code)]
pub fn get_fusions<T: 'static>() -> Vec<Arc<dyn FusedOp>> {
let type_id = TypeId::of::<T>();
let registry = FUSION_REGISTRY.lock().expect("Operation failed");
match registry.get(&type_id) {
Some(ops) => ops.clone(),
None => Vec::new(),
}
}