use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use rlx_ir::{DType, Graph, Op};
use crate::ffi::CoremlModel;
use crate::mil::{LoweredProgram, TypedParams, lower_graph};
use crate::{ChipInfo, ComputeUnits, CoremlError, Result};
fn content_hash(proto: &[u8], blob: &[u8]) -> String {
use std::hash::{Hash, Hasher};
let mut h = std::collections::hash_map::DefaultHasher::new();
proto.hash(&mut h);
blob.hash(&mut h);
format!("{:016x}", h.finish())
}
static PKG_COUNTER: AtomicU64 = AtomicU64::new(0);
pub struct CoremlExecutable {
graph: Graph,
params: HashMap<String, Vec<f32>>,
typed_params: TypedParams,
compute_units: ComputeUnits,
lowered: Option<LoweredProgram>,
model: Option<CoremlModel>,
pkg_dir: Option<PathBuf>,
}
fn promote_int_to_f32(graph: &mut Graph) {
fn is_int(dt: DType) -> bool {
matches!(
dt,
DType::I64 | DType::I32 | DType::U32 | DType::I8 | DType::U8
)
}
for node in graph.nodes_mut() {
let dt = node.shape.dtype();
if is_int(dt) {
if let Op::Constant { data } = &mut node.op {
let floats: Vec<f32> = match dt {
DType::I64 => data
.chunks_exact(8)
.map(|c| i64::from_le_bytes(c.try_into().unwrap()) as f32)
.collect(),
DType::I32 => data
.chunks_exact(4)
.map(|c| i32::from_le_bytes(c.try_into().unwrap()) as f32)
.collect(),
DType::U32 => data
.chunks_exact(4)
.map(|c| u32::from_le_bytes(c.try_into().unwrap()) as f32)
.collect(),
DType::U8 => data.iter().map(|&b| b as f32).collect(),
DType::I8 => data.iter().map(|&b| (b as i8) as f32).collect(),
_ => unreachable!(),
};
*data = floats.iter().flat_map(|f| f.to_le_bytes()).collect();
}
}
if let Op::Cast { to } = &mut node.op {
if is_int(*to) {
*to = DType::F32;
}
}
if is_int(dt) {
node.shape = node.shape.clone().with_dtype(DType::F32);
}
}
}
impl CoremlExecutable {
pub fn compile(graph: Graph) -> Self {
let units = match std::env::var("RLX_COREML_UNITS").as_deref() {
Ok("cpu") => ComputeUnits::CpuOnly,
Ok("gpu") => ComputeUnits::CpuAndGpu,
Ok("all") => ComputeUnits::All,
_ => ComputeUnits::CpuAndNeuralEngine,
};
Self::compile_with_units(graph, units)
}
pub fn compile_with_units(mut graph: Graph, compute_units: ComputeUnits) -> Self {
promote_int_to_f32(&mut graph);
CoremlExecutable {
graph,
params: HashMap::new(),
typed_params: TypedParams::new(),
compute_units,
lowered: None,
model: None,
pkg_dir: None,
}
}
pub fn clone_for_cache(&self) -> Self {
CoremlExecutable {
graph: self.graph.clone(),
params: self.params.clone(),
typed_params: self.typed_params.clone(),
compute_units: self.compute_units,
lowered: None,
model: None,
pkg_dir: None,
}
}
pub fn set_param(&mut self, name: &str, data: &[f32]) {
self.params.insert(name.to_string(), data.to_vec());
self.model = None;
self.lowered = None;
}
pub fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: DType) {
if dtype == DType::F32 {
let floats: Vec<f32> = data
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
self.set_param(name, &floats);
return;
}
self.typed_params
.insert(name.to_string(), (data.to_vec(), dtype));
self.model = None;
self.lowered = None;
}
pub fn finalize(&mut self) -> Result<()> {
if self.model.is_some() {
return Ok(());
}
let lowered = lower_graph(&self.graph, &self.params, &self.typed_params)?;
let proto_bytes = crate::mlpackage::encode_model(&lowered.model)?;
let key = content_hash(&proto_bytes, &lowered.blob);
let cache_dir = std::env::temp_dir().join("rlx-coreml-cache");
let cache_path = cache_dir.join(format!("{key}.mlmodelc"));
let seq = PKG_COUNTER.fetch_add(1, Ordering::Relaxed);
let pid = std::process::id();
let dir = std::env::temp_dir().join(format!(
"rlx-coreml-{pid}-{seq}-{}.mlpackage",
sanitize(&self.graph.name)
));
if !cache_path.exists() {
crate::mlpackage::write_mlpackage_bytes(&proto_bytes, &lowered.blob, &dir)?;
}
let model = CoremlModel::load(&dir, self.compute_units.code(), Some(cache_path.as_path()))?;
self.lowered = Some(lowered);
self.pkg_dir = Some(dir);
self.model = Some(model);
Ok(())
}
pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Result<Vec<Vec<f32>>> {
self.finalize()?;
let lowered = self.lowered.as_ref().expect("finalized");
let mut in_args = Vec::with_capacity(lowered.inputs.len());
for io in &lowered.inputs {
let data = inputs
.iter()
.find(|(n, _)| *n == io.ir_name)
.map(|(_, d)| *d)
.ok_or_else(|| CoremlError::Runtime(format!("missing input '{}'", io.ir_name)))?;
let cname = std::ffi::CString::new(io.feature_name.as_bytes())
.map_err(|_| CoremlError::Runtime("feature name contains NUL".into()))?;
in_args.push((cname, io.dims.clone(), data));
}
let mut out_bufs: Vec<Vec<f32>> = lowered
.outputs
.iter()
.map(|io| vec![0.0f32; io.numel()])
.collect();
let mut out_args: Vec<(std::ffi::CString, &mut [f32])> = Vec::new();
for (io, buf) in lowered.outputs.iter().zip(out_bufs.iter_mut()) {
let cname = std::ffi::CString::new(io.feature_name.as_bytes())
.map_err(|_| CoremlError::Runtime("feature name contains NUL".into()))?;
out_args.push((cname, buf.as_mut_slice()));
}
self.model
.as_mut()
.expect("finalized")
.predict(&in_args, &mut out_args)?;
Ok(out_bufs)
}
pub fn compute_plan(&mut self) -> Result<Option<[i32; 4]>> {
self.finalize()?;
Ok(self.model.as_mut().expect("finalized").compute_plan())
}
pub fn chip_info(&self) -> ChipInfo {
crate::chip_info()
}
}
impl Drop for CoremlExecutable {
fn drop(&mut self) {
self.model = None;
if let Some(dir) = self.pkg_dir.take() {
let _ = std::fs::remove_dir_all(dir);
}
}
}
fn sanitize(raw: &str) -> String {
raw.chars()
.map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
.collect()
}