1use qssm_le::{Commitment, LatticeProof, RqPoly, N};
4use qssm_ms::GhostMirrorProof;
5use serde::{Deserialize, Serialize};
6
7use crate::context::Proof;
8
9pub const PROTOCOL_VERSION: u32 = 1;
11
12pub(crate) const PROOF_BUNDLE_VERSION: u32 = 1;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(deny_unknown_fields)]
22#[non_exhaustive]
23pub struct ProofBundle {
24 pub version: u32,
25 pub protocol_version: u32,
26 pub ms_root_hex: String,
28 pub ms_n: u8,
29 pub ms_k: u8,
30 pub ms_bit_at_k: u8,
31 pub ms_opened_salt_hex: String,
32 pub ms_path_hex: Vec<String>,
33 pub ms_challenge_hex: String,
34 pub le_commitment_coeffs: Vec<u32>,
36 pub le_proof_t_coeffs: Vec<u32>,
37 pub le_proof_z_coeffs: Vec<u32>,
38 pub le_challenge_seed_hex: String,
39 pub external_entropy_hex: String,
41 pub external_entropy_included: bool,
42 pub value: u64,
44 pub target: u64,
45 pub binding_entropy_hex: String,
46}
47
48#[derive(Debug, thiserror::Error)]
50#[non_exhaustive]
51pub enum WireFormatError {
52 #[error("unsupported bundle version {0} (expected {PROOF_BUNDLE_VERSION})")]
53 UnsupportedVersion(u32),
54 #[error("hex decode failed for field `{field}`: {source}")]
55 HexDecode {
56 field: &'static str,
57 source: hex::FromHexError,
58 },
59 #[error("wrong byte length for `{field}`: expected {expected}, got {got}")]
60 BadLength {
61 field: &'static str,
62 expected: usize,
63 got: usize,
64 },
65 #[error("wrong coefficient count for `{field}`: expected {expected}, got {got}")]
66 BadCoeffCount {
67 field: &'static str,
68 expected: usize,
69 got: usize,
70 },
71 #[error("invalid MS proof field: {0}")]
72 InvalidMsProofField(#[from] qssm_ms::MsError),
73}
74
75impl ProofBundle {
76 #[must_use]
78 pub fn from_proof(p: &Proof) -> Self {
79 Self {
80 version: PROOF_BUNDLE_VERSION,
81 protocol_version: PROTOCOL_VERSION,
82 ms_root_hex: hex::encode(p.ms_root),
83 ms_n: p.ms_proof.n(),
84 ms_k: p.ms_proof.k(),
85 ms_bit_at_k: p.ms_proof.bit_at_k(),
86 ms_opened_salt_hex: hex::encode(p.ms_proof.opened_salt()),
87 ms_path_hex: p.ms_proof.path().iter().map(hex::encode).collect(),
88 ms_challenge_hex: hex::encode(p.ms_proof.challenge()),
89 le_commitment_coeffs: p.le_commitment.0 .0.to_vec(),
90 le_proof_t_coeffs: p.le_proof.t.0.to_vec(),
91 le_proof_z_coeffs: p.le_proof.z.0.to_vec(),
92 le_challenge_seed_hex: hex::encode(p.le_proof.challenge_seed),
93 external_entropy_hex: hex::encode(p.external_entropy),
94 external_entropy_included: p.external_entropy_included,
95 value: p.value,
96 target: p.target,
97 binding_entropy_hex: hex::encode(p.binding_entropy),
98 }
99 }
100
101 pub fn to_proof(&self) -> Result<Proof, WireFormatError> {
103 if self.version != PROOF_BUNDLE_VERSION {
104 return Err(WireFormatError::UnsupportedVersion(self.version));
105 }
106 if self.protocol_version != PROTOCOL_VERSION {
107 return Err(WireFormatError::UnsupportedVersion(self.protocol_version));
108 }
109 Ok(Proof {
110 ms_root: decode_hash(&self.ms_root_hex, "ms_root_hex")?,
111 ms_proof: GhostMirrorProof::new(
112 self.ms_n,
113 self.ms_k,
114 self.ms_bit_at_k,
115 decode_hash(&self.ms_opened_salt_hex, "ms_opened_salt_hex")?,
116 self.ms_path_hex
117 .iter()
118 .map(|h| decode_hash(h, "ms_path_hex"))
119 .collect::<Result<Vec<_>, _>>()?,
120 decode_hash(&self.ms_challenge_hex, "ms_challenge_hex")?,
121 )?,
122 le_commitment: Commitment(RqPoly(vec_to_poly(
123 &self.le_commitment_coeffs,
124 "le_commitment_coeffs",
125 )?)),
126 le_proof: LatticeProof {
127 t: RqPoly(vec_to_poly(&self.le_proof_t_coeffs, "le_proof_t_coeffs")?),
128 z: RqPoly(vec_to_poly(&self.le_proof_z_coeffs, "le_proof_z_coeffs")?),
129 challenge_seed: decode_hash(&self.le_challenge_seed_hex, "le_challenge_seed_hex")?,
130 },
131 external_entropy: decode_hash(&self.external_entropy_hex, "external_entropy_hex")?,
132 external_entropy_included: self.external_entropy_included,
133 value: self.value,
134 target: self.target,
135 binding_entropy: decode_hash(&self.binding_entropy_hex, "binding_entropy_hex")?,
136 })
137 }
138}
139
140fn decode_hash(hex_str: &str, field: &'static str) -> Result<[u8; 32], WireFormatError> {
141 let bytes =
142 hex::decode(hex_str).map_err(|source| WireFormatError::HexDecode { field, source })?;
143 <[u8; 32]>::try_from(bytes.as_slice()).map_err(|_| WireFormatError::BadLength {
144 field,
145 expected: 32,
146 got: bytes.len(),
147 })
148}
149
150fn vec_to_poly(v: &[u32], field: &'static str) -> Result<[u32; N], WireFormatError> {
151 <[u32; N]>::try_from(v).map_err(|_| WireFormatError::BadCoeffCount {
152 field,
153 expected: N,
154 got: v.len(),
155 })
156}