use core::fmt;
use crate::alloc::string::String;
use crate::alloc::vec::Vec;
use crate::ir::function::{FunctionStencil, VersionMarker};
use crate::ir::Function;
use crate::machinst::{CompiledCode, CompiledCodeStencil};
use crate::result::CompileResult;
use crate::{isa::TargetIsa, timing};
use crate::{trace, CompileError, Context};
use alloc::borrow::{Cow, ToOwned as _};
use alloc::string::ToString as _;
impl Context {
pub fn compile_with_cache(
&mut self,
isa: &dyn TargetIsa,
cache_store: &mut dyn CacheKvStore,
) -> CompileResult<(&CompiledCode, bool)> {
let cache_key_hash = {
let _tt = timing::try_incremental_cache();
let cache_key_hash = compute_cache_key(isa, &mut self.func);
if let Some(blob) = cache_store.get(&cache_key_hash.0) {
match try_finish_recompile(&self.func, &blob) {
Ok(compiled_code) => {
let info = compiled_code.code_info();
if isa.flags().enable_incremental_compilation_cache_checks() {
let actual_result = self.compile(isa)?;
assert_eq!(*actual_result, compiled_code);
assert_eq!(actual_result.code_info(), info);
return Ok((actual_result, true));
}
let compiled_code = self.compiled_code.insert(compiled_code);
return Ok((compiled_code, true));
}
Err(err) => {
trace!("error when finishing recompilation: {err}");
}
}
}
cache_key_hash
};
let stencil = self.compile_stencil(isa).map_err(|err| CompileError {
inner: err,
func: &self.func,
})?;
let stencil = {
let _tt = timing::store_incremental_cache();
let (stencil, res) = serialize_compiled(stencil);
if let Ok(blob) = res {
cache_store.insert(&cache_key_hash.0, blob);
}
stencil
};
let compiled_code = self
.compiled_code
.insert(stencil.apply_params(&self.func.params));
Ok((compiled_code, false))
}
}
pub trait CacheKvStore {
fn get(&self, key: &[u8]) -> Option<Cow<[u8]>>;
fn insert(&mut self, key: &[u8], val: Vec<u8>);
}
#[derive(Clone, Hash, PartialEq, Eq)]
pub struct CacheKeyHash([u8; 32]);
impl std::fmt::Display for CacheKeyHash {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "CacheKeyHash:{:?}", self.0)
}
}
#[derive(serde::Serialize, serde::Deserialize)]
struct CachedFunc {
stencil: CompiledCodeStencil,
version_marker: VersionMarker,
}
#[derive(Hash)]
struct CacheKey<'a> {
stencil: &'a FunctionStencil,
parameters: CompileParameters,
}
#[derive(Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
struct CompileParameters {
isa: String,
triple: String,
flags: String,
isa_flags: Vec<String>,
}
impl CompileParameters {
fn from_isa(isa: &dyn TargetIsa) -> Self {
Self {
isa: isa.name().to_owned(),
triple: isa.triple().to_string(),
flags: isa.flags().to_string(),
isa_flags: isa
.isa_flags()
.into_iter()
.map(|v| v.value_string())
.collect(),
}
}
}
impl<'a> CacheKey<'a> {
fn new(isa: &dyn TargetIsa, f: &'a mut Function) -> Self {
f.stencil.layout.full_renumber();
CacheKey {
stencil: &f.stencil,
parameters: CompileParameters::from_isa(isa),
}
}
}
pub fn compute_cache_key(isa: &dyn TargetIsa, func: &mut Function) -> CacheKeyHash {
use core::hash::{Hash as _, Hasher};
use sha2::Digest as _;
struct Sha256Hasher(sha2::Sha256);
impl Hasher for Sha256Hasher {
fn finish(&self) -> u64 {
panic!("Sha256Hasher doesn't support finish!");
}
fn write(&mut self, bytes: &[u8]) {
self.0.update(bytes);
}
}
let cache_key = CacheKey::new(isa, func);
let mut hasher = Sha256Hasher(sha2::Sha256::new());
cache_key.hash(&mut hasher);
let hash: [u8; 32] = hasher.0.finalize().into();
CacheKeyHash(hash)
}
pub fn serialize_compiled(
result: CompiledCodeStencil,
) -> (CompiledCodeStencil, Result<Vec<u8>, bincode::Error>) {
let cached = CachedFunc {
stencil: result,
version_marker: VersionMarker,
};
let result = bincode::serialize(&cached);
(cached.stencil, result)
}
#[derive(Debug)]
pub enum RecompileError {
VersionMismatch,
Deserialize(bincode::Error),
}
impl fmt::Display for RecompileError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RecompileError::VersionMismatch => write!(f, "cranelift version mismatch",),
RecompileError::Deserialize(err) => {
write!(f, "bincode failed during deserialization: {err}")
}
}
}
}
pub fn try_finish_recompile(func: &Function, bytes: &[u8]) -> Result<CompiledCode, RecompileError> {
match bincode::deserialize::<CachedFunc>(bytes) {
Ok(result) => {
if result.version_marker != func.stencil.version_marker {
Err(RecompileError::VersionMismatch)
} else {
Ok(result.stencil.apply_params(&func.params))
}
}
Err(err) => Err(RecompileError::Deserialize(err)),
}
}