Skip to main content

pir_client/
lib.rs

1//! PIR client library for private Merkle path retrieval.
2//!
3//! Provides [`PirClient`] which connects to a `pir-server` instance and
4//! retrieves circuit-ready `ImtProofData` without revealing the
5//! queried nullifier to the server.
6
7use std::time::Instant;
8
9use anyhow::{Context, Result};
10use ff::PrimeField as _;
11use imt_tree::hasher::PoseidonHasher;
12use imt_tree::tree::{precompute_empty_hashes, TREE_DEPTH};
13use pasta_curves::Fp;
14// Re-exported so downstream crates (e.g. zcash_voting) can reference the type
15// returned by PirClientBlocking::fetch_proof without a direct imt-tree dependency.
16pub use imt_tree::ImtProofData;
17
18use pir_types::tier0::Tier0Data;
19use pir_types::tier1::Tier1Row;
20use pir_types::tier2::Tier2Row;
21use pir_types::{
22    serialize_ypir_query, RootInfo, YpirScenario, PIR_DEPTH, TIER0_LAYERS, TIER1_LAYERS,
23    TIER1_LEAVES, TIER1_ROW_BYTES, TIER2_LEAVES, TIER2_ROW_BYTES,
24};
25
26use ypir::client::YPIRClient;
27
28// ── Timing breakdown ─────────────────────────────────────────────────────────
29
30/// Per-tier timing breakdown for a single YPIR query, measuring each stage
31/// of the client-server round trip.
32struct TierTiming {
33    /// Client-side YPIR query generation time.
34    gen_ms: f64,
35    /// Size of the uploaded query payload.
36    upload_bytes: usize,
37    /// Size of the downloaded encrypted response.
38    download_bytes: usize,
39    /// Wall-clock round-trip time (upload + server compute + download).
40    rtt_ms: f64,
41    /// Client-side YPIR response decryption time.
42    decode_ms: f64,
43    /// Server-assigned request ID (from response header).
44    server_req_id: Option<u64>,
45    /// Server-reported total processing time.
46    server_total_ms: Option<f64>,
47    /// Server-reported query validation time.
48    server_validate_ms: Option<f64>,
49    /// Server-reported decode+copy time.
50    server_decode_copy_ms: Option<f64>,
51    /// Server-reported YPIR online computation time.
52    server_compute_ms: Option<f64>,
53    /// Estimated network + queue latency (RTT minus server time).
54    net_queue_ms: Option<f64>,
55    /// Estimated upload-to-server latency.
56    upload_to_server_ms: Option<f64>,
57    /// Estimated download-from-server latency.
58    download_from_server_ms: f64,
59}
60
61/// Per-note timing breakdown covering both tier 1 and tier 2 YPIR queries.
62struct NoteTiming {
63    tier1: TierTiming,
64    tier2: TierTiming,
65    /// Total wall-clock time for this note's proof retrieval.
66    total_ms: f64,
67}
68
69// ── HTTP-based PIR client ────────────────────────────────────────────────────
70
71/// PIR client that connects to a `pir-server` instance over HTTP.
72///
73/// Downloads Tier 0 data and YPIR parameters during `connect()`, then
74/// performs private queries via `fetch_proof()`.
75pub struct PirClient {
76    server_url: String,
77    http: reqwest::Client,
78    tier0: Tier0Data,
79    tier1_scenario: YpirScenario,
80    tier2_scenario: YpirScenario,
81    num_ranges: usize,
82    empty_hashes: [Fp; TREE_DEPTH],
83    root29: Fp,
84}
85
86/// Return the number of populated leaves in a Tier 2 row, clamped to
87/// [`TIER2_LEAVES`]. The final row may be only partially filled when
88/// `num_ranges` is not a multiple of the row size.
89#[inline]
90fn valid_leaves_for_row(num_ranges: usize, row_idx: usize) -> usize {
91    let row_start = row_idx.saturating_mul(TIER2_LEAVES);
92    num_ranges.saturating_sub(row_start).min(TIER2_LEAVES)
93}
94
95// ── Shared tier-processing helpers ───────────────────────────────────────────
96
97/// Copy `siblings` into `path` starting at `offset`.
98#[inline]
99fn fill_path(path: &mut [Fp; TREE_DEPTH], offset: usize, siblings: &[Fp]) {
100    path[offset..offset + siblings.len()].copy_from_slice(siblings);
101}
102
103/// Locate the nullifier's subtree in Tier 0, fill its siblings into `path`,
104/// and return the subtree index `s1`.
105fn process_tier0(tier0: &Tier0Data, nullifier: Fp, path: &mut [Fp; TREE_DEPTH]) -> Result<usize> {
106    let s1 = tier0
107        .find_subtree(nullifier)
108        .context("nullifier not found in any Tier 0 subtree")?;
109    fill_path(path, PIR_DEPTH - TIER0_LAYERS, &tier0.extract_siblings(s1));
110    Ok(s1)
111}
112
113/// Parse a Tier 1 row, locate the nullifier's sub-subtree, fill its siblings
114/// into `path`, and return the sub-subtree index `s2`.
115fn process_tier1(tier1_row: &[u8], nullifier: Fp, path: &mut [Fp; TREE_DEPTH]) -> Result<usize> {
116    let hasher = PoseidonHasher::new();
117    let tier1 = Tier1Row::from_bytes(tier1_row)?;
118    let s2 = tier1
119        .find_sub_subtree(nullifier)
120        .context("nullifier not found in any Tier 1 sub-subtree")?;
121    fill_path(
122        path,
123        PIR_DEPTH - TIER0_LAYERS - TIER1_LAYERS,
124        &tier1.extract_siblings(s2, &hasher),
125    );
126    Ok(s2)
127}
128
129/// Parse a Tier 2 row, locate the nullifier's leaf, fill tier-2 and padding
130/// siblings into `path`, and assemble the final [`ImtProofData`].
131fn process_tier2_and_build(
132    tier2_row: &[u8],
133    t2_row_idx: usize,
134    num_ranges: usize,
135    nullifier: Fp,
136    path: &mut [Fp; TREE_DEPTH],
137    empty_hashes: &[Fp; TREE_DEPTH],
138    root29: Fp,
139) -> Result<ImtProofData> {
140    let hasher = PoseidonHasher::new();
141    let tier2 = Tier2Row::from_bytes(tier2_row)?;
142    let valid_leaves = valid_leaves_for_row(num_ranges, t2_row_idx);
143
144    let leaf_local_idx = tier2
145        .find_leaf(nullifier, valid_leaves)
146        .context("nullifier not found in Tier 2 leaf scan")?;
147
148    fill_path(
149        path,
150        0,
151        &tier2.extract_siblings(leaf_local_idx, valid_leaves, &hasher),
152    );
153    // Pad from PIR depth (25) to circuit depth (29) with empty hashes.
154    fill_path(path, PIR_DEPTH, &empty_hashes[PIR_DEPTH..TREE_DEPTH]);
155
156    let global_leaf_idx = t2_row_idx * TIER2_LEAVES + leaf_local_idx;
157    let (nf_lo, nf_mid, nf_hi) = tier2.leaf_record(leaf_local_idx);
158
159    Ok(ImtProofData {
160        root: root29,
161        nf_bounds: [nf_lo, nf_mid, nf_hi],
162        leaf_pos: global_leaf_idx as u32,
163        path: *path,
164    })
165}
166
167impl PirClient {
168    /// Connect to a PIR server, downloading Tier 0 data and YPIR parameters.
169    pub async fn connect(server_url: &str) -> Result<Self> {
170        let http = reqwest::Client::new();
171        let base = server_url.trim_end_matches('/');
172
173        // Download Tier 0 data, YPIR params, and root concurrently
174        let t0 = Instant::now();
175        let (tier0_resp, tier1_resp, tier2_resp, root_resp) = tokio::try_join!(
176            http.get(format!("{base}/tier0")).send(),
177            http.get(format!("{base}/params/tier1")).send(),
178            http.get(format!("{base}/params/tier2")).send(),
179            http.get(format!("{base}/root")).send(),
180        )
181        .map_err(|e| anyhow::anyhow!("connect fetch failed: {e}"))?;
182
183        let tier0_bytes = tier0_resp.error_for_status()?.bytes().await?;
184        log::debug!(
185            "Downloaded Tier 0: {} bytes in {:.1}s",
186            tier0_bytes.len(),
187            t0.elapsed().as_secs_f64()
188        );
189        let tier0 = Tier0Data::from_bytes(tier0_bytes.to_vec())?;
190
191        let tier1_scenario: YpirScenario = tier1_resp
192            .error_for_status()
193            .context("GET /params/tier1 failed")?
194            .json()
195            .await?;
196        let tier2_scenario: YpirScenario = tier2_resp
197            .error_for_status()
198            .context("GET /params/tier2 failed")?
199            .json()
200            .await?;
201
202        let root_info: RootInfo = root_resp
203            .error_for_status()
204            .context("GET /root failed")?
205            .json()
206            .await?;
207        anyhow::ensure!(
208            root_info.pir_depth == PIR_DEPTH,
209            "server pir_depth {} != expected {}",
210            root_info.pir_depth,
211            PIR_DEPTH
212        );
213        let root29_bytes = hex::decode(&root_info.root29)?;
214        anyhow::ensure!(
215            root29_bytes.len() == 32,
216            "root29 hex decoded to {} bytes, expected 32",
217            root29_bytes.len()
218        );
219        let mut root29_arr = [0u8; 32];
220        root29_arr.copy_from_slice(&root29_bytes);
221        let root29 = Option::from(Fp::from_repr(root29_arr))
222            .ok_or_else(|| anyhow::anyhow!("invalid root29 field element"))?;
223
224        let empty_hashes = precompute_empty_hashes();
225
226        Ok(Self {
227            server_url: base.to_string(),
228            http,
229            tier0,
230            tier1_scenario,
231            tier2_scenario,
232            num_ranges: root_info.num_ranges,
233            empty_hashes,
234            root29,
235        })
236    }
237
238    /// Perform private Merkle path retrieval for a nullifier.
239    ///
240    /// Returns circuit-ready `ImtProofData` with a 29-element path
241    /// (25 PIR siblings + 4 empty-hash padding).
242    pub async fn fetch_proof(&self, nullifier: Fp) -> Result<ImtProofData> {
243        let (proof, _timing) = self.fetch_proof_inner(nullifier).await?;
244        Ok(proof)
245    }
246
247    /// Perform private Merkle path retrieval for multiple nullifiers in parallel.
248    ///
249    /// All queries run concurrently via `try_join_all`, sharing the same
250    /// `PirClient` (and thus the same HTTP client and Tier 0 data).
251    pub async fn fetch_proofs(&self, nullifiers: &[Fp]) -> Result<Vec<ImtProofData>> {
252        log::debug!(
253            "[PIR] Starting parallel fetch for {} notes...",
254            nullifiers.len()
255        );
256        let wall_start = Instant::now();
257
258        let futures: Vec<_> = nullifiers
259            .iter()
260            .enumerate()
261            .map(|(i, &nf)| async move {
262                let (proof, timing) = self.fetch_proof_inner(nf).await?;
263                Ok::<_, anyhow::Error>((i, proof, timing))
264            })
265            .collect();
266
267        let results_with_timing = futures::future::try_join_all(futures).await?;
268        let wall_ms = wall_start.elapsed().as_secs_f64() * 1000.0;
269
270        print_timing_table(&results_with_timing, wall_ms);
271
272        let proofs = results_with_timing
273            .into_iter()
274            .map(|(_, proof, _)| proof)
275            .collect();
276        Ok(proofs)
277    }
278
279    /// Fetch proof and return timing breakdown.
280    ///
281    /// **Error-oracle mitigation**: the tier 2 query is always sent even when
282    /// tier 1 fails. A malicious server could craft a tier 1 response whose
283    /// decryption outcome depends on the client's secret key material (e.g. by
284    /// triggering an assert in the LWE decode path). If the client aborted
285    /// before sending the tier 2 query, the server could observe its absence
286    /// and use the binary "crash / no-crash" signal as an oracle. By
287    /// unconditionally sending a (possibly dummy) tier 2 query we ensure the
288    /// server always sees both requests and gains no information from errors.
289    async fn fetch_proof_inner(&self, nullifier: Fp) -> Result<(ImtProofData, NoteTiming)> {
290        let note_start = Instant::now();
291        let mut path = [Fp::default(); TREE_DEPTH];
292
293        // Process tier 0 (plaintext, not server-controlled)
294        let s1 = process_tier0(&self.tier0, nullifier, &mut path)?;
295
296        // Process tier 1 (PIR) — capture the outcome without `?` so that a
297        // tier 2 query is always sent regardless of tier 1 success.
298        //
299        // process_tier1 is wrapped in catch_unwind so that a panic (e.g. from
300        // a debug_assert or an unexpected slice bounds violation) cannot
301        // prevent the tier 2 query from being sent. Without this, a panic
302        // here would unwind past the tier 2 dispatch and give the server an
303        // observable one-query-vs-two oracle.
304        let tier1_outcome = self
305            .ypir_query(&self.tier1_scenario, "tier1", s1, TIER1_ROW_BYTES)
306            .await
307            .and_then(|(row, timing)| {
308                let mut_path = &mut path;
309                let s2 = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
310                    process_tier1(&row, nullifier, mut_path)
311                }))
312                .unwrap_or_else(|payload| {
313                    let msg = payload
314                        .downcast_ref::<String>()
315                        .map(|s| s.as_str())
316                        .or_else(|| payload.downcast_ref::<&str>().copied())
317                        .unwrap_or("unknown panic");
318                    Err(anyhow::anyhow!("process_tier1 panicked: {}", msg))
319                })?;
320                Ok((s1 * TIER1_LEAVES + s2, timing))
321            });
322
323        // Real index on success, dummy index 0 on failure. PIR hides the
324        // queried index from the server, so the dummy is indistinguishable.
325        let t2_row_idx = tier1_outcome.as_ref().map(|(idx, _)| *idx).unwrap_or(0);
326
327        // Validate the tier 2 index before passing it to ypir_query.
328        // ypir_query has an ensure!(row_idx < num_items) that returns Err
329        // *before* sending the HTTP request — if that fires, no tier 2
330        // request reaches the server and we leak an oracle bit. A malicious
331        // server can trigger this by setting tier2 num_items too small or
332        // crafting tier 1 data that produces out-of-bounds indices. Clamp to
333        // dummy index 0 so the query always goes out; propagate the error
334        // only after both queries have been sent.
335        let t2_bounds_err = if t2_row_idx >= self.tier2_scenario.num_items {
336            Some(anyhow::anyhow!(
337                "tier2 row_idx {} >= num_items {}",
338                t2_row_idx,
339                self.tier2_scenario.num_items
340            ))
341        } else {
342            None
343        };
344        let t2_query_idx = if t2_bounds_err.is_some() {
345            0
346        } else {
347            t2_row_idx
348        };
349
350        // Always send tier 2 to void error-based oracles.
351        let tier2_result = self
352            .ypir_query(&self.tier2_scenario, "tier2", t2_query_idx, TIER2_ROW_BYTES)
353            .await;
354
355        // Propagate errors only after both queries have been sent.
356        let (t2_row_idx, tier1_timing) = tier1_outcome?;
357        if let Some(e) = t2_bounds_err {
358            return Err(e);
359        }
360        let (tier2_row, tier2_timing) = tier2_result?;
361
362        let proof = process_tier2_and_build(
363            &tier2_row,
364            t2_row_idx,
365            self.num_ranges,
366            nullifier,
367            &mut path,
368            &self.empty_hashes,
369            self.root29,
370        )?;
371
372        let total_ms = note_start.elapsed().as_secs_f64() * 1000.0;
373        Ok((
374            proof,
375            NoteTiming {
376                tier1: tier1_timing,
377                tier2: tier2_timing,
378                total_ms,
379            },
380        ))
381    }
382
383    /// Send a YPIR query for a tier row and return the decrypted row bytes.
384    /// This function handles the key client PIR operations:
385    /// 1. Generate keys
386    /// 2. Query
387    /// 3. Recover
388    async fn ypir_query(
389        &self,
390        scenario: &YpirScenario,
391        tier_name: &str,
392        row_idx: usize,
393        expected_row_bytes: usize,
394    ) -> Result<(Vec<u8>, TierTiming)> {
395        anyhow::ensure!(
396            row_idx < scenario.num_items,
397            "{} row_idx {} >= num_items {}",
398            tier_name,
399            row_idx,
400            scenario.num_items
401        );
402        let t0 = Instant::now();
403        let ypir_client = YPIRClient::from_db_sz(
404            scenario.num_items as u64,
405            scenario.item_size_bits as u64,
406            true,
407        );
408
409        // Generate PIR query from a fresh secret created from OsRng seed.
410        let (query, seed) = ypir_client.generate_query_simplepir(row_idx);
411        let gen_ms = t0.elapsed().as_secs_f64() * 1000.0;
412
413        // Serialize query
414        let payload = serialize_ypir_query(query.0.as_slice(), query.1.as_slice());
415        let upload_bytes = payload.len();
416
417        // Send the request
418        let t1 = Instant::now();
419        let url = format!("{}/{}/query", self.server_url, tier_name);
420        let send_result = self.http.post(&url).body(payload).send().await;
421        let send_ms = t1.elapsed().as_secs_f64() * 1000.0;
422        let resp = match send_result {
423            Ok(r) => r,
424            Err(e) => {
425                log::warn!("YPIR {} send error: {:?}", tier_name, e);
426                return Err(e.into());
427            }
428        };
429        let server_req_id = parse_header_u64(resp.headers(), "x-pir-req-id");
430        let server_total_ms = parse_header_f64(resp.headers(), "x-pir-server-total-ms");
431        let server_validate_ms = parse_header_f64(resp.headers(), "x-pir-server-validate-ms");
432        let server_decode_copy_ms = parse_header_f64(resp.headers(), "x-pir-server-decode-copy-ms");
433        let server_compute_ms = parse_header_f64(resp.headers(), "x-pir-server-compute-ms");
434        let status = resp.status();
435        let response_bytes = resp.bytes().await?;
436        if !status.is_success() {
437            anyhow::bail!(
438                "{} query failed: HTTP {} body={}",
439                tier_name,
440                status,
441                String::from_utf8_lossy(&response_bytes)
442            );
443        }
444        let rtt_ms = t1.elapsed().as_secs_f64() * 1000.0;
445        let download_from_server_ms = (rtt_ms - send_ms).max(0.0);
446        let net_queue_ms = server_total_ms.map(|server_ms| (rtt_ms - server_ms).max(0.0));
447        let upload_to_server_ms = server_total_ms.map(|server_ms| (send_ms - server_ms).max(0.0));
448
449        // Decode the response. Wrap in catch_unwind so that assert panics
450        // in the YPIR library (e.g. `val < lwe_q_prime` in the LWE decode
451        // path) become recoverable errors rather than process aborts. This is
452        // necessary for the error-oracle mitigation in fetch_proof_inner:
453        // a panic here must not prevent the second query from being sent.
454        let t2 = Instant::now();
455        let decoded = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
456            ypir_client.decode_response_simplepir(seed, &response_bytes)
457        }))
458        .map_err(|panic_payload| {
459            let msg = panic_payload
460                .downcast_ref::<String>()
461                .map(|s| s.as_str())
462                .or_else(|| panic_payload.downcast_ref::<&str>().copied())
463                .unwrap_or("unknown panic");
464            anyhow::anyhow!("{} response decryption panicked: {}", tier_name, msg)
465        })?;
466        let decode_ms = t2.elapsed().as_secs_f64() * 1000.0;
467
468        anyhow::ensure!(
469            decoded.len() >= expected_row_bytes,
470            "{} decoded response too short: {} bytes, expected >= {}",
471            tier_name,
472            decoded.len(),
473            expected_row_bytes
474        );
475        Ok((
476            decoded[..expected_row_bytes].to_vec(),
477            TierTiming {
478                gen_ms,
479                upload_bytes,
480                download_bytes: response_bytes.len(),
481                rtt_ms,
482                decode_ms,
483                server_req_id,
484                server_total_ms,
485                server_validate_ms,
486                server_decode_copy_ms,
487                server_compute_ms,
488                net_queue_ms,
489                upload_to_server_ms,
490                download_from_server_ms,
491            },
492        ))
493    }
494}
495
496fn fmt_time(ms: f64) -> String {
497    if ms >= 1000.0 {
498        format!("{:>5.1}s ", ms / 1000.0)
499    } else {
500        format!("{:>5.0}ms", ms)
501    }
502}
503
504fn fmt_opt_time(ms: Option<f64>) -> String {
505    match ms {
506        Some(v) => fmt_time(v),
507        None => "  n/a ".to_string(),
508    }
509}
510
511/// Print a detailed timing breakdown table for a batch of PIR proof fetches.
512fn print_timing_table(results: &[(usize, ImtProofData, NoteTiming)], wall_ms: f64) {
513    if !log::log_enabled!(log::Level::Debug) {
514        return;
515    }
516
517    log::debug!("[PIR] ┌─────┬──────────┬─────────────┬──────────┬──────────┬─────────────┬──────────┬────────┐");
518    log::debug!("[PIR] │ Note│ T1 keygen│ T1 upload+  │ T1 decode│ T2 keygen│ T2 upload+  │ T2 decode│ Total  │");
519    log::debug!("[PIR] │     │ (client) │ server+down │ (client) │ (client) │ server+down │ (client) │        │");
520    log::debug!("[PIR] ├─────┼──────────┼─────────────┼──────────┼──────────┼─────────────┼──────────┼────────┤");
521    for &(i, _, ref t) in results {
522        log::debug!(
523            "[PIR] │  {i:>2} │  {:>6} │   {:>7}   │  {:>6} │  {:>6} │   {:>7}   │  {:>6} │{} │",
524            fmt_time(t.tier1.gen_ms),
525            fmt_time(t.tier1.rtt_ms),
526            fmt_time(t.tier1.decode_ms),
527            fmt_time(t.tier2.gen_ms),
528            fmt_time(t.tier2.rtt_ms),
529            fmt_time(t.tier2.decode_ms),
530            fmt_time(t.total_ms),
531        );
532    }
533    log::debug!("[PIR] └─────┴──────────┴─────────────┴──────────┴──────────┴─────────────┴──────────┴────────┘");
534    log::debug!(
535        "[PIR] Upload per note: T1={:.0}KB T2={:.1}MB  |  Wall clock: {:.2}s",
536        results
537            .first()
538            .map(|(_, _, t)| t.tier1.upload_bytes)
539            .unwrap_or(0) as f64
540            / 1024.0,
541        results
542            .first()
543            .map(|(_, _, t)| t.tier2.upload_bytes)
544            .unwrap_or(0) as f64
545            / (1024.0 * 1024.0),
546        wall_ms / 1000.0,
547    );
548
549    for &(i, _, ref t) in results {
550        log::trace!(
551            "[PIR] Note {i:>2} transfer: T1 up={:.0}KB down={:.0}KB | T2 up={:.1}MB down={:.0}KB",
552            t.tier1.upload_bytes as f64 / 1024.0,
553            t.tier1.download_bytes as f64 / 1024.0,
554            t.tier2.upload_bytes as f64 / (1024.0 * 1024.0),
555            t.tier2.download_bytes as f64 / 1024.0,
556        );
557        log::trace!(
558            "[PIR] Note {i:>2} server/net: T1 {} / {} | T2 {} / {}",
559            fmt_opt_time(t.tier1.server_total_ms),
560            fmt_opt_time(t.tier1.net_queue_ms),
561            fmt_opt_time(t.tier2.server_total_ms),
562            fmt_opt_time(t.tier2.net_queue_ms),
563        );
564        log::trace!(
565            "[PIR] Note {i:>2} up/srv/down: T1 {} / {} / {} | T2 {} / {} / {}",
566            fmt_opt_time(t.tier1.upload_to_server_ms),
567            fmt_opt_time(t.tier1.server_total_ms),
568            fmt_time(t.tier1.download_from_server_ms),
569            fmt_opt_time(t.tier2.upload_to_server_ms),
570            fmt_opt_time(t.tier2.server_total_ms),
571            fmt_time(t.tier2.download_from_server_ms),
572        );
573        log::trace!(
574            "[PIR] Note {i:>2} server stages: T1(v={} copy={} compute={}) T2(v={} copy={} compute={})",
575            fmt_opt_time(t.tier1.server_validate_ms),
576            fmt_opt_time(t.tier1.server_decode_copy_ms),
577            fmt_opt_time(t.tier1.server_compute_ms),
578            fmt_opt_time(t.tier2.server_validate_ms),
579            fmt_opt_time(t.tier2.server_decode_copy_ms),
580            fmt_opt_time(t.tier2.server_compute_ms),
581        );
582        log::trace!(
583            "[PIR] Note {i:>2} req ids: T1={:?} T2={:?}",
584            t.tier1.server_req_id,
585            t.tier2.server_req_id
586        );
587    }
588}
589
590/// Parse an HTTP response header value as `f64`, returning `None` on missing or malformed values.
591fn parse_header_f64(headers: &reqwest::header::HeaderMap, name: &'static str) -> Option<f64> {
592    headers
593        .get(name)
594        .and_then(|v| v.to_str().ok())
595        .and_then(|s| s.parse::<f64>().ok())
596}
597
598/// Parse an HTTP response header value as `u64`, returning `None` on missing or malformed values.
599fn parse_header_u64(headers: &reqwest::header::HeaderMap, name: &'static str) -> Option<u64> {
600    headers
601        .get(name)
602        .and_then(|v| v.to_str().ok())
603        .and_then(|s| s.parse::<u64>().ok())
604}
605
606// ── Blocking wrapper ─────────────────────────────────────────────────────────
607
608/// Synchronous wrapper around [`PirClient`] for use from non-async code.
609///
610/// Owns a Tokio runtime internally so callers (e.g. zcash_voting, which must
611/// stay synchronous for the Halo2 prover) don't need to manage one.
612pub struct PirClientBlocking {
613    inner: PirClient,
614    rt: tokio::runtime::Runtime,
615}
616
617impl PirClientBlocking {
618    /// Connect to a PIR server (blocking). Downloads Tier 0 data and YPIR params.
619    pub fn connect(server_url: &str) -> Result<Self> {
620        let rt = tokio::runtime::Runtime::new()?;
621        let inner = rt.block_on(PirClient::connect(server_url))?;
622        Ok(Self { inner, rt })
623    }
624
625    /// Perform a private Merkle path retrieval for a nullifier (blocking).
626    pub fn fetch_proof(&self, nullifier: Fp) -> Result<ImtProofData> {
627        self.rt.block_on(self.inner.fetch_proof(nullifier))
628    }
629
630    /// Perform private Merkle path retrieval for multiple nullifiers in parallel (blocking).
631    pub fn fetch_proofs(&self, nullifiers: &[Fp]) -> Result<Vec<ImtProofData>> {
632        self.rt.block_on(self.inner.fetch_proofs(nullifiers))
633    }
634
635    /// The depth-29 root (PIR depth 25 padded to tree depth 29).
636    pub fn root29(&self) -> Fp {
637        self.inner.root29
638    }
639}
640
641// ── Local (in-process) PIR client ────────────────────────────────────────────
642
643/// Perform a complete local PIR proof retrieval without HTTP.
644///
645/// This is used by `pir-test local` mode. It takes the tier data directly
646/// (as built by `pir-export`) and performs the YPIR operations in-process.
647pub fn fetch_proof_local(
648    tier0_data: &[u8],
649    tier1_data: &[u8],
650    tier2_data: &[u8],
651    num_ranges: usize,
652    nullifier: Fp,
653    empty_hashes: &[Fp; TREE_DEPTH],
654    root29: Fp,
655) -> Result<ImtProofData> {
656    let mut path = [Fp::default(); TREE_DEPTH];
657    let tier0 = Tier0Data::from_bytes(tier0_data.to_vec())?;
658
659    let s1 = process_tier0(&tier0, nullifier, &mut path)?;
660
661    // ── Tier 1: direct row lookup (no YPIR in local mode) ────────────────
662    let t1_offset = s1 * TIER1_ROW_BYTES;
663    anyhow::ensure!(
664        t1_offset + TIER1_ROW_BYTES <= tier1_data.len(),
665        "tier1 data too short: need {} bytes at offset {}, have {}",
666        TIER1_ROW_BYTES,
667        t1_offset,
668        tier1_data.len()
669    );
670    let s2 = process_tier1(
671        &tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES],
672        nullifier,
673        &mut path,
674    )?;
675
676    // ── Tier 2: direct row lookup (no YPIR in local mode) ────────────────
677    let t2_row_idx = s1 * TIER1_LEAVES + s2;
678    let t2_offset = t2_row_idx * TIER2_ROW_BYTES;
679    anyhow::ensure!(
680        t2_offset + TIER2_ROW_BYTES <= tier2_data.len(),
681        "tier2 data too short: need {} bytes at offset {}, have {}",
682        TIER2_ROW_BYTES,
683        t2_offset,
684        tier2_data.len()
685    );
686
687    process_tier2_and_build(
688        &tier2_data[t2_offset..t2_offset + TIER2_ROW_BYTES],
689        t2_row_idx,
690        num_ranges,
691        nullifier,
692        &mut path,
693        empty_hashes,
694        root29,
695    )
696}
697
698#[cfg(test)]
699mod tests {
700    use super::*;
701    use ff::Field;
702    use pasta_curves::Fp;
703    use pir_export::build_ranges_with_sentinels;
704
705    /// Build a tree and export all three tier blobs.
706    struct TestFixture {
707        tier0_data: Vec<u8>,
708        tier1_data: Vec<u8>,
709        tier2_data: Vec<u8>,
710        ranges: Vec<[Fp; 3]>,
711        empty_hashes: [Fp; TREE_DEPTH],
712        root29: Fp,
713    }
714
715    impl TestFixture {
716        fn build(raw_nfs: &[Fp]) -> Self {
717            let ranges = build_ranges_with_sentinels(raw_nfs);
718            let tree = pir_export::build_pir_tree(ranges.clone()).unwrap();
719
720            let tier0_data = pir_export::tier0::export(
721                &tree.root25,
722                &tree.levels,
723                &tree.ranges,
724                &tree.empty_hashes,
725            );
726            let mut tier1_data = Vec::new();
727            pir_export::tier1::export(
728                &tree.levels,
729                &tree.ranges,
730                &tree.empty_hashes,
731                &mut tier1_data,
732            )
733            .unwrap();
734            let mut tier2_data = Vec::new();
735            pir_export::tier2::export(&tree.ranges, &mut tier2_data).unwrap();
736
737            Self {
738                tier0_data,
739                tier1_data,
740                tier2_data,
741                ranges,
742                empty_hashes: tree.empty_hashes,
743                root29: tree.root29,
744            }
745        }
746    }
747
748    // ── fetch_proof_local round-trip ──────────────────────────────────────
749
750    #[test]
751    fn fetch_proof_local_verifies_for_known_ranges() {
752        let mut rng = rand::thread_rng();
753        let raw_nfs: Vec<Fp> = (0..100).map(|_| Fp::random(&mut rng)).collect();
754        let fix = TestFixture::build(&raw_nfs);
755
756        for &[nf_lo, _, _] in fix.ranges.iter().take(20) {
757            let value = nf_lo + Fp::one();
758            let proof = fetch_proof_local(
759                &fix.tier0_data,
760                &fix.tier1_data,
761                &fix.tier2_data,
762                fix.ranges.len(),
763                value,
764                &fix.empty_hashes,
765                fix.root29,
766            )
767            .expect("fetch_proof_local should succeed for a value in range");
768            assert!(
769                proof.verify(value),
770                "proof should verify for value {:?}",
771                value,
772            );
773        }
774    }
775
776    #[test]
777    fn fetch_proof_local_correct_root_and_path_length() {
778        let raw_nfs: Vec<Fp> = (1u64..=50).map(|i| Fp::from(i * 997)).collect();
779        let fix = TestFixture::build(&raw_nfs);
780
781        let value = fix.ranges[0][0] + Fp::one(); // nf_lo + 1 is inside the range
782        let proof = fetch_proof_local(
783            &fix.tier0_data,
784            &fix.tier1_data,
785            &fix.tier2_data,
786            fix.ranges.len(),
787            value,
788            &fix.empty_hashes,
789            fix.root29,
790        )
791        .unwrap();
792
793        assert_eq!(proof.root, fix.root29);
794        assert_eq!(proof.path.len(), TREE_DEPTH);
795    }
796
797    // ── process_tier0 ────────────────────────────────────────────────────
798
799    #[test]
800    fn process_tier0_fills_correct_path_region() {
801        let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
802        let fix = TestFixture::build(&raw_nfs);
803        let tier0 = Tier0Data::from_bytes(fix.tier0_data).unwrap();
804
805        let value = fix.ranges[0][0];
806        let mut path = [Fp::default(); TREE_DEPTH];
807        let s1 = process_tier0(&tier0, value, &mut path).unwrap();
808
809        assert!(s1 < pir_types::TIER1_ROWS);
810
811        let tier0_region = &path[PIR_DEPTH - TIER0_LAYERS..PIR_DEPTH];
812        assert!(
813            tier0_region.iter().any(|&v| v != Fp::default()),
814            "tier0 should write at least one non-zero sibling"
815        );
816
817        let below = &path[..PIR_DEPTH - TIER0_LAYERS];
818        assert!(
819            below.iter().all(|&v| v == Fp::default()),
820            "path below tier0 region should be untouched"
821        );
822    }
823
824    #[test]
825    fn process_tier0_handles_arbitrary_field_element() {
826        let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
827        let fix = TestFixture::build(&raw_nfs);
828        let tier0 = Tier0Data::from_bytes(fix.tier0_data).unwrap();
829
830        // Sentinel nullifiers span the field, so every non-nullifier value
831        // falls in some gap range. Verify this doesn't panic and returns a
832        // valid subtree index.
833        let bogus = Fp::from(u64::MAX);
834        let mut path = [Fp::default(); TREE_DEPTH];
835        let s1 = process_tier0(&tier0, bogus, &mut path).unwrap();
836        assert!(s1 < pir_types::TIER1_ROWS);
837    }
838
839    // ── process_tier1 ────────────────────────────────────────────────────
840
841    #[test]
842    fn process_tier1_fills_correct_path_region() {
843        let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
844        let fix = TestFixture::build(&raw_nfs);
845        let tier0 = Tier0Data::from_bytes(fix.tier0_data.clone()).unwrap();
846
847        let value = fix.ranges[0][0];
848        let mut path = [Fp::default(); TREE_DEPTH];
849        let s1 = process_tier0(&tier0, value, &mut path).unwrap();
850
851        let t1_offset = s1 * TIER1_ROW_BYTES;
852        let tier1_row = &fix.tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES];
853        let s2 = process_tier1(tier1_row, value, &mut path).unwrap();
854
855        assert!(s2 < TIER1_LEAVES);
856
857        let tier1_region = &path[PIR_DEPTH - TIER0_LAYERS - TIER1_LAYERS..PIR_DEPTH - TIER0_LAYERS];
858        assert!(
859            tier1_region.iter().any(|&v| v != Fp::default()),
860            "tier1 should write at least one non-zero sibling"
861        );
862    }
863
864    // ── process_tier2_and_build ───────────────────────────────────────────
865
866    #[test]
867    fn process_tier2_and_build_produces_verifiable_proof() {
868        let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
869        let fix = TestFixture::build(&raw_nfs);
870        let tier0 = Tier0Data::from_bytes(fix.tier0_data.clone()).unwrap();
871
872        let value = fix.ranges[0][0] + Fp::one();
873        let mut path = [Fp::default(); TREE_DEPTH];
874
875        let s1 = process_tier0(&tier0, value, &mut path).unwrap();
876        let t1_offset = s1 * TIER1_ROW_BYTES;
877        let s2 = process_tier1(
878            &fix.tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES],
879            value,
880            &mut path,
881        )
882        .unwrap();
883
884        let t2_row_idx = s1 * TIER1_LEAVES + s2;
885        let t2_offset = t2_row_idx * TIER2_ROW_BYTES;
886        let proof = process_tier2_and_build(
887            &fix.tier2_data[t2_offset..t2_offset + TIER2_ROW_BYTES],
888            t2_row_idx,
889            fix.ranges.len(),
890            value,
891            &mut path,
892            &fix.empty_hashes,
893            fix.root29,
894        )
895        .unwrap();
896
897        assert!(proof.verify(value));
898        assert_eq!(proof.root, fix.root29);
899    }
900
901    // ── valid_leaves_for_row ──────────────────────────────────────────────
902
903    #[test]
904    fn valid_leaves_for_row_basic() {
905        assert_eq!(valid_leaves_for_row(TIER2_LEAVES, 0), TIER2_LEAVES);
906        assert_eq!(valid_leaves_for_row(TIER2_LEAVES + 1, 0), TIER2_LEAVES);
907        assert_eq!(valid_leaves_for_row(TIER2_LEAVES + 1, 1), 1);
908        assert_eq!(valid_leaves_for_row(0, 0), 0);
909        assert_eq!(valid_leaves_for_row(1, 0), 1);
910        assert_eq!(valid_leaves_for_row(1, 1), 0);
911    }
912
913    // ── fetch_proof_local error paths ─────────────────────────────────────
914
915    #[test]
916    fn fetch_proof_local_rejects_truncated_tier1() {
917        let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
918        let fix = TestFixture::build(&raw_nfs);
919
920        let result = fetch_proof_local(
921            &fix.tier0_data,
922            &fix.tier1_data[..TIER1_ROW_BYTES / 2],
923            &fix.tier2_data,
924            fix.ranges.len(),
925            fix.ranges[0][0],
926            &fix.empty_hashes,
927            fix.root29,
928        );
929        assert!(result.is_err());
930    }
931
932    #[test]
933    fn fetch_proof_local_rejects_truncated_tier2() {
934        let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
935        let fix = TestFixture::build(&raw_nfs);
936
937        let result = fetch_proof_local(
938            &fix.tier0_data,
939            &fix.tier1_data,
940            &fix.tier2_data[..TIER2_ROW_BYTES / 2],
941            fix.ranges.len(),
942            fix.ranges[0][0],
943            &fix.empty_hashes,
944            fix.root29,
945        );
946        assert!(result.is_err());
947    }
948
949    // ── Error-oracle mitigation ─────────────────────────────────────────
950
951    /// Verify that the tier 2 query is always sent to the server even when
952    /// the tier 1 response is corrupted.
953    ///
954    /// A malicious server could craft a tier 1 response whose decryption
955    /// outcome depends on the client's secret key material (e.g. by
956    /// triggering an assert in the LWE decode path). Without the
957    /// mitigation, a decode failure would prevent the tier 2 query from
958    /// being sent, and the server could use the absence of query 2 as a
959    /// single-bit oracle. This test asserts that both queries are always
960    /// issued regardless of tier 1 outcome.
961    #[tokio::test]
962    async fn tier2_query_sent_despite_tier1_decode_failure() {
963        use ff::PrimeField as _;
964        use pir_types::{TIER1_ITEM_BITS, TIER1_ROWS, TIER2_ITEM_BITS};
965        use wiremock::matchers::{method, path};
966        use wiremock::{Mock, MockServer, ResponseTemplate};
967
968        // Build real tier0 data so PirClient::connect() succeeds and
969        // process_tier0() produces a valid subtree index.
970        let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
971        let ranges = build_ranges_with_sentinels(&raw_nfs);
972        let tree = pir_export::build_pir_tree(ranges).unwrap();
973        let tier0_data =
974            pir_export::tier0::export(&tree.root25, &tree.levels, &tree.ranges, &tree.empty_hashes);
975
976        let root_info = pir_types::RootInfo {
977            root29: hex::encode(tree.root29.to_repr()),
978            root25: hex::encode(tree.root25.to_repr()),
979            num_ranges: tree.ranges.len(),
980            pir_depth: PIR_DEPTH,
981            height: None,
982        };
983
984        // Use the real item_size_bits to satisfy YPIR's internal
985        // parameter constraints. num_items=TIER1_ROWS (2048) matches
986        // production tier1 and is large enough for any s1 value.
987        let tier1_scenario = YpirScenario {
988            num_items: TIER1_ROWS,
989            item_size_bits: TIER1_ITEM_BITS,
990        };
991        let tier2_scenario = YpirScenario {
992            num_items: TIER1_ROWS,
993            item_size_bits: TIER2_ITEM_BITS,
994        };
995
996        let server = MockServer::start().await;
997
998        // ── setup endpoints (valid data) ────────────────────────────────
999        Mock::given(method("GET"))
1000            .and(path("/tier0"))
1001            .respond_with(ResponseTemplate::new(200).set_body_bytes(tier0_data))
1002            .mount(&server)
1003            .await;
1004        Mock::given(method("GET"))
1005            .and(path("/params/tier1"))
1006            .respond_with(ResponseTemplate::new(200).set_body_json(&tier1_scenario))
1007            .mount(&server)
1008            .await;
1009        Mock::given(method("GET"))
1010            .and(path("/params/tier2"))
1011            .respond_with(ResponseTemplate::new(200).set_body_json(&tier2_scenario))
1012            .mount(&server)
1013            .await;
1014        Mock::given(method("GET"))
1015            .and(path("/root"))
1016            .respond_with(ResponseTemplate::new(200).set_body_json(&root_info))
1017            .mount(&server)
1018            .await;
1019
1020        // ── query endpoints (corrupted responses) ───────────────────────
1021        Mock::given(method("POST"))
1022            .and(path("/tier1/query"))
1023            .respond_with(ResponseTemplate::new(200).set_body_bytes(vec![0xDE; 65536]))
1024            .mount(&server)
1025            .await;
1026        Mock::given(method("POST"))
1027            .and(path("/tier2/query"))
1028            .respond_with(ResponseTemplate::new(200).set_body_bytes(vec![0xAD; 65536]))
1029            .mount(&server)
1030            .await;
1031
1032        // ── run the client ──────────────────────────────────────────────
1033        let client = PirClient::connect(&server.uri()).await.unwrap();
1034        let nullifier = tree.ranges[0][0];
1035        let result = client.fetch_proof(nullifier).await;
1036
1037        assert!(
1038            result.is_err(),
1039            "fetch_proof should fail with corrupted tier1 response"
1040        );
1041
1042        // ── verify both queries were sent ───────────────────────────────
1043        let received = server.received_requests().await.unwrap();
1044        let tier1_hits = received
1045            .iter()
1046            .filter(|r| r.url.path() == "/tier1/query")
1047            .count();
1048        let tier2_hits = received
1049            .iter()
1050            .filter(|r| r.url.path() == "/tier2/query")
1051            .count();
1052
1053        assert_eq!(tier1_hits, 1, "tier1 query should have been sent");
1054        assert_eq!(
1055            tier2_hits, 1,
1056            "tier2 query must still be sent when tier1 decode fails \
1057             (error-oracle mitigation)"
1058        );
1059    }
1060}