use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::{Path, PathBuf};
use std::time::Instant;
use clap::{Parser, Subcommand, ValueEnum};
use eyre::{Context, Result};
use indicatif::{ProgressBar, ProgressStyle};
use serde::Deserialize;
use tracing::{info, Level};
use tracing_subscriber::FmtSubscriber;
use inspire::ethereum_db::EthereumStateDb;
use inspire::math::GaussianSampler;
use inspire::params::{InspireVariant, ShardConfig};
use inspire::pir::{
extract_inspiring, extract_with_variant, query, query_seeded, ClientQuery, PackingMode,
SeededClientQuery, ServerCrs, ServerResponse,
};
use inspire::rlwe::RlweSecretKey;
#[derive(Parser)]
#[command(name = "inspire-client")]
#[command(about = "InsPIRe PIR client")]
#[command(version)]
struct Args {
#[arg(long, default_value = "http://localhost:3000")]
server: String,
#[arg(long, default_value = "inspire_data/secret_key.json")]
secret_key: PathBuf,
#[arg(long, default_value = "inspire_data")]
state_path: PathBuf,
#[arg(long, value_enum, default_value = "two-packing")]
variant: VariantChoice,
#[arg(long, value_enum, default_value = "inspiring")]
packing_mode: PackingModeChoice,
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
Storage {
#[arg(long)]
address: String,
#[arg(long)]
slot: String,
},
Index {
#[arg(long)]
index: u64,
},
Batch {
#[arg(long)]
file: PathBuf,
},
Params,
Health,
}
#[derive(Clone, Copy, Debug, ValueEnum)]
#[value(rename_all = "kebab_case")]
enum VariantChoice {
OnePacking,
TwoPacking,
}
#[derive(Clone, Copy, Debug, ValueEnum)]
#[value(rename_all = "kebab_case")]
enum PackingModeChoice {
Inspiring,
Tree,
}
impl From<PackingModeChoice> for PackingMode {
fn from(value: PackingModeChoice) -> Self {
match value {
PackingModeChoice::Inspiring => PackingMode::Inspiring,
PackingModeChoice::Tree => PackingMode::Tree,
}
}
}
enum QueryPayload {
Full(ClientQuery),
Seeded(SeededClientQuery),
}
impl VariantChoice {
fn inspire_variant(self) -> InspireVariant {
match self {
VariantChoice::OnePacking => InspireVariant::OnePacking,
VariantChoice::TwoPacking => InspireVariant::TwoPacking,
}
}
}
#[derive(Deserialize)]
struct ParamsResponse {
version: String,
ring_dim: usize,
modulus: String,
plaintext_modulus: u64,
gadget_base: u64,
gadget_len: usize,
entry_count: u64,
shard_count: usize,
crs_a_vectors_count: usize,
}
#[derive(Deserialize)]
struct QueryResponse {
response: ServerResponse,
processing_time_ms: u64,
}
#[derive(Deserialize)]
struct HealthResponse {
status: String,
version: String,
}
#[tokio::main]
async fn main() -> Result<()> {
let subscriber = FmtSubscriber::builder()
.with_max_level(Level::INFO)
.with_target(false)
.finish();
tracing::subscriber::set_global_default(subscriber)?;
let args = Args::parse();
match args.command {
Commands::Health => {
check_health(&args.server).await?;
}
Commands::Params => {
get_params(&args.server).await?;
}
Commands::Index { index } => {
query_by_index(
&args.server,
&args.secret_key,
index,
args.variant,
args.packing_mode,
)
.await?;
}
Commands::Storage { address, slot } => {
query_storage(
&args.server,
&args.secret_key,
&args.state_path,
&address,
&slot,
args.variant,
args.packing_mode,
)
.await?;
}
Commands::Batch { file } => {
batch_query(
&args.server,
&args.secret_key,
&file,
args.variant,
args.packing_mode,
)
.await?;
}
}
Ok(())
}
async fn check_health(server: &str) -> Result<()> {
let url = format!("{}/health", server);
let client = reqwest::Client::new();
let response: HealthResponse = client.get(&url).send().await?.json().await?;
println!("Server status: {}", response.status);
println!("Server version: {}", response.version);
Ok(())
}
async fn get_params(server: &str) -> Result<()> {
let url = format!("{}/params", server);
let client = reqwest::Client::new();
let params: ParamsResponse = client.get(&url).send().await?.json().await?;
println!("=== Server Parameters ===");
println!("Version: {}", params.version);
println!("Ring dimension: {}", params.ring_dim);
println!("Modulus: {}", params.modulus);
println!("Plaintext modulus: {}", params.plaintext_modulus);
println!("Gadget base: {}", params.gadget_base);
println!("Gadget length: {}", params.gadget_len);
println!("Entry count: {}", params.entry_count);
println!("Shard count: {}", params.shard_count);
println!("CRS vectors: {}", params.crs_a_vectors_count);
Ok(())
}
async fn query_by_index(
server: &str,
sk_path: &PathBuf,
index: u64,
variant: VariantChoice,
packing_mode: PackingModeChoice,
) -> Result<()> {
let total_start = Instant::now();
info!("Loading secret key...");
let secret_key = load_secret_key(sk_path)?;
info!("Fetching CRS from server...");
let fetch_start = Instant::now();
let crs = fetch_crs(server).await?;
let params_response = fetch_params(server).await?;
info!("CRS fetch time: {:.2?}", fetch_start.elapsed());
let shard_config = ShardConfig {
shard_size_bytes: (crs.ring_dim() as u64) * 32,
entry_size_bytes: 32,
total_entries: params_response.entry_count,
};
info!("Generating PIR query for index {}...", index);
let query_start = Instant::now();
let mut sampler = GaussianSampler::new(crs.params.sigma);
let inspire_variant = variant.inspire_variant();
let (state, payload) = match variant {
VariantChoice::OnePacking => {
let (state, mut client_query) =
query(&crs, index, &shard_config, &secret_key, &mut sampler)
.with_context(|| "Failed to generate query")?;
let packing_mode: PackingMode = packing_mode.into();
client_query.packing_mode = packing_mode;
if packing_mode == PackingMode::Tree {
client_query.inspiring_packing_keys = None;
}
(state, QueryPayload::Full(client_query))
}
VariantChoice::TwoPacking => {
let (state, mut seeded_query) =
query_seeded(&crs, index, &shard_config, &secret_key, &mut sampler)
.with_context(|| "Failed to generate seeded query")?;
let packing_mode: PackingMode = packing_mode.into();
seeded_query.packing_mode = packing_mode;
if packing_mode == PackingMode::Tree {
seeded_query.inspiring_packing_keys = None;
}
(state, QueryPayload::Seeded(seeded_query))
}
};
info!("Query generation time: {:.2?}", query_start.elapsed());
info!("Sending query to server...");
let send_start = Instant::now();
let response = match &payload {
QueryPayload::Full(client_query) => send_query(server, client_query).await?,
QueryPayload::Seeded(seeded_query) => send_seeded_query(server, seeded_query).await?,
};
info!("Server processing time: {} ms", response.processing_time_ms);
info!("Network round-trip: {:.2?}", send_start.elapsed());
info!("Extracting result...");
let extract_start = Instant::now();
let entry = match packing_mode {
PackingModeChoice::Inspiring => extract_inspiring(&crs, &state, &response.response, 32),
PackingModeChoice::Tree => {
extract_with_variant(&crs, &state, &response.response, 32, inspire_variant)
}
}
.with_context(|| "Failed to extract result")?;
info!("Extraction time: {:.2?}", extract_start.elapsed());
let total_time = total_start.elapsed();
println!();
println!("=== Query Result ===");
println!("Index: {}", index);
println!("Shard: {}", state.shard_id);
println!("Local index: {}", state.local_index);
println!("Entry (hex): 0x{}", hex_encode(&entry));
println!();
println!("=== Timing ===");
println!("Query generation: {:.2?}", query_start.elapsed());
println!("Server processing: {} ms", response.processing_time_ms);
println!("Extraction: {:.2?}", extract_start.elapsed());
println!("Total: {:.2?}", total_time);
Ok(())
}
async fn query_storage(
server: &str,
sk_path: &PathBuf,
state_path: &Path,
address: &str,
slot: &str,
variant: VariantChoice,
packing_mode: PackingModeChoice,
) -> Result<()> {
let address_bytes = parse_hex_address(address)?;
let slot_bytes = parse_hex_slot(slot)?;
info!("Loading state.bin from {}...", state_path.display());
let eth_db = EthereumStateDb::open(state_path)
.with_context(|| format!("Failed to load state.bin from {}", state_path.display()))?;
info!("Searching for storage slot (linear scan)...");
let search_start = Instant::now();
let mut found_index: Option<u64> = None;
for i in 0..eth_db.entry_count() {
let entry = eth_db.read_storage_entry(i)?;
if entry.address == address_bytes && entry.slot == slot_bytes {
found_index = Some(i);
break;
}
}
info!("Search time: {:.2?}", search_start.elapsed());
let index = match found_index {
Some(idx) => idx,
None => {
println!(
"Storage slot not found: address=0x{}, slot=0x{}",
hex_encode(&address_bytes),
hex_encode(&slot_bytes)
);
println!("The storage slot is not in the database.");
return Ok(());
}
};
info!("Found storage slot at index {}", index);
query_by_index(server, sk_path, index, variant, packing_mode).await
}
async fn batch_query(
server: &str,
sk_path: &PathBuf,
file: &PathBuf,
variant: VariantChoice,
packing_mode: PackingModeChoice,
) -> Result<()> {
info!("Loading secret key...");
let secret_key = load_secret_key(sk_path)?;
let batch_file = File::open(file)
.with_context(|| format!("Failed to open batch file: {}", file.display()))?;
let reader = BufReader::new(batch_file);
let indices: Vec<u64> = reader
.lines()
.filter_map(|line| {
line.ok().and_then(|l| {
let trimmed = l.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
None
} else {
trimmed.parse().ok()
}
})
})
.collect();
if indices.is_empty() {
println!("No valid indices found in batch file");
return Ok(());
}
println!("Processing {} queries...", indices.len());
info!("Fetching CRS from server...");
let crs = fetch_crs(server).await?;
let params_response = fetch_params(server).await?;
let shard_config = ShardConfig {
shard_size_bytes: (crs.ring_dim() as u64) * 32,
entry_size_bytes: 32,
total_entries: params_response.entry_count,
};
let pb = ProgressBar::new(indices.len() as u64);
pb.set_style(
ProgressStyle::default_bar()
.template(
"{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
)?
.progress_chars("#>-"),
);
let mut results = Vec::new();
let mut total_server_time = 0u64;
let inspire_variant = variant.inspire_variant();
for index in &indices {
let mut sampler = GaussianSampler::new(crs.params.sigma);
let (state, payload) = match variant {
VariantChoice::OnePacking => {
let (state, mut client_query) =
query(&crs, *index, &shard_config, &secret_key, &mut sampler)?;
let packing_mode: PackingMode = packing_mode.into();
client_query.packing_mode = packing_mode;
if packing_mode == PackingMode::Tree {
client_query.inspiring_packing_keys = None;
}
(state, QueryPayload::Full(client_query))
}
VariantChoice::TwoPacking => {
let (state, mut seeded_query) =
query_seeded(&crs, *index, &shard_config, &secret_key, &mut sampler)?;
let packing_mode: PackingMode = packing_mode.into();
seeded_query.packing_mode = packing_mode;
if packing_mode == PackingMode::Tree {
seeded_query.inspiring_packing_keys = None;
}
(state, QueryPayload::Seeded(seeded_query))
}
};
let response = match &payload {
QueryPayload::Full(client_query) => send_query(server, client_query).await?,
QueryPayload::Seeded(seeded_query) => send_seeded_query(server, seeded_query).await?,
};
total_server_time += response.processing_time_ms;
let entry = match packing_mode {
PackingModeChoice::Inspiring => extract_inspiring(&crs, &state, &response.response, 32),
PackingModeChoice::Tree => {
extract_with_variant(&crs, &state, &response.response, 32, inspire_variant)
}
}?;
results.push((*index, entry));
pb.inc(1);
}
pb.finish_with_message("Done");
println!();
println!("=== Batch Results ===");
for (index, entry) in &results {
println!("Index {}: 0x{}", index, hex_encode(entry));
}
println!();
println!("Total queries: {}", results.len());
println!("Total server time: {} ms", total_server_time);
println!(
"Average server time: {:.2} ms",
total_server_time as f64 / results.len() as f64
);
Ok(())
}
async fn fetch_crs(server: &str) -> Result<ServerCrs> {
let url = format!("{}/crs", server);
let client = reqwest::Client::new();
let crs: ServerCrs = client
.get(&url)
.send()
.await
.with_context(|| "Failed to connect to server")?
.json()
.await
.with_context(|| "Failed to parse CRS response")?;
Ok(crs)
}
async fn fetch_params(server: &str) -> Result<ParamsResponse> {
let url = format!("{}/params", server);
let client = reqwest::Client::new();
let params: ParamsResponse = client.get(&url).send().await?.json().await?;
Ok(params)
}
fn load_secret_key(path: &PathBuf) -> Result<RlweSecretKey> {
let file = File::open(path)
.with_context(|| format!("Failed to open secret key file: {}", path.display()))?;
let reader = BufReader::new(file);
let sk: RlweSecretKey =
serde_json::from_reader(reader).with_context(|| "Failed to parse secret key")?;
Ok(sk)
}
async fn send_query(server: &str, query: &ClientQuery) -> Result<QueryResponse> {
let url = format!("{}/query", server);
let client = reqwest::Client::new();
let response: QueryResponse = client
.post(&url)
.json(query)
.send()
.await
.with_context(|| "Failed to send query")?
.json()
.await
.with_context(|| "Failed to parse query response")?;
Ok(response)
}
async fn send_seeded_query(server: &str, query: &SeededClientQuery) -> Result<QueryResponse> {
let url = format!("{}/query_seeded", server);
let client = reqwest::Client::new();
let response: QueryResponse = client
.post(&url)
.json(query)
.send()
.await
.with_context(|| "Failed to send seeded query")?
.json()
.await
.with_context(|| "Failed to parse seeded query response")?;
Ok(response)
}
fn parse_hex_address(s: &str) -> Result<[u8; 20]> {
let s = s.strip_prefix("0x").unwrap_or(s);
if s.len() != 40 {
return Err(eyre::eyre!(
"Invalid address length: expected 40 hex chars, got {}",
s.len()
));
}
let bytes = hex_decode(s)?;
let mut result = [0u8; 20];
result.copy_from_slice(&bytes);
Ok(result)
}
fn parse_hex_slot(s: &str) -> Result<[u8; 32]> {
let s = s.strip_prefix("0x").unwrap_or(s);
if s.len() != 64 {
return Err(eyre::eyre!(
"Invalid slot length: expected 64 hex chars, got {}",
s.len()
));
}
let bytes = hex_decode(s)?;
let mut result = [0u8; 32];
result.copy_from_slice(&bytes);
Ok(result)
}
fn hex_encode(bytes: &[u8]) -> String {
hex::encode(bytes)
}
fn hex_decode(s: &str) -> Result<Vec<u8>> {
hex::decode(s).map_err(|e| eyre::eyre!("Invalid hex: {}", e))
}