rlx_ir/const_check.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Compile-time shape / rank assertions (plan #77).
17//!
18//! Borrowed from MAX's `comptime assert c.rank == 2, "c must be rank 2"`
19//! pattern. Where shapes are known at the point of macro expansion,
20//! verify them at compile time via `const fn` helpers; runtime
21//! `Shape` checks remain for genuinely-dynamic cases.
22//!
23//! The Rust spelling uses `const fn` predicates plus a small
24//! `static_assert!` macro that wraps a `const _: () = assert!(...)`
25//! evaluation. Failures surface as compile errors with the full
26//! const-evaluation chain, so the user sees exactly which check
27//! tripped.
28//!
29//! These tools are most useful inside macros (e.g. a future
30//! `tensor!{ shape: [8, 8] }` literal that wants to check the
31//! shape is non-empty + has the expected rank). Today they're
32//! exposed as building blocks.
33
34/// Compile-time assert. Wraps the const-evaluation idiom in a
35/// terse macro so call sites read like `static_assert!(cond)`.
36///
37/// ```
38/// rlx_ir::static_assert!(1 + 1 == 2);
39/// rlx_ir::static_assert!(usize::MAX > 0, "platform sanity");
40/// ```
41///
42/// Failure is a compile error pointing at the macro call site.
43#[macro_export]
44macro_rules! static_assert {
45 ($cond:expr) => {
46 const _: () = assert!($cond);
47 };
48 ($cond:expr, $msg:literal) => {
49 const _: () = assert!($cond, $msg);
50 };
51}
52
53/// Const-evaluable rank check.
54pub const fn rank_eq(rank: usize, expected: usize) -> bool {
55 rank == expected
56}
57
58/// Const-evaluable rank-at-least check.
59pub const fn rank_at_least(rank: usize, min: usize) -> bool {
60 rank >= min
61}
62
63/// Const product of a fixed-size dim array. Useful for asserting
64/// a flat element count matches a structured shape at compile
65/// time.
66pub const fn shape_elements<const N: usize>(dims: [usize; N]) -> usize {
67 let mut total = 1usize;
68 let mut i = 0;
69 while i < N {
70 total *= dims[i];
71 i += 1;
72 }
73 total
74}
75
76/// Const check that `lhs` and `rhs` shapes are broadcast-compat
77/// per the standard rules: equal at every dim, or one of them is
78/// 1 at that dim. Both shapes must have the same rank (left-pad
79/// the shorter externally if the runtime shape supports it).
80pub const fn broadcastable<const N: usize>(lhs: [usize; N], rhs: [usize; N]) -> bool {
81 let mut i = 0;
82 while i < N {
83 let l = lhs[i];
84 let r = rhs[i];
85 if !(l == r || l == 1 || r == 1) {
86 return false;
87 }
88 i += 1;
89 }
90 true
91}
92
93/// Const check for the matmul rank/dim contract:
94/// `[m, k] @ [k, n] → [m, n]`. Returns true iff the inner dims
95/// agree.
96pub const fn matmul_compat(lhs_m: usize, lhs_k: usize, rhs_k: usize, rhs_n: usize) -> bool {
97 let _ = lhs_m;
98 let _ = rhs_n;
99 lhs_k == rhs_k
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 // Compile-time assertions (each one fails the build if wrong).
107 static_assert!(rank_eq(2, 2));
108 static_assert!(rank_at_least(3, 2));
109 static_assert!(shape_elements([2, 3, 4]) == 24);
110 static_assert!(broadcastable([4, 1, 8], [4, 6, 1]));
111 static_assert!(!broadcastable([4, 5], [3, 5]));
112 static_assert!(matmul_compat(8, 16, 16, 32));
113 static_assert!(!matmul_compat(8, 16, 32, 16));
114
115 // Runtime basic tests too — the const fns are also useful at
116 // runtime for shape-inference helpers.
117 #[test]
118 fn const_helpers_at_runtime() {
119 assert!(rank_eq(2, 2));
120 assert!(!rank_eq(2, 3));
121 assert_eq!(shape_elements([2, 3, 4]), 24);
122 assert!(broadcastable([4, 1, 8], [4, 6, 1]));
123 assert!(matmul_compat(8, 16, 16, 32));
124 }
125}