Skip to main content

rlx_macros/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! RLX proc macros for AOT model compilation.
17//!
18//! `#[rlx_model]` transforms a function that uses the RLX tracing API
19//! into an optimized, cached, zero-overhead execution path.
20//!
21//! # Usage
22//! ```rust,ignore
23//! use rlx_macros::rlx_model;
24//! use rlx_runtime::trace::*;
25//!
26//! #[rlx_model]
27//! fn my_encoder(t: &Tracer) -> Vec<TracedTensor> {
28//!     let x = t.input("x", &[4, 15, 384], DType::F32);
29//!     let w = t.param("w", &[384, 1536], DType::F32);
30//!     let b = t.param("b", &[1536], DType::F32);
31//!     let out = t.matmul(x, w);
32//!     let out = (out + b).gelu();
33//!     vec![out]
34//! }
35//!
36//! // Generated: my_encoder_compiled() returns a cached CompiledGraph
37//! // that's built once and reused on every call.
38//! ```
39
40use proc_macro::TokenStream;
41use quote::quote;
42use syn::{ItemFn, parse_macro_input};
43
44mod lm_runner;
45mod pipeline;
46
47/// Compile-time pipeline scheduler (plan #11). See `pipeline_schedule_impl`
48/// in this crate's private `pipeline` module for the full grammar.
49///
50/// ```ignore
51/// pipeline_schedule! {
52///     name: AttentionBlock,
53///     stages: {
54///         qkv_proj => [],
55///         narrow_q => [qkv_proj],
56///         attention => [narrow_q],
57///     }
58/// }
59/// ```
60///
61/// Emits a unit struct + `ORDER`/`DEPS` const slices, with
62/// topological sort + cycle detection at compile time.
63#[proc_macro]
64pub fn pipeline_schedule(item: TokenStream) -> TokenStream {
65    pipeline::pipeline_schedule_impl(item.into()).into()
66}
67
68/// AOT compilation macro for RLX models.
69///
70/// Wraps a tracing function with a `static OnceCell` cache that:
71/// 1. On first call: traces the function → builds IR graph → fuses → compiles thunks
72/// 2. On subsequent calls: executes pre-compiled thunks (zero overhead)
73///
74/// The original function becomes the "graph builder". A new `_compiled` function
75/// is generated that manages the cache and execution.
76#[proc_macro_attribute]
77pub fn rlx_model(_attr: TokenStream, item: TokenStream) -> TokenStream {
78    let input_fn = parse_macro_input!(item as ItemFn);
79    let fn_name = &input_fn.sig.ident;
80    let fn_vis = &input_fn.vis;
81    let fn_block = &input_fn.block;
82    let fn_inputs = &input_fn.sig.inputs;
83    let fn_output = &input_fn.sig.output;
84
85    // Generate the compiled version name
86    let compiled_name = syn::Ident::new(&format!("{fn_name}_compiled"), fn_name.span());
87
88    // The graph builder function name (original, kept for debugging)
89    let builder_name = syn::Ident::new(&format!("{fn_name}_build_graph"), fn_name.span());
90
91    let expanded = quote! {
92        /// Graph builder (the original function — builds IR graph via tracing).
93        fn #builder_name(#fn_inputs) #fn_output {
94            #fn_block
95        }
96
97        /// Compiled model — traces once, caches, executes with zero overhead.
98        ///
99        /// Returns a reference to the cached `CompiledGraph`. Call `.run()` or
100        /// `.run_raw()` to execute.
101        #fn_vis fn #compiled_name() -> &'static ::std::sync::Mutex<::rlx_runtime::CompiledGraph> {
102            use ::std::sync::{Mutex, OnceLock};
103
104            static COMPILED: OnceLock<Mutex<::rlx_runtime::CompiledGraph>> = OnceLock::new();
105
106            COMPILED.get_or_init(|| {
107                // Trace the function to build the IR graph
108                let graph = ::rlx_runtime::trace::trace(stringify!(#fn_name), |t| {
109                    #builder_name(t)
110                });
111
112                // Compile: fuse → memory plan → thunks
113                let session = ::rlx_runtime::Session::new(::rlx_runtime::Device::Cpu);
114                let compiled = session.compile(graph);
115
116                Mutex::new(compiled)
117            })
118        }
119
120        // Keep original function accessible for debugging
121        #[allow(dead_code)]
122        #input_fn
123    };
124
125    TokenStream::from(expanded)
126}
127
128/// Register a per-family LM runner so [`rlx_runtime::auto_runner_name`]
129/// can route a weights file to it.
130///
131/// ```ignore
132/// rlx_macros::register_lm_runner! {
133///     family = "qwen3",
134///     description = "Qwen 3 LM",
135///     arches = ["qwen3", "qwen3moe"]
136/// }
137/// ```
138///
139/// Backed by `inventory` at startup; no per-bin `register_cli` call
140/// is needed once each family invokes this macro at the crate root.
141#[proc_macro]
142pub fn register_lm_runner(input: TokenStream) -> TokenStream {
143    lm_runner::register_lm_runner_impl(input)
144}
145
146/// `fn main()` for a per-family runner binary. Replaces the 8-line
147/// boilerplate at the top of every `rlx-<family>/src/bin/rlx_*.rs`.
148///
149/// ```ignore
150/// // src/bin/rlx_qwen3.rs
151/// rlx_macros::rlx_runner_main!(rlx_qwen3::cli::run, "rlx-qwen3");
152/// ```
153#[proc_macro]
154pub fn rlx_runner_main(input: TokenStream) -> TokenStream {
155    lm_runner::rlx_runner_main_impl(input)
156}