rlx_macros/lib.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! RLX proc macros for AOT model compilation.
17//!
18//! `#[rlx_model]` transforms a function that uses the RLX tracing API
19//! into an optimized, cached, zero-overhead execution path.
20//!
21//! # Usage
22//! ```rust,ignore
23//! use rlx_macros::rlx_model;
24//! use rlx_runtime::trace::*;
25//!
26//! #[rlx_model]
27//! fn my_encoder(t: &Tracer) -> Vec<TracedTensor> {
28//! let x = t.input("x", &[4, 15, 384], DType::F32);
29//! let w = t.param("w", &[384, 1536], DType::F32);
30//! let b = t.param("b", &[1536], DType::F32);
31//! let out = t.matmul(x, w);
32//! let out = (out + b).gelu();
33//! vec![out]
34//! }
35//!
36//! // Generated: my_encoder_compiled() returns a cached CompiledGraph
37//! // that's built once and reused on every call.
38//! ```
39
40use proc_macro::TokenStream;
41use quote::quote;
42use syn::{ItemFn, parse_macro_input};
43
44mod lm_runner;
45mod pipeline;
46
47/// Compile-time pipeline scheduler (plan #11). See `pipeline_schedule_impl`
48/// in this crate's private `pipeline` module for the full grammar.
49///
50/// ```ignore
51/// pipeline_schedule! {
52/// name: AttentionBlock,
53/// stages: {
54/// qkv_proj => [],
55/// narrow_q => [qkv_proj],
56/// attention => [narrow_q],
57/// }
58/// }
59/// ```
60///
61/// Emits a unit struct + `ORDER`/`DEPS` const slices, with
62/// topological sort + cycle detection at compile time.
63#[proc_macro]
64pub fn pipeline_schedule(item: TokenStream) -> TokenStream {
65 pipeline::pipeline_schedule_impl(item.into()).into()
66}
67
68/// AOT compilation macro for RLX models.
69///
70/// Wraps a tracing function with a `static OnceCell` cache that:
71/// 1. On first call: traces the function → builds IR graph → fuses → compiles thunks
72/// 2. On subsequent calls: executes pre-compiled thunks (zero overhead)
73///
74/// The original function becomes the "graph builder". A new `_compiled` function
75/// is generated that manages the cache and execution.
76#[proc_macro_attribute]
77pub fn rlx_model(_attr: TokenStream, item: TokenStream) -> TokenStream {
78 let input_fn = parse_macro_input!(item as ItemFn);
79 let fn_name = &input_fn.sig.ident;
80 let fn_vis = &input_fn.vis;
81 let fn_block = &input_fn.block;
82 let fn_inputs = &input_fn.sig.inputs;
83 let fn_output = &input_fn.sig.output;
84
85 // Generate the compiled version name
86 let compiled_name = syn::Ident::new(&format!("{fn_name}_compiled"), fn_name.span());
87
88 // The graph builder function name (original, kept for debugging)
89 let builder_name = syn::Ident::new(&format!("{fn_name}_build_graph"), fn_name.span());
90
91 let expanded = quote! {
92 /// Graph builder (the original function — builds IR graph via tracing).
93 fn #builder_name(#fn_inputs) #fn_output {
94 #fn_block
95 }
96
97 /// Compiled model — traces once, caches, executes with zero overhead.
98 ///
99 /// Returns a reference to the cached `CompiledGraph`. Call `.run()` or
100 /// `.run_raw()` to execute.
101 #fn_vis fn #compiled_name() -> &'static ::std::sync::Mutex<::rlx_runtime::CompiledGraph> {
102 use ::std::sync::{Mutex, OnceLock};
103
104 static COMPILED: OnceLock<Mutex<::rlx_runtime::CompiledGraph>> = OnceLock::new();
105
106 COMPILED.get_or_init(|| {
107 // Trace the function to build the IR graph
108 let graph = ::rlx_runtime::trace::trace(stringify!(#fn_name), |t| {
109 #builder_name(t)
110 });
111
112 // Compile: fuse → memory plan → thunks
113 let session = ::rlx_runtime::Session::new(::rlx_runtime::Device::Cpu);
114 let compiled = session.compile(graph);
115
116 Mutex::new(compiled)
117 })
118 }
119
120 // Keep original function accessible for debugging
121 #[allow(dead_code)]
122 #input_fn
123 };
124
125 TokenStream::from(expanded)
126}
127
128/// Register a per-family LM runner so [`rlx_runtime::auto_runner_name`]
129/// can route a weights file to it.
130///
131/// ```ignore
132/// rlx_macros::register_lm_runner! {
133/// family = "qwen3",
134/// description = "Qwen 3 LM",
135/// arches = ["qwen3", "qwen3moe"]
136/// }
137/// ```
138///
139/// Backed by `inventory` at startup; no per-bin `register_cli` call
140/// is needed once each family invokes this macro at the crate root.
141#[proc_macro]
142pub fn register_lm_runner(input: TokenStream) -> TokenStream {
143 lm_runner::register_lm_runner_impl(input)
144}
145
146/// `fn main()` for a per-family runner binary. Replaces the 8-line
147/// boilerplate at the top of every `rlx-<family>/src/bin/rlx_*.rs`.
148///
149/// ```ignore
150/// // src/bin/rlx_qwen3.rs
151/// rlx_macros::rlx_runner_main!(rlx_qwen3::cli::run, "rlx-qwen3");
152/// ```
153#[proc_macro]
154pub fn rlx_runner_main(input: TokenStream) -> TokenStream {
155 lm_runner::rlx_runner_main_impl(input)
156}