use itertools::Itertools;
use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
use crate::Extension;
use crate::extension::prelude::PRELUDE_ID;
use crate::extension::{SignatureFunc, TypeDef};
use crate::types::type_param::TypeParam;
use crate::types::{CustomType, FuncValueType, PolyFuncTypeRV, Type, TypeRowRV};
use super::{DeclarationContext, ExtensionDeclarationError};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(super) struct SignatureDeclaration {
inputs: Vec<SignaturePortDeclaration>,
outputs: Vec<SignaturePortDeclaration>,
}
impl SignatureDeclaration {
pub fn make_signature(
&self,
ext: &Extension,
ctx: DeclarationContext<'_>,
op_params: &[TypeParam],
) -> Result<SignatureFunc, ExtensionDeclarationError> {
let make_type_row =
|v: &[SignaturePortDeclaration]| -> Result<TypeRowRV, ExtensionDeclarationError> {
let types = v
.iter()
.map(|port_decl| port_decl.make_types(ext, ctx, op_params))
.flatten_ok()
.collect::<Result<Vec<Type>, _>>()?;
Ok(types.into())
};
let body = FuncValueType {
input: make_type_row(&self.inputs)?,
output: make_type_row(&self.outputs)?,
};
let poly_func = PolyFuncTypeRV::new(op_params, body);
Ok(poly_func.into())
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
enum SignaturePortDeclaration {
Type(TypeDeclaration),
TypeRepeat(TypeDeclaration, PortRepetitionDeclaration),
DescriptionTypeRepeat(String, TypeDeclaration, PortRepetitionDeclaration),
}
impl SignaturePortDeclaration {
fn make_types(
&self,
ext: &Extension,
ctx: DeclarationContext<'_>,
op_params: &[TypeParam],
) -> Result<impl Iterator<Item = Type>, ExtensionDeclarationError> {
let n: usize = match self.repeat() {
PortRepetitionDeclaration::Count(n) => *n,
PortRepetitionDeclaration::Parameter(parametric_repetition) => {
return Err(ExtensionDeclarationError::UnsupportedPortRepetition {
ext: ext.name().clone(),
parametric_repetition: parametric_repetition.clone(),
});
}
};
let ty = self.type_decl().make_type(ext, ctx, op_params)?;
let ty = Type::new_extension(ty);
Ok(itertools::repeat_n(ty, n))
}
fn type_decl(&self) -> &TypeDeclaration {
match self {
SignaturePortDeclaration::Type(ty) => ty,
SignaturePortDeclaration::TypeRepeat(ty, _) => ty,
SignaturePortDeclaration::DescriptionTypeRepeat(_, ty, _) => ty,
}
}
fn repeat(&self) -> &PortRepetitionDeclaration {
static DEFAULT_REPEAT: PortRepetitionDeclaration = PortRepetitionDeclaration::Count(1);
match self {
SignaturePortDeclaration::DescriptionTypeRepeat(_, _, repeat) => repeat,
SignaturePortDeclaration::TypeRepeat(_, repeat) => repeat,
_ => &DEFAULT_REPEAT,
}
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
enum PortRepetitionDeclaration {
Count(usize),
Parameter(SmolStr),
}
impl Default for PortRepetitionDeclaration {
fn default() -> Self {
PortRepetitionDeclaration::Count(1)
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
struct TypeDeclaration(
String,
);
impl TypeDeclaration {
pub fn make_type(
&self,
ext: &Extension,
ctx: DeclarationContext<'_>,
_op_params: &[TypeParam],
) -> Result<CustomType, ExtensionDeclarationError> {
let Some(type_def) = self.resolve_type(ext, ctx) else {
return Err(ExtensionDeclarationError::UnknownType {
ext: ext.name().clone(),
ty: self.0.clone(),
});
};
assert!(type_def.params().is_empty());
let op = type_def.instantiate(&[]).unwrap();
Ok(op)
}
fn resolve_type<'a>(
&'a self,
ext: &'a Extension,
ctx: DeclarationContext<'a>,
) -> Option<&'a TypeDef> {
debug_assert!(ctx.scope.contains(&PRELUDE_ID));
let prelude = ctx.registry.get(&PRELUDE_ID).unwrap();
match self.0.as_str() {
"USize" => return prelude.get_type("usize"),
"Q" => return prelude.get_type("qubit"),
_ => {}
}
if let Some(ty) = ext.get_type(&self.0) {
return Some(ty);
}
for ext in ctx.scope.iter() {
if let Some(ty) = ctx.registry.get(ext).and_then(|ext| ext.get_type(&self.0)) {
return Some(ty);
}
}
None
}
}