Skip to main content

axonml_fusion/
lib.rs

1//! Axonml Fusion - Kernel Fusion Library
2//!
3//! Provides kernel fusion support for combining multiple operations into
4//! single optimized kernels. Common fusion patterns include:
5//!
6//! - **MatMul + Bias + Activation**: Fused dense layer
7//! - **Conv + BatchNorm + ReLU**: Fused convolution block
8//! - **Elementwise chains**: Multiple elementwise ops in one pass
9//! - **Reduction + Transform**: Softmax, LayerNorm patterns
10//!
11//! # Example
12//! ```ignore
13//! use axonml_fusion::{FusedOp, fuse_matmul_bias_relu};
14//!
15//! let fused = fuse_matmul_bias_relu(&weight, &bias);
16//! let output = fused.execute(&input);
17//! ```
18//!
19//! @version 0.1.0
20//! @author AutomataNexus Development Team
21
22#![warn(missing_docs)]
23#![warn(clippy::all)]
24#![allow(clippy::module_name_repetitions)]
25#![allow(clippy::must_use_candidate)]
26#![allow(clippy::missing_errors_doc)]
27
28pub mod error;
29pub mod patterns;
30pub mod elementwise;
31pub mod linear;
32pub mod optimizer;
33
34pub use error::{FusionError, FusionResult};
35pub use patterns::{FusionPattern, detect_patterns};
36pub use elementwise::{FusedElementwise, fuse_elementwise};
37pub use linear::{FusedLinear, fuse_matmul_bias_relu};
38pub use optimizer::{FusionOptimizer, optimize_graph};
39
40// =============================================================================
41// Fused Operation Trait
42// =============================================================================
43
44use axonml_tensor::Tensor;
45use std::fmt::Debug;
46
47/// Trait for fused operations.
48pub trait FusedOp: Debug + Send + Sync {
49    /// Executes the fused operation.
50    fn execute(&self, inputs: &[&Tensor<f32>]) -> FusionResult<Tensor<f32>>;
51
52    /// Returns the name of the fused operation.
53    fn name(&self) -> &str;
54
55    /// Returns the number of operations fused.
56    fn num_ops(&self) -> usize;
57
58    /// Returns estimated speedup from fusion.
59    fn estimated_speedup(&self) -> f32 {
60        // Default: 1.0 + 0.2 per additional op fused
61        1.0 + 0.2 * (self.num_ops() - 1) as f32
62    }
63}
64
65// =============================================================================
66// Tests
67// =============================================================================
68
69#[cfg(test)]
70mod tests {
71    #[test]
72    fn test_placeholder() {
73        assert!(true);
74    }
75}