use std::hash::{Hash, Hasher};
use derive_more::{Display, Error};
use fxhash::{FxHashMap, FxHasher64};
use hugr::ops::OpType;
use hugr::{HugrView, Node};
use hugr_core::hugr::internal::PortgraphNodeMap;
use petgraph::visit::{self as pg, Walker};
use super::Circuit;
pub trait CircuitHash {
fn circuit_hash(&self, node: Node) -> Result<u64, HashError>;
}
impl<T: HugrView<Node = Node>> CircuitHash for Circuit<T> {
fn circuit_hash(&self, node: Node) -> Result<u64, HashError> {
self.hugr().circuit_hash(node)
}
}
impl<T> CircuitHash for T
where
T: HugrView<Node = Node>,
{
fn circuit_hash(&self, node: Node) -> Result<u64, HashError> {
let Some([_, output_node]) = self.get_io(node) else {
return Err(HashError::NotADfg);
};
let mut node_hashes = HashState::default();
let (region, node_map) = self.region_portgraph(node);
for pg_node in pg::Topo::new(®ion).iter(®ion) {
let node = node_map.from_portgraph(pg_node);
let hash = hash_node(self, node, &mut node_hashes)?;
if node_hashes.set_hash(node, hash).is_some() {
panic!("Hash already set for node {node}");
}
}
node_hashes
.node_hash(output_node)
.ok_or(HashError::CyclicCircuit)
}
}
#[derive(Clone, Default, Debug)]
struct HashState {
pub hashes: FxHashMap<Node, u64>,
}
impl HashState {
#[inline]
fn node_hash(&self, node: Node) -> Option<u64> {
self.hashes.get(&node).copied()
}
#[inline]
fn set_hash(&mut self, node: Node, hash: u64) -> Option<u64> {
self.hashes.insert(node, hash)
}
}
fn hashable_op(op: &OpType) -> impl Hash {
match op {
OpType::ExtensionOp(op) if !op.args().is_empty() => {
format!(
"{}[{}]",
op.def().name(),
serde_json::to_string(op.args()).unwrap()
)
}
OpType::OpaqueOp(op) if !op.args().is_empty() => {
format!(
"{}[{}]",
op.qualified_id(),
serde_json::to_string(op.args()).unwrap()
)
}
_ => op.to_string(),
}
}
fn hash_node(
circ: &impl HugrView<Node = Node>,
node: Node,
state: &mut HashState,
) -> Result<u64, HashError> {
let op = circ.get_optype(node);
let mut hasher = FxHasher64::default();
if circ.children(node).count() > 0 {
circ.circuit_hash(node)?.hash(&mut hasher);
}
hashable_op(op).hash(&mut hasher);
for input in circ.node_inputs(node) {
let input_hash = circ
.linked_ports(node, input)
.map(|(pred_node, pred_port)| {
let pred_node_hash = state.node_hash(pred_node);
fxhash::hash64(&(pred_node_hash, pred_port, input))
})
.fold(0, |total, hash| hash ^ total);
input_hash.hash(&mut hasher);
}
Ok(hasher.finish())
}
#[derive(Debug, Display, Clone, PartialEq, Eq, Error)]
#[non_exhaustive]
pub enum HashError {
#[display("The circuit contains a cycle.")]
CyclicCircuit,
#[display("Tried to hash a non-dfg hugr.")]
NotADfg,
}
#[cfg(test)]
mod test {
use tket_json_rs::circuit_json;
use crate::serialize::TKETDecode;
use crate::utils::build_simple_circuit;
use crate::Tk2Op;
use super::*;
#[test]
fn hash_equality() {
let circ1 = build_simple_circuit(2, |circ| {
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::T, [1])?;
circ.append(Tk2Op::CX, [0, 1])?;
Ok(())
})
.unwrap();
let hash1 = circ1.circuit_hash(circ1.parent()).unwrap();
let circ2 = build_simple_circuit(2, |circ| {
circ.append(Tk2Op::T, [1])?;
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
Ok(())
})
.unwrap();
let hash2 = circ2.circuit_hash(circ2.parent()).unwrap();
assert_eq!(hash1, hash2);
let circ3 = build_simple_circuit(2, |circ| {
circ.append(Tk2Op::T, [1])?;
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [1, 0])?;
Ok(())
})
.unwrap();
let hash3 = circ3.circuit_hash(circ3.parent()).unwrap();
assert_ne!(hash1, hash3);
}
#[test]
fn hash_constants() {
let c_str = r#"{"bits": [], "commands": [{"args": [["q", [0]]], "op": {"params": ["0.5"], "type": "Rz"}}], "created_qubits": [], "discarded_qubits": [], "implicit_permutation": [[["q", [0]], ["q", [0]]]], "phase": "0.0", "qubits": [["q", [0]]]}"#;
let ser: circuit_json::SerialCircuit = serde_json::from_str(c_str).unwrap();
let circ: Circuit = ser.decode().unwrap();
circ.circuit_hash(circ.parent()).unwrap();
}
#[test]
fn hash_constants_neq() {
let c_str1 = r#"{"bits": [], "commands": [{"args": [["q", [0]]], "op": {"params": ["0.5"], "type": "Rz"}}], "created_qubits": [], "discarded_qubits": [], "implicit_permutation": [[["q", [0]], ["q", [0]]]], "phase": "0.0", "qubits": [["q", [0]]]}"#;
let c_str2 = r#"{"bits": [], "commands": [{"args": [["q", [0]]], "op": {"params": ["1.0"], "type": "Rz"}}], "created_qubits": [], "discarded_qubits": [], "implicit_permutation": [[["q", [0]], ["q", [0]]]], "phase": "0.0", "qubits": [["q", [0]]]}"#;
let mut all_hashes = Vec::with_capacity(2);
for c_str in [c_str1, c_str2] {
let ser: circuit_json::SerialCircuit = serde_json::from_str(c_str).unwrap();
let circ: Circuit = ser.decode().unwrap();
all_hashes.push(circ.circuit_hash(circ.parent()).unwrap());
}
assert_ne!(all_hashes[0], all_hashes[1]);
}
}