Skip to main content

rlx_macros/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! RLX proc macros for AOT model compilation.
17//!
18//! `#[rlx_model]` transforms a function that uses the RLX tracing API
19//! into an optimized, cached, zero-overhead execution path.
20//!
21//! # Usage
22//! ```rust,ignore
23//! use rlx_macros::rlx_model;
24//! use rlx_runtime::trace::*;
25//!
26//! #[rlx_model]
27//! fn my_encoder(t: &Tracer) -> Vec<TracedTensor> {
28//!     let x = t.input("x", &[4, 15, 384], DType::F32);
29//!     let w = t.param("w", &[384, 1536], DType::F32);
30//!     let b = t.param("b", &[1536], DType::F32);
31//!     let out = t.matmul(x, w);
32//!     let out = (out + b).gelu();
33//!     vec![out]
34//! }
35//!
36//! // Generated: my_encoder_compiled() returns a cached CompiledGraph
37//! // that's built once and reused on every call.
38//! ```
39
40use proc_macro::TokenStream;
41use quote::quote;
42use syn::{ItemFn, parse_macro_input};
43
44mod pipeline;
45
46/// Compile-time pipeline scheduler (plan #11). See `pipeline_schedule_impl`
47/// in this crate's private `pipeline` module for the full grammar.
48///
49/// ```ignore
50/// pipeline_schedule! {
51///     name: AttentionBlock,
52///     stages: {
53///         qkv_proj => [],
54///         narrow_q => [qkv_proj],
55///         attention => [narrow_q],
56///     }
57/// }
58/// ```
59///
60/// Emits a unit struct + `ORDER`/`DEPS` const slices, with
61/// topological sort + cycle detection at compile time.
62#[proc_macro]
63pub fn pipeline_schedule(item: TokenStream) -> TokenStream {
64    pipeline::pipeline_schedule_impl(item.into()).into()
65}
66
67/// AOT compilation macro for RLX models.
68///
69/// Wraps a tracing function with a `static OnceCell` cache that:
70/// 1. On first call: traces the function → builds IR graph → fuses → compiles thunks
71/// 2. On subsequent calls: executes pre-compiled thunks (zero overhead)
72///
73/// The original function becomes the "graph builder". A new `_compiled` function
74/// is generated that manages the cache and execution.
75#[proc_macro_attribute]
76pub fn rlx_model(_attr: TokenStream, item: TokenStream) -> TokenStream {
77    let input_fn = parse_macro_input!(item as ItemFn);
78    let fn_name = &input_fn.sig.ident;
79    let fn_vis = &input_fn.vis;
80    let fn_block = &input_fn.block;
81    let fn_inputs = &input_fn.sig.inputs;
82    let fn_output = &input_fn.sig.output;
83
84    // Generate the compiled version name
85    let compiled_name = syn::Ident::new(&format!("{fn_name}_compiled"), fn_name.span());
86
87    // The graph builder function name (original, kept for debugging)
88    let builder_name = syn::Ident::new(&format!("{fn_name}_build_graph"), fn_name.span());
89
90    let expanded = quote! {
91        /// Graph builder (the original function — builds IR graph via tracing).
92        fn #builder_name(#fn_inputs) #fn_output {
93            #fn_block
94        }
95
96        /// Compiled model — traces once, caches, executes with zero overhead.
97        ///
98        /// Returns a reference to the cached `CompiledGraph`. Call `.run()` or
99        /// `.run_raw()` to execute.
100        #fn_vis fn #compiled_name() -> &'static ::std::sync::Mutex<::rlx_runtime::CompiledGraph> {
101            use ::std::sync::{Mutex, OnceLock};
102
103            static COMPILED: OnceLock<Mutex<::rlx_runtime::CompiledGraph>> = OnceLock::new();
104
105            COMPILED.get_or_init(|| {
106                // Trace the function to build the IR graph
107                let graph = ::rlx_runtime::trace::trace(stringify!(#fn_name), |t| {
108                    #builder_name(t)
109                });
110
111                // Compile: fuse → memory plan → thunks
112                let session = ::rlx_runtime::Session::new(::rlx_runtime::Device::Cpu);
113                let compiled = session.compile(graph);
114
115                Mutex::new(compiled)
116            })
117        }
118
119        // Keep original function accessible for debugging
120        #[allow(dead_code)]
121        #input_fn
122    };
123
124    TokenStream::from(expanded)
125}