Skip to main content

rlx_runtime/
session.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Session — the main entry point for compiling and executing graphs.
17
18use crate::backend::Backend;
19use crate::compiled::CompiledGraph;
20use crate::precision::Precision;
21use rlx_driver::Device;
22use rlx_ir::Graph;
23use rlx_ir::GraphModule;
24use rlx_ir::hir::HirModule;
25use rlx_opt::PrecisionPolicy;
26
27/// A session manages graph compilation and execution on a device.
28pub struct Session {
29    device: Device,
30    precision: Precision,
31    /// Optional per-op precision policy. If set, runs AutoMixedPrecision
32    /// rewrite before backend compile. Works identically across all modes
33    /// (AOT compile, trace/JIT, proc-macro AOT) — it's just a graph pass.
34    policy: Option<PrecisionPolicy>,
35}
36
37impl Session {
38    /// Create a session for the given device at default (F32) precision.
39    ///
40    /// # Panics
41    /// Panics if the device is not available (missing feature flag).
42    pub fn new(device: Device) -> Self {
43        Self::new_with_precision(device, Precision::F32)
44    }
45
46    /// Create a session targeting a specific numeric precision.
47    /// Backends fall back to F32 if the requested precision isn't supported.
48    pub fn new_with_precision(device: Device, precision: Precision) -> Self {
49        assert!(
50            crate::device_ext::is_available(device),
51            "device {} is not available — enable the `{}` Cargo feature",
52            device,
53            feature_name(device)
54        );
55        Self {
56            device,
57            precision,
58            policy: None,
59        }
60    }
61
62    /// Builder: set a per-op precision policy. Applied as a graph rewrite
63    /// before backend compile. Same mechanism works for AOT compile, JIT
64    /// tracing, and proc-macro AOT — it's a graph pass, not a runtime mode.
65    pub fn with_policy(mut self, policy: PrecisionPolicy) -> Self {
66        self.policy = Some(policy);
67        self
68    }
69
70    pub fn device(&self) -> Device {
71        self.device
72    }
73    pub fn precision(&self) -> Precision {
74        self.precision
75    }
76    pub fn policy(&self) -> Option<&PrecisionPolicy> {
77        self.policy.as_ref()
78    }
79
80    /// Compile a MIR graph through the fusion-first pipeline (`GraphModule` → LIR).
81    ///
82    /// Prefer [`Self::compile_hir`] or [`Self::compile_module`] for new code.
83    /// This entry wraps the graph as a MIR-stage [`GraphModule`].
84    pub fn compile(&self, graph: Graph) -> CompiledGraph {
85        self.compile_module(GraphModule::from_graph(graph))
86            .expect("compile MIR graph through fusion pipeline")
87    }
88
89    /// Explicit legacy alias — same as [`Self::compile`].
90    pub fn compile_graph(&self, graph: Graph) -> CompiledGraph {
91        self.compile(graph)
92    }
93
94    /// Compile with explicit options (full control over the pipeline).
95    /// Most callers use `compile()` and configure the session via
96    /// `new_with_precision` / `with_policy`. This escape hatch is for
97    /// callers that need finer control (e.g., disable DCE for debugging).
98    pub fn compile_with(&self, graph: Graph, options: &crate::CompileOptions) -> CompiledGraph {
99        self.compile_module_with(GraphModule::from_graph(graph), options)
100            .expect("compile MIR graph through fusion pipeline")
101    }
102
103    /// Compile a fusion-first HIR module through HIR → MIR → LIR.
104    pub fn compile_hir(&self, hir: HirModule) -> Result<CompiledGraph, rlx_ir::hir::LowerError> {
105        self.compile_hir_with(hir, &self.default_options())
106    }
107
108    /// Compile HIR with explicit compile options.
109    pub fn compile_hir_with(
110        &self,
111        hir: HirModule,
112        options: &crate::CompileOptions,
113    ) -> Result<CompiledGraph, rlx_ir::hir::LowerError> {
114        let backend = self.create_backend();
115        let executable = backend.compile_hir(hir, self.device, options)?;
116        Ok(CompiledGraph::new(executable, self.device))
117    }
118
119    /// Compile a [`GraphModule`] (HIR/MIR/LIR stage) through the pipeline.
120    pub fn compile_module(
121        &self,
122        module: GraphModule,
123    ) -> Result<CompiledGraph, rlx_ir::hir::LowerError> {
124        self.compile_module_with(module, &self.default_options())
125    }
126
127    /// Compile a [`GraphModule`] with explicit compile options.
128    pub fn compile_module_with(
129        &self,
130        module: GraphModule,
131        options: &crate::CompileOptions,
132    ) -> Result<CompiledGraph, rlx_ir::hir::LowerError> {
133        let backend = self.create_backend();
134        let executable = backend.compile_module(module, self.device, options)?;
135        Ok(CompiledGraph::new(executable, self.device))
136    }
137
138    fn default_options(&self) -> crate::CompileOptions {
139        let opts = crate::CompileOptions::new().precision(self.precision);
140        match &self.policy {
141            Some(p) => opts.policy(p.clone()),
142            None => opts,
143        }
144    }
145
146    fn create_backend(&self) -> Box<dyn Backend> {
147        // Single dispatch point: consult the registry. Backends register
148        // themselves (builtins via cfg-gated `register_builtin`; external
149        // crates via `register_backend`). No hardcoded match here.
150        crate::registry::backend_for(self.device).unwrap_or_else(|| {
151            panic!(
152                "no backend registered for device {} — enable feature `{}` \
153                 (or call `rlx_runtime::register_backend` for an external backend)",
154                self.device,
155                feature_name(self.device)
156            )
157        })
158    }
159}
160
161fn feature_name(device: Device) -> &'static str {
162    match device {
163        Device::Cpu => "cpu",
164        Device::Metal => "metal",
165        Device::Mlx => "mlx",
166        Device::Ane => "ane",
167        Device::Cuda => "cuda",
168        Device::Rocm => "rocm",
169        Device::Tpu => "tpu",
170        Device::Gpu => "gpu",
171        Device::Vulkan => "vulkan",
172        Device::OpenGl => "opengl",
173        Device::DirectX => "directx",
174        Device::WebGpu => "webgpu",
175    }
176}