#![doc = include_str!("../../examples/extension/declarative.yaml")]
mod ops;
mod signature;
mod types;
use std::fs::File;
use std::path::Path;
use std::sync::Arc;
use crate::Extension;
use crate::extension::prelude::PRELUDE_ID;
use crate::ops::OpName;
use crate::types::TypeName;
use super::{
ExtensionBuildError, ExtensionId, ExtensionRegistry, ExtensionRegistryError, ExtensionSet,
PRELUDE,
};
use ops::OperationDeclaration;
use smol_str::SmolStr;
use types::TypeDeclaration;
use serde::{Deserialize, Serialize};
pub fn load_extensions(
yaml: &str,
registry: &mut ExtensionRegistry,
) -> Result<(), ExtensionDeclarationError> {
let ext: ExtensionSetDeclaration = serde_yaml::from_str(yaml)?;
ext.add_to_registry(registry)
}
pub fn load_extensions_file(
path: &Path,
registry: &mut ExtensionRegistry,
) -> Result<(), ExtensionDeclarationError> {
let file = File::open(path)?;
let ext: ExtensionSetDeclaration = serde_yaml::from_reader(file)?;
ext.add_to_registry(registry)
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
struct ExtensionSetDeclaration {
extensions: Vec<ExtensionDeclaration>,
#[serde(default)]
#[serde(skip_serializing_if = "crate::utils::is_default")]
imports: ExtensionSet,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
struct ExtensionDeclaration {
name: ExtensionId,
#[serde(default)]
#[serde(skip_serializing_if = "crate::utils::is_default")]
types: Vec<TypeDeclaration>,
#[serde(default)]
#[serde(skip_serializing_if = "crate::utils::is_default")]
operations: Vec<OperationDeclaration>,
}
impl ExtensionSetDeclaration {
pub fn add_to_registry(
&self,
registry: &mut ExtensionRegistry,
) -> Result<(), ExtensionDeclarationError> {
for imp in self.imports.iter() {
if !registry.contains(imp) {
return Err(ExtensionDeclarationError::MissingExtension { ext: imp.clone() });
}
}
let mut scope = self.imports.clone();
if !registry.contains(&PRELUDE_ID) {
registry.register(PRELUDE.clone())?;
}
if !scope.contains(&PRELUDE_ID) {
scope.insert(PRELUDE_ID);
}
for decl in &self.extensions {
let ctx = DeclarationContext {
scope: &scope,
registry,
};
let ext = decl.make_extension(&self.imports, ctx)?;
scope.insert(ext.name().clone());
registry.register(ext)?;
}
Ok(())
}
}
impl ExtensionDeclaration {
pub fn make_extension(
&self,
_imports: &ExtensionSet,
ctx: DeclarationContext<'_>,
) -> Result<Arc<Extension>, ExtensionDeclarationError> {
Extension::try_new_arc(
self.name.clone(),
crate::extension::Version::new(0, 0, 0),
|ext, extension_ref| {
for t in &self.types {
t.register(ext, ctx, extension_ref)?;
}
for o in &self.operations {
o.register(ext, ctx, extension_ref)?;
}
Ok(())
},
)
}
}
#[derive(Debug, Copy, Clone)]
struct DeclarationContext<'a> {
pub scope: &'a ExtensionSet,
pub registry: &'a ExtensionRegistry,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum ExtensionDeclarationError {
#[error("Error while parsing the extension set yaml: {0}")]
Deserialize(#[from] serde_yaml::Error),
#[error("Error registering the extensions.")]
ExtensionRegistryError(#[from] ExtensionRegistryError),
#[error("Error while adding operations or types to the extension: {0}")]
ExtensionBuildError(#[from] ExtensionBuildError),
#[error("Invalid yaml declaration file {0}")]
InvalidFile(#[from] std::io::Error),
#[error("Missing required extension {ext}")]
MissingExtension {
ext: ExtensionId,
},
#[error("Extension {ext} referenced an unknown type {ty}.")]
MissingType {
ext: ExtensionId,
ty: TypeName,
},
#[error("Found a currently unsupported higher-order type parameter {ty} in extension {ext}")]
ParametricTypeParameter {
ext: ExtensionId,
ty: TypeName,
},
#[error("Found a currently unsupported parametric operation {op} in extension {ext}")]
ParametricOperation {
ext: ExtensionId,
op: OpName,
},
#[error("Operation {op} in extension {ext} has no signature. This is not currently supported.")]
MissingSignature {
ext: ExtensionId,
op: OpName,
},
#[error("Type {ty} is not in scope. In extension {ext}.")]
UnknownType {
ext: ExtensionId,
ty: String,
},
#[error("Unsupported port repetition {parametric_repetition} in extension {ext}")]
UnsupportedPortRepetition {
ext: crate::hugr::IdentList,
parametric_repetition: SmolStr,
},
#[error("Unsupported lowering definition for op {op} in extension {ext}")]
LoweringNotSupported {
ext: crate::hugr::IdentList,
op: OpName,
},
}
#[cfg(test)]
mod test {
use itertools::Itertools;
use rstest::rstest;
use std::path::PathBuf;
use std::sync::Arc;
use crate::extension::PRELUDE_REGISTRY;
use crate::std_extensions;
use super::*;
const EMPTY_YAML: &str = r#"
extensions:
- name: EmptyExt
"#;
const BASIC_YAML: &str = r#"
imports: [prelude]
extensions:
- name: SimpleExt
types:
- name: MyType
description: A simple type with no parameters
bound: Any
operations:
- name: MyOperation
description: A simple operation with no inputs nor outputs
signature:
inputs: []
outputs: []
- name: AnotherOperation
description: An operation from 3 qubits to 3 qubits
signature:
inputs: [MyType, Q, Q]
outputs: [[MyType, 1], [Control, Q, 2]]
"#;
const UNSUPPORTED_YAML: &str = r#"
extensions:
- name: UnsupportedExt
types:
- name: MyType
description: A simple type with no parameters
# Parametric types are not currently supported.
params: [String]
operations:
- name: UnsupportedOperation
description: An operation from 3 qubits to 3 qubits
params:
# Parametric operations are not currently supported.
param1: String
signature:
# Type declarations will have their own syntax.
inputs: []
outputs: ["Array<param1>[USize]"]
"#;
const EXAMPLE_YAML_FILE: &str = "examples/extension/declarative.yaml";
#[rstest]
#[case(EMPTY_YAML, 1, 0, 0, &PRELUDE_REGISTRY)]
#[case(BASIC_YAML, 1, 1, 2, &PRELUDE_REGISTRY)]
fn test_decode(
#[case] yaml: &str,
#[case] num_declarations: usize,
#[case] num_types: usize,
#[case] num_operations: usize,
#[case] dependencies: &ExtensionRegistry,
) -> Result<(), Box<dyn std::error::Error>> {
let mut reg = dependencies.clone();
load_extensions(yaml, &mut reg)?;
let new_exts = new_extensions(®, dependencies).collect_vec();
assert_eq!(new_exts.len(), num_declarations);
assert_eq!(new_exts.iter().flat_map(|e| e.types()).count(), num_types);
assert_eq!(
new_exts.iter().flat_map(|e| e.operations()).count(),
num_operations
);
Ok(())
}
#[cfg_attr(miri, ignore)] #[rstest]
#[case(EXAMPLE_YAML_FILE, 1, 1, 3, &std_extensions::STD_REG)]
fn test_decode_file(
#[case] yaml_file: &str,
#[case] num_declarations: usize,
#[case] num_types: usize,
#[case] num_operations: usize,
#[case] dependencies: &ExtensionRegistry,
) -> Result<(), Box<dyn std::error::Error>> {
let mut reg = dependencies.clone();
load_extensions_file(&PathBuf::from(yaml_file), &mut reg)?;
let new_exts = new_extensions(®, dependencies).collect_vec();
assert_eq!(new_exts.len(), num_declarations);
assert_eq!(new_exts.iter().flat_map(|e| e.types()).count(), num_types);
assert_eq!(
new_exts.iter().flat_map(|e| e.operations()).count(),
num_operations
);
Ok(())
}
#[rstest]
#[case(UNSUPPORTED_YAML, &PRELUDE_REGISTRY)]
fn test_unsupported(
#[case] yaml: &str,
#[case] dependencies: &ExtensionRegistry,
) -> Result<(), Box<dyn std::error::Error>> {
let mut reg = dependencies.clone();
let ext: ExtensionSetDeclaration = serde_yaml::from_str(yaml)?;
assert!(ext.add_to_registry(&mut reg).is_err());
Ok(())
}
fn new_extensions<'a>(
reg: &'a ExtensionRegistry,
dependencies: &'a ExtensionRegistry,
) -> impl Iterator<Item = &'a Arc<Extension>> {
reg.iter()
.filter(move |ext| !dependencies.contains(ext.name()) && ext.name() != &PRELUDE_ID)
}
}