Skip to main content

rlx_runtime/
aot_cache.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7
8//! AOT cache — persist optimized LIR modules and reload for backend compile.
9
10use std::fmt;
11use std::fs;
12use std::io;
13use std::path::PathBuf;
14
15use rlx_ir::DimBinding;
16use rlx_ir::Graph;
17use rlx_ir::LirFingerprint;
18use rlx_ir::LirModule;
19use rlx_ir::hir::HirModule;
20use rlx_opt::CompileResult;
21
22use crate::stages;
23use crate::{CompileOptions, CompiledGraph, Device};
24
25/// Errors from [`AotCache`] disk / compile operations.
26#[derive(Debug)]
27pub enum AotCacheError {
28    Io(io::Error),
29    Serde(String),
30    Lower(rlx_ir::hir::LowerError),
31}
32
33impl fmt::Display for AotCacheError {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        match self {
36            Self::Io(e) => write!(f, "{e}"),
37            Self::Serde(e) => write!(f, "serde: {e}"),
38            Self::Lower(e) => write!(f, "{e}"),
39        }
40    }
41}
42
43impl std::error::Error for AotCacheError {}
44
45impl From<io::Error> for AotCacheError {
46    fn from(e: io::Error) -> Self {
47        Self::Io(e)
48    }
49}
50
51impl From<rlx_ir::hir::LowerError> for AotCacheError {
52    fn from(e: rlx_ir::hir::LowerError) -> Self {
53        Self::Lower(e)
54    }
55}
56
57/// On-disk AOT cache for optimized LIR modules.
58pub struct AotCache {
59    root: PathBuf,
60}
61
62impl AotCache {
63    pub fn new(root: impl Into<PathBuf>) -> Self {
64        Self { root: root.into() }
65    }
66
67    fn lir_path(&self, key: &str) -> PathBuf {
68        self.root.join(format!("{key}.lir.json"))
69    }
70
71    fn meta_path(&self, key: &str) -> PathBuf {
72        self.root.join(format!("{key}.meta.json"))
73    }
74
75    /// Persist an optimized LIR module. Returns its compile fingerprint.
76    pub fn put_lir(&self, key: &str, lir: &LirModule) -> io::Result<LirFingerprint> {
77        fs::create_dir_all(&self.root)?;
78        let fp = LirFingerprint::of(lir);
79        let json = rlx_ir::lir_to_json(lir)
80            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
81        fs::write(self.lir_path(key), json)?;
82        fs::write(
83            self.meta_path(key),
84            format!("{{\"fingerprint\":{}}}\n", fp.0),
85        )?;
86        Ok(fp)
87    }
88
89    /// Load a previously stored LIR module.
90    pub fn get_lir(&self, key: &str) -> io::Result<LirModule> {
91        let json = fs::read_to_string(self.lir_path(key))?;
92        rlx_ir::lir_from_json(&json)
93            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))
94    }
95
96    pub fn contains(&self, key: &str) -> bool {
97        self.lir_path(key).is_file()
98    }
99
100    /// MIR graph → fusion pipeline → cached LIR → backend executable.
101    ///
102    /// On a cache hit only the backend compile runs (fusion / vmap /
103    /// autodiff already done). Keys must be unique per graph + options
104    /// fingerprint the caller encodes in `key`.
105    pub fn compile_graph_cached(
106        &self,
107        key: &str,
108        device: Device,
109        graph: Graph,
110        options: &CompileOptions,
111    ) -> Result<CompiledGraph, AotCacheError> {
112        if self.contains(key) {
113            let lir = self.get_lir(key)?;
114            return Ok(self.compile_lir(device, lir, options));
115        }
116        let result = stages::compile_graph_stages(device, graph, options);
117        stages::maybe_log_fusion(&result.fusion);
118        self.put_lir(key, &result.lir)?;
119        Ok(self.compile_lir(device, result.lir, options))
120    }
121
122    /// Compile HIR through the fusion pipeline, cache LIR, return executable.
123    pub fn compile_hir_cached(
124        &self,
125        key: &str,
126        device: Device,
127        hir: HirModule,
128        options: &CompileOptions,
129    ) -> Result<CompiledGraph, AotCacheError> {
130        if self.contains(key) {
131            let lir = self.get_lir(key)?;
132            return Ok(self.compile_lir(device, lir, options));
133        }
134        let result = stages::compile_hir_stages(device, hir, options)?;
135        stages::maybe_log_fusion(&result.fusion);
136        self.put_lir(key, &result.lir)?;
137        Ok(self.compile_lir(device, result.lir, options))
138    }
139
140    /// Specialize a cached dynamic LIR template and persist the bound variant.
141    pub fn specialize_cached(
142        &self,
143        base_key: &str,
144        binding: &DimBinding,
145        device: Device,
146        template: &CompileResult,
147        options: &CompileOptions,
148    ) -> Result<CompiledGraph, AotCacheError> {
149        let spec_key = format!("{base_key}__{}", binding_hash(binding));
150        if self.contains(&spec_key) {
151            let lir = self.get_lir(&spec_key)?;
152            return Ok(self.compile_lir(device, lir, options));
153        }
154        let pipe = stages::pipeline_for(device, options);
155        let specialized = template.specialize(&pipe, binding);
156        self.put_lir(&spec_key, &specialized.lir)?;
157        Ok(self.compile_lir(device, specialized.lir, options))
158    }
159
160    fn compile_lir(
161        &self,
162        device: Device,
163        lir: LirModule,
164        options: &CompileOptions,
165    ) -> CompiledGraph {
166        let backend = crate::registry::backend_for(device).expect("backend registered");
167        let executable = backend.compile_lir(lir, options);
168        CompiledGraph::new(executable, device)
169    }
170}
171
172fn binding_hash(binding: &DimBinding) -> u64 {
173    use std::collections::hash_map::DefaultHasher;
174    use std::hash::{Hash, Hasher};
175    let mut h = DefaultHasher::new();
176    for (sym, size) in binding.iter() {
177        sym.hash(&mut h);
178        size.hash(&mut h);
179    }
180    h.finish()
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use rlx_ir::DType;
187    use rlx_ir::Shape;
188
189    #[test]
190    fn aot_lir_roundtrip_on_disk() {
191        let dir = std::env::temp_dir().join(format!("rlx_aot_{}", std::process::id()));
192        let cache = AotCache::new(&dir);
193        let mut hir = HirModule::new("aot");
194        let x = hir.input("x", Shape::new(&[1, 4], DType::F32));
195        let w = hir.param("w", Shape::new(&[4, 2], DType::F32));
196        let y = hir.linear(x, w, None, None, Shape::new(&[1, 2], DType::F32));
197        hir.set_outputs(vec![y]);
198        let opts = CompileOptions::new();
199        let _compiled = cache
200            .compile_hir_cached("tiny", Device::Cpu, hir, &opts)
201            .expect("compile + cache");
202        assert!(cache.contains("tiny"));
203        let lir = cache.get_lir("tiny").expect("reload LIR");
204        assert_eq!(lir.name(), "aot");
205        fs::remove_dir_all(&dir).ok();
206    }
207}