use crate::lower_vae_ops::{LowerBatchNormInference, LowerGroupNorm, LowerResizeNearest2x};
use crate::pass::Pass;
use rlx_ir::logical_kernel::splat_common;
use rlx_ir::logical_kernel::{self, KernelDispatchConfig};
use rlx_ir::{Graph, NodeId, Op, OpKind};
use std::collections::HashMap;
pub fn lower_logical_kernels(
graph: Graph,
supported: &[OpKind],
config: KernelDispatchConfig,
) -> Graph {
if supported.is_empty()
&& config.policy != rlx_ir::logical_kernel::KernelDispatchPolicy::ForceCommon
&& config.force_common_kinds.is_empty()
{
return graph;
}
let to_lower = logical_kernel::logical_kinds_in_graph(&graph, supported, config);
if to_lower.is_empty() {
return graph;
}
let mut g = graph;
for kind in to_lower {
g = match kind {
OpKind::GroupNorm => LowerGroupNorm.run(g),
OpKind::BatchNormInference => LowerBatchNormInference.run(g),
OpKind::ResizeNearest2x => LowerResizeNearest2x.run(g),
OpKind::GaussianSplatRender => lower_gaussian_splat_render_pass(g),
OpKind::GaussianSplatRenderBackward => lower_gaussian_splat_backward_pass(g),
_ => g,
};
}
g
}
fn lower_gaussian_splat_render_pass(graph: Graph) -> Graph {
lower_gaussian_splat_nodes(graph, |g, node| {
if let Op::GaussianSplatRender {
width,
height,
tile_size: _,
radius_scale: _,
alpha_cutoff: _,
max_splat_steps: _,
transmittance_threshold: _,
max_list_entries: _,
} = &node.op
{
let inputs = &node.inputs;
splat_common::lower_gaussian_splat_render(
g,
inputs[0],
inputs[1],
inputs[2],
inputs[3],
inputs[4],
inputs[5],
inputs[6],
*width,
*height,
node.shape.clone(),
)
} else {
unreachable!()
}
})
}
fn lower_gaussian_splat_backward_pass(graph: Graph) -> Graph {
lower_gaussian_splat_nodes(graph, |g, node| {
if let Op::GaussianSplatRenderBackward {
width,
height,
loss_grad_clip: _,
sh_band: _,
max_anisotropy: _,
tile_size: _,
radius_scale: _,
alpha_cutoff: _,
max_splat_steps: _,
transmittance_threshold: _,
max_list_entries: _,
} = &node.op
{
let inputs = &node.inputs;
splat_common::lower_gaussian_splat_render_backward(
g,
inputs[0],
inputs[1],
inputs[2],
inputs[3],
inputs[4],
inputs[5],
inputs[6],
inputs[7],
*width,
*height,
node.shape.clone(),
)
} else {
unreachable!()
}
})
}
fn lower_gaussian_splat_nodes<F>(graph: Graph, mut lower_one: F) -> Graph
where
F: FnMut(&mut Graph, &rlx_ir::Node) -> NodeId,
{
if !graph.nodes().iter().any(|n| {
matches!(
n.op,
Op::GaussianSplatRender { .. } | Op::GaussianSplatRenderBackward { .. }
)
}) {
return graph;
}
let mut new_graph = Graph::new(&graph.name);
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
for node in graph.nodes() {
let new_id = if matches!(
node.op,
Op::GaussianSplatRender { .. } | Op::GaussianSplatRenderBackward { .. }
) {
lower_one(&mut new_graph, node)
} else {
let inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
new_graph.add_node(node.op.clone(), inputs, node.shape.clone())
};
id_map.insert(node.id, new_id);
}
let new_outputs: Vec<NodeId> = graph.outputs.iter().map(|i| id_map[i]).collect();
new_graph.set_outputs(new_outputs);
new_graph
}