1use clio::Input;
4use hugr::envelope::{EnvelopeError, read_envelope};
5use hugr::extension::ExtensionRegistry;
6use hugr::package::Package;
7use hugr::{Extension, Hugr};
8use std::io::{BufReader, Read};
9use std::path::PathBuf;
10
11use crate::CliError;
12
13#[derive(Debug, clap::Args)]
15pub struct HugrInputArgs {
16 #[arg(value_parser, default_value = "-", help_heading = "Input")]
18 pub input: Input,
19
20 #[arg(
22 long,
23 help_heading = "Input",
24 help = "Don't use standard extensions when validating hugrs. Prelude is still used."
25 )]
26 pub no_std: bool,
27 #[arg(
29 short,
30 long,
31 help_heading = "Input",
32 help = "Paths to serialised extensions to validate against."
33 )]
34 pub extensions: Vec<PathBuf>,
35 #[clap(long, help_heading = "Input")]
39 pub hugr_json: bool,
40}
41
42impl HugrInputArgs {
43 pub fn get_package(&mut self) -> Result<Package, CliError> {
49 let extensions = self.load_extensions()?;
50 let buffer = BufReader::new(&mut self.input);
51 match read_envelope(buffer, &extensions) {
52 Ok((_, pkg)) => Ok(pkg),
53 Err(EnvelopeError::MagicNumber { .. }) => Err(CliError::NotAnEnvelope),
54 Err(e) => Err(CliError::Envelope(e)),
55 }
56 }
57
58 pub fn get_hugr(&mut self) -> Result<Hugr, CliError> {
65 let extensions = self.load_extensions()?;
66 let mut buffer = BufReader::new(&mut self.input);
67
68 const PREPEND: &str = r#"HUGRiHJv?@{"modules": ["#;
70 const APPEND: &str = r#"],"extensions": []}"#;
71
72 let mut envelope = PREPEND.to_string();
73 buffer.read_to_string(&mut envelope)?;
74 envelope.push_str(APPEND);
75
76 let hugr = Hugr::load_str(envelope, Some(&extensions))?;
77 Ok(hugr)
78 }
79
80 pub fn load_extensions(&self) -> Result<ExtensionRegistry, CliError> {
85 let mut reg = if self.no_std {
86 hugr::extension::PRELUDE_REGISTRY.to_owned()
87 } else {
88 hugr::std_extensions::STD_REG.to_owned()
89 };
90
91 for ext in &self.extensions {
92 let f = std::fs::File::open(ext)?;
93 let ext: Extension = serde_json::from_reader(f)?;
94 reg.register_updated(ext);
95 }
96
97 Ok(reg)
98 }
99}