Skip to main content

rlx_fusion/
lower_dot_general.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Lower `Op::DotGeneral` to primitive ops (MatMul + Transpose + Reshape).
17//!
18//! DotGeneral is XLA's fully general matmul: arbitrary contracting axes,
19//! batch dimensions, etc. Implementing it as a backend primitive is a lot
20//! of work. Instead we rewrite it to MatMul at the IR level. The existing
21//! matmul kernels then handle dispatch — same code path as user-written
22//! MatMuls, including all fusion benefits.
23//!
24//! Currently handles the common pattern that matters in practice:
25//!   `dot_general(lhs[m, k], rhs[k, n], lhs_contracting=[1], rhs_contracting=[0])`
26//! collapses to a plain MatMul. Other patterns (batched, non-standard
27//! contracting axes) bail out — those are future work, but the coverage
28//! report will tell us when one shows up.
29
30use crate::pass::Pass;
31use rlx_ir::*;
32use std::collections::HashMap;
33
34pub struct LowerDotGeneral;
35
36impl Pass for LowerDotGeneral {
37    fn name(&self) -> &str {
38        "lower_dot_general"
39    }
40
41    fn run(&self, graph: Graph) -> Graph {
42        // Quick scan: is there anything to lower?
43        if !graph
44            .nodes()
45            .iter()
46            .any(|n| matches!(n.op, Op::DotGeneral { .. }))
47        {
48            return graph;
49        }
50
51        let mut new_graph = Graph::new(&graph.name);
52        let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
53
54        for node in graph.nodes() {
55            let new_id = match &node.op {
56                Op::DotGeneral {
57                    lhs_contracting,
58                    rhs_contracting,
59                    lhs_batch,
60                    rhs_batch,
61                } => {
62                    // Only the canonical 2D pattern (no batch dims, contract on
63                    // lhs's last axis and rhs's first axis) reduces to a plain
64                    // MatMul. For everything else, leave the node intact —
65                    // the coverage report flags it as MISSING and we fix it
66                    // when a model needs it.
67                    if lhs_batch.is_empty()
68                        && rhs_batch.is_empty()
69                        && lhs_contracting.as_slice() == [1]
70                        && rhs_contracting.as_slice() == [0]
71                    {
72                        let lhs = id_map[&node.inputs[0]];
73                        let rhs = id_map[&node.inputs[1]];
74                        new_graph.add_node(Op::MatMul, vec![lhs, rhs], node.shape.clone())
75                    } else {
76                        let inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
77                        new_graph.add_node(node.op.clone(), inputs, node.shape.clone())
78                    }
79                }
80                _ => {
81                    let inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
82                    new_graph.add_node(node.op.clone(), inputs, node.shape.clone())
83                }
84            };
85            id_map.insert(node.id, new_id);
86        }
87
88        let new_outputs: Vec<NodeId> = graph.outputs.iter().map(|i| id_map[i]).collect();
89        new_graph.set_outputs(new_outputs);
90        new_graph
91    }
92}