Skip to main content

pir_client/
lib.rs

1//! PIR client library for private Merkle path retrieval.
2//!
3//! Provides [`PirClient`] which connects to a `pir-server` instance and
4//! retrieves circuit-ready `ImtProofData` without revealing the
5//! queried nullifier to the server.
6
7use std::{sync::Arc, time::Instant};
8
9use anyhow::{Context, Result};
10use ff::PrimeField as _;
11use imt_tree::hasher::PoseidonHasher;
12use imt_tree::tree::{precompute_empty_hashes, TREE_DEPTH};
13use pasta_curves::Fp;
14// Re-exported so downstream crates (e.g. zcash_voting) can reference the type
15// returned by PirClientBlocking::fetch_proof without a direct imt-tree dependency.
16pub use imt_tree::ImtProofData;
17
18mod transport;
19pub use transport::{Transport, TransportFuture, TransportResponse};
20
21use pir_types::tier0::Tier0Data;
22use pir_types::tier1::Tier1Row;
23use pir_types::tier2::Tier2Row;
24use pir_types::{
25    serialize_ypir_query, RootInfo, YpirScenario, PIR_DEPTH, TIER0_LAYERS, TIER1_LAYERS,
26    TIER1_LEAVES, TIER1_ROW_BYTES, TIER2_LEAVES, TIER2_ROW_BYTES,
27};
28
29use ypir::client::YPIRClient;
30
31// ── Timing breakdown ─────────────────────────────────────────────────────────
32
33/// Per-tier timing breakdown for a single YPIR query, measuring each stage
34/// of the client-server round trip.
35pub struct TierTiming {
36    /// Client-side YPIR query generation time.
37    pub gen_ms: f64,
38    /// Size of the uploaded query payload.
39    pub upload_bytes: usize,
40    /// Bytes of the uploaded query attributable to the SimplePIR query
41    /// vector itself (`q.0` / `pqr` — the first arg to
42    /// [`pir_types::serialize_ypir_query`]).
43    pub upload_q_bytes: usize,
44    /// Bytes of the uploaded query attributable to `pack_pub_params`
45    /// (the second arg to [`pir_types::serialize_ypir_query`]). Identical
46    /// across queries that share a YPIR `client_seed`.
47    pub upload_pp_bytes: usize,
48    /// Size of the downloaded encrypted response.
49    pub download_bytes: usize,
50    /// Wall-clock round-trip time (upload + server compute + download).
51    pub rtt_ms: f64,
52    /// Client-side YPIR response decryption time.
53    pub decode_ms: f64,
54    /// Server-assigned request ID (from response header).
55    pub server_req_id: Option<u64>,
56    /// Server-reported total processing time.
57    pub server_total_ms: Option<f64>,
58    /// Server-reported query validation time.
59    pub server_validate_ms: Option<f64>,
60    /// Server-reported decode+copy time.
61    pub server_decode_copy_ms: Option<f64>,
62    /// Server-reported YPIR online computation time.
63    pub server_compute_ms: Option<f64>,
64    /// Estimated network + queue latency (RTT minus server time).
65    pub net_queue_ms: Option<f64>,
66    /// Estimated upload-to-server latency.
67    pub upload_to_server_ms: Option<f64>,
68    /// Estimated download-from-server latency.
69    pub download_from_server_ms: f64,
70}
71
72/// Per-note timing breakdown covering both tier 1 and tier 2 YPIR queries.
73pub struct NoteTiming {
74    pub tier1: TierTiming,
75    pub tier2: TierTiming,
76    /// Total wall-clock time for this note's proof retrieval.
77    pub total_ms: f64,
78}
79
80// ── HTTP-based PIR client ────────────────────────────────────────────────────
81
82/// PIR client that connects to a `pir-server` instance over HTTP.
83///
84/// Downloads Tier 0 data and YPIR parameters during `connect()`, then
85/// performs private queries via `fetch_proof()`.
86pub struct PirClient {
87    server_url: String,
88    transport: Arc<dyn Transport>,
89    tier0: Tier0Data,
90    tier1_scenario: YpirScenario,
91    tier2_scenario: YpirScenario,
92    num_ranges: usize,
93    empty_hashes: [Fp; TREE_DEPTH],
94    root29: Fp,
95}
96
97/// Return the number of populated leaves in a Tier 2 row, clamped to
98/// [`TIER2_LEAVES`]. The final row may be only partially filled when
99/// `num_ranges` is not a multiple of the row size.
100#[inline]
101fn valid_leaves_for_row(num_ranges: usize, row_idx: usize) -> usize {
102    let row_start = row_idx.saturating_mul(TIER2_LEAVES);
103    num_ranges.saturating_sub(row_start).min(TIER2_LEAVES)
104}
105
106// ── Shared tier-processing helpers ───────────────────────────────────────────
107
108/// Copy `siblings` into `path` starting at `offset`.
109#[inline]
110fn fill_path(path: &mut [Fp; TREE_DEPTH], offset: usize, siblings: &[Fp]) {
111    path[offset..offset + siblings.len()].copy_from_slice(siblings);
112}
113
114/// Locate the nullifier's subtree in Tier 0, fill its siblings into `path`,
115/// and return the subtree index `s1`.
116fn process_tier0(tier0: &Tier0Data, nullifier: Fp, path: &mut [Fp; TREE_DEPTH]) -> Result<usize> {
117    let s1 = tier0
118        .find_subtree(nullifier)
119        .context("nullifier not found in any Tier 0 subtree")?;
120    fill_path(path, PIR_DEPTH - TIER0_LAYERS, &tier0.extract_siblings(s1));
121    Ok(s1)
122}
123
124/// Parse a Tier 1 row, locate the nullifier's sub-subtree, fill its siblings
125/// into `path`, and return the sub-subtree index `s2`.
126fn process_tier1(tier1_row: &[u8], nullifier: Fp, path: &mut [Fp; TREE_DEPTH]) -> Result<usize> {
127    let hasher = PoseidonHasher::new();
128    let tier1 = Tier1Row::from_bytes(tier1_row)?;
129    let s2 = tier1
130        .find_sub_subtree(nullifier)
131        .context("nullifier not found in any Tier 1 sub-subtree")?;
132    fill_path(
133        path,
134        PIR_DEPTH - TIER0_LAYERS - TIER1_LAYERS,
135        &tier1.extract_siblings(s2, &hasher),
136    );
137    Ok(s2)
138}
139
140/// Parse a Tier 2 row, locate the nullifier's leaf, fill tier-2 and padding
141/// siblings into `path`, and assemble the final [`ImtProofData`].
142fn process_tier2_and_build(
143    tier2_row: &[u8],
144    t2_row_idx: usize,
145    num_ranges: usize,
146    nullifier: Fp,
147    path: &mut [Fp; TREE_DEPTH],
148    empty_hashes: &[Fp; TREE_DEPTH],
149    root29: Fp,
150) -> Result<ImtProofData> {
151    let hasher = PoseidonHasher::new();
152    let tier2 = Tier2Row::from_bytes(tier2_row)?;
153    let valid_leaves = valid_leaves_for_row(num_ranges, t2_row_idx);
154
155    let leaf_local_idx = tier2
156        .find_leaf(nullifier, valid_leaves)
157        .context("nullifier not found in Tier 2 leaf scan")?;
158
159    fill_path(
160        path,
161        0,
162        &tier2.extract_siblings(leaf_local_idx, valid_leaves, &hasher),
163    );
164    // Pad from PIR depth (25) to circuit depth (29) with empty hashes.
165    fill_path(path, PIR_DEPTH, &empty_hashes[PIR_DEPTH..TREE_DEPTH]);
166
167    let global_leaf_idx = t2_row_idx * TIER2_LEAVES + leaf_local_idx;
168    let (nf_lo, nf_mid, nf_hi) = tier2.leaf_record(leaf_local_idx);
169
170    Ok(ImtProofData {
171        root: root29,
172        nf_bounds: [nf_lo, nf_mid, nf_hi],
173        leaf_pos: global_leaf_idx as u32,
174        path: *path,
175    })
176}
177
178impl PirClient {
179    /// Connect using a caller-provided HTTP transport.
180    pub async fn with_transport(server_url: &str, transport: Arc<dyn Transport>) -> Result<Self> {
181        let base = server_url.trim_end_matches('/');
182
183        // Download Tier 0 data, YPIR params, and root concurrently
184        let t0 = Instant::now();
185        let tier0_url = format!("{base}/tier0");
186        let tier1_url = format!("{base}/params/tier1");
187        let tier2_url = format!("{base}/params/tier2");
188        let root_url = format!("{base}/root");
189        let (tier0_resp, tier1_resp, tier2_resp, root_resp) = tokio::try_join!(
190            transport.get(&tier0_url),
191            transport.get(&tier1_url),
192            transport.get(&tier2_url),
193            transport.get(&root_url),
194        )
195        .map_err(|e| anyhow::anyhow!("connect fetch failed: {e}"))?;
196
197        let tier0_bytes = body_for_status(tier0_resp, "GET /tier0 failed")?;
198        log::debug!(
199            "Downloaded Tier 0: {} bytes in {:.1}s",
200            tier0_bytes.len(),
201            t0.elapsed().as_secs_f64()
202        );
203        let tier0 = Tier0Data::from_bytes(tier0_bytes.to_vec())?;
204
205        let tier1_scenario: YpirScenario =
206            serde_json::from_slice(&body_for_status(tier1_resp, "GET /params/tier1 failed")?)
207                .context("parse /params/tier1 response")?;
208        let tier2_scenario: YpirScenario =
209            serde_json::from_slice(&body_for_status(tier2_resp, "GET /params/tier2 failed")?)
210                .context("parse /params/tier2 response")?;
211
212        let root_info: RootInfo =
213            serde_json::from_slice(&body_for_status(root_resp, "GET /root failed")?)
214                .context("parse /root response")?;
215        anyhow::ensure!(
216            root_info.pir_depth == PIR_DEPTH,
217            "server pir_depth {} != expected {}",
218            root_info.pir_depth,
219            PIR_DEPTH
220        );
221        let root29_bytes = hex::decode(&root_info.root29)?;
222        anyhow::ensure!(
223            root29_bytes.len() == 32,
224            "root29 hex decoded to {} bytes, expected 32",
225            root29_bytes.len()
226        );
227        let mut root29_arr = [0u8; 32];
228        root29_arr.copy_from_slice(&root29_bytes);
229        let root29 = Option::from(Fp::from_repr(root29_arr))
230            .ok_or_else(|| anyhow::anyhow!("invalid root29 field element"))?;
231
232        let empty_hashes = precompute_empty_hashes();
233
234        Ok(Self {
235            server_url: base.to_string(),
236            transport,
237            tier0,
238            tier1_scenario,
239            tier2_scenario,
240            num_ranges: root_info.num_ranges,
241            empty_hashes,
242            root29,
243        })
244    }
245
246    /// Perform private Merkle path retrieval for a nullifier.
247    ///
248    /// Returns circuit-ready `ImtProofData` with a 29-element path
249    /// (25 PIR siblings + 4 empty-hash padding).
250    pub async fn fetch_proof(&self, nullifier: Fp) -> Result<ImtProofData> {
251        let (proof, _timing) = self.fetch_proof_inner(nullifier).await?;
252        Ok(proof)
253    }
254
255    /// Like [`fetch_proof`](Self::fetch_proof) but also returns the full
256    /// client+server timing breakdown for load-testing / observability.
257    pub async fn fetch_proof_with_timing(
258        &self,
259        nullifier: Fp,
260    ) -> Result<(ImtProofData, NoteTiming)> {
261        self.fetch_proof_inner(nullifier).await
262    }
263
264    /// Perform private Merkle path retrieval for multiple nullifiers in parallel.
265    ///
266    /// All queries run concurrently via `try_join_all`, sharing the same
267    /// `PirClient` (and thus the same HTTP client and Tier 0 data).
268    pub async fn fetch_proofs(&self, nullifiers: &[Fp]) -> Result<Vec<ImtProofData>> {
269        log::debug!(
270            "[PIR] Starting parallel fetch for {} notes...",
271            nullifiers.len()
272        );
273        let wall_start = Instant::now();
274
275        let futures: Vec<_> = nullifiers
276            .iter()
277            .enumerate()
278            .map(|(i, &nf)| async move {
279                let (proof, timing) = self.fetch_proof_inner(nf).await?;
280                Ok::<_, anyhow::Error>((i, proof, timing))
281            })
282            .collect();
283
284        let results_with_timing = futures::future::try_join_all(futures).await?;
285        let wall_ms = wall_start.elapsed().as_secs_f64() * 1000.0;
286
287        print_timing_table(&results_with_timing, wall_ms);
288
289        let proofs = results_with_timing
290            .into_iter()
291            .map(|(_, proof, _)| proof)
292            .collect();
293        Ok(proofs)
294    }
295
296    /// Fetch proof and return timing breakdown.
297    ///
298    /// **Error-oracle mitigation**: the tier 2 query is always sent even when
299    /// tier 1 fails. A malicious server could craft a tier 1 response whose
300    /// decryption outcome depends on the client's secret key material (e.g. by
301    /// triggering an assert in the LWE decode path). If the client aborted
302    /// before sending the tier 2 query, the server could observe its absence
303    /// and use the binary "crash / no-crash" signal as an oracle. By
304    /// unconditionally sending a (possibly dummy) tier 2 query we ensure the
305    /// server always sees both requests and gains no information from errors.
306    async fn fetch_proof_inner(&self, nullifier: Fp) -> Result<(ImtProofData, NoteTiming)> {
307        let note_start = Instant::now();
308        let mut path = [Fp::default(); TREE_DEPTH];
309
310        // Process tier 0 (plaintext, not server-controlled)
311        let s1 = process_tier0(&self.tier0, nullifier, &mut path)?;
312
313        // Process tier 1 (PIR) — capture the outcome without `?` so that a
314        // tier 2 query is always sent regardless of tier 1 success.
315        //
316        // process_tier1 is wrapped in catch_unwind so that a panic (e.g. from
317        // a debug_assert or an unexpected slice bounds violation) cannot
318        // prevent the tier 2 query from being sent. Without this, a panic
319        // here would unwind past the tier 2 dispatch and give the server an
320        // observable one-query-vs-two oracle.
321        let tier1_outcome = self
322            .ypir_query(&self.tier1_scenario, "tier1", s1, TIER1_ROW_BYTES)
323            .await
324            .and_then(|(row, timing)| {
325                let mut_path = &mut path;
326                let s2 = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
327                    process_tier1(&row, nullifier, mut_path)
328                }))
329                .unwrap_or_else(|payload| {
330                    let msg = payload
331                        .downcast_ref::<String>()
332                        .map(|s| s.as_str())
333                        .or_else(|| payload.downcast_ref::<&str>().copied())
334                        .unwrap_or("unknown panic");
335                    Err(anyhow::anyhow!("process_tier1 panicked: {}", msg))
336                })?;
337                Ok((s1 * TIER1_LEAVES + s2, timing))
338            });
339
340        // Real index on success, dummy index 0 on failure. PIR hides the
341        // queried index from the server, so the dummy is indistinguishable.
342        let t2_row_idx = tier1_outcome.as_ref().map(|(idx, _)| *idx).unwrap_or(0);
343
344        // Validate the tier 2 index before passing it to ypir_query.
345        // ypir_query has an ensure!(row_idx < num_items) that returns Err
346        // *before* sending the HTTP request — if that fires, no tier 2
347        // request reaches the server and we leak an oracle bit. A malicious
348        // server can trigger this by setting tier2 num_items too small or
349        // crafting tier 1 data that produces out-of-bounds indices. Clamp to
350        // dummy index 0 so the query always goes out; propagate the error
351        // only after both queries have been sent.
352        let t2_bounds_err = if t2_row_idx >= self.tier2_scenario.num_items {
353            Some(anyhow::anyhow!(
354                "tier2 row_idx {} >= num_items {}",
355                t2_row_idx,
356                self.tier2_scenario.num_items
357            ))
358        } else {
359            None
360        };
361        let t2_query_idx = if t2_bounds_err.is_some() {
362            0
363        } else {
364            t2_row_idx
365        };
366
367        // Always send tier 2 to void error-based oracles.
368        let tier2_result = self
369            .ypir_query(&self.tier2_scenario, "tier2", t2_query_idx, TIER2_ROW_BYTES)
370            .await;
371
372        // Propagate errors only after both queries have been sent.
373        let (t2_row_idx, tier1_timing) = tier1_outcome?;
374        if let Some(e) = t2_bounds_err {
375            return Err(e);
376        }
377        let (tier2_row, tier2_timing) = tier2_result?;
378
379        let proof = process_tier2_and_build(
380            &tier2_row,
381            t2_row_idx,
382            self.num_ranges,
383            nullifier,
384            &mut path,
385            &self.empty_hashes,
386            self.root29,
387        )?;
388
389        let total_ms = note_start.elapsed().as_secs_f64() * 1000.0;
390        Ok((
391            proof,
392            NoteTiming {
393                tier1: tier1_timing,
394                tier2: tier2_timing,
395                total_ms,
396            },
397        ))
398    }
399
400    /// Send a YPIR query for a tier row and return the decrypted row bytes.
401    /// This function handles the key client PIR operations:
402    /// 1. Generate keys
403    /// 2. Query
404    /// 3. Recover
405    async fn ypir_query(
406        &self,
407        scenario: &YpirScenario,
408        tier_name: &str,
409        row_idx: usize,
410        expected_row_bytes: usize,
411    ) -> Result<(Vec<u8>, TierTiming)> {
412        anyhow::ensure!(
413            row_idx < scenario.num_items,
414            "{} row_idx {} >= num_items {}",
415            tier_name,
416            row_idx,
417            scenario.num_items
418        );
419        let t0 = Instant::now();
420        let ypir_client = YPIRClient::from_db_sz(
421            scenario.num_items as u64,
422            scenario.item_size_bits as u64,
423            true,
424        );
425
426        // Generate PIR query from a fresh secret created from OsRng seed.
427        let (query, seed) = ypir_client.generate_query_simplepir(row_idx);
428        let gen_ms = t0.elapsed().as_secs_f64() * 1000.0;
429
430        // Serialize query. `query.0` is the SimplePIR query vector
431        // (per-query); `query.1` is `pack_pub_params` (depends only on
432        // the client's `client_seed`).
433        let upload_q_bytes = query.0.as_slice().len() * std::mem::size_of::<u64>();
434        let upload_pp_bytes = query.1.as_slice().len() * std::mem::size_of::<u64>();
435        let payload = serialize_ypir_query(query.0.as_slice(), query.1.as_slice());
436        let upload_bytes = payload.len();
437
438        // Send the request
439        let t1 = Instant::now();
440        let url = format!("{}/{}/query", self.server_url, tier_name);
441        let send_result = self.transport.post(&url, payload).await;
442        let send_ms = t1.elapsed().as_secs_f64() * 1000.0;
443        let resp = match send_result {
444            Ok(r) => r,
445            Err(e) => {
446                log::warn!("YPIR {} send error: {:?}", tier_name, e);
447                return Err(e);
448            }
449        };
450        let server_req_id = parse_header_u64(&resp.headers, "x-pir-req-id");
451        let server_total_ms = parse_header_f64(&resp.headers, "x-pir-server-total-ms");
452        let server_validate_ms = parse_header_f64(&resp.headers, "x-pir-server-validate-ms");
453        let server_decode_copy_ms = parse_header_f64(&resp.headers, "x-pir-server-decode-copy-ms");
454        let server_compute_ms = parse_header_f64(&resp.headers, "x-pir-server-compute-ms");
455        let status = resp.status;
456        let response_bytes = resp.body;
457        if !is_success(status) {
458            anyhow::bail!(
459                "{} query failed: HTTP {} body={}",
460                tier_name,
461                status,
462                String::from_utf8_lossy(&response_bytes)
463            );
464        }
465        let rtt_ms = t1.elapsed().as_secs_f64() * 1000.0;
466        let download_from_server_ms = (rtt_ms - send_ms).max(0.0);
467        let net_queue_ms = server_total_ms.map(|server_ms| (rtt_ms - server_ms).max(0.0));
468        let upload_to_server_ms = server_total_ms.map(|server_ms| (send_ms - server_ms).max(0.0));
469
470        // Decode the response. Wrap in catch_unwind so that assert panics
471        // in the YPIR library (e.g. `val < lwe_q_prime` in the LWE decode
472        // path) become recoverable errors rather than process aborts. This is
473        // necessary for the error-oracle mitigation in fetch_proof_inner:
474        // a panic here must not prevent the second query from being sent.
475        let t2 = Instant::now();
476        let decoded = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
477            ypir_client.decode_response_simplepir(seed, &response_bytes)
478        }))
479        .map_err(|panic_payload| {
480            let msg = panic_payload
481                .downcast_ref::<String>()
482                .map(|s| s.as_str())
483                .or_else(|| panic_payload.downcast_ref::<&str>().copied())
484                .unwrap_or("unknown panic");
485            anyhow::anyhow!("{} response decryption panicked: {}", tier_name, msg)
486        })?;
487        let decode_ms = t2.elapsed().as_secs_f64() * 1000.0;
488
489        anyhow::ensure!(
490            decoded.len() >= expected_row_bytes,
491            "{} decoded response too short: {} bytes, expected >= {}",
492            tier_name,
493            decoded.len(),
494            expected_row_bytes
495        );
496        Ok((
497            decoded[..expected_row_bytes].to_vec(),
498            TierTiming {
499                gen_ms,
500                upload_bytes,
501                upload_q_bytes,
502                upload_pp_bytes,
503                download_bytes: response_bytes.len(),
504                rtt_ms,
505                decode_ms,
506                server_req_id,
507                server_total_ms,
508                server_validate_ms,
509                server_decode_copy_ms,
510                server_compute_ms,
511                net_queue_ms,
512                upload_to_server_ms,
513                download_from_server_ms,
514            },
515        ))
516    }
517}
518
519fn fmt_time(ms: f64) -> String {
520    if ms >= 1000.0 {
521        format!("{:>5.1}s ", ms / 1000.0)
522    } else {
523        format!("{:>5.0}ms", ms)
524    }
525}
526
527fn fmt_opt_time(ms: Option<f64>) -> String {
528    match ms {
529        Some(v) => fmt_time(v),
530        None => "  n/a ".to_string(),
531    }
532}
533
534fn is_success(status: u16) -> bool {
535    (200..300).contains(&status)
536}
537
538fn body_for_status(response: TransportResponse, context: &'static str) -> Result<Vec<u8>> {
539    if is_success(response.status) {
540        Ok(response.body)
541    } else {
542        anyhow::bail!(
543            "{}: HTTP {} body={}",
544            context,
545            response.status,
546            String::from_utf8_lossy(&response.body)
547        )
548    }
549}
550
551/// Print a detailed timing breakdown table for a batch of PIR proof fetches.
552fn print_timing_table(results: &[(usize, ImtProofData, NoteTiming)], wall_ms: f64) {
553    if !log::log_enabled!(log::Level::Debug) {
554        return;
555    }
556
557    log::debug!("[PIR] ┌─────┬──────────┬─────────────┬──────────┬──────────┬─────────────┬──────────┬────────┐");
558    log::debug!("[PIR] │ Note│ T1 keygen│ T1 upload+  │ T1 decode│ T2 keygen│ T2 upload+  │ T2 decode│ Total  │");
559    log::debug!("[PIR] │     │ (client) │ server+down │ (client) │ (client) │ server+down │ (client) │        │");
560    log::debug!("[PIR] ├─────┼──────────┼─────────────┼──────────┼──────────┼─────────────┼──────────┼────────┤");
561    for &(i, _, ref t) in results {
562        log::debug!(
563            "[PIR] │  {i:>2} │  {:>6} │   {:>7}   │  {:>6} │  {:>6} │   {:>7}   │  {:>6} │{} │",
564            fmt_time(t.tier1.gen_ms),
565            fmt_time(t.tier1.rtt_ms),
566            fmt_time(t.tier1.decode_ms),
567            fmt_time(t.tier2.gen_ms),
568            fmt_time(t.tier2.rtt_ms),
569            fmt_time(t.tier2.decode_ms),
570            fmt_time(t.total_ms),
571        );
572    }
573    log::debug!("[PIR] └─────┴──────────┴─────────────┴──────────┴──────────┴─────────────┴──────────┴────────┘");
574    log::debug!(
575        "[PIR] Upload per note: T1={:.0}KB T2={:.1}MB  |  Wall clock: {:.2}s",
576        results
577            .first()
578            .map(|(_, _, t)| t.tier1.upload_bytes)
579            .unwrap_or(0) as f64
580            / 1024.0,
581        results
582            .first()
583            .map(|(_, _, t)| t.tier2.upload_bytes)
584            .unwrap_or(0) as f64
585            / (1024.0 * 1024.0),
586        wall_ms / 1000.0,
587    );
588
589    for &(i, _, ref t) in results {
590        log::trace!(
591            "[PIR] Note {i:>2} transfer: T1 up={:.0}KB down={:.0}KB | T2 up={:.1}MB down={:.0}KB",
592            t.tier1.upload_bytes as f64 / 1024.0,
593            t.tier1.download_bytes as f64 / 1024.0,
594            t.tier2.upload_bytes as f64 / (1024.0 * 1024.0),
595            t.tier2.download_bytes as f64 / 1024.0,
596        );
597        log::trace!(
598            "[PIR] Note {i:>2} server/net: T1 {} / {} | T2 {} / {}",
599            fmt_opt_time(t.tier1.server_total_ms),
600            fmt_opt_time(t.tier1.net_queue_ms),
601            fmt_opt_time(t.tier2.server_total_ms),
602            fmt_opt_time(t.tier2.net_queue_ms),
603        );
604        log::trace!(
605            "[PIR] Note {i:>2} up/srv/down: T1 {} / {} / {} | T2 {} / {} / {}",
606            fmt_opt_time(t.tier1.upload_to_server_ms),
607            fmt_opt_time(t.tier1.server_total_ms),
608            fmt_time(t.tier1.download_from_server_ms),
609            fmt_opt_time(t.tier2.upload_to_server_ms),
610            fmt_opt_time(t.tier2.server_total_ms),
611            fmt_time(t.tier2.download_from_server_ms),
612        );
613        log::trace!(
614            "[PIR] Note {i:>2} server stages: T1(v={} copy={} compute={}) T2(v={} copy={} compute={})",
615            fmt_opt_time(t.tier1.server_validate_ms),
616            fmt_opt_time(t.tier1.server_decode_copy_ms),
617            fmt_opt_time(t.tier1.server_compute_ms),
618            fmt_opt_time(t.tier2.server_validate_ms),
619            fmt_opt_time(t.tier2.server_decode_copy_ms),
620            fmt_opt_time(t.tier2.server_compute_ms),
621        );
622        log::trace!(
623            "[PIR] Note {i:>2} req ids: T1={:?} T2={:?}",
624            t.tier1.server_req_id,
625            t.tier2.server_req_id
626        );
627    }
628}
629
630/// Parse an HTTP response header value as `f64`, returning `None` on missing or malformed values.
631fn parse_header_f64(headers: &[(String, String)], name: &'static str) -> Option<f64> {
632    headers
633        .iter()
634        .find(|(header_name, _)| header_name.eq_ignore_ascii_case(name))
635        .and_then(|(_, value)| value.parse::<f64>().ok())
636}
637
638/// Parse an HTTP response header value as `u64`, returning `None` on missing or malformed values.
639fn parse_header_u64(headers: &[(String, String)], name: &'static str) -> Option<u64> {
640    headers
641        .iter()
642        .find(|(header_name, _)| header_name.eq_ignore_ascii_case(name))
643        .and_then(|(_, value)| value.parse::<u64>().ok())
644}
645
646// ── Blocking wrapper ─────────────────────────────────────────────────────────
647
648/// Synchronous wrapper around [`PirClient`] for use from non-async code.
649///
650/// Owns a Tokio runtime internally so callers (e.g. zcash_voting, which must
651/// stay synchronous for the Halo2 prover) don't need to manage one.
652pub struct PirClientBlocking {
653    inner: PirClient,
654    rt: tokio::runtime::Runtime,
655}
656
657impl PirClientBlocking {
658    /// Connect to a PIR server with a caller-provided HTTP transport.
659    pub fn with_transport(server_url: &str, transport: Arc<dyn Transport>) -> Result<Self> {
660        let rt = tokio::runtime::Runtime::new()?;
661        let inner = rt.block_on(PirClient::with_transport(server_url, transport))?;
662        Ok(Self { inner, rt })
663    }
664
665    /// Perform a private Merkle path retrieval for a nullifier (blocking).
666    pub fn fetch_proof(&self, nullifier: Fp) -> Result<ImtProofData> {
667        self.rt.block_on(self.inner.fetch_proof(nullifier))
668    }
669
670    /// Perform private Merkle path retrieval for multiple nullifiers in parallel (blocking).
671    pub fn fetch_proofs(&self, nullifiers: &[Fp]) -> Result<Vec<ImtProofData>> {
672        self.rt.block_on(self.inner.fetch_proofs(nullifiers))
673    }
674
675    /// The depth-29 root (PIR depth 25 padded to tree depth 29).
676    pub fn root29(&self) -> Fp {
677        self.inner.root29
678    }
679}
680
681// ── Local (in-process) PIR client ────────────────────────────────────────────
682
683/// Perform a complete local PIR proof retrieval without HTTP.
684///
685/// This is used by `pir-test local` mode. It takes the tier data directly
686/// (as built by `pir-export`) and performs the YPIR operations in-process.
687pub fn fetch_proof_local(
688    tier0_data: &[u8],
689    tier1_data: &[u8],
690    tier2_data: &[u8],
691    num_ranges: usize,
692    nullifier: Fp,
693    empty_hashes: &[Fp; TREE_DEPTH],
694    root29: Fp,
695) -> Result<ImtProofData> {
696    let mut path = [Fp::default(); TREE_DEPTH];
697    let tier0 = Tier0Data::from_bytes(tier0_data.to_vec())?;
698
699    let s1 = process_tier0(&tier0, nullifier, &mut path)?;
700
701    // ── Tier 1: direct row lookup (no YPIR in local mode) ────────────────
702    let t1_offset = s1 * TIER1_ROW_BYTES;
703    anyhow::ensure!(
704        t1_offset + TIER1_ROW_BYTES <= tier1_data.len(),
705        "tier1 data too short: need {} bytes at offset {}, have {}",
706        TIER1_ROW_BYTES,
707        t1_offset,
708        tier1_data.len()
709    );
710    let s2 = process_tier1(
711        &tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES],
712        nullifier,
713        &mut path,
714    )?;
715
716    // ── Tier 2: direct row lookup (no YPIR in local mode) ────────────────
717    let t2_row_idx = s1 * TIER1_LEAVES + s2;
718    let t2_offset = t2_row_idx * TIER2_ROW_BYTES;
719    anyhow::ensure!(
720        t2_offset + TIER2_ROW_BYTES <= tier2_data.len(),
721        "tier2 data too short: need {} bytes at offset {}, have {}",
722        TIER2_ROW_BYTES,
723        t2_offset,
724        tier2_data.len()
725    );
726
727    process_tier2_and_build(
728        &tier2_data[t2_offset..t2_offset + TIER2_ROW_BYTES],
729        t2_row_idx,
730        num_ranges,
731        nullifier,
732        &mut path,
733        empty_hashes,
734        root29,
735    )
736}
737
738#[cfg(test)]
739mod tests {
740    use super::*;
741    use ff::Field;
742    use pasta_curves::Fp;
743    use pir_export::build_ranges_with_sentinels;
744
745    /// Build a tree and export all three tier blobs.
746    struct TestFixture {
747        tier0_data: Vec<u8>,
748        tier1_data: Vec<u8>,
749        tier2_data: Vec<u8>,
750        ranges: Vec<[Fp; 3]>,
751        empty_hashes: [Fp; TREE_DEPTH],
752        root29: Fp,
753    }
754
755    impl TestFixture {
756        fn build(raw_nfs: &[Fp]) -> Self {
757            let ranges = build_ranges_with_sentinels(raw_nfs);
758            let tree = pir_export::build_pir_tree(ranges.clone()).unwrap();
759
760            let tier0_data = pir_export::tier0::export(
761                &tree.root25,
762                &tree.levels,
763                &tree.ranges,
764                &tree.empty_hashes,
765            );
766            let mut tier1_data = Vec::new();
767            pir_export::tier1::export(
768                &tree.levels,
769                &tree.ranges,
770                &tree.empty_hashes,
771                &mut tier1_data,
772            )
773            .unwrap();
774            let mut tier2_data = Vec::new();
775            pir_export::tier2::export(&tree.ranges, &mut tier2_data).unwrap();
776
777            Self {
778                tier0_data,
779                tier1_data,
780                tier2_data,
781                ranges,
782                empty_hashes: tree.empty_hashes,
783                root29: tree.root29,
784            }
785        }
786    }
787
788    // ── fetch_proof_local round-trip ──────────────────────────────────────
789
790    #[test]
791    fn fetch_proof_local_verifies_for_known_ranges() {
792        let mut rng = rand::thread_rng();
793        let raw_nfs: Vec<Fp> = (0..100).map(|_| Fp::random(&mut rng)).collect();
794        let fix = TestFixture::build(&raw_nfs);
795
796        for &[nf_lo, _, _] in fix.ranges.iter().take(20) {
797            let value = nf_lo + Fp::one();
798            let proof = fetch_proof_local(
799                &fix.tier0_data,
800                &fix.tier1_data,
801                &fix.tier2_data,
802                fix.ranges.len(),
803                value,
804                &fix.empty_hashes,
805                fix.root29,
806            )
807            .expect("fetch_proof_local should succeed for a value in range");
808            assert!(
809                proof.verify(value),
810                "proof should verify for value {:?}",
811                value,
812            );
813        }
814    }
815
816    #[test]
817    fn fetch_proof_local_correct_root_and_path_length() {
818        let raw_nfs: Vec<Fp> = (1u64..=50).map(|i| Fp::from(i * 997)).collect();
819        let fix = TestFixture::build(&raw_nfs);
820
821        let value = fix.ranges[0][0] + Fp::one(); // nf_lo + 1 is inside the range
822        let proof = fetch_proof_local(
823            &fix.tier0_data,
824            &fix.tier1_data,
825            &fix.tier2_data,
826            fix.ranges.len(),
827            value,
828            &fix.empty_hashes,
829            fix.root29,
830        )
831        .unwrap();
832
833        assert_eq!(proof.root, fix.root29);
834        assert_eq!(proof.path.len(), TREE_DEPTH);
835    }
836
837    // ── process_tier0 ────────────────────────────────────────────────────
838
839    #[test]
840    fn process_tier0_fills_correct_path_region() {
841        let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
842        let fix = TestFixture::build(&raw_nfs);
843        let tier0 = Tier0Data::from_bytes(fix.tier0_data).unwrap();
844
845        let value = fix.ranges[0][0];
846        let mut path = [Fp::default(); TREE_DEPTH];
847        let s1 = process_tier0(&tier0, value, &mut path).unwrap();
848
849        assert!(s1 < pir_types::TIER1_ROWS);
850
851        let tier0_region = &path[PIR_DEPTH - TIER0_LAYERS..PIR_DEPTH];
852        assert!(
853            tier0_region.iter().any(|&v| v != Fp::default()),
854            "tier0 should write at least one non-zero sibling"
855        );
856
857        let below = &path[..PIR_DEPTH - TIER0_LAYERS];
858        assert!(
859            below.iter().all(|&v| v == Fp::default()),
860            "path below tier0 region should be untouched"
861        );
862    }
863
864    #[test]
865    fn process_tier0_handles_arbitrary_field_element() {
866        let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
867        let fix = TestFixture::build(&raw_nfs);
868        let tier0 = Tier0Data::from_bytes(fix.tier0_data).unwrap();
869
870        // Sentinel nullifiers span the field, so every non-nullifier value
871        // falls in some gap range. Verify this doesn't panic and returns a
872        // valid subtree index.
873        let bogus = Fp::from(u64::MAX);
874        let mut path = [Fp::default(); TREE_DEPTH];
875        let s1 = process_tier0(&tier0, bogus, &mut path).unwrap();
876        assert!(s1 < pir_types::TIER1_ROWS);
877    }
878
879    // ── process_tier1 ────────────────────────────────────────────────────
880
881    #[test]
882    fn process_tier1_fills_correct_path_region() {
883        let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
884        let fix = TestFixture::build(&raw_nfs);
885        let tier0 = Tier0Data::from_bytes(fix.tier0_data.clone()).unwrap();
886
887        let value = fix.ranges[0][0];
888        let mut path = [Fp::default(); TREE_DEPTH];
889        let s1 = process_tier0(&tier0, value, &mut path).unwrap();
890
891        let t1_offset = s1 * TIER1_ROW_BYTES;
892        let tier1_row = &fix.tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES];
893        let s2 = process_tier1(tier1_row, value, &mut path).unwrap();
894
895        assert!(s2 < TIER1_LEAVES);
896
897        let tier1_region = &path[PIR_DEPTH - TIER0_LAYERS - TIER1_LAYERS..PIR_DEPTH - TIER0_LAYERS];
898        assert!(
899            tier1_region.iter().any(|&v| v != Fp::default()),
900            "tier1 should write at least one non-zero sibling"
901        );
902    }
903
904    // ── process_tier2_and_build ───────────────────────────────────────────
905
906    #[test]
907    fn process_tier2_and_build_produces_verifiable_proof() {
908        let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
909        let fix = TestFixture::build(&raw_nfs);
910        let tier0 = Tier0Data::from_bytes(fix.tier0_data.clone()).unwrap();
911
912        let value = fix.ranges[0][0] + Fp::one();
913        let mut path = [Fp::default(); TREE_DEPTH];
914
915        let s1 = process_tier0(&tier0, value, &mut path).unwrap();
916        let t1_offset = s1 * TIER1_ROW_BYTES;
917        let s2 = process_tier1(
918            &fix.tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES],
919            value,
920            &mut path,
921        )
922        .unwrap();
923
924        let t2_row_idx = s1 * TIER1_LEAVES + s2;
925        let t2_offset = t2_row_idx * TIER2_ROW_BYTES;
926        let proof = process_tier2_and_build(
927            &fix.tier2_data[t2_offset..t2_offset + TIER2_ROW_BYTES],
928            t2_row_idx,
929            fix.ranges.len(),
930            value,
931            &mut path,
932            &fix.empty_hashes,
933            fix.root29,
934        )
935        .unwrap();
936
937        assert!(proof.verify(value));
938        assert_eq!(proof.root, fix.root29);
939    }
940
941    // ── valid_leaves_for_row ──────────────────────────────────────────────
942
943    #[test]
944    fn valid_leaves_for_row_basic() {
945        assert_eq!(valid_leaves_for_row(TIER2_LEAVES, 0), TIER2_LEAVES);
946        assert_eq!(valid_leaves_for_row(TIER2_LEAVES + 1, 0), TIER2_LEAVES);
947        assert_eq!(valid_leaves_for_row(TIER2_LEAVES + 1, 1), 1);
948        assert_eq!(valid_leaves_for_row(0, 0), 0);
949        assert_eq!(valid_leaves_for_row(1, 0), 1);
950        assert_eq!(valid_leaves_for_row(1, 1), 0);
951    }
952
953    // ── fetch_proof_local error paths ─────────────────────────────────────
954
955    #[test]
956    fn fetch_proof_local_rejects_truncated_tier1() {
957        let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
958        let fix = TestFixture::build(&raw_nfs);
959
960        let result = fetch_proof_local(
961            &fix.tier0_data,
962            &fix.tier1_data[..TIER1_ROW_BYTES / 2],
963            &fix.tier2_data,
964            fix.ranges.len(),
965            fix.ranges[0][0],
966            &fix.empty_hashes,
967            fix.root29,
968        );
969        assert!(result.is_err());
970    }
971
972    #[test]
973    fn fetch_proof_local_rejects_truncated_tier2() {
974        let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
975        let fix = TestFixture::build(&raw_nfs);
976
977        let result = fetch_proof_local(
978            &fix.tier0_data,
979            &fix.tier1_data,
980            &fix.tier2_data[..TIER2_ROW_BYTES / 2],
981            fix.ranges.len(),
982            fix.ranges[0][0],
983            &fix.empty_hashes,
984            fix.root29,
985        );
986        assert!(result.is_err());
987    }
988
989    // ── Error-oracle mitigation ─────────────────────────────────────────
990
991    struct MockTransport {
992        gets: std::collections::HashMap<&'static str, TransportResponse>,
993        posts: std::collections::HashMap<&'static str, TransportResponse>,
994        hits: std::sync::Mutex<Vec<String>>,
995    }
996
997    impl MockTransport {
998        fn new(tree: &pir_export::PirTree) -> Self {
999            use ff::PrimeField as _;
1000            use pir_types::{TIER1_ITEM_BITS, TIER1_YPIR_ROWS, TIER2_ITEM_BITS};
1001
1002            let tier0_data = pir_export::tier0::export(
1003                &tree.root25,
1004                &tree.levels,
1005                &tree.ranges,
1006                &tree.empty_hashes,
1007            );
1008            let root_info = pir_types::RootInfo {
1009                root29: hex::encode(tree.root29.to_repr()),
1010                root25: hex::encode(tree.root25.to_repr()),
1011                num_ranges: tree.ranges.len(),
1012                pir_depth: PIR_DEPTH,
1013                height: None,
1014            };
1015            let tier1_scenario = YpirScenario {
1016                num_items: TIER1_YPIR_ROWS,
1017                item_size_bits: TIER1_ITEM_BITS,
1018            };
1019            let tier2_scenario = YpirScenario {
1020                num_items: TIER1_YPIR_ROWS,
1021                item_size_bits: TIER2_ITEM_BITS,
1022            };
1023
1024            let gets = [
1025                ("/tier0", response(tier0_data)),
1026                (
1027                    "/params/tier1",
1028                    response(serde_json::to_vec(&tier1_scenario).unwrap()),
1029                ),
1030                (
1031                    "/params/tier2",
1032                    response(serde_json::to_vec(&tier2_scenario).unwrap()),
1033                ),
1034                ("/root", response(serde_json::to_vec(&root_info).unwrap())),
1035            ]
1036            .into_iter()
1037            .collect();
1038            let posts = [
1039                ("/tier1/query", response(vec![0xDE; 65536])),
1040                ("/tier2/query", response(vec![0xAD; 65536])),
1041            ]
1042            .into_iter()
1043            .collect();
1044
1045            Self {
1046                gets,
1047                posts,
1048                hits: std::sync::Mutex::new(Vec::new()),
1049            }
1050        }
1051
1052        fn count_hits(&self, path: &str) -> usize {
1053            self.hits
1054                .lock()
1055                .unwrap()
1056                .iter()
1057                .filter(|hit| hit.as_str() == path)
1058                .count()
1059        }
1060    }
1061
1062    fn response(body: Vec<u8>) -> TransportResponse {
1063        TransportResponse {
1064            status: 200,
1065            headers: Vec::new(),
1066            body,
1067        }
1068    }
1069
1070    fn request_path(url: &str) -> &str {
1071        let without_scheme = url.split_once("://").map(|(_, rest)| rest).unwrap_or(url);
1072        without_scheme
1073            .find('/')
1074            .map(|idx| &without_scheme[idx..])
1075            .unwrap_or("/")
1076    }
1077
1078    impl Transport for MockTransport {
1079        fn get<'a>(&'a self, url: &'a str) -> transport::TransportFuture<'a> {
1080            Box::pin(async move {
1081                let path = request_path(url);
1082                self.hits.lock().unwrap().push(path.to_string());
1083                self.gets
1084                    .get(path)
1085                    .cloned()
1086                    .ok_or_else(|| anyhow::anyhow!("unexpected GET {path}"))
1087            })
1088        }
1089
1090        fn post<'a>(&'a self, url: &'a str, _body: Vec<u8>) -> transport::TransportFuture<'a> {
1091            Box::pin(async move {
1092                let path = request_path(url);
1093                self.hits.lock().unwrap().push(path.to_string());
1094                if path == "/tier2/query" {
1095                    // Coarse-grained guard against `try_join_all` cancellation:
1096                    // all fetch_proof_inner futures should dispatch tier 2 before
1097                    // the first corrupted tier 1 response resolves to Err.
1098                    tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1099                }
1100                self.posts
1101                    .get(path)
1102                    .cloned()
1103                    .ok_or_else(|| anyhow::anyhow!("unexpected POST {path}"))
1104            })
1105        }
1106    }
1107
1108    async fn corrupting_client() -> (
1109        PirClient,
1110        std::sync::Arc<MockTransport>,
1111        pir_export::PirTree,
1112    ) {
1113        let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
1114        let ranges = build_ranges_with_sentinels(&raw_nfs);
1115        let tree = pir_export::build_pir_tree(ranges).unwrap();
1116        let transport = std::sync::Arc::new(MockTransport::new(&tree));
1117        let client = PirClient::with_transport("https://pir.example", transport.clone())
1118            .await
1119            .unwrap();
1120        (client, transport, tree)
1121    }
1122
1123    /// Verify that the tier 2 query is always sent to the server even when
1124    /// the tier 1 response is corrupted.
1125    #[tokio::test]
1126    async fn tier2_query_sent_despite_tier1_decode_failure() {
1127        let (client, transport, tree) = corrupting_client().await;
1128        let result = client.fetch_proof(tree.ranges[0][0]).await;
1129
1130        assert!(
1131            result.is_err(),
1132            "fetch_proof should fail with corrupted tier1 response"
1133        );
1134        assert_eq!(transport.count_hits("/tier1/query"), 1);
1135        assert_eq!(transport.count_hits("/tier2/query"), 1);
1136    }
1137
1138    /// Asserts the K-note granularity of the error-oracle mitigation: when
1139    /// `fetch_proofs(K=5)` is called and every tier 1 response is corrupted,
1140    /// the server must still observe K tier 1 POSTs and K tier 2 POSTs.
1141    /// Aborting tier 2 dispatch the moment any tier 1 decode fails would
1142    /// re-introduce the "K vs K' tier-2 requests" oracle the per-note
1143    /// mitigation closes.
1144    #[tokio::test]
1145    async fn batched_tier2_queries_all_sent_despite_tier1_decode_failure() {
1146        const K: usize = 5;
1147
1148        let (client, transport, tree) = corrupting_client().await;
1149        let nullifiers: Vec<Fp> = tree
1150            .ranges
1151            .iter()
1152            .take(K)
1153            .map(|r| r[0] + Fp::one())
1154            .collect();
1155
1156        let result = client.fetch_proofs(&nullifiers).await;
1157
1158        assert!(
1159            result.is_err(),
1160            "fetch_proofs should fail with corrupted tier1 responses"
1161        );
1162        assert_eq!(transport.count_hits("/tier1/query"), K);
1163        assert_eq!(transport.count_hits("/tier2/query"), K);
1164    }
1165}