Skip to main content

rlx_sam_ir/
mask_hyper_matmul_ir.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Mask logits from hypernetwork coeffs × upscaled embedding planes (`hyper @ up`).
17
18use anyhow::Result;
19use rlx_flow::CompileProfile;
20use rlx_ir::hir::{HirModule, HirMut};
21use rlx_ir::{DType, Graph, HirGraphExt, Shape};
22use rlx_runtime::{CompiledGraph, Device};
23
24/// `masks = hyper_in @ up2` with `hyper_in` `[num_masks, q8]`, `up2` `[q8, spat]` NCHW-flat.
25pub struct MaskHyperMatmulCompiled {
26    graph: CompiledGraph,
27    pub num_masks: usize,
28    pub q8: usize,
29    pub spat: usize,
30}
31
32impl MaskHyperMatmulCompiled {
33    pub fn compile(num_masks: usize, q8: usize, grid: usize, device: Device) -> Result<Self> {
34        Self::compile_with_profile(num_masks, q8, grid, device, &CompileProfile::sam_encoder())
35    }
36
37    pub fn compile_with_profile(
38        num_masks: usize,
39        q8: usize,
40        grid: usize,
41        device: Device,
42        profile: &CompileProfile,
43    ) -> Result<Self> {
44        let spat = (4 * grid) * (4 * grid);
45        let f = DType::F32;
46        let mut hir = HirModule::new("mask_hyper_matmul");
47        let mut g = HirMut::new(&mut hir);
48
49        let hyper = g.input("hyper", Shape::new(&[num_masks, q8], f));
50        let up = g.input("up", Shape::new(&[q8, spat], f));
51        let masks = g.mm(hyper, up);
52
53        hir.set_outputs(vec![masks]);
54        let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
55        let compiled = rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
56        Ok(Self {
57            graph: compiled,
58            num_masks,
59            q8,
60            spat,
61        })
62    }
63
64    pub fn run(&mut self, hyper_in: &[f32], up2: &[f32], masks_out: &mut [f32]) -> Result<()> {
65        let hn = self.num_masks * self.q8;
66        let un = self.q8 * self.spat;
67        let mn = self.num_masks * self.spat;
68        anyhow::ensure!(hyper_in.len() == hn, "hyper len {} ≠ {hn}", hyper_in.len());
69        anyhow::ensure!(up2.len() == un, "up2 len {} ≠ {un}", up2.len());
70        anyhow::ensure!(
71            masks_out.len() == mn,
72            "masks_out len {} ≠ {mn}",
73            masks_out.len()
74        );
75        let outs = self.graph.run(&[("hyper", hyper_in), ("up", up2)]);
76        let out = outs.into_iter().next().expect("mask_hyper_matmul output");
77        masks_out.copy_from_slice(&out);
78        Ok(())
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use rlx_runtime::Device;
86
87    #[test]
88    fn hyper_matmul_ir_matches_blas() {
89        let nm = 4usize;
90        let q8 = 32usize;
91        let grid = 64usize;
92        let spat = (4 * grid) * (4 * grid);
93        let hyper: Vec<f32> = (0..nm * q8).map(|i| (i as f32) * 0.01).collect();
94        let up: Vec<f32> = (0..q8 * spat).map(|i| ((i % 17) as f32) * 0.02).collect();
95        let mut blas_out = vec![0f32; nm * spat];
96        rlx_cpu::blas::sgemm_auto(&hyper, &up, &mut blas_out, nm, q8, spat);
97        let mut ir_out = vec![0f32; nm * spat];
98        let mut compiled = MaskHyperMatmulCompiled::compile(nm, q8, grid, Device::Cpu).unwrap();
99        compiled.run(&hyper, &up, &mut ir_out).unwrap();
100        let fd = blas_out
101            .iter()
102            .zip(&ir_out)
103            .map(|(a, b)| (a - b).abs())
104            .fold(0f32, f32::max);
105        assert!(fd < 1e-3, "IR vs BLAS max |Δ| = {fd:.3e}");
106    }
107}