Skip to main content

rlx_runtime/
flexible_session.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// Licensed under the GNU General Public License, version 3.
5
6//! Session that defers backend choice until compile time.
7
8use rlx_driver::Device;
9use rlx_ir::{Graph, GraphModule, HirModule};
10use rlx_opt::PrecisionPolicy;
11
12use crate::compiled::CompiledGraph;
13use crate::device_policy::{DevicePolicy, resolve_device};
14use crate::precision::Precision;
15use crate::session::Session;
16
17/// Compile-time settings without a fixed [`Device`].
18///
19/// Pick the backend per graph via [`Self::compile_resolved`] or
20/// [`Self::compile_on`].
21pub struct FlexibleSession {
22    device_policy: DevicePolicy,
23    precision: Precision,
24    op_policy: Option<PrecisionPolicy>,
25}
26
27impl Default for FlexibleSession {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl FlexibleSession {
34    pub fn new() -> Self {
35        Self {
36            device_policy: DevicePolicy::default(),
37            precision: Precision::F32,
38            op_policy: None,
39        }
40    }
41
42    pub fn from_env() -> Self {
43        Self {
44            device_policy: DevicePolicy::from_env(),
45            ..Self::new()
46        }
47    }
48
49    pub fn with_device_policy(mut self, policy: DevicePolicy) -> Self {
50        self.device_policy = policy;
51        self
52    }
53
54    pub fn with_precision(mut self, precision: Precision) -> Self {
55        self.precision = precision;
56        self
57    }
58
59    pub fn with_op_policy(mut self, policy: PrecisionPolicy) -> Self {
60        self.op_policy = Some(policy);
61        self
62    }
63
64    pub fn device_policy(&self) -> &DevicePolicy {
65        &self.device_policy
66    }
67
68    pub fn precision(&self) -> Precision {
69        self.precision
70    }
71
72    fn session_on(&self, device: Device) -> Session {
73        let mut s = Session::new_with_precision(device, self.precision);
74        if let Some(p) = &self.op_policy {
75            s = s.with_policy(p.clone());
76        }
77        s
78    }
79
80    pub fn compile_on(&self, graph: Graph, device: Device) -> Result<CompiledGraph, String> {
81        Ok(self.session_on(device).compile(graph))
82    }
83
84    pub fn compile_with_on(
85        &self,
86        graph: Graph,
87        device: Device,
88        options: &crate::CompileOptions,
89    ) -> Result<CompiledGraph, String> {
90        Ok(self.session_on(device).compile_with(graph, options))
91    }
92
93    pub fn compile_resolved(
94        &self,
95        graph: Graph,
96        hint: Option<Device>,
97    ) -> Result<CompiledGraph, String> {
98        let device = resolve_device(&graph, hint, &self.device_policy)?;
99        self.compile_on(graph, device)
100    }
101
102    pub fn compile_resolved_with(
103        &self,
104        graph: Graph,
105        hint: Option<Device>,
106        options: &crate::CompileOptions,
107    ) -> Result<CompiledGraph, String> {
108        let device = resolve_device(&graph, hint, &self.device_policy)?;
109        self.compile_with_on(graph, device, options)
110    }
111}
112
113impl FlexibleSession {
114    pub fn compile_hir_on(
115        &self,
116        hir: HirModule,
117        device: Device,
118    ) -> Result<CompiledGraph, rlx_ir::hir::LowerError> {
119        self.session_on(device).compile_hir(hir)
120    }
121
122    pub fn compile_module_on(
123        &self,
124        module: GraphModule,
125        device: Device,
126    ) -> Result<CompiledGraph, rlx_ir::hir::LowerError> {
127        self.session_on(device).compile_module(module)
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use rlx_ir::{DType, Shape};
135
136    #[test]
137    fn compile_resolved_picks_cpu() {
138        let mut g = Graph::new("id");
139        let x = g.input("x", Shape::new(&[2], DType::F32));
140        g.set_outputs(vec![x]);
141        let session = FlexibleSession::new().with_device_policy(DevicePolicy::only([Device::Cpu]));
142        let compiled = session.compile_resolved(g, None).expect("compile");
143        assert_eq!(compiled.device(), Device::Cpu);
144    }
145
146    #[test]
147    fn compile_resolved_with_matches_compile() {
148        let mut g = Graph::new("id");
149        let x = g.input("x", Shape::new(&[2], DType::F32));
150        let y = g.input("y", Shape::new(&[2], DType::F32));
151        let s = g.add_node(
152            rlx_ir::Op::Binary(rlx_ir::op::BinaryOp::Add),
153            vec![x, y],
154            Shape::new(&[2], DType::F32),
155        );
156        g.set_outputs(vec![s]);
157
158        let session = FlexibleSession::new().with_device_policy(DevicePolicy::only([Device::Cpu]));
159        let g1 = g.clone();
160        let g2 = g;
161        let a = session.compile_resolved(g1, None).expect("compile");
162        let b = session
163            .compile_resolved_with(g2, None, &crate::CompileOptions::new())
164            .expect("compile_with");
165        assert_eq!(a.device(), b.device());
166    }
167}