use serde::de::DeserializeOwned;
use std::collections::HashMap;
use thiserror::Error;
use crate::core::NodeIndex;
use crate::hugr::Hugr;
use crate::ops::OpType;
use crate::{Node, PortIndex};
use portgraph::hierarchy::AttachError;
use portgraph::{Direction, LinkError, PortView};
use serde::{Deserialize, Deserializer, Serialize};
use self::upgrade::UpgradeError;
use super::{HugrMut, HugrView, NodeMetadataMap};
mod upgrade;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "version", rename_all = "lowercase")]
enum Versioned<SerHugr = SerHugrLatest> {
#[serde(skip_serializing)]
V0,
V1(serde_json::Value),
V2(serde_json::Value),
Live(SerHugr),
#[serde(skip_serializing)]
#[serde(other)]
Unsupported,
}
impl<T> Versioned<T> {
pub fn new_latest(t: T) -> Self {
Self::Live(t)
}
}
impl<T: DeserializeOwned> Versioned<T> {
fn upgrade(self) -> Result<T, UpgradeError> {
#[allow(unused)]
fn go<D: serde::de::DeserializeOwned>(v: serde_json::Value) -> Result<D, UpgradeError> {
serde_json::from_value(v).map_err(Into::into)
}
loop {
match self {
Self::V0 => Err(UpgradeError::KnownVersionUnsupported("0".into()))?,
Self::V1(_) => Err(UpgradeError::KnownVersionUnsupported("1".into()))?,
Self::V2(_) => Err(UpgradeError::KnownVersionUnsupported("2".into()))?,
Self::Live(ser_hugr) => return Ok(ser_hugr),
Versioned::Unsupported => Err(UpgradeError::UnknownVersionUnsupported)?,
}
}
}
}
#[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
struct NodeSer {
parent: Node,
#[serde(flatten)]
op: OpType,
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
struct SerHugrLatest {
nodes: Vec<NodeSer>,
edges: Vec<[(Node, Option<u16>); 2]>,
#[serde(default)]
metadata: Option<Vec<Option<NodeMetadataMap>>>,
#[serde(default)]
encoder: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Error)]
#[non_exhaustive]
pub enum HUGRSerializationError {
#[error("Failed to attach child to parent: {0}.")]
AttachError(#[from] AttachError),
#[error("Failed to build edge when deserializing: {0}.")]
LinkError(#[from] LinkError),
#[error("Cannot connect an {dir:?} edge without port offset to node {node} with operation type {op_type}.")]
MissingPortOffset {
node: Node,
dir: Direction,
op_type: OpType,
},
#[error("The edge endpoint {node} is not a node in the graph.")]
UnknownEdgeNode {
node: Node,
},
#[error("The first node in the node list has parent {0}, should be itself (index 0)")]
FirstNodeNotRoot(Node),
}
impl Serialize for Hugr {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let shg: SerHugrLatest = self.try_into().map_err(serde::ser::Error::custom)?;
let versioned = Versioned::new_latest(shg);
versioned.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Hugr {
fn deserialize<D>(deserializer: D) -> Result<Hugr, D::Error>
where
D: Deserializer<'de>,
{
let versioned = Versioned::deserialize(deserializer)?;
let shl: SerHugrLatest = versioned.upgrade().map_err(serde::de::Error::custom)?;
shl.try_into().map_err(serde::de::Error::custom)
}
}
impl TryFrom<&Hugr> for SerHugrLatest {
type Error = HUGRSerializationError;
fn try_from(hugr: &Hugr) -> Result<Self, Self::Error> {
let mut node_rekey: HashMap<Node, Node> = HashMap::with_capacity(hugr.node_count());
for (order, node) in hugr.canonical_order(hugr.root()).enumerate() {
node_rekey.insert(node, portgraph::NodeIndex::new(order).into());
}
let mut nodes = vec![None; hugr.node_count()];
let mut metadata = vec![None; hugr.node_count()];
for n in hugr.nodes() {
let parent = node_rekey[&hugr.get_parent(n).unwrap_or(n)];
let opt = hugr.get_optype(n);
let new_node = node_rekey[&n].index();
nodes[new_node] = Some(NodeSer {
parent,
op: opt.clone(),
});
metadata[new_node].clone_from(hugr.metadata.get(n.pg_index()));
}
let nodes = nodes
.into_iter()
.collect::<Option<Vec<_>>>()
.expect("Could not reach one of the nodes");
let find_offset = |node: Node, offset: usize, dir: Direction, hugr: &Hugr| {
let op = hugr.get_optype(node);
let is_value_port = offset < op.value_port_count(dir);
let is_static_input = op.static_port(dir).map_or(false, |p| p.index() == offset);
let offset = (is_value_port || is_static_input).then_some(offset as u16);
(node_rekey[&node], offset)
};
let edges: Vec<_> = hugr
.nodes()
.flat_map(|node| {
hugr.node_ports(node, Direction::Outgoing)
.enumerate()
.flat_map(move |(src_offset, port)| {
let src = find_offset(node, src_offset, Direction::Outgoing, hugr);
hugr.linked_ports(node, port).map(move |(tgt_node, tgt)| {
let tgt = find_offset(tgt_node, tgt.index(), Direction::Incoming, hugr);
[src, tgt]
})
})
})
.collect();
let encoder = Some(format!("hugr-rs v{}", env!("CARGO_PKG_VERSION")));
Ok(Self {
nodes,
edges,
metadata: Some(metadata),
encoder,
})
}
}
impl TryFrom<SerHugrLatest> for Hugr {
type Error = HUGRSerializationError;
fn try_from(
SerHugrLatest {
nodes,
edges,
metadata,
encoder: _,
}: SerHugrLatest,
) -> Result<Self, Self::Error> {
let mut nodes = nodes.into_iter();
let NodeSer {
parent: root_parent,
op: root_type,
..
} = nodes.next().unwrap();
if root_parent.index() != 0 {
return Err(HUGRSerializationError::FirstNodeNotRoot(root_parent));
}
let mut hugr = Hugr::with_capacity(root_type, nodes.len(), edges.len() * 2);
for node_ser in nodes {
hugr.add_node_with_parent(node_ser.parent, node_ser.op);
}
if let Some(metadata) = metadata {
for (node, metadata) in metadata.into_iter().enumerate() {
if let Some(metadata) = metadata {
let node = portgraph::NodeIndex::new(node);
hugr.metadata[node] = Some(metadata);
}
}
}
let unwrap_offset = |node: Node, offset, dir, hugr: &Hugr| -> Result<usize, Self::Error> {
if !hugr.graph.contains_node(node.pg_index()) {
return Err(HUGRSerializationError::UnknownEdgeNode { node });
}
let offset = match offset {
Some(offset) => offset as usize,
None => {
let op_type = hugr.get_optype(node);
op_type
.other_port(dir)
.ok_or(HUGRSerializationError::MissingPortOffset {
node,
dir,
op_type: op_type.clone(),
})?
.index()
}
};
Ok(offset)
};
for [(src, from_offset), (dst, to_offset)] in edges {
let src_port = unwrap_offset(src, from_offset, Direction::Outgoing, &hugr)?;
let dst_port = unwrap_offset(dst, to_offset, Direction::Incoming, &hugr)?;
hugr.connect(src, src_port, dst, dst_port);
}
Ok(hugr)
}
}
#[cfg(all(test, not(miri)))]
pub mod test;