#![allow(clippy::type_complexity)]
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;
use handlers::list_const;
use hugr_core::hugr::linking::{HugrLinking, NameLinkingPolicy, OnMultiDefn};
use hugr_core::std_extensions::collections::array::array_type_def;
use hugr_core::std_extensions::collections::list::list_type_def;
use itertools::Either;
use thiserror::Error;
use hugr_core::builder::{
BuildError, BuildHandle, Container, Dataflow, DataflowHugr, FunctionBuilder, HugrBuilder,
};
use hugr_core::extension::{ExtensionId, OpDef, SignatureError, TypeDef};
use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::ops::constant::{OpaqueValue, Sum};
use hugr_core::ops::handle::{DataflowOpID, FuncID, NodeHandle};
use hugr_core::ops::{
AliasDefn, CFG, Call, CallIndirect, Case, Conditional, Const, DFG, DataflowBlock, ExitBlock,
ExtensionOp, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, Value,
};
use hugr_core::types::{
ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeRow,
TypeTransformer,
};
use hugr_core::{Direction, Hugr, HugrView, Node, PortIndex, Visibility, Wire};
use crate::passes::composable::WithScope;
use crate::passes::{ComposablePass, PassScope};
mod linearize;
pub use linearize::{CallbackHandler, DelegatingLinearizer, LinearizeError, Linearizer};
#[derive(Clone, Debug, PartialEq)]
#[non_exhaustive]
pub enum NodeTemplate {
SingleOp(OpType),
CompoundOp(Box<Hugr>),
LinkedHugr(Box<Hugr>, NameLinkingPolicy),
}
impl NodeTemplate {
pub fn linked_hugr(h: impl Into<Hugr>) -> Self {
NodeTemplate::LinkedHugr(Box::new(h.into()), NameLinkingPolicy::default())
}
pub fn call_to_function(
func_def: Hugr,
type_args: &[TypeArg],
) -> Result<NodeTemplate, BuildError> {
let func_op = func_def.entrypoint_optype();
let func_signature = match func_op {
OpType::FuncDecl(decl) => decl.signature().clone(),
OpType::FuncDefn(defn) => defn.signature().clone(),
_ => {
return Err(BuildError::UnexpectedType {
node: func_def.entrypoint(),
op_desc: "function definition or declaration",
});
}
}
.instantiate(type_args)?;
let mut b = FunctionBuilder::new_vis("", func_signature, Visibility::Private).unwrap();
let func_id = FuncID::<true>::from(
b.module_root_builder()
.add_hugr(func_def)
.inserted_entrypoint,
);
let call = b.call(&func_id, type_args, b.input_wires()).unwrap();
let mut call_hugr = b.finish_hugr_with_outputs(call.outputs()).unwrap();
call_hugr.set_entrypoint(call.node());
Ok(NodeTemplate::LinkedHugr(
Box::new(call_hugr),
NameLinkingPolicy::default().on_multiple_defn(OnMultiDefn::UseTarget),
))
}
pub fn add_hugr(
self,
hugr: &mut impl HugrMut<Node = Node>,
parent: Node,
) -> Result<Node, BuildError> {
match self {
NodeTemplate::SingleOp(op_type) => Ok(hugr.add_node_with_parent(parent, op_type)),
NodeTemplate::CompoundOp(new_h) => {
Ok(hugr.insert_hugr(parent, *new_h).inserted_entrypoint)
}
NodeTemplate::LinkedHugr(h, pol) => {
Ok(hugr.insert_link_hugr(parent, *h, &pol)?.inserted_entrypoint)
}
}
}
pub fn add(
self,
dfb: &mut impl Dataflow,
inputs: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
match self {
NodeTemplate::SingleOp(opty) => dfb.add_dataflow_op(opty, inputs),
NodeTemplate::CompoundOp(h) => dfb.add_hugr_with_wires(*h, inputs),
NodeTemplate::LinkedHugr(h, pol) => dfb.add_link_hugr_with_wires(*h, &pol, inputs),
}
}
fn replace<H: HugrMut<Node = Node>>(
self,
hugr: &mut H,
n: Node,
rt: &ReplaceTypes,
opts: &ReplacementOptions,
) -> Result<(), ReplaceTypesError> {
let ef = |e| ReplaceTypesError::AddTemplateError(n, Box::new(e));
assert_eq!(hugr.children(n).count(), 0);
let (new_optype, static_source, static_inport) = match self {
NodeTemplate::SingleOp(op_type) => {
if op_type.static_input_port().is_some() {
return Err(ef(BuildError::UnexpectedType {
node: n,
op_desc: "Replacement SingleOp without static input",
}));
}
(op_type, None, None)
}
NodeTemplate::CompoundOp(new_h) => {
let root = new_h.entrypoint_optype();
if !matches!(
root,
OpType::CFG(_) | OpType::DFG(_) | OpType::Conditional(_) | OpType::TailLoop(_)
)
{
return Err(ef(BuildError::UnexpectedType {
node: n,
op_desc: "Replacement CompoundOp not a container/dataflow node",
}));
}
assert!(root.static_input_port().is_none());
let new_entrypoint = hugr.insert_hugr(n, *new_h).inserted_entrypoint;
let children = hugr.children(new_entrypoint).collect::<Vec<_>>();
let root_opty = hugr.remove_node(new_entrypoint);
for ch in children {
hugr.set_parent(ch, n);
}
(root_opty, None, None)
}
NodeTemplate::LinkedHugr(mut h, pol) => {
let mut containing_func = h.entrypoint();
while let Some(parent) = h.get_parent(containing_func)
&& !h.get_optype(parent).is_module()
{
containing_func = parent;
}
for ch in h.children(h.module_root()).collect::<Vec<_>>() {
if ch != containing_func {
rt.process_subtree_opts(&mut h, ch, opts)?;
}
}
let new_entrypoint = hugr
.insert_link_hugr(n, *h, &pol)
.map_err(|e| ef(BuildError::from(e)))?
.inserted_entrypoint;
let children = hugr.children(new_entrypoint).collect::<Vec<_>>();
let static_source = hugr.static_source(new_entrypoint);
let root_opty = hugr.remove_node(new_entrypoint);
let static_inport = root_opty.static_input_port();
for ch in children {
hugr.set_parent(ch, n);
}
(root_opty, static_source, static_inport)
}
};
*hugr.optype_mut(n) = new_optype;
if let Some(static_inport) = static_inport {
hugr.insert_ports(n, Direction::Incoming, static_inport.index(), 1);
if let Some(static_source) = static_source {
hugr.connect(static_source, 0, n, static_inport);
}
}
rt.process_subtree_opts(hugr, n, opts)?;
Ok(())
}
fn check_signature(
&self,
inputs: &TypeRow,
outputs: &TypeRow,
) -> Result<(), Option<Signature>> {
let sig = match self {
NodeTemplate::SingleOp(op_type) => op_type,
NodeTemplate::CompoundOp(hugr) => hugr.entrypoint_optype(),
NodeTemplate::LinkedHugr(hugr, _) => hugr.entrypoint_optype(),
}
.dataflow_signature();
if sig.as_deref().map(Signature::io) == Some((inputs, outputs)) {
Ok(())
} else {
Err(sig.map(Cow::into_owned))
}
}
}
#[derive(Clone, Default, PartialEq, Eq)] pub struct ReplacementOptions {
process_recursive: bool,
linearize_unchanged: bool,
}
impl ReplacementOptions {
fn recursive() -> Self {
Self {
process_recursive: true,
linearize_unchanged: false,
}
}
pub fn with_linearization(mut self, lin: bool) -> Self {
self.linearize_unchanged = lin;
self
}
}
#[derive(Clone)]
pub struct ReplaceTypes {
type_map: HashMap<CustomType, (Type, ReplacementOptions)>,
param_types:
HashMap<ParametricType, (Arc<dyn Fn(&[TypeArg]) -> Option<Type>>, ReplacementOptions)>,
linearize: DelegatingLinearizer,
op_map: HashMap<OpHashWrapper, (NodeTemplate, ReplacementOptions)>,
param_ops: HashMap<
ParametricOp,
(
Arc<
dyn Fn(
&[TypeArg],
&ReplaceTypes,
) -> Result<Option<NodeTemplate>, ReplaceTypesError>,
>,
ReplacementOptions,
),
>,
consts: HashMap<
CustomType,
Arc<dyn Fn(&OpaqueValue, &ReplaceTypes) -> Result<Value, ReplaceTypesError>>,
>,
param_consts: HashMap<
ParametricType,
Arc<dyn Fn(&OpaqueValue, &ReplaceTypes) -> Result<Option<Value>, ReplaceTypesError>>,
>,
scope: Either<PassScope, Vec<Node>>,
}
impl Default for ReplaceTypes {
fn default() -> Self {
let mut res = Self::new_empty();
res.linearize = DelegatingLinearizer::default();
res.replace_consts_parametrized(array_type_def(), handlers::array_const);
res.replace_consts_parametrized(list_type_def(), list_const);
res
}
}
impl TypeTransformer for ReplaceTypes {
type Err = ReplaceTypesError;
fn apply_custom(&self, ct: &CustomType) -> Result<Option<Type>, Self::Err> {
let mut ty_and_opts = None;
if let Some(res) = self.type_map.get(ct) {
ty_and_opts = Some(res.clone())
} else if let Some((dest_fn, opts)) = self.param_types.get(&ct.into()) {
let mut nargs = ct.args().to_vec();
nargs
.iter_mut()
.try_for_each(|ta| ta.transform(self).map(|_ch| ()))?;
ty_and_opts = dest_fn(&nargs).map(|ty| (ty, opts.clone()))
};
let Some((mut ty, opts)) = ty_and_opts else {
return Ok(None);
};
if opts.process_recursive {
ty.transform(self)?;
}
Ok(Some(ty))
}
}
#[derive(Debug, Error, PartialEq)]
#[non_exhaustive]
#[expect(missing_docs)]
pub enum ReplaceTypesError {
#[error(transparent)]
SignatureError(#[from] SignatureError),
#[error(transparent)]
ConstError(#[from] ConstTypeError),
#[error(transparent)]
LinearizeError(#[from] LinearizeError),
#[error("Replacement op for {0} could not be added because {1}")]
AddTemplateError(Node, Box<BuildError>),
}
impl ReplaceTypes {
#[must_use]
pub fn new_empty() -> Self {
Self {
type_map: Default::default(),
param_types: Default::default(),
linearize: DelegatingLinearizer::new_empty(),
op_map: Default::default(),
param_ops: Default::default(),
consts: Default::default(),
param_consts: Default::default(),
scope: Either::Left(PassScope::default()),
}
}
pub fn set_replace_type(&mut self, src: CustomType, dest: Type) {
self.type_map
.insert(src, (dest, ReplacementOptions::recursive()));
}
pub fn set_replace_parametrized_type(
&mut self,
src: &TypeDef,
dest_fn: impl Fn(&[TypeArg]) -> Option<Type> + 'static,
) {
self.param_types.insert(
src.into(),
(Arc::new(dest_fn), ReplacementOptions::recursive()),
);
}
pub fn linearizer_mut(&mut self) -> &mut DelegatingLinearizer {
&mut self.linearize
}
pub fn get_linearizer(&self) -> &impl Linearizer {
&self.linearize
}
pub fn set_replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) {
self.op_map.insert(
OpHashWrapper::from(src),
(dest, ReplacementOptions::recursive()),
);
}
pub fn set_replace_parametrized_op(
&mut self,
src: &OpDef,
dest_fn: impl Fn(&[TypeArg], &ReplaceTypes) -> Result<Option<NodeTemplate>, ReplaceTypesError>
+ 'static,
) {
self.param_ops.insert(
src.into(),
(Arc::new(dest_fn), ReplacementOptions::recursive()),
);
}
pub fn replace_consts(
&mut self,
src_ty: CustomType,
const_fn: impl Fn(&OpaqueValue, &ReplaceTypes) -> Result<Value, ReplaceTypesError> + 'static,
) {
self.consts.insert(src_ty, Arc::new(const_fn));
}
pub fn replace_consts_parametrized(
&mut self,
src_ty: &TypeDef,
const_fn: impl Fn(&OpaqueValue, &ReplaceTypes) -> Result<Option<Value>, ReplaceTypesError>
+ 'static,
) {
self.param_consts.insert(src_ty.into(), Arc::new(const_fn));
}
pub fn set_regions(&mut self, regions: impl IntoIterator<Item = Node>) {
self.scope = Either::Right(regions.into_iter().collect());
}
fn process_subtree_opts(
&self,
hugr: &mut impl HugrMut<Node = Node>,
root: Node,
opts: &ReplacementOptions,
) -> Result<(), ReplaceTypesError> {
if opts.process_recursive {
self.change_subtree(hugr, root, opts.linearize_unchanged)?;
} else if opts.linearize_unchanged {
let mut descs = hugr.descendants(root);
assert_eq!(descs.next(), Some(root));
for n in descs.collect::<Vec<_>>() {
self.linearize_outputs(hugr, n)?;
}
}
Ok(())
}
fn change_subtree(
&self,
hugr: &mut impl HugrMut<Node = Node>,
root: Node,
linearize_unchanged_ops: bool,
) -> Result<bool, ReplaceTypesError> {
let mut descs = hugr.descendants(root).collect::<Vec<_>>().into_iter();
assert_eq!(descs.next(), Some(root));
let mut changed = self.change_node(hugr, root)?;
for n in descs {
if self.change_node(hugr, n)? {
changed = true;
} else if !linearize_unchanged_ops {
continue;
}
self.linearize_outputs(hugr, n)?;
}
Ok(changed)
}
fn change_node(
&self,
hugr: &mut impl HugrMut<Node = Node>,
n: Node,
) -> Result<bool, ReplaceTypesError> {
match hugr.optype_mut(n) {
OpType::FuncDefn(fd) => fd.signature_mut().body_mut().transform(self),
OpType::FuncDecl(fd) => fd.signature_mut().body_mut().transform(self),
OpType::LoadConstant(LoadConstant { datatype: ty })
| OpType::AliasDefn(AliasDefn { definition: ty, .. }) => ty.transform(self),
OpType::ExitBlock(ExitBlock { cfg_outputs: types })
| OpType::Input(Input { types })
| OpType::Output(Output { types }) => types.transform(self),
OpType::LoadFunction(LoadFunction {
func_sig,
type_args,
instantiation,
})
| OpType::Call(Call {
func_sig,
type_args,
instantiation,
}) => {
let change = func_sig.body_mut().transform(self)? | type_args.transform(self)?;
if change {
let new_inst = func_sig
.instantiate(type_args)
.map_err(ReplaceTypesError::SignatureError)?;
*instantiation = new_inst;
}
Ok(change)
}
OpType::Case(Case { signature })
| OpType::CFG(CFG { signature })
| OpType::DFG(DFG { signature })
| OpType::CallIndirect(CallIndirect { signature }) => signature.transform(self),
OpType::Tag(Tag { variants, .. }) => variants.transform(self),
OpType::Conditional(Conditional {
other_inputs: row1,
outputs: row2,
sum_rows,
..
})
| OpType::DataflowBlock(DataflowBlock {
inputs: row1,
other_outputs: row2,
sum_rows,
..
}) => Ok(row1.transform(self)? | row2.transform(self)? | sum_rows.transform(self)?),
OpType::TailLoop(TailLoop {
just_inputs,
just_outputs,
rest,
..
}) => Ok(just_inputs.transform(self)?
| just_outputs.transform(self)?
| rest.transform(self)?),
OpType::Const(Const { value, .. }) => self.change_value(value),
OpType::ExtensionOp(ext_op) => Ok({
let def = ext_op.def_arc();
let mut changed = false;
let replacement = match self.op_map.get(&OpHashWrapper::from(&*ext_op)) {
r @ Some(_) => r.cloned(),
None => {
let mut args = ext_op.args().to_vec();
changed = args.transform(self)?;
let r2 = match self.param_ops.get(&def.as_ref().into()) {
None => None,
Some((rep_fn, opts)) => {
rep_fn(&args, self)?.map(|nt| (nt, opts.clone()))
}
};
if r2.is_none() && changed {
*ext_op = ExtensionOp::new(def.clone(), args)?;
}
r2
}
};
if let Some((replacement, opts)) = replacement {
replacement.replace(hugr, n, self, &opts)?;
true
} else {
changed
}
}),
OpType::OpaqueOp(_) => panic!("OpaqueOp should not be in a Hugr"),
OpType::AliasDecl(_) | OpType::Module(_) => Ok(false),
_ => todo!(),
}
}
pub fn change_value(&self, value: &mut Value) -> Result<bool, ReplaceTypesError> {
match value {
Value::Sum(Sum {
values, sum_type, ..
}) => {
let mut any_change = false;
for value in values {
any_change |= self.change_value(value)?;
}
any_change |= sum_type.transform(self)?;
Ok(any_change)
}
Value::Extension { e } => Ok({
let new_const = match e.get_type().as_type_enum() {
TypeEnum::Extension(exty) => match self.consts.get(exty) {
Some(const_fn) => Some(const_fn(e, self)),
None => self
.param_consts
.get(&exty.into())
.and_then(|const_fn| const_fn(e, self).transpose()),
},
_ => None,
};
if let Some(new_const) = new_const {
*value = new_const?;
true
} else {
false
}
}),
}
}
fn linearize_outputs<H: HugrMut<Node = Node>>(
&self,
hugr: &mut H,
n: H::Node,
) -> Result<(), LinearizeError> {
if let Some(new_sig) = hugr.get_optype(n).dataflow_signature() {
let new_sig = new_sig.into_owned();
for outp in new_sig.output_ports() {
if !new_sig.out_port_type(outp).unwrap().copyable() {
let targets = hugr.linked_inputs(n, outp).collect::<Vec<_>>();
if targets.len() != 1 {
hugr.disconnect(n, outp);
let src = Wire::new(n, outp);
self.linearize.insert_copy_discard(hugr, src, &targets)?;
}
}
}
}
Ok(())
}
}
impl<H: HugrMut<Node = Node>> ComposablePass<H> for ReplaceTypes {
type Error = ReplaceTypesError;
type Result = bool;
fn run(&self, hugr: &mut H) -> Result<bool, ReplaceTypesError> {
let temp: Vec<Node>; let regions = match &self.scope {
Either::Left(scope) => {
temp = Vec::from_iter(scope.root(hugr));
&temp
}
Either::Right(regs) => regs,
};
let mut changed = false;
for region_root in regions {
changed |= self.change_subtree(hugr, *region_root, false)?;
}
Ok(changed)
}
}
impl WithScope for ReplaceTypes {
fn with_scope(mut self, scope: impl Into<PassScope>) -> Self {
self.scope = Either::Left(scope.into());
self
}
}
pub mod handlers;
#[derive(Clone, Hash, PartialEq, Eq)]
struct OpHashWrapper {
op_name: String, args: Vec<TypeArg>,
}
impl From<&ExtensionOp> for OpHashWrapper {
fn from(op: &ExtensionOp) -> Self {
Self {
op_name: op.qualified_id().to_string(),
args: op.args().to_vec(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct ParametricType(ExtensionId, String);
impl From<&TypeDef> for ParametricType {
fn from(value: &TypeDef) -> Self {
Self(value.extension_id().clone(), value.name().to_string())
}
}
impl From<&CustomType> for ParametricType {
fn from(value: &CustomType) -> Self {
Self(value.extension().clone(), value.name().to_string())
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct ParametricOp(ExtensionId, String);
impl From<&OpDef> for ParametricOp {
fn from(value: &OpDef) -> Self {
Self(value.extension_id().clone(), value.name().to_string())
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use crate::passes::replace_types::handlers::generic_array_const;
use hugr_core::builder::{
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, endo_sig,
inout_sig,
};
use hugr_core::extension::SignatureError;
use hugr_core::extension::prelude::{
ConstUsize, Noop, UnwrapBuilder, bool_t, option_type, qb_t, usize_t,
};
use hugr_core::extension::simple_op::{MakeOpDef, MakeRegisteredOp};
use hugr_core::extension::{TypeDefBound, Version, simple_op::MakeExtensionOp};
use hugr_core::hugr::{IdentList, ValidationError, hugrmut::HugrMut};
use hugr_core::ops::constant::{CustomConst, OpaqueValue};
use hugr_core::ops::{self, ExtensionOp, OpTrait, OpType, Tag, Value, handle::NodeHandle};
use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef;
use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES};
use hugr_core::std_extensions::collections::array::{
self, Array, ArrayKind, ArrayOpDef, GenericArrayValue, array_type, array_type_def,
};
use hugr_core::std_extensions::collections::borrow_array::{
BArrayValue, BorrowArray, borrow_array_type,
};
use hugr_core::std_extensions::collections::list::{
ListOp, ListOpInst, ListValue, list_type, list_type_def,
};
use hugr_core::types::{
EdgeKind, PolyFuncType, Signature, SumType, Term, Type, TypeArg, TypeBound, TypeRow,
};
use hugr_core::{Direction, Extension, HugrView, Port, Visibility, type_row};
use itertools::Itertools;
use rstest::rstest;
use crate::passes::{ComposablePass, mangle_name};
use super::{NodeTemplate, ReplaceTypes, handlers::list_const};
const PACKED_VEC: &str = "PackedVec";
const READ: &str = "read";
fn i64_t() -> Type {
INT_TYPES[6].clone()
}
fn read_op(ext: &Arc<Extension>, t: Type) -> ExtensionOp {
ExtensionOp::new(ext.get_op(READ).unwrap().clone(), [t.into()]).unwrap()
}
fn just_elem_type(args: &[TypeArg]) -> &Type {
let [TypeArg::Runtime(ty)] = args else {
panic!("Expected just elem type")
};
ty
}
fn ext() -> Arc<Extension> {
Extension::new_arc(
IdentList::new("TestExt").unwrap(),
Version::new(0, 0, 1),
|ext, w| {
let pv_of_var = ext
.add_type(
PACKED_VEC.into(),
vec![TypeBound::Linear.into()],
String::new(),
TypeDefBound::from_params(vec![0]),
w,
)
.unwrap()
.instantiate(vec![Type::new_var_use(0, TypeBound::Copyable).into()])
.unwrap();
ext.add_op(
READ.into(),
String::new(),
PolyFuncType::new(
vec![TypeBound::Copyable.into()],
Signature::new(
vec![pv_of_var.into(), i64_t()],
[Type::new_var_use(0, TypeBound::Linear)],
),
),
w,
)
.unwrap();
ext.add_op(
"lowered_read_bool".into(),
String::new(),
Signature::new(vec![i64_t(); 2], [bool_t()]),
w,
)
.unwrap();
},
)
}
fn lowered_read<T: Container + Dataflow>(
elem_ty: Type,
new: impl Fn(Signature) -> Result<T, BuildError>,
) -> T {
let mut dfb = new(Signature::new(
[list_type(elem_ty.clone()), i64_t()],
[elem_ty.clone()],
))
.unwrap();
let [val, idx] = dfb.input_wires_arr();
let [idx] = dfb
.add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx])
.unwrap()
.outputs_arr();
let [opt] = dfb
.add_dataflow_op(
ListOp::get
.with_type(elem_ty.clone())
.to_extension_op()
.unwrap(),
[val, idx],
)
.unwrap()
.outputs_arr();
let [res] = dfb
.build_unwrap_sum(1, option_type([Type::from(elem_ty)]), opt)
.unwrap();
dfb.set_outputs([res]).unwrap();
dfb
}
fn lowerer(ext: &Arc<Extension>) -> ReplaceTypes {
let pv = ext.get_type(PACKED_VEC).unwrap();
let mut lw = ReplaceTypes::default();
lw.set_replace_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t());
lw.set_replace_parametrized_type(
pv,
Box::new(|args: &[TypeArg]| Some(list_type(just_elem_type(args).clone()))),
);
lw.set_replace_op(
&read_op(ext, bool_t()),
NodeTemplate::SingleOp(
ExtensionOp::new(ext.get_op("lowered_read_bool").unwrap().clone(), [])
.unwrap()
.into(),
),
);
lw.set_replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), |type_args, _| {
Ok(Some(NodeTemplate::CompoundOp(Box::new(
lowered_read(just_elem_type(type_args).clone(), DFGBuilder::new)
.finish_hugr()
.unwrap(),
))))
});
lw
}
#[test]
fn module_func_cfg_call() {
let ext = ext();
let coln = ext.get_type(PACKED_VEC).unwrap();
let c_int = Type::from(coln.instantiate([i64_t().into()]).unwrap());
let c_bool = Type::from(coln.instantiate([bool_t().into()]).unwrap());
let mut mb = ModuleBuilder::new();
let sig = Signature::new_endo([Type::new_var_use(0, TypeBound::Linear)]);
let fb = mb
.define_function("id", PolyFuncType::new([TypeBound::Linear.into()], sig))
.unwrap();
let inps = fb.input_wires();
let id = fb.finish_with_outputs(inps).unwrap();
let sig = Signature::new([i64_t(), c_int.clone(), c_bool.clone()], [bool_t()]);
let mut fb = mb.define_function("main", sig).unwrap();
let [idx, indices, bools] = fb.input_wires_arr();
let [indices] = fb
.call(id.handle(), &[c_int.into()], [indices])
.unwrap()
.outputs_arr();
let [idx2] = fb
.add_dataflow_op(read_op(&ext, i64_t()), [indices, idx])
.unwrap()
.outputs_arr();
let mut cfg = fb
.cfg_builder(
[(i64_t(), idx2), (c_bool.clone(), bools)],
[bool_t()].into(),
)
.unwrap();
let mut entry = cfg.entry_builder([[bool_t()].into()], type_row![]).unwrap();
let [idx2, bools] = entry.input_wires_arr();
let [bools] = entry
.call(id.handle(), &[c_bool.into()], [bools])
.unwrap()
.outputs_arr();
let bool_read_op = entry
.add_dataflow_op(read_op(&ext, bool_t()), [bools, idx2])
.unwrap();
let [tagged] = entry
.add_dataflow_op(
OpType::Tag(Tag::new(0, vec![[bool_t()].into()])),
bool_read_op.outputs(),
)
.unwrap()
.outputs_arr();
let entry = entry.finish_with_outputs(tagged, []).unwrap();
cfg.branch(&entry, 0, &cfg.exit_block()).unwrap();
let cfg = cfg.finish_sub_container().unwrap();
fb.finish_with_outputs(cfg.outputs()).unwrap();
let mut h = mb.finish_hugr().unwrap();
assert!(lowerer(&ext).run(&mut h).unwrap());
let ext_ops = h
.entry_descendants()
.filter_map(|n| h.get_optype(n).as_extension_op());
assert_eq!(
ext_ops
.map(hugr_core::ops::ExtensionOp::unqualified_id)
.sorted()
.collect_vec(),
["get", "itousize", "lowered_read_bool", "panic",]
);
}
#[test]
fn dfg_conditional_case() {
let ext = ext();
let coln = ext.get_type(PACKED_VEC).unwrap();
let pv = |t: Type| Type::new_extension(coln.instantiate([t.into()]).unwrap());
let sum_rows = [[pv(pv(bool_t())), i64_t()].into(), [pv(i64_t())].into()];
let mut dfb = DFGBuilder::new(inout_sig(
vec![Type::new_sum(sum_rows.clone()), pv(bool_t()), pv(i64_t())],
vec![pv(bool_t()), pv(i64_t())],
))
.unwrap();
let [sum, vb, vi] = dfb.input_wires_arr();
let mut cb = dfb
.conditional_builder(
(sum_rows, sum),
[(pv(bool_t()), vb), (pv(i64_t()), vi)],
vec![pv(bool_t()), pv(i64_t())].into(),
)
.unwrap();
let mut case0 = cb.case_builder(0).unwrap();
let [vvb, i, _, vi0] = case0.input_wires_arr();
let [vb0] = case0
.add_dataflow_op(read_op(&ext, pv(bool_t())), [vvb, i])
.unwrap()
.outputs_arr();
case0.finish_with_outputs([vb0, vi0]).unwrap();
let case1 = cb.case_builder(1).unwrap();
let [vi, vb1, _vi1] = case1.input_wires_arr();
case1.finish_with_outputs([vb1, vi]).unwrap();
let cond = cb.finish_sub_container().unwrap();
let mut h = dfb.finish_hugr_with_outputs(cond.outputs()).unwrap();
lowerer(&ext).run(&mut h).unwrap();
let ext_ops = h
.entry_descendants()
.filter_map(|n| h.get_optype(n).as_extension_op())
.collect_vec();
assert_eq!(
ext_ops
.iter()
.map(|x| x.unqualified_id())
.sorted()
.collect_vec(),
["get", "itousize", "panic"]
);
let array_gets = ext_ops
.into_iter()
.filter_map(|e| ListOpInst::from_extension_op(e).ok())
.collect_vec();
assert_eq!(array_gets, [ListOp::get.with_type(i64_t())]);
}
#[test]
fn loop_const() {
let cu = |u| ConstUsize::new(u).into();
let mut tl = TailLoopBuilder::new(
[list_type(usize_t())],
[list_type(bool_t())],
[list_type(usize_t())],
)
.unwrap();
let [_, bools] = tl.input_wires_arr();
let st = SumType::new(vec![[list_type(usize_t())]; 2]);
let pred = tl.add_load_value(
Value::sum(
0,
[ListValue::new(usize_t(), [cu(1), cu(3), cu(3), cu(7)]).into()],
st,
)
.unwrap(),
);
tl.set_outputs(pred, [bools]).unwrap();
let backup = tl.finish_hugr().unwrap();
let mut lowerer = ReplaceTypes::default();
lowerer.set_replace_parametrized_type(list_type_def(), |args| {
let ty = just_elem_type(args);
(![usize_t(), i64_t()].contains(ty)).then_some(borrow_array_type(10, ty.clone()))
});
{
let mut h = backup.clone();
assert_eq!(lowerer.run(&mut h), Ok(true));
let sig = h.signature(h.entrypoint()).unwrap();
assert_eq!(
sig.input(),
&TypeRow::from(vec![list_type(usize_t()), borrow_array_type(10, bool_t())])
);
assert_eq!(sig.input(), sig.output());
}
let usize_custom_t = usize_t().as_extension().unwrap().clone();
lowerer.set_replace_type(usize_custom_t.clone(), i64_t());
lowerer.replace_consts(usize_custom_t, |opaq, _| {
Ok(ConstInt::new_u(
6,
opaq.value().downcast_ref::<ConstUsize>().unwrap().value(),
)
.unwrap()
.into())
});
{
let mut h = backup.clone();
assert_eq!(lowerer.run(&mut h), Ok(true));
let sig = h.signature(h.entrypoint()).unwrap();
assert_eq!(
sig.input(),
&TypeRow::from(vec![list_type(i64_t()), borrow_array_type(10, bool_t())])
);
assert_eq!(sig.input(), sig.output());
let cst = h
.entry_descendants()
.filter_map(|n| h.get_optype(n).as_const())
.exactly_one()
.ok()
.unwrap();
assert_eq!(cst.get_type(), Type::new_sum(vec![[list_type(i64_t())]; 2]));
}
let mut h = backup;
lowerer.set_replace_parametrized_type(
list_type_def(),
Box::new(|args: &[TypeArg]| Some(borrow_array_type(4, just_elem_type(args).clone()))),
);
lowerer.replace_consts_parametrized(list_type_def(), |opaq, repl| {
let Some(Value::Extension { e: opaq }) = list_const(opaq, repl)? else {
panic!("Expected list value to stay a list value");
};
let lv = opaq.value().downcast_ref::<ListValue>().unwrap();
Ok(Some(
BArrayValue::new(lv.get_element_type().clone(), lv.get_contents().to_vec()).into(),
))
});
lowerer.run(&mut h).unwrap();
assert_eq!(
h.get_optype(pred.node())
.as_load_constant()
.map(hugr_core::ops::LoadConstant::constant_type),
Some(&Type::new_sum(vec![
[Type::from(borrow_array_type(
4,
i64_t()
))];
2
]))
);
}
#[test]
fn partial_replace() {
let e = Extension::new_arc(
IdentList::new_unchecked("NoBoundsCheck"),
Version::new(0, 0, 0),
|e, w| {
let params = vec![TypeBound::Linear.into()];
let tv = Type::new_var_use(0, TypeBound::Linear);
let list_of_var = list_type(tv.clone());
e.add_op(
READ.into(),
"Like List::get but without the option".to_string(),
PolyFuncType::new(params, Signature::new([list_of_var, usize_t()], [tv])),
w,
)
.unwrap();
},
);
fn option_contents(ty: &Type) -> Option<Type> {
let row = ty.as_sum()?.get_variant(1).unwrap().clone();
let elem = row.into_owned().into_iter().exactly_one().unwrap();
Some(elem.try_into_type().unwrap())
}
let i32_t = || INT_TYPES[5].clone();
let opt_i32 = Type::from(option_type([i32_t()]));
let i32_custom_t = i32_t().as_extension().unwrap().clone();
let mut dfb = DFGBuilder::new(inout_sig(
vec![list_type(i32_t()), list_type(opt_i32.clone())],
vec![i32_t(), opt_i32.clone()],
))
.unwrap();
let [l_i, l_oi] = dfb.input_wires_arr();
let idx = dfb.add_load_value(ConstUsize::new(2));
let [i] = dfb
.add_dataflow_op(read_op(&e, i32_t()), [l_i, idx])
.unwrap()
.outputs_arr();
let [oi] = dfb
.add_dataflow_op(read_op(&e, opt_i32.clone()), [l_oi, idx])
.unwrap()
.outputs_arr();
let mut h = dfb.finish_hugr_with_outputs([i, oi]).unwrap();
let mut lowerer = ReplaceTypes::default();
lowerer.set_replace_type(i32_custom_t, qb_t());
lowerer.set_replace_parametrized_type(list_type_def(), |args| {
option_contents(just_elem_type(args)).map(list_type)
});
lowerer.set_replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), |args, _| {
Ok(option_contents(just_elem_type(args)).map(|elem| {
NodeTemplate::SingleOp(
ListOp::get
.with_type(elem)
.to_extension_op()
.unwrap()
.into(),
)
}))
});
assert!(lowerer.run(&mut h).unwrap());
assert_eq!(
h.entrypoint_optype().dataflow_signature().unwrap().io(),
(
&vec![list_type(qb_t()); 2].into(),
&vec![qb_t(), option_type([qb_t()]).into()].into()
)
);
assert_eq!(
h.entry_descendants()
.filter_map(|n| h.get_optype(n).as_extension_op())
.map(hugr_core::ops::ExtensionOp::qualified_id)
.sorted()
.collect_vec(),
["NoBoundsCheck.read", "collections.list.get"]
);
}
#[rstest]
#[case(&[], Array)]
#[case(&[3], Array)]
#[case(&[5,7,11,13,17,19], BorrowArray)]
fn array_const<AK: ArrayKind>(#[case] vals: &[u64], #[case] _kind: AK)
where
GenericArrayValue<AK>: CustomConst,
{
let mut dfb =
DFGBuilder::new(inout_sig(type_row![], [AK::ty(vals.len() as _, usize_t())])).unwrap();
let c = dfb.add_load_value(GenericArrayValue::<AK>::new(
usize_t(),
vals.iter().map(|u| ConstUsize::new(*u).into()),
));
let backup = dfb.finish_hugr_with_outputs([c]).unwrap();
let mut repl = ReplaceTypes::new_empty();
let usize_custom_t = usize_t().as_extension().unwrap().clone();
repl.set_replace_type(usize_custom_t.clone(), INT_TYPES[6].clone());
repl.replace_consts(usize_custom_t, |cst: &OpaqueValue, _| {
let cu = cst.value().downcast_ref::<ConstUsize>().unwrap();
Ok(ConstInt::new_u(6, cu.value())?.into())
});
let mut h = backup.clone();
repl.run(&mut h).unwrap(); assert!(
matches!(h.validate(), Err(ValidationError::IncompatiblePorts {from, to, ..})
if backup.get_optype(from).is_const() && to == c.node())
);
repl.replace_consts_parametrized(AK::type_def(), generic_array_const::<AK>);
let mut h = backup;
repl.run(&mut h).unwrap();
h.validate().unwrap();
}
#[rstest]
fn op_to_call_polymorphic() {
let e = ext();
let pv = e.get_type(PACKED_VEC).unwrap();
let inner = pv.instantiate([usize_t().into()]).unwrap();
let outer = pv
.instantiate([Type::new_extension(inner.clone()).into()])
.unwrap();
let mut dfb = DFGBuilder::new(inout_sig([outer.into(), i64_t()], [usize_t()])).unwrap();
let read_func = dfb
.module_root_builder()
.add_hugr(
lowered_read(Type::new_var_use(0, TypeBound::Copyable), |sig| {
FunctionBuilder::new_vis(
"lowered_read",
PolyFuncType::new([TypeBound::Copyable.into()], sig),
Visibility::Public,
)
})
.finish_hugr()
.unwrap(),
)
.inserted_entrypoint;
let [outer, idx] = dfb.input_wires_arr();
let [inner] = dfb
.add_dataflow_op(read_op(&e, inner.clone().into()), [outer, idx])
.unwrap()
.outputs_arr();
let res = dfb
.add_dataflow_op(read_op(&e, usize_t()), [inner, idx])
.unwrap();
let mut h = dfb.finish_hugr_with_outputs(res.outputs()).unwrap();
let read_poly = h
.get_optype(read_func)
.as_func_defn()
.unwrap()
.signature()
.clone();
let mut lw = lowerer(&e);
lw.set_replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args, _| {
let mut decl_b = ModuleBuilder::new();
let decl_node = decl_b.declare("lowered_read", read_poly.clone()).unwrap();
let mut decl_hugr = decl_b.finish_hugr().unwrap();
decl_hugr.set_entrypoint(decl_node.node());
Ok(Some(
NodeTemplate::call_to_function(decl_hugr, args).unwrap(),
))
});
lw.run(&mut h).unwrap();
h.validate().unwrap();
assert_eq!(h.output_neighbours(read_func).count(), 2);
assert_eq!(
h.entry_descendants()
.find(|n| h.get_optype(*n).is_extension_op()),
None
);
assert_eq!(h.children(h.module_root()).count(), 2); }
#[rstest]
fn op_to_call_monomorphic(#[values(false, true)] i64_to_usize: bool) {
let e = ext();
let pv = e.get_type(PACKED_VEC).unwrap();
let inner = pv.instantiate([usize_t().into()]).unwrap();
let outer = pv
.instantiate([Type::new_extension(inner.clone()).into()])
.unwrap();
let read_outer = read_op(&e, inner.clone().into());
let mut dfb = DFGBuilder::new(inout_sig(
vec![outer.into(), inner.clone().into(), i64_t()],
vec![usize_t(); 2],
))
.unwrap();
let [outer, inner, idx] = dfb.input_wires_arr();
let res1 = dfb
.add_dataflow_op(read_op(&e, usize_t()), [inner, idx])
.unwrap();
let [inner] = dfb
.add_dataflow_op(read_outer, [outer, idx])
.unwrap()
.outputs_arr();
let res2 = dfb
.add_dataflow_op(read_op(&e, usize_t()), [inner, idx])
.unwrap();
let mut h = dfb
.finish_hugr_with_outputs(res1.outputs().chain(res2.outputs()))
.unwrap();
let mut lw = lowerer(&e);
lw.set_replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args, _| {
Ok(Some({
let [Term::Runtime(ty)] = args else {
return Err(SignatureError::InvalidTypeArgs.into());
};
let defn_hugr = lowered_read(ty.clone(), |sig| {
FunctionBuilder::new_vis(
mangle_name("lowered_read", args),
sig,
Visibility::Public,
)
})
.finish_hugr()
.unwrap();
NodeTemplate::call_to_function(defn_hugr, &[]).unwrap()
}))
});
if i64_to_usize {
lw.set_replace_type(i64_t().as_extension().unwrap().clone(), usize_t());
lw.set_replace_op(
&ConvertOpDef::itousize
.without_log_width()
.to_extension_op()
.unwrap(),
NodeTemplate::SingleOp(Noop::new(usize_t()).into()),
);
}
lw.run(&mut h).unwrap();
h.validate().unwrap();
assert_eq!(
h.entry_descendants()
.find(|n| h.get_optype(*n).is_extension_op()),
None
);
assert_eq!(h.children(h.module_root()).count(), 3); for n in h.children(h.module_root()) {
let fd = h.get_optype(n).as_func_defn().unwrap();
let expected_uses_and_vis = if fd.func_name() == "main" {
(0, Visibility::Private)
} else {
let is_array = !fd.signature().body().output[0]
.as_extension()
.unwrap()
.args()
.is_empty();
(2 - (is_array as usize), Visibility::Public)
};
assert_eq!(h.output_neighbours(n).count(), expected_uses_and_vis.0);
assert_eq!(fd.visibility(), &expected_uses_and_vis.1);
}
}
#[test]
fn regions() {
let ext = ext();
let coln = ext.get_type(PACKED_VEC).unwrap();
let c_u = Type::new_extension(coln.instantiate(&[usize_t().into()]).unwrap());
let mut h = {
let db = DFGBuilder::new(endo_sig([c_u.clone()])).unwrap();
let inps = db.input_wires();
db.finish_hugr_with_outputs(inps)
}
.unwrap();
let mut lowerer = lowerer(&ext);
{
let backup = h.clone();
lowerer.set_regions(vec![]);
assert!(!lowerer.run(&mut h).unwrap());
assert_eq!(h, backup);
}
let ep = h.entrypoint();
lowerer.set_regions(vec![h.entrypoint()]);
assert!(lowerer.run(&mut h).unwrap());
let v_u = list_type(usize_t());
assert_eq!(h.signature(ep).unwrap().as_ref(), &endo_sig([v_u.clone()]));
assert_eq!(h.num_nodes(), h.num_nodes());
let [f_in, _] = h.get_io(h.get_parent(ep).unwrap()).unwrap();
assert_eq!(
h.validate(),
Err(ValidationError::IncompatiblePorts {
from: f_in,
from_port: Port::new(Direction::Outgoing, 0),
to: ep,
to_port: Port::new(Direction::Incoming, 0),
from_kind: Box::new(EdgeKind::Value(c_u)),
to_kind: Box::new(EdgeKind::Value(v_u))
})
);
}
#[test]
fn compositionality() {
let ext = ext();
let mut lowerer = lowerer(&ext);
let ext2 = ext.clone();
lowerer.set_replace_parametrized_type(array_type_def(), move |args| {
let [sz, ty] = args else {
panic!("Expected two args to array")
};
(sz == &Term::BoundedNat(64)).then_some(
ext2.get_type(PACKED_VEC)
.unwrap()
.instantiate([ty.clone()])
.unwrap()
.into(),
)
});
let ext = ext.clone();
lowerer.set_replace_parametrized_op(
array::EXTENSION
.get_op(ArrayOpDef::get.opdef_id().as_str())
.unwrap()
.as_ref(),
move |args, _| {
let [sz, Term::Runtime(ty)] = args else {
panic!("Expected two args to array-get")
};
if sz != &Term::BoundedNat(64) {
return Ok(None);
}
let pv = ext
.get_type(PACKED_VEC)
.unwrap()
.instantiate([ty.clone().into()])
.unwrap();
let mut dfb = DFGBuilder::new(Signature::new(
vec![pv.clone().into(), usize_t()],
vec![option_type([ty.clone()]).into(), pv.into()],
))
.unwrap();
let [pvec, idx] = dfb.input_wires_arr();
let [idx] = dfb
.add_dataflow_op(ConvertOpDef::ifromusize.without_log_width(), [idx])
.unwrap()
.outputs_arr();
let [elem] = dfb
.add_dataflow_op(read_op(&ext, ty.clone()), [pvec, idx])
.unwrap()
.outputs_arr();
let [wrapped_elem] = dfb
.add_dataflow_op(
ops::Tag::new(1, vec![type_row![], [ty.clone()].into()]),
[elem],
)
.unwrap()
.outputs_arr();
Ok(Some(NodeTemplate::CompoundOp(Box::new(
dfb.finish_hugr_with_outputs([wrapped_elem, pvec]).unwrap(),
))))
},
);
let a64 = |t| array_type(64, t);
let opt = |t| Type::from(option_type([t]));
let mut dfb = DFGBuilder::new(Signature::new(
vec![a64(bool_t()), a64(usize_t())],
vec![opt(bool_t()), a64(bool_t()), opt(usize_t()), a64(usize_t())],
))
.unwrap();
let [bools, usizes] = dfb.input_wires_arr();
let idx = dfb.add_load_value(ConstUsize::new(5));
let [b, bools] = dfb
.add_dataflow_op(ArrayOpDef::get.to_concrete(bool_t(), 64), [bools, idx])
.unwrap()
.outputs_arr();
let [u, usizes] = dfb
.add_dataflow_op(ArrayOpDef::get.to_concrete(usize_t(), 64), [usizes, idx])
.unwrap()
.outputs_arr();
let mut h = dfb.finish_hugr_with_outputs([b, bools, u, usizes]).unwrap();
lowerer.run(&mut h).unwrap();
h.validate().unwrap();
}
}