use crate::env;
use crate::op::OpKind;
pub mod splat_common;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum KernelDispatchPolicy {
#[default]
PreferNative,
ForceCommon,
ForceNative,
}
impl KernelDispatchPolicy {
pub fn from_env() -> Self {
let v = env::var("KERNEL_DISPATCH").or_else(|| env::var("RLX_KERNEL_DISPATCH"));
match v.as_deref() {
Some("common") | Some("force_common") | Some("ForceCommon") => Self::ForceCommon,
Some("native") | Some("force_native") | Some("ForceNative") => Self::ForceNative,
_ => Self::PreferNative,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct LogicalKernelEntry {
pub kind: OpKind,
pub name: &'static str,
}
pub fn registered_logical_kernels() -> &'static [LogicalKernelEntry] {
&[
LogicalKernelEntry {
kind: OpKind::GroupNorm,
name: "group_norm",
},
LogicalKernelEntry {
kind: OpKind::ResizeNearest2x,
name: "resize_nearest_2x",
},
LogicalKernelEntry {
kind: OpKind::GaussianSplatRender,
name: "gaussian_splat_render",
},
LogicalKernelEntry {
kind: OpKind::GaussianSplatRenderBackward,
name: "gaussian_splat_render_backward",
},
]
}
#[derive(Debug, Clone, Copy, Default)]
pub struct KernelDispatchConfig {
pub policy: KernelDispatchPolicy,
pub force_common_kinds: &'static [OpKind],
pub force_native_kinds: &'static [OpKind],
}
impl KernelDispatchConfig {
pub fn new(policy: KernelDispatchPolicy) -> Self {
Self {
policy,
..Self::default()
}
}
pub fn from_env() -> Self {
Self::new(KernelDispatchPolicy::from_env())
}
}
pub fn should_lower_to_common(
kind: OpKind,
supported: &[OpKind],
config: KernelDispatchConfig,
) -> bool {
if !registered_logical_kernels().iter().any(|e| e.kind == kind) {
return false;
}
if config.force_native_kinds.contains(&kind) {
return false;
}
if config.force_common_kinds.contains(&kind) {
return true;
}
match config.policy {
KernelDispatchPolicy::ForceCommon => true,
KernelDispatchPolicy::ForceNative => false,
KernelDispatchPolicy::PreferNative => !supported.is_empty() && !supported.contains(&kind),
}
}
pub fn logical_kinds_in_graph(
graph: &crate::Graph,
supported: &[OpKind],
config: KernelDispatchConfig,
) -> Vec<OpKind> {
let mut kinds = Vec::new();
for node in graph.nodes() {
let k = node.op.kind();
if should_lower_to_common(k, supported, config) && !kinds.contains(&k) {
kinds.push(k);
}
}
kinds
}