Skip to main content

pir_client/
lib.rs

1//! PIR client library for private Merkle path retrieval.
2//!
3//! Provides [`PirClient`] which connects to a `pir-server` instance and
4//! retrieves circuit-ready `ImtProofData` without revealing the
5//! queried nullifier to the server.
6
7use std::time::Instant;
8
9use anyhow::{Context, Result};
10use ff::PrimeField as _;
11use imt_tree::hasher::PoseidonHasher;
12use imt_tree::tree::{precompute_empty_hashes, TREE_DEPTH};
13use pasta_curves::Fp;
14// Re-exported so downstream crates (e.g. zcash_voting) can reference the type
15// returned by PirClientBlocking::fetch_proof without a direct imt-tree dependency.
16pub use imt_tree::ImtProofData;
17
18use pir_types::tier0::Tier0Data;
19use pir_types::tier1::Tier1Row;
20use pir_types::tier2::Tier2Row;
21use pir_types::{
22    serialize_ypir_query, RootInfo, YpirScenario, PIR_DEPTH, TIER0_LAYERS, TIER1_LAYERS,
23    TIER1_LEAVES, TIER1_ROW_BYTES, TIER2_LEAVES, TIER2_ROW_BYTES,
24};
25
26use ypir::client::YPIRClient;
27
28// ── Timing breakdown ─────────────────────────────────────────────────────────
29
30/// Per-tier timing breakdown for a single YPIR query, measuring each stage
31/// of the client-server round trip.
32pub struct TierTiming {
33    /// Client-side YPIR query generation time.
34    pub gen_ms: f64,
35    /// Size of the uploaded query payload.
36    pub upload_bytes: usize,
37    /// Bytes of the uploaded query attributable to the SimplePIR query
38    /// vector itself (`q.0` / `pqr` — the first arg to
39    /// [`pir_types::serialize_ypir_query`]).
40    pub upload_q_bytes: usize,
41    /// Bytes of the uploaded query attributable to `pack_pub_params`
42    /// (the second arg to [`pir_types::serialize_ypir_query`]). Identical
43    /// across queries that share a YPIR `client_seed`.
44    pub upload_pp_bytes: usize,
45    /// Size of the downloaded encrypted response.
46    pub download_bytes: usize,
47    /// Wall-clock round-trip time (upload + server compute + download).
48    pub rtt_ms: f64,
49    /// Client-side YPIR response decryption time.
50    pub decode_ms: f64,
51    /// Server-assigned request ID (from response header).
52    pub server_req_id: Option<u64>,
53    /// Server-reported total processing time.
54    pub server_total_ms: Option<f64>,
55    /// Server-reported query validation time.
56    pub server_validate_ms: Option<f64>,
57    /// Server-reported decode+copy time.
58    pub server_decode_copy_ms: Option<f64>,
59    /// Server-reported YPIR online computation time.
60    pub server_compute_ms: Option<f64>,
61    /// Estimated network + queue latency (RTT minus server time).
62    pub net_queue_ms: Option<f64>,
63    /// Estimated upload-to-server latency.
64    pub upload_to_server_ms: Option<f64>,
65    /// Estimated download-from-server latency.
66    pub download_from_server_ms: f64,
67}
68
69/// Per-note timing breakdown covering both tier 1 and tier 2 YPIR queries.
70pub struct NoteTiming {
71    pub tier1: TierTiming,
72    pub tier2: TierTiming,
73    /// Total wall-clock time for this note's proof retrieval.
74    pub total_ms: f64,
75}
76
77// ── HTTP-based PIR client ────────────────────────────────────────────────────
78
79/// PIR client that connects to a `pir-server` instance over HTTP.
80///
81/// Downloads Tier 0 data and YPIR parameters during `connect()`, then
82/// performs private queries via `fetch_proof()`.
83pub struct PirClient {
84    server_url: String,
85    http: reqwest::Client,
86    tier0: Tier0Data,
87    tier1_scenario: YpirScenario,
88    tier2_scenario: YpirScenario,
89    num_ranges: usize,
90    empty_hashes: [Fp; TREE_DEPTH],
91    root29: Fp,
92}
93
94/// Return the number of populated leaves in a Tier 2 row, clamped to
95/// [`TIER2_LEAVES`]. The final row may be only partially filled when
96/// `num_ranges` is not a multiple of the row size.
97#[inline]
98fn valid_leaves_for_row(num_ranges: usize, row_idx: usize) -> usize {
99    let row_start = row_idx.saturating_mul(TIER2_LEAVES);
100    num_ranges.saturating_sub(row_start).min(TIER2_LEAVES)
101}
102
103// ── Shared tier-processing helpers ───────────────────────────────────────────
104
105/// Copy `siblings` into `path` starting at `offset`.
106#[inline]
107fn fill_path(path: &mut [Fp; TREE_DEPTH], offset: usize, siblings: &[Fp]) {
108    path[offset..offset + siblings.len()].copy_from_slice(siblings);
109}
110
111/// Locate the nullifier's subtree in Tier 0, fill its siblings into `path`,
112/// and return the subtree index `s1`.
113fn process_tier0(tier0: &Tier0Data, nullifier: Fp, path: &mut [Fp; TREE_DEPTH]) -> Result<usize> {
114    let s1 = tier0
115        .find_subtree(nullifier)
116        .context("nullifier not found in any Tier 0 subtree")?;
117    fill_path(path, PIR_DEPTH - TIER0_LAYERS, &tier0.extract_siblings(s1));
118    Ok(s1)
119}
120
121/// Parse a Tier 1 row, locate the nullifier's sub-subtree, fill its siblings
122/// into `path`, and return the sub-subtree index `s2`.
123fn process_tier1(tier1_row: &[u8], nullifier: Fp, path: &mut [Fp; TREE_DEPTH]) -> Result<usize> {
124    let hasher = PoseidonHasher::new();
125    let tier1 = Tier1Row::from_bytes(tier1_row)?;
126    let s2 = tier1
127        .find_sub_subtree(nullifier)
128        .context("nullifier not found in any Tier 1 sub-subtree")?;
129    fill_path(
130        path,
131        PIR_DEPTH - TIER0_LAYERS - TIER1_LAYERS,
132        &tier1.extract_siblings(s2, &hasher),
133    );
134    Ok(s2)
135}
136
137/// Parse a Tier 2 row, locate the nullifier's leaf, fill tier-2 and padding
138/// siblings into `path`, and assemble the final [`ImtProofData`].
139fn process_tier2_and_build(
140    tier2_row: &[u8],
141    t2_row_idx: usize,
142    num_ranges: usize,
143    nullifier: Fp,
144    path: &mut [Fp; TREE_DEPTH],
145    empty_hashes: &[Fp; TREE_DEPTH],
146    root29: Fp,
147) -> Result<ImtProofData> {
148    let hasher = PoseidonHasher::new();
149    let tier2 = Tier2Row::from_bytes(tier2_row)?;
150    let valid_leaves = valid_leaves_for_row(num_ranges, t2_row_idx);
151
152    let leaf_local_idx = tier2
153        .find_leaf(nullifier, valid_leaves)
154        .context("nullifier not found in Tier 2 leaf scan")?;
155
156    fill_path(
157        path,
158        0,
159        &tier2.extract_siblings(leaf_local_idx, valid_leaves, &hasher),
160    );
161    // Pad from PIR depth (25) to circuit depth (29) with empty hashes.
162    fill_path(path, PIR_DEPTH, &empty_hashes[PIR_DEPTH..TREE_DEPTH]);
163
164    let global_leaf_idx = t2_row_idx * TIER2_LEAVES + leaf_local_idx;
165    let (nf_lo, nf_mid, nf_hi) = tier2.leaf_record(leaf_local_idx);
166
167    Ok(ImtProofData {
168        root: root29,
169        nf_bounds: [nf_lo, nf_mid, nf_hi],
170        leaf_pos: global_leaf_idx as u32,
171        path: *path,
172    })
173}
174
175impl PirClient {
176    /// Connect to a PIR server, downloading Tier 0 data and YPIR parameters.
177    pub async fn connect(server_url: &str) -> Result<Self> {
178        Self::connect_with_http(server_url, reqwest::Client::new()).await
179    }
180
181    /// Like [`connect`](Self::connect) but with a caller-provided
182    /// [`reqwest::Client`]. Used by `pir-test bench-server --mode single-tls`
183    /// to force HTTP/1.1 with a single connection (no HTTP/2 stream
184    /// multiplexing) so we can isolate per-query upload bandwidth from
185    /// HTTP/2 stream contention.
186    pub async fn connect_with_http(server_url: &str, http: reqwest::Client) -> Result<Self> {
187        let base = server_url.trim_end_matches('/');
188
189        // Download Tier 0 data, YPIR params, and root concurrently
190        let t0 = Instant::now();
191        let (tier0_resp, tier1_resp, tier2_resp, root_resp) = tokio::try_join!(
192            http.get(format!("{base}/tier0")).send(),
193            http.get(format!("{base}/params/tier1")).send(),
194            http.get(format!("{base}/params/tier2")).send(),
195            http.get(format!("{base}/root")).send(),
196        )
197        .map_err(|e| anyhow::anyhow!("connect fetch failed: {e}"))?;
198
199        let tier0_bytes = tier0_resp.error_for_status()?.bytes().await?;
200        log::debug!(
201            "Downloaded Tier 0: {} bytes in {:.1}s",
202            tier0_bytes.len(),
203            t0.elapsed().as_secs_f64()
204        );
205        let tier0 = Tier0Data::from_bytes(tier0_bytes.to_vec())?;
206
207        let tier1_scenario: YpirScenario = tier1_resp
208            .error_for_status()
209            .context("GET /params/tier1 failed")?
210            .json()
211            .await?;
212        let tier2_scenario: YpirScenario = tier2_resp
213            .error_for_status()
214            .context("GET /params/tier2 failed")?
215            .json()
216            .await?;
217
218        let root_info: RootInfo = root_resp
219            .error_for_status()
220            .context("GET /root failed")?
221            .json()
222            .await?;
223        anyhow::ensure!(
224            root_info.pir_depth == PIR_DEPTH,
225            "server pir_depth {} != expected {}",
226            root_info.pir_depth,
227            PIR_DEPTH
228        );
229        let root29_bytes = hex::decode(&root_info.root29)?;
230        anyhow::ensure!(
231            root29_bytes.len() == 32,
232            "root29 hex decoded to {} bytes, expected 32",
233            root29_bytes.len()
234        );
235        let mut root29_arr = [0u8; 32];
236        root29_arr.copy_from_slice(&root29_bytes);
237        let root29 = Option::from(Fp::from_repr(root29_arr))
238            .ok_or_else(|| anyhow::anyhow!("invalid root29 field element"))?;
239
240        let empty_hashes = precompute_empty_hashes();
241
242        Ok(Self {
243            server_url: base.to_string(),
244            http,
245            tier0,
246            tier1_scenario,
247            tier2_scenario,
248            num_ranges: root_info.num_ranges,
249            empty_hashes,
250            root29,
251        })
252    }
253
254    /// Perform private Merkle path retrieval for a nullifier.
255    ///
256    /// Returns circuit-ready `ImtProofData` with a 29-element path
257    /// (25 PIR siblings + 4 empty-hash padding).
258    pub async fn fetch_proof(&self, nullifier: Fp) -> Result<ImtProofData> {
259        let (proof, _timing) = self.fetch_proof_inner(nullifier).await?;
260        Ok(proof)
261    }
262
263    /// Like [`fetch_proof`](Self::fetch_proof) but also returns the full
264    /// client+server timing breakdown for load-testing / observability.
265    pub async fn fetch_proof_with_timing(
266        &self,
267        nullifier: Fp,
268    ) -> Result<(ImtProofData, NoteTiming)> {
269        self.fetch_proof_inner(nullifier).await
270    }
271
272    /// Perform private Merkle path retrieval for multiple nullifiers in parallel.
273    ///
274    /// All queries run concurrently via `try_join_all`, sharing the same
275    /// `PirClient` (and thus the same HTTP client and Tier 0 data).
276    pub async fn fetch_proofs(&self, nullifiers: &[Fp]) -> Result<Vec<ImtProofData>> {
277        log::debug!(
278            "[PIR] Starting parallel fetch for {} notes...",
279            nullifiers.len()
280        );
281        let wall_start = Instant::now();
282
283        let futures: Vec<_> = nullifiers
284            .iter()
285            .enumerate()
286            .map(|(i, &nf)| async move {
287                let (proof, timing) = self.fetch_proof_inner(nf).await?;
288                Ok::<_, anyhow::Error>((i, proof, timing))
289            })
290            .collect();
291
292        let results_with_timing = futures::future::try_join_all(futures).await?;
293        let wall_ms = wall_start.elapsed().as_secs_f64() * 1000.0;
294
295        print_timing_table(&results_with_timing, wall_ms);
296
297        let proofs = results_with_timing
298            .into_iter()
299            .map(|(_, proof, _)| proof)
300            .collect();
301        Ok(proofs)
302    }
303
304    /// Fetch proof and return timing breakdown.
305    ///
306    /// **Error-oracle mitigation**: the tier 2 query is always sent even when
307    /// tier 1 fails. A malicious server could craft a tier 1 response whose
308    /// decryption outcome depends on the client's secret key material (e.g. by
309    /// triggering an assert in the LWE decode path). If the client aborted
310    /// before sending the tier 2 query, the server could observe its absence
311    /// and use the binary "crash / no-crash" signal as an oracle. By
312    /// unconditionally sending a (possibly dummy) tier 2 query we ensure the
313    /// server always sees both requests and gains no information from errors.
314    async fn fetch_proof_inner(&self, nullifier: Fp) -> Result<(ImtProofData, NoteTiming)> {
315        let note_start = Instant::now();
316        let mut path = [Fp::default(); TREE_DEPTH];
317
318        // Process tier 0 (plaintext, not server-controlled)
319        let s1 = process_tier0(&self.tier0, nullifier, &mut path)?;
320
321        // Process tier 1 (PIR) — capture the outcome without `?` so that a
322        // tier 2 query is always sent regardless of tier 1 success.
323        //
324        // process_tier1 is wrapped in catch_unwind so that a panic (e.g. from
325        // a debug_assert or an unexpected slice bounds violation) cannot
326        // prevent the tier 2 query from being sent. Without this, a panic
327        // here would unwind past the tier 2 dispatch and give the server an
328        // observable one-query-vs-two oracle.
329        let tier1_outcome = self
330            .ypir_query(&self.tier1_scenario, "tier1", s1, TIER1_ROW_BYTES)
331            .await
332            .and_then(|(row, timing)| {
333                let mut_path = &mut path;
334                let s2 = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
335                    process_tier1(&row, nullifier, mut_path)
336                }))
337                .unwrap_or_else(|payload| {
338                    let msg = payload
339                        .downcast_ref::<String>()
340                        .map(|s| s.as_str())
341                        .or_else(|| payload.downcast_ref::<&str>().copied())
342                        .unwrap_or("unknown panic");
343                    Err(anyhow::anyhow!("process_tier1 panicked: {}", msg))
344                })?;
345                Ok((s1 * TIER1_LEAVES + s2, timing))
346            });
347
348        // Real index on success, dummy index 0 on failure. PIR hides the
349        // queried index from the server, so the dummy is indistinguishable.
350        let t2_row_idx = tier1_outcome.as_ref().map(|(idx, _)| *idx).unwrap_or(0);
351
352        // Validate the tier 2 index before passing it to ypir_query.
353        // ypir_query has an ensure!(row_idx < num_items) that returns Err
354        // *before* sending the HTTP request — if that fires, no tier 2
355        // request reaches the server and we leak an oracle bit. A malicious
356        // server can trigger this by setting tier2 num_items too small or
357        // crafting tier 1 data that produces out-of-bounds indices. Clamp to
358        // dummy index 0 so the query always goes out; propagate the error
359        // only after both queries have been sent.
360        let t2_bounds_err = if t2_row_idx >= self.tier2_scenario.num_items {
361            Some(anyhow::anyhow!(
362                "tier2 row_idx {} >= num_items {}",
363                t2_row_idx,
364                self.tier2_scenario.num_items
365            ))
366        } else {
367            None
368        };
369        let t2_query_idx = if t2_bounds_err.is_some() {
370            0
371        } else {
372            t2_row_idx
373        };
374
375        // Always send tier 2 to void error-based oracles.
376        let tier2_result = self
377            .ypir_query(&self.tier2_scenario, "tier2", t2_query_idx, TIER2_ROW_BYTES)
378            .await;
379
380        // Propagate errors only after both queries have been sent.
381        let (t2_row_idx, tier1_timing) = tier1_outcome?;
382        if let Some(e) = t2_bounds_err {
383            return Err(e);
384        }
385        let (tier2_row, tier2_timing) = tier2_result?;
386
387        let proof = process_tier2_and_build(
388            &tier2_row,
389            t2_row_idx,
390            self.num_ranges,
391            nullifier,
392            &mut path,
393            &self.empty_hashes,
394            self.root29,
395        )?;
396
397        let total_ms = note_start.elapsed().as_secs_f64() * 1000.0;
398        Ok((
399            proof,
400            NoteTiming {
401                tier1: tier1_timing,
402                tier2: tier2_timing,
403                total_ms,
404            },
405        ))
406    }
407
408    /// Send a YPIR query for a tier row and return the decrypted row bytes.
409    /// This function handles the key client PIR operations:
410    /// 1. Generate keys
411    /// 2. Query
412    /// 3. Recover
413    async fn ypir_query(
414        &self,
415        scenario: &YpirScenario,
416        tier_name: &str,
417        row_idx: usize,
418        expected_row_bytes: usize,
419    ) -> Result<(Vec<u8>, TierTiming)> {
420        anyhow::ensure!(
421            row_idx < scenario.num_items,
422            "{} row_idx {} >= num_items {}",
423            tier_name,
424            row_idx,
425            scenario.num_items
426        );
427        let t0 = Instant::now();
428        let ypir_client = YPIRClient::from_db_sz(
429            scenario.num_items as u64,
430            scenario.item_size_bits as u64,
431            true,
432        );
433
434        // Generate PIR query from a fresh secret created from OsRng seed.
435        let (query, seed) = ypir_client.generate_query_simplepir(row_idx);
436        let gen_ms = t0.elapsed().as_secs_f64() * 1000.0;
437
438        // Serialize query. `query.0` is the SimplePIR query vector
439        // (per-query); `query.1` is `pack_pub_params` (depends only on
440        // the client's `client_seed`).
441        let upload_q_bytes = query.0.as_slice().len() * std::mem::size_of::<u64>();
442        let upload_pp_bytes = query.1.as_slice().len() * std::mem::size_of::<u64>();
443        let payload = serialize_ypir_query(query.0.as_slice(), query.1.as_slice());
444        let upload_bytes = payload.len();
445
446        // Send the request
447        let t1 = Instant::now();
448        let url = format!("{}/{}/query", self.server_url, tier_name);
449        let send_result = self.http.post(&url).body(payload).send().await;
450        let send_ms = t1.elapsed().as_secs_f64() * 1000.0;
451        let resp = match send_result {
452            Ok(r) => r,
453            Err(e) => {
454                log::warn!("YPIR {} send error: {:?}", tier_name, e);
455                return Err(e.into());
456            }
457        };
458        let server_req_id = parse_header_u64(resp.headers(), "x-pir-req-id");
459        let server_total_ms = parse_header_f64(resp.headers(), "x-pir-server-total-ms");
460        let server_validate_ms = parse_header_f64(resp.headers(), "x-pir-server-validate-ms");
461        let server_decode_copy_ms = parse_header_f64(resp.headers(), "x-pir-server-decode-copy-ms");
462        let server_compute_ms = parse_header_f64(resp.headers(), "x-pir-server-compute-ms");
463        let status = resp.status();
464        let response_bytes = resp.bytes().await?;
465        if !status.is_success() {
466            anyhow::bail!(
467                "{} query failed: HTTP {} body={}",
468                tier_name,
469                status,
470                String::from_utf8_lossy(&response_bytes)
471            );
472        }
473        let rtt_ms = t1.elapsed().as_secs_f64() * 1000.0;
474        let download_from_server_ms = (rtt_ms - send_ms).max(0.0);
475        let net_queue_ms = server_total_ms.map(|server_ms| (rtt_ms - server_ms).max(0.0));
476        let upload_to_server_ms = server_total_ms.map(|server_ms| (send_ms - server_ms).max(0.0));
477
478        // Decode the response. Wrap in catch_unwind so that assert panics
479        // in the YPIR library (e.g. `val < lwe_q_prime` in the LWE decode
480        // path) become recoverable errors rather than process aborts. This is
481        // necessary for the error-oracle mitigation in fetch_proof_inner:
482        // a panic here must not prevent the second query from being sent.
483        let t2 = Instant::now();
484        let decoded = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
485            ypir_client.decode_response_simplepir(seed, &response_bytes)
486        }))
487        .map_err(|panic_payload| {
488            let msg = panic_payload
489                .downcast_ref::<String>()
490                .map(|s| s.as_str())
491                .or_else(|| panic_payload.downcast_ref::<&str>().copied())
492                .unwrap_or("unknown panic");
493            anyhow::anyhow!("{} response decryption panicked: {}", tier_name, msg)
494        })?;
495        let decode_ms = t2.elapsed().as_secs_f64() * 1000.0;
496
497        anyhow::ensure!(
498            decoded.len() >= expected_row_bytes,
499            "{} decoded response too short: {} bytes, expected >= {}",
500            tier_name,
501            decoded.len(),
502            expected_row_bytes
503        );
504        Ok((
505            decoded[..expected_row_bytes].to_vec(),
506            TierTiming {
507                gen_ms,
508                upload_bytes,
509                upload_q_bytes,
510                upload_pp_bytes,
511                download_bytes: response_bytes.len(),
512                rtt_ms,
513                decode_ms,
514                server_req_id,
515                server_total_ms,
516                server_validate_ms,
517                server_decode_copy_ms,
518                server_compute_ms,
519                net_queue_ms,
520                upload_to_server_ms,
521                download_from_server_ms,
522            },
523        ))
524    }
525}
526
527fn fmt_time(ms: f64) -> String {
528    if ms >= 1000.0 {
529        format!("{:>5.1}s ", ms / 1000.0)
530    } else {
531        format!("{:>5.0}ms", ms)
532    }
533}
534
535fn fmt_opt_time(ms: Option<f64>) -> String {
536    match ms {
537        Some(v) => fmt_time(v),
538        None => "  n/a ".to_string(),
539    }
540}
541
542/// Print a detailed timing breakdown table for a batch of PIR proof fetches.
543fn print_timing_table(results: &[(usize, ImtProofData, NoteTiming)], wall_ms: f64) {
544    if !log::log_enabled!(log::Level::Debug) {
545        return;
546    }
547
548    log::debug!("[PIR] ┌─────┬──────────┬─────────────┬──────────┬──────────┬─────────────┬──────────┬────────┐");
549    log::debug!("[PIR] │ Note│ T1 keygen│ T1 upload+  │ T1 decode│ T2 keygen│ T2 upload+  │ T2 decode│ Total  │");
550    log::debug!("[PIR] │     │ (client) │ server+down │ (client) │ (client) │ server+down │ (client) │        │");
551    log::debug!("[PIR] ├─────┼──────────┼─────────────┼──────────┼──────────┼─────────────┼──────────┼────────┤");
552    for &(i, _, ref t) in results {
553        log::debug!(
554            "[PIR] │  {i:>2} │  {:>6} │   {:>7}   │  {:>6} │  {:>6} │   {:>7}   │  {:>6} │{} │",
555            fmt_time(t.tier1.gen_ms),
556            fmt_time(t.tier1.rtt_ms),
557            fmt_time(t.tier1.decode_ms),
558            fmt_time(t.tier2.gen_ms),
559            fmt_time(t.tier2.rtt_ms),
560            fmt_time(t.tier2.decode_ms),
561            fmt_time(t.total_ms),
562        );
563    }
564    log::debug!("[PIR] └─────┴──────────┴─────────────┴──────────┴──────────┴─────────────┴──────────┴────────┘");
565    log::debug!(
566        "[PIR] Upload per note: T1={:.0}KB T2={:.1}MB  |  Wall clock: {:.2}s",
567        results
568            .first()
569            .map(|(_, _, t)| t.tier1.upload_bytes)
570            .unwrap_or(0) as f64
571            / 1024.0,
572        results
573            .first()
574            .map(|(_, _, t)| t.tier2.upload_bytes)
575            .unwrap_or(0) as f64
576            / (1024.0 * 1024.0),
577        wall_ms / 1000.0,
578    );
579
580    for &(i, _, ref t) in results {
581        log::trace!(
582            "[PIR] Note {i:>2} transfer: T1 up={:.0}KB down={:.0}KB | T2 up={:.1}MB down={:.0}KB",
583            t.tier1.upload_bytes as f64 / 1024.0,
584            t.tier1.download_bytes as f64 / 1024.0,
585            t.tier2.upload_bytes as f64 / (1024.0 * 1024.0),
586            t.tier2.download_bytes as f64 / 1024.0,
587        );
588        log::trace!(
589            "[PIR] Note {i:>2} server/net: T1 {} / {} | T2 {} / {}",
590            fmt_opt_time(t.tier1.server_total_ms),
591            fmt_opt_time(t.tier1.net_queue_ms),
592            fmt_opt_time(t.tier2.server_total_ms),
593            fmt_opt_time(t.tier2.net_queue_ms),
594        );
595        log::trace!(
596            "[PIR] Note {i:>2} up/srv/down: T1 {} / {} / {} | T2 {} / {} / {}",
597            fmt_opt_time(t.tier1.upload_to_server_ms),
598            fmt_opt_time(t.tier1.server_total_ms),
599            fmt_time(t.tier1.download_from_server_ms),
600            fmt_opt_time(t.tier2.upload_to_server_ms),
601            fmt_opt_time(t.tier2.server_total_ms),
602            fmt_time(t.tier2.download_from_server_ms),
603        );
604        log::trace!(
605            "[PIR] Note {i:>2} server stages: T1(v={} copy={} compute={}) T2(v={} copy={} compute={})",
606            fmt_opt_time(t.tier1.server_validate_ms),
607            fmt_opt_time(t.tier1.server_decode_copy_ms),
608            fmt_opt_time(t.tier1.server_compute_ms),
609            fmt_opt_time(t.tier2.server_validate_ms),
610            fmt_opt_time(t.tier2.server_decode_copy_ms),
611            fmt_opt_time(t.tier2.server_compute_ms),
612        );
613        log::trace!(
614            "[PIR] Note {i:>2} req ids: T1={:?} T2={:?}",
615            t.tier1.server_req_id,
616            t.tier2.server_req_id
617        );
618    }
619}
620
621/// Parse an HTTP response header value as `f64`, returning `None` on missing or malformed values.
622fn parse_header_f64(headers: &reqwest::header::HeaderMap, name: &'static str) -> Option<f64> {
623    headers
624        .get(name)
625        .and_then(|v| v.to_str().ok())
626        .and_then(|s| s.parse::<f64>().ok())
627}
628
629/// Parse an HTTP response header value as `u64`, returning `None` on missing or malformed values.
630fn parse_header_u64(headers: &reqwest::header::HeaderMap, name: &'static str) -> Option<u64> {
631    headers
632        .get(name)
633        .and_then(|v| v.to_str().ok())
634        .and_then(|s| s.parse::<u64>().ok())
635}
636
637// ── Blocking wrapper ─────────────────────────────────────────────────────────
638
639/// Synchronous wrapper around [`PirClient`] for use from non-async code.
640///
641/// Owns a Tokio runtime internally so callers (e.g. zcash_voting, which must
642/// stay synchronous for the Halo2 prover) don't need to manage one.
643pub struct PirClientBlocking {
644    inner: PirClient,
645    rt: tokio::runtime::Runtime,
646}
647
648impl PirClientBlocking {
649    /// Connect to a PIR server (blocking). Downloads Tier 0 data and YPIR params.
650    pub fn connect(server_url: &str) -> Result<Self> {
651        let rt = tokio::runtime::Runtime::new()?;
652        let inner = rt.block_on(PirClient::connect(server_url))?;
653        Ok(Self { inner, rt })
654    }
655
656    /// Perform a private Merkle path retrieval for a nullifier (blocking).
657    pub fn fetch_proof(&self, nullifier: Fp) -> Result<ImtProofData> {
658        self.rt.block_on(self.inner.fetch_proof(nullifier))
659    }
660
661    /// Perform private Merkle path retrieval for multiple nullifiers in parallel (blocking).
662    pub fn fetch_proofs(&self, nullifiers: &[Fp]) -> Result<Vec<ImtProofData>> {
663        self.rt.block_on(self.inner.fetch_proofs(nullifiers))
664    }
665
666    /// The depth-29 root (PIR depth 25 padded to tree depth 29).
667    pub fn root29(&self) -> Fp {
668        self.inner.root29
669    }
670}
671
672// ── Local (in-process) PIR client ────────────────────────────────────────────
673
674/// Perform a complete local PIR proof retrieval without HTTP.
675///
676/// This is used by `pir-test local` mode. It takes the tier data directly
677/// (as built by `pir-export`) and performs the YPIR operations in-process.
678pub fn fetch_proof_local(
679    tier0_data: &[u8],
680    tier1_data: &[u8],
681    tier2_data: &[u8],
682    num_ranges: usize,
683    nullifier: Fp,
684    empty_hashes: &[Fp; TREE_DEPTH],
685    root29: Fp,
686) -> Result<ImtProofData> {
687    let mut path = [Fp::default(); TREE_DEPTH];
688    let tier0 = Tier0Data::from_bytes(tier0_data.to_vec())?;
689
690    let s1 = process_tier0(&tier0, nullifier, &mut path)?;
691
692    // ── Tier 1: direct row lookup (no YPIR in local mode) ────────────────
693    let t1_offset = s1 * TIER1_ROW_BYTES;
694    anyhow::ensure!(
695        t1_offset + TIER1_ROW_BYTES <= tier1_data.len(),
696        "tier1 data too short: need {} bytes at offset {}, have {}",
697        TIER1_ROW_BYTES,
698        t1_offset,
699        tier1_data.len()
700    );
701    let s2 = process_tier1(
702        &tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES],
703        nullifier,
704        &mut path,
705    )?;
706
707    // ── Tier 2: direct row lookup (no YPIR in local mode) ────────────────
708    let t2_row_idx = s1 * TIER1_LEAVES + s2;
709    let t2_offset = t2_row_idx * TIER2_ROW_BYTES;
710    anyhow::ensure!(
711        t2_offset + TIER2_ROW_BYTES <= tier2_data.len(),
712        "tier2 data too short: need {} bytes at offset {}, have {}",
713        TIER2_ROW_BYTES,
714        t2_offset,
715        tier2_data.len()
716    );
717
718    process_tier2_and_build(
719        &tier2_data[t2_offset..t2_offset + TIER2_ROW_BYTES],
720        t2_row_idx,
721        num_ranges,
722        nullifier,
723        &mut path,
724        empty_hashes,
725        root29,
726    )
727}
728
729#[cfg(test)]
730mod tests {
731    use super::*;
732    use ff::Field;
733    use pasta_curves::Fp;
734    use pir_export::build_ranges_with_sentinels;
735
736    /// Build a tree and export all three tier blobs.
737    struct TestFixture {
738        tier0_data: Vec<u8>,
739        tier1_data: Vec<u8>,
740        tier2_data: Vec<u8>,
741        ranges: Vec<[Fp; 3]>,
742        empty_hashes: [Fp; TREE_DEPTH],
743        root29: Fp,
744    }
745
746    impl TestFixture {
747        fn build(raw_nfs: &[Fp]) -> Self {
748            let ranges = build_ranges_with_sentinels(raw_nfs);
749            let tree = pir_export::build_pir_tree(ranges.clone()).unwrap();
750
751            let tier0_data = pir_export::tier0::export(
752                &tree.root25,
753                &tree.levels,
754                &tree.ranges,
755                &tree.empty_hashes,
756            );
757            let mut tier1_data = Vec::new();
758            pir_export::tier1::export(
759                &tree.levels,
760                &tree.ranges,
761                &tree.empty_hashes,
762                &mut tier1_data,
763            )
764            .unwrap();
765            let mut tier2_data = Vec::new();
766            pir_export::tier2::export(&tree.ranges, &mut tier2_data).unwrap();
767
768            Self {
769                tier0_data,
770                tier1_data,
771                tier2_data,
772                ranges,
773                empty_hashes: tree.empty_hashes,
774                root29: tree.root29,
775            }
776        }
777    }
778
779    // ── fetch_proof_local round-trip ──────────────────────────────────────
780
781    #[test]
782    fn fetch_proof_local_verifies_for_known_ranges() {
783        let mut rng = rand::thread_rng();
784        let raw_nfs: Vec<Fp> = (0..100).map(|_| Fp::random(&mut rng)).collect();
785        let fix = TestFixture::build(&raw_nfs);
786
787        for &[nf_lo, _, _] in fix.ranges.iter().take(20) {
788            let value = nf_lo + Fp::one();
789            let proof = fetch_proof_local(
790                &fix.tier0_data,
791                &fix.tier1_data,
792                &fix.tier2_data,
793                fix.ranges.len(),
794                value,
795                &fix.empty_hashes,
796                fix.root29,
797            )
798            .expect("fetch_proof_local should succeed for a value in range");
799            assert!(
800                proof.verify(value),
801                "proof should verify for value {:?}",
802                value,
803            );
804        }
805    }
806
807    #[test]
808    fn fetch_proof_local_correct_root_and_path_length() {
809        let raw_nfs: Vec<Fp> = (1u64..=50).map(|i| Fp::from(i * 997)).collect();
810        let fix = TestFixture::build(&raw_nfs);
811
812        let value = fix.ranges[0][0] + Fp::one(); // nf_lo + 1 is inside the range
813        let proof = fetch_proof_local(
814            &fix.tier0_data,
815            &fix.tier1_data,
816            &fix.tier2_data,
817            fix.ranges.len(),
818            value,
819            &fix.empty_hashes,
820            fix.root29,
821        )
822        .unwrap();
823
824        assert_eq!(proof.root, fix.root29);
825        assert_eq!(proof.path.len(), TREE_DEPTH);
826    }
827
828    // ── process_tier0 ────────────────────────────────────────────────────
829
830    #[test]
831    fn process_tier0_fills_correct_path_region() {
832        let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
833        let fix = TestFixture::build(&raw_nfs);
834        let tier0 = Tier0Data::from_bytes(fix.tier0_data).unwrap();
835
836        let value = fix.ranges[0][0];
837        let mut path = [Fp::default(); TREE_DEPTH];
838        let s1 = process_tier0(&tier0, value, &mut path).unwrap();
839
840        assert!(s1 < pir_types::TIER1_ROWS);
841
842        let tier0_region = &path[PIR_DEPTH - TIER0_LAYERS..PIR_DEPTH];
843        assert!(
844            tier0_region.iter().any(|&v| v != Fp::default()),
845            "tier0 should write at least one non-zero sibling"
846        );
847
848        let below = &path[..PIR_DEPTH - TIER0_LAYERS];
849        assert!(
850            below.iter().all(|&v| v == Fp::default()),
851            "path below tier0 region should be untouched"
852        );
853    }
854
855    #[test]
856    fn process_tier0_handles_arbitrary_field_element() {
857        let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
858        let fix = TestFixture::build(&raw_nfs);
859        let tier0 = Tier0Data::from_bytes(fix.tier0_data).unwrap();
860
861        // Sentinel nullifiers span the field, so every non-nullifier value
862        // falls in some gap range. Verify this doesn't panic and returns a
863        // valid subtree index.
864        let bogus = Fp::from(u64::MAX);
865        let mut path = [Fp::default(); TREE_DEPTH];
866        let s1 = process_tier0(&tier0, bogus, &mut path).unwrap();
867        assert!(s1 < pir_types::TIER1_ROWS);
868    }
869
870    // ── process_tier1 ────────────────────────────────────────────────────
871
872    #[test]
873    fn process_tier1_fills_correct_path_region() {
874        let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
875        let fix = TestFixture::build(&raw_nfs);
876        let tier0 = Tier0Data::from_bytes(fix.tier0_data.clone()).unwrap();
877
878        let value = fix.ranges[0][0];
879        let mut path = [Fp::default(); TREE_DEPTH];
880        let s1 = process_tier0(&tier0, value, &mut path).unwrap();
881
882        let t1_offset = s1 * TIER1_ROW_BYTES;
883        let tier1_row = &fix.tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES];
884        let s2 = process_tier1(tier1_row, value, &mut path).unwrap();
885
886        assert!(s2 < TIER1_LEAVES);
887
888        let tier1_region = &path[PIR_DEPTH - TIER0_LAYERS - TIER1_LAYERS..PIR_DEPTH - TIER0_LAYERS];
889        assert!(
890            tier1_region.iter().any(|&v| v != Fp::default()),
891            "tier1 should write at least one non-zero sibling"
892        );
893    }
894
895    // ── process_tier2_and_build ───────────────────────────────────────────
896
897    #[test]
898    fn process_tier2_and_build_produces_verifiable_proof() {
899        let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
900        let fix = TestFixture::build(&raw_nfs);
901        let tier0 = Tier0Data::from_bytes(fix.tier0_data.clone()).unwrap();
902
903        let value = fix.ranges[0][0] + Fp::one();
904        let mut path = [Fp::default(); TREE_DEPTH];
905
906        let s1 = process_tier0(&tier0, value, &mut path).unwrap();
907        let t1_offset = s1 * TIER1_ROW_BYTES;
908        let s2 = process_tier1(
909            &fix.tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES],
910            value,
911            &mut path,
912        )
913        .unwrap();
914
915        let t2_row_idx = s1 * TIER1_LEAVES + s2;
916        let t2_offset = t2_row_idx * TIER2_ROW_BYTES;
917        let proof = process_tier2_and_build(
918            &fix.tier2_data[t2_offset..t2_offset + TIER2_ROW_BYTES],
919            t2_row_idx,
920            fix.ranges.len(),
921            value,
922            &mut path,
923            &fix.empty_hashes,
924            fix.root29,
925        )
926        .unwrap();
927
928        assert!(proof.verify(value));
929        assert_eq!(proof.root, fix.root29);
930    }
931
932    // ── valid_leaves_for_row ──────────────────────────────────────────────
933
934    #[test]
935    fn valid_leaves_for_row_basic() {
936        assert_eq!(valid_leaves_for_row(TIER2_LEAVES, 0), TIER2_LEAVES);
937        assert_eq!(valid_leaves_for_row(TIER2_LEAVES + 1, 0), TIER2_LEAVES);
938        assert_eq!(valid_leaves_for_row(TIER2_LEAVES + 1, 1), 1);
939        assert_eq!(valid_leaves_for_row(0, 0), 0);
940        assert_eq!(valid_leaves_for_row(1, 0), 1);
941        assert_eq!(valid_leaves_for_row(1, 1), 0);
942    }
943
944    // ── fetch_proof_local error paths ─────────────────────────────────────
945
946    #[test]
947    fn fetch_proof_local_rejects_truncated_tier1() {
948        let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
949        let fix = TestFixture::build(&raw_nfs);
950
951        let result = fetch_proof_local(
952            &fix.tier0_data,
953            &fix.tier1_data[..TIER1_ROW_BYTES / 2],
954            &fix.tier2_data,
955            fix.ranges.len(),
956            fix.ranges[0][0],
957            &fix.empty_hashes,
958            fix.root29,
959        );
960        assert!(result.is_err());
961    }
962
963    #[test]
964    fn fetch_proof_local_rejects_truncated_tier2() {
965        let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
966        let fix = TestFixture::build(&raw_nfs);
967
968        let result = fetch_proof_local(
969            &fix.tier0_data,
970            &fix.tier1_data,
971            &fix.tier2_data[..TIER2_ROW_BYTES / 2],
972            fix.ranges.len(),
973            fix.ranges[0][0],
974            &fix.empty_hashes,
975            fix.root29,
976        );
977        assert!(result.is_err());
978    }
979
980    // ── Error-oracle mitigation ─────────────────────────────────────────
981
982    /// Verify that the tier 2 query is always sent to the server even when
983    /// the tier 1 response is corrupted.
984    ///
985    /// A malicious server could craft a tier 1 response whose decryption
986    /// outcome depends on the client's secret key material (e.g. by
987    /// triggering an assert in the LWE decode path). Without the
988    /// mitigation, a decode failure would prevent the tier 2 query from
989    /// being sent, and the server could use the absence of query 2 as a
990    /// single-bit oracle. This test asserts that both queries are always
991    /// issued regardless of tier 1 outcome.
992    #[tokio::test]
993    async fn tier2_query_sent_despite_tier1_decode_failure() {
994        use ff::PrimeField as _;
995        use pir_types::{TIER1_ITEM_BITS, TIER1_YPIR_ROWS, TIER2_ITEM_BITS};
996        use wiremock::matchers::{method, path};
997        use wiremock::{Mock, MockServer, ResponseTemplate};
998
999        // Build real tier0 data so PirClient::connect() succeeds and
1000        // process_tier0() produces a valid subtree index.
1001        let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
1002        let ranges = build_ranges_with_sentinels(&raw_nfs);
1003        let tree = pir_export::build_pir_tree(ranges).unwrap();
1004        let tier0_data =
1005            pir_export::tier0::export(&tree.root25, &tree.levels, &tree.ranges, &tree.empty_hashes);
1006
1007        let root_info = pir_types::RootInfo {
1008            root29: hex::encode(tree.root29.to_repr()),
1009            root25: hex::encode(tree.root25.to_repr()),
1010            num_ranges: tree.ranges.len(),
1011            pir_depth: PIR_DEPTH,
1012            height: None,
1013        };
1014
1015        // Use the real item_size_bits to satisfy YPIR's internal
1016        // parameter constraints. num_items=TIER1_YPIR_ROWS (2048) matches
1017        // production tier1 (which pads TIER1_ROWS=512 up to the YPIR
1018        // poly_len floor) and is large enough for any s1 value.
1019        let tier1_scenario = YpirScenario {
1020            num_items: TIER1_YPIR_ROWS,
1021            item_size_bits: TIER1_ITEM_BITS,
1022        };
1023        let tier2_scenario = YpirScenario {
1024            num_items: TIER1_YPIR_ROWS,
1025            item_size_bits: TIER2_ITEM_BITS,
1026        };
1027
1028        let server = MockServer::start().await;
1029
1030        // ── setup endpoints (valid data) ────────────────────────────────
1031        Mock::given(method("GET"))
1032            .and(path("/tier0"))
1033            .respond_with(ResponseTemplate::new(200).set_body_bytes(tier0_data))
1034            .mount(&server)
1035            .await;
1036        Mock::given(method("GET"))
1037            .and(path("/params/tier1"))
1038            .respond_with(ResponseTemplate::new(200).set_body_json(&tier1_scenario))
1039            .mount(&server)
1040            .await;
1041        Mock::given(method("GET"))
1042            .and(path("/params/tier2"))
1043            .respond_with(ResponseTemplate::new(200).set_body_json(&tier2_scenario))
1044            .mount(&server)
1045            .await;
1046        Mock::given(method("GET"))
1047            .and(path("/root"))
1048            .respond_with(ResponseTemplate::new(200).set_body_json(&root_info))
1049            .mount(&server)
1050            .await;
1051
1052        // ── query endpoints (corrupted responses) ───────────────────────
1053        Mock::given(method("POST"))
1054            .and(path("/tier1/query"))
1055            .respond_with(ResponseTemplate::new(200).set_body_bytes(vec![0xDE; 65536]))
1056            .mount(&server)
1057            .await;
1058        Mock::given(method("POST"))
1059            .and(path("/tier2/query"))
1060            .respond_with(ResponseTemplate::new(200).set_body_bytes(vec![0xAD; 65536]))
1061            .mount(&server)
1062            .await;
1063
1064        // ── run the client ──────────────────────────────────────────────
1065        let client = PirClient::connect(&server.uri()).await.unwrap();
1066        let nullifier = tree.ranges[0][0];
1067        let result = client.fetch_proof(nullifier).await;
1068
1069        assert!(
1070            result.is_err(),
1071            "fetch_proof should fail with corrupted tier1 response"
1072        );
1073
1074        // ── verify both queries were sent ───────────────────────────────
1075        let received = server.received_requests().await.unwrap();
1076        let tier1_hits = received
1077            .iter()
1078            .filter(|r| r.url.path() == "/tier1/query")
1079            .count();
1080        let tier2_hits = received
1081            .iter()
1082            .filter(|r| r.url.path() == "/tier2/query")
1083            .count();
1084
1085        assert_eq!(tier1_hits, 1, "tier1 query should have been sent");
1086        assert_eq!(
1087            tier2_hits, 1,
1088            "tier2 query must still be sent when tier1 decode fails \
1089             (error-oracle mitigation)"
1090        );
1091    }
1092
1093    /// Multi-note analog of [`tier2_query_sent_despite_tier1_decode_failure`].
1094    ///
1095    /// Asserts the **K-note granularity** of the error-oracle mitigation:
1096    /// when `fetch_proofs(K=5)` is called and **every** tier 1 response is
1097    /// corrupted, the server must still observe **K tier 1 POSTs and
1098    /// K tier 2 POSTs**. Aborting tier 2 dispatch the moment any tier 1
1099    /// decode fails would re-introduce the "K vs K' tier-2 requests"
1100    /// oracle the per-note mitigation closes.
1101    ///
1102    /// The 50 ms tier 2 delay is a coarse-grained guard against
1103    /// `try_join_all` cancellation: it ensures all five fetch_proof_inner
1104    /// futures dispatch their tier 2 POST before the first one resolves
1105    /// to `Err`, so wiremock's request log is deterministic.
1106    #[tokio::test]
1107    async fn batched_tier2_queries_all_sent_despite_tier1_decode_failure() {
1108        use ff::PrimeField as _;
1109        use pir_types::{TIER1_ITEM_BITS, TIER1_YPIR_ROWS, TIER2_ITEM_BITS};
1110        use std::time::Duration;
1111        use wiremock::matchers::{method, path};
1112        use wiremock::{Mock, MockServer, ResponseTemplate};
1113
1114        const K: usize = 5;
1115
1116        let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
1117        let ranges = build_ranges_with_sentinels(&raw_nfs);
1118        let tree = pir_export::build_pir_tree(ranges).unwrap();
1119        let tier0_data =
1120            pir_export::tier0::export(&tree.root25, &tree.levels, &tree.ranges, &tree.empty_hashes);
1121
1122        let root_info = pir_types::RootInfo {
1123            root29: hex::encode(tree.root29.to_repr()),
1124            root25: hex::encode(tree.root25.to_repr()),
1125            num_ranges: tree.ranges.len(),
1126            pir_depth: PIR_DEPTH,
1127            height: None,
1128        };
1129
1130        let tier1_scenario = YpirScenario {
1131            num_items: TIER1_YPIR_ROWS,
1132            item_size_bits: TIER1_ITEM_BITS,
1133        };
1134        let tier2_scenario = YpirScenario {
1135            num_items: TIER1_YPIR_ROWS,
1136            item_size_bits: TIER2_ITEM_BITS,
1137        };
1138
1139        let server = MockServer::start().await;
1140
1141        Mock::given(method("GET"))
1142            .and(path("/tier0"))
1143            .respond_with(ResponseTemplate::new(200).set_body_bytes(tier0_data))
1144            .mount(&server)
1145            .await;
1146        Mock::given(method("GET"))
1147            .and(path("/params/tier1"))
1148            .respond_with(ResponseTemplate::new(200).set_body_json(&tier1_scenario))
1149            .mount(&server)
1150            .await;
1151        Mock::given(method("GET"))
1152            .and(path("/params/tier2"))
1153            .respond_with(ResponseTemplate::new(200).set_body_json(&tier2_scenario))
1154            .mount(&server)
1155            .await;
1156        Mock::given(method("GET"))
1157            .and(path("/root"))
1158            .respond_with(ResponseTemplate::new(200).set_body_json(&root_info))
1159            .mount(&server)
1160            .await;
1161
1162        Mock::given(method("POST"))
1163            .and(path("/tier1/query"))
1164            .respond_with(ResponseTemplate::new(200).set_body_bytes(vec![0xDE; 65536]))
1165            .mount(&server)
1166            .await;
1167        Mock::given(method("POST"))
1168            .and(path("/tier2/query"))
1169            .respond_with(
1170                ResponseTemplate::new(200)
1171                    .set_body_bytes(vec![0xAD; 65536])
1172                    .set_delay(Duration::from_millis(50)),
1173            )
1174            .mount(&server)
1175            .await;
1176
1177        let client = PirClient::connect(&server.uri()).await.unwrap();
1178
1179        let nullifiers: Vec<Fp> = tree
1180            .ranges
1181            .iter()
1182            .take(K)
1183            .map(|r| r[0] + Fp::one())
1184            .collect();
1185        assert_eq!(nullifiers.len(), K);
1186
1187        let result = client.fetch_proofs(&nullifiers).await;
1188        assert!(
1189            result.is_err(),
1190            "fetch_proofs should fail with corrupted tier1 responses"
1191        );
1192
1193        let received = server.received_requests().await.unwrap();
1194        let tier1_hits = received
1195            .iter()
1196            .filter(|r| r.url.path() == "/tier1/query")
1197            .count();
1198        let tier2_hits = received
1199            .iter()
1200            .filter(|r| r.url.path() == "/tier2/query")
1201            .count();
1202
1203        assert_eq!(
1204            tier1_hits, K,
1205            "all K tier1 queries should have been sent (got {})",
1206            tier1_hits,
1207        );
1208        assert_eq!(
1209            tier2_hits, K,
1210            "all K tier2 queries must still be sent when their tier1 \
1211             decodes fail (batch-level error-oracle mitigation)",
1212        );
1213    }
1214}