use crate::passes::monomorphize::mangle_name;
use crate::passes::{ReplaceTypes, replace_types::NodeTemplate};
use hugr::HugrView;
use hugr::builder::{Container, Dataflow, HugrBuilder};
use hugr::hugr::linking::OnMultiDefn;
use hugr::ops::handle::{FuncID, NodeHandle};
use hugr::{
Hugr, Node, Wire,
builder::{BuildError, DataflowHugr, FunctionBuilder},
hugr::hugrmut::HugrMut,
ops::{DataflowOpTrait, ExtensionOp},
types::TypeArg,
};
use hugr_core::Visibility;
use hugr_core::hugr::linking::NameLinkingPolicy;
use indexmap::IndexMap;
use std::{cell::RefCell, ops::Deref};
#[derive(Clone, PartialEq, Eq)]
struct OpHashWrapper(ExtensionOp);
impl From<ExtensionOp> for OpHashWrapper {
fn from(op: ExtensionOp) -> Self {
Self(op)
}
}
impl std::hash::Hash for OpHashWrapper {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.extension_id().hash(state);
self.0.unqualified_id().hash(state);
self.0.args().hash(state);
}
}
#[derive(Clone)]
pub struct OpFunctionMap {
map: RefCell<IndexMap<OpHashWrapper, Option<Hugr>>>,
}
impl OpFunctionMap {
pub fn new() -> Self {
Self {
map: RefCell::new(IndexMap::new()),
}
}
pub fn insert_with<O, F>(
&self,
op: &ExtensionOp,
mangle_args: &[TypeArg],
func_builder: F,
) -> Result<(), BuildError>
where
O: IntoIterator<Item = Wire>,
F: FnOnce(&mut FunctionBuilder<Hugr>) -> Result<O, BuildError>,
{
let key = OpHashWrapper::from(op.clone());
if self.map.borrow().contains_key(&key) {
return Ok(());
}
let name = mangle_name(op.def().name(), mangle_args);
let sig = op.signature().deref().clone();
let mut func_b = FunctionBuilder::new(name, sig)?;
self.map.borrow_mut().insert(key.clone(), None);
let outputs = func_builder(&mut func_b)?;
let hugr = func_b.finish_hugr_with_outputs(outputs)?;
let out = self.map.borrow_mut().insert(key, Some(hugr));
debug_assert_eq!(out, Some(None));
Ok(())
}
pub fn len(&self) -> usize {
self.map.borrow().len()
}
pub fn is_empty(&self) -> bool {
self.map.borrow().is_empty()
}
pub fn into_function_iter(self) -> impl Iterator<Item = (ExtensionOp, Hugr)> {
self.map
.into_inner()
.into_iter()
.map(|(k, v)| (k.0, v.expect("All placeholders should have been replaced")))
}
pub fn register_operation_replacements(
self,
_hugr: &mut impl HugrMut<Node = Node>,
lowerer: &mut ReplaceTypes,
) {
for (op, func_def) in self.into_function_iter() {
lowerer.set_replace_op(&op, func_as_node_template(func_def));
}
}
}
impl Default for OpFunctionMap {
fn default() -> Self {
Self::new()
}
}
fn func_as_node_template(func_def: Hugr) -> NodeTemplate {
let func_signature = func_def.inner_function_type().unwrap().into_owned();
let mut b = FunctionBuilder::new_vis("", func_signature, Visibility::Private).unwrap();
let func_id = FuncID::<true>::from(
b.module_root_builder()
.add_hugr(func_def)
.inserted_entrypoint,
);
let call = b.call(&func_id, &[], b.input_wires()).unwrap();
let mut call_hugr = b.finish_hugr_with_outputs(call.outputs()).unwrap();
call_hugr.set_entrypoint(call.node());
NodeTemplate::LinkedHugr(
Box::new(call_hugr),
NameLinkingPolicy::default().on_multiple_defn(OnMultiDefn::UseTarget),
)
}