Skip to main content

rlx_ir/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! RLX Tensor IR — the intermediate representation for the RLX ML compiler.
17//!
18//! This IR is:
19//! - **Standalone**: no runtime, no backend, no framework coupling
20//! - **Serializable**: graphs can be saved/loaded for AOT compilation
21//! - **Optimizable**: designed for pattern-matching fusion and buffer planning
22//!
23//! The compiler pipeline has three named levels:
24//!
25//! - **HIR** ([`hir`]) — block-oriented IR for model builders (`Linear`,
26//!   `SwiGLU`, `ResidualRmsNorm`, …).
27//! - **MIR** ([`mir`]) — fused tensor DAG; input to [`rlx_opt`].
28//! - **LIR** ([`lir`]) — optimized MIR + arena buffer plan for backends.
29//!
30//! [`Graph`] is the primary DX surface. Use [`Graph::define`] for
31//! fusion-first HIR builders, or [`Graph::new`] / [`GraphModule::mir`]
32//! for primitive MIR. [`GraphModule`] tracks pipeline stage (HIR/MIR/LIR).
33//!
34//! - [`Graph`]: a DAG of tensor operations (like XLA's HloModule)
35//! - [`Node`]: a single operation with typed inputs/outputs
36//! - [`Op`]: the operation kind with parameters
37
38pub mod ad;
39pub mod async_copy;
40pub mod attention_layout;
41pub mod audio;
42pub mod const_check;
43pub mod dtype;
44pub mod dynamic;
45pub mod env;
46pub mod fft;
47pub mod graph;
48pub mod hir;
49pub mod infer;
50pub mod infer_shape;
51pub mod inspect;
52pub mod layout;
53pub mod lir;
54pub mod logical_kernel;
55pub mod measure;
56pub mod mir;
57pub mod module;
58pub mod nvfp4;
59pub mod op;
60pub mod op_registry;
61pub mod ops;
62pub mod perfetto;
63pub mod phase;
64pub mod pretty;
65pub mod provenance;
66pub mod quant;
67pub mod region_encode;
68pub mod rng;
69pub use nvfp4::{FP4_E2M1_LUT, NVFP4_GROUP_SIZE, fp4_e2m1_to_f32, fp8_e4m3_scale_to_f32};
70pub mod binding_manifest;
71pub mod component;
72pub mod hir_extension;
73pub mod reflect;
74pub mod rf;
75#[cfg(feature = "serialize")]
76pub mod serialize;
77pub mod shape;
78pub mod target;
79pub mod variant;
80pub mod verify;
81
82pub use ad::AdPipelineStage;
83pub use async_copy::{AsyncCopy, BarrierToken, DoubleBuffer, SyncCopy};
84pub use attention_layout::{
85    ATTENTION_FLASH_MAX_HEAD_DIM, AttentionGeom, AttentionLaunchStrides,
86    attention_dispatch_use_row, attention_geom, attention_launch_strides, cpu_attention_bshd,
87    cpu_attention_packed_bshd_qkv, detect_packed_bshd_qkv_attention, mask_strides_bhsd,
88    mask_strides_for_shape, packed_bshd_narrow_elidable, packed_bshd_qkv_strides, strides_bhsd,
89    strides_bshd, strides_for_shape,
90};
91pub use dtype::{DType, Element, ElementSubtype, scalar_constant_bytes};
92pub use dynamic::sym;
93pub use dynamic::{
94    DimEnv, bind_graph, collect_dynamic_symbols, has_dynamic_dims, infer_bindings_from_f32_inputs,
95    infer_bindings_from_inputs, same_binding, sync_concat_shapes, sync_expand_ops,
96    sync_graph_shapes, sync_narrow_ops, sync_reshape_ops,
97};
98pub use env::{RlxEnv, RuntimeOverrides, flag, is_unset, parse_or, set, unset, var, var_os};
99pub use fft::{FftGpuPlan, FftMeta, FftNorm, fft_meta, fftn_axes_all, normalize_fftn_axes};
100pub use graph::{Graph, Node, NodeId};
101pub use hir::{FusionPolicy, HirGraphExt, HirModule, HirMut, HirNode, HirNodeId, HirOp};
102pub use infer::GraphExt;
103pub use inspect::{
104    inspect_buffer_plan, inspect_graph, inspect_graph_diff, inspect_hir, inspect_hir_stats,
105    inspect_lir, inspect_mir, inspect_mir_diff, inspect_mir_stats,
106};
107pub use layout::{Coord2, Ragged, ShapeTuple, Strides2, Strides3, Tile2, Tile3};
108pub use lir::{
109    LirBufferPlan, LirBufferSlot, LirFingerprint, LirIoManifest, LirModule, LirViewAlias,
110};
111pub use logical_kernel::{
112    KernelDispatchConfig, KernelDispatchPolicy, LogicalKernelEntry, logical_kinds_in_graph,
113    registered_logical_kernels, should_lower_to_common,
114};
115pub use measure::{CacheBuster, Tick, time_ns};
116pub use mir::{MirModule, MirNode, MirNodeId, MirOp};
117pub use module::{GraphModule, GraphStage};
118pub use op::{ChainOperand, ChainStep, Op, OpKind, RegionPrologue, TransformStep};
119pub use op_registry::{
120    JvpContext, OpExtension, OpRegistry, VjpContext, VmapContext, global_registry, lookup_op,
121    register_op,
122};
123pub use ops::attention::attention_kind_op;
124pub use phase::{Phase, PhaseSchedule, derive_phases};
125pub use provenance::{NodeOrigin, node_label, stamp_pass_origins};
126pub use quant::{QuantMap, QuantScheme};
127pub use region_encode::{
128    FK_BATCH_SINGLE_KERNEL_MAX, PrologueLaunchGrid, REGION_META_WORDS, REGION_PROLOGUE_NONE,
129    REGION_PROLOGUE_RESIZE_NEAREST_2X_NCHW, RegionNchwDims, batch_region_slice_dst_off_f32,
130    batch_region_slice_elems, batch_region_slice_shape, encode_chain_operand, encode_chain_steps,
131    encode_elementwise_region_meta, encode_prologue_tail, fk_batch_single_kernel_enabled,
132    fk_batch_use_single_launch,
133};
134pub use rng::Philox4x32;
135#[cfg(feature = "serialize")]
136pub use serialize::{hir_from_json, hir_to_json, lir_from_json, lir_to_json};
137pub use verify::{VerifyError, verify, verify_all, verify_shapes};
138
139/// Lower a HIR module to MIR, then extract the legacy [`Graph`] API surface.
140pub fn hir_to_graph(hir: HirModule) -> Result<Graph, hir::LowerError> {
141    Ok(hir.lower_to_mir()?.into_graph())
142}
143pub use binding_manifest::{BindingManifest, IoBindingEntry, WeightBlock};
144pub use component::{CompilationMode, ModelComponent};
145pub use hir_extension::{
146    HirExtensionFn, apply_hir_extensions, apply_hir_extensions_named, register_hir_extension,
147    registered_hir_extensions,
148};
149pub use reflect::{
150    BlockSpecialization, HirReflection, ManifestDiff, MirReflection, SpecializeBlockRecord,
151    layout_for_binding, layout_from_lir, probe_block_specialization, symbolic_layout_hint,
152};
153pub use rf::{
154    complex_div, const_f32, cs_degen_z_in, find_param_node, find_param_nodes, mag2, s11_from_z,
155    scalar_f32,
156};
157pub use shape::{Dim, DimBinding, Shape};
158pub use variant::{ModelPhase, ModelVariant};