use std::collections::HashMap;
use crate::envelope::HugrUsedExtensions;
use crate::metadata;
use crate::{
HugrView, Node,
envelope::EnvelopeHeader,
ops::{DataflowOpTrait, OpType},
};
use itertools::Itertools;
use semver::Version;
type OptionVec<T> = Vec<Option<T>>;
fn set_option_vec_len<T: Clone>(vec: &mut OptionVec<T>, n: usize) {
vec.resize(n, None);
}
fn set_option_vec_index<T: Clone>(vec: &mut OptionVec<T>, index: usize, value: T) {
if index >= vec.len() {
set_option_vec_len(vec, index + 1);
}
vec[index] = Some(value);
}
fn extend_option_vec<T: Clone>(vec: &mut Option<Vec<T>>, items: impl IntoIterator<Item = T>) {
if let Some(existing) = vec {
existing.extend(items);
} else {
vec.replace(items.into_iter().collect());
}
}
#[derive(Debug, Clone, PartialEq, Default, serde::Serialize, schemars::JsonSchema)]
pub struct PackageDesc {
#[serde(serialize_with = "header_serialize")]
#[schemars(with = "String")]
pub header: EnvelopeHeader,
pub modules: OptionVec<ModuleDesc>,
#[serde(skip_serializing_if = "Vec::is_empty")]
#[serde(default)]
pub packaged_extensions: OptionVec<ExtensionDesc>,
}
fn header_serialize<S>(header: &EnvelopeHeader, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&header.to_string())
}
impl PackageDesc {
pub(super) fn new(header: EnvelopeHeader) -> Self {
Self {
header,
..Default::default()
}
}
pub(crate) fn set_n_modules(&mut self, n: usize) {
set_option_vec_len(&mut self.modules, n);
}
pub fn header(&self) -> EnvelopeHeader {
self.header
}
pub fn n_modules(&self) -> usize {
self.modules.len()
}
pub(crate) fn set_module(&mut self, index: usize, module: impl Into<ModuleDesc>) {
set_option_vec_index(&mut self.modules, index, module.into());
}
pub(crate) fn set_packaged_extension(&mut self, index: usize, ext: impl Into<ExtensionDesc>) {
set_option_vec_index(&mut self.packaged_extensions, index, ext.into());
}
pub fn n_packaged_extensions(&self) -> usize {
self.packaged_extensions.len()
}
pub fn generator(&self) -> Option<String> {
let generators: Vec<String> = self
.modules
.iter()
.flatten()
.flat_map(|m| &m.generator)
.unique()
.cloned()
.collect();
if generators.is_empty() {
return None;
}
Some(generators.join(", "))
}
pub fn modules(&self) -> impl Iterator<Item = &Option<ModuleDesc>> {
self.modules.iter()
}
pub fn packaged_extensions(&self) -> impl Iterator<Item = &ExtensionDesc> {
self.packaged_extensions.iter().flatten()
}
}
#[derive(
derive_more::Display,
Debug,
Clone,
PartialEq,
serde::Deserialize,
serde::Serialize,
schemars::JsonSchema,
)]
#[display("Extension {name} v{version}")]
pub struct ExtensionDesc {
pub name: String,
#[schemars(with = "String")]
pub version: Version,
}
impl ExtensionDesc {
pub fn new(name: impl ToString, version: impl Into<Version>) -> Self {
Self {
name: name.to_string(),
version: version.into(),
}
}
pub fn new_unversioned(name: impl ToString) -> Self {
Self {
name: name.to_string(),
version: Version::new(0, 0, 0),
}
}
}
impl<E: AsRef<crate::Extension>> From<&E> for ExtensionDesc {
fn from(ext: &E) -> Self {
let ext = ext.as_ref();
Self {
name: ext.name.to_string(),
version: ext.version.clone(),
}
}
}
#[derive(derive_more::Display, Debug, Clone, PartialEq)]
pub enum GeneratorDesc {
#[display("{}", if let Some(version) = version {
format!("{name}-v{version}")
} else {
name.to_string()
})]
Structured {
name: String,
version: Option<Version>,
},
#[display("{description}")]
Flat {
description: String,
},
}
impl serde::Serialize for GeneratorDesc {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl GeneratorDesc {
pub fn new(name: impl ToString, version: impl Into<Version>) -> Self {
Self::Structured {
name: name.to_string(),
version: Some(version.into()),
}
}
pub fn new_unversioned(name: impl ToString) -> Self {
Self::Structured {
name: name.to_string(),
version: None,
}
}
}
impl<'de> serde::de::Deserialize<'de> for GeneratorDesc {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct Helper;
impl<'vis> serde::de::Visitor<'vis> for Helper {
type Value = GeneratorDesc;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("a string-encoded envelope")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
let value = value.to_string();
if let Some((name, version)) = value.split_once("-v")
&& let Some(version) = version.parse::<Version>().ok()
{
return Ok(GeneratorDesc::Structured {
name: name.to_string(),
version: Some(version),
});
};
Ok(GeneratorDesc::Flat {
description: value.to_string(),
})
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'vis>,
{
let values: HashMap<String, serde_json::Value> = std::iter::from_fn(|| {
let key = map.next_key::<String>().ok()??;
let value = map.next_value::<serde_json::Value>().ok()?;
Some((key, value))
})
.collect();
let name = values
.get("name")
.and_then(|n| n.as_str())
.map(ToString::to_string);
let version = values
.get("version")
.and_then(|v| v.as_str())
.map(|v| v.parse::<Version>());
let other_fields = values.keys().any(|k| k != "name" && k != "version");
match (other_fields, name, version) {
(false, Some(name), Some(Ok(version))) => Ok(GeneratorDesc::Structured {
name,
version: Some(version),
}),
(false, Some(name), None) => Ok(GeneratorDesc::Structured {
name,
version: None,
}),
(_, _, _) => Ok(GeneratorDesc::Flat {
description: values
.into_iter()
.map(|(k, v)| format!("{k}: {v}"))
.join("\n"),
}),
}
}
}
deserializer.deserialize_any(Helper)
}
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
pub struct Entrypoint {
#[schemars(with = "u32")]
pub node: Node,
#[schemars(with = "String")]
#[serde(serialize_with = "op_serialize")]
pub optype: OpType,
}
impl Entrypoint {
pub fn new(node: Node, optype: OpType) -> Self {
Self { node, optype }
}
}
pub fn op_string(op: &OpType) -> String {
match op {
OpType::FuncDefn(defn) => format!(
"FuncDefn({})",
func_symbol(defn.func_name(), defn.signature())
),
OpType::FuncDecl(decl) => format!(
"FuncDecl({})",
func_symbol(decl.func_name(), decl.signature())
),
OpType::DFG(dfg) => format!("DFG({})", dfg.signature()),
_ => format!("{op}"),
}
}
fn op_serialize<S>(op_type: &OpType, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(op_string(op_type).as_str())
}
#[derive(
Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize, schemars::JsonSchema,
)]
pub struct ModuleDesc {
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub num_nodes: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub entrypoint: Option<Entrypoint>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub used_extensions_resolved: Option<Vec<ExtensionDesc>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub generator: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub used_extensions_generator: Option<Vec<ExtensionDesc>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub public_symbols: Option<Vec<String>>,
}
impl ModuleDesc {
pub fn set_num_nodes(&mut self, num_nodes: usize) {
self.num_nodes = Some(num_nodes);
}
pub fn set_entrypoint(&mut self, node: Node, optype: OpType) {
self.entrypoint = Some(Entrypoint::new(node, optype));
}
pub fn set_generator(&mut self, generator: impl ToString) {
self.generator = Some(generator.to_string());
}
pub fn set_used_extensions_generator(
&mut self,
used_extensions_metadata: impl IntoIterator<Item = ExtensionDesc>,
) {
self.used_extensions_generator = Some(used_extensions_metadata.into_iter().collect());
}
pub fn extend_used_extensions_metadata(
&mut self,
exts: impl IntoIterator<Item = ExtensionDesc>,
) {
extend_option_vec(&mut self.used_extensions_generator, exts);
}
pub fn set_used_extensions_resolved(
&mut self,
used_extensions_resolved: impl IntoIterator<Item = ExtensionDesc>,
) {
self.used_extensions_resolved = Some(used_extensions_resolved.into_iter().collect());
}
pub fn extend_used_extensions_resolved(
&mut self,
exts: impl IntoIterator<Item = ExtensionDesc>,
) {
extend_option_vec(&mut self.used_extensions_resolved, exts);
}
pub fn set_public_symbols(&mut self, symbols: impl IntoIterator<Item = String>) {
self.public_symbols = Some(symbols.into_iter().collect());
}
pub fn extend_public_symbols(&mut self, symbols: impl IntoIterator<Item = String>) {
extend_option_vec(&mut self.public_symbols, symbols);
}
pub(crate) fn load_generator(&mut self, hugr: &impl HugrView) {
if let Some(val) = hugr.get_metadata::<metadata::HugrGenerator>(hugr.module_root()) {
self.set_generator(val);
}
}
pub(crate) fn load_used_extensions_generator(
&mut self,
hugr: &impl HugrView,
) -> Result<(), serde_json::Error> {
let Some(used_exts) = hugr.get_metadata::<HugrUsedExtensions>(hugr.module_root()) else {
return Ok(()); };
self.set_used_extensions_generator(used_exts);
Ok(())
}
pub(crate) fn load_used_extensions_resolved(&mut self, hugr: &impl HugrView) {
self.set_used_extensions_resolved(
hugr.extensions()
.iter()
.map(|ext| ExtensionDesc::new(&ext.name, ext.version.clone())),
)
}
pub(crate) fn load_public_symbols(&mut self, hugr: &impl HugrView) {
let symbols = hugr
.children(hugr.module_root())
.filter_map(|n| match hugr.get_optype(n) {
OpType::FuncDecl(decl) if *decl.visibility() == crate::Visibility::Public => {
Some(func_symbol(decl.func_name(), decl.signature()))
}
OpType::FuncDefn(defn) if *defn.visibility() == crate::Visibility::Public => {
Some(func_symbol(defn.func_name(), defn.signature()))
}
_ => None,
});
self.set_public_symbols(symbols);
}
pub(crate) fn load_entrypoint(&mut self, hugr: &impl HugrView<Node = Node>) {
let node = hugr.entrypoint();
self.set_entrypoint(node, hugr.get_optype(node).clone());
}
pub(crate) fn load_num_nodes(&mut self, hugr: &impl HugrView) {
self.set_num_nodes(hugr.num_nodes());
}
pub(crate) fn load_from_hugr(&mut self, hugr: &impl HugrView<Node = Node>) {
self.load_num_nodes(hugr);
self.load_entrypoint(hugr);
self.load_generator(hugr);
self.load_used_extensions_resolved(hugr);
self.load_public_symbols(hugr);
self.load_used_extensions_generator(hugr).ok();
}
}
fn func_symbol(name: &str, signature: &crate::types::PolyFuncType) -> String {
format!("{name}: {}", signature)
}
impl<H: HugrView<Node = Node>> From<&H> for ModuleDesc {
fn from(hugr: &H) -> Self {
let mut desc = ModuleDesc::default();
desc.load_from_hugr(hugr);
desc
}
}
#[cfg(test)]
mod test {
use super::*;
use rstest::{fixture, rstest};
use semver::Version;
#[fixture]
fn empty_package_desc() -> PackageDesc {
PackageDesc::default()
}
#[fixture]
fn empty_module_desc() -> ModuleDesc {
ModuleDesc::default()
}
#[fixture]
fn test_extension() -> ExtensionDesc {
ExtensionDesc::new("test_ext", Version::new(1, 0, 0))
}
#[rstest]
fn test_package_desc_new() {
let header = EnvelopeHeader::default();
let package = PackageDesc::new(header);
assert_eq!(package.header(), header);
assert_eq!(package.n_modules(), 0);
assert_eq!(package.n_packaged_extensions(), 0);
}
#[rstest]
fn test_package_desc_set_n_modules(mut empty_package_desc: PackageDesc) {
empty_package_desc.set_n_modules(5);
assert_eq!(empty_package_desc.n_modules(), 5);
}
#[rstest]
fn test_package_desc_set_module(
mut empty_package_desc: PackageDesc,
empty_module_desc: ModuleDesc,
) {
empty_package_desc.set_module(0, empty_module_desc.clone());
assert_eq!(
empty_package_desc.modules().next().unwrap().as_ref(),
Some(&empty_module_desc)
);
}
#[rstest]
fn test_package_desc_set_packaged_extension(
mut empty_package_desc: PackageDesc,
test_extension: ExtensionDesc,
) {
empty_package_desc.set_packaged_extension(0, test_extension.clone());
assert_eq!(
empty_package_desc.packaged_extensions().next(),
Some(&test_extension)
);
}
#[rstest]
fn test_package_desc_generator(mut empty_package_desc: PackageDesc) {
let mut module = ModuleDesc::default();
module.set_generator("test_generator");
empty_package_desc.set_module(0, module);
assert_eq!(
empty_package_desc.generator(),
Some("test_generator".to_string())
);
}
#[rstest]
fn test_module_desc_set_num_nodes(mut empty_module_desc: ModuleDesc) {
empty_module_desc.set_num_nodes(10);
assert_eq!(empty_module_desc.num_nodes, Some(10));
}
#[rstest]
fn test_module_desc_set_entrypoint(mut empty_module_desc: ModuleDesc) {
let node = Node::from(portgraph::NodeIndex::new(0));
let optype: OpType = crate::ops::DFG {
signature: Default::default(),
}
.into();
empty_module_desc.set_entrypoint(node, optype.clone());
assert_eq!(empty_module_desc.entrypoint.as_ref().unwrap().node, node);
assert_eq!(
empty_module_desc.entrypoint.as_ref().unwrap().optype,
optype
);
}
#[rstest]
#[case("test_generator", Some("test_generator".to_string()))]
#[case("", None)]
fn test_module_desc_generator(#[case] input: &str, #[case] expected: Option<String>) {
let mut module = ModuleDesc::default();
if !input.is_empty() {
module.set_generator(input);
}
assert_eq!(module.generator, expected);
}
#[test]
fn test_extension_desc_new() {
let name = "test_extension";
let version = Version::new(1, 0, 0);
let extension = ExtensionDesc::new(name, version.clone());
assert_eq!(extension.name, name);
assert_eq!(extension.version, version);
}
#[rstest]
fn test_package_desc_n_packaged_extensions(
mut empty_package_desc: PackageDesc,
test_extension: ExtensionDesc,
) {
assert_eq!(empty_package_desc.n_packaged_extensions(), 0);
empty_package_desc.set_packaged_extension(0, test_extension);
assert_eq!(empty_package_desc.n_packaged_extensions(), 1);
}
#[rstest]
fn test_package_desc_modules_iterator(
mut empty_package_desc: PackageDesc,
empty_module_desc: ModuleDesc,
) {
empty_package_desc.set_module(0, empty_module_desc.clone());
let modules: Vec<_> = empty_package_desc.modules().collect();
assert_eq!(modules.len(), 1);
assert_eq!(modules[0].as_ref(), Some(&empty_module_desc));
}
#[rstest]
fn test_package_desc_packaged_extensions_iterator(
mut empty_package_desc: PackageDesc,
test_extension: ExtensionDesc,
) {
empty_package_desc.set_packaged_extension(0, test_extension.clone());
let extensions: Vec<_> = empty_package_desc.packaged_extensions().collect();
assert_eq!(extensions.len(), 1);
assert_eq!(extensions[0], &test_extension);
}
#[rstest]
fn test_module_desc_set_used_extensions_generator(
mut empty_module_desc: ModuleDesc,
test_extension: ExtensionDesc,
) {
empty_module_desc.set_used_extensions_generator(vec![test_extension.clone()]);
assert_eq!(
empty_module_desc
.used_extensions_generator
.as_ref()
.unwrap()
.len(),
1
);
assert_eq!(
empty_module_desc
.used_extensions_generator
.as_ref()
.unwrap()[0],
test_extension
);
}
#[rstest]
fn test_module_desc_extend_used_extensions_metadata(mut empty_module_desc: ModuleDesc) {
let extension1 = ExtensionDesc::new("test_ext1", Version::new(1, 0, 0));
let extension2 = ExtensionDesc::new("test_ext2", Version::new(2, 0, 0));
empty_module_desc.set_used_extensions_generator(vec![extension1.clone()]);
empty_module_desc.extend_used_extensions_metadata(vec![extension2.clone()]);
let extensions = empty_module_desc
.used_extensions_generator
.as_ref()
.unwrap();
assert_eq!(extensions.len(), 2);
assert!(extensions.contains(&extension1));
assert!(extensions.contains(&extension2));
}
#[rstest]
fn test_module_desc_set_public_symbols(mut empty_module_desc: ModuleDesc) {
let symbols = vec!["symbol1".to_string(), "symbol2".to_string()];
empty_module_desc.set_public_symbols(symbols.clone());
assert_eq!(empty_module_desc.public_symbols.as_ref().unwrap().len(), 2);
assert_eq!(empty_module_desc.public_symbols.as_ref().unwrap(), &symbols);
}
#[rstest]
fn test_module_desc_extend_public_symbols(mut empty_module_desc: ModuleDesc) {
let symbols1 = vec!["symbol1".to_string()];
let symbols2 = vec!["symbol2".to_string()];
empty_module_desc.set_public_symbols(symbols1.clone());
empty_module_desc.extend_public_symbols(symbols2.clone());
let symbols = empty_module_desc.public_symbols.as_ref().unwrap();
assert_eq!(symbols.len(), 2);
assert!(symbols.contains(&"symbol1".to_string()));
assert!(symbols.contains(&"symbol2".to_string()));
}
}