use alloc::{sync::Arc, vec::Vec};
use core::fmt;
use miden_crypto::{Felt, WORD_SIZE, Word};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use winter_math::FieldElement;
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use super::Kernel;
use crate::{
AdviceMap,
mast::{MastForest, MastNode, MastNodeExt, MastNodeId},
utils::ToElements,
};
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Program {
mast_forest: Arc<MastForest>,
entrypoint: MastNodeId,
kernel: Kernel,
}
impl Program {
pub fn new(mast_forest: Arc<MastForest>, entrypoint: MastNodeId) -> Self {
Self::with_kernel(mast_forest, entrypoint, Kernel::default())
}
pub fn with_kernel(
mast_forest: Arc<MastForest>,
entrypoint: MastNodeId,
kernel: Kernel,
) -> Self {
assert!(mast_forest.get_node_by_id(entrypoint).is_some(), "invalid entrypoint");
assert!(mast_forest.is_procedure_root(entrypoint), "entrypoint not a procedure");
Self { mast_forest, entrypoint, kernel }
}
pub fn with_advice_map(self, advice_map: AdviceMap) -> Self {
let mut mast_forest = (*self.mast_forest).clone();
mast_forest.advice_map_mut().extend(advice_map);
Self {
mast_forest: Arc::new(mast_forest),
..self
}
}
}
impl Program {
pub fn hash(&self) -> Word {
self.mast_forest[self.entrypoint].digest()
}
pub fn entrypoint(&self) -> MastNodeId {
self.entrypoint
}
pub fn mast_forest(&self) -> &Arc<MastForest> {
&self.mast_forest
}
pub fn kernel(&self) -> &Kernel {
&self.kernel
}
#[inline(always)]
pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
self.mast_forest.get_node_by_id(node_id)
}
#[inline(always)]
pub fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId> {
self.mast_forest.find_procedure_root(digest)
}
pub fn num_procedures(&self) -> u32 {
self.mast_forest.num_procedures()
}
}
#[cfg(feature = "std")]
impl Program {
pub fn write_to_file<P>(&self, path: P) -> std::io::Result<()>
where
P: AsRef<std::path::Path>,
{
let path = path.as_ref();
if let Some(dir) = path.parent() {
std::fs::create_dir_all(dir)?;
}
std::panic::catch_unwind(|| match std::fs::File::create(path) {
Ok(ref mut file) => {
self.write_into(file);
Ok(())
},
Err(err) => Err(err),
})
.map_err(|p| {
match p.downcast::<std::io::Error>() {
Ok(err) => unsafe { core::ptr::read(&*err) },
Err(err) => std::panic::resume_unwind(err),
}
})?
}
}
impl Serializable for Program {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.mast_forest.write_into(target);
self.kernel.write_into(target);
target.write_u32(self.entrypoint.as_u32());
}
}
impl Deserializable for Program {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let mast_forest = Arc::new(source.read()?);
let kernel = source.read()?;
let entrypoint = MastNodeId::from_u32_safe(source.read_u32()?, &mast_forest)?;
if !mast_forest.is_procedure_root(entrypoint) {
return Err(DeserializationError::InvalidValue(format!(
"entrypoint {entrypoint} is not a procedure"
)));
}
Ok(Self::with_kernel(mast_forest, entrypoint, kernel))
}
}
impl crate::prettier::PrettyPrint for Program {
fn render(&self) -> crate::prettier::Document {
use crate::prettier::*;
let entrypoint = self.mast_forest[self.entrypoint()].to_pretty_print(&self.mast_forest);
indent(4, const_text("begin") + nl() + entrypoint.render()) + nl() + const_text("end")
}
}
impl fmt::Display for Program {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use crate::prettier::PrettyPrint;
self.pretty_print(f)
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct ProgramInfo {
program_hash: Word,
kernel: Kernel,
}
impl ProgramInfo {
pub const fn new(program_hash: Word, kernel: Kernel) -> Self {
Self { program_hash, kernel }
}
pub const fn program_hash(&self) -> &Word {
&self.program_hash
}
pub const fn kernel(&self) -> &Kernel {
&self.kernel
}
pub fn kernel_procedures(&self) -> &[Word] {
self.kernel.proc_hashes()
}
}
impl From<Program> for ProgramInfo {
fn from(program: Program) -> Self {
let program_hash = program.hash();
let kernel = program.kernel().clone();
Self { program_hash, kernel }
}
}
impl Serializable for ProgramInfo {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.program_hash.write_into(target);
self.kernel.write_into(target);
}
}
impl Deserializable for ProgramInfo {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let program_hash = source.read()?;
let kernel = source.read()?;
Ok(Self { program_hash, kernel })
}
}
impl ToElements for ProgramInfo {
fn to_elements(&self) -> Vec<Felt> {
let num_kernel_proc_elements = self.kernel.proc_hashes().len() * WORD_SIZE;
let mut result = Vec::with_capacity(2 * WORD_SIZE + num_kernel_proc_elements);
result.extend_from_slice(self.program_hash.as_elements());
result.extend_from_slice(&[Felt::ZERO; 4]);
for proc_hash in self.kernel.proc_hashes() {
let mut proc_hash_elements = proc_hash.as_elements().to_vec();
pad_next_mul_8(&mut proc_hash_elements);
proc_hash_elements.reverse();
result.extend_from_slice(&proc_hash_elements);
}
result
}
}
fn pad_next_mul_8(input: &mut Vec<Felt>) {
let output_len = input.len().next_multiple_of(8);
input.resize(output_len, Felt::ZERO);
}