extern crate alloc;
use alloc::vec::Vec;
use core::fmt;
use bitcoin::OutPoint;
use risc0_zkvm::sha::{Impl as Sha256Impl, Sha256};
use serde::{Deserialize, Serialize};
use taproot_assets_types as types;
use taproot_assets_types::proof::Proof;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ProofChainClaimInput {
pub proof_file_version: u32,
pub entries: Vec<ProofChainEntryInput>,
}
impl ProofChainClaimInput {
pub fn from_file(file: &types::proof::File) -> Self {
let mut entries = Vec::with_capacity(file.proofs.len());
for proof in &file.proofs {
entries.push(ProofChainEntryInput {
proof_bytes: proof.proof_bytes.clone(),
proof_checksum: proof.hash,
});
}
Self {
proof_file_version: file.version,
entries,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ProofChainEntryInput {
pub proof_bytes: Vec<u8>,
pub proof_checksum: [u8; 32],
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ProofChainClaimOutput {
pub proof_file_version: u32,
pub proof_count: u32,
pub last_proof_checksum: [u8; 32],
pub last_proof_outpoint: Option<OutPoint>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Error {
TooManyProofs(usize),
ChecksumMismatch { index: u32 },
InvalidProofEncoding { index: u32 },
PrevOutChainMismatch { index: u32 },
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TooManyProofs(count) => write!(f, "too many proofs for claim: {count}"),
Self::ChecksumMismatch { index } => {
write!(f, "proof checksum mismatch at index {index}")
}
Self::InvalidProofEncoding { index } => {
write!(f, "invalid proof encoding at index {index}")
}
Self::PrevOutChainMismatch { index } => {
write!(
f,
"proof prev_out does not match prior outpoint at index {index}"
)
}
}
}
}
pub fn verify_proof_chain_claim(
input: &ProofChainClaimInput,
) -> Result<ProofChainClaimOutput, Error> {
let proof_count_u32 = u32::try_from(input.entries.len())
.map_err(|_| Error::TooManyProofs(input.entries.len()))?;
let mut prev_checksum = [0u8; 32];
let mut expected_prev_outpoint: Option<OutPoint> = None;
let mut last_outpoint: Option<OutPoint> = None;
for (idx, entry) in input.entries.iter().enumerate() {
let index = idx as u32;
let expected_checksum = hash_proof(entry.proof_bytes.as_slice(), prev_checksum);
if expected_checksum != entry.proof_checksum {
return Err(Error::ChecksumMismatch { index });
}
let proof = Proof::from_bytes(&entry.proof_bytes)
.map_err(|_| Error::InvalidProofEncoding { index })?;
if let Some(expected_prev) = expected_prev_outpoint {
if proof.prev_out != expected_prev {
return Err(Error::PrevOutChainMismatch { index });
}
}
let outpoint = OutPoint {
txid: proof.anchor_tx.compute_txid(),
vout: proof.inclusion_proof.output_index,
};
last_outpoint = Some(outpoint);
expected_prev_outpoint = Some(outpoint);
prev_checksum = entry.proof_checksum;
}
Ok(ProofChainClaimOutput {
proof_file_version: input.proof_file_version,
proof_count: proof_count_u32,
last_proof_checksum: prev_checksum,
last_proof_outpoint: last_outpoint,
})
}
fn hash_proof(proof_bytes: &[u8], prev_hash: [u8; 32]) -> [u8; 32] {
let mut preimage = Vec::with_capacity(32 + proof_bytes.len());
preimage.extend_from_slice(&prev_hash);
preimage.extend_from_slice(proof_bytes);
let digest = Sha256Impl::hash_bytes(&preimage);
digest.as_bytes().try_into().unwrap()
}