midenc_codegen_masm/packaging/
package.rs

1use alloc::{collections::BTreeSet, fmt, sync::Arc};
2
3use miden_processor::Digest;
4use midenc_hir::{formatter::DisplayHex, ConstantData, FunctionIdent, Ident, Signature, Symbol};
5use midenc_session::{diagnostics::Report, Emit, LinkLibrary, Session};
6use serde::{Deserialize, Serialize};
7
8use super::{de, se};
9use crate::*;
10
11#[derive(Serialize, Deserialize, Clone)]
12pub struct Package {
13    /// Name of the package
14    pub name: Symbol,
15    /// Content digest of the package
16    #[serde(
17        serialize_with = "se::serialize_digest",
18        deserialize_with = "de::deserialize_digest"
19    )]
20    pub digest: Digest,
21    /// The package type and MAST
22    #[serde(
23        serialize_with = "se::serialize_mast",
24        deserialize_with = "de::deserialize_mast"
25    )]
26    pub mast: MastArtifact,
27    /// The rodata segments required by the code in this package
28    pub rodata: Vec<Rodata>,
29    /// The package manifest, containing the set of exported procedures and their signatures,
30    /// if known.
31    pub manifest: PackageManifest,
32}
33impl fmt::Debug for Package {
34    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
35        f.debug_struct("Package")
36            .field("name", &self.name)
37            .field("digest", &format_args!("{}", DisplayHex::new(&self.digest.as_bytes())))
38            .field_with("rodata", |f| f.debug_list().entries(self.rodata.iter()).finish())
39            .field("manifest", &self.manifest)
40            .finish_non_exhaustive()
41    }
42}
43impl Emit for Package {
44    fn name(&self) -> Option<Symbol> {
45        Some(self.name)
46    }
47
48    fn output_type(&self, mode: midenc_session::OutputMode) -> midenc_session::OutputType {
49        use midenc_session::OutputMode;
50        match mode {
51            OutputMode::Text => self.mast.output_type(mode),
52            OutputMode::Binary => midenc_session::OutputType::Masp,
53        }
54    }
55
56    fn write_to<W: std::io::Write>(
57        &self,
58        mut writer: W,
59        mode: midenc_session::OutputMode,
60        session: &Session,
61    ) -> std::io::Result<()> {
62        use midenc_session::OutputMode;
63        match mode {
64            OutputMode::Text => self.mast.write_to(writer, mode, session),
65            OutputMode::Binary => {
66                // Write magic
67                writer.write_all(b"MASP\0")?;
68                // Write format version
69                writer.write_all(b"1.0\0")?;
70                let data = bitcode::serialize(self).map_err(std::io::Error::other)?;
71                writer.write_all(data.as_slice())
72            }
73        }
74    }
75}
76
77#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
78#[serde(default)]
79pub struct PackageManifest {
80    /// The set of exports in this package.
81    pub exports: BTreeSet<PackageExport>,
82    /// The libraries linked against by this package, which must be provided when executing the
83    /// program.
84    pub link_libraries: Vec<LinkLibrary>,
85}
86
87#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
88pub struct PackageExport {
89    pub id: FunctionIdent,
90    #[serde(
91        serialize_with = "se::serialize_digest",
92        deserialize_with = "de::deserialize_digest"
93    )]
94    pub digest: Digest,
95    /// We don't always have a type signature for an export
96    #[serde(default)]
97    pub signature: Option<Signature>,
98}
99impl fmt::Debug for PackageExport {
100    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101        f.debug_struct("PackageExport")
102            .field("id", &format_args!("{}", self.id.display()))
103            .field("digest", &format_args!("{}", DisplayHex::new(&self.digest.as_bytes())))
104            .field("signature", &self.signature)
105            .finish()
106    }
107}
108impl PartialOrd for PackageExport {
109    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
110        Some(self.cmp(other))
111    }
112}
113impl Ord for PackageExport {
114    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
115        self.id.cmp(&other.id).then_with(|| self.digest.cmp(&other.digest))
116    }
117}
118
119/// Represents a read-only data segment, combined with its content digest
120#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
121pub struct Rodata {
122    /// The content digest computed for `data`
123    #[serde(
124        serialize_with = "se::serialize_digest",
125        deserialize_with = "de::deserialize_digest"
126    )]
127    pub digest: Digest,
128    /// The address at which the data for this segment begins
129    pub start: NativePtr,
130    /// The raw binary data for this segment
131    pub data: Arc<ConstantData>,
132}
133impl fmt::Debug for Rodata {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        f.debug_struct("Rodata")
136            .field("digest", &format_args!("{}", DisplayHex::new(&self.digest.as_bytes())))
137            .field("start", &self.start)
138            .field_with("data", |f| {
139                f.debug_struct("ConstantData")
140                    .field("len", &self.data.len())
141                    .finish_non_exhaustive()
142            })
143            .finish()
144    }
145}
146impl Rodata {
147    pub fn size_in_bytes(&self) -> usize {
148        self.data.len()
149    }
150
151    pub fn size_in_felts(&self) -> usize {
152        self.data.len().next_multiple_of(4) / 4
153    }
154
155    pub fn size_in_words(&self) -> usize {
156        self.size_in_felts().next_multiple_of(4) / 4
157    }
158
159    /// Attempt to convert this rodata object to its equivalent representation in felts
160    ///
161    /// The resulting felts will be in padded out to the nearest number of words, i.e. if the data
162    /// only takes up 3 felts worth of bytes, then the resulting `Vec` will contain 4 felts, so that
163    /// the total size is a valid number of words.
164    pub fn to_elements(&self) -> Result<Vec<miden_processor::Felt>, String> {
165        use miden_core::FieldElement;
166        use miden_processor::Felt;
167
168        let data = self.data.as_slice();
169        let mut felts = Vec::with_capacity(data.len() / 4);
170        let mut iter = data.iter().copied().array_chunks::<4>();
171        felts.extend(iter.by_ref().map(|bytes| Felt::new(u32::from_le_bytes(bytes) as u64)));
172        if let Some(remainder) = iter.into_remainder() {
173            let mut chunk = [0u8; 4];
174            for (i, byte) in remainder.into_iter().enumerate() {
175                chunk[i] = byte;
176            }
177            felts.push(Felt::new(u32::from_le_bytes(chunk) as u64));
178        }
179
180        let padding = (self.size_in_words() * 4).abs_diff(felts.len());
181        felts.resize(felts.len() + padding, Felt::ZERO);
182
183        Ok(felts)
184    }
185}
186
187impl Package {
188    /// Create a [Package] for a [MastArtifact], using the [MasmArtifact] from which it was
189    /// assembled, and the [Session] that was used to compile it.
190    pub fn new(mast: MastArtifact, masm: &MasmArtifact, session: &Session) -> Self {
191        let name = Symbol::intern(session.name());
192        let digest = mast.digest();
193        let link_libraries = session.options.link_libraries.clone();
194        let mut manifest = PackageManifest {
195            exports: Default::default(),
196            link_libraries,
197        };
198
199        // Gater all of the rodata segments for this package
200        let rodata = match masm {
201            MasmArtifact::Executable(ref prog) => prog.rodatas().to_vec(),
202            MasmArtifact::Library(ref lib) => lib.rodatas().to_vec(),
203        };
204
205        // Gather all of the procedure metadata for exports of this package
206        if let MastArtifact::Library(ref lib) = mast {
207            let MasmArtifact::Library(ref masm_lib) = masm else {
208                unreachable!();
209            };
210            for module_info in lib.module_infos() {
211                let module_path = module_info.path().path();
212                let masm_module = masm_lib.get(module_path.as_ref());
213                let module_span = masm_module.map(|module| module.span).unwrap_or_default();
214                for (_, proc_info) in module_info.procedures() {
215                    let proc_name = proc_info.name.as_str();
216                    let masm_function = masm_module.and_then(|module| {
217                        module.functions().find(|f| f.name.function.as_str() == proc_name)
218                    });
219                    let proc_span = masm_function.map(|f| f.span).unwrap_or_default();
220                    let id = FunctionIdent {
221                        module: Ident::new(Symbol::intern(module_path.as_ref()), module_span),
222                        function: Ident::new(Symbol::intern(proc_name), proc_span),
223                    };
224                    let digest = proc_info.digest;
225                    let signature = masm_function.map(|f| f.signature.clone());
226                    manifest.exports.insert(PackageExport {
227                        id,
228                        digest,
229                        signature,
230                    });
231                }
232            }
233        }
234
235        Self {
236            name,
237            digest,
238            mast,
239            rodata,
240            manifest,
241        }
242    }
243
244    pub fn read_from_file<P>(path: P) -> std::io::Result<Self>
245    where
246        P: AsRef<std::path::Path>,
247    {
248        let path = path.as_ref();
249        let bytes = std::fs::read(path)?;
250
251        Self::read_from_bytes(bytes).map_err(std::io::Error::other)
252    }
253
254    pub fn read_from_bytes<B>(bytes: B) -> Result<Self, Report>
255    where
256        B: AsRef<[u8]>,
257    {
258        use alloc::borrow::Cow;
259
260        let bytes = bytes.as_ref();
261
262        let bytes = bytes
263            .strip_prefix(b"MASP\0")
264            .ok_or_else(|| Report::msg("invalid package: missing header"))?;
265        let bytes = bytes.strip_prefix(b"1.0\0").ok_or_else(|| {
266            Report::msg(format!(
267                "invalid package: incorrect version, expected '1.0', got '{}'",
268                bytes.get(0..4).map(String::from_utf8_lossy).unwrap_or(Cow::Borrowed("")),
269            ))
270        })?;
271
272        bitcode::deserialize(bytes).map_err(Report::msg)
273    }
274
275    pub fn is_program(&self) -> bool {
276        matches!(self.mast, MastArtifact::Executable(_))
277    }
278
279    pub fn is_library(&self) -> bool {
280        matches!(self.mast, MastArtifact::Library(_))
281    }
282
283    pub fn unwrap_program(&self) -> Arc<miden_core::Program> {
284        match self.mast {
285            MastArtifact::Executable(ref prog) => Arc::clone(prog),
286            _ => panic!("expected package to contain a program, but got a library"),
287        }
288    }
289
290    pub fn unwrap_library(&self) -> Arc<miden_assembly::Library> {
291        match self.mast {
292            MastArtifact::Library(ref lib) => Arc::clone(lib),
293            _ => panic!("expected package to contain a library, but got an executable"),
294        }
295    }
296
297    pub fn make_executable(&self, entrypoint: &FunctionIdent) -> Result<Self, Report> {
298        use midenc_session::diagnostics::{SourceSpan, Span};
299
300        let MastArtifact::Library(ref library) = self.mast else {
301            return Err(Report::msg("expected library but got an executable"));
302        };
303
304        let module = library
305            .module_infos()
306            .find(|info| info.path().path() == entrypoint.module.as_str())
307            .ok_or_else(|| {
308                Report::msg(format!(
309                    "invalid entrypoint: library does not contain a module named '{}'",
310                    entrypoint.module.as_str()
311                ))
312            })?;
313        let name = miden_assembly::ast::ProcedureName::new_unchecked(
314            miden_assembly::ast::Ident::new_unchecked(Span::new(
315                SourceSpan::UNKNOWN,
316                Arc::from(entrypoint.function.as_str()),
317            )),
318        );
319        if let Some(digest) = module.get_procedure_digest_by_name(&name) {
320            let node_id = library.mast_forest().find_procedure_root(digest).ok_or_else(|| {
321                Report::msg(
322                    "invalid entrypoint: malformed library - procedure exported, but digest has \
323                     no node in the forest",
324                )
325            })?;
326
327            let exports = BTreeSet::from_iter(self.manifest.exports.iter().find_map(|export| {
328                if export.digest == digest {
329                    Some(export.clone())
330                } else {
331                    None
332                }
333            }));
334
335            Ok(Self {
336                name: self.name,
337                digest,
338                mast: MastArtifact::Executable(Arc::new(miden_core::Program::new(
339                    library.mast_forest().clone(),
340                    node_id,
341                ))),
342                rodata: self.rodata.clone(),
343                manifest: PackageManifest {
344                    exports,
345                    link_libraries: self.manifest.link_libraries.clone(),
346                },
347            })
348        } else {
349            Err(Report::msg(format!(
350                "invalid entrypoint: library does not export '{}'",
351                entrypoint.display()
352            )))
353        }
354    }
355}