use crate::resource::types::ResourceId;
use derive_more::derive::{Display, Error};
use hugr::HugrView;
use hugr::ops::{OpTrait, OpType};
use hugr::types::Type;
use itertools::{EitherOrBoth, Itertools};
#[derive(Debug, Display, Clone, PartialEq, Error)]
#[display("Unsupported operation: {_0}")]
pub struct UnsupportedOp(#[error(not(source))] OpType);
pub trait ResourceFlow<H: HugrView> {
fn map_resources(
&self,
node: H::Node,
hugr: &H,
inputs: &[Option<ResourceId>],
) -> Result<Vec<Option<ResourceId>>, UnsupportedOp>;
fn into_boxed<'a>(self) -> Box<dyn 'a + ResourceFlow<H>>
where
Self: 'a + Sized,
{
Box::new(self)
}
}
impl<H: HugrView> ResourceFlow<H> for Box<dyn '_ + ResourceFlow<H>> {
fn map_resources(
&self,
node: H::Node,
hugr: &H,
inputs: &[Option<ResourceId>],
) -> Result<Vec<Option<ResourceId>>, UnsupportedOp> {
self.as_ref().map_resources(node, hugr, inputs)
}
}
#[derive(Debug, Clone, Default)]
pub struct DefaultResourceFlow;
impl DefaultResourceFlow {
fn is_resource_preserving(input_types: &[Type], output_types: &[Type]) -> bool {
for io_ty in input_types.iter().zip_longest(output_types.iter()) {
let (input_ty, output_ty) = match io_ty {
EitherOrBoth::Both(input_ty, output_ty) => (input_ty, output_ty),
EitherOrBoth::Left(ty) | EitherOrBoth::Right(ty) => {
if !ty.copyable() {
return false;
}
continue;
}
};
if !input_ty.copyable() || !output_ty.copyable() {
if input_ty != output_ty {
return false;
}
}
}
true
}
}
impl<H: HugrView> ResourceFlow<H> for DefaultResourceFlow {
fn map_resources(
&self,
node: H::Node,
hugr: &H,
inputs: &[Option<ResourceId>],
) -> Result<Vec<Option<ResourceId>>, UnsupportedOp> {
let op = hugr.get_optype(node);
let signature = op.dataflow_signature().expect("dataflow op");
let input_types = signature.input_types();
let output_types = signature.output_types();
debug_assert_eq!(
inputs.len(),
input_types.len(),
"Input resource array length must match operation input count"
);
if Self::is_resource_preserving(input_types, output_types) {
Ok(retain_linear_types(inputs.to_vec(), output_types))
} else {
Ok(vec![None; output_types.len()])
}
}
}
fn retain_linear_types(
mut resources: Vec<Option<ResourceId>>,
types: &[Type],
) -> Vec<Option<ResourceId>> {
resources.resize(types.len(), None);
for (ty, resource) in types.iter().zip(resources.iter_mut()) {
if ty.copyable() {
*resource = None;
}
}
resources
}