use std::sync::{self, Arc, LazyLock};
use delegate::delegate;
use crate::extension::{ExtensionId, SignatureError, TypeDef, TypeDefBound};
use crate::ops::constant::{CustomConst, ValueName};
use crate::type_row;
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{CustomCheckFailure, Term, Type, TypeBound, TypeName};
use crate::{Extension, Wire};
use crate::{
builder::{BuildError, Dataflow},
extension::SignatureFunc,
};
use crate::{
extension::simple_op::{HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp},
ops::ExtensionOp,
};
use crate::{
extension::{
OpDef,
prelude::usize_t,
resolution::{ExtensionResolutionError, WeakExtensionRegistry},
simple_op::{OpLoadError, try_from_name},
},
ops::OpName,
types::{FuncValueType, PolyFuncTypeRV},
};
use super::array::op_builder::GenericArrayOpBuilder;
use super::array::{
Array, ArrayKind, FROM, GenericArrayClone, GenericArrayCloneDef, GenericArrayConvert,
GenericArrayConvertDef, GenericArrayDiscard, GenericArrayDiscardDef, GenericArrayOp,
GenericArrayOpDef, GenericArrayRepeat, GenericArrayRepeatDef, GenericArrayScan,
GenericArrayScanDef, GenericArrayValue, INTO,
};
pub const BORROW_ARRAY_TYPENAME: TypeName = TypeName::new_inline("borrow_array");
pub const BORROW_ARRAY_VALUENAME: TypeName = TypeName::new_inline("borrow_array");
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections.borrow_arr");
pub const VERSION: semver::Version = semver::Version::new(0, 2, 0);
#[derive(Clone, Copy, Debug, derive_more::Display, Eq, PartialEq, Default)]
pub struct BorrowArray;
impl ArrayKind for BorrowArray {
const EXTENSION_ID: ExtensionId = EXTENSION_ID;
const TYPE_NAME: TypeName = BORROW_ARRAY_TYPENAME;
const VALUE_NAME: ValueName = BORROW_ARRAY_VALUENAME;
fn extension() -> &'static Arc<Extension> {
&EXTENSION
}
fn type_def() -> &'static TypeDef {
EXTENSION.get_type(&BORROW_ARRAY_TYPENAME).unwrap()
}
}
pub type BArrayOpDef = GenericArrayOpDef<BorrowArray>;
pub type BArrayCloneDef = GenericArrayCloneDef<BorrowArray>;
pub type BArrayDiscardDef = GenericArrayDiscardDef<BorrowArray>;
pub type BArrayRepeatDef = GenericArrayRepeatDef<BorrowArray>;
pub type BArrayScanDef = GenericArrayScanDef<BorrowArray>;
pub type BArrayToArrayDef = GenericArrayConvertDef<BorrowArray, INTO, Array>;
pub type BArrayFromArrayDef = GenericArrayConvertDef<BorrowArray, FROM, Array>;
pub type BArrayOp = GenericArrayOp<BorrowArray>;
pub type BArrayClone = GenericArrayClone<BorrowArray>;
pub type BArrayDiscard = GenericArrayDiscard<BorrowArray>;
pub type BArrayRepeat = GenericArrayRepeat<BorrowArray>;
pub type BArrayScan = GenericArrayScan<BorrowArray>;
pub type BArrayToArray = GenericArrayConvert<BorrowArray, INTO, Array>;
pub type BArrayFromArray = GenericArrayConvert<BorrowArray, FROM, Array>;
pub type BArrayValue = GenericArrayValue<BorrowArray>;
#[derive(
Clone,
Copy,
Debug,
Hash,
PartialEq,
Eq,
strum::EnumIter,
strum::IntoStaticStr,
strum::EnumString,
)]
#[allow(non_camel_case_types, missing_docs)]
#[non_exhaustive]
pub enum BArrayUnsafeOpDef {
borrow,
#[strum(serialize = "return")]
r#return,
discard_all_borrowed,
new_all_borrowed,
is_borrowed,
}
impl BArrayUnsafeOpDef {
#[must_use]
pub fn to_concrete(self, elem_ty: Type, size: u64) -> BArrayUnsafeOp {
BArrayUnsafeOp {
def: self,
elem_ty,
size,
}
}
fn signature_from_def(&self, def: &TypeDef, _: &sync::Weak<Extension>) -> SignatureFunc {
let size_var = TypeArg::new_var_use(0, TypeParam::max_nat_type());
let elem_ty_var = Type::new_var_use(1, TypeBound::Linear);
let array_ty: Type = def
.instantiate(vec![size_var, elem_ty_var.clone().into()])
.unwrap()
.into();
let params = vec![TypeParam::max_nat_type(), TypeBound::Linear.into()];
let usize_t: Type = usize_t();
match self {
Self::borrow => PolyFuncTypeRV::new(
params,
FuncValueType::new(vec![array_ty.clone(), usize_t], vec![array_ty, elem_ty_var]),
),
Self::r#return => PolyFuncTypeRV::new(
params,
FuncValueType::new(
vec![array_ty.clone(), usize_t, elem_ty_var.clone()],
vec![array_ty],
),
),
Self::discard_all_borrowed => {
PolyFuncTypeRV::new(params, FuncValueType::new(vec![array_ty], type_row![]))
}
Self::new_all_borrowed => {
PolyFuncTypeRV::new(params, FuncValueType::new(type_row![], vec![array_ty]))
}
Self::is_borrowed => PolyFuncTypeRV::new(
params,
FuncValueType::new(
vec![array_ty.clone(), usize_t],
vec![array_ty, crate::extension::prelude::bool_t()],
),
),
}
.into()
}
}
impl MakeOpDef for BArrayUnsafeOpDef {
fn opdef_id(&self) -> OpName {
<&'static str>::from(self).into()
}
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
where
Self: Sized,
{
try_from_name(op_def.name(), op_def.extension_id())
}
fn init_signature(&self, extension_ref: &sync::Weak<Extension>) -> SignatureFunc {
self.signature_from_def(
EXTENSION.get_type(&BORROW_ARRAY_TYPENAME).unwrap(),
extension_ref,
)
}
fn extension_ref(&self) -> sync::Weak<Extension> {
Arc::downgrade(&EXTENSION)
}
fn extension(&self) -> ExtensionId {
EXTENSION_ID.clone()
}
fn description(&self) -> String {
match self {
Self::borrow => {
"Take an element from a borrow array (panicking if it was already taken before)"
}
Self::r#return => {
"Put an element into a borrow array (panicking if there is an element already)"
}
Self::discard_all_borrowed => {
"Discard a borrow array where all elements have been borrowed"
}
Self::new_all_borrowed => "Create a new borrow array that contains no elements",
Self::is_borrowed => "Test whether an element in a borrow array has been borrowed",
}
.into()
}
fn add_to_extension(
&self,
extension: &mut Extension,
extension_ref: &sync::Weak<Extension>,
) -> Result<(), crate::extension::ExtensionBuildError> {
let sig = self.signature_from_def(
extension.get_type(&BORROW_ARRAY_TYPENAME).unwrap(),
extension_ref,
);
let def = extension.add_op(self.opdef_id(), self.description(), sig, extension_ref)?;
self.post_opdef(def);
Ok(())
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct BArrayUnsafeOp {
pub def: BArrayUnsafeOpDef,
pub elem_ty: Type,
pub size: u64,
}
impl MakeExtensionOp for BArrayUnsafeOp {
fn op_id(&self) -> OpName {
self.def.opdef_id()
}
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized,
{
let def = BArrayUnsafeOpDef::from_def(ext_op.def())?;
def.instantiate(ext_op.args())
}
fn type_args(&self) -> Vec<TypeArg> {
vec![self.size.into(), self.elem_ty.clone().into()]
}
}
impl HasDef for BArrayUnsafeOp {
type Def = BArrayUnsafeOpDef;
}
impl HasConcrete for BArrayUnsafeOpDef {
type Concrete = BArrayUnsafeOp;
fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
match type_args {
[Term::BoundedNat(n), Term::Runtime(ty)] => Ok(self.to_concrete(ty.clone(), *n)),
_ => Err(SignatureError::InvalidTypeArgs.into()),
}
}
}
impl MakeRegisteredOp for BArrayUnsafeOp {
fn extension_id(&self) -> ExtensionId {
EXTENSION_ID.clone()
}
fn extension_ref(&self) -> Arc<Extension> {
EXTENSION.clone()
}
}
pub static EXTENSION: LazyLock<Arc<Extension>> = LazyLock::new(|| {
Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
extension
.add_type(
BORROW_ARRAY_TYPENAME,
vec![TypeParam::max_nat_type(), TypeBound::Linear.into()],
"Fixed-length borrow array".into(),
TypeDefBound::any(),
extension_ref,
)
.unwrap();
BArrayOpDef::load_all_ops(extension, extension_ref).unwrap();
BArrayCloneDef::new()
.add_to_extension(extension, extension_ref)
.unwrap();
BArrayDiscardDef::new()
.add_to_extension(extension, extension_ref)
.unwrap();
BArrayRepeatDef::new()
.add_to_extension(extension, extension_ref)
.unwrap();
BArrayScanDef::new()
.add_to_extension(extension, extension_ref)
.unwrap();
BArrayToArrayDef::new()
.add_to_extension(extension, extension_ref)
.unwrap();
BArrayFromArrayDef::new()
.add_to_extension(extension, extension_ref)
.unwrap();
BArrayUnsafeOpDef::load_all_ops(extension, extension_ref).unwrap();
})
});
#[typetag::serde(name = "BArrayValue")]
impl CustomConst for BArrayValue {
delegate! {
to self {
fn name(&self) -> ValueName;
fn validate(&self) -> Result<(), CustomCheckFailure>;
fn update_extensions(
&mut self,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError>;
fn get_type(&self) -> Type;
}
}
fn equal_consts(&self, other: &dyn CustomConst) -> bool {
crate::ops::constant::downcast_equal_consts(self, other)
}
}
#[must_use]
pub fn borrow_array_type_def() -> &'static TypeDef {
BorrowArray::type_def()
}
#[must_use]
pub fn borrow_array_type(size: u64, element_ty: Type) -> Type {
BorrowArray::ty(size, element_ty)
}
pub fn borrow_array_type_parametric(
size: impl Into<TypeArg>,
element_ty: impl Into<TypeArg>,
) -> Result<Type, SignatureError> {
BorrowArray::ty_parametric(size, element_ty)
}
pub trait BArrayOpBuilder: GenericArrayOpBuilder {
fn add_new_borrow_array(
&mut self,
elem_ty: Type,
values: impl IntoIterator<Item = Wire>,
) -> Result<Wire, BuildError> {
self.add_new_generic_array::<BorrowArray>(elem_ty, values)
}
fn add_borrow_array_unpack(
&mut self,
elem_ty: Type,
size: u64,
input: Wire,
) -> Result<Vec<Wire>, BuildError> {
self.add_generic_array_unpack::<BorrowArray>(elem_ty, size, input)
}
fn add_borrow_array_clone(
&mut self,
elem_ty: Type,
size: u64,
input: Wire,
) -> Result<(Wire, Wire), BuildError> {
self.add_generic_array_clone::<BorrowArray>(elem_ty, size, input)
}
fn add_borrow_array_discard(
&mut self,
elem_ty: Type,
size: u64,
input: Wire,
) -> Result<(), BuildError> {
self.add_generic_array_discard::<BorrowArray>(elem_ty, size, input)
}
fn add_borrow_array_get(
&mut self,
elem_ty: Type,
size: u64,
input: Wire,
index: Wire,
) -> Result<(Wire, Wire), BuildError> {
self.add_generic_array_get::<BorrowArray>(elem_ty, size, input, index)
}
fn add_borrow_array_set(
&mut self,
elem_ty: Type,
size: u64,
input: Wire,
index: Wire,
value: Wire,
) -> Result<Wire, BuildError> {
self.add_generic_array_set::<BorrowArray>(elem_ty, size, input, index, value)
}
fn add_borrow_array_swap(
&mut self,
elem_ty: Type,
size: u64,
input: Wire,
index1: Wire,
index2: Wire,
) -> Result<Wire, BuildError> {
let op =
GenericArrayOpDef::<BorrowArray>::swap.instantiate(&[size.into(), elem_ty.into()])?;
let [out] = self
.add_dataflow_op(op, vec![input, index1, index2])?
.outputs_arr();
Ok(out)
}
fn add_borrow_array_pop_left(
&mut self,
elem_ty: Type,
size: u64,
input: Wire,
) -> Result<Wire, BuildError> {
self.add_generic_array_pop_left::<BorrowArray>(elem_ty, size, input)
}
fn add_borrow_array_pop_right(
&mut self,
elem_ty: Type,
size: u64,
input: Wire,
) -> Result<Wire, BuildError> {
self.add_generic_array_pop_right::<BorrowArray>(elem_ty, size, input)
}
fn add_borrow_array_discard_empty(
&mut self,
elem_ty: Type,
input: Wire,
) -> Result<(), BuildError> {
self.add_generic_array_discard_empty::<BorrowArray>(elem_ty, input)
}
fn add_borrow_array_borrow(
&mut self,
elem_ty: Type,
size: u64,
input: Wire,
index: Wire,
) -> Result<(Wire, Wire), BuildError> {
let op = BArrayUnsafeOpDef::borrow.instantiate(&[size.into(), elem_ty.into()])?;
let [arr, out] = self
.add_dataflow_op(op.to_extension_op().unwrap(), vec![input, index])?
.outputs_arr();
Ok((arr, out))
}
fn add_borrow_array_return(
&mut self,
elem_ty: Type,
size: u64,
input: Wire,
index: Wire,
value: Wire,
) -> Result<Wire, BuildError> {
let op = BArrayUnsafeOpDef::r#return.instantiate(&[size.into(), elem_ty.into()])?;
let [arr] = self
.add_dataflow_op(op.to_extension_op().unwrap(), vec![input, index, value])?
.outputs_arr();
Ok(arr)
}
fn add_discard_all_borrowed(
&mut self,
elem_ty: Type,
size: u64,
input: Wire,
) -> Result<(), BuildError> {
let op =
BArrayUnsafeOpDef::discard_all_borrowed.instantiate(&[size.into(), elem_ty.into()])?;
self.add_dataflow_op(op.to_extension_op().unwrap(), vec![input])?;
Ok(())
}
fn add_new_all_borrowed(&mut self, elem_ty: Type, size: u64) -> Result<Wire, BuildError> {
let op = BArrayUnsafeOpDef::new_all_borrowed.instantiate(&[size.into(), elem_ty.into()])?;
let [arr] = self
.add_dataflow_op(op.to_extension_op().unwrap(), vec![])?
.outputs_arr();
Ok(arr)
}
fn add_is_borrowed(
&mut self,
elem_ty: Type,
size: u64,
input: Wire,
index: Wire,
) -> Result<(Wire, Wire), BuildError> {
let op = BArrayUnsafeOpDef::is_borrowed.instantiate(&[size.into(), elem_ty.into()])?;
let [arr, is_borrowed] = self
.add_dataflow_op(op.to_extension_op().unwrap(), vec![input, index])?
.outputs_arr();
Ok((arr, is_borrowed))
}
}
impl<D: Dataflow> BArrayOpBuilder for D {}
#[cfg(test)]
mod test {
use strum::IntoEnumIterator;
use crate::{
builder::{DFGBuilder, Dataflow, DataflowHugr as _},
extension::prelude::{ConstUsize, qb_t, usize_t},
ops::OpType,
std_extensions::collections::borrow_array::{
BArrayOpBuilder, BArrayUnsafeOp, BArrayUnsafeOpDef, borrow_array_type,
},
types::Signature,
};
#[test]
fn test_borrow_array_unsafe_ops() {
for def in BArrayUnsafeOpDef::iter() {
let op = def.to_concrete(qb_t(), 2);
let optype: OpType = op.clone().into();
let new_op: BArrayUnsafeOp = optype.cast().unwrap();
assert_eq!(new_op, op);
}
}
#[test]
fn test_borrow_and_return() {
let size = 22;
let elem_ty = qb_t();
let arr_ty = borrow_array_type(size, elem_ty.clone());
let _ = {
let mut builder = DFGBuilder::new(Signature::new_endo([arr_ty.clone()])).unwrap();
let idx1 = builder.add_load_value(ConstUsize::new(11));
let idx2 = builder.add_load_value(ConstUsize::new(11));
let [arr] = builder.input_wires_arr();
let (arr_with_take, el) = builder
.add_borrow_array_borrow(elem_ty.clone(), size, arr, idx1)
.unwrap();
let arr_with_put = builder
.add_borrow_array_return(elem_ty, size, arr_with_take, idx2, el)
.unwrap();
builder.finish_hugr_with_outputs([arr_with_put]).unwrap()
};
}
#[test]
fn test_discard_all_borrowed() {
let size = 1;
let elem_ty = qb_t();
let arr_ty = borrow_array_type(size, elem_ty.clone());
let _ = {
let mut builder =
DFGBuilder::new(Signature::new(vec![arr_ty.clone()], vec![qb_t()])).unwrap();
let idx = builder.add_load_value(ConstUsize::new(0));
let [arr] = builder.input_wires_arr();
let (arr_with_borrowed, el) = builder
.add_borrow_array_borrow(elem_ty.clone(), size, arr, idx)
.unwrap();
builder
.add_discard_all_borrowed(elem_ty, size, arr_with_borrowed)
.unwrap();
builder.finish_hugr_with_outputs([el]).unwrap()
};
}
#[test]
fn test_new_all_borrowed() {
let size = 5;
let elem_ty = usize_t();
let arr_ty = borrow_array_type(size, elem_ty.clone());
let _ = {
let mut builder =
DFGBuilder::new(Signature::new(vec![], vec![arr_ty.clone()])).unwrap();
let arr = builder.add_new_all_borrowed(elem_ty.clone(), size).unwrap();
let idx = builder.add_load_value(ConstUsize::new(3));
let val = builder.add_load_value(ConstUsize::new(202));
let arr_with_put = builder
.add_borrow_array_return(elem_ty, size, arr, idx, val)
.unwrap();
builder.finish_hugr_with_outputs([arr_with_put]).unwrap()
};
}
#[test]
fn test_is_borrowed() {
let size = 4;
let elem_ty = qb_t();
let arr_ty = borrow_array_type(size, elem_ty.clone());
let mut builder =
DFGBuilder::new(Signature::new(vec![arr_ty.clone()], vec![qb_t(), arr_ty])).unwrap();
let idx = builder.add_load_value(ConstUsize::new(2));
let [arr] = builder.input_wires_arr();
let (arr_with_borrowed, qb) = builder
.add_borrow_array_borrow(elem_ty.clone(), size, arr, idx)
.unwrap();
let (arr_after_check, _is_borrowed) = builder
.add_is_borrowed(elem_ty.clone(), size, arr_with_borrowed, idx)
.unwrap();
builder
.finish_hugr_with_outputs([qb, arr_after_check])
.unwrap();
}
}