use super::core::*;
use crate::error::TensorError;
use std::collections::HashMap;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
pub struct GraphDef {
pub nodes: Vec<NodeDef>,
pub edges: Vec<EdgeDef>,
pub version: u64,
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
pub struct NodeDef {
pub id: NodeId,
pub name: String,
pub op_type: String,
pub device: String,
pub attributes: HashMap<String, AttributeValueDef>,
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
pub struct EdgeDef {
pub id: EdgeId,
pub from_node: NodeId,
pub to_node: NodeId,
pub from_output: usize,
pub to_input: usize,
pub dtype: String,
pub shape: Vec<usize>,
pub is_control: bool,
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
pub enum AttributeValueDef {
String(String),
Int(i64),
Float(f64),
Bool(bool),
IntList(Vec<i64>),
FloatList(Vec<f64>),
Shape(Vec<usize>),
Tensor(Vec<f32>), }
impl From<AttributeValue> for AttributeValueDef {
fn from(value: AttributeValue) -> Self {
match value {
AttributeValue::String(s) => AttributeValueDef::String(s),
AttributeValue::Int(i) => AttributeValueDef::Int(i),
AttributeValue::Float(f) => AttributeValueDef::Float(f),
AttributeValue::Bool(b) => AttributeValueDef::Bool(b),
AttributeValue::IntList(list) => AttributeValueDef::IntList(list),
AttributeValue::FloatList(list) => AttributeValueDef::FloatList(list),
AttributeValue::Shape(shape) => AttributeValueDef::Shape(shape.dims().to_vec()),
AttributeValue::Tensor(tensor) => {
let data = tensor.as_slice().unwrap_or(&[]).to_vec();
AttributeValueDef::Tensor(data)
}
}
}
}
impl TryFrom<AttributeValueDef> for AttributeValue {
type Error = TensorError;
fn try_from(def: AttributeValueDef) -> Result<Self, Self::Error> {
match def {
AttributeValueDef::String(s) => Ok(AttributeValue::String(s)),
AttributeValueDef::Int(i) => Ok(AttributeValue::Int(i)),
AttributeValueDef::Float(f) => Ok(AttributeValue::Float(f)),
AttributeValueDef::Bool(b) => Ok(AttributeValue::Bool(b)),
AttributeValueDef::IntList(list) => Ok(AttributeValue::IntList(list)),
AttributeValueDef::FloatList(list) => Ok(AttributeValue::FloatList(list)),
AttributeValueDef::Shape(dims) => {
Ok(AttributeValue::Shape(crate::shape::Shape::new(dims)))
}
AttributeValueDef::Tensor(data) => {
use crate::tensor::Tensor;
let shape = vec![data.len()];
let tensor = Tensor::from_vec(data, &shape)?;
Ok(AttributeValue::Tensor(tensor))
}
}
}
}
impl Graph {
pub fn to_graph_def(&self) -> GraphDef {
let nodes = self
.nodes
.values()
.map(|node| NodeDef {
id: node.id,
name: node.name.clone(),
op_type: match &node.op_type {
NodeType::Operation(op) => op.clone(),
NodeType::Variable { dtype, shape: _ } => format!("Variable:{:?}", dtype),
NodeType::Placeholder { dtype, shape: _ } => format!("Placeholder:{:?}", dtype),
NodeType::Constant => "Constant".to_string(),
},
device: format!("{:?}", node.device),
attributes: node
.attributes
.iter()
.map(|(k, v)| (k.clone(), v.clone().into()))
.collect(),
})
.collect();
let edges = self
.edges
.values()
.map(|edge| EdgeDef {
id: edge.id,
from_node: edge.from_node,
to_node: edge.to_node,
from_output: edge.from_output,
to_input: edge.to_input,
dtype: format!("{:?}", edge.dtype),
shape: edge.shape.dims().to_vec(),
is_control: edge.is_control,
})
.collect();
GraphDef {
nodes,
edges,
version: self.version,
}
}
pub fn from_graph_def(graph_def: &GraphDef) -> Result<Self, TensorError> {
let mut graph = Graph::new();
let mut id_mapping: HashMap<NodeId, NodeId> = HashMap::new();
for node_def in &graph_def.nodes {
let op_type = if node_def.op_type.starts_with("Variable:") {
NodeType::Variable {
dtype: crate::dtype::DType::Float32, shape: crate::shape::Shape::new(vec![]),
}
} else if node_def.op_type.starts_with("Placeholder:") {
NodeType::Placeholder {
dtype: crate::dtype::DType::Float32, shape: crate::shape::Shape::new(vec![]),
}
} else if node_def.op_type == "Constant" {
NodeType::Constant
} else {
NodeType::Operation(node_def.op_type.clone())
};
let device = crate::device::Device::Cpu;
let attributes: Result<HashMap<String, AttributeValue>, TensorError> = node_def
.attributes
.iter()
.map(|(k, v)| Ok((k.clone(), v.clone().try_into()?)))
.collect();
let new_id = graph.add_node(node_def.name.clone(), op_type, device, attributes?)?;
id_mapping.insert(node_def.id, new_id);
}
for edge_def in &graph_def.edges {
let from_node = *id_mapping.get(&edge_def.from_node).ok_or_else(|| {
TensorError::invalid_argument(format!(
"Node {} not found in mapping",
edge_def.from_node
))
})?;
let to_node = *id_mapping.get(&edge_def.to_node).ok_or_else(|| {
TensorError::invalid_argument(format!(
"Node {} not found in mapping",
edge_def.to_node
))
})?;
graph.add_edge(
from_node,
to_node,
edge_def.from_output,
edge_def.to_input,
crate::dtype::DType::Float32, crate::shape::Shape::new(edge_def.shape.clone()),
edge_def.is_control,
)?;
}
graph.version = graph_def.version;
Ok(graph)
}
#[cfg(feature = "serialize")]
pub fn save_to_file<P: AsRef<std::path::Path>>(&self, path: P) -> Result<(), TensorError> {
let graph_def = self.to_graph_def();
let serialized = oxicode::serde::encode_to_vec(&graph_def, oxicode::config::standard())
.map_err(|e| TensorError::invalid_argument(format!("Serialization failed: {}", e)))?;
std::fs::write(path, serialized)
.map_err(|e| TensorError::invalid_argument(format!("Failed to write file: {}", e)))?;
Ok(())
}
#[cfg(feature = "serialize")]
pub fn load_from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self, TensorError> {
let data = std::fs::read(path)
.map_err(|e| TensorError::invalid_argument(format!("Failed to read file: {}", e)))?;
let graph_def: GraphDef =
oxicode::serde::decode_owned_from_slice(&data, oxicode::config::standard())
.map_err(|e| {
TensorError::invalid_argument(format!("Deserialization failed: {}", e))
})?
.0;
Self::from_graph_def(&graph_def)
}
#[cfg(feature = "serialize")]
pub fn to_json(&self) -> Result<String, TensorError> {
let graph_def = self.to_graph_def();
serde_json::to_string_pretty(&graph_def)
.map_err(|e| TensorError::invalid_argument(format!("JSON serialization failed: {}", e)))
}
#[cfg(feature = "serialize")]
pub fn from_json(json: &str) -> Result<Self, TensorError> {
let graph_def: GraphDef = serde_json::from_str(json).map_err(|e| {
TensorError::invalid_argument(format!("JSON deserialization failed: {}", e))
})?;
Self::from_graph_def(&graph_def)
}
}