rlx_fusion/lower_dot_general.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Lower `Op::DotGeneral` to primitive ops (MatMul + Transpose + Reshape).
17//!
18//! DotGeneral is XLA's fully general matmul: arbitrary contracting axes,
19//! batch dimensions, etc. Implementing it as a backend primitive is a lot
20//! of work. Instead we rewrite it to MatMul at the IR level. The existing
21//! matmul kernels then handle dispatch — same code path as user-written
22//! MatMuls, including all fusion benefits.
23//!
24//! Currently handles the common pattern that matters in practice:
25//! `dot_general(lhs[m, k], rhs[k, n], lhs_contracting=[1], rhs_contracting=[0])`
26//! collapses to a plain MatMul. Other patterns (batched, non-standard
27//! contracting axes) bail out — those are future work, but the coverage
28//! report will tell us when one shows up.
29
30use crate::pass::Pass;
31use rlx_ir::*;
32use std::collections::HashMap;
33
34pub struct LowerDotGeneral;
35
36impl Pass for LowerDotGeneral {
37 fn name(&self) -> &str {
38 "lower_dot_general"
39 }
40
41 fn run(&self, graph: Graph) -> Graph {
42 // Quick scan: is there anything to lower?
43 if !graph
44 .nodes()
45 .iter()
46 .any(|n| matches!(n.op, Op::DotGeneral { .. }))
47 {
48 return graph;
49 }
50
51 let mut new_graph = Graph::new(&graph.name);
52 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
53
54 for node in graph.nodes() {
55 let new_id = match &node.op {
56 Op::DotGeneral {
57 lhs_contracting,
58 rhs_contracting,
59 lhs_batch,
60 rhs_batch,
61 } => {
62 // Only the canonical 2D pattern (no batch dims, contract on
63 // lhs's last axis and rhs's first axis) reduces to a plain
64 // MatMul. For everything else, leave the node intact —
65 // the coverage report flags it as MISSING and we fix it
66 // when a model needs it.
67 if lhs_batch.is_empty()
68 && rhs_batch.is_empty()
69 && lhs_contracting.as_slice() == [1]
70 && rhs_contracting.as_slice() == [0]
71 {
72 let lhs = id_map[&node.inputs[0]];
73 let rhs = id_map[&node.inputs[1]];
74 new_graph.add_node(Op::MatMul, vec![lhs, rhs], node.shape.clone())
75 } else {
76 let inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
77 new_graph.add_node(node.op.clone(), inputs, node.shape.clone())
78 }
79 }
80 _ => {
81 let inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
82 new_graph.add_node(node.op.clone(), inputs, node.shape.clone())
83 }
84 };
85 id_map.insert(node.id, new_id);
86 }
87
88 let new_outputs: Vec<NodeId> = graph.outputs.iter().map(|i| id_map[i]).collect();
89 new_graph.set_outputs(new_outputs);
90 new_graph
91 }
92}