photon_indexer/api/method/
get_validity_proof.rs1use crate::{
2    api::error::PhotonApiError,
3    common::typedefs::{hash::Hash, serializable_pubkey::SerializablePubkey},
4    ingester::persist::persisted_state_tree::{
5        get_multiple_compressed_leaf_proofs, MerkleProofWithContext,
6    },
7};
8use lazy_static::lazy_static;
9use num_bigint::BigUint;
10use reqwest::Client;
11use sea_orm::{ConnectionTrait, DatabaseBackend, DatabaseConnection, Statement, TransactionTrait};
12use serde::{Deserialize, Serialize};
13use std::str::FromStr;
14use utoipa::ToSchema;
15
16use super::{
17    get_multiple_new_address_proofs::{
18        get_multiple_new_address_proofs_helper, AddressWithTree, MerkleContextWithNewAddressProof,
19        ADDRESS_TREE_ADDRESS,
20    },
21    utils::Context,
22};
23
24lazy_static! {
25    pub static ref FIELD_SIZE: BigUint = BigUint::from_str(
26        "21888242871839275222246405745257275088548364400416034343698204186575808495616"
27    )
28    .unwrap();
29}
30
31pub const STATE_TREE_QUEUE_SIZE: u64 = 2400;
32
33#[derive(Serialize, Deserialize)]
34#[serde(rename_all = "camelCase")]
35struct InclusionHexInputsForProver {
36    root: String,
37    path_index: u32,
38    path_elements: Vec<String>,
39    leaf: String,
40}
41
42#[derive(Serialize, Deserialize)]
43#[serde(rename_all = "camelCase")]
44struct NonInclusionHexInputsForProver {
45    root: String,
46    value: String,
47    path_index: u32,
48    path_elements: Vec<String>,
49    leaf_lower_range_value: String,
50    leaf_higher_range_value: String,
51    next_index: u32,
52}
53
54fn convert_non_inclusion_merkle_proof_to_hex(
55    non_inclusion_merkle_proof_inputs: Vec<MerkleContextWithNewAddressProof>,
56) -> Vec<NonInclusionHexInputsForProver> {
57    let mut inputs: Vec<NonInclusionHexInputsForProver> = Vec::new();
58    for i in 0..non_inclusion_merkle_proof_inputs.len() {
59        let input = NonInclusionHexInputsForProver {
60            root: hash_to_hex(&non_inclusion_merkle_proof_inputs[i].root),
61            value: pubkey_to_hex(&non_inclusion_merkle_proof_inputs[i].address),
62            path_index: non_inclusion_merkle_proof_inputs[i].lowElementLeafIndex,
63            path_elements: non_inclusion_merkle_proof_inputs[i]
64                .proof
65                .iter()
66                .map(hash_to_hex)
67                .collect(),
68            next_index: non_inclusion_merkle_proof_inputs[i].nextIndex,
69            leaf_lower_range_value: pubkey_to_hex(
70                &non_inclusion_merkle_proof_inputs[i].lowerRangeAddress,
71            ),
72            leaf_higher_range_value: pubkey_to_hex(
73                &non_inclusion_merkle_proof_inputs[i].higherRangeAddress,
74            ),
75        };
76        inputs.push(input);
77    }
78    inputs
79}
80
81fn convert_inclusion_proofs_to_hex(
82    inclusion_proof_inputs: Vec<MerkleProofWithContext>,
83) -> Vec<InclusionHexInputsForProver> {
84    let mut inputs: Vec<InclusionHexInputsForProver> = Vec::new();
85    for i in 0..inclusion_proof_inputs.len() {
86        let input = InclusionHexInputsForProver {
87            root: hash_to_hex(&inclusion_proof_inputs[i].root),
88            path_index: inclusion_proof_inputs[i].leafIndex,
89            path_elements: inclusion_proof_inputs[i]
90                .proof
91                .iter()
92                .map(hash_to_hex)
93                .collect(),
94            leaf: hash_to_hex(&inclusion_proof_inputs[i].hash),
95        };
96        inputs.push(input);
97    }
98    inputs
99}
100
101#[derive(Serialize, Deserialize)]
102#[serde(rename_all = "camelCase")]
103struct HexBatchInputsForProver {
104    #[serde(
105        rename = "input-compressed-accounts",
106        skip_serializing_if = "Vec::is_empty"
107    )]
108    input_compressed_accounts: Vec<InclusionHexInputsForProver>,
109    #[serde(rename = "new-addresses", skip_serializing_if = "Vec::is_empty")]
110    new_addresses: Vec<NonInclusionHexInputsForProver>,
111}
112
113#[derive(Serialize, Deserialize, ToSchema)]
114#[serde(rename_all = "camelCase")]
115#[allow(non_snake_case)]
116pub struct CompressedProofWithContext {
117    pub compressedProof: CompressedProof,
118    roots: Vec<String>,
119    rootIndices: Vec<u64>,
120    leafIndices: Vec<u32>,
121    leaves: Vec<String>,
122    merkleTrees: Vec<String>,
123}
124
125fn hash_to_hex(hash: &Hash) -> String {
126    let bytes = hash.to_vec();
127    let hex = hex::encode(bytes);
128    format!("0x{}", hex)
129}
130
131fn pubkey_to_hex(pubkey: &SerializablePubkey) -> String {
132    let bytes = pubkey.to_bytes_vec();
133    let hex = hex::encode(bytes);
134    format!("0x{}", hex)
135}
136
137#[derive(Serialize, Deserialize, Debug)]
138struct GnarkProofJson {
139    ar: [String; 2],
140    bs: [[String; 2]; 2],
141    krs: [String; 2],
142}
143
144#[derive(Debug)]
145struct ProofABC {
146    a: Vec<u8>,
147    b: Vec<u8>,
148    c: Vec<u8>,
149}
150
151#[derive(Serialize, Deserialize, ToSchema, Default)]
152pub struct CompressedProof {
153    a: Vec<u8>,
154    b: Vec<u8>,
155    c: Vec<u8>,
156}
157
158fn deserialize_hex_string_to_bytes(hex_str: &str) -> Vec<u8> {
159    let hex_str = if hex_str.starts_with("0x") {
160        &hex_str[2..]
161    } else {
162        hex_str
163    };
164
165    let hex_str = format!("{:0>64}", hex_str);
167
168    hex::decode(&hex_str).expect("Failed to decode hex string")
169}
170
171fn proof_from_json_struct(json: GnarkProofJson) -> ProofABC {
172    let proof_ax = deserialize_hex_string_to_bytes(&json.ar[0]);
173    let proof_ay = deserialize_hex_string_to_bytes(&json.ar[1]);
174    let proof_a = [proof_ax, proof_ay].concat();
175
176    let proof_bx0 = deserialize_hex_string_to_bytes(&json.bs[0][0]);
177    let proof_bx1 = deserialize_hex_string_to_bytes(&json.bs[0][1]);
178    let proof_by0 = deserialize_hex_string_to_bytes(&json.bs[1][0]);
179    let proof_by1 = deserialize_hex_string_to_bytes(&json.bs[1][1]);
180    let proof_b = [proof_bx0, proof_bx1, proof_by0, proof_by1].concat();
181
182    let proof_cx = deserialize_hex_string_to_bytes(&json.krs[0]);
183    let proof_cy = deserialize_hex_string_to_bytes(&json.krs[1]);
184    let proof_c = [proof_cx, proof_cy].concat();
185
186    ProofABC {
187        a: proof_a,
188        b: proof_b,
189        c: proof_c,
190    }
191}
192
193fn y_element_is_positive_g1(y_element: &BigUint) -> bool {
194    y_element <= &(FIELD_SIZE.clone() - y_element)
195}
196
197fn y_element_is_positive_g2(y_element1: &BigUint, y_element2: &BigUint) -> bool {
198    let field_midpoint = FIELD_SIZE.clone() / 2u32;
199
200    if y_element1 < &field_midpoint {
201        true
202    } else if y_element1 > &field_midpoint {
203        false
204    } else {
205        y_element2 < &field_midpoint
206    }
207}
208
209fn add_bitmask_to_byte(mut byte: u8, y_is_positive: bool) -> u8 {
210    if !y_is_positive {
211        byte |= 1 << 7;
212    }
213    byte
214}
215
216fn negate_and_compress_proof(proof: ProofABC) -> CompressedProof {
217    let proof_a = &proof.a;
218    let proof_b = &proof.b;
219    let proof_c = &proof.c;
220
221    let a_x_element = &mut proof_a[0..32].to_vec();
222    let a_y_element = BigUint::from_bytes_be(&proof_a[32..64]);
223
224    let proof_a_is_positive = !y_element_is_positive_g1(&a_y_element);
225    a_x_element[0] = add_bitmask_to_byte(a_x_element[0], proof_a_is_positive);
226
227    let b_x_element = &mut proof_b[0..64].to_vec();
228    let b_y_element = &proof_b[64..128];
229    let b_y1_element = BigUint::from_bytes_be(&b_y_element[0..32]);
230    let b_y2_element = BigUint::from_bytes_be(&b_y_element[32..64]);
231
232    let proof_b_is_positive = y_element_is_positive_g2(&b_y1_element, &b_y2_element);
233    b_x_element[0] = add_bitmask_to_byte(b_x_element[0], proof_b_is_positive);
234
235    let c_x_element = &mut proof_c[0..32].to_vec();
236    let c_y_element = BigUint::from_bytes_be(&proof_c[32..64]);
237
238    let proof_c_is_positive = y_element_is_positive_g1(&c_y_element);
239    c_x_element[0] = add_bitmask_to_byte(c_x_element[0], proof_c_is_positive);
240
241    CompressedProof {
242        a: a_x_element.clone(),
243        b: b_x_element.clone(),
244        c: c_x_element.clone(),
245    }
246}
247
248#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
249#[serde(deny_unknown_fields, rename_all = "camelCase")]
250#[allow(non_snake_case)]
251pub struct GetValidityProofRequest {
252    #[serde(default)]
253    pub hashes: Vec<Hash>,
254    #[serde(default)]
255    pub newAddresses: Vec<SerializablePubkey>,
256    #[serde(default)]
257    pub newAddressesWithTrees: Vec<AddressWithTree>,
258}
259
260#[derive(Serialize, Deserialize, ToSchema)]
261#[serde(deny_unknown_fields, rename_all = "camelCase")]
262pub struct GetValidityProofResponse {
263    pub value: CompressedProofWithContext,
264    pub context: Context,
265}
266
267pub async fn get_validity_proof(
268    conn: &DatabaseConnection,
269    prover_url: &str,
270    mut request: GetValidityProofRequest,
271) -> Result<GetValidityProofResponse, PhotonApiError> {
272    if request.hashes.is_empty()
273        && request.newAddresses.is_empty()
274        && request.newAddressesWithTrees.is_empty()
275    {
276        return Err(PhotonApiError::UnexpectedError(
277            "No hashes or new addresses provided for proof generation".to_string(),
278        ));
279    }
280    if !request.newAddressesWithTrees.is_empty() && !request.newAddresses.is_empty() {
281        return Err(PhotonApiError::UnexpectedError(
282            "Cannot provide both newAddresses and newAddressesWithTree".to_string(),
283        ));
284    }
285    if !request.newAddresses.is_empty() {
286        request.newAddressesWithTrees = request
287            .newAddresses
288            .iter()
289            .map(|new_address| AddressWithTree {
290                address: *new_address,
291                tree: SerializablePubkey::from(ADDRESS_TREE_ADDRESS),
292            })
293            .collect();
294    }
295
296    let context = Context::extract(conn).await?;
297    let client = Client::new();
298    let tx = conn.begin().await?;
299    if tx.get_database_backend() == DatabaseBackend::Postgres {
300        tx.execute(Statement::from_string(
301            tx.get_database_backend(),
302            "SET TRANSACTION ISOLATION LEVEL REPEATABLE READ;".to_string(),
303        ))
304        .await?;
305    }
306
307    let account_proofs = match !request.hashes.is_empty() {
308        true => get_multiple_compressed_leaf_proofs(&tx, request.hashes).await?,
309        false => {
310            vec![]
311        }
312    };
313    let new_address_proofs = match !request.newAddressesWithTrees.is_empty() {
314        true => get_multiple_new_address_proofs_helper(&tx, request.newAddressesWithTrees).await?,
315        false => {
316            vec![]
317        }
318    };
319    tx.commit().await?;
320
321    let batch_inputs = HexBatchInputsForProver {
322        input_compressed_accounts: convert_inclusion_proofs_to_hex(account_proofs.clone()),
323        new_addresses: convert_non_inclusion_merkle_proof_to_hex(new_address_proofs.clone()),
324    };
325
326    let inclusion_proof_url = format!("{}/prove", prover_url);
327    let json_body = serde_json::to_string(&batch_inputs).map_err(|e| {
328        PhotonApiError::UnexpectedError(format!("Got an error while serializing the request {}", e))
329    })?;
330    let res = client
331        .post(&inclusion_proof_url)
332        .body(json_body.clone())
333        .header("Content-Type", "application/json")
334        .send()
335        .await
336        .map_err(|e| PhotonApiError::UnexpectedError(format!("Error fetching proof {}", e)))?;
337
338    if !res.status().is_success() {
339        return Err(PhotonApiError::UnexpectedError(format!(
340            "Error fetching proof {:?}",
341            res.text().await,
342        )));
343    }
344
345    let text = res
346        .text()
347        .await
348        .map_err(|e| PhotonApiError::UnexpectedError(format!("Error fetching proof {}", e)))?;
349
350    let proof: GnarkProofJson = serde_json::from_str(&text).map_err(|e| {
351        PhotonApiError::UnexpectedError(format!(
352            "Got an error while deserializing the response {}",
353            e
354        ))
355    })?;
356
357    let proof = proof_from_json_struct(proof);
358    #[allow(non_snake_case)]
360    let compressedProof = negate_and_compress_proof(proof);
361
362    let compressed_proof_with_context = CompressedProofWithContext {
363        compressedProof,
364        roots: account_proofs
365            .iter()
366            .map(|x| x.root.clone().to_string())
367            .chain(
368                new_address_proofs
369                    .iter()
370                    .map(|x| x.root.clone().to_string()),
371            )
372            .collect(),
373        rootIndices: account_proofs
374            .iter()
375            .map(|x| x.rootSeq)
376            .chain(new_address_proofs.iter().map(|x| x.rootSeq))
377            .map(|x| x % STATE_TREE_QUEUE_SIZE)
378            .collect(),
379        leafIndices: account_proofs
380            .iter()
381            .map(|x| x.leafIndex)
382            .chain(new_address_proofs.iter().map(|x| x.lowElementLeafIndex))
383            .collect(),
384        leaves: account_proofs
385            .iter()
386            .map(|x| x.hash.clone().to_string())
387            .chain(
388                new_address_proofs
389                    .iter()
390                    .map(|x| x.address.clone().to_string()),
391            )
392            .collect(),
393        merkleTrees: account_proofs
394            .iter()
395            .map(|x| x.merkleTree.clone().to_string())
396            .chain(
397                new_address_proofs
398                    .iter()
399                    .map(|x| x.merkleTree.clone().to_string()),
400            )
401            .collect(),
402    };
403    Ok(GetValidityProofResponse {
404        value: compressed_proof_with_context,
405        context,
406    })
407}