use crate::graph::{Graph, TensorID};
use crate::op::OpError;
use crate::Float;
use scirs2_core::ndarray::{Array, IxDyn, Zip};
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
type KernelFunction<F> =
Box<dyn Fn(&[&Array<F, IxDyn>]) -> Result<Array<F, IxDyn>, OpError> + Send + Sync>;
type KernelResult<F> = Result<KernelFunction<F>, OpError>;
#[derive(Debug, Clone, PartialEq)]
pub enum FusableOperation<F: Float> {
Add,
Sub,
Mul,
Div,
UnaryFunc(UnaryFunction<F>),
ScalarOp(F, BinaryFunction),
}
#[derive(Debug, Clone, PartialEq)]
pub enum UnaryFunction<F: Float> {
ReLU,
Sigmoid,
Tanh,
Square,
Sqrt,
Exp,
Log,
Abs,
Custom(fn(F) -> F),
}
#[derive(Debug, Clone, PartialEq)]
pub enum BinaryFunction {
AddScalar,
MulScalar,
Pow,
}
#[derive(Debug, Clone)]
pub struct FusionChain<F: Float> {
operations: Vec<FusableOperation<F>>,
inputshapes: Vec<Vec<usize>>,
outputshape: Vec<usize>,
performance_benefit: f64,
}
impl<F: Float> Default for FusionChain<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float> FusionChain<F> {
pub fn new() -> Self {
Self {
operations: Vec::new(),
inputshapes: Vec::new(),
outputshape: Vec::new(),
performance_benefit: 0.0,
}
}
pub fn add_operation(&mut self, op: FusableOperation<F>, inputshape: Vec<usize>) {
if self.outputshape.is_empty() {
self.outputshape = inputshape.clone();
}
self.performance_benefit += self.estimate_benefit(&op);
self.operations.push(op);
self.inputshapes.push(inputshape.clone());
self.outputshape = inputshape;
}
fn estimate_benefit(&self, op: &FusableOperation<F>) -> f64 {
let elements = self.outputshape.iter().product::<usize>() as f64;
match op {
FusableOperation::Add
| FusableOperation::Sub
| FusableOperation::Mul
| FusableOperation::Div => {
elements * 0.5
}
FusableOperation::UnaryFunc(_) => {
elements * 0.3
}
FusableOperation::ScalarOp(_, _) => {
elements * 0.7
}
}
}
pub fn is_worthwhile(&self) -> bool {
self.operations.len() >= 2 && self.performance_benefit > 1000.0
}
pub fn len(&self) -> usize {
self.operations.len()
}
pub fn is_empty(&self) -> bool {
self.operations.is_empty()
}
}
pub struct LoopFusionOptimizer<F: Float> {
fusion_chains: Vec<FusionChain<F>>,
fusion_mapping: HashMap<TensorID, usize>,
stats: FusionStats<F>,
}
impl<F: Float> Default for LoopFusionOptimizer<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float> LoopFusionOptimizer<F> {
pub fn new() -> Self {
Self {
fusion_chains: Vec::new(),
fusion_mapping: HashMap::new(),
stats: FusionStats::default(),
}
}
pub fn analyze_graph(&mut self, graph: &Graph<F>) -> Result<(), OpError> {
self.fusion_chains.clear();
self.fusion_mapping.clear();
let element_wise_ops = LoopFusionOptimizer::<F>::find_element_wise_operations(graph);
let chains = self.identify_fusion_chains(&element_wise_ops, graph);
for chain in chains {
if chain.is_worthwhile() {
self.fusion_chains.push(chain);
}
}
self.stats.chains_identified = self.fusion_chains.len();
self.stats.total_operations_fused =
self.fusion_chains.iter().map(|chain| chain.len()).sum();
Ok(())
}
fn find_element_wise_operations(selfgraph: &Graph<F>) -> Vec<TensorID> {
Vec::new()
}
fn identify_fusion_chains(
&self,
operations: &[TensorID],
graph: &Graph<F>,
) -> Vec<FusionChain<F>> {
let mut chains = Vec::new();
let mut visited = HashSet::new();
for &op_idx in operations {
if visited.contains(&op_idx) {
continue;
}
let chain = self.build_fusion_chain(op_idx, graph, &mut visited);
if !chain.is_empty() {
chains.push(chain);
}
}
chains
}
fn build_fusion_chain(
&self,
start_op: TensorID,
graph: &Graph<F>,
visited: &mut HashSet<TensorID>,
) -> FusionChain<F> {
let mut chain = FusionChain::new();
let mut current_op = start_op;
loop {
if visited.contains(¤t_op) {
break;
}
if let Some(fusableop) = LoopFusionOptimizer::<F>::classify_operation(current_op, graph)
{
visited.insert(current_op);
chain.add_operation(fusableop, vec![100]);
if let Some(nextop) =
LoopFusionOptimizer::<F>::find_next_fusable_operation(current_op, graph)
{
current_op = nextop;
} else {
break;
}
} else {
break;
}
}
chain
}
fn classify_operation(op_idx: TensorID, graph: &Graph<F>) -> Option<FusableOperation<F>> {
Some(FusableOperation::Add)
}
fn find_next_fusable_operation(current_op: TensorID, graph: &Graph<F>) -> Option<TensorID> {
None
}
pub fn apply_fusion(&self) -> Result<Vec<FusedKernel<F>>, OpError> {
let mut fused_kernels = Vec::new();
for chain in &self.fusion_chains {
let kernel = self.create_fused_kernel(chain)?;
fused_kernels.push(kernel);
}
Ok(fused_kernels)
}
fn create_fused_kernel(&self, chain: &FusionChain<F>) -> Result<FusedKernel<F>, OpError> {
FusedKernel::from_chain(chain.clone())
}
pub fn get_stats(&self) -> &FusionStats<F> {
&self.stats
}
}
pub struct FusedKernel<F: Float> {
chain: FusionChain<F>,
kernel_func: KernelFunction<F>,
}
impl<F: Float> FusedKernel<F> {
pub fn from_chain(chain: FusionChain<F>) -> Result<Self, OpError> {
let kernel_func = Self::compile_kernel(&chain)?;
Ok(Self { chain, kernel_func })
}
fn compile_kernel(chain: &FusionChain<F>) -> KernelResult<F> {
let operations = chain.operations.clone();
let _outputshape = chain.outputshape.clone();
Ok(Box::new(
move |inputs: &[&Array<F, IxDyn>]| -> Result<Array<F, IxDyn>, OpError> {
if inputs.is_empty() {
return Err(OpError::RuntimeError(
"No input arrays provided".to_string(),
));
}
let input = inputs[0];
let mut result = Array::zeros(input.raw_dim());
Zip::from(&mut result)
.and(input)
.par_for_each(|output, &input_val| {
let mut value = input_val;
for op in &operations {
value = match op {
FusableOperation::Add => {
value
}
FusableOperation::Mul => value,
FusableOperation::UnaryFunc(func) => {
Self::apply_unary_function(value, func)
}
FusableOperation::ScalarOp(scalar, func) => {
Self::apply_scalar_operation(value, *scalar, func)
}
_ => value,
};
}
*output = value;
});
Ok(result)
},
))
}
pub fn apply_unary_function(value: F, func: &UnaryFunction<F>) -> F {
match func {
UnaryFunction::ReLU => {
if value > F::zero() {
value
} else {
F::zero()
}
}
UnaryFunction::Sigmoid => {
let one = F::one();
one / (one + (-value).exp())
}
UnaryFunction::Tanh => value.tanh(),
UnaryFunction::Square => value * value,
UnaryFunction::Sqrt => value.sqrt(),
UnaryFunction::Exp => value.exp(),
UnaryFunction::Log => value.ln(),
UnaryFunction::Abs => value.abs(),
UnaryFunction::Custom(f) => f(value),
}
}
pub fn apply_scalar_operation(value: F, scalar: F, func: &BinaryFunction) -> F {
match func {
BinaryFunction::AddScalar => value + scalar,
BinaryFunction::MulScalar => value * scalar,
BinaryFunction::Pow => value.powf(scalar),
}
}
pub fn execute(&self, inputs: &[&Array<F, IxDyn>]) -> Result<Array<F, IxDyn>, OpError> {
(self.kernel_func)(inputs)
}
pub fn get_chain(&self) -> &FusionChain<F> {
&self.chain
}
pub fn estimate_speedup(&self) -> f64 {
let num_ops = self.chain.len() as f64;
let memory_reduction = (num_ops - 1.0) / num_ops;
1.0 + memory_reduction * 2.0
}
}
#[derive(Debug, Clone)]
pub struct FusionStats<F: crate::Float> {
pub chains_identified: usize,
pub total_operations_fused: usize,
pub memory_bandwidth_reduction: f64,
pub estimated_speedup: f64,
_phantom: std::marker::PhantomData<F>,
}
impl<F: crate::Float> FusionStats<F> {
pub fn calculate_memory_reduction(&mut self, original_ops: usize) {
if original_ops > 0 {
self.memory_bandwidth_reduction =
(original_ops - self.chains_identified) as f64 / original_ops as f64 * 100.0;
}
}
pub fn calculate_speedup(&mut self, kernels: &[FusedKernel<F>]) {
if !kernels.is_empty() {
self.estimated_speedup = kernels
.iter()
.map(|kernel| kernel.estimate_speedup())
.sum::<f64>()
/ kernels.len() as f64;
}
}
}
impl<F: crate::Float> Default for FusionStats<F> {
fn default() -> Self {
Self {
chains_identified: 0,
total_operations_fused: 0,
memory_bandwidth_reduction: 0.0,
estimated_speedup: 0.0,
_phantom: std::marker::PhantomData,
}
}
}
pub struct LoopFusionManager<F: Float> {
optimizer: LoopFusionOptimizer<F>,
kernels: Vec<FusedKernel<F>>,
config: FusionConfig,
}
impl<F: Float> Default for LoopFusionManager<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float> LoopFusionManager<F> {
pub fn new() -> Self {
Self {
optimizer: LoopFusionOptimizer::new(),
kernels: Vec::new(),
config: FusionConfig::default(),
}
}
pub fn with_config(config: FusionConfig) -> Self {
Self {
optimizer: LoopFusionOptimizer::new(),
kernels: Vec::new(),
config,
}
}
pub fn optimize_graph(&mut self, graph: &Graph<F>) -> Result<(), OpError> {
if !self.config.enable_fusion {
return Ok(());
}
self.optimizer.analyze_graph(graph)?;
self.kernels = self.optimizer.apply_fusion()?;
Ok(())
}
pub fn execute_fused_operation(
&self,
kernel_id: usize,
inputs: &[&Array<F, IxDyn>],
) -> Result<Array<F, IxDyn>, OpError> {
if kernel_id >= self.kernels.len() {
return Err(OpError::RuntimeError("Invalid kernel ID".to_string()));
}
self.kernels[kernel_id].execute(inputs)
}
pub fn get_stats(&self) -> &FusionStats<F> {
self.optimizer.get_stats()
}
pub fn num_kernels(&self) -> usize {
self.kernels.len()
}
pub fn is_enabled(&self) -> bool {
self.config.enable_fusion
}
}
#[derive(Debug, Clone)]
pub struct FusionConfig {
pub enable_fusion: bool,
pub min_chain_length: usize,
pub max_chain_length: usize,
pub min_tensor_size: usize,
pub enable_parallel_fusion: bool,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
enable_fusion: true,
min_chain_length: 2,
max_chain_length: 10,
min_tensor_size: 1000,
enable_parallel_fusion: true,
}
}
}
static FUSION_MANAGER: std::sync::OnceLock<Arc<Mutex<LoopFusionManager<f32>>>> =
std::sync::OnceLock::new();
#[allow(dead_code)]
pub fn init_fusion_manager() -> Arc<Mutex<LoopFusionManager<f32>>> {
FUSION_MANAGER
.get_or_init(|| Arc::new(Mutex::new(LoopFusionManager::new())))
.clone()
}
#[allow(dead_code)]
pub fn configure_fusion(config: FusionConfig) -> Result<(), OpError> {
let manager = init_fusion_manager();
let mut manager_guard = manager
.lock()
.map_err(|_| OpError::RuntimeError("Lock error".to_string()))?;
*manager_guard = LoopFusionManager::with_config(config);
Ok(())
}
#[allow(dead_code)]
pub fn set_fusion_enabled(enabled: bool) -> Result<(), OpError> {
let config = FusionConfig {
enable_fusion: enabled,
..Default::default()
};
configure_fusion(config)
}
#[allow(dead_code)]
pub fn is_fusion_enabled() -> bool {
let manager = init_fusion_manager();
let result = match manager.lock() {
Ok(manager_guard) => manager_guard.is_enabled(),
Err(_) => false,
};
result
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(unused_imports)]
use scirs2_core::ndarray::Array1;
#[test]
fn test_fusion_chain_creation() {
let mut chain = FusionChain::<f32>::new();
assert!(chain.is_empty());
chain.add_operation(FusableOperation::Add, vec![10000]);
chain.add_operation(
FusableOperation::UnaryFunc(UnaryFunction::ReLU),
vec![10000],
);
assert_eq!(chain.len(), 2);
assert!(chain.is_worthwhile());
}
#[test]
fn test_unary_functions() {
let value = 2.0f32;
assert_eq!(
FusedKernel::<f32>::apply_unary_function(value, &UnaryFunction::Square),
4.0
);
assert_eq!(
FusedKernel::<f32>::apply_unary_function(-1.0, &UnaryFunction::ReLU),
0.0
);
assert_eq!(
FusedKernel::<f32>::apply_unary_function(1.0, &UnaryFunction::ReLU),
1.0
);
assert_eq!(
FusedKernel::<f32>::apply_unary_function(4.0, &UnaryFunction::Sqrt),
2.0
);
assert_eq!(
FusedKernel::<f32>::apply_unary_function(-2.0, &UnaryFunction::Abs),
2.0
);
}
#[test]
fn test_scalar_operations() {
let value = 3.0f32;
let scalar = 2.0f32;
assert_eq!(
FusedKernel::<f32>::apply_scalar_operation(value, scalar, &BinaryFunction::AddScalar),
5.0
);
assert_eq!(
FusedKernel::<f32>::apply_scalar_operation(value, scalar, &BinaryFunction::MulScalar),
6.0
);
assert_eq!(
FusedKernel::<f32>::apply_scalar_operation(value, scalar, &BinaryFunction::Pow),
9.0
);
}
#[test]
fn test_fused_kernel_creation() {
let mut chain = FusionChain::new();
chain.add_operation(FusableOperation::UnaryFunc(UnaryFunction::Square), vec![5]);
chain.add_operation(
FusableOperation::ScalarOp(2.0, BinaryFunction::MulScalar),
vec![5],
);
let kernel = FusedKernel::from_chain(chain).expect("Operation failed");
let input = Array::from_shape_vec(IxDyn(&[5]), vec![1.0, 2.0, 3.0, 4.0, 5.0])
.expect("Operation failed");
let result = kernel.execute(&[&input]).expect("Operation failed");
let expected = vec![2.0, 8.0, 18.0, 32.0, 50.0];
assert_eq!(result.as_slice().expect("Operation failed"), &expected);
}
#[test]
fn test_fusion_config() {
let config = FusionConfig {
enable_fusion: false,
min_chain_length: 3,
max_chain_length: 15,
min_tensor_size: 5000,
enable_parallel_fusion: false,
};
let manager: LoopFusionManager<f32> = LoopFusionManager::with_config(config.clone());
assert!(!manager.is_enabled());
}
#[test]
fn test_fusion_stats() {
let mut stats: FusionStats<f32> = FusionStats {
chains_identified: 5,
total_operations_fused: 20,
..Default::default()
};
stats.calculate_memory_reduction(25);
assert_eq!(stats.memory_bandwidth_reduction, 80.0);
}
#[test]
fn test_global_fusion_manager() {
set_fusion_enabled(true).expect("Operation failed");
assert!(is_fusion_enabled());
set_fusion_enabled(false).expect("Operation failed");
assert!(!is_fusion_enabled());
}
#[test]
fn test_complex_fused_chain() {
let mut chain = FusionChain::new();
chain.add_operation(FusableOperation::UnaryFunc(UnaryFunction::Square), vec![4]);
chain.add_operation(FusableOperation::UnaryFunc(UnaryFunction::ReLU), vec![4]);
chain.add_operation(
FusableOperation::ScalarOp(3.0, BinaryFunction::MulScalar),
vec![4],
);
chain.add_operation(
FusableOperation::ScalarOp(1.0, BinaryFunction::AddScalar),
vec![4],
);
let kernel = FusedKernel::from_chain(chain).expect("Operation failed");
let input = Array::from_shape_vec(IxDyn(&[4]), vec![-2.0, -1.0, 1.0, 2.0])
.expect("Operation failed");
let result = kernel.execute(&[&input]).expect("Operation failed");
let expected = vec![13.0, 4.0, 4.0, 13.0];
assert_eq!(result.as_slice().expect("Operation failed"), &expected);
let speedup = kernel.estimate_speedup();
assert!(speedup > 1.0 && speedup <= 4.0);
}
}