1use crate::backend::Backend;
19use crate::compiled::CompiledGraph;
20use crate::precision::Precision;
21use rlx_driver::Device;
22use rlx_ir::Graph;
23use rlx_ir::GraphModule;
24use rlx_ir::hir::HirModule;
25use rlx_opt::PrecisionPolicy;
26
27pub struct Session {
29 device: Device,
30 precision: Precision,
31 policy: Option<PrecisionPolicy>,
35}
36
37impl Session {
38 pub fn new(device: Device) -> Self {
43 Self::new_with_precision(device, Precision::F32)
44 }
45
46 pub fn new_with_precision(device: Device, precision: Precision) -> Self {
49 assert!(
50 crate::device_ext::is_available(device),
51 "device {} is not available — enable the `{}` Cargo feature",
52 device,
53 feature_name(device)
54 );
55 Self {
56 device,
57 precision,
58 policy: None,
59 }
60 }
61
62 pub fn with_policy(mut self, policy: PrecisionPolicy) -> Self {
66 self.policy = Some(policy);
67 self
68 }
69
70 pub fn device(&self) -> Device {
71 self.device
72 }
73 pub fn precision(&self) -> Precision {
74 self.precision
75 }
76 pub fn policy(&self) -> Option<&PrecisionPolicy> {
77 self.policy.as_ref()
78 }
79
80 pub fn compile(&self, graph: Graph) -> CompiledGraph {
85 self.compile_module(GraphModule::from_graph(graph))
86 .expect("compile MIR graph through fusion pipeline")
87 }
88
89 pub fn compile_graph(&self, graph: Graph) -> CompiledGraph {
91 self.compile(graph)
92 }
93
94 pub fn compile_with(&self, graph: Graph, options: &crate::CompileOptions) -> CompiledGraph {
99 self.compile_module_with(GraphModule::from_graph(graph), options)
100 .expect("compile MIR graph through fusion pipeline")
101 }
102
103 pub fn compile_hir(&self, hir: HirModule) -> Result<CompiledGraph, rlx_ir::hir::LowerError> {
105 self.compile_hir_with(hir, &self.default_options())
106 }
107
108 pub fn compile_hir_with(
110 &self,
111 hir: HirModule,
112 options: &crate::CompileOptions,
113 ) -> Result<CompiledGraph, rlx_ir::hir::LowerError> {
114 let backend = self.create_backend();
115 let executable = backend.compile_hir(hir, self.device, options)?;
116 Ok(CompiledGraph::new(executable, self.device))
117 }
118
119 pub fn compile_module(
121 &self,
122 module: GraphModule,
123 ) -> Result<CompiledGraph, rlx_ir::hir::LowerError> {
124 self.compile_module_with(module, &self.default_options())
125 }
126
127 pub fn compile_module_with(
129 &self,
130 module: GraphModule,
131 options: &crate::CompileOptions,
132 ) -> Result<CompiledGraph, rlx_ir::hir::LowerError> {
133 let backend = self.create_backend();
134 let executable = backend.compile_module(module, self.device, options)?;
135 Ok(CompiledGraph::new(executable, self.device))
136 }
137
138 fn default_options(&self) -> crate::CompileOptions {
139 let opts = crate::CompileOptions::new().precision(self.precision);
140 match &self.policy {
141 Some(p) => opts.policy(p.clone()),
142 None => opts,
143 }
144 }
145
146 fn create_backend(&self) -> Box<dyn Backend> {
147 crate::registry::backend_for(self.device).unwrap_or_else(|| {
151 panic!(
152 "no backend registered for device {} — enable feature `{}` \
153 (or call `rlx_runtime::register_backend` for an external backend)",
154 self.device,
155 feature_name(self.device)
156 )
157 })
158 }
159}
160
161fn feature_name(device: Device) -> &'static str {
162 match device {
163 Device::Cpu => "cpu",
164 Device::Metal => "metal",
165 Device::Mlx => "mlx",
166 Device::Ane => "ane",
167 Device::Cuda => "cuda",
168 Device::Rocm => "rocm",
169 Device::Tpu => "tpu",
170 Device::Gpu => "gpu",
171 Device::Vulkan => "vulkan",
172 Device::OpenGl => "opengl",
173 Device::DirectX => "directx",
174 Device::WebGpu => "webgpu",
175 }
176}