rlx_macros/lib.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! RLX proc macros for AOT model compilation.
17//!
18//! `#[rlx_model]` transforms a function that uses the RLX tracing API
19//! into an optimized, cached, zero-overhead execution path.
20//!
21//! # Usage
22//! ```rust,ignore
23//! use rlx_macros::rlx_model;
24//! use rlx_runtime::trace::*;
25//!
26//! #[rlx_model]
27//! fn my_encoder(t: &Tracer) -> Vec<TracedTensor> {
28//! let x = t.input("x", &[4, 15, 384], DType::F32);
29//! let w = t.param("w", &[384, 1536], DType::F32);
30//! let b = t.param("b", &[1536], DType::F32);
31//! let out = t.matmul(x, w);
32//! let out = (out + b).gelu();
33//! vec![out]
34//! }
35//!
36//! // Generated: my_encoder_compiled() returns a cached CompiledGraph
37//! // that's built once and reused on every call.
38//! ```
39
40use proc_macro::TokenStream;
41use quote::quote;
42use syn::{ItemFn, parse_macro_input};
43
44mod pipeline;
45
46/// Compile-time pipeline scheduler (plan #11). See `pipeline_schedule_impl`
47/// in this crate's private `pipeline` module for the full grammar.
48///
49/// ```ignore
50/// pipeline_schedule! {
51/// name: AttentionBlock,
52/// stages: {
53/// qkv_proj => [],
54/// narrow_q => [qkv_proj],
55/// attention => [narrow_q],
56/// }
57/// }
58/// ```
59///
60/// Emits a unit struct + `ORDER`/`DEPS` const slices, with
61/// topological sort + cycle detection at compile time.
62#[proc_macro]
63pub fn pipeline_schedule(item: TokenStream) -> TokenStream {
64 pipeline::pipeline_schedule_impl(item.into()).into()
65}
66
67/// AOT compilation macro for RLX models.
68///
69/// Wraps a tracing function with a `static OnceCell` cache that:
70/// 1. On first call: traces the function → builds IR graph → fuses → compiles thunks
71/// 2. On subsequent calls: executes pre-compiled thunks (zero overhead)
72///
73/// The original function becomes the "graph builder". A new `_compiled` function
74/// is generated that manages the cache and execution.
75#[proc_macro_attribute]
76pub fn rlx_model(_attr: TokenStream, item: TokenStream) -> TokenStream {
77 let input_fn = parse_macro_input!(item as ItemFn);
78 let fn_name = &input_fn.sig.ident;
79 let fn_vis = &input_fn.vis;
80 let fn_block = &input_fn.block;
81 let fn_inputs = &input_fn.sig.inputs;
82 let fn_output = &input_fn.sig.output;
83
84 // Generate the compiled version name
85 let compiled_name = syn::Ident::new(&format!("{fn_name}_compiled"), fn_name.span());
86
87 // The graph builder function name (original, kept for debugging)
88 let builder_name = syn::Ident::new(&format!("{fn_name}_build_graph"), fn_name.span());
89
90 let expanded = quote! {
91 /// Graph builder (the original function — builds IR graph via tracing).
92 fn #builder_name(#fn_inputs) #fn_output {
93 #fn_block
94 }
95
96 /// Compiled model — traces once, caches, executes with zero overhead.
97 ///
98 /// Returns a reference to the cached `CompiledGraph`. Call `.run()` or
99 /// `.run_raw()` to execute.
100 #fn_vis fn #compiled_name() -> &'static ::std::sync::Mutex<::rlx_runtime::CompiledGraph> {
101 use ::std::sync::{Mutex, OnceLock};
102
103 static COMPILED: OnceLock<Mutex<::rlx_runtime::CompiledGraph>> = OnceLock::new();
104
105 COMPILED.get_or_init(|| {
106 // Trace the function to build the IR graph
107 let graph = ::rlx_runtime::trace::trace(stringify!(#fn_name), |t| {
108 #builder_name(t)
109 });
110
111 // Compile: fuse → memory plan → thunks
112 let session = ::rlx_runtime::Session::new(::rlx_runtime::Device::Cpu);
113 let compiled = session.compile(graph);
114
115 Mutex::new(compiled)
116 })
117 }
118
119 // Keep original function accessible for debugging
120 #[allow(dead_code)]
121 #input_fn
122 };
123
124 TokenStream::from(expanded)
125}