#![allow(dead_code)]
use crate::error::{IoError, Result};
use crate::ml_framework::converters::MLFrameworkConverter;
use crate::ml_framework::types::{MLFramework, MLModel, MLTensor};
use scirs2_core::ndarray::{ArrayD, IxDyn};
use std::fs::File;
use std::path::Path;
pub struct CoreMLConverter;
impl MLFrameworkConverter for CoreMLConverter {
fn save_model(&self, model: &MLModel, path: &Path) -> Result<()> {
let coreml_model = serde_json::json!({
"format": "coreml",
"specificationVersion": 5,
"description": {
"metadata": {
"userDefined": model.metadata.parameters,
"author": "SciRS2",
"license": "Apache-2.0",
"shortDescription": model.metadata.model_name.clone().unwrap_or_default()
},
"input": model.metadata.inputshapes.iter().map(|(name, shape)| {
serde_json::json!({
"name": name,
"type": {
"multiArrayType": {
"shape": shape,
"dataType": "FLOAT32"
}
}
})
}).collect::<Vec<_>>(),
"output": model.metadata.outputshapes.iter().map(|(name, shape)| {
serde_json::json!({
"name": name,
"type": {
"multiArrayType": {
"shape": shape,
"dataType": "FLOAT32"
}
}
})
}).collect::<Vec<_>>()
},
"neuralNetwork": {
"layers": [],
"preprocessing": []
},
"weights": model.weights.iter().map(|(name, tensor)| {
(name.clone(), serde_json::json!({
"shape": tensor.metadata.shape,
"floatValue": tensor.data.as_slice().expect("Operation failed").to_vec()
}))
}).collect::<serde_json::Map<String, serde_json::Value>>()
});
let file = File::create(path).map_err(IoError::Io)?;
serde_json::to_writer_pretty(file, &coreml_model)
.map_err(|e| IoError::SerializationError(e.to_string()))
}
fn load_model(&self, path: &Path) -> Result<MLModel> {
let file = File::open(path).map_err(IoError::Io)?;
let coreml_model: serde_json::Value = serde_json::from_reader(file)
.map_err(|e| IoError::SerializationError(e.to_string()))?;
let mut model = MLModel::new(MLFramework::CoreML);
if let Some(description) = coreml_model.get("description") {
if let Some(metadata) = description.get("metadata") {
if let Some(short_desc) = metadata.get("shortDescription").and_then(|v| v.as_str())
{
model.metadata.model_name = Some(short_desc.to_string());
}
}
if let Some(inputs) = description.get("input").and_then(|v| v.as_array()) {
for input in inputs {
if let Some(input_obj) = input.as_object() {
if let (Some(name), Some(shape)) = (
input_obj.get("name").and_then(|v| v.as_str()),
input_obj
.get("type")
.and_then(|t| t.get("multiArrayType"))
.and_then(|mat| mat.get("shape"))
.and_then(|s| s.as_array()),
) {
let shape_vec: Vec<usize> = shape
.iter()
.filter_map(|v| v.as_u64().map(|u| u as usize))
.collect();
model
.metadata
.inputshapes
.insert(name.to_string(), shape_vec);
}
}
}
}
if let Some(outputs) = description.get("output").and_then(|v| v.as_array()) {
for output in outputs {
if let Some(output_obj) = output.as_object() {
if let (Some(name), Some(shape)) = (
output_obj.get("name").and_then(|v| v.as_str()),
output_obj
.get("type")
.and_then(|t| t.get("multiArrayType"))
.and_then(|mat| mat.get("shape"))
.and_then(|s| s.as_array()),
) {
let shape_vec: Vec<usize> = shape
.iter()
.filter_map(|v| v.as_u64().map(|u| u as usize))
.collect();
model
.metadata
.outputshapes
.insert(name.to_string(), shape_vec);
}
}
}
}
}
if let Some(weights) = coreml_model.get("weights").and_then(|v| v.as_object()) {
for (name, weight_data) in weights {
let shape: Vec<usize> = serde_json::from_value(weight_data["shape"].clone())
.map_err(|e| IoError::SerializationError(e.to_string()))?;
let data: Vec<f32> = serde_json::from_value(weight_data["floatValue"].clone())
.map_err(|e| IoError::SerializationError(e.to_string()))?;
let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
.map_err(|e| IoError::Other(e.to_string()))?;
model
.weights
.insert(name.clone(), MLTensor::new(array, Some(name.clone())));
}
}
Ok(model)
}
fn save_tensor(&self, tensor: &MLTensor, path: &Path) -> Result<()> {
let tensor_data = serde_json::json!({
"coreml_multiarray": {
"shape": tensor.metadata.shape,
"dataType": "FLOAT32",
"floatValue": tensor.data.as_slice().expect("Operation failed").to_vec()
}
});
let file = File::create(path).map_err(IoError::Io)?;
serde_json::to_writer_pretty(file, &tensor_data)
.map_err(|e| IoError::SerializationError(e.to_string()))
}
fn load_tensor(&self, path: &Path) -> Result<MLTensor> {
let file = File::open(path).map_err(IoError::Io)?;
let tensor_data: serde_json::Value = serde_json::from_reader(file)
.map_err(|e| IoError::SerializationError(e.to_string()))?;
if let Some(multiarray) = tensor_data.get("coreml_multiarray") {
let shape: Vec<usize> = serde_json::from_value(multiarray["shape"].clone())
.map_err(|e| IoError::SerializationError(e.to_string()))?;
let data: Vec<f32> = serde_json::from_value(multiarray["floatValue"].clone())
.map_err(|e| IoError::SerializationError(e.to_string()))?;
let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
.map_err(|e| IoError::Other(e.to_string()))?;
return Ok(MLTensor::new(array, None));
}
Err(IoError::Other("Invalid CoreML tensor format".to_string()))
}
}