1use std::fmt;
11use std::fs;
12use std::io;
13use std::path::PathBuf;
14
15use rlx_ir::DimBinding;
16use rlx_ir::Graph;
17use rlx_ir::LirFingerprint;
18use rlx_ir::LirModule;
19use rlx_ir::hir::HirModule;
20use rlx_opt::CompileResult;
21
22use crate::stages;
23use crate::{CompileOptions, CompiledGraph, Device};
24
25#[derive(Debug)]
27pub enum AotCacheError {
28 Io(io::Error),
29 Serde(String),
30 Lower(rlx_ir::hir::LowerError),
31}
32
33impl fmt::Display for AotCacheError {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 match self {
36 Self::Io(e) => write!(f, "{e}"),
37 Self::Serde(e) => write!(f, "serde: {e}"),
38 Self::Lower(e) => write!(f, "{e}"),
39 }
40 }
41}
42
43impl std::error::Error for AotCacheError {}
44
45impl From<io::Error> for AotCacheError {
46 fn from(e: io::Error) -> Self {
47 Self::Io(e)
48 }
49}
50
51impl From<rlx_ir::hir::LowerError> for AotCacheError {
52 fn from(e: rlx_ir::hir::LowerError) -> Self {
53 Self::Lower(e)
54 }
55}
56
57pub struct AotCache {
59 root: PathBuf,
60}
61
62impl AotCache {
63 pub fn new(root: impl Into<PathBuf>) -> Self {
64 Self { root: root.into() }
65 }
66
67 fn lir_path(&self, key: &str) -> PathBuf {
68 self.root.join(format!("{key}.lir.json"))
69 }
70
71 fn meta_path(&self, key: &str) -> PathBuf {
72 self.root.join(format!("{key}.meta.json"))
73 }
74
75 pub fn put_lir(&self, key: &str, lir: &LirModule) -> io::Result<LirFingerprint> {
77 fs::create_dir_all(&self.root)?;
78 let fp = LirFingerprint::of(lir);
79 let json = rlx_ir::lir_to_json(lir)
80 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
81 fs::write(self.lir_path(key), json)?;
82 fs::write(
83 self.meta_path(key),
84 format!("{{\"fingerprint\":{}}}\n", fp.0),
85 )?;
86 Ok(fp)
87 }
88
89 pub fn get_lir(&self, key: &str) -> io::Result<LirModule> {
91 let json = fs::read_to_string(self.lir_path(key))?;
92 rlx_ir::lir_from_json(&json)
93 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))
94 }
95
96 pub fn contains(&self, key: &str) -> bool {
97 self.lir_path(key).is_file()
98 }
99
100 pub fn compile_graph_cached(
106 &self,
107 key: &str,
108 device: Device,
109 graph: Graph,
110 options: &CompileOptions,
111 ) -> Result<CompiledGraph, AotCacheError> {
112 if self.contains(key) {
113 let lir = self.get_lir(key)?;
114 return Ok(self.compile_lir(device, lir, options));
115 }
116 let result = stages::compile_graph_stages(device, graph, options);
117 stages::maybe_log_fusion(&result.fusion);
118 self.put_lir(key, &result.lir)?;
119 Ok(self.compile_lir(device, result.lir, options))
120 }
121
122 pub fn compile_hir_cached(
124 &self,
125 key: &str,
126 device: Device,
127 hir: HirModule,
128 options: &CompileOptions,
129 ) -> Result<CompiledGraph, AotCacheError> {
130 if self.contains(key) {
131 let lir = self.get_lir(key)?;
132 return Ok(self.compile_lir(device, lir, options));
133 }
134 let result = stages::compile_hir_stages(device, hir, options)?;
135 stages::maybe_log_fusion(&result.fusion);
136 self.put_lir(key, &result.lir)?;
137 Ok(self.compile_lir(device, result.lir, options))
138 }
139
140 pub fn specialize_cached(
142 &self,
143 base_key: &str,
144 binding: &DimBinding,
145 device: Device,
146 template: &CompileResult,
147 options: &CompileOptions,
148 ) -> Result<CompiledGraph, AotCacheError> {
149 let spec_key = format!("{base_key}__{}", binding_hash(binding));
150 if self.contains(&spec_key) {
151 let lir = self.get_lir(&spec_key)?;
152 return Ok(self.compile_lir(device, lir, options));
153 }
154 let pipe = stages::pipeline_for(device, options);
155 let specialized = template.specialize(&pipe, binding);
156 self.put_lir(&spec_key, &specialized.lir)?;
157 Ok(self.compile_lir(device, specialized.lir, options))
158 }
159
160 fn compile_lir(
161 &self,
162 device: Device,
163 lir: LirModule,
164 options: &CompileOptions,
165 ) -> CompiledGraph {
166 let backend = crate::registry::backend_for(device).expect("backend registered");
167 let executable = backend.compile_lir(lir, options);
168 CompiledGraph::new(executable, device)
169 }
170}
171
172fn binding_hash(binding: &DimBinding) -> u64 {
173 use std::collections::hash_map::DefaultHasher;
174 use std::hash::{Hash, Hasher};
175 let mut h = DefaultHasher::new();
176 for (sym, size) in binding.iter() {
177 sym.hash(&mut h);
178 size.hash(&mut h);
179 }
180 h.finish()
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use rlx_ir::DType;
187 use rlx_ir::Shape;
188
189 #[test]
190 fn aot_lir_roundtrip_on_disk() {
191 let dir = std::env::temp_dir().join(format!("rlx_aot_{}", std::process::id()));
192 let cache = AotCache::new(&dir);
193 let mut hir = HirModule::new("aot");
194 let x = hir.input("x", Shape::new(&[1, 4], DType::F32));
195 let w = hir.param("w", Shape::new(&[4, 2], DType::F32));
196 let y = hir.linear(x, w, None, None, Shape::new(&[1, 2], DType::F32));
197 hir.set_outputs(vec![y]);
198 let opts = CompileOptions::new();
199 let _compiled = cache
200 .compile_hir_cached("tiny", Device::Cpu, hir, &opts)
201 .expect("compile + cache");
202 assert!(cache.contains("tiny"));
203 let lir = cache.get_lir("tiny").expect("reload LIR");
204 assert_eq!(lir.name(), "aot");
205 fs::remove_dir_all(&dir).ok();
206 }
207}