1use std::time::Instant;
8
9use anyhow::{Context, Result};
10use ff::PrimeField as _;
11use imt_tree::hasher::PoseidonHasher;
12use imt_tree::tree::{precompute_empty_hashes, TREE_DEPTH};
13use pasta_curves::Fp;
14pub use imt_tree::ImtProofData;
17
18use pir_types::tier0::Tier0Data;
19use pir_types::tier1::Tier1Row;
20use pir_types::tier2::Tier2Row;
21use pir_types::{
22 serialize_ypir_query, RootInfo, YpirScenario, PIR_DEPTH, TIER0_LAYERS, TIER1_LAYERS,
23 TIER1_LEAVES, TIER1_ROW_BYTES, TIER2_LEAVES, TIER2_ROW_BYTES,
24};
25
26use ypir::client::YPIRClient;
27
28struct TierTiming {
33 gen_ms: f64,
35 upload_bytes: usize,
37 download_bytes: usize,
39 rtt_ms: f64,
41 decode_ms: f64,
43 server_req_id: Option<u64>,
45 server_total_ms: Option<f64>,
47 server_validate_ms: Option<f64>,
49 server_decode_copy_ms: Option<f64>,
51 server_compute_ms: Option<f64>,
53 net_queue_ms: Option<f64>,
55 upload_to_server_ms: Option<f64>,
57 download_from_server_ms: f64,
59}
60
61struct NoteTiming {
63 tier1: TierTiming,
64 tier2: TierTiming,
65 total_ms: f64,
67}
68
69pub struct PirClient {
76 server_url: String,
77 http: reqwest::Client,
78 tier0: Tier0Data,
79 tier1_scenario: YpirScenario,
80 tier2_scenario: YpirScenario,
81 num_ranges: usize,
82 empty_hashes: [Fp; TREE_DEPTH],
83 root29: Fp,
84}
85
86#[inline]
90fn valid_leaves_for_row(num_ranges: usize, row_idx: usize) -> usize {
91 let row_start = row_idx.saturating_mul(TIER2_LEAVES);
92 num_ranges.saturating_sub(row_start).min(TIER2_LEAVES)
93}
94
95#[inline]
99fn fill_path(path: &mut [Fp; TREE_DEPTH], offset: usize, siblings: &[Fp]) {
100 path[offset..offset + siblings.len()].copy_from_slice(siblings);
101}
102
103fn process_tier0(tier0: &Tier0Data, nullifier: Fp, path: &mut [Fp; TREE_DEPTH]) -> Result<usize> {
106 let s1 = tier0
107 .find_subtree(nullifier)
108 .context("nullifier not found in any Tier 0 subtree")?;
109 fill_path(path, PIR_DEPTH - TIER0_LAYERS, &tier0.extract_siblings(s1));
110 Ok(s1)
111}
112
113fn process_tier1(tier1_row: &[u8], nullifier: Fp, path: &mut [Fp; TREE_DEPTH]) -> Result<usize> {
116 let hasher = PoseidonHasher::new();
117 let tier1 = Tier1Row::from_bytes(tier1_row)?;
118 let s2 = tier1
119 .find_sub_subtree(nullifier)
120 .context("nullifier not found in any Tier 1 sub-subtree")?;
121 fill_path(
122 path,
123 PIR_DEPTH - TIER0_LAYERS - TIER1_LAYERS,
124 &tier1.extract_siblings(s2, &hasher),
125 );
126 Ok(s2)
127}
128
129fn process_tier2_and_build(
132 tier2_row: &[u8],
133 t2_row_idx: usize,
134 num_ranges: usize,
135 nullifier: Fp,
136 path: &mut [Fp; TREE_DEPTH],
137 empty_hashes: &[Fp; TREE_DEPTH],
138 root29: Fp,
139) -> Result<ImtProofData> {
140 let hasher = PoseidonHasher::new();
141 let tier2 = Tier2Row::from_bytes(tier2_row)?;
142 let valid_leaves = valid_leaves_for_row(num_ranges, t2_row_idx);
143
144 let leaf_local_idx = tier2
145 .find_leaf(nullifier, valid_leaves)
146 .context("nullifier not found in Tier 2 leaf scan")?;
147
148 fill_path(
149 path,
150 0,
151 &tier2.extract_siblings(leaf_local_idx, valid_leaves, &hasher),
152 );
153 fill_path(path, PIR_DEPTH, &empty_hashes[PIR_DEPTH..TREE_DEPTH]);
155
156 let global_leaf_idx = t2_row_idx * TIER2_LEAVES + leaf_local_idx;
157 let (nf_lo, nf_mid, nf_hi) = tier2.leaf_record(leaf_local_idx);
158
159 Ok(ImtProofData {
160 root: root29,
161 nf_bounds: [nf_lo, nf_mid, nf_hi],
162 leaf_pos: global_leaf_idx as u32,
163 path: *path,
164 })
165}
166
167impl PirClient {
168 pub async fn connect(server_url: &str) -> Result<Self> {
170 let http = reqwest::Client::new();
171 let base = server_url.trim_end_matches('/');
172
173 let t0 = Instant::now();
175 let (tier0_resp, tier1_resp, tier2_resp, root_resp) = tokio::try_join!(
176 http.get(format!("{base}/tier0")).send(),
177 http.get(format!("{base}/params/tier1")).send(),
178 http.get(format!("{base}/params/tier2")).send(),
179 http.get(format!("{base}/root")).send(),
180 )
181 .map_err(|e| anyhow::anyhow!("connect fetch failed: {e}"))?;
182
183 let tier0_bytes = tier0_resp.error_for_status()?.bytes().await?;
184 log::debug!(
185 "Downloaded Tier 0: {} bytes in {:.1}s",
186 tier0_bytes.len(),
187 t0.elapsed().as_secs_f64()
188 );
189 let tier0 = Tier0Data::from_bytes(tier0_bytes.to_vec())?;
190
191 let tier1_scenario: YpirScenario = tier1_resp
192 .error_for_status()
193 .context("GET /params/tier1 failed")?
194 .json()
195 .await?;
196 let tier2_scenario: YpirScenario = tier2_resp
197 .error_for_status()
198 .context("GET /params/tier2 failed")?
199 .json()
200 .await?;
201
202 let root_info: RootInfo = root_resp
203 .error_for_status()
204 .context("GET /root failed")?
205 .json()
206 .await?;
207 anyhow::ensure!(
208 root_info.pir_depth == PIR_DEPTH,
209 "server pir_depth {} != expected {}",
210 root_info.pir_depth,
211 PIR_DEPTH
212 );
213 let root29_bytes = hex::decode(&root_info.root29)?;
214 anyhow::ensure!(
215 root29_bytes.len() == 32,
216 "root29 hex decoded to {} bytes, expected 32",
217 root29_bytes.len()
218 );
219 let mut root29_arr = [0u8; 32];
220 root29_arr.copy_from_slice(&root29_bytes);
221 let root29 = Option::from(Fp::from_repr(root29_arr))
222 .ok_or_else(|| anyhow::anyhow!("invalid root29 field element"))?;
223
224 let empty_hashes = precompute_empty_hashes();
225
226 Ok(Self {
227 server_url: base.to_string(),
228 http,
229 tier0,
230 tier1_scenario,
231 tier2_scenario,
232 num_ranges: root_info.num_ranges,
233 empty_hashes,
234 root29,
235 })
236 }
237
238 pub async fn fetch_proof(&self, nullifier: Fp) -> Result<ImtProofData> {
243 let (proof, _timing) = self.fetch_proof_inner(nullifier).await?;
244 Ok(proof)
245 }
246
247 pub async fn fetch_proofs(&self, nullifiers: &[Fp]) -> Result<Vec<ImtProofData>> {
252 log::debug!(
253 "[PIR] Starting parallel fetch for {} notes...",
254 nullifiers.len()
255 );
256 let wall_start = Instant::now();
257
258 let futures: Vec<_> = nullifiers
259 .iter()
260 .enumerate()
261 .map(|(i, &nf)| async move {
262 let (proof, timing) = self.fetch_proof_inner(nf).await?;
263 Ok::<_, anyhow::Error>((i, proof, timing))
264 })
265 .collect();
266
267 let results_with_timing = futures::future::try_join_all(futures).await?;
268 let wall_ms = wall_start.elapsed().as_secs_f64() * 1000.0;
269
270 print_timing_table(&results_with_timing, wall_ms);
271
272 let proofs = results_with_timing
273 .into_iter()
274 .map(|(_, proof, _)| proof)
275 .collect();
276 Ok(proofs)
277 }
278
279 async fn fetch_proof_inner(&self, nullifier: Fp) -> Result<(ImtProofData, NoteTiming)> {
290 let note_start = Instant::now();
291 let mut path = [Fp::default(); TREE_DEPTH];
292
293 let s1 = process_tier0(&self.tier0, nullifier, &mut path)?;
295
296 let tier1_outcome = self
305 .ypir_query(&self.tier1_scenario, "tier1", s1, TIER1_ROW_BYTES)
306 .await
307 .and_then(|(row, timing)| {
308 let mut_path = &mut path;
309 let s2 = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
310 process_tier1(&row, nullifier, mut_path)
311 }))
312 .unwrap_or_else(|payload| {
313 let msg = payload
314 .downcast_ref::<String>()
315 .map(|s| s.as_str())
316 .or_else(|| payload.downcast_ref::<&str>().copied())
317 .unwrap_or("unknown panic");
318 Err(anyhow::anyhow!("process_tier1 panicked: {}", msg))
319 })?;
320 Ok((s1 * TIER1_LEAVES + s2, timing))
321 });
322
323 let t2_row_idx = tier1_outcome.as_ref().map(|(idx, _)| *idx).unwrap_or(0);
326
327 let t2_bounds_err = if t2_row_idx >= self.tier2_scenario.num_items {
336 Some(anyhow::anyhow!(
337 "tier2 row_idx {} >= num_items {}",
338 t2_row_idx,
339 self.tier2_scenario.num_items
340 ))
341 } else {
342 None
343 };
344 let t2_query_idx = if t2_bounds_err.is_some() {
345 0
346 } else {
347 t2_row_idx
348 };
349
350 let tier2_result = self
352 .ypir_query(&self.tier2_scenario, "tier2", t2_query_idx, TIER2_ROW_BYTES)
353 .await;
354
355 let (t2_row_idx, tier1_timing) = tier1_outcome?;
357 if let Some(e) = t2_bounds_err {
358 return Err(e);
359 }
360 let (tier2_row, tier2_timing) = tier2_result?;
361
362 let proof = process_tier2_and_build(
363 &tier2_row,
364 t2_row_idx,
365 self.num_ranges,
366 nullifier,
367 &mut path,
368 &self.empty_hashes,
369 self.root29,
370 )?;
371
372 let total_ms = note_start.elapsed().as_secs_f64() * 1000.0;
373 Ok((
374 proof,
375 NoteTiming {
376 tier1: tier1_timing,
377 tier2: tier2_timing,
378 total_ms,
379 },
380 ))
381 }
382
383 async fn ypir_query(
389 &self,
390 scenario: &YpirScenario,
391 tier_name: &str,
392 row_idx: usize,
393 expected_row_bytes: usize,
394 ) -> Result<(Vec<u8>, TierTiming)> {
395 anyhow::ensure!(
396 row_idx < scenario.num_items,
397 "{} row_idx {} >= num_items {}",
398 tier_name,
399 row_idx,
400 scenario.num_items
401 );
402 let t0 = Instant::now();
403 let ypir_client = YPIRClient::from_db_sz(
404 scenario.num_items as u64,
405 scenario.item_size_bits as u64,
406 true,
407 );
408
409 let (query, seed) = ypir_client.generate_query_simplepir(row_idx);
411 let gen_ms = t0.elapsed().as_secs_f64() * 1000.0;
412
413 let payload = serialize_ypir_query(query.0.as_slice(), query.1.as_slice());
415 let upload_bytes = payload.len();
416
417 let t1 = Instant::now();
419 let url = format!("{}/{}/query", self.server_url, tier_name);
420 let send_result = self.http.post(&url).body(payload).send().await;
421 let send_ms = t1.elapsed().as_secs_f64() * 1000.0;
422 let resp = match send_result {
423 Ok(r) => r,
424 Err(e) => {
425 log::warn!("YPIR {} send error: {:?}", tier_name, e);
426 return Err(e.into());
427 }
428 };
429 let server_req_id = parse_header_u64(resp.headers(), "x-pir-req-id");
430 let server_total_ms = parse_header_f64(resp.headers(), "x-pir-server-total-ms");
431 let server_validate_ms = parse_header_f64(resp.headers(), "x-pir-server-validate-ms");
432 let server_decode_copy_ms = parse_header_f64(resp.headers(), "x-pir-server-decode-copy-ms");
433 let server_compute_ms = parse_header_f64(resp.headers(), "x-pir-server-compute-ms");
434 let status = resp.status();
435 let response_bytes = resp.bytes().await?;
436 if !status.is_success() {
437 anyhow::bail!(
438 "{} query failed: HTTP {} body={}",
439 tier_name,
440 status,
441 String::from_utf8_lossy(&response_bytes)
442 );
443 }
444 let rtt_ms = t1.elapsed().as_secs_f64() * 1000.0;
445 let download_from_server_ms = (rtt_ms - send_ms).max(0.0);
446 let net_queue_ms = server_total_ms.map(|server_ms| (rtt_ms - server_ms).max(0.0));
447 let upload_to_server_ms = server_total_ms.map(|server_ms| (send_ms - server_ms).max(0.0));
448
449 let t2 = Instant::now();
455 let decoded = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
456 ypir_client.decode_response_simplepir(seed, &response_bytes)
457 }))
458 .map_err(|panic_payload| {
459 let msg = panic_payload
460 .downcast_ref::<String>()
461 .map(|s| s.as_str())
462 .or_else(|| panic_payload.downcast_ref::<&str>().copied())
463 .unwrap_or("unknown panic");
464 anyhow::anyhow!("{} response decryption panicked: {}", tier_name, msg)
465 })?;
466 let decode_ms = t2.elapsed().as_secs_f64() * 1000.0;
467
468 anyhow::ensure!(
469 decoded.len() >= expected_row_bytes,
470 "{} decoded response too short: {} bytes, expected >= {}",
471 tier_name,
472 decoded.len(),
473 expected_row_bytes
474 );
475 Ok((
476 decoded[..expected_row_bytes].to_vec(),
477 TierTiming {
478 gen_ms,
479 upload_bytes,
480 download_bytes: response_bytes.len(),
481 rtt_ms,
482 decode_ms,
483 server_req_id,
484 server_total_ms,
485 server_validate_ms,
486 server_decode_copy_ms,
487 server_compute_ms,
488 net_queue_ms,
489 upload_to_server_ms,
490 download_from_server_ms,
491 },
492 ))
493 }
494}
495
496fn fmt_time(ms: f64) -> String {
497 if ms >= 1000.0 {
498 format!("{:>5.1}s ", ms / 1000.0)
499 } else {
500 format!("{:>5.0}ms", ms)
501 }
502}
503
504fn fmt_opt_time(ms: Option<f64>) -> String {
505 match ms {
506 Some(v) => fmt_time(v),
507 None => " n/a ".to_string(),
508 }
509}
510
511fn print_timing_table(results: &[(usize, ImtProofData, NoteTiming)], wall_ms: f64) {
513 if !log::log_enabled!(log::Level::Debug) {
514 return;
515 }
516
517 log::debug!("[PIR] ┌─────┬──────────┬─────────────┬──────────┬──────────┬─────────────┬──────────┬────────┐");
518 log::debug!("[PIR] │ Note│ T1 keygen│ T1 upload+ │ T1 decode│ T2 keygen│ T2 upload+ │ T2 decode│ Total │");
519 log::debug!("[PIR] │ │ (client) │ server+down │ (client) │ (client) │ server+down │ (client) │ │");
520 log::debug!("[PIR] ├─────┼──────────┼─────────────┼──────────┼──────────┼─────────────┼──────────┼────────┤");
521 for &(i, _, ref t) in results {
522 log::debug!(
523 "[PIR] │ {i:>2} │ {:>6} │ {:>7} │ {:>6} │ {:>6} │ {:>7} │ {:>6} │{} │",
524 fmt_time(t.tier1.gen_ms),
525 fmt_time(t.tier1.rtt_ms),
526 fmt_time(t.tier1.decode_ms),
527 fmt_time(t.tier2.gen_ms),
528 fmt_time(t.tier2.rtt_ms),
529 fmt_time(t.tier2.decode_ms),
530 fmt_time(t.total_ms),
531 );
532 }
533 log::debug!("[PIR] └─────┴──────────┴─────────────┴──────────┴──────────┴─────────────┴──────────┴────────┘");
534 log::debug!(
535 "[PIR] Upload per note: T1={:.0}KB T2={:.1}MB | Wall clock: {:.2}s",
536 results
537 .first()
538 .map(|(_, _, t)| t.tier1.upload_bytes)
539 .unwrap_or(0) as f64
540 / 1024.0,
541 results
542 .first()
543 .map(|(_, _, t)| t.tier2.upload_bytes)
544 .unwrap_or(0) as f64
545 / (1024.0 * 1024.0),
546 wall_ms / 1000.0,
547 );
548
549 for &(i, _, ref t) in results {
550 log::trace!(
551 "[PIR] Note {i:>2} transfer: T1 up={:.0}KB down={:.0}KB | T2 up={:.1}MB down={:.0}KB",
552 t.tier1.upload_bytes as f64 / 1024.0,
553 t.tier1.download_bytes as f64 / 1024.0,
554 t.tier2.upload_bytes as f64 / (1024.0 * 1024.0),
555 t.tier2.download_bytes as f64 / 1024.0,
556 );
557 log::trace!(
558 "[PIR] Note {i:>2} server/net: T1 {} / {} | T2 {} / {}",
559 fmt_opt_time(t.tier1.server_total_ms),
560 fmt_opt_time(t.tier1.net_queue_ms),
561 fmt_opt_time(t.tier2.server_total_ms),
562 fmt_opt_time(t.tier2.net_queue_ms),
563 );
564 log::trace!(
565 "[PIR] Note {i:>2} up/srv/down: T1 {} / {} / {} | T2 {} / {} / {}",
566 fmt_opt_time(t.tier1.upload_to_server_ms),
567 fmt_opt_time(t.tier1.server_total_ms),
568 fmt_time(t.tier1.download_from_server_ms),
569 fmt_opt_time(t.tier2.upload_to_server_ms),
570 fmt_opt_time(t.tier2.server_total_ms),
571 fmt_time(t.tier2.download_from_server_ms),
572 );
573 log::trace!(
574 "[PIR] Note {i:>2} server stages: T1(v={} copy={} compute={}) T2(v={} copy={} compute={})",
575 fmt_opt_time(t.tier1.server_validate_ms),
576 fmt_opt_time(t.tier1.server_decode_copy_ms),
577 fmt_opt_time(t.tier1.server_compute_ms),
578 fmt_opt_time(t.tier2.server_validate_ms),
579 fmt_opt_time(t.tier2.server_decode_copy_ms),
580 fmt_opt_time(t.tier2.server_compute_ms),
581 );
582 log::trace!(
583 "[PIR] Note {i:>2} req ids: T1={:?} T2={:?}",
584 t.tier1.server_req_id,
585 t.tier2.server_req_id
586 );
587 }
588}
589
590fn parse_header_f64(headers: &reqwest::header::HeaderMap, name: &'static str) -> Option<f64> {
592 headers
593 .get(name)
594 .and_then(|v| v.to_str().ok())
595 .and_then(|s| s.parse::<f64>().ok())
596}
597
598fn parse_header_u64(headers: &reqwest::header::HeaderMap, name: &'static str) -> Option<u64> {
600 headers
601 .get(name)
602 .and_then(|v| v.to_str().ok())
603 .and_then(|s| s.parse::<u64>().ok())
604}
605
606pub struct PirClientBlocking {
613 inner: PirClient,
614 rt: tokio::runtime::Runtime,
615}
616
617impl PirClientBlocking {
618 pub fn connect(server_url: &str) -> Result<Self> {
620 let rt = tokio::runtime::Runtime::new()?;
621 let inner = rt.block_on(PirClient::connect(server_url))?;
622 Ok(Self { inner, rt })
623 }
624
625 pub fn fetch_proof(&self, nullifier: Fp) -> Result<ImtProofData> {
627 self.rt.block_on(self.inner.fetch_proof(nullifier))
628 }
629
630 pub fn fetch_proofs(&self, nullifiers: &[Fp]) -> Result<Vec<ImtProofData>> {
632 self.rt.block_on(self.inner.fetch_proofs(nullifiers))
633 }
634
635 pub fn root29(&self) -> Fp {
637 self.inner.root29
638 }
639}
640
641pub fn fetch_proof_local(
648 tier0_data: &[u8],
649 tier1_data: &[u8],
650 tier2_data: &[u8],
651 num_ranges: usize,
652 nullifier: Fp,
653 empty_hashes: &[Fp; TREE_DEPTH],
654 root29: Fp,
655) -> Result<ImtProofData> {
656 let mut path = [Fp::default(); TREE_DEPTH];
657 let tier0 = Tier0Data::from_bytes(tier0_data.to_vec())?;
658
659 let s1 = process_tier0(&tier0, nullifier, &mut path)?;
660
661 let t1_offset = s1 * TIER1_ROW_BYTES;
663 anyhow::ensure!(
664 t1_offset + TIER1_ROW_BYTES <= tier1_data.len(),
665 "tier1 data too short: need {} bytes at offset {}, have {}",
666 TIER1_ROW_BYTES,
667 t1_offset,
668 tier1_data.len()
669 );
670 let s2 = process_tier1(
671 &tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES],
672 nullifier,
673 &mut path,
674 )?;
675
676 let t2_row_idx = s1 * TIER1_LEAVES + s2;
678 let t2_offset = t2_row_idx * TIER2_ROW_BYTES;
679 anyhow::ensure!(
680 t2_offset + TIER2_ROW_BYTES <= tier2_data.len(),
681 "tier2 data too short: need {} bytes at offset {}, have {}",
682 TIER2_ROW_BYTES,
683 t2_offset,
684 tier2_data.len()
685 );
686
687 process_tier2_and_build(
688 &tier2_data[t2_offset..t2_offset + TIER2_ROW_BYTES],
689 t2_row_idx,
690 num_ranges,
691 nullifier,
692 &mut path,
693 empty_hashes,
694 root29,
695 )
696}
697
698#[cfg(test)]
699mod tests {
700 use super::*;
701 use ff::Field;
702 use pasta_curves::Fp;
703 use pir_export::build_ranges_with_sentinels;
704
705 struct TestFixture {
707 tier0_data: Vec<u8>,
708 tier1_data: Vec<u8>,
709 tier2_data: Vec<u8>,
710 ranges: Vec<[Fp; 3]>,
711 empty_hashes: [Fp; TREE_DEPTH],
712 root29: Fp,
713 }
714
715 impl TestFixture {
716 fn build(raw_nfs: &[Fp]) -> Self {
717 let ranges = build_ranges_with_sentinels(raw_nfs);
718 let tree = pir_export::build_pir_tree(ranges.clone()).unwrap();
719
720 let tier0_data = pir_export::tier0::export(
721 &tree.root25,
722 &tree.levels,
723 &tree.ranges,
724 &tree.empty_hashes,
725 );
726 let mut tier1_data = Vec::new();
727 pir_export::tier1::export(
728 &tree.levels,
729 &tree.ranges,
730 &tree.empty_hashes,
731 &mut tier1_data,
732 )
733 .unwrap();
734 let mut tier2_data = Vec::new();
735 pir_export::tier2::export(&tree.ranges, &mut tier2_data).unwrap();
736
737 Self {
738 tier0_data,
739 tier1_data,
740 tier2_data,
741 ranges,
742 empty_hashes: tree.empty_hashes,
743 root29: tree.root29,
744 }
745 }
746 }
747
748 #[test]
751 fn fetch_proof_local_verifies_for_known_ranges() {
752 let mut rng = rand::thread_rng();
753 let raw_nfs: Vec<Fp> = (0..100).map(|_| Fp::random(&mut rng)).collect();
754 let fix = TestFixture::build(&raw_nfs);
755
756 for &[nf_lo, _, _] in fix.ranges.iter().take(20) {
757 let value = nf_lo + Fp::one();
758 let proof = fetch_proof_local(
759 &fix.tier0_data,
760 &fix.tier1_data,
761 &fix.tier2_data,
762 fix.ranges.len(),
763 value,
764 &fix.empty_hashes,
765 fix.root29,
766 )
767 .expect("fetch_proof_local should succeed for a value in range");
768 assert!(
769 proof.verify(value),
770 "proof should verify for value {:?}",
771 value,
772 );
773 }
774 }
775
776 #[test]
777 fn fetch_proof_local_correct_root_and_path_length() {
778 let raw_nfs: Vec<Fp> = (1u64..=50).map(|i| Fp::from(i * 997)).collect();
779 let fix = TestFixture::build(&raw_nfs);
780
781 let value = fix.ranges[0][0] + Fp::one(); let proof = fetch_proof_local(
783 &fix.tier0_data,
784 &fix.tier1_data,
785 &fix.tier2_data,
786 fix.ranges.len(),
787 value,
788 &fix.empty_hashes,
789 fix.root29,
790 )
791 .unwrap();
792
793 assert_eq!(proof.root, fix.root29);
794 assert_eq!(proof.path.len(), TREE_DEPTH);
795 }
796
797 #[test]
800 fn process_tier0_fills_correct_path_region() {
801 let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
802 let fix = TestFixture::build(&raw_nfs);
803 let tier0 = Tier0Data::from_bytes(fix.tier0_data).unwrap();
804
805 let value = fix.ranges[0][0];
806 let mut path = [Fp::default(); TREE_DEPTH];
807 let s1 = process_tier0(&tier0, value, &mut path).unwrap();
808
809 assert!(s1 < pir_types::TIER1_ROWS);
810
811 let tier0_region = &path[PIR_DEPTH - TIER0_LAYERS..PIR_DEPTH];
812 assert!(
813 tier0_region.iter().any(|&v| v != Fp::default()),
814 "tier0 should write at least one non-zero sibling"
815 );
816
817 let below = &path[..PIR_DEPTH - TIER0_LAYERS];
818 assert!(
819 below.iter().all(|&v| v == Fp::default()),
820 "path below tier0 region should be untouched"
821 );
822 }
823
824 #[test]
825 fn process_tier0_handles_arbitrary_field_element() {
826 let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
827 let fix = TestFixture::build(&raw_nfs);
828 let tier0 = Tier0Data::from_bytes(fix.tier0_data).unwrap();
829
830 let bogus = Fp::from(u64::MAX);
834 let mut path = [Fp::default(); TREE_DEPTH];
835 let s1 = process_tier0(&tier0, bogus, &mut path).unwrap();
836 assert!(s1 < pir_types::TIER1_ROWS);
837 }
838
839 #[test]
842 fn process_tier1_fills_correct_path_region() {
843 let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
844 let fix = TestFixture::build(&raw_nfs);
845 let tier0 = Tier0Data::from_bytes(fix.tier0_data.clone()).unwrap();
846
847 let value = fix.ranges[0][0];
848 let mut path = [Fp::default(); TREE_DEPTH];
849 let s1 = process_tier0(&tier0, value, &mut path).unwrap();
850
851 let t1_offset = s1 * TIER1_ROW_BYTES;
852 let tier1_row = &fix.tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES];
853 let s2 = process_tier1(tier1_row, value, &mut path).unwrap();
854
855 assert!(s2 < TIER1_LEAVES);
856
857 let tier1_region = &path[PIR_DEPTH - TIER0_LAYERS - TIER1_LAYERS..PIR_DEPTH - TIER0_LAYERS];
858 assert!(
859 tier1_region.iter().any(|&v| v != Fp::default()),
860 "tier1 should write at least one non-zero sibling"
861 );
862 }
863
864 #[test]
867 fn process_tier2_and_build_produces_verifiable_proof() {
868 let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
869 let fix = TestFixture::build(&raw_nfs);
870 let tier0 = Tier0Data::from_bytes(fix.tier0_data.clone()).unwrap();
871
872 let value = fix.ranges[0][0] + Fp::one();
873 let mut path = [Fp::default(); TREE_DEPTH];
874
875 let s1 = process_tier0(&tier0, value, &mut path).unwrap();
876 let t1_offset = s1 * TIER1_ROW_BYTES;
877 let s2 = process_tier1(
878 &fix.tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES],
879 value,
880 &mut path,
881 )
882 .unwrap();
883
884 let t2_row_idx = s1 * TIER1_LEAVES + s2;
885 let t2_offset = t2_row_idx * TIER2_ROW_BYTES;
886 let proof = process_tier2_and_build(
887 &fix.tier2_data[t2_offset..t2_offset + TIER2_ROW_BYTES],
888 t2_row_idx,
889 fix.ranges.len(),
890 value,
891 &mut path,
892 &fix.empty_hashes,
893 fix.root29,
894 )
895 .unwrap();
896
897 assert!(proof.verify(value));
898 assert_eq!(proof.root, fix.root29);
899 }
900
901 #[test]
904 fn valid_leaves_for_row_basic() {
905 assert_eq!(valid_leaves_for_row(TIER2_LEAVES, 0), TIER2_LEAVES);
906 assert_eq!(valid_leaves_for_row(TIER2_LEAVES + 1, 0), TIER2_LEAVES);
907 assert_eq!(valid_leaves_for_row(TIER2_LEAVES + 1, 1), 1);
908 assert_eq!(valid_leaves_for_row(0, 0), 0);
909 assert_eq!(valid_leaves_for_row(1, 0), 1);
910 assert_eq!(valid_leaves_for_row(1, 1), 0);
911 }
912
913 #[test]
916 fn fetch_proof_local_rejects_truncated_tier1() {
917 let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
918 let fix = TestFixture::build(&raw_nfs);
919
920 let result = fetch_proof_local(
921 &fix.tier0_data,
922 &fix.tier1_data[..TIER1_ROW_BYTES / 2],
923 &fix.tier2_data,
924 fix.ranges.len(),
925 fix.ranges[0][0],
926 &fix.empty_hashes,
927 fix.root29,
928 );
929 assert!(result.is_err());
930 }
931
932 #[test]
933 fn fetch_proof_local_rejects_truncated_tier2() {
934 let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
935 let fix = TestFixture::build(&raw_nfs);
936
937 let result = fetch_proof_local(
938 &fix.tier0_data,
939 &fix.tier1_data,
940 &fix.tier2_data[..TIER2_ROW_BYTES / 2],
941 fix.ranges.len(),
942 fix.ranges[0][0],
943 &fix.empty_hashes,
944 fix.root29,
945 );
946 assert!(result.is_err());
947 }
948
949 #[tokio::test]
962 async fn tier2_query_sent_despite_tier1_decode_failure() {
963 use ff::PrimeField as _;
964 use pir_types::{TIER1_ITEM_BITS, TIER1_ROWS, TIER2_ITEM_BITS};
965 use wiremock::matchers::{method, path};
966 use wiremock::{Mock, MockServer, ResponseTemplate};
967
968 let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
971 let ranges = build_ranges_with_sentinels(&raw_nfs);
972 let tree = pir_export::build_pir_tree(ranges).unwrap();
973 let tier0_data =
974 pir_export::tier0::export(&tree.root25, &tree.levels, &tree.ranges, &tree.empty_hashes);
975
976 let root_info = pir_types::RootInfo {
977 root29: hex::encode(tree.root29.to_repr()),
978 root25: hex::encode(tree.root25.to_repr()),
979 num_ranges: tree.ranges.len(),
980 pir_depth: PIR_DEPTH,
981 height: None,
982 };
983
984 let tier1_scenario = YpirScenario {
988 num_items: TIER1_ROWS,
989 item_size_bits: TIER1_ITEM_BITS,
990 };
991 let tier2_scenario = YpirScenario {
992 num_items: TIER1_ROWS,
993 item_size_bits: TIER2_ITEM_BITS,
994 };
995
996 let server = MockServer::start().await;
997
998 Mock::given(method("GET"))
1000 .and(path("/tier0"))
1001 .respond_with(ResponseTemplate::new(200).set_body_bytes(tier0_data))
1002 .mount(&server)
1003 .await;
1004 Mock::given(method("GET"))
1005 .and(path("/params/tier1"))
1006 .respond_with(ResponseTemplate::new(200).set_body_json(&tier1_scenario))
1007 .mount(&server)
1008 .await;
1009 Mock::given(method("GET"))
1010 .and(path("/params/tier2"))
1011 .respond_with(ResponseTemplate::new(200).set_body_json(&tier2_scenario))
1012 .mount(&server)
1013 .await;
1014 Mock::given(method("GET"))
1015 .and(path("/root"))
1016 .respond_with(ResponseTemplate::new(200).set_body_json(&root_info))
1017 .mount(&server)
1018 .await;
1019
1020 Mock::given(method("POST"))
1022 .and(path("/tier1/query"))
1023 .respond_with(ResponseTemplate::new(200).set_body_bytes(vec![0xDE; 65536]))
1024 .mount(&server)
1025 .await;
1026 Mock::given(method("POST"))
1027 .and(path("/tier2/query"))
1028 .respond_with(ResponseTemplate::new(200).set_body_bytes(vec![0xAD; 65536]))
1029 .mount(&server)
1030 .await;
1031
1032 let client = PirClient::connect(&server.uri()).await.unwrap();
1034 let nullifier = tree.ranges[0][0];
1035 let result = client.fetch_proof(nullifier).await;
1036
1037 assert!(
1038 result.is_err(),
1039 "fetch_proof should fail with corrupted tier1 response"
1040 );
1041
1042 let received = server.received_requests().await.unwrap();
1044 let tier1_hits = received
1045 .iter()
1046 .filter(|r| r.url.path() == "/tier1/query")
1047 .count();
1048 let tier2_hits = received
1049 .iter()
1050 .filter(|r| r.url.path() == "/tier2/query")
1051 .count();
1052
1053 assert_eq!(tier1_hits, 1, "tier1 query should have been sent");
1054 assert_eq!(
1055 tier2_hits, 1,
1056 "tier2 query must still be sent when tier1 decode fails \
1057 (error-oracle mitigation)"
1058 );
1059 }
1060}