rlx_runtime/
flexible_session.rs1use rlx_driver::Device;
9use rlx_ir::{Graph, GraphModule, HirModule};
10use rlx_opt::PrecisionPolicy;
11
12use crate::compiled::CompiledGraph;
13use crate::device_policy::{DevicePolicy, resolve_device};
14use crate::precision::Precision;
15use crate::session::Session;
16
17pub struct FlexibleSession {
22 device_policy: DevicePolicy,
23 precision: Precision,
24 op_policy: Option<PrecisionPolicy>,
25}
26
27impl Default for FlexibleSession {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl FlexibleSession {
34 pub fn new() -> Self {
35 Self {
36 device_policy: DevicePolicy::default(),
37 precision: Precision::F32,
38 op_policy: None,
39 }
40 }
41
42 pub fn from_env() -> Self {
43 Self {
44 device_policy: DevicePolicy::from_env(),
45 ..Self::new()
46 }
47 }
48
49 pub fn with_device_policy(mut self, policy: DevicePolicy) -> Self {
50 self.device_policy = policy;
51 self
52 }
53
54 pub fn with_precision(mut self, precision: Precision) -> Self {
55 self.precision = precision;
56 self
57 }
58
59 pub fn with_op_policy(mut self, policy: PrecisionPolicy) -> Self {
60 self.op_policy = Some(policy);
61 self
62 }
63
64 pub fn device_policy(&self) -> &DevicePolicy {
65 &self.device_policy
66 }
67
68 pub fn precision(&self) -> Precision {
69 self.precision
70 }
71
72 fn session_on(&self, device: Device) -> Session {
73 let mut s = Session::new_with_precision(device, self.precision);
74 if let Some(p) = &self.op_policy {
75 s = s.with_policy(p.clone());
76 }
77 s
78 }
79
80 pub fn compile_on(&self, graph: Graph, device: Device) -> Result<CompiledGraph, String> {
81 Ok(self.session_on(device).compile(graph))
82 }
83
84 pub fn compile_with_on(
85 &self,
86 graph: Graph,
87 device: Device,
88 options: &crate::CompileOptions,
89 ) -> Result<CompiledGraph, String> {
90 Ok(self.session_on(device).compile_with(graph, options))
91 }
92
93 pub fn compile_resolved(
94 &self,
95 graph: Graph,
96 hint: Option<Device>,
97 ) -> Result<CompiledGraph, String> {
98 let device = resolve_device(&graph, hint, &self.device_policy)?;
99 self.compile_on(graph, device)
100 }
101
102 pub fn compile_resolved_with(
103 &self,
104 graph: Graph,
105 hint: Option<Device>,
106 options: &crate::CompileOptions,
107 ) -> Result<CompiledGraph, String> {
108 let device = resolve_device(&graph, hint, &self.device_policy)?;
109 self.compile_with_on(graph, device, options)
110 }
111}
112
113impl FlexibleSession {
114 pub fn compile_hir_on(
115 &self,
116 hir: HirModule,
117 device: Device,
118 ) -> Result<CompiledGraph, rlx_ir::hir::LowerError> {
119 self.session_on(device).compile_hir(hir)
120 }
121
122 pub fn compile_module_on(
123 &self,
124 module: GraphModule,
125 device: Device,
126 ) -> Result<CompiledGraph, rlx_ir::hir::LowerError> {
127 self.session_on(device).compile_module(module)
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use rlx_ir::{DType, Shape};
135
136 #[test]
137 fn compile_resolved_picks_cpu() {
138 let mut g = Graph::new("id");
139 let x = g.input("x", Shape::new(&[2], DType::F32));
140 g.set_outputs(vec![x]);
141 let session = FlexibleSession::new().with_device_policy(DevicePolicy::only([Device::Cpu]));
142 let compiled = session.compile_resolved(g, None).expect("compile");
143 assert_eq!(compiled.device(), Device::Cpu);
144 }
145
146 #[test]
147 fn compile_resolved_with_matches_compile() {
148 let mut g = Graph::new("id");
149 let x = g.input("x", Shape::new(&[2], DType::F32));
150 let y = g.input("y", Shape::new(&[2], DType::F32));
151 let s = g.add_node(
152 rlx_ir::Op::Binary(rlx_ir::op::BinaryOp::Add),
153 vec![x, y],
154 Shape::new(&[2], DType::F32),
155 );
156 g.set_outputs(vec![s]);
157
158 let session = FlexibleSession::new().with_device_policy(DevicePolicy::only([Device::Cpu]));
159 let g1 = g.clone();
160 let g2 = g;
161 let a = session.compile_resolved(g1, None).expect("compile");
162 let b = session
163 .compile_resolved_with(g2, None, &crate::CompileOptions::new())
164 .expect("compile_with");
165 assert_eq!(a.device(), b.device());
166 }
167}