use std::path::{Path, PathBuf};
use prost::Message;
use crate::Result;
use crate::proto;
#[derive(Default)]
pub struct BlobWriter {
buf: Vec<u8>,
count: u32,
}
const BLOB_SENTINEL: u32 = 0xDEAD_BEEF;
const BLOB_DTYPE_FLOAT32: u32 = 2;
impl BlobWriter {
pub fn new() -> Self {
let mut buf = vec![0u8; 64];
buf[4..8].copy_from_slice(&2u32.to_le_bytes());
BlobWriter { buf, count: 0 }
}
fn align64(&mut self) {
let rem = self.buf.len() % 64;
if rem != 0 {
self.buf.resize(self.buf.len() + (64 - rem), 0);
}
}
pub fn write_f32(&mut self, data: &[f32]) -> u64 {
self.align64();
let meta_off = self.buf.len() as u64;
let data_off = meta_off + 64; let size = (data.len() * 4) as u64;
let mut meta = [0u8; 64];
meta[0..4].copy_from_slice(&BLOB_SENTINEL.to_le_bytes());
meta[4..8].copy_from_slice(&BLOB_DTYPE_FLOAT32.to_le_bytes());
meta[8..16].copy_from_slice(&size.to_le_bytes());
meta[16..24].copy_from_slice(&data_off.to_le_bytes());
self.buf.extend_from_slice(&meta);
self.buf.reserve(data.len() * 4);
for &v in data {
self.buf.extend_from_slice(&v.to_le_bytes());
}
self.count += 1;
meta_off
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn finish(mut self) -> Vec<u8> {
if self.count == 0 {
return Vec::new();
}
self.buf[0..4].copy_from_slice(&self.count.to_le_bytes());
self.buf
}
}
const ROOT_ID: &str = "0000000000000000000000000000000000000000";
pub const PROTO_MAX_BYTES: usize = i32::MAX as usize;
pub fn check_model_size(model: &proto::Model, limit: usize) -> Result<()> {
let bytes = model.encoded_len();
if bytes > limit {
return Err(crate::CoremlError::TooLarge {
what: "CoreML model.mlmodel proto".to_string(),
bytes,
limit,
});
}
Ok(())
}
pub fn write_mlpackage(model: &proto::Model, blob: &[u8], package_dir: &Path) -> Result<PathBuf> {
let proto_bytes = encode_model(model)?;
write_mlpackage_bytes(&proto_bytes, blob, package_dir)
}
pub fn encode_model(model: &proto::Model) -> Result<Vec<u8>> {
check_model_size(model, PROTO_MAX_BYTES)?;
let mut buf = Vec::with_capacity(model.encoded_len());
model
.encode(&mut buf)
.map_err(|e| crate::CoremlError::Runtime(format!("proto encode: {e}")))?;
Ok(buf)
}
pub fn write_mlpackage_bytes(
proto_bytes: &[u8],
blob: &[u8],
package_dir: &Path,
) -> Result<PathBuf> {
if package_dir.exists() {
std::fs::remove_dir_all(package_dir)?;
}
let data_dir = package_dir.join("Data").join("com.apple.CoreML");
std::fs::create_dir_all(&data_dir)?;
std::fs::write(data_dir.join("model.mlmodel"), proto_bytes)?;
if !blob.is_empty() {
let wdir = data_dir.join("weights");
std::fs::create_dir_all(&wdir)?;
std::fs::write(wdir.join("weight.bin"), blob)?;
}
std::fs::write(package_dir.join("Manifest.json"), manifest_json())?;
Ok(package_dir.to_path_buf())
}
fn manifest_json() -> String {
format!(
r#"{{
"fileFormatVersion": "1.0.0",
"itemInfoEntries": {{
"{id}": {{
"author": "com.apple.CoreML",
"description": "CoreML Model Specification",
"name": "model.mlmodel",
"path": "com.apple.CoreML/model.mlmodel",
"digestType": "sha256"
}}
}},
"rootModelIdentifier": "{id}"
}}
"#,
id = ROOT_ID
)
}