1use std::time::Instant;
8
9use anyhow::{Context, Result};
10use ff::PrimeField as _;
11use imt_tree::hasher::PoseidonHasher;
12use imt_tree::tree::{precompute_empty_hashes, TREE_DEPTH};
13use pasta_curves::Fp;
14pub use imt_tree::ImtProofData;
17
18use pir_types::tier0::Tier0Data;
19use pir_types::tier1::Tier1Row;
20use pir_types::tier2::Tier2Row;
21use pir_types::{
22 serialize_ypir_query, RootInfo, YpirScenario, PIR_DEPTH, TIER0_LAYERS, TIER1_LAYERS,
23 TIER1_LEAVES, TIER1_ROW_BYTES, TIER2_LEAVES, TIER2_ROW_BYTES,
24};
25
26use ypir::client::YPIRClient;
27
28pub struct TierTiming {
33 pub gen_ms: f64,
35 pub upload_bytes: usize,
37 pub upload_q_bytes: usize,
41 pub upload_pp_bytes: usize,
45 pub download_bytes: usize,
47 pub rtt_ms: f64,
49 pub decode_ms: f64,
51 pub server_req_id: Option<u64>,
53 pub server_total_ms: Option<f64>,
55 pub server_validate_ms: Option<f64>,
57 pub server_decode_copy_ms: Option<f64>,
59 pub server_compute_ms: Option<f64>,
61 pub net_queue_ms: Option<f64>,
63 pub upload_to_server_ms: Option<f64>,
65 pub download_from_server_ms: f64,
67}
68
69pub struct NoteTiming {
71 pub tier1: TierTiming,
72 pub tier2: TierTiming,
73 pub total_ms: f64,
75}
76
77pub struct PirClient {
84 server_url: String,
85 http: reqwest::Client,
86 tier0: Tier0Data,
87 tier1_scenario: YpirScenario,
88 tier2_scenario: YpirScenario,
89 num_ranges: usize,
90 empty_hashes: [Fp; TREE_DEPTH],
91 root29: Fp,
92}
93
94#[inline]
98fn valid_leaves_for_row(num_ranges: usize, row_idx: usize) -> usize {
99 let row_start = row_idx.saturating_mul(TIER2_LEAVES);
100 num_ranges.saturating_sub(row_start).min(TIER2_LEAVES)
101}
102
103#[inline]
107fn fill_path(path: &mut [Fp; TREE_DEPTH], offset: usize, siblings: &[Fp]) {
108 path[offset..offset + siblings.len()].copy_from_slice(siblings);
109}
110
111fn process_tier0(tier0: &Tier0Data, nullifier: Fp, path: &mut [Fp; TREE_DEPTH]) -> Result<usize> {
114 let s1 = tier0
115 .find_subtree(nullifier)
116 .context("nullifier not found in any Tier 0 subtree")?;
117 fill_path(path, PIR_DEPTH - TIER0_LAYERS, &tier0.extract_siblings(s1));
118 Ok(s1)
119}
120
121fn process_tier1(tier1_row: &[u8], nullifier: Fp, path: &mut [Fp; TREE_DEPTH]) -> Result<usize> {
124 let hasher = PoseidonHasher::new();
125 let tier1 = Tier1Row::from_bytes(tier1_row)?;
126 let s2 = tier1
127 .find_sub_subtree(nullifier)
128 .context("nullifier not found in any Tier 1 sub-subtree")?;
129 fill_path(
130 path,
131 PIR_DEPTH - TIER0_LAYERS - TIER1_LAYERS,
132 &tier1.extract_siblings(s2, &hasher),
133 );
134 Ok(s2)
135}
136
137fn process_tier2_and_build(
140 tier2_row: &[u8],
141 t2_row_idx: usize,
142 num_ranges: usize,
143 nullifier: Fp,
144 path: &mut [Fp; TREE_DEPTH],
145 empty_hashes: &[Fp; TREE_DEPTH],
146 root29: Fp,
147) -> Result<ImtProofData> {
148 let hasher = PoseidonHasher::new();
149 let tier2 = Tier2Row::from_bytes(tier2_row)?;
150 let valid_leaves = valid_leaves_for_row(num_ranges, t2_row_idx);
151
152 let leaf_local_idx = tier2
153 .find_leaf(nullifier, valid_leaves)
154 .context("nullifier not found in Tier 2 leaf scan")?;
155
156 fill_path(
157 path,
158 0,
159 &tier2.extract_siblings(leaf_local_idx, valid_leaves, &hasher),
160 );
161 fill_path(path, PIR_DEPTH, &empty_hashes[PIR_DEPTH..TREE_DEPTH]);
163
164 let global_leaf_idx = t2_row_idx * TIER2_LEAVES + leaf_local_idx;
165 let (nf_lo, nf_mid, nf_hi) = tier2.leaf_record(leaf_local_idx);
166
167 Ok(ImtProofData {
168 root: root29,
169 nf_bounds: [nf_lo, nf_mid, nf_hi],
170 leaf_pos: global_leaf_idx as u32,
171 path: *path,
172 })
173}
174
175impl PirClient {
176 pub async fn connect(server_url: &str) -> Result<Self> {
178 Self::connect_with_http(server_url, reqwest::Client::new()).await
179 }
180
181 pub async fn connect_with_http(server_url: &str, http: reqwest::Client) -> Result<Self> {
187 let base = server_url.trim_end_matches('/');
188
189 let t0 = Instant::now();
191 let (tier0_resp, tier1_resp, tier2_resp, root_resp) = tokio::try_join!(
192 http.get(format!("{base}/tier0")).send(),
193 http.get(format!("{base}/params/tier1")).send(),
194 http.get(format!("{base}/params/tier2")).send(),
195 http.get(format!("{base}/root")).send(),
196 )
197 .map_err(|e| anyhow::anyhow!("connect fetch failed: {e}"))?;
198
199 let tier0_bytes = tier0_resp.error_for_status()?.bytes().await?;
200 log::debug!(
201 "Downloaded Tier 0: {} bytes in {:.1}s",
202 tier0_bytes.len(),
203 t0.elapsed().as_secs_f64()
204 );
205 let tier0 = Tier0Data::from_bytes(tier0_bytes.to_vec())?;
206
207 let tier1_scenario: YpirScenario = tier1_resp
208 .error_for_status()
209 .context("GET /params/tier1 failed")?
210 .json()
211 .await?;
212 let tier2_scenario: YpirScenario = tier2_resp
213 .error_for_status()
214 .context("GET /params/tier2 failed")?
215 .json()
216 .await?;
217
218 let root_info: RootInfo = root_resp
219 .error_for_status()
220 .context("GET /root failed")?
221 .json()
222 .await?;
223 anyhow::ensure!(
224 root_info.pir_depth == PIR_DEPTH,
225 "server pir_depth {} != expected {}",
226 root_info.pir_depth,
227 PIR_DEPTH
228 );
229 let root29_bytes = hex::decode(&root_info.root29)?;
230 anyhow::ensure!(
231 root29_bytes.len() == 32,
232 "root29 hex decoded to {} bytes, expected 32",
233 root29_bytes.len()
234 );
235 let mut root29_arr = [0u8; 32];
236 root29_arr.copy_from_slice(&root29_bytes);
237 let root29 = Option::from(Fp::from_repr(root29_arr))
238 .ok_or_else(|| anyhow::anyhow!("invalid root29 field element"))?;
239
240 let empty_hashes = precompute_empty_hashes();
241
242 Ok(Self {
243 server_url: base.to_string(),
244 http,
245 tier0,
246 tier1_scenario,
247 tier2_scenario,
248 num_ranges: root_info.num_ranges,
249 empty_hashes,
250 root29,
251 })
252 }
253
254 pub async fn fetch_proof(&self, nullifier: Fp) -> Result<ImtProofData> {
259 let (proof, _timing) = self.fetch_proof_inner(nullifier).await?;
260 Ok(proof)
261 }
262
263 pub async fn fetch_proof_with_timing(
266 &self,
267 nullifier: Fp,
268 ) -> Result<(ImtProofData, NoteTiming)> {
269 self.fetch_proof_inner(nullifier).await
270 }
271
272 pub async fn fetch_proofs(&self, nullifiers: &[Fp]) -> Result<Vec<ImtProofData>> {
277 log::debug!(
278 "[PIR] Starting parallel fetch for {} notes...",
279 nullifiers.len()
280 );
281 let wall_start = Instant::now();
282
283 let futures: Vec<_> = nullifiers
284 .iter()
285 .enumerate()
286 .map(|(i, &nf)| async move {
287 let (proof, timing) = self.fetch_proof_inner(nf).await?;
288 Ok::<_, anyhow::Error>((i, proof, timing))
289 })
290 .collect();
291
292 let results_with_timing = futures::future::try_join_all(futures).await?;
293 let wall_ms = wall_start.elapsed().as_secs_f64() * 1000.0;
294
295 print_timing_table(&results_with_timing, wall_ms);
296
297 let proofs = results_with_timing
298 .into_iter()
299 .map(|(_, proof, _)| proof)
300 .collect();
301 Ok(proofs)
302 }
303
304 async fn fetch_proof_inner(&self, nullifier: Fp) -> Result<(ImtProofData, NoteTiming)> {
315 let note_start = Instant::now();
316 let mut path = [Fp::default(); TREE_DEPTH];
317
318 let s1 = process_tier0(&self.tier0, nullifier, &mut path)?;
320
321 let tier1_outcome = self
330 .ypir_query(&self.tier1_scenario, "tier1", s1, TIER1_ROW_BYTES)
331 .await
332 .and_then(|(row, timing)| {
333 let mut_path = &mut path;
334 let s2 = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
335 process_tier1(&row, nullifier, mut_path)
336 }))
337 .unwrap_or_else(|payload| {
338 let msg = payload
339 .downcast_ref::<String>()
340 .map(|s| s.as_str())
341 .or_else(|| payload.downcast_ref::<&str>().copied())
342 .unwrap_or("unknown panic");
343 Err(anyhow::anyhow!("process_tier1 panicked: {}", msg))
344 })?;
345 Ok((s1 * TIER1_LEAVES + s2, timing))
346 });
347
348 let t2_row_idx = tier1_outcome.as_ref().map(|(idx, _)| *idx).unwrap_or(0);
351
352 let t2_bounds_err = if t2_row_idx >= self.tier2_scenario.num_items {
361 Some(anyhow::anyhow!(
362 "tier2 row_idx {} >= num_items {}",
363 t2_row_idx,
364 self.tier2_scenario.num_items
365 ))
366 } else {
367 None
368 };
369 let t2_query_idx = if t2_bounds_err.is_some() {
370 0
371 } else {
372 t2_row_idx
373 };
374
375 let tier2_result = self
377 .ypir_query(&self.tier2_scenario, "tier2", t2_query_idx, TIER2_ROW_BYTES)
378 .await;
379
380 let (t2_row_idx, tier1_timing) = tier1_outcome?;
382 if let Some(e) = t2_bounds_err {
383 return Err(e);
384 }
385 let (tier2_row, tier2_timing) = tier2_result?;
386
387 let proof = process_tier2_and_build(
388 &tier2_row,
389 t2_row_idx,
390 self.num_ranges,
391 nullifier,
392 &mut path,
393 &self.empty_hashes,
394 self.root29,
395 )?;
396
397 let total_ms = note_start.elapsed().as_secs_f64() * 1000.0;
398 Ok((
399 proof,
400 NoteTiming {
401 tier1: tier1_timing,
402 tier2: tier2_timing,
403 total_ms,
404 },
405 ))
406 }
407
408 async fn ypir_query(
414 &self,
415 scenario: &YpirScenario,
416 tier_name: &str,
417 row_idx: usize,
418 expected_row_bytes: usize,
419 ) -> Result<(Vec<u8>, TierTiming)> {
420 anyhow::ensure!(
421 row_idx < scenario.num_items,
422 "{} row_idx {} >= num_items {}",
423 tier_name,
424 row_idx,
425 scenario.num_items
426 );
427 let t0 = Instant::now();
428 let ypir_client = YPIRClient::from_db_sz(
429 scenario.num_items as u64,
430 scenario.item_size_bits as u64,
431 true,
432 );
433
434 let (query, seed) = ypir_client.generate_query_simplepir(row_idx);
436 let gen_ms = t0.elapsed().as_secs_f64() * 1000.0;
437
438 let upload_q_bytes = query.0.as_slice().len() * std::mem::size_of::<u64>();
442 let upload_pp_bytes = query.1.as_slice().len() * std::mem::size_of::<u64>();
443 let payload = serialize_ypir_query(query.0.as_slice(), query.1.as_slice());
444 let upload_bytes = payload.len();
445
446 let t1 = Instant::now();
448 let url = format!("{}/{}/query", self.server_url, tier_name);
449 let send_result = self.http.post(&url).body(payload).send().await;
450 let send_ms = t1.elapsed().as_secs_f64() * 1000.0;
451 let resp = match send_result {
452 Ok(r) => r,
453 Err(e) => {
454 log::warn!("YPIR {} send error: {:?}", tier_name, e);
455 return Err(e.into());
456 }
457 };
458 let server_req_id = parse_header_u64(resp.headers(), "x-pir-req-id");
459 let server_total_ms = parse_header_f64(resp.headers(), "x-pir-server-total-ms");
460 let server_validate_ms = parse_header_f64(resp.headers(), "x-pir-server-validate-ms");
461 let server_decode_copy_ms = parse_header_f64(resp.headers(), "x-pir-server-decode-copy-ms");
462 let server_compute_ms = parse_header_f64(resp.headers(), "x-pir-server-compute-ms");
463 let status = resp.status();
464 let response_bytes = resp.bytes().await?;
465 if !status.is_success() {
466 anyhow::bail!(
467 "{} query failed: HTTP {} body={}",
468 tier_name,
469 status,
470 String::from_utf8_lossy(&response_bytes)
471 );
472 }
473 let rtt_ms = t1.elapsed().as_secs_f64() * 1000.0;
474 let download_from_server_ms = (rtt_ms - send_ms).max(0.0);
475 let net_queue_ms = server_total_ms.map(|server_ms| (rtt_ms - server_ms).max(0.0));
476 let upload_to_server_ms = server_total_ms.map(|server_ms| (send_ms - server_ms).max(0.0));
477
478 let t2 = Instant::now();
484 let decoded = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
485 ypir_client.decode_response_simplepir(seed, &response_bytes)
486 }))
487 .map_err(|panic_payload| {
488 let msg = panic_payload
489 .downcast_ref::<String>()
490 .map(|s| s.as_str())
491 .or_else(|| panic_payload.downcast_ref::<&str>().copied())
492 .unwrap_or("unknown panic");
493 anyhow::anyhow!("{} response decryption panicked: {}", tier_name, msg)
494 })?;
495 let decode_ms = t2.elapsed().as_secs_f64() * 1000.0;
496
497 anyhow::ensure!(
498 decoded.len() >= expected_row_bytes,
499 "{} decoded response too short: {} bytes, expected >= {}",
500 tier_name,
501 decoded.len(),
502 expected_row_bytes
503 );
504 Ok((
505 decoded[..expected_row_bytes].to_vec(),
506 TierTiming {
507 gen_ms,
508 upload_bytes,
509 upload_q_bytes,
510 upload_pp_bytes,
511 download_bytes: response_bytes.len(),
512 rtt_ms,
513 decode_ms,
514 server_req_id,
515 server_total_ms,
516 server_validate_ms,
517 server_decode_copy_ms,
518 server_compute_ms,
519 net_queue_ms,
520 upload_to_server_ms,
521 download_from_server_ms,
522 },
523 ))
524 }
525}
526
527fn fmt_time(ms: f64) -> String {
528 if ms >= 1000.0 {
529 format!("{:>5.1}s ", ms / 1000.0)
530 } else {
531 format!("{:>5.0}ms", ms)
532 }
533}
534
535fn fmt_opt_time(ms: Option<f64>) -> String {
536 match ms {
537 Some(v) => fmt_time(v),
538 None => " n/a ".to_string(),
539 }
540}
541
542fn print_timing_table(results: &[(usize, ImtProofData, NoteTiming)], wall_ms: f64) {
544 if !log::log_enabled!(log::Level::Debug) {
545 return;
546 }
547
548 log::debug!("[PIR] ┌─────┬──────────┬─────────────┬──────────┬──────────┬─────────────┬──────────┬────────┐");
549 log::debug!("[PIR] │ Note│ T1 keygen│ T1 upload+ │ T1 decode│ T2 keygen│ T2 upload+ │ T2 decode│ Total │");
550 log::debug!("[PIR] │ │ (client) │ server+down │ (client) │ (client) │ server+down │ (client) │ │");
551 log::debug!("[PIR] ├─────┼──────────┼─────────────┼──────────┼──────────┼─────────────┼──────────┼────────┤");
552 for &(i, _, ref t) in results {
553 log::debug!(
554 "[PIR] │ {i:>2} │ {:>6} │ {:>7} │ {:>6} │ {:>6} │ {:>7} │ {:>6} │{} │",
555 fmt_time(t.tier1.gen_ms),
556 fmt_time(t.tier1.rtt_ms),
557 fmt_time(t.tier1.decode_ms),
558 fmt_time(t.tier2.gen_ms),
559 fmt_time(t.tier2.rtt_ms),
560 fmt_time(t.tier2.decode_ms),
561 fmt_time(t.total_ms),
562 );
563 }
564 log::debug!("[PIR] └─────┴──────────┴─────────────┴──────────┴──────────┴─────────────┴──────────┴────────┘");
565 log::debug!(
566 "[PIR] Upload per note: T1={:.0}KB T2={:.1}MB | Wall clock: {:.2}s",
567 results
568 .first()
569 .map(|(_, _, t)| t.tier1.upload_bytes)
570 .unwrap_or(0) as f64
571 / 1024.0,
572 results
573 .first()
574 .map(|(_, _, t)| t.tier2.upload_bytes)
575 .unwrap_or(0) as f64
576 / (1024.0 * 1024.0),
577 wall_ms / 1000.0,
578 );
579
580 for &(i, _, ref t) in results {
581 log::trace!(
582 "[PIR] Note {i:>2} transfer: T1 up={:.0}KB down={:.0}KB | T2 up={:.1}MB down={:.0}KB",
583 t.tier1.upload_bytes as f64 / 1024.0,
584 t.tier1.download_bytes as f64 / 1024.0,
585 t.tier2.upload_bytes as f64 / (1024.0 * 1024.0),
586 t.tier2.download_bytes as f64 / 1024.0,
587 );
588 log::trace!(
589 "[PIR] Note {i:>2} server/net: T1 {} / {} | T2 {} / {}",
590 fmt_opt_time(t.tier1.server_total_ms),
591 fmt_opt_time(t.tier1.net_queue_ms),
592 fmt_opt_time(t.tier2.server_total_ms),
593 fmt_opt_time(t.tier2.net_queue_ms),
594 );
595 log::trace!(
596 "[PIR] Note {i:>2} up/srv/down: T1 {} / {} / {} | T2 {} / {} / {}",
597 fmt_opt_time(t.tier1.upload_to_server_ms),
598 fmt_opt_time(t.tier1.server_total_ms),
599 fmt_time(t.tier1.download_from_server_ms),
600 fmt_opt_time(t.tier2.upload_to_server_ms),
601 fmt_opt_time(t.tier2.server_total_ms),
602 fmt_time(t.tier2.download_from_server_ms),
603 );
604 log::trace!(
605 "[PIR] Note {i:>2} server stages: T1(v={} copy={} compute={}) T2(v={} copy={} compute={})",
606 fmt_opt_time(t.tier1.server_validate_ms),
607 fmt_opt_time(t.tier1.server_decode_copy_ms),
608 fmt_opt_time(t.tier1.server_compute_ms),
609 fmt_opt_time(t.tier2.server_validate_ms),
610 fmt_opt_time(t.tier2.server_decode_copy_ms),
611 fmt_opt_time(t.tier2.server_compute_ms),
612 );
613 log::trace!(
614 "[PIR] Note {i:>2} req ids: T1={:?} T2={:?}",
615 t.tier1.server_req_id,
616 t.tier2.server_req_id
617 );
618 }
619}
620
621fn parse_header_f64(headers: &reqwest::header::HeaderMap, name: &'static str) -> Option<f64> {
623 headers
624 .get(name)
625 .and_then(|v| v.to_str().ok())
626 .and_then(|s| s.parse::<f64>().ok())
627}
628
629fn parse_header_u64(headers: &reqwest::header::HeaderMap, name: &'static str) -> Option<u64> {
631 headers
632 .get(name)
633 .and_then(|v| v.to_str().ok())
634 .and_then(|s| s.parse::<u64>().ok())
635}
636
637pub struct PirClientBlocking {
644 inner: PirClient,
645 rt: tokio::runtime::Runtime,
646}
647
648impl PirClientBlocking {
649 pub fn connect(server_url: &str) -> Result<Self> {
651 let rt = tokio::runtime::Runtime::new()?;
652 let inner = rt.block_on(PirClient::connect(server_url))?;
653 Ok(Self { inner, rt })
654 }
655
656 pub fn fetch_proof(&self, nullifier: Fp) -> Result<ImtProofData> {
658 self.rt.block_on(self.inner.fetch_proof(nullifier))
659 }
660
661 pub fn fetch_proofs(&self, nullifiers: &[Fp]) -> Result<Vec<ImtProofData>> {
663 self.rt.block_on(self.inner.fetch_proofs(nullifiers))
664 }
665
666 pub fn root29(&self) -> Fp {
668 self.inner.root29
669 }
670}
671
672pub fn fetch_proof_local(
679 tier0_data: &[u8],
680 tier1_data: &[u8],
681 tier2_data: &[u8],
682 num_ranges: usize,
683 nullifier: Fp,
684 empty_hashes: &[Fp; TREE_DEPTH],
685 root29: Fp,
686) -> Result<ImtProofData> {
687 let mut path = [Fp::default(); TREE_DEPTH];
688 let tier0 = Tier0Data::from_bytes(tier0_data.to_vec())?;
689
690 let s1 = process_tier0(&tier0, nullifier, &mut path)?;
691
692 let t1_offset = s1 * TIER1_ROW_BYTES;
694 anyhow::ensure!(
695 t1_offset + TIER1_ROW_BYTES <= tier1_data.len(),
696 "tier1 data too short: need {} bytes at offset {}, have {}",
697 TIER1_ROW_BYTES,
698 t1_offset,
699 tier1_data.len()
700 );
701 let s2 = process_tier1(
702 &tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES],
703 nullifier,
704 &mut path,
705 )?;
706
707 let t2_row_idx = s1 * TIER1_LEAVES + s2;
709 let t2_offset = t2_row_idx * TIER2_ROW_BYTES;
710 anyhow::ensure!(
711 t2_offset + TIER2_ROW_BYTES <= tier2_data.len(),
712 "tier2 data too short: need {} bytes at offset {}, have {}",
713 TIER2_ROW_BYTES,
714 t2_offset,
715 tier2_data.len()
716 );
717
718 process_tier2_and_build(
719 &tier2_data[t2_offset..t2_offset + TIER2_ROW_BYTES],
720 t2_row_idx,
721 num_ranges,
722 nullifier,
723 &mut path,
724 empty_hashes,
725 root29,
726 )
727}
728
729#[cfg(test)]
730mod tests {
731 use super::*;
732 use ff::Field;
733 use pasta_curves::Fp;
734 use pir_export::build_ranges_with_sentinels;
735
736 struct TestFixture {
738 tier0_data: Vec<u8>,
739 tier1_data: Vec<u8>,
740 tier2_data: Vec<u8>,
741 ranges: Vec<[Fp; 3]>,
742 empty_hashes: [Fp; TREE_DEPTH],
743 root29: Fp,
744 }
745
746 impl TestFixture {
747 fn build(raw_nfs: &[Fp]) -> Self {
748 let ranges = build_ranges_with_sentinels(raw_nfs);
749 let tree = pir_export::build_pir_tree(ranges.clone()).unwrap();
750
751 let tier0_data = pir_export::tier0::export(
752 &tree.root25,
753 &tree.levels,
754 &tree.ranges,
755 &tree.empty_hashes,
756 );
757 let mut tier1_data = Vec::new();
758 pir_export::tier1::export(
759 &tree.levels,
760 &tree.ranges,
761 &tree.empty_hashes,
762 &mut tier1_data,
763 )
764 .unwrap();
765 let mut tier2_data = Vec::new();
766 pir_export::tier2::export(&tree.ranges, &mut tier2_data).unwrap();
767
768 Self {
769 tier0_data,
770 tier1_data,
771 tier2_data,
772 ranges,
773 empty_hashes: tree.empty_hashes,
774 root29: tree.root29,
775 }
776 }
777 }
778
779 #[test]
782 fn fetch_proof_local_verifies_for_known_ranges() {
783 let mut rng = rand::thread_rng();
784 let raw_nfs: Vec<Fp> = (0..100).map(|_| Fp::random(&mut rng)).collect();
785 let fix = TestFixture::build(&raw_nfs);
786
787 for &[nf_lo, _, _] in fix.ranges.iter().take(20) {
788 let value = nf_lo + Fp::one();
789 let proof = fetch_proof_local(
790 &fix.tier0_data,
791 &fix.tier1_data,
792 &fix.tier2_data,
793 fix.ranges.len(),
794 value,
795 &fix.empty_hashes,
796 fix.root29,
797 )
798 .expect("fetch_proof_local should succeed for a value in range");
799 assert!(
800 proof.verify(value),
801 "proof should verify for value {:?}",
802 value,
803 );
804 }
805 }
806
807 #[test]
808 fn fetch_proof_local_correct_root_and_path_length() {
809 let raw_nfs: Vec<Fp> = (1u64..=50).map(|i| Fp::from(i * 997)).collect();
810 let fix = TestFixture::build(&raw_nfs);
811
812 let value = fix.ranges[0][0] + Fp::one(); let proof = fetch_proof_local(
814 &fix.tier0_data,
815 &fix.tier1_data,
816 &fix.tier2_data,
817 fix.ranges.len(),
818 value,
819 &fix.empty_hashes,
820 fix.root29,
821 )
822 .unwrap();
823
824 assert_eq!(proof.root, fix.root29);
825 assert_eq!(proof.path.len(), TREE_DEPTH);
826 }
827
828 #[test]
831 fn process_tier0_fills_correct_path_region() {
832 let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
833 let fix = TestFixture::build(&raw_nfs);
834 let tier0 = Tier0Data::from_bytes(fix.tier0_data).unwrap();
835
836 let value = fix.ranges[0][0];
837 let mut path = [Fp::default(); TREE_DEPTH];
838 let s1 = process_tier0(&tier0, value, &mut path).unwrap();
839
840 assert!(s1 < pir_types::TIER1_ROWS);
841
842 let tier0_region = &path[PIR_DEPTH - TIER0_LAYERS..PIR_DEPTH];
843 assert!(
844 tier0_region.iter().any(|&v| v != Fp::default()),
845 "tier0 should write at least one non-zero sibling"
846 );
847
848 let below = &path[..PIR_DEPTH - TIER0_LAYERS];
849 assert!(
850 below.iter().all(|&v| v == Fp::default()),
851 "path below tier0 region should be untouched"
852 );
853 }
854
855 #[test]
856 fn process_tier0_handles_arbitrary_field_element() {
857 let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
858 let fix = TestFixture::build(&raw_nfs);
859 let tier0 = Tier0Data::from_bytes(fix.tier0_data).unwrap();
860
861 let bogus = Fp::from(u64::MAX);
865 let mut path = [Fp::default(); TREE_DEPTH];
866 let s1 = process_tier0(&tier0, bogus, &mut path).unwrap();
867 assert!(s1 < pir_types::TIER1_ROWS);
868 }
869
870 #[test]
873 fn process_tier1_fills_correct_path_region() {
874 let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
875 let fix = TestFixture::build(&raw_nfs);
876 let tier0 = Tier0Data::from_bytes(fix.tier0_data.clone()).unwrap();
877
878 let value = fix.ranges[0][0];
879 let mut path = [Fp::default(); TREE_DEPTH];
880 let s1 = process_tier0(&tier0, value, &mut path).unwrap();
881
882 let t1_offset = s1 * TIER1_ROW_BYTES;
883 let tier1_row = &fix.tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES];
884 let s2 = process_tier1(tier1_row, value, &mut path).unwrap();
885
886 assert!(s2 < TIER1_LEAVES);
887
888 let tier1_region = &path[PIR_DEPTH - TIER0_LAYERS - TIER1_LAYERS..PIR_DEPTH - TIER0_LAYERS];
889 assert!(
890 tier1_region.iter().any(|&v| v != Fp::default()),
891 "tier1 should write at least one non-zero sibling"
892 );
893 }
894
895 #[test]
898 fn process_tier2_and_build_produces_verifiable_proof() {
899 let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
900 let fix = TestFixture::build(&raw_nfs);
901 let tier0 = Tier0Data::from_bytes(fix.tier0_data.clone()).unwrap();
902
903 let value = fix.ranges[0][0] + Fp::one();
904 let mut path = [Fp::default(); TREE_DEPTH];
905
906 let s1 = process_tier0(&tier0, value, &mut path).unwrap();
907 let t1_offset = s1 * TIER1_ROW_BYTES;
908 let s2 = process_tier1(
909 &fix.tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES],
910 value,
911 &mut path,
912 )
913 .unwrap();
914
915 let t2_row_idx = s1 * TIER1_LEAVES + s2;
916 let t2_offset = t2_row_idx * TIER2_ROW_BYTES;
917 let proof = process_tier2_and_build(
918 &fix.tier2_data[t2_offset..t2_offset + TIER2_ROW_BYTES],
919 t2_row_idx,
920 fix.ranges.len(),
921 value,
922 &mut path,
923 &fix.empty_hashes,
924 fix.root29,
925 )
926 .unwrap();
927
928 assert!(proof.verify(value));
929 assert_eq!(proof.root, fix.root29);
930 }
931
932 #[test]
935 fn valid_leaves_for_row_basic() {
936 assert_eq!(valid_leaves_for_row(TIER2_LEAVES, 0), TIER2_LEAVES);
937 assert_eq!(valid_leaves_for_row(TIER2_LEAVES + 1, 0), TIER2_LEAVES);
938 assert_eq!(valid_leaves_for_row(TIER2_LEAVES + 1, 1), 1);
939 assert_eq!(valid_leaves_for_row(0, 0), 0);
940 assert_eq!(valid_leaves_for_row(1, 0), 1);
941 assert_eq!(valid_leaves_for_row(1, 1), 0);
942 }
943
944 #[test]
947 fn fetch_proof_local_rejects_truncated_tier1() {
948 let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
949 let fix = TestFixture::build(&raw_nfs);
950
951 let result = fetch_proof_local(
952 &fix.tier0_data,
953 &fix.tier1_data[..TIER1_ROW_BYTES / 2],
954 &fix.tier2_data,
955 fix.ranges.len(),
956 fix.ranges[0][0],
957 &fix.empty_hashes,
958 fix.root29,
959 );
960 assert!(result.is_err());
961 }
962
963 #[test]
964 fn fetch_proof_local_rejects_truncated_tier2() {
965 let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
966 let fix = TestFixture::build(&raw_nfs);
967
968 let result = fetch_proof_local(
969 &fix.tier0_data,
970 &fix.tier1_data,
971 &fix.tier2_data[..TIER2_ROW_BYTES / 2],
972 fix.ranges.len(),
973 fix.ranges[0][0],
974 &fix.empty_hashes,
975 fix.root29,
976 );
977 assert!(result.is_err());
978 }
979
980 #[tokio::test]
993 async fn tier2_query_sent_despite_tier1_decode_failure() {
994 use ff::PrimeField as _;
995 use pir_types::{TIER1_ITEM_BITS, TIER1_YPIR_ROWS, TIER2_ITEM_BITS};
996 use wiremock::matchers::{method, path};
997 use wiremock::{Mock, MockServer, ResponseTemplate};
998
999 let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
1002 let ranges = build_ranges_with_sentinels(&raw_nfs);
1003 let tree = pir_export::build_pir_tree(ranges).unwrap();
1004 let tier0_data =
1005 pir_export::tier0::export(&tree.root25, &tree.levels, &tree.ranges, &tree.empty_hashes);
1006
1007 let root_info = pir_types::RootInfo {
1008 root29: hex::encode(tree.root29.to_repr()),
1009 root25: hex::encode(tree.root25.to_repr()),
1010 num_ranges: tree.ranges.len(),
1011 pir_depth: PIR_DEPTH,
1012 height: None,
1013 };
1014
1015 let tier1_scenario = YpirScenario {
1020 num_items: TIER1_YPIR_ROWS,
1021 item_size_bits: TIER1_ITEM_BITS,
1022 };
1023 let tier2_scenario = YpirScenario {
1024 num_items: TIER1_YPIR_ROWS,
1025 item_size_bits: TIER2_ITEM_BITS,
1026 };
1027
1028 let server = MockServer::start().await;
1029
1030 Mock::given(method("GET"))
1032 .and(path("/tier0"))
1033 .respond_with(ResponseTemplate::new(200).set_body_bytes(tier0_data))
1034 .mount(&server)
1035 .await;
1036 Mock::given(method("GET"))
1037 .and(path("/params/tier1"))
1038 .respond_with(ResponseTemplate::new(200).set_body_json(&tier1_scenario))
1039 .mount(&server)
1040 .await;
1041 Mock::given(method("GET"))
1042 .and(path("/params/tier2"))
1043 .respond_with(ResponseTemplate::new(200).set_body_json(&tier2_scenario))
1044 .mount(&server)
1045 .await;
1046 Mock::given(method("GET"))
1047 .and(path("/root"))
1048 .respond_with(ResponseTemplate::new(200).set_body_json(&root_info))
1049 .mount(&server)
1050 .await;
1051
1052 Mock::given(method("POST"))
1054 .and(path("/tier1/query"))
1055 .respond_with(ResponseTemplate::new(200).set_body_bytes(vec![0xDE; 65536]))
1056 .mount(&server)
1057 .await;
1058 Mock::given(method("POST"))
1059 .and(path("/tier2/query"))
1060 .respond_with(ResponseTemplate::new(200).set_body_bytes(vec![0xAD; 65536]))
1061 .mount(&server)
1062 .await;
1063
1064 let client = PirClient::connect(&server.uri()).await.unwrap();
1066 let nullifier = tree.ranges[0][0];
1067 let result = client.fetch_proof(nullifier).await;
1068
1069 assert!(
1070 result.is_err(),
1071 "fetch_proof should fail with corrupted tier1 response"
1072 );
1073
1074 let received = server.received_requests().await.unwrap();
1076 let tier1_hits = received
1077 .iter()
1078 .filter(|r| r.url.path() == "/tier1/query")
1079 .count();
1080 let tier2_hits = received
1081 .iter()
1082 .filter(|r| r.url.path() == "/tier2/query")
1083 .count();
1084
1085 assert_eq!(tier1_hits, 1, "tier1 query should have been sent");
1086 assert_eq!(
1087 tier2_hits, 1,
1088 "tier2 query must still be sent when tier1 decode fails \
1089 (error-oracle mitigation)"
1090 );
1091 }
1092
1093 #[tokio::test]
1107 async fn batched_tier2_queries_all_sent_despite_tier1_decode_failure() {
1108 use ff::PrimeField as _;
1109 use pir_types::{TIER1_ITEM_BITS, TIER1_YPIR_ROWS, TIER2_ITEM_BITS};
1110 use std::time::Duration;
1111 use wiremock::matchers::{method, path};
1112 use wiremock::{Mock, MockServer, ResponseTemplate};
1113
1114 const K: usize = 5;
1115
1116 let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
1117 let ranges = build_ranges_with_sentinels(&raw_nfs);
1118 let tree = pir_export::build_pir_tree(ranges).unwrap();
1119 let tier0_data =
1120 pir_export::tier0::export(&tree.root25, &tree.levels, &tree.ranges, &tree.empty_hashes);
1121
1122 let root_info = pir_types::RootInfo {
1123 root29: hex::encode(tree.root29.to_repr()),
1124 root25: hex::encode(tree.root25.to_repr()),
1125 num_ranges: tree.ranges.len(),
1126 pir_depth: PIR_DEPTH,
1127 height: None,
1128 };
1129
1130 let tier1_scenario = YpirScenario {
1131 num_items: TIER1_YPIR_ROWS,
1132 item_size_bits: TIER1_ITEM_BITS,
1133 };
1134 let tier2_scenario = YpirScenario {
1135 num_items: TIER1_YPIR_ROWS,
1136 item_size_bits: TIER2_ITEM_BITS,
1137 };
1138
1139 let server = MockServer::start().await;
1140
1141 Mock::given(method("GET"))
1142 .and(path("/tier0"))
1143 .respond_with(ResponseTemplate::new(200).set_body_bytes(tier0_data))
1144 .mount(&server)
1145 .await;
1146 Mock::given(method("GET"))
1147 .and(path("/params/tier1"))
1148 .respond_with(ResponseTemplate::new(200).set_body_json(&tier1_scenario))
1149 .mount(&server)
1150 .await;
1151 Mock::given(method("GET"))
1152 .and(path("/params/tier2"))
1153 .respond_with(ResponseTemplate::new(200).set_body_json(&tier2_scenario))
1154 .mount(&server)
1155 .await;
1156 Mock::given(method("GET"))
1157 .and(path("/root"))
1158 .respond_with(ResponseTemplate::new(200).set_body_json(&root_info))
1159 .mount(&server)
1160 .await;
1161
1162 Mock::given(method("POST"))
1163 .and(path("/tier1/query"))
1164 .respond_with(ResponseTemplate::new(200).set_body_bytes(vec![0xDE; 65536]))
1165 .mount(&server)
1166 .await;
1167 Mock::given(method("POST"))
1168 .and(path("/tier2/query"))
1169 .respond_with(
1170 ResponseTemplate::new(200)
1171 .set_body_bytes(vec![0xAD; 65536])
1172 .set_delay(Duration::from_millis(50)),
1173 )
1174 .mount(&server)
1175 .await;
1176
1177 let client = PirClient::connect(&server.uri()).await.unwrap();
1178
1179 let nullifiers: Vec<Fp> = tree
1180 .ranges
1181 .iter()
1182 .take(K)
1183 .map(|r| r[0] + Fp::one())
1184 .collect();
1185 assert_eq!(nullifiers.len(), K);
1186
1187 let result = client.fetch_proofs(&nullifiers).await;
1188 assert!(
1189 result.is_err(),
1190 "fetch_proofs should fail with corrupted tier1 responses"
1191 );
1192
1193 let received = server.received_requests().await.unwrap();
1194 let tier1_hits = received
1195 .iter()
1196 .filter(|r| r.url.path() == "/tier1/query")
1197 .count();
1198 let tier2_hits = received
1199 .iter()
1200 .filter(|r| r.url.path() == "/tier2/query")
1201 .count();
1202
1203 assert_eq!(
1204 tier1_hits, K,
1205 "all K tier1 queries should have been sent (got {})",
1206 tier1_hits,
1207 );
1208 assert_eq!(
1209 tier2_hits, K,
1210 "all K tier2 queries must still be sent when their tier1 \
1211 decodes fail (batch-level error-oracle mitigation)",
1212 );
1213 }
1214}