1use std::{sync::Arc, time::Instant};
8
9use anyhow::{Context, Result};
10use ff::PrimeField as _;
11use imt_tree::hasher::PoseidonHasher;
12use imt_tree::tree::{precompute_empty_hashes, TREE_DEPTH};
13use pasta_curves::Fp;
14pub use imt_tree::ImtProofData;
17
18mod transport;
19pub use transport::{Transport, TransportFuture, TransportResponse};
20
21use pir_types::tier0::Tier0Data;
22use pir_types::tier1::Tier1Row;
23use pir_types::tier2::Tier2Row;
24use pir_types::{
25 serialize_ypir_query, RootInfo, YpirScenario, PIR_DEPTH, TIER0_LAYERS, TIER1_LAYERS,
26 TIER1_LEAVES, TIER1_ROW_BYTES, TIER2_LEAVES, TIER2_ROW_BYTES,
27};
28
29use ypir::client::YPIRClient;
30
31pub struct TierTiming {
36 pub gen_ms: f64,
38 pub upload_bytes: usize,
40 pub upload_q_bytes: usize,
44 pub upload_pp_bytes: usize,
48 pub download_bytes: usize,
50 pub rtt_ms: f64,
52 pub decode_ms: f64,
54 pub server_req_id: Option<u64>,
56 pub server_total_ms: Option<f64>,
58 pub server_validate_ms: Option<f64>,
60 pub server_decode_copy_ms: Option<f64>,
62 pub server_compute_ms: Option<f64>,
64 pub net_queue_ms: Option<f64>,
66 pub upload_to_server_ms: Option<f64>,
68 pub download_from_server_ms: f64,
70}
71
72pub struct NoteTiming {
74 pub tier1: TierTiming,
75 pub tier2: TierTiming,
76 pub total_ms: f64,
78}
79
80pub struct PirClient {
87 server_url: String,
88 transport: Arc<dyn Transport>,
89 tier0: Tier0Data,
90 tier1_scenario: YpirScenario,
91 tier2_scenario: YpirScenario,
92 num_ranges: usize,
93 empty_hashes: [Fp; TREE_DEPTH],
94 root29: Fp,
95}
96
97#[inline]
101fn valid_leaves_for_row(num_ranges: usize, row_idx: usize) -> usize {
102 let row_start = row_idx.saturating_mul(TIER2_LEAVES);
103 num_ranges.saturating_sub(row_start).min(TIER2_LEAVES)
104}
105
106#[inline]
110fn fill_path(path: &mut [Fp; TREE_DEPTH], offset: usize, siblings: &[Fp]) {
111 path[offset..offset + siblings.len()].copy_from_slice(siblings);
112}
113
114fn process_tier0(tier0: &Tier0Data, nullifier: Fp, path: &mut [Fp; TREE_DEPTH]) -> Result<usize> {
117 let s1 = tier0
118 .find_subtree(nullifier)
119 .context("nullifier not found in any Tier 0 subtree")?;
120 fill_path(path, PIR_DEPTH - TIER0_LAYERS, &tier0.extract_siblings(s1));
121 Ok(s1)
122}
123
124fn process_tier1(tier1_row: &[u8], nullifier: Fp, path: &mut [Fp; TREE_DEPTH]) -> Result<usize> {
127 let hasher = PoseidonHasher::new();
128 let tier1 = Tier1Row::from_bytes(tier1_row)?;
129 let s2 = tier1
130 .find_sub_subtree(nullifier)
131 .context("nullifier not found in any Tier 1 sub-subtree")?;
132 fill_path(
133 path,
134 PIR_DEPTH - TIER0_LAYERS - TIER1_LAYERS,
135 &tier1.extract_siblings(s2, &hasher),
136 );
137 Ok(s2)
138}
139
140fn process_tier2_and_build(
143 tier2_row: &[u8],
144 t2_row_idx: usize,
145 num_ranges: usize,
146 nullifier: Fp,
147 path: &mut [Fp; TREE_DEPTH],
148 empty_hashes: &[Fp; TREE_DEPTH],
149 root29: Fp,
150) -> Result<ImtProofData> {
151 let hasher = PoseidonHasher::new();
152 let tier2 = Tier2Row::from_bytes(tier2_row)?;
153 let valid_leaves = valid_leaves_for_row(num_ranges, t2_row_idx);
154
155 let leaf_local_idx = tier2
156 .find_leaf(nullifier, valid_leaves)
157 .context("nullifier not found in Tier 2 leaf scan")?;
158
159 fill_path(
160 path,
161 0,
162 &tier2.extract_siblings(leaf_local_idx, valid_leaves, &hasher),
163 );
164 fill_path(path, PIR_DEPTH, &empty_hashes[PIR_DEPTH..TREE_DEPTH]);
166
167 let global_leaf_idx = t2_row_idx * TIER2_LEAVES + leaf_local_idx;
168 let (nf_lo, nf_mid, nf_hi) = tier2.leaf_record(leaf_local_idx);
169
170 Ok(ImtProofData {
171 root: root29,
172 nf_bounds: [nf_lo, nf_mid, nf_hi],
173 leaf_pos: global_leaf_idx as u32,
174 path: *path,
175 })
176}
177
178impl PirClient {
179 pub async fn with_transport(server_url: &str, transport: Arc<dyn Transport>) -> Result<Self> {
181 let base = server_url.trim_end_matches('/');
182
183 let t0 = Instant::now();
185 let tier0_url = format!("{base}/tier0");
186 let tier1_url = format!("{base}/params/tier1");
187 let tier2_url = format!("{base}/params/tier2");
188 let root_url = format!("{base}/root");
189 let (tier0_resp, tier1_resp, tier2_resp, root_resp) = tokio::try_join!(
190 transport.get(&tier0_url),
191 transport.get(&tier1_url),
192 transport.get(&tier2_url),
193 transport.get(&root_url),
194 )
195 .map_err(|e| anyhow::anyhow!("connect fetch failed: {e}"))?;
196
197 let tier0_bytes = body_for_status(tier0_resp, "GET /tier0 failed")?;
198 log::debug!(
199 "Downloaded Tier 0: {} bytes in {:.1}s",
200 tier0_bytes.len(),
201 t0.elapsed().as_secs_f64()
202 );
203 let tier0 = Tier0Data::from_bytes(tier0_bytes.to_vec())?;
204
205 let tier1_scenario: YpirScenario =
206 serde_json::from_slice(&body_for_status(tier1_resp, "GET /params/tier1 failed")?)
207 .context("parse /params/tier1 response")?;
208 let tier2_scenario: YpirScenario =
209 serde_json::from_slice(&body_for_status(tier2_resp, "GET /params/tier2 failed")?)
210 .context("parse /params/tier2 response")?;
211
212 let root_info: RootInfo =
213 serde_json::from_slice(&body_for_status(root_resp, "GET /root failed")?)
214 .context("parse /root response")?;
215 anyhow::ensure!(
216 root_info.pir_depth == PIR_DEPTH,
217 "server pir_depth {} != expected {}",
218 root_info.pir_depth,
219 PIR_DEPTH
220 );
221 let root29_bytes = hex::decode(&root_info.root29)?;
222 anyhow::ensure!(
223 root29_bytes.len() == 32,
224 "root29 hex decoded to {} bytes, expected 32",
225 root29_bytes.len()
226 );
227 let mut root29_arr = [0u8; 32];
228 root29_arr.copy_from_slice(&root29_bytes);
229 let root29 = Option::from(Fp::from_repr(root29_arr))
230 .ok_or_else(|| anyhow::anyhow!("invalid root29 field element"))?;
231
232 let empty_hashes = precompute_empty_hashes();
233
234 Ok(Self {
235 server_url: base.to_string(),
236 transport,
237 tier0,
238 tier1_scenario,
239 tier2_scenario,
240 num_ranges: root_info.num_ranges,
241 empty_hashes,
242 root29,
243 })
244 }
245
246 pub async fn fetch_proof(&self, nullifier: Fp) -> Result<ImtProofData> {
251 let (proof, _timing) = self.fetch_proof_inner(nullifier).await?;
252 Ok(proof)
253 }
254
255 pub async fn fetch_proof_with_timing(
258 &self,
259 nullifier: Fp,
260 ) -> Result<(ImtProofData, NoteTiming)> {
261 self.fetch_proof_inner(nullifier).await
262 }
263
264 pub async fn fetch_proofs(&self, nullifiers: &[Fp]) -> Result<Vec<ImtProofData>> {
269 log::debug!(
270 "[PIR] Starting parallel fetch for {} notes...",
271 nullifiers.len()
272 );
273 let wall_start = Instant::now();
274
275 let futures: Vec<_> = nullifiers
276 .iter()
277 .enumerate()
278 .map(|(i, &nf)| async move {
279 let (proof, timing) = self.fetch_proof_inner(nf).await?;
280 Ok::<_, anyhow::Error>((i, proof, timing))
281 })
282 .collect();
283
284 let results_with_timing = futures::future::try_join_all(futures).await?;
285 let wall_ms = wall_start.elapsed().as_secs_f64() * 1000.0;
286
287 print_timing_table(&results_with_timing, wall_ms);
288
289 let proofs = results_with_timing
290 .into_iter()
291 .map(|(_, proof, _)| proof)
292 .collect();
293 Ok(proofs)
294 }
295
296 async fn fetch_proof_inner(&self, nullifier: Fp) -> Result<(ImtProofData, NoteTiming)> {
307 let note_start = Instant::now();
308 let mut path = [Fp::default(); TREE_DEPTH];
309
310 let s1 = process_tier0(&self.tier0, nullifier, &mut path)?;
312
313 let tier1_outcome = self
322 .ypir_query(&self.tier1_scenario, "tier1", s1, TIER1_ROW_BYTES)
323 .await
324 .and_then(|(row, timing)| {
325 let mut_path = &mut path;
326 let s2 = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
327 process_tier1(&row, nullifier, mut_path)
328 }))
329 .unwrap_or_else(|payload| {
330 let msg = payload
331 .downcast_ref::<String>()
332 .map(|s| s.as_str())
333 .or_else(|| payload.downcast_ref::<&str>().copied())
334 .unwrap_or("unknown panic");
335 Err(anyhow::anyhow!("process_tier1 panicked: {}", msg))
336 })?;
337 Ok((s1 * TIER1_LEAVES + s2, timing))
338 });
339
340 let t2_row_idx = tier1_outcome.as_ref().map(|(idx, _)| *idx).unwrap_or(0);
343
344 let t2_bounds_err = if t2_row_idx >= self.tier2_scenario.num_items {
353 Some(anyhow::anyhow!(
354 "tier2 row_idx {} >= num_items {}",
355 t2_row_idx,
356 self.tier2_scenario.num_items
357 ))
358 } else {
359 None
360 };
361 let t2_query_idx = if t2_bounds_err.is_some() {
362 0
363 } else {
364 t2_row_idx
365 };
366
367 let tier2_result = self
369 .ypir_query(&self.tier2_scenario, "tier2", t2_query_idx, TIER2_ROW_BYTES)
370 .await;
371
372 let (t2_row_idx, tier1_timing) = tier1_outcome?;
374 if let Some(e) = t2_bounds_err {
375 return Err(e);
376 }
377 let (tier2_row, tier2_timing) = tier2_result?;
378
379 let proof = process_tier2_and_build(
380 &tier2_row,
381 t2_row_idx,
382 self.num_ranges,
383 nullifier,
384 &mut path,
385 &self.empty_hashes,
386 self.root29,
387 )?;
388
389 let total_ms = note_start.elapsed().as_secs_f64() * 1000.0;
390 Ok((
391 proof,
392 NoteTiming {
393 tier1: tier1_timing,
394 tier2: tier2_timing,
395 total_ms,
396 },
397 ))
398 }
399
400 async fn ypir_query(
406 &self,
407 scenario: &YpirScenario,
408 tier_name: &str,
409 row_idx: usize,
410 expected_row_bytes: usize,
411 ) -> Result<(Vec<u8>, TierTiming)> {
412 anyhow::ensure!(
413 row_idx < scenario.num_items,
414 "{} row_idx {} >= num_items {}",
415 tier_name,
416 row_idx,
417 scenario.num_items
418 );
419 let t0 = Instant::now();
420 let ypir_client = YPIRClient::from_db_sz(
421 scenario.num_items as u64,
422 scenario.item_size_bits as u64,
423 true,
424 );
425
426 let (query, seed) = ypir_client.generate_query_simplepir(row_idx);
428 let gen_ms = t0.elapsed().as_secs_f64() * 1000.0;
429
430 let upload_q_bytes = query.0.as_slice().len() * std::mem::size_of::<u64>();
434 let upload_pp_bytes = query.1.as_slice().len() * std::mem::size_of::<u64>();
435 let payload = serialize_ypir_query(query.0.as_slice(), query.1.as_slice());
436 let upload_bytes = payload.len();
437
438 let t1 = Instant::now();
440 let url = format!("{}/{}/query", self.server_url, tier_name);
441 let send_result = self.transport.post(&url, payload).await;
442 let send_ms = t1.elapsed().as_secs_f64() * 1000.0;
443 let resp = match send_result {
444 Ok(r) => r,
445 Err(e) => {
446 log::warn!("YPIR {} send error: {:?}", tier_name, e);
447 return Err(e);
448 }
449 };
450 let server_req_id = parse_header_u64(&resp.headers, "x-pir-req-id");
451 let server_total_ms = parse_header_f64(&resp.headers, "x-pir-server-total-ms");
452 let server_validate_ms = parse_header_f64(&resp.headers, "x-pir-server-validate-ms");
453 let server_decode_copy_ms = parse_header_f64(&resp.headers, "x-pir-server-decode-copy-ms");
454 let server_compute_ms = parse_header_f64(&resp.headers, "x-pir-server-compute-ms");
455 let status = resp.status;
456 let response_bytes = resp.body;
457 if !is_success(status) {
458 anyhow::bail!(
459 "{} query failed: HTTP {} body={}",
460 tier_name,
461 status,
462 String::from_utf8_lossy(&response_bytes)
463 );
464 }
465 let rtt_ms = t1.elapsed().as_secs_f64() * 1000.0;
466 let download_from_server_ms = (rtt_ms - send_ms).max(0.0);
467 let net_queue_ms = server_total_ms.map(|server_ms| (rtt_ms - server_ms).max(0.0));
468 let upload_to_server_ms = server_total_ms.map(|server_ms| (send_ms - server_ms).max(0.0));
469
470 let t2 = Instant::now();
476 let decoded = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
477 ypir_client.decode_response_simplepir(seed, &response_bytes)
478 }))
479 .map_err(|panic_payload| {
480 let msg = panic_payload
481 .downcast_ref::<String>()
482 .map(|s| s.as_str())
483 .or_else(|| panic_payload.downcast_ref::<&str>().copied())
484 .unwrap_or("unknown panic");
485 anyhow::anyhow!("{} response decryption panicked: {}", tier_name, msg)
486 })?;
487 let decode_ms = t2.elapsed().as_secs_f64() * 1000.0;
488
489 anyhow::ensure!(
490 decoded.len() >= expected_row_bytes,
491 "{} decoded response too short: {} bytes, expected >= {}",
492 tier_name,
493 decoded.len(),
494 expected_row_bytes
495 );
496 Ok((
497 decoded[..expected_row_bytes].to_vec(),
498 TierTiming {
499 gen_ms,
500 upload_bytes,
501 upload_q_bytes,
502 upload_pp_bytes,
503 download_bytes: response_bytes.len(),
504 rtt_ms,
505 decode_ms,
506 server_req_id,
507 server_total_ms,
508 server_validate_ms,
509 server_decode_copy_ms,
510 server_compute_ms,
511 net_queue_ms,
512 upload_to_server_ms,
513 download_from_server_ms,
514 },
515 ))
516 }
517}
518
519fn fmt_time(ms: f64) -> String {
520 if ms >= 1000.0 {
521 format!("{:>5.1}s ", ms / 1000.0)
522 } else {
523 format!("{:>5.0}ms", ms)
524 }
525}
526
527fn fmt_opt_time(ms: Option<f64>) -> String {
528 match ms {
529 Some(v) => fmt_time(v),
530 None => " n/a ".to_string(),
531 }
532}
533
534fn is_success(status: u16) -> bool {
535 (200..300).contains(&status)
536}
537
538fn body_for_status(response: TransportResponse, context: &'static str) -> Result<Vec<u8>> {
539 if is_success(response.status) {
540 Ok(response.body)
541 } else {
542 anyhow::bail!(
543 "{}: HTTP {} body={}",
544 context,
545 response.status,
546 String::from_utf8_lossy(&response.body)
547 )
548 }
549}
550
551fn print_timing_table(results: &[(usize, ImtProofData, NoteTiming)], wall_ms: f64) {
553 if !log::log_enabled!(log::Level::Debug) {
554 return;
555 }
556
557 log::debug!("[PIR] ┌─────┬──────────┬─────────────┬──────────┬──────────┬─────────────┬──────────┬────────┐");
558 log::debug!("[PIR] │ Note│ T1 keygen│ T1 upload+ │ T1 decode│ T2 keygen│ T2 upload+ │ T2 decode│ Total │");
559 log::debug!("[PIR] │ │ (client) │ server+down │ (client) │ (client) │ server+down │ (client) │ │");
560 log::debug!("[PIR] ├─────┼──────────┼─────────────┼──────────┼──────────┼─────────────┼──────────┼────────┤");
561 for &(i, _, ref t) in results {
562 log::debug!(
563 "[PIR] │ {i:>2} │ {:>6} │ {:>7} │ {:>6} │ {:>6} │ {:>7} │ {:>6} │{} │",
564 fmt_time(t.tier1.gen_ms),
565 fmt_time(t.tier1.rtt_ms),
566 fmt_time(t.tier1.decode_ms),
567 fmt_time(t.tier2.gen_ms),
568 fmt_time(t.tier2.rtt_ms),
569 fmt_time(t.tier2.decode_ms),
570 fmt_time(t.total_ms),
571 );
572 }
573 log::debug!("[PIR] └─────┴──────────┴─────────────┴──────────┴──────────┴─────────────┴──────────┴────────┘");
574 log::debug!(
575 "[PIR] Upload per note: T1={:.0}KB T2={:.1}MB | Wall clock: {:.2}s",
576 results
577 .first()
578 .map(|(_, _, t)| t.tier1.upload_bytes)
579 .unwrap_or(0) as f64
580 / 1024.0,
581 results
582 .first()
583 .map(|(_, _, t)| t.tier2.upload_bytes)
584 .unwrap_or(0) as f64
585 / (1024.0 * 1024.0),
586 wall_ms / 1000.0,
587 );
588
589 for &(i, _, ref t) in results {
590 log::trace!(
591 "[PIR] Note {i:>2} transfer: T1 up={:.0}KB down={:.0}KB | T2 up={:.1}MB down={:.0}KB",
592 t.tier1.upload_bytes as f64 / 1024.0,
593 t.tier1.download_bytes as f64 / 1024.0,
594 t.tier2.upload_bytes as f64 / (1024.0 * 1024.0),
595 t.tier2.download_bytes as f64 / 1024.0,
596 );
597 log::trace!(
598 "[PIR] Note {i:>2} server/net: T1 {} / {} | T2 {} / {}",
599 fmt_opt_time(t.tier1.server_total_ms),
600 fmt_opt_time(t.tier1.net_queue_ms),
601 fmt_opt_time(t.tier2.server_total_ms),
602 fmt_opt_time(t.tier2.net_queue_ms),
603 );
604 log::trace!(
605 "[PIR] Note {i:>2} up/srv/down: T1 {} / {} / {} | T2 {} / {} / {}",
606 fmt_opt_time(t.tier1.upload_to_server_ms),
607 fmt_opt_time(t.tier1.server_total_ms),
608 fmt_time(t.tier1.download_from_server_ms),
609 fmt_opt_time(t.tier2.upload_to_server_ms),
610 fmt_opt_time(t.tier2.server_total_ms),
611 fmt_time(t.tier2.download_from_server_ms),
612 );
613 log::trace!(
614 "[PIR] Note {i:>2} server stages: T1(v={} copy={} compute={}) T2(v={} copy={} compute={})",
615 fmt_opt_time(t.tier1.server_validate_ms),
616 fmt_opt_time(t.tier1.server_decode_copy_ms),
617 fmt_opt_time(t.tier1.server_compute_ms),
618 fmt_opt_time(t.tier2.server_validate_ms),
619 fmt_opt_time(t.tier2.server_decode_copy_ms),
620 fmt_opt_time(t.tier2.server_compute_ms),
621 );
622 log::trace!(
623 "[PIR] Note {i:>2} req ids: T1={:?} T2={:?}",
624 t.tier1.server_req_id,
625 t.tier2.server_req_id
626 );
627 }
628}
629
630fn parse_header_f64(headers: &[(String, String)], name: &'static str) -> Option<f64> {
632 headers
633 .iter()
634 .find(|(header_name, _)| header_name.eq_ignore_ascii_case(name))
635 .and_then(|(_, value)| value.parse::<f64>().ok())
636}
637
638fn parse_header_u64(headers: &[(String, String)], name: &'static str) -> Option<u64> {
640 headers
641 .iter()
642 .find(|(header_name, _)| header_name.eq_ignore_ascii_case(name))
643 .and_then(|(_, value)| value.parse::<u64>().ok())
644}
645
646pub struct PirClientBlocking {
653 inner: PirClient,
654 rt: tokio::runtime::Runtime,
655}
656
657impl PirClientBlocking {
658 pub fn with_transport(server_url: &str, transport: Arc<dyn Transport>) -> Result<Self> {
660 let rt = tokio::runtime::Runtime::new()?;
661 let inner = rt.block_on(PirClient::with_transport(server_url, transport))?;
662 Ok(Self { inner, rt })
663 }
664
665 pub fn fetch_proof(&self, nullifier: Fp) -> Result<ImtProofData> {
667 self.rt.block_on(self.inner.fetch_proof(nullifier))
668 }
669
670 pub fn fetch_proofs(&self, nullifiers: &[Fp]) -> Result<Vec<ImtProofData>> {
672 self.rt.block_on(self.inner.fetch_proofs(nullifiers))
673 }
674
675 pub fn root29(&self) -> Fp {
677 self.inner.root29
678 }
679}
680
681pub fn fetch_proof_local(
688 tier0_data: &[u8],
689 tier1_data: &[u8],
690 tier2_data: &[u8],
691 num_ranges: usize,
692 nullifier: Fp,
693 empty_hashes: &[Fp; TREE_DEPTH],
694 root29: Fp,
695) -> Result<ImtProofData> {
696 let mut path = [Fp::default(); TREE_DEPTH];
697 let tier0 = Tier0Data::from_bytes(tier0_data.to_vec())?;
698
699 let s1 = process_tier0(&tier0, nullifier, &mut path)?;
700
701 let t1_offset = s1 * TIER1_ROW_BYTES;
703 anyhow::ensure!(
704 t1_offset + TIER1_ROW_BYTES <= tier1_data.len(),
705 "tier1 data too short: need {} bytes at offset {}, have {}",
706 TIER1_ROW_BYTES,
707 t1_offset,
708 tier1_data.len()
709 );
710 let s2 = process_tier1(
711 &tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES],
712 nullifier,
713 &mut path,
714 )?;
715
716 let t2_row_idx = s1 * TIER1_LEAVES + s2;
718 let t2_offset = t2_row_idx * TIER2_ROW_BYTES;
719 anyhow::ensure!(
720 t2_offset + TIER2_ROW_BYTES <= tier2_data.len(),
721 "tier2 data too short: need {} bytes at offset {}, have {}",
722 TIER2_ROW_BYTES,
723 t2_offset,
724 tier2_data.len()
725 );
726
727 process_tier2_and_build(
728 &tier2_data[t2_offset..t2_offset + TIER2_ROW_BYTES],
729 t2_row_idx,
730 num_ranges,
731 nullifier,
732 &mut path,
733 empty_hashes,
734 root29,
735 )
736}
737
738#[cfg(test)]
739mod tests {
740 use super::*;
741 use ff::Field;
742 use pasta_curves::Fp;
743 use pir_export::build_ranges_with_sentinels;
744
745 struct TestFixture {
747 tier0_data: Vec<u8>,
748 tier1_data: Vec<u8>,
749 tier2_data: Vec<u8>,
750 ranges: Vec<[Fp; 3]>,
751 empty_hashes: [Fp; TREE_DEPTH],
752 root29: Fp,
753 }
754
755 impl TestFixture {
756 fn build(raw_nfs: &[Fp]) -> Self {
757 let ranges = build_ranges_with_sentinels(raw_nfs);
758 let tree = pir_export::build_pir_tree(ranges.clone()).unwrap();
759
760 let tier0_data = pir_export::tier0::export(
761 &tree.root25,
762 &tree.levels,
763 &tree.ranges,
764 &tree.empty_hashes,
765 );
766 let mut tier1_data = Vec::new();
767 pir_export::tier1::export(
768 &tree.levels,
769 &tree.ranges,
770 &tree.empty_hashes,
771 &mut tier1_data,
772 )
773 .unwrap();
774 let mut tier2_data = Vec::new();
775 pir_export::tier2::export(&tree.ranges, &mut tier2_data).unwrap();
776
777 Self {
778 tier0_data,
779 tier1_data,
780 tier2_data,
781 ranges,
782 empty_hashes: tree.empty_hashes,
783 root29: tree.root29,
784 }
785 }
786 }
787
788 #[test]
791 fn fetch_proof_local_verifies_for_known_ranges() {
792 let mut rng = rand::thread_rng();
793 let raw_nfs: Vec<Fp> = (0..100).map(|_| Fp::random(&mut rng)).collect();
794 let fix = TestFixture::build(&raw_nfs);
795
796 for &[nf_lo, _, _] in fix.ranges.iter().take(20) {
797 let value = nf_lo + Fp::one();
798 let proof = fetch_proof_local(
799 &fix.tier0_data,
800 &fix.tier1_data,
801 &fix.tier2_data,
802 fix.ranges.len(),
803 value,
804 &fix.empty_hashes,
805 fix.root29,
806 )
807 .expect("fetch_proof_local should succeed for a value in range");
808 assert!(
809 proof.verify(value),
810 "proof should verify for value {:?}",
811 value,
812 );
813 }
814 }
815
816 #[test]
817 fn fetch_proof_local_correct_root_and_path_length() {
818 let raw_nfs: Vec<Fp> = (1u64..=50).map(|i| Fp::from(i * 997)).collect();
819 let fix = TestFixture::build(&raw_nfs);
820
821 let value = fix.ranges[0][0] + Fp::one(); let proof = fetch_proof_local(
823 &fix.tier0_data,
824 &fix.tier1_data,
825 &fix.tier2_data,
826 fix.ranges.len(),
827 value,
828 &fix.empty_hashes,
829 fix.root29,
830 )
831 .unwrap();
832
833 assert_eq!(proof.root, fix.root29);
834 assert_eq!(proof.path.len(), TREE_DEPTH);
835 }
836
837 #[test]
840 fn process_tier0_fills_correct_path_region() {
841 let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
842 let fix = TestFixture::build(&raw_nfs);
843 let tier0 = Tier0Data::from_bytes(fix.tier0_data).unwrap();
844
845 let value = fix.ranges[0][0];
846 let mut path = [Fp::default(); TREE_DEPTH];
847 let s1 = process_tier0(&tier0, value, &mut path).unwrap();
848
849 assert!(s1 < pir_types::TIER1_ROWS);
850
851 let tier0_region = &path[PIR_DEPTH - TIER0_LAYERS..PIR_DEPTH];
852 assert!(
853 tier0_region.iter().any(|&v| v != Fp::default()),
854 "tier0 should write at least one non-zero sibling"
855 );
856
857 let below = &path[..PIR_DEPTH - TIER0_LAYERS];
858 assert!(
859 below.iter().all(|&v| v == Fp::default()),
860 "path below tier0 region should be untouched"
861 );
862 }
863
864 #[test]
865 fn process_tier0_handles_arbitrary_field_element() {
866 let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
867 let fix = TestFixture::build(&raw_nfs);
868 let tier0 = Tier0Data::from_bytes(fix.tier0_data).unwrap();
869
870 let bogus = Fp::from(u64::MAX);
874 let mut path = [Fp::default(); TREE_DEPTH];
875 let s1 = process_tier0(&tier0, bogus, &mut path).unwrap();
876 assert!(s1 < pir_types::TIER1_ROWS);
877 }
878
879 #[test]
882 fn process_tier1_fills_correct_path_region() {
883 let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
884 let fix = TestFixture::build(&raw_nfs);
885 let tier0 = Tier0Data::from_bytes(fix.tier0_data.clone()).unwrap();
886
887 let value = fix.ranges[0][0];
888 let mut path = [Fp::default(); TREE_DEPTH];
889 let s1 = process_tier0(&tier0, value, &mut path).unwrap();
890
891 let t1_offset = s1 * TIER1_ROW_BYTES;
892 let tier1_row = &fix.tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES];
893 let s2 = process_tier1(tier1_row, value, &mut path).unwrap();
894
895 assert!(s2 < TIER1_LEAVES);
896
897 let tier1_region = &path[PIR_DEPTH - TIER0_LAYERS - TIER1_LAYERS..PIR_DEPTH - TIER0_LAYERS];
898 assert!(
899 tier1_region.iter().any(|&v| v != Fp::default()),
900 "tier1 should write at least one non-zero sibling"
901 );
902 }
903
904 #[test]
907 fn process_tier2_and_build_produces_verifiable_proof() {
908 let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
909 let fix = TestFixture::build(&raw_nfs);
910 let tier0 = Tier0Data::from_bytes(fix.tier0_data.clone()).unwrap();
911
912 let value = fix.ranges[0][0] + Fp::one();
913 let mut path = [Fp::default(); TREE_DEPTH];
914
915 let s1 = process_tier0(&tier0, value, &mut path).unwrap();
916 let t1_offset = s1 * TIER1_ROW_BYTES;
917 let s2 = process_tier1(
918 &fix.tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES],
919 value,
920 &mut path,
921 )
922 .unwrap();
923
924 let t2_row_idx = s1 * TIER1_LEAVES + s2;
925 let t2_offset = t2_row_idx * TIER2_ROW_BYTES;
926 let proof = process_tier2_and_build(
927 &fix.tier2_data[t2_offset..t2_offset + TIER2_ROW_BYTES],
928 t2_row_idx,
929 fix.ranges.len(),
930 value,
931 &mut path,
932 &fix.empty_hashes,
933 fix.root29,
934 )
935 .unwrap();
936
937 assert!(proof.verify(value));
938 assert_eq!(proof.root, fix.root29);
939 }
940
941 #[test]
944 fn valid_leaves_for_row_basic() {
945 assert_eq!(valid_leaves_for_row(TIER2_LEAVES, 0), TIER2_LEAVES);
946 assert_eq!(valid_leaves_for_row(TIER2_LEAVES + 1, 0), TIER2_LEAVES);
947 assert_eq!(valid_leaves_for_row(TIER2_LEAVES + 1, 1), 1);
948 assert_eq!(valid_leaves_for_row(0, 0), 0);
949 assert_eq!(valid_leaves_for_row(1, 0), 1);
950 assert_eq!(valid_leaves_for_row(1, 1), 0);
951 }
952
953 #[test]
956 fn fetch_proof_local_rejects_truncated_tier1() {
957 let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
958 let fix = TestFixture::build(&raw_nfs);
959
960 let result = fetch_proof_local(
961 &fix.tier0_data,
962 &fix.tier1_data[..TIER1_ROW_BYTES / 2],
963 &fix.tier2_data,
964 fix.ranges.len(),
965 fix.ranges[0][0],
966 &fix.empty_hashes,
967 fix.root29,
968 );
969 assert!(result.is_err());
970 }
971
972 #[test]
973 fn fetch_proof_local_rejects_truncated_tier2() {
974 let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
975 let fix = TestFixture::build(&raw_nfs);
976
977 let result = fetch_proof_local(
978 &fix.tier0_data,
979 &fix.tier1_data,
980 &fix.tier2_data[..TIER2_ROW_BYTES / 2],
981 fix.ranges.len(),
982 fix.ranges[0][0],
983 &fix.empty_hashes,
984 fix.root29,
985 );
986 assert!(result.is_err());
987 }
988
989 struct MockTransport {
992 gets: std::collections::HashMap<&'static str, TransportResponse>,
993 posts: std::collections::HashMap<&'static str, TransportResponse>,
994 hits: std::sync::Mutex<Vec<String>>,
995 }
996
997 impl MockTransport {
998 fn new(tree: &pir_export::PirTree) -> Self {
999 use ff::PrimeField as _;
1000 use pir_types::{TIER1_ITEM_BITS, TIER1_YPIR_ROWS, TIER2_ITEM_BITS};
1001
1002 let tier0_data = pir_export::tier0::export(
1003 &tree.root25,
1004 &tree.levels,
1005 &tree.ranges,
1006 &tree.empty_hashes,
1007 );
1008 let root_info = pir_types::RootInfo {
1009 root29: hex::encode(tree.root29.to_repr()),
1010 root25: hex::encode(tree.root25.to_repr()),
1011 num_ranges: tree.ranges.len(),
1012 pir_depth: PIR_DEPTH,
1013 height: None,
1014 };
1015 let tier1_scenario = YpirScenario {
1016 num_items: TIER1_YPIR_ROWS,
1017 item_size_bits: TIER1_ITEM_BITS,
1018 };
1019 let tier2_scenario = YpirScenario {
1020 num_items: TIER1_YPIR_ROWS,
1021 item_size_bits: TIER2_ITEM_BITS,
1022 };
1023
1024 let gets = [
1025 ("/tier0", response(tier0_data)),
1026 (
1027 "/params/tier1",
1028 response(serde_json::to_vec(&tier1_scenario).unwrap()),
1029 ),
1030 (
1031 "/params/tier2",
1032 response(serde_json::to_vec(&tier2_scenario).unwrap()),
1033 ),
1034 ("/root", response(serde_json::to_vec(&root_info).unwrap())),
1035 ]
1036 .into_iter()
1037 .collect();
1038 let posts = [
1039 ("/tier1/query", response(vec![0xDE; 65536])),
1040 ("/tier2/query", response(vec![0xAD; 65536])),
1041 ]
1042 .into_iter()
1043 .collect();
1044
1045 Self {
1046 gets,
1047 posts,
1048 hits: std::sync::Mutex::new(Vec::new()),
1049 }
1050 }
1051
1052 fn count_hits(&self, path: &str) -> usize {
1053 self.hits
1054 .lock()
1055 .unwrap()
1056 .iter()
1057 .filter(|hit| hit.as_str() == path)
1058 .count()
1059 }
1060 }
1061
1062 fn response(body: Vec<u8>) -> TransportResponse {
1063 TransportResponse {
1064 status: 200,
1065 headers: Vec::new(),
1066 body,
1067 }
1068 }
1069
1070 fn request_path(url: &str) -> &str {
1071 let without_scheme = url.split_once("://").map(|(_, rest)| rest).unwrap_or(url);
1072 without_scheme
1073 .find('/')
1074 .map(|idx| &without_scheme[idx..])
1075 .unwrap_or("/")
1076 }
1077
1078 impl Transport for MockTransport {
1079 fn get<'a>(&'a self, url: &'a str) -> transport::TransportFuture<'a> {
1080 Box::pin(async move {
1081 let path = request_path(url);
1082 self.hits.lock().unwrap().push(path.to_string());
1083 self.gets
1084 .get(path)
1085 .cloned()
1086 .ok_or_else(|| anyhow::anyhow!("unexpected GET {path}"))
1087 })
1088 }
1089
1090 fn post<'a>(&'a self, url: &'a str, _body: Vec<u8>) -> transport::TransportFuture<'a> {
1091 Box::pin(async move {
1092 let path = request_path(url);
1093 self.hits.lock().unwrap().push(path.to_string());
1094 if path == "/tier2/query" {
1095 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1099 }
1100 self.posts
1101 .get(path)
1102 .cloned()
1103 .ok_or_else(|| anyhow::anyhow!("unexpected POST {path}"))
1104 })
1105 }
1106 }
1107
1108 async fn corrupting_client() -> (
1109 PirClient,
1110 std::sync::Arc<MockTransport>,
1111 pir_export::PirTree,
1112 ) {
1113 let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
1114 let ranges = build_ranges_with_sentinels(&raw_nfs);
1115 let tree = pir_export::build_pir_tree(ranges).unwrap();
1116 let transport = std::sync::Arc::new(MockTransport::new(&tree));
1117 let client = PirClient::with_transport("https://pir.example", transport.clone())
1118 .await
1119 .unwrap();
1120 (client, transport, tree)
1121 }
1122
1123 #[tokio::test]
1126 async fn tier2_query_sent_despite_tier1_decode_failure() {
1127 let (client, transport, tree) = corrupting_client().await;
1128 let result = client.fetch_proof(tree.ranges[0][0]).await;
1129
1130 assert!(
1131 result.is_err(),
1132 "fetch_proof should fail with corrupted tier1 response"
1133 );
1134 assert_eq!(transport.count_hits("/tier1/query"), 1);
1135 assert_eq!(transport.count_hits("/tier2/query"), 1);
1136 }
1137
1138 #[tokio::test]
1145 async fn batched_tier2_queries_all_sent_despite_tier1_decode_failure() {
1146 const K: usize = 5;
1147
1148 let (client, transport, tree) = corrupting_client().await;
1149 let nullifiers: Vec<Fp> = tree
1150 .ranges
1151 .iter()
1152 .take(K)
1153 .map(|r| r[0] + Fp::one())
1154 .collect();
1155
1156 let result = client.fetch_proofs(&nullifiers).await;
1157
1158 assert!(
1159 result.is_err(),
1160 "fetch_proofs should fail with corrupted tier1 responses"
1161 );
1162 assert_eq!(transport.count_hits("/tier1/query"), K);
1163 assert_eq!(transport.count_hits("/tier2/query"), K);
1164 }
1165}