rlx_sam_ir/
mask_hyper_matmul_ir.rs1use anyhow::Result;
19use rlx_flow::CompileProfile;
20use rlx_ir::hir::{HirModule, HirMut};
21use rlx_ir::{DType, Graph, HirGraphExt, Shape};
22use rlx_runtime::{CompiledGraph, Device};
23
24pub struct MaskHyperMatmulCompiled {
26 graph: CompiledGraph,
27 pub num_masks: usize,
28 pub q8: usize,
29 pub spat: usize,
30}
31
32impl MaskHyperMatmulCompiled {
33 pub fn compile(num_masks: usize, q8: usize, grid: usize, device: Device) -> Result<Self> {
34 Self::compile_with_profile(num_masks, q8, grid, device, &CompileProfile::sam_encoder())
35 }
36
37 pub fn compile_with_profile(
38 num_masks: usize,
39 q8: usize,
40 grid: usize,
41 device: Device,
42 profile: &CompileProfile,
43 ) -> Result<Self> {
44 let spat = (4 * grid) * (4 * grid);
45 let f = DType::F32;
46 let mut hir = HirModule::new("mask_hyper_matmul");
47 let mut g = HirMut::new(&mut hir);
48
49 let hyper = g.input("hyper", Shape::new(&[num_masks, q8], f));
50 let up = g.input("up", Shape::new(&[q8, spat], f));
51 let masks = g.mm(hyper, up);
52
53 hir.set_outputs(vec![masks]);
54 let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
55 let compiled = rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
56 Ok(Self {
57 graph: compiled,
58 num_masks,
59 q8,
60 spat,
61 })
62 }
63
64 pub fn run(&mut self, hyper_in: &[f32], up2: &[f32], masks_out: &mut [f32]) -> Result<()> {
65 let hn = self.num_masks * self.q8;
66 let un = self.q8 * self.spat;
67 let mn = self.num_masks * self.spat;
68 anyhow::ensure!(hyper_in.len() == hn, "hyper len {} ≠ {hn}", hyper_in.len());
69 anyhow::ensure!(up2.len() == un, "up2 len {} ≠ {un}", up2.len());
70 anyhow::ensure!(
71 masks_out.len() == mn,
72 "masks_out len {} ≠ {mn}",
73 masks_out.len()
74 );
75 let outs = self.graph.run(&[("hyper", hyper_in), ("up", up2)]);
76 let out = outs.into_iter().next().expect("mask_hyper_matmul output");
77 masks_out.copy_from_slice(&out);
78 Ok(())
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use rlx_runtime::Device;
86
87 #[test]
88 fn hyper_matmul_ir_matches_blas() {
89 let nm = 4usize;
90 let q8 = 32usize;
91 let grid = 64usize;
92 let spat = (4 * grid) * (4 * grid);
93 let hyper: Vec<f32> = (0..nm * q8).map(|i| (i as f32) * 0.01).collect();
94 let up: Vec<f32> = (0..q8 * spat).map(|i| ((i % 17) as f32) * 0.02).collect();
95 let mut blas_out = vec![0f32; nm * spat];
96 rlx_cpu::blas::sgemm_auto(&hyper, &up, &mut blas_out, nm, q8, spat);
97 let mut ir_out = vec![0f32; nm * spat];
98 let mut compiled = MaskHyperMatmulCompiled::compile(nm, q8, grid, Device::Cpu).unwrap();
99 compiled.run(&hyper, &up, &mut ir_out).unwrap();
100 let fd = blas_out
101 .iter()
102 .zip(&ir_out)
103 .map(|(a, b)| (a - b).abs())
104 .fold(0f32, f32::max);
105 assert!(fd < 1e-3, "IR vs BLAS max |Δ| = {fd:.3e}");
106 }
107}