use rlx_ir::{Op, OpKind};
use std::sync::{OnceLock, RwLock};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FusionRole {
TransformStep,
RegionCompute,
RegionPrologue,
BatchPlane,
}
pub trait FusionFragment: Send + Sync {
fn name(&self) -> &'static str;
fn role(&self) -> FusionRole;
fn op_kinds(&self) -> &'static [OpKind];
fn matches_op(&self, op: &Op) -> bool {
self.op_kinds().iter().any(|k| op.kind() == *k)
}
}
struct BuiltinResizeNearest2x;
impl FusionFragment for BuiltinResizeNearest2x {
fn name(&self) -> &'static str {
"resize_nearest_2x"
}
fn role(&self) -> FusionRole {
FusionRole::TransformStep
}
fn op_kinds(&self) -> &'static [OpKind] {
&[OpKind::ResizeNearest2x]
}
}
struct BuiltinElementwiseRegion;
impl FusionFragment for BuiltinElementwiseRegion {
fn name(&self) -> &'static str {
"elementwise_region"
}
fn role(&self) -> FusionRole {
FusionRole::RegionCompute
}
fn op_kinds(&self) -> &'static [OpKind] {
&[OpKind::ElementwiseRegion]
}
}
static REGISTRY: OnceLock<RwLock<Vec<&'static dyn FusionFragment>>> = OnceLock::new();
fn registry() -> &'static RwLock<Vec<&'static dyn FusionFragment>> {
REGISTRY.get_or_init(|| RwLock::new(vec![&BuiltinResizeNearest2x, &BuiltinElementwiseRegion]))
}
pub fn register_fusion_fragment(fragment: &'static dyn FusionFragment) {
registry()
.write()
.expect("fusion fragment registry poisoned")
.push(fragment);
}
pub fn fusion_fragments() -> Vec<&'static dyn FusionFragment> {
registry()
.read()
.expect("fusion fragment registry poisoned")
.clone()
}
pub fn is_registered_transform_op(op: &Op) -> bool {
fusion_fragments()
.iter()
.any(|f| f.role() == FusionRole::TransformStep && f.matches_op(op))
}
pub fn transform_chain_eligible(op: &Op) -> bool {
op.is_transform_eligible() || is_registered_transform_op(op)
}
pub fn prologue_for_transform_op(op: &Op) -> Option<rlx_ir::RegionPrologue> {
if matches!(op, Op::ResizeNearest2x) {
Some(rlx_ir::RegionPrologue::ResizeNearest2x)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
struct DummyCrop;
impl FusionFragment for DummyCrop {
fn name(&self) -> &'static str {
"dummy_crop"
}
fn role(&self) -> FusionRole {
FusionRole::TransformStep
}
fn op_kinds(&self) -> &'static [OpKind] {
&[]
}
fn matches_op(&self, op: &Op) -> bool {
matches!(op, Op::Narrow { axis: 0, .. })
}
}
#[test]
fn registry_lists_builtins() {
let names: Vec<_> = fusion_fragments().iter().map(|f| f.name()).collect();
assert!(names.contains(&"resize_nearest_2x"));
}
#[test]
fn plugin_fragment_is_discovered() {
register_fusion_fragment(&DummyCrop);
assert!(is_registered_transform_op(&Op::Narrow {
axis: 0,
start: 0,
len: 1,
}));
}
}