#![allow(dead_code)]
use crate::error::{IoError, Result};
use crate::ml_framework::converters::MLFrameworkConverter;
use crate::ml_framework::types::{MLFramework, MLModel, MLTensor};
use scirs2_core::ndarray::{ArrayD, IxDyn};
use std::fs::File;
use std::path::Path;
pub struct TensorFlowConverter;
impl MLFrameworkConverter for TensorFlowConverter {
fn save_model(&self, model: &MLModel, path: &Path) -> Result<()> {
let model_dir = path.parent().unwrap_or(Path::new("."));
std::fs::create_dir_all(model_dir).map_err(IoError::Io)?;
let tf_model = serde_json::json!({
"saved_model_schema_version": 1,
"meta_graphs": [{
"meta_info_def": {
"meta_graph_version": "v2.0.0",
"tensorflow_version": "2.12.0",
"tags": ["serve"]
},
"graph_def": {
"versions": { "producer": 1982, "min_consumer": 12 }
},
"signature_def": {
"serving_default": {
"inputs": model.metadata.inputshapes,
"outputs": model.metadata.outputshapes,
"method_name": "tensorflow/serving/predict"
}
}
}],
"variables": model.weights.iter().map(|(name, tensor)| {
serde_json::json!({
"name": name,
"shape": tensor.metadata.shape,
"dtype": format!("{:?}", tensor.metadata.dtype),
"data": tensor.data.as_slice().expect("Operation failed").to_vec()
})
}).collect::<Vec<_>>()
});
let file = File::create(path).map_err(IoError::Io)?;
serde_json::to_writer_pretty(file, &tf_model)
.map_err(|e| IoError::SerializationError(e.to_string()))
}
fn load_model(&self, path: &Path) -> Result<MLModel> {
let file = File::open(path).map_err(IoError::Io)?;
let tf_model: serde_json::Value = serde_json::from_reader(file)
.map_err(|e| IoError::SerializationError(e.to_string()))?;
let mut model = MLModel::new(MLFramework::TensorFlow);
if let Some(meta_graphs) = tf_model.get("meta_graphs").and_then(|v| v.as_array()) {
if let Some(meta_graph) = meta_graphs.first() {
if let Some(signature_def) = meta_graph
.get("signature_def")
.and_then(|v| v.get("serving_default"))
{
if let Some(inputs) = signature_def.get("inputs").and_then(|v| v.as_object()) {
for (name, input_info) in inputs {
if let Some(shape) = input_info.as_array() {
let shape_vec: Vec<usize> = shape
.iter()
.filter_map(|v| v.as_u64().map(|u| u as usize))
.collect();
model.metadata.inputshapes.insert(name.clone(), shape_vec);
}
}
}
}
}
}
if let Some(variables) = tf_model.get("variables").and_then(|v| v.as_array()) {
for var in variables {
if let Some(var_obj) = var.as_object() {
if let (Some(name), Some(shape), Some(data)) = (
var_obj.get("name").and_then(|v| v.as_str()),
var_obj.get("shape").and_then(|v| v.as_array()),
var_obj.get("data").and_then(|v| v.as_array()),
) {
let shape_vec: Vec<usize> = shape
.iter()
.filter_map(|v| v.as_u64().map(|u| u as usize))
.collect();
let data_vec: Vec<f32> = data
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
if let Ok(array) = ArrayD::from_shape_vec(IxDyn(&shape_vec), data_vec) {
model.weights.insert(
name.to_string(),
MLTensor::new(array, Some(name.to_string())),
);
}
}
}
}
}
Ok(model)
}
fn save_tensor(&self, tensor: &MLTensor, path: &Path) -> Result<()> {
let tensor_data = serde_json::json!({
"tensor": {
"dtype": format!("{:?}", tensor.metadata.dtype),
"tensorshape": {
"dim": tensor.metadata.shape.iter().map(|&d| serde_json::json!({"size": d})).collect::<Vec<_>>()
},
"tensor_content": tensor.data.as_slice().expect("Operation failed")
.iter()
.flat_map(|f| f.to_le_bytes().to_vec())
.collect::<Vec<u8>>()
}
});
let file = File::create(path).map_err(IoError::Io)?;
serde_json::to_writer_pretty(file, &tensor_data)
.map_err(|e| IoError::SerializationError(e.to_string()))
}
fn load_tensor(&self, path: &Path) -> Result<MLTensor> {
let file = File::open(path).map_err(IoError::Io)?;
let tensor_data: serde_json::Value = serde_json::from_reader(file)
.map_err(|e| IoError::SerializationError(e.to_string()))?;
if let Some(tensor) = tensor_data.get("tensor") {
let shape: Vec<usize> = tensor
.get("tensorshape")
.and_then(|ts| ts.get("dim"))
.and_then(|dims| dims.as_array())
.map(|dims| {
dims.iter()
.filter_map(|d| d.get("size").and_then(|s| s.as_u64().map(|u| u as usize)))
.collect()
})
.unwrap_or_default();
let content = tensor.get("tensor_content").and_then(|c| c.as_array());
let data: Vec<f32> = if let Some(content_array) = content {
content_array
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect()
} else {
vec![0.0; shape.iter().product()]
};
let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
.map_err(|e| IoError::Other(e.to_string()))?;
return Ok(MLTensor::new(array, None));
}
Err(IoError::Other(
"Invalid TensorFlow tensor format".to_string(),
))
}
}