1use derive_more::{Display, Error, From};
4use itertools::Itertools;
5use std::path::Path;
6use std::{fs, io, mem};
7
8use crate::builder::{Container, Dataflow, DataflowSubContainer, ModuleBuilder};
9use crate::envelope::{read_envelope, write_envelope, EnvelopeConfig, EnvelopeError};
10use crate::extension::resolution::ExtensionResolutionError;
11use crate::extension::{ExtensionId, ExtensionRegistry, PRELUDE_REGISTRY};
12use crate::hugr::internal::HugrMutInternals;
13use crate::hugr::{ExtensionError, HugrView, ValidationError};
14use crate::ops::{FuncDefn, Module, NamedOp, OpTag, OpTrait, OpType};
15use crate::{Extension, Hugr};
16
17#[derive(Debug, Default, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
18pub struct Package {
20 pub modules: Vec<Hugr>,
22 pub extensions: ExtensionRegistry,
26}
27
28impl Package {
29 pub fn new(modules: impl IntoIterator<Item = Hugr>) -> Result<Self, PackageError> {
40 let modules: Vec<Hugr> = modules.into_iter().collect();
41 let mut extensions = ExtensionRegistry::default();
42 for (idx, module) in modules.iter().enumerate() {
43 let root_op = module.get_optype(module.root());
44 if !root_op.is_module() {
45 return Err(PackageError::NonModuleHugr {
46 module_index: idx,
47 root_op: root_op.clone(),
48 });
49 }
50 extensions.extend(module.extensions());
51 }
52 Ok(Self {
53 modules,
54 extensions,
55 })
56 }
57
58 pub fn from_hugrs(modules: impl IntoIterator<Item = Hugr>) -> Result<Self, PackageError> {
69 let modules: Vec<Hugr> = modules
70 .into_iter()
71 .map(to_module_hugr)
72 .collect::<Result<_, PackageError>>()?;
73
74 let mut extensions = ExtensionRegistry::default();
75 for module in &modules {
76 extensions.extend(module.extensions());
77 }
78
79 Ok(Self {
80 modules,
81 extensions,
82 })
83 }
84
85 pub fn from_hugr(hugr: Hugr) -> Result<Self, PackageError> {
96 let mut package = Self::default();
97 let module = to_module_hugr(hugr)?;
98 package.extensions = module.extensions().clone();
99 package.modules.push(module);
100 Ok(package)
101 }
102
103 pub fn validate(&self) -> Result<(), PackageValidationError> {
107 for hugr in self.modules.iter() {
108 hugr.validate()?;
109
110 let missing_exts = hugr
111 .extensions()
112 .ids()
113 .filter(|id| !self.extensions.contains(id))
114 .cloned()
115 .collect_vec();
116 if !missing_exts.is_empty() {
117 return Err(PackageValidationError::MissingExtension {
118 missing: missing_exts,
119 available: self.extensions.ids().cloned().collect(),
120 });
121 }
122 }
123 Ok(())
124 }
125
126 pub fn load(
128 reader: impl io::BufRead,
129 extensions: Option<&ExtensionRegistry>,
130 ) -> Result<Self, EnvelopeError> {
131 let extensions = extensions.unwrap_or(&PRELUDE_REGISTRY);
132 let (_, pkg) = read_envelope(reader, extensions)?;
133 Ok(pkg)
134 }
135
136 pub fn load_str(
141 envelope: impl AsRef<str>,
142 extensions: Option<&ExtensionRegistry>,
143 ) -> Result<Self, EnvelopeError> {
144 Self::load(envelope.as_ref().as_bytes(), extensions)
145 }
146
147 pub fn store(
149 &self,
150 writer: impl io::Write,
151 config: EnvelopeConfig,
152 ) -> Result<(), EnvelopeError> {
153 write_envelope(writer, self, config)
154 }
155
156 pub fn store_str(&self, config: EnvelopeConfig) -> Result<String, EnvelopeError> {
162 if !config.format.ascii_printable() {
163 return Err(EnvelopeError::NonASCIIFormat {
164 format: config.format,
165 });
166 }
167
168 let mut buf = Vec::new();
169 self.store(&mut buf, config)?;
170 Ok(String::from_utf8(buf).expect("Envelope is valid utf8"))
171 }
172
173 #[deprecated(
179 since = "0.14.5",
180 note = "Json encoding of packages is deprecated. Use `Package::load` instead"
181 )]
182 #[cfg_attr(coverage_nightly, coverage(off))]
183 pub fn from_json_reader(
184 reader: impl io::Read,
185 extension_registry: &ExtensionRegistry,
186 ) -> Result<Self, PackageEncodingError> {
187 let val: serde_json::Value = serde_json::from_reader(reader)?;
188
189 #[derive(Debug, serde::Deserialize)]
192 struct PackageDeser {
193 pub modules: Vec<Hugr>,
194 pub extensions: Vec<Extension>,
195 }
196 let loaded_pkg = serde_json::from_value::<PackageDeser>(val.clone());
197
198 if let Ok(PackageDeser {
199 mut modules,
200 extensions: pkg_extensions,
201 }) = loaded_pkg
202 {
203 let mut pkg_extensions = ExtensionRegistry::new_with_extension_resolution(
204 pkg_extensions,
205 &extension_registry.into(),
206 )?;
207
208 let mut combined_registry = extension_registry.clone();
210 combined_registry.extend(&pkg_extensions);
211
212 for module in &mut modules {
213 module.resolve_extension_defs(&combined_registry)?;
214 pkg_extensions.extend(module.extensions());
215 }
216
217 return Ok(Package {
218 modules,
219 extensions: pkg_extensions,
220 });
221 };
222 let pkg_load_err = loaded_pkg.unwrap_err();
223
224 if let Ok(mut hugr) = serde_json::from_value::<Hugr>(val) {
226 hugr.resolve_extension_defs(extension_registry)?;
227 if cfg!(feature = "extension_inference") {
228 hugr.infer_extensions(false)?;
229 }
230 return Ok(Package::from_hugr(hugr)?);
231 }
232
233 Err(PackageEncodingError::JsonEncoding(pkg_load_err))
235 }
236
237 #[deprecated(
241 since = "0.14.5",
242 note = "Json encoding of packages is deprecated. Use `Package::load_str` instead"
243 )]
244 #[cfg_attr(coverage_nightly, coverage(off))]
245 pub fn from_json(
246 json: impl AsRef<str>,
247 extension_registry: &ExtensionRegistry,
248 ) -> Result<Self, PackageEncodingError> {
249 #[allow(deprecated)]
250 Self::from_json_reader(json.as_ref().as_bytes(), extension_registry)
251 }
252
253 #[deprecated(
257 since = "0.14.5",
258 note = "Json encoding of packages is deprecated. Use `Package::load` instead"
259 )]
260 #[cfg_attr(coverage_nightly, coverage(off))]
261 pub fn from_json_file(
262 path: impl AsRef<Path>,
263 extension_registry: &ExtensionRegistry,
264 ) -> Result<Self, PackageEncodingError> {
265 let file = fs::File::open(path)?;
266 let reader = io::BufReader::new(file);
267 #[allow(deprecated)]
268 Self::from_json_reader(reader, extension_registry)
269 }
270
271 #[deprecated(
273 since = "0.14.5",
274 note = "Json encoding of packages is deprecated. Use `Package::store` instead"
275 )]
276 #[cfg_attr(coverage_nightly, coverage(off))]
277 pub fn to_json_writer(&self, writer: impl io::Write) -> Result<(), PackageEncodingError> {
278 serde_json::to_writer(writer, self)?;
279 Ok(())
280 }
281
282 #[deprecated(
286 since = "0.14.5",
287 note = "Json encoding of packages is deprecated. Use `Package::store_str` instead"
288 )]
289 #[cfg_attr(coverage_nightly, coverage(off))]
290 pub fn to_json(&self) -> Result<String, PackageEncodingError> {
291 let json = serde_json::to_string(self)?;
292 Ok(json)
293 }
294
295 #[deprecated(
297 since = "0.14.5",
298 note = "Json encoding of packages is deprecated. Use `Package::store` instead"
299 )]
300 #[cfg_attr(coverage_nightly, coverage(off))]
301 pub fn to_json_file(&self, path: impl AsRef<Path>) -> Result<(), PackageEncodingError> {
302 let file = fs::OpenOptions::new()
303 .write(true)
304 .truncate(true)
305 .create(true)
306 .open(path)?;
307 let writer = io::BufWriter::new(file);
308 #[allow(deprecated)]
309 self.to_json_writer(writer)
310 }
311}
312
313impl AsRef<[Hugr]> for Package {
314 fn as_ref(&self) -> &[Hugr] {
315 &self.modules
316 }
317}
318
319fn to_module_hugr(mut hugr: Hugr) -> Result<Hugr, PackageError> {
327 let root = hugr.root();
328 let root_op = hugr.get_optype(root).clone();
329 let tag = root_op.tag();
330
331 if root_op.is_module() {
333 return Ok(hugr);
334 }
335 if OpTag::ModuleOp.is_superset(tag) {
337 let new_root = hugr.add_node(Module::new().into());
338 hugr.set_root(new_root);
339 hugr.set_parent(root, new_root);
340 return Ok(hugr);
341 }
342 if OpTag::Dfg.is_superset(tag) {
344 let signature = root_op
345 .dataflow_signature()
346 .unwrap_or_else(|| panic!("Dataflow child {} without signature", root_op.name()));
347
348 hugr.set_num_ports(root, 0, 1);
350 hugr.replace_op(
351 root,
352 FuncDefn {
353 name: "main".to_string(),
354 signature: signature.into_owned().into(),
355 },
356 )
357 .expect("Hugr accepts any root node");
358
359 let new_root = hugr.add_node(Module::new().into());
361 hugr.set_root(new_root);
362 hugr.set_parent(root, new_root);
363 return Ok(hugr);
364 }
365 if OpTag::DataflowChild.is_superset(tag) && !root_op.is_input() && !root_op.is_output() {
367 let signature = root_op
368 .dataflow_signature()
369 .unwrap_or_else(|| panic!("Dataflow child {} without signature", root_op.name()))
370 .into_owned();
371 let mut new_hugr = ModuleBuilder::new();
372 let mut func = new_hugr.define_function("main", signature).unwrap();
373 let dataflow_node = func.add_hugr_with_wires(hugr, func.input_wires()).unwrap();
374 func.finish_with_outputs(dataflow_node.outputs()).unwrap();
375 return Ok(mem::take(new_hugr.hugr_mut()));
376 }
377 Err(PackageError::CannotWrapHugr {
379 root_op: root_op.clone(),
380 })
381}
382
383#[derive(Debug, Display, Error, PartialEq)]
385#[non_exhaustive]
386pub enum PackageError {
387 #[display("Module {module_index} in the package does not have an OpType::Module root, but {}", root_op.name())]
389 NonModuleHugr {
390 module_index: usize,
392 root_op: OpType,
394 },
395 #[display("A hugr with optype {} cannot be wrapped in a module.", root_op.name())]
397 CannotWrapHugr {
398 root_op: OpType,
400 },
401}
402
403#[derive(Debug, Display, Error, From)]
405#[non_exhaustive]
406pub enum PackageEncodingError {
407 JsonEncoding(serde_json::Error),
409 IOError(io::Error),
411 Package(PackageError),
413 ExtensionResolution(ExtensionResolutionError),
415 RuntimeExtensionResolution(ExtensionError),
417}
418
419#[derive(Debug, Display, From, Error)]
421#[non_exhaustive]
422pub enum PackageValidationError {
423 #[display("The package modules use the extension{} {} not present in the defined set. The declared extensions are {}",
425 if missing.len() > 1 {"s"} else {""},
426 missing.iter().map(|id| id.to_string()).collect::<Vec<_>>().join(", "),
427 available.iter().map(|id| id.to_string()).collect::<Vec<_>>().join(", "),
428 )]
429 MissingExtension {
430 missing: Vec<ExtensionId>,
432 available: Vec<ExtensionId>,
434 },
435 Validation(ValidationError),
437}
438
439#[cfg(test)]
440mod test {
441 use cool_asserts::assert_matches;
442
443 use crate::builder::test::{
444 simple_cfg_hugr, simple_dfg_hugr, simple_funcdef_hugr, simple_module_hugr,
445 };
446 use crate::ops::dataflow::IOTrait;
447 use crate::ops::Input;
448
449 use super::*;
450 use rstest::{fixture, rstest};
451
452 #[fixture]
453 fn simple_input_node() -> Hugr {
454 Hugr::new(Input::new(vec![]))
455 }
456
457 #[rstest]
458 #[case::module("module", simple_module_hugr(), false)]
459 #[case::funcdef("funcdef", simple_funcdef_hugr(), false)]
460 #[case::dfg("dfg", simple_dfg_hugr(), false)]
461 #[case::cfg("cfg", simple_cfg_hugr(), false)]
462 #[case::unsupported_input("input", simple_input_node(), true)]
463 #[cfg_attr(miri, ignore)] fn hugr_to_package(#[case] test_name: &str, #[case] hugr: Hugr, #[case] errors: bool) {
465 match (&Package::from_hugr(hugr), errors) {
466 (Ok(package), false) => {
467 assert_eq!(package.modules.len(), 1);
468 let hugr = &package.modules[0];
469 let root_op = hugr.get_optype(hugr.root());
470 assert!(root_op.is_module());
471
472 insta::assert_snapshot!(test_name, hugr.mermaid_string());
473 }
474 (Err(_), true) => {}
475 (p, _) => panic!("Unexpected result {:?}", p),
476 }
477 }
478
479 #[rstest]
480 fn package_properties() {
481 let module = simple_module_hugr();
482 let dfg = simple_dfg_hugr();
483
484 assert_matches!(
485 Package::new([module.clone(), dfg.clone()]),
486 Err(PackageError::NonModuleHugr {
487 module_index: 1,
488 root_op: OpType::DFG(_),
489 })
490 );
491
492 let pkg = Package::from_hugrs([module, dfg]).unwrap();
493 pkg.validate().unwrap();
494
495 assert_eq!(pkg.modules.len(), 2);
496 }
497}