use rlx_driver::Device;
use rlx_ir::Graph;
use rlx_ir::GraphModule;
use rlx_ir::OpKind;
use rlx_ir::hir::HirModule;
use rlx_ir::lir::LirModule;
use rlx_opt::{
CompilePipeline, CompileResult, FusionLimits, FusionOptions, FusionReport, FusionTarget,
fusion_limits_for_target,
};
use crate::CompileOptions;
pub fn fusion_target_for(device: Device) -> FusionTarget {
match device {
Device::Cpu => FusionTarget::Cpu,
Device::Metal => FusionTarget::Metal,
Device::Mlx => FusionTarget::Mlx,
Device::Cuda => FusionTarget::Cuda,
Device::Rocm => FusionTarget::Rocm,
Device::Gpu | Device::Vulkan | Device::WebGpu => FusionTarget::Wgpu,
Device::Tpu => FusionTarget::Tpu,
_ => FusionTarget::Cpu,
}
}
pub fn pipeline_for(device: Device, options: &CompileOptions) -> CompilePipeline {
let target = options
.fusion_target
.unwrap_or_else(|| fusion_target_for(device));
let mut opts = options.fusion_opts;
if matches!(target, FusionTarget::Cpu) && !opts.unfuse_elementwise_regions {
opts.unfuse_elementwise_regions = true;
}
if matches!(target, FusionTarget::Metal) {
let metal_env = FusionOptions::from_metal_env();
if !rlx_ir::env::flag("RLX_METAL_NO_FUSION") {
}
if metal_env.skip_fusion {
opts.skip_fusion = true;
}
if metal_env.unfuse_elementwise_regions {
opts.unfuse_elementwise_regions = true;
}
}
let mut pipe = CompilePipeline::new(target);
pipe.opts = opts;
if pipe.opts.fusion_limits == FusionLimits::default() {
pipe.opts.fusion_limits = fusion_limits_for_target(target);
}
pipe.arena_alignment = options.arena_alignment;
pipe.assert_fusion_clean = options.assert_fusion_clean;
if let Some(ops) = options.supported_ops {
pipe.supported_ops = Some(ops);
} else if let Some(backend) = crate::registry::backend_for(device) {
let ops = backend.supported_ops();
if !ops.is_empty() {
pipe.supported_ops = Some(ops);
}
}
pipe.kernel_dispatch = options.kernel_dispatch;
pipe
}
pub fn options_with_supported_ops(
options: &CompileOptions,
supported_ops: &'static [OpKind],
) -> CompileOptions {
let mut opts = options.clone();
opts.supported_ops = Some(supported_ops);
opts
}
pub fn compile_graph_stages(
device: Device,
graph: Graph,
options: &CompileOptions,
) -> CompileResult {
let pipe = pipeline_for(device, options);
maybe_specialize(pipe.compile_graph(graph), &pipe, options)
}
fn maybe_specialize(
result: CompileResult,
pipe: &CompilePipeline,
options: &CompileOptions,
) -> CompileResult {
match &options.dim_binding {
Some(binding) => result.specialize(pipe, binding),
None => result,
}
}
pub fn compile_graph_stages_for_backend(
device: Device,
graph: Graph,
options: &CompileOptions,
supported_ops: &'static [OpKind],
) -> CompileResult {
let opts = options_with_supported_ops(options, supported_ops);
compile_graph_stages(device, graph, &opts)
}
pub fn compile_hir_stages(
device: Device,
hir: HirModule,
options: &CompileOptions,
) -> Result<CompileResult, rlx_ir::hir::LowerError> {
let pipe = pipeline_for(device, options);
pipe.compile_hir(hir)
.map(|r| maybe_specialize(r, &pipe, options))
}
pub fn compile_module_stages(
device: Device,
module: GraphModule,
options: &CompileOptions,
) -> Result<CompileResult, rlx_ir::hir::LowerError> {
let pipe = pipeline_for(device, options);
pipe.compile_module(module)
.map(|r| maybe_specialize(r, &pipe, options))
}
pub fn maybe_log_fusion(report: &FusionReport) {
if rlx_ir::env::flag("RLX_FUSION_REPORT") {
eprintln!("{report}");
}
}
pub fn graph_from_lir(lir: LirModule) -> Graph {
lir.into_graph()
}