use std::time::Instant;
use anyhow::{Context, Result};
use ff::PrimeField as _;
use imt_tree::hasher::PoseidonHasher;
use imt_tree::tree::{precompute_empty_hashes, TREE_DEPTH};
use pasta_curves::Fp;
pub use imt_tree::ImtProofData;
use pir_types::tier0::Tier0Data;
use pir_types::tier1::Tier1Row;
use pir_types::tier2::Tier2Row;
use pir_types::{
serialize_ypir_query, RootInfo, YpirScenario, PIR_DEPTH, TIER0_LAYERS, TIER1_LAYERS,
TIER1_LEAVES, TIER1_ROW_BYTES, TIER2_LEAVES, TIER2_ROW_BYTES,
};
use ypir::client::YPIRClient;
struct TierTiming {
gen_ms: f64,
upload_bytes: usize,
download_bytes: usize,
rtt_ms: f64,
decode_ms: f64,
server_req_id: Option<u64>,
server_total_ms: Option<f64>,
server_validate_ms: Option<f64>,
server_decode_copy_ms: Option<f64>,
server_compute_ms: Option<f64>,
net_queue_ms: Option<f64>,
upload_to_server_ms: Option<f64>,
download_from_server_ms: f64,
}
struct NoteTiming {
tier1: TierTiming,
tier2: TierTiming,
total_ms: f64,
}
pub struct PirClient {
server_url: String,
http: reqwest::Client,
tier0: Tier0Data,
tier1_scenario: YpirScenario,
tier2_scenario: YpirScenario,
num_ranges: usize,
empty_hashes: [Fp; TREE_DEPTH],
root29: Fp,
}
#[inline]
fn valid_leaves_for_row(num_ranges: usize, row_idx: usize) -> usize {
let row_start = row_idx.saturating_mul(TIER2_LEAVES);
num_ranges.saturating_sub(row_start).min(TIER2_LEAVES)
}
#[inline]
fn fill_path(path: &mut [Fp; TREE_DEPTH], offset: usize, siblings: &[Fp]) {
path[offset..offset + siblings.len()].copy_from_slice(siblings);
}
fn process_tier0(tier0: &Tier0Data, nullifier: Fp, path: &mut [Fp; TREE_DEPTH]) -> Result<usize> {
let s1 = tier0
.find_subtree(nullifier)
.context("nullifier not found in any Tier 0 subtree")?;
fill_path(path, PIR_DEPTH - TIER0_LAYERS, &tier0.extract_siblings(s1));
Ok(s1)
}
fn process_tier1(tier1_row: &[u8], nullifier: Fp, path: &mut [Fp; TREE_DEPTH]) -> Result<usize> {
let hasher = PoseidonHasher::new();
let tier1 = Tier1Row::from_bytes(tier1_row)?;
let s2 = tier1
.find_sub_subtree(nullifier)
.context("nullifier not found in any Tier 1 sub-subtree")?;
fill_path(
path,
PIR_DEPTH - TIER0_LAYERS - TIER1_LAYERS,
&tier1.extract_siblings(s2, &hasher),
);
Ok(s2)
}
fn process_tier2_and_build(
tier2_row: &[u8],
t2_row_idx: usize,
num_ranges: usize,
nullifier: Fp,
path: &mut [Fp; TREE_DEPTH],
empty_hashes: &[Fp; TREE_DEPTH],
root29: Fp,
) -> Result<ImtProofData> {
let hasher = PoseidonHasher::new();
let tier2 = Tier2Row::from_bytes(tier2_row)?;
let valid_leaves = valid_leaves_for_row(num_ranges, t2_row_idx);
let leaf_local_idx = tier2
.find_leaf(nullifier, valid_leaves)
.context("nullifier not found in Tier 2 leaf scan")?;
fill_path(
path,
0,
&tier2.extract_siblings(leaf_local_idx, valid_leaves, &hasher),
);
fill_path(path, PIR_DEPTH, &empty_hashes[PIR_DEPTH..TREE_DEPTH]);
let global_leaf_idx = t2_row_idx * TIER2_LEAVES + leaf_local_idx;
let (nf_lo, nf_mid, nf_hi) = tier2.leaf_record(leaf_local_idx);
Ok(ImtProofData {
root: root29,
nf_bounds: [nf_lo, nf_mid, nf_hi],
leaf_pos: global_leaf_idx as u32,
path: *path,
})
}
impl PirClient {
pub async fn connect(server_url: &str) -> Result<Self> {
let http = reqwest::Client::new();
let base = server_url.trim_end_matches('/');
let t0 = Instant::now();
let (tier0_resp, tier1_resp, tier2_resp, root_resp) = tokio::try_join!(
http.get(format!("{base}/tier0")).send(),
http.get(format!("{base}/params/tier1")).send(),
http.get(format!("{base}/params/tier2")).send(),
http.get(format!("{base}/root")).send(),
)
.map_err(|e| anyhow::anyhow!("connect fetch failed: {e}"))?;
let tier0_bytes = tier0_resp.error_for_status()?.bytes().await?;
log::debug!(
"Downloaded Tier 0: {} bytes in {:.1}s",
tier0_bytes.len(),
t0.elapsed().as_secs_f64()
);
let tier0 = Tier0Data::from_bytes(tier0_bytes.to_vec())?;
let tier1_scenario: YpirScenario = tier1_resp
.error_for_status()
.context("GET /params/tier1 failed")?
.json()
.await?;
let tier2_scenario: YpirScenario = tier2_resp
.error_for_status()
.context("GET /params/tier2 failed")?
.json()
.await?;
let root_info: RootInfo = root_resp
.error_for_status()
.context("GET /root failed")?
.json()
.await?;
anyhow::ensure!(
root_info.pir_depth == PIR_DEPTH,
"server pir_depth {} != expected {}",
root_info.pir_depth,
PIR_DEPTH
);
let root29_bytes = hex::decode(&root_info.root29)?;
anyhow::ensure!(
root29_bytes.len() == 32,
"root29 hex decoded to {} bytes, expected 32",
root29_bytes.len()
);
let mut root29_arr = [0u8; 32];
root29_arr.copy_from_slice(&root29_bytes);
let root29 = Option::from(Fp::from_repr(root29_arr))
.ok_or_else(|| anyhow::anyhow!("invalid root29 field element"))?;
let empty_hashes = precompute_empty_hashes();
Ok(Self {
server_url: base.to_string(),
http,
tier0,
tier1_scenario,
tier2_scenario,
num_ranges: root_info.num_ranges,
empty_hashes,
root29,
})
}
pub async fn fetch_proof(&self, nullifier: Fp) -> Result<ImtProofData> {
let (proof, _timing) = self.fetch_proof_inner(nullifier).await?;
Ok(proof)
}
pub async fn fetch_proofs(&self, nullifiers: &[Fp]) -> Result<Vec<ImtProofData>> {
log::debug!(
"[PIR] Starting parallel fetch for {} notes...",
nullifiers.len()
);
let wall_start = Instant::now();
let futures: Vec<_> = nullifiers
.iter()
.enumerate()
.map(|(i, &nf)| async move {
let (proof, timing) = self.fetch_proof_inner(nf).await?;
Ok::<_, anyhow::Error>((i, proof, timing))
})
.collect();
let results_with_timing = futures::future::try_join_all(futures).await?;
let wall_ms = wall_start.elapsed().as_secs_f64() * 1000.0;
print_timing_table(&results_with_timing, wall_ms);
let proofs = results_with_timing
.into_iter()
.map(|(_, proof, _)| proof)
.collect();
Ok(proofs)
}
async fn fetch_proof_inner(&self, nullifier: Fp) -> Result<(ImtProofData, NoteTiming)> {
let note_start = Instant::now();
let mut path = [Fp::default(); TREE_DEPTH];
let s1 = process_tier0(&self.tier0, nullifier, &mut path)?;
let tier1_outcome = self
.ypir_query(&self.tier1_scenario, "tier1", s1, TIER1_ROW_BYTES)
.await
.and_then(|(row, timing)| {
let mut_path = &mut path;
let s2 = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
process_tier1(&row, nullifier, mut_path)
}))
.unwrap_or_else(|payload| {
let msg = payload
.downcast_ref::<String>()
.map(|s| s.as_str())
.or_else(|| payload.downcast_ref::<&str>().copied())
.unwrap_or("unknown panic");
Err(anyhow::anyhow!("process_tier1 panicked: {}", msg))
})?;
Ok((s1 * TIER1_LEAVES + s2, timing))
});
let t2_row_idx = tier1_outcome.as_ref().map(|(idx, _)| *idx).unwrap_or(0);
let t2_bounds_err = if t2_row_idx >= self.tier2_scenario.num_items {
Some(anyhow::anyhow!(
"tier2 row_idx {} >= num_items {}",
t2_row_idx,
self.tier2_scenario.num_items
))
} else {
None
};
let t2_query_idx = if t2_bounds_err.is_some() {
0
} else {
t2_row_idx
};
let tier2_result = self
.ypir_query(&self.tier2_scenario, "tier2", t2_query_idx, TIER2_ROW_BYTES)
.await;
let (t2_row_idx, tier1_timing) = tier1_outcome?;
if let Some(e) = t2_bounds_err {
return Err(e);
}
let (tier2_row, tier2_timing) = tier2_result?;
let proof = process_tier2_and_build(
&tier2_row,
t2_row_idx,
self.num_ranges,
nullifier,
&mut path,
&self.empty_hashes,
self.root29,
)?;
let total_ms = note_start.elapsed().as_secs_f64() * 1000.0;
Ok((
proof,
NoteTiming {
tier1: tier1_timing,
tier2: tier2_timing,
total_ms,
},
))
}
async fn ypir_query(
&self,
scenario: &YpirScenario,
tier_name: &str,
row_idx: usize,
expected_row_bytes: usize,
) -> Result<(Vec<u8>, TierTiming)> {
anyhow::ensure!(
row_idx < scenario.num_items,
"{} row_idx {} >= num_items {}",
tier_name,
row_idx,
scenario.num_items
);
let t0 = Instant::now();
let ypir_client = YPIRClient::from_db_sz(
scenario.num_items as u64,
scenario.item_size_bits as u64,
true,
);
let (query, seed) = ypir_client.generate_query_simplepir(row_idx);
let gen_ms = t0.elapsed().as_secs_f64() * 1000.0;
let payload = serialize_ypir_query(query.0.as_slice(), query.1.as_slice());
let upload_bytes = payload.len();
let t1 = Instant::now();
let url = format!("{}/{}/query", self.server_url, tier_name);
let send_result = self.http.post(&url).body(payload).send().await;
let send_ms = t1.elapsed().as_secs_f64() * 1000.0;
let resp = match send_result {
Ok(r) => r,
Err(e) => {
log::warn!("YPIR {} send error: {:?}", tier_name, e);
return Err(e.into());
}
};
let server_req_id = parse_header_u64(resp.headers(), "x-pir-req-id");
let server_total_ms = parse_header_f64(resp.headers(), "x-pir-server-total-ms");
let server_validate_ms = parse_header_f64(resp.headers(), "x-pir-server-validate-ms");
let server_decode_copy_ms = parse_header_f64(resp.headers(), "x-pir-server-decode-copy-ms");
let server_compute_ms = parse_header_f64(resp.headers(), "x-pir-server-compute-ms");
let status = resp.status();
let response_bytes = resp.bytes().await?;
if !status.is_success() {
anyhow::bail!(
"{} query failed: HTTP {} body={}",
tier_name,
status,
String::from_utf8_lossy(&response_bytes)
);
}
let rtt_ms = t1.elapsed().as_secs_f64() * 1000.0;
let download_from_server_ms = (rtt_ms - send_ms).max(0.0);
let net_queue_ms = server_total_ms.map(|server_ms| (rtt_ms - server_ms).max(0.0));
let upload_to_server_ms = server_total_ms.map(|server_ms| (send_ms - server_ms).max(0.0));
let t2 = Instant::now();
let decoded = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
ypir_client.decode_response_simplepir(seed, &response_bytes)
}))
.map_err(|panic_payload| {
let msg = panic_payload
.downcast_ref::<String>()
.map(|s| s.as_str())
.or_else(|| panic_payload.downcast_ref::<&str>().copied())
.unwrap_or("unknown panic");
anyhow::anyhow!("{} response decryption panicked: {}", tier_name, msg)
})?;
let decode_ms = t2.elapsed().as_secs_f64() * 1000.0;
anyhow::ensure!(
decoded.len() >= expected_row_bytes,
"{} decoded response too short: {} bytes, expected >= {}",
tier_name,
decoded.len(),
expected_row_bytes
);
Ok((
decoded[..expected_row_bytes].to_vec(),
TierTiming {
gen_ms,
upload_bytes,
download_bytes: response_bytes.len(),
rtt_ms,
decode_ms,
server_req_id,
server_total_ms,
server_validate_ms,
server_decode_copy_ms,
server_compute_ms,
net_queue_ms,
upload_to_server_ms,
download_from_server_ms,
},
))
}
}
fn fmt_time(ms: f64) -> String {
if ms >= 1000.0 {
format!("{:>5.1}s ", ms / 1000.0)
} else {
format!("{:>5.0}ms", ms)
}
}
fn fmt_opt_time(ms: Option<f64>) -> String {
match ms {
Some(v) => fmt_time(v),
None => " n/a ".to_string(),
}
}
fn print_timing_table(results: &[(usize, ImtProofData, NoteTiming)], wall_ms: f64) {
if !log::log_enabled!(log::Level::Debug) {
return;
}
log::debug!("[PIR] ┌─────┬──────────┬─────────────┬──────────┬──────────┬─────────────┬──────────┬────────┐");
log::debug!("[PIR] │ Note│ T1 keygen│ T1 upload+ │ T1 decode│ T2 keygen│ T2 upload+ │ T2 decode│ Total │");
log::debug!("[PIR] │ │ (client) │ server+down │ (client) │ (client) │ server+down │ (client) │ │");
log::debug!("[PIR] ├─────┼──────────┼─────────────┼──────────┼──────────┼─────────────┼──────────┼────────┤");
for &(i, _, ref t) in results {
log::debug!(
"[PIR] │ {i:>2} │ {:>6} │ {:>7} │ {:>6} │ {:>6} │ {:>7} │ {:>6} │{} │",
fmt_time(t.tier1.gen_ms),
fmt_time(t.tier1.rtt_ms),
fmt_time(t.tier1.decode_ms),
fmt_time(t.tier2.gen_ms),
fmt_time(t.tier2.rtt_ms),
fmt_time(t.tier2.decode_ms),
fmt_time(t.total_ms),
);
}
log::debug!("[PIR] └─────┴──────────┴─────────────┴──────────┴──────────┴─────────────┴──────────┴────────┘");
log::debug!(
"[PIR] Upload per note: T1={:.0}KB T2={:.1}MB | Wall clock: {:.2}s",
results
.first()
.map(|(_, _, t)| t.tier1.upload_bytes)
.unwrap_or(0) as f64
/ 1024.0,
results
.first()
.map(|(_, _, t)| t.tier2.upload_bytes)
.unwrap_or(0) as f64
/ (1024.0 * 1024.0),
wall_ms / 1000.0,
);
for &(i, _, ref t) in results {
log::trace!(
"[PIR] Note {i:>2} transfer: T1 up={:.0}KB down={:.0}KB | T2 up={:.1}MB down={:.0}KB",
t.tier1.upload_bytes as f64 / 1024.0,
t.tier1.download_bytes as f64 / 1024.0,
t.tier2.upload_bytes as f64 / (1024.0 * 1024.0),
t.tier2.download_bytes as f64 / 1024.0,
);
log::trace!(
"[PIR] Note {i:>2} server/net: T1 {} / {} | T2 {} / {}",
fmt_opt_time(t.tier1.server_total_ms),
fmt_opt_time(t.tier1.net_queue_ms),
fmt_opt_time(t.tier2.server_total_ms),
fmt_opt_time(t.tier2.net_queue_ms),
);
log::trace!(
"[PIR] Note {i:>2} up/srv/down: T1 {} / {} / {} | T2 {} / {} / {}",
fmt_opt_time(t.tier1.upload_to_server_ms),
fmt_opt_time(t.tier1.server_total_ms),
fmt_time(t.tier1.download_from_server_ms),
fmt_opt_time(t.tier2.upload_to_server_ms),
fmt_opt_time(t.tier2.server_total_ms),
fmt_time(t.tier2.download_from_server_ms),
);
log::trace!(
"[PIR] Note {i:>2} server stages: T1(v={} copy={} compute={}) T2(v={} copy={} compute={})",
fmt_opt_time(t.tier1.server_validate_ms),
fmt_opt_time(t.tier1.server_decode_copy_ms),
fmt_opt_time(t.tier1.server_compute_ms),
fmt_opt_time(t.tier2.server_validate_ms),
fmt_opt_time(t.tier2.server_decode_copy_ms),
fmt_opt_time(t.tier2.server_compute_ms),
);
log::trace!(
"[PIR] Note {i:>2} req ids: T1={:?} T2={:?}",
t.tier1.server_req_id,
t.tier2.server_req_id
);
}
}
fn parse_header_f64(headers: &reqwest::header::HeaderMap, name: &'static str) -> Option<f64> {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<f64>().ok())
}
fn parse_header_u64(headers: &reqwest::header::HeaderMap, name: &'static str) -> Option<u64> {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
}
pub struct PirClientBlocking {
inner: PirClient,
rt: tokio::runtime::Runtime,
}
impl PirClientBlocking {
pub fn connect(server_url: &str) -> Result<Self> {
let rt = tokio::runtime::Runtime::new()?;
let inner = rt.block_on(PirClient::connect(server_url))?;
Ok(Self { inner, rt })
}
pub fn fetch_proof(&self, nullifier: Fp) -> Result<ImtProofData> {
self.rt.block_on(self.inner.fetch_proof(nullifier))
}
pub fn fetch_proofs(&self, nullifiers: &[Fp]) -> Result<Vec<ImtProofData>> {
self.rt.block_on(self.inner.fetch_proofs(nullifiers))
}
pub fn root29(&self) -> Fp {
self.inner.root29
}
}
pub fn fetch_proof_local(
tier0_data: &[u8],
tier1_data: &[u8],
tier2_data: &[u8],
num_ranges: usize,
nullifier: Fp,
empty_hashes: &[Fp; TREE_DEPTH],
root29: Fp,
) -> Result<ImtProofData> {
let mut path = [Fp::default(); TREE_DEPTH];
let tier0 = Tier0Data::from_bytes(tier0_data.to_vec())?;
let s1 = process_tier0(&tier0, nullifier, &mut path)?;
let t1_offset = s1 * TIER1_ROW_BYTES;
anyhow::ensure!(
t1_offset + TIER1_ROW_BYTES <= tier1_data.len(),
"tier1 data too short: need {} bytes at offset {}, have {}",
TIER1_ROW_BYTES,
t1_offset,
tier1_data.len()
);
let s2 = process_tier1(
&tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES],
nullifier,
&mut path,
)?;
let t2_row_idx = s1 * TIER1_LEAVES + s2;
let t2_offset = t2_row_idx * TIER2_ROW_BYTES;
anyhow::ensure!(
t2_offset + TIER2_ROW_BYTES <= tier2_data.len(),
"tier2 data too short: need {} bytes at offset {}, have {}",
TIER2_ROW_BYTES,
t2_offset,
tier2_data.len()
);
process_tier2_and_build(
&tier2_data[t2_offset..t2_offset + TIER2_ROW_BYTES],
t2_row_idx,
num_ranges,
nullifier,
&mut path,
empty_hashes,
root29,
)
}
#[cfg(test)]
mod tests {
use super::*;
use ff::Field;
use pasta_curves::Fp;
use pir_export::build_ranges_with_sentinels;
struct TestFixture {
tier0_data: Vec<u8>,
tier1_data: Vec<u8>,
tier2_data: Vec<u8>,
ranges: Vec<[Fp; 3]>,
empty_hashes: [Fp; TREE_DEPTH],
root29: Fp,
}
impl TestFixture {
fn build(raw_nfs: &[Fp]) -> Self {
let ranges = build_ranges_with_sentinels(raw_nfs);
let tree = pir_export::build_pir_tree(ranges.clone()).unwrap();
let tier0_data = pir_export::tier0::export(
&tree.root25,
&tree.levels,
&tree.ranges,
&tree.empty_hashes,
);
let mut tier1_data = Vec::new();
pir_export::tier1::export(
&tree.levels,
&tree.ranges,
&tree.empty_hashes,
&mut tier1_data,
)
.unwrap();
let mut tier2_data = Vec::new();
pir_export::tier2::export(&tree.ranges, &mut tier2_data).unwrap();
Self {
tier0_data,
tier1_data,
tier2_data,
ranges,
empty_hashes: tree.empty_hashes,
root29: tree.root29,
}
}
}
#[test]
fn fetch_proof_local_verifies_for_known_ranges() {
let mut rng = rand::thread_rng();
let raw_nfs: Vec<Fp> = (0..100).map(|_| Fp::random(&mut rng)).collect();
let fix = TestFixture::build(&raw_nfs);
for &[nf_lo, _, _] in fix.ranges.iter().take(20) {
let value = nf_lo + Fp::one();
let proof = fetch_proof_local(
&fix.tier0_data,
&fix.tier1_data,
&fix.tier2_data,
fix.ranges.len(),
value,
&fix.empty_hashes,
fix.root29,
)
.expect("fetch_proof_local should succeed for a value in range");
assert!(
proof.verify(value),
"proof should verify for value {:?}",
value,
);
}
}
#[test]
fn fetch_proof_local_correct_root_and_path_length() {
let raw_nfs: Vec<Fp> = (1u64..=50).map(|i| Fp::from(i * 997)).collect();
let fix = TestFixture::build(&raw_nfs);
let value = fix.ranges[0][0] + Fp::one(); let proof = fetch_proof_local(
&fix.tier0_data,
&fix.tier1_data,
&fix.tier2_data,
fix.ranges.len(),
value,
&fix.empty_hashes,
fix.root29,
)
.unwrap();
assert_eq!(proof.root, fix.root29);
assert_eq!(proof.path.len(), TREE_DEPTH);
}
#[test]
fn process_tier0_fills_correct_path_region() {
let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
let fix = TestFixture::build(&raw_nfs);
let tier0 = Tier0Data::from_bytes(fix.tier0_data).unwrap();
let value = fix.ranges[0][0];
let mut path = [Fp::default(); TREE_DEPTH];
let s1 = process_tier0(&tier0, value, &mut path).unwrap();
assert!(s1 < pir_types::TIER1_ROWS);
let tier0_region = &path[PIR_DEPTH - TIER0_LAYERS..PIR_DEPTH];
assert!(
tier0_region.iter().any(|&v| v != Fp::default()),
"tier0 should write at least one non-zero sibling"
);
let below = &path[..PIR_DEPTH - TIER0_LAYERS];
assert!(
below.iter().all(|&v| v == Fp::default()),
"path below tier0 region should be untouched"
);
}
#[test]
fn process_tier0_handles_arbitrary_field_element() {
let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
let fix = TestFixture::build(&raw_nfs);
let tier0 = Tier0Data::from_bytes(fix.tier0_data).unwrap();
let bogus = Fp::from(u64::MAX);
let mut path = [Fp::default(); TREE_DEPTH];
let s1 = process_tier0(&tier0, bogus, &mut path).unwrap();
assert!(s1 < pir_types::TIER1_ROWS);
}
#[test]
fn process_tier1_fills_correct_path_region() {
let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
let fix = TestFixture::build(&raw_nfs);
let tier0 = Tier0Data::from_bytes(fix.tier0_data.clone()).unwrap();
let value = fix.ranges[0][0];
let mut path = [Fp::default(); TREE_DEPTH];
let s1 = process_tier0(&tier0, value, &mut path).unwrap();
let t1_offset = s1 * TIER1_ROW_BYTES;
let tier1_row = &fix.tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES];
let s2 = process_tier1(tier1_row, value, &mut path).unwrap();
assert!(s2 < TIER1_LEAVES);
let tier1_region = &path[PIR_DEPTH - TIER0_LAYERS - TIER1_LAYERS..PIR_DEPTH - TIER0_LAYERS];
assert!(
tier1_region.iter().any(|&v| v != Fp::default()),
"tier1 should write at least one non-zero sibling"
);
}
#[test]
fn process_tier2_and_build_produces_verifiable_proof() {
let raw_nfs: Vec<Fp> = (1u64..=30).map(|i| Fp::from(i * 1013)).collect();
let fix = TestFixture::build(&raw_nfs);
let tier0 = Tier0Data::from_bytes(fix.tier0_data.clone()).unwrap();
let value = fix.ranges[0][0] + Fp::one();
let mut path = [Fp::default(); TREE_DEPTH];
let s1 = process_tier0(&tier0, value, &mut path).unwrap();
let t1_offset = s1 * TIER1_ROW_BYTES;
let s2 = process_tier1(
&fix.tier1_data[t1_offset..t1_offset + TIER1_ROW_BYTES],
value,
&mut path,
)
.unwrap();
let t2_row_idx = s1 * TIER1_LEAVES + s2;
let t2_offset = t2_row_idx * TIER2_ROW_BYTES;
let proof = process_tier2_and_build(
&fix.tier2_data[t2_offset..t2_offset + TIER2_ROW_BYTES],
t2_row_idx,
fix.ranges.len(),
value,
&mut path,
&fix.empty_hashes,
fix.root29,
)
.unwrap();
assert!(proof.verify(value));
assert_eq!(proof.root, fix.root29);
}
#[test]
fn valid_leaves_for_row_basic() {
assert_eq!(valid_leaves_for_row(TIER2_LEAVES, 0), TIER2_LEAVES);
assert_eq!(valid_leaves_for_row(TIER2_LEAVES + 1, 0), TIER2_LEAVES);
assert_eq!(valid_leaves_for_row(TIER2_LEAVES + 1, 1), 1);
assert_eq!(valid_leaves_for_row(0, 0), 0);
assert_eq!(valid_leaves_for_row(1, 0), 1);
assert_eq!(valid_leaves_for_row(1, 1), 0);
}
#[test]
fn fetch_proof_local_rejects_truncated_tier1() {
let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
let fix = TestFixture::build(&raw_nfs);
let result = fetch_proof_local(
&fix.tier0_data,
&fix.tier1_data[..TIER1_ROW_BYTES / 2],
&fix.tier2_data,
fix.ranges.len(),
fix.ranges[0][0],
&fix.empty_hashes,
fix.root29,
);
assert!(result.is_err());
}
#[test]
fn fetch_proof_local_rejects_truncated_tier2() {
let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
let fix = TestFixture::build(&raw_nfs);
let result = fetch_proof_local(
&fix.tier0_data,
&fix.tier1_data,
&fix.tier2_data[..TIER2_ROW_BYTES / 2],
fix.ranges.len(),
fix.ranges[0][0],
&fix.empty_hashes,
fix.root29,
);
assert!(result.is_err());
}
#[tokio::test]
async fn tier2_query_sent_despite_tier1_decode_failure() {
use ff::PrimeField as _;
use pir_types::{TIER1_ITEM_BITS, TIER1_ROWS, TIER2_ITEM_BITS};
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let raw_nfs: Vec<Fp> = (1u64..=10).map(|i| Fp::from(i * 7)).collect();
let ranges = build_ranges_with_sentinels(&raw_nfs);
let tree = pir_export::build_pir_tree(ranges).unwrap();
let tier0_data =
pir_export::tier0::export(&tree.root25, &tree.levels, &tree.ranges, &tree.empty_hashes);
let root_info = pir_types::RootInfo {
root29: hex::encode(tree.root29.to_repr()),
root25: hex::encode(tree.root25.to_repr()),
num_ranges: tree.ranges.len(),
pir_depth: PIR_DEPTH,
height: None,
};
let tier1_scenario = YpirScenario {
num_items: TIER1_ROWS,
item_size_bits: TIER1_ITEM_BITS,
};
let tier2_scenario = YpirScenario {
num_items: TIER1_ROWS,
item_size_bits: TIER2_ITEM_BITS,
};
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/tier0"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(tier0_data))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/params/tier1"))
.respond_with(ResponseTemplate::new(200).set_body_json(&tier1_scenario))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/params/tier2"))
.respond_with(ResponseTemplate::new(200).set_body_json(&tier2_scenario))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/root"))
.respond_with(ResponseTemplate::new(200).set_body_json(&root_info))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/tier1/query"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(vec![0xDE; 65536]))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/tier2/query"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(vec![0xAD; 65536]))
.mount(&server)
.await;
let client = PirClient::connect(&server.uri()).await.unwrap();
let nullifier = tree.ranges[0][0];
let result = client.fetch_proof(nullifier).await;
assert!(
result.is_err(),
"fetch_proof should fail with corrupted tier1 response"
);
let received = server.received_requests().await.unwrap();
let tier1_hits = received
.iter()
.filter(|r| r.url.path() == "/tier1/query")
.count();
let tier2_hits = received
.iter()
.filter(|r| r.url.path() == "/tier2/query")
.count();
assert_eq!(tier1_hits, 1, "tier1 query should have been sent");
assert_eq!(
tier2_hits, 1,
"tier2 query must still be sent when tier1 decode fails \
(error-oracle mitigation)"
);
}
}