1use crate::{
3 HugrView, Node,
4 envelope::{EnvelopeHeader, USED_EXTENSIONS_KEY},
5 ops::{DataflowOpTrait, OpType},
6};
7use itertools::Itertools;
8use semver::Version;
9
10type OptionVec<T> = Vec<Option<T>>;
11fn set_option_vec_len<T: Clone>(vec: &mut OptionVec<T>, n: usize) {
12 vec.resize(n, None);
13}
14fn set_option_vec_index<T: Clone>(vec: &mut OptionVec<T>, index: usize, value: T) {
15 if index >= vec.len() {
16 set_option_vec_len(vec, index + 1);
17 }
18 vec[index] = Some(value);
19}
20
21fn extend_option_vec<T: Clone>(vec: &mut Option<Vec<T>>, items: impl IntoIterator<Item = T>) {
22 if let Some(existing) = vec {
23 existing.extend(items);
24 } else {
25 vec.replace(items.into_iter().collect());
26 }
27}
28
29#[derive(Debug, Clone, PartialEq, Default, serde::Serialize, schemars::JsonSchema)]
31pub struct PackageDesc {
32 #[serde(serialize_with = "header_serialize")]
34 #[schemars(with = "String")]
35 pub header: EnvelopeHeader,
36 pub modules: OptionVec<ModuleDesc>,
38 #[serde(skip_serializing_if = "Vec::is_empty")]
40 #[serde(default)]
41 pub packaged_extensions: OptionVec<ExtensionDesc>,
42}
43
44fn header_serialize<S>(header: &EnvelopeHeader, serializer: S) -> Result<S::Ok, S::Error>
45where
46 S: serde::Serializer,
47{
48 serializer.serialize_str(&header.to_string())
49}
50
51impl PackageDesc {
52 pub(super) fn new(header: EnvelopeHeader) -> Self {
54 Self {
55 header,
56 ..Default::default()
57 }
58 }
59
60 pub(crate) fn set_n_modules(&mut self, n: usize) {
62 set_option_vec_len(&mut self.modules, n);
63 }
64
65 pub fn header(&self) -> EnvelopeHeader {
67 self.header
68 }
69
70 pub fn n_modules(&self) -> usize {
72 self.modules.len()
73 }
74
75 pub(crate) fn set_module(&mut self, index: usize, module: impl Into<ModuleDesc>) {
77 set_option_vec_index(&mut self.modules, index, module.into());
78 }
79
80 pub(crate) fn set_packaged_extension(&mut self, index: usize, ext: impl Into<ExtensionDesc>) {
82 set_option_vec_index(&mut self.packaged_extensions, index, ext.into());
83 }
84
85 pub fn n_packaged_extensions(&self) -> usize {
87 self.packaged_extensions.len()
88 }
89
90 pub fn generator(&self) -> Option<String> {
93 let generators: Vec<String> = self
94 .modules
95 .iter()
96 .flatten()
97 .flat_map(|m| &m.generator)
98 .unique()
99 .cloned()
100 .collect();
101 if generators.is_empty() {
102 return None;
103 }
104
105 Some(generators.join(", "))
106 }
107
108 pub fn modules(&self) -> impl Iterator<Item = &Option<ModuleDesc>> {
111 self.modules.iter()
112 }
113
114 pub fn packaged_extensions(&self) -> impl Iterator<Item = &ExtensionDesc> {
116 self.packaged_extensions.iter().flatten()
117 }
118}
119
120#[derive(
122 derive_more::Display,
123 Debug,
124 Clone,
125 PartialEq,
126 serde::Deserialize,
127 serde::Serialize,
128 schemars::JsonSchema,
129)]
130#[display("Extension {name} v{version}")]
131pub struct ExtensionDesc {
132 pub name: String,
134 #[schemars(with = "String")]
136 pub version: Version,
137}
138
139impl ExtensionDesc {
140 pub fn new(name: impl ToString, version: impl Into<Version>) -> Self {
142 Self {
143 name: name.to_string(),
144 version: version.into(),
145 }
146 }
147}
148
149impl<E: AsRef<crate::Extension>> From<&E> for ExtensionDesc {
150 fn from(ext: &E) -> Self {
151 let ext = ext.as_ref();
152 Self {
153 name: ext.name.to_string(),
154 version: ext.version.clone(),
155 }
156 }
157}
158
159#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
160pub struct Entrypoint {
162 #[schemars(with = "u32")]
164 pub node: Node,
165 #[schemars(with = "String")]
166 #[serde(serialize_with = "op_serialize")]
167 pub optype: OpType,
169}
170
171impl Entrypoint {
172 pub fn new(node: Node, optype: OpType) -> Self {
174 Self { node, optype }
175 }
176}
177
178pub fn op_string(op: &OpType) -> String {
180 match op {
181 OpType::FuncDefn(defn) => format!(
182 "FuncDefn({})",
183 func_symbol(defn.func_name(), defn.signature())
184 ),
185 OpType::FuncDecl(decl) => format!(
186 "FuncDecl({})",
187 func_symbol(decl.func_name(), decl.signature())
188 ),
189 OpType::DFG(dfg) => format!("DFG({})", dfg.signature()),
190 _ => format!("{op}"),
191 }
192}
193fn op_serialize<S>(op_type: &OpType, serializer: S) -> Result<S::Ok, S::Error>
194where
195 S: serde::Serializer,
196{
197 serializer.serialize_str(op_string(op_type).as_str())
198}
199
200#[derive(
201 Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize, schemars::JsonSchema,
202)]
203pub struct ModuleDesc {
205 #[serde(skip_serializing_if = "Option::is_none")]
207 #[serde(default)]
208 pub num_nodes: Option<usize>,
209 #[serde(skip_serializing_if = "Option::is_none")]
211 #[serde(default)]
212 pub entrypoint: Option<Entrypoint>,
213 #[serde(skip_serializing_if = "Option::is_none")]
215 #[serde(default)]
216 pub used_extensions_resolved: Option<Vec<ExtensionDesc>>,
217 #[serde(skip_serializing_if = "Option::is_none")]
219 #[serde(default)]
220 pub generator: Option<String>,
221 #[serde(skip_serializing_if = "Option::is_none")]
223 #[serde(default)]
224 pub used_extensions_generator: Option<Vec<ExtensionDesc>>,
225 #[serde(skip_serializing_if = "Option::is_none")]
227 #[serde(default)]
228 pub public_symbols: Option<Vec<String>>,
229}
230
231impl ModuleDesc {
232 pub fn set_num_nodes(&mut self, num_nodes: usize) {
234 self.num_nodes = Some(num_nodes);
235 }
236
237 pub fn set_entrypoint(&mut self, node: Node, optype: OpType) {
239 self.entrypoint = Some(Entrypoint::new(node, optype));
240 }
241
242 pub fn set_generator(&mut self, generator: impl Into<String>) {
244 self.generator = Some(generator.into());
245 }
246
247 pub fn set_used_extensions_generator(
249 &mut self,
250 used_extensions_metadata: impl IntoIterator<Item = ExtensionDesc>,
251 ) {
252 self.used_extensions_generator = Some(used_extensions_metadata.into_iter().collect());
253 }
254
255 pub fn extend_used_extensions_metadata(
257 &mut self,
258 exts: impl IntoIterator<Item = ExtensionDesc>,
259 ) {
260 extend_option_vec(&mut self.used_extensions_generator, exts);
261 }
262
263 pub fn set_used_extensions_resolved(
265 &mut self,
266 used_extensions_resolved: impl IntoIterator<Item = ExtensionDesc>,
267 ) {
268 self.used_extensions_resolved = Some(used_extensions_resolved.into_iter().collect());
269 }
270
271 pub fn extend_used_extensions_resolved(
273 &mut self,
274 exts: impl IntoIterator<Item = ExtensionDesc>,
275 ) {
276 extend_option_vec(&mut self.used_extensions_resolved, exts);
277 }
278
279 pub fn set_public_symbols(&mut self, symbols: impl IntoIterator<Item = String>) {
281 self.public_symbols = Some(symbols.into_iter().collect());
282 }
283
284 pub fn extend_public_symbols(&mut self, symbols: impl IntoIterator<Item = String>) {
286 extend_option_vec(&mut self.public_symbols, symbols);
287 }
288
289 pub(crate) fn load_generator(&mut self, hugr: &impl HugrView) {
291 if let Some(val) = hugr.get_metadata(hugr.module_root(), crate::envelope::GENERATOR_KEY) {
292 self.set_generator(super::format_generator(val));
293 }
294 }
295
296 pub(crate) fn load_used_extensions_generator(
298 &mut self,
299 hugr: &impl HugrView,
300 ) -> Result<(), serde_json::Error> {
301 let Some(exts) = hugr.get_metadata(hugr.module_root(), USED_EXTENSIONS_KEY) else {
302 return Ok(()); };
304 let used_exts: Vec<ExtensionDesc> = serde_json::from_value(exts.clone())?;
305
306 self.set_used_extensions_generator(used_exts);
307 Ok(())
308 }
309
310 pub(crate) fn load_used_extensions_resolved(&mut self, hugr: &impl HugrView) {
312 self.set_used_extensions_resolved(
313 hugr.extensions()
314 .iter()
315 .map(|ext| ExtensionDesc::new(&ext.name, ext.version.clone())),
316 )
317 }
318
319 pub(crate) fn load_public_symbols(&mut self, hugr: &impl HugrView) {
321 let symbols = hugr
322 .children(hugr.module_root())
323 .filter_map(|n| match hugr.get_optype(n) {
324 OpType::FuncDecl(decl) if *decl.visibility() == crate::Visibility::Public => {
325 Some(func_symbol(decl.func_name(), decl.signature()))
326 }
327 OpType::FuncDefn(defn) if *defn.visibility() == crate::Visibility::Public => {
328 Some(func_symbol(defn.func_name(), defn.signature()))
329 }
330 _ => None,
331 });
332
333 self.set_public_symbols(symbols);
334 }
335
336 pub(crate) fn load_entrypoint(&mut self, hugr: &impl HugrView<Node = Node>) {
338 let node = hugr.entrypoint();
339 self.set_entrypoint(node, hugr.get_optype(node).clone());
340 }
341
342 pub(crate) fn load_num_nodes(&mut self, hugr: &impl HugrView) {
344 self.set_num_nodes(hugr.num_nodes());
345 }
346
347 pub(crate) fn load_from_hugr(&mut self, hugr: &impl HugrView<Node = Node>) {
349 self.load_num_nodes(hugr);
350 self.load_entrypoint(hugr);
351 self.load_generator(hugr);
352 self.load_used_extensions_resolved(hugr);
353 self.load_public_symbols(hugr);
354 self.load_used_extensions_generator(hugr).ok();
356 }
357}
358
359fn func_symbol(name: &str, signature: &crate::types::PolyFuncType) -> String {
360 format!("{name}: {}", signature)
361}
362impl<H: HugrView<Node = Node>> From<&H> for ModuleDesc {
363 fn from(hugr: &H) -> Self {
364 let mut desc = ModuleDesc::default();
365 desc.load_from_hugr(hugr);
366 desc
367 }
368}
369
370#[cfg(test)]
371mod test {
372 use super::*;
373 use rstest::{fixture, rstest};
374 use semver::Version;
375
376 #[fixture]
377 fn empty_package_desc() -> PackageDesc {
378 PackageDesc::default()
379 }
380
381 #[fixture]
382 fn empty_module_desc() -> ModuleDesc {
383 ModuleDesc::default()
384 }
385
386 #[fixture]
387 fn test_extension() -> ExtensionDesc {
388 ExtensionDesc::new("test_ext", Version::new(1, 0, 0))
389 }
390
391 #[rstest]
392 fn test_package_desc_new() {
393 let header = EnvelopeHeader::default();
394 let package = PackageDesc::new(header);
395 assert_eq!(package.header(), header);
396 assert_eq!(package.n_modules(), 0);
397 assert_eq!(package.n_packaged_extensions(), 0);
398 }
399
400 #[rstest]
401 fn test_package_desc_set_n_modules(mut empty_package_desc: PackageDesc) {
402 empty_package_desc.set_n_modules(5);
403 assert_eq!(empty_package_desc.n_modules(), 5);
404 }
405
406 #[rstest]
407 fn test_package_desc_set_module(
408 mut empty_package_desc: PackageDesc,
409 empty_module_desc: ModuleDesc,
410 ) {
411 empty_package_desc.set_module(0, empty_module_desc.clone());
412 assert_eq!(
413 empty_package_desc.modules().next().unwrap().as_ref(),
414 Some(&empty_module_desc)
415 );
416 }
417
418 #[rstest]
419 fn test_package_desc_set_packaged_extension(
420 mut empty_package_desc: PackageDesc,
421 test_extension: ExtensionDesc,
422 ) {
423 empty_package_desc.set_packaged_extension(0, test_extension.clone());
424 assert_eq!(
425 empty_package_desc.packaged_extensions().next(),
426 Some(&test_extension)
427 );
428 }
429
430 #[rstest]
431 fn test_package_desc_generator(mut empty_package_desc: PackageDesc) {
432 let mut module = ModuleDesc::default();
433 module.set_generator("test_generator");
434 empty_package_desc.set_module(0, module);
435 assert_eq!(
436 empty_package_desc.generator(),
437 Some("test_generator".to_string())
438 );
439 }
440
441 #[rstest]
442 fn test_module_desc_set_num_nodes(mut empty_module_desc: ModuleDesc) {
443 empty_module_desc.set_num_nodes(10);
444 assert_eq!(empty_module_desc.num_nodes, Some(10));
445 }
446
447 #[rstest]
448 fn test_module_desc_set_entrypoint(mut empty_module_desc: ModuleDesc) {
449 let node = Node::from(portgraph::NodeIndex::new(0));
450 let optype: OpType = crate::ops::DFG {
451 signature: Default::default(),
452 }
453 .into();
454 empty_module_desc.set_entrypoint(node, optype.clone());
455 assert_eq!(empty_module_desc.entrypoint.as_ref().unwrap().node, node);
456 assert_eq!(
457 empty_module_desc.entrypoint.as_ref().unwrap().optype,
458 optype
459 );
460 }
461
462 #[rstest]
463 #[case("test_generator", Some("test_generator".to_string()))]
464 #[case("", None)]
465 fn test_module_desc_generator(#[case] input: &str, #[case] expected: Option<String>) {
466 let mut module = ModuleDesc::default();
467 if !input.is_empty() {
468 module.set_generator(input);
469 }
470 assert_eq!(module.generator, expected);
471 }
472
473 #[test]
474 fn test_extension_desc_new() {
475 let name = "test_extension";
476 let version = Version::new(1, 0, 0);
477 let extension = ExtensionDesc::new(name, version.clone());
478 assert_eq!(extension.name, name);
479 assert_eq!(extension.version, version);
480 }
481
482 #[rstest]
483 fn test_package_desc_n_packaged_extensions(
484 mut empty_package_desc: PackageDesc,
485 test_extension: ExtensionDesc,
486 ) {
487 assert_eq!(empty_package_desc.n_packaged_extensions(), 0);
488
489 empty_package_desc.set_packaged_extension(0, test_extension);
490 assert_eq!(empty_package_desc.n_packaged_extensions(), 1);
491 }
492
493 #[rstest]
494 fn test_package_desc_modules_iterator(
495 mut empty_package_desc: PackageDesc,
496 empty_module_desc: ModuleDesc,
497 ) {
498 empty_package_desc.set_module(0, empty_module_desc.clone());
499
500 let modules: Vec<_> = empty_package_desc.modules().collect();
501 assert_eq!(modules.len(), 1);
502 assert_eq!(modules[0].as_ref(), Some(&empty_module_desc));
503 }
504
505 #[rstest]
506 fn test_package_desc_packaged_extensions_iterator(
507 mut empty_package_desc: PackageDesc,
508 test_extension: ExtensionDesc,
509 ) {
510 empty_package_desc.set_packaged_extension(0, test_extension.clone());
511
512 let extensions: Vec<_> = empty_package_desc.packaged_extensions().collect();
513 assert_eq!(extensions.len(), 1);
514 assert_eq!(extensions[0], &test_extension);
515 }
516
517 #[rstest]
518 fn test_module_desc_set_used_extensions_generator(
519 mut empty_module_desc: ModuleDesc,
520 test_extension: ExtensionDesc,
521 ) {
522 empty_module_desc.set_used_extensions_generator(vec![test_extension.clone()]);
523
524 assert_eq!(
525 empty_module_desc
526 .used_extensions_generator
527 .as_ref()
528 .unwrap()
529 .len(),
530 1
531 );
532 assert_eq!(
533 empty_module_desc
534 .used_extensions_generator
535 .as_ref()
536 .unwrap()[0],
537 test_extension
538 );
539 }
540
541 #[rstest]
542 fn test_module_desc_extend_used_extensions_metadata(mut empty_module_desc: ModuleDesc) {
543 let extension1 = ExtensionDesc::new("test_ext1", Version::new(1, 0, 0));
544 let extension2 = ExtensionDesc::new("test_ext2", Version::new(2, 0, 0));
545
546 empty_module_desc.set_used_extensions_generator(vec![extension1.clone()]);
547 empty_module_desc.extend_used_extensions_metadata(vec![extension2.clone()]);
548
549 let extensions = empty_module_desc
550 .used_extensions_generator
551 .as_ref()
552 .unwrap();
553 assert_eq!(extensions.len(), 2);
554 assert!(extensions.contains(&extension1));
555 assert!(extensions.contains(&extension2));
556 }
557
558 #[rstest]
559 fn test_module_desc_set_public_symbols(mut empty_module_desc: ModuleDesc) {
560 let symbols = vec!["symbol1".to_string(), "symbol2".to_string()];
561 empty_module_desc.set_public_symbols(symbols.clone());
562
563 assert_eq!(empty_module_desc.public_symbols.as_ref().unwrap().len(), 2);
564 assert_eq!(empty_module_desc.public_symbols.as_ref().unwrap(), &symbols);
565 }
566
567 #[rstest]
568 fn test_module_desc_extend_public_symbols(mut empty_module_desc: ModuleDesc) {
569 let symbols1 = vec!["symbol1".to_string()];
570 let symbols2 = vec!["symbol2".to_string()];
571
572 empty_module_desc.set_public_symbols(symbols1.clone());
573 empty_module_desc.extend_public_symbols(symbols2.clone());
574
575 let symbols = empty_module_desc.public_symbols.as_ref().unwrap();
576 assert_eq!(symbols.len(), 2);
577 assert!(symbols.contains(&"symbol1".to_string()));
578 assert!(symbols.contains(&"symbol2".to_string()));
579 }
580}