use crate::{
network_content::NetworkContent,
preprocess_info::PreprocessInfo,
profile_summary::ProfileSummary,
};
use std::{
collections::{hash_map::Entry, HashMap},
error::Error,
str::FromStr,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ProfileInfo {
pub total_gates: usize,
pub network_depth: usize,
pub network_content: NetworkContent,
pub pre_process: PreprocessInfo,
pub circuit_hash: usize,
}
impl ProfileInfo {
pub fn csv_header() -> impl Iterator<Item = &'static str> {
[
"Weight",
"Gate W.",
"Depth W.",
"🛜Size W.",
"Prepro. W.",
"Gates",
"Depth",
"🛜Size",
"🛜Bits",
"🛜Base",
"🛜Scal.",
"🛜M107",
"🛜Pts",
"daBits",
"Triples",
"Singlets",
"Pows",
"BiTriples",
"Binglets",
"Hash",
]
.into_iter()
}
pub fn json_fields() -> impl Iterator<Item = &'static str> {
[
"weight",
"gate_weight",
"depth_weight",
"network_size_weight",
"preprocess_weight",
"total_gates",
"network_depth",
"network_size",
"network_bit",
"network_base",
"network_scalar",
"network_mersenne",
"network_point",
"da_bits",
"arith_triples",
"arith_singlets",
"pow_pairs",
"bit_triples",
"bit_singlets",
"hash",
]
.into_iter()
}
pub fn usize_values(&self) -> impl Iterator<Item = usize> {
let summary = self.summary();
let depth_weight = ProfileSummary::new(self.network_depth, 0, 0, 0).weight();
let gate_weight = ProfileSummary::new(0, self.total_gates, 0, 0).weight();
let network_size_weight = ProfileSummary::new(0, 0, summary.network_size, 0).weight();
[
summary.weight(),
gate_weight,
depth_weight,
network_size_weight,
summary.preprocess_weight,
self.total_gates,
self.network_depth,
summary.network_size,
self.network_content.bit,
self.network_content.base,
self.network_content.scalar,
self.network_content.mersenne,
self.network_content.point,
self.pre_process.da_bits,
self.pre_process.arith_triples,
self.pre_process.arith_singlets,
self.pre_process.pow_pairs,
self.pre_process.bit_triples,
self.pre_process.bit_singlets,
self.circuit_hash,
]
.into_iter()
}
pub fn string_values(&self) -> impl Iterator<Item = String> {
self.usize_values().map(|x| x.to_string())
}
fn summary(&self) -> ProfileSummary {
ProfileSummary::new(
self.network_depth,
self.total_gates,
self.network_content.network_size(),
self.pre_process.weight(),
)
}
pub fn weight(&self) -> usize {
self.summary().weight()
}
pub fn to_json(&self) -> String {
let hash_map =
std::iter::zip(Self::json_fields(), self.usize_values()).collect::<HashMap<_, _>>();
serde_json::to_string(&hash_map).expect("ProfileInfo::to_json() failed")
}
}
impl TryFrom<Vec<usize>> for ProfileInfo {
type Error = Box<dyn Error>;
fn try_from(value: Vec<usize>) -> Result<Self, Self::Error> {
let expected = ProfileInfo::csv_header().count();
if value.len() < expected {
return Err(format!("Expected {} records, got {}", expected, value.len()).into());
}
let res = ProfileInfo {
total_gates: value[5],
network_depth: value[6],
network_content: NetworkContent::try_from(&value[8..13])?,
pre_process: PreprocessInfo::try_from(&value[13..19])?,
circuit_hash: value[19],
};
Ok(res)
}
}
pub trait PerformanceTracker {
fn is_enabled(&self) -> bool;
fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>>;
}
#[derive(Debug)]
pub struct PerformanceStore {
strict_mode: bool,
circuit_performances: HashMap<String, ProfileInfo>,
has_errored: bool,
}
impl PerformanceStore {
pub fn new(strict_mode: bool, circuit_depths: HashMap<String, ProfileInfo>) -> Self {
Self {
strict_mode,
circuit_performances: circuit_depths,
has_errored: false,
}
}
pub fn from_csv(strict_mode: bool, csv_file: &str) -> Result<Self, Box<dyn Error>> {
let circuit_depths = if std::fs::exists(csv_file)? {
let mut content = csv::Reader::from_path(csv_file)?;
let mut circuit_depths = HashMap::new();
for record in content.records() {
let record = record?;
let circuit = record[0].to_string();
let mut info_vec = Vec::with_capacity(record.len() - 1);
for s in record.iter().skip(1) {
let info = usize::from_str(s)?;
info_vec.push(info);
}
let profile_info = info_vec.try_into()?;
circuit_depths.insert(circuit, profile_info);
}
circuit_depths
} else {
if strict_mode {
return Err(format!("CSV performance file {} is missing.", csv_file).into());
}
HashMap::new()
};
Ok(Self::new(strict_mode, circuit_depths))
}
pub fn to_csv(&self, csv_file: &str) -> Result<(), Box<dyn Error>> {
if self.strict_mode {
return Ok(());
}
if self.has_errored {
return Err("Cannot save an errored performance store.".into());
}
let mut records: Vec<(&str, ProfileInfo)> = self
.circuit_performances
.iter()
.map(|(s, d)| (s.as_str(), *d))
.collect();
records.sort_by_key(|(a, b)| (b.network_depth == 0, *a));
let mut writer = csv::Writer::from_path(csv_file)?;
writer.write_field("Circuit Name")?;
writer.write_record(ProfileInfo::csv_header())?;
for (name, profile_info) in records {
writer.write_field(name)?;
writer.write_record(profile_info.string_values())?;
}
Ok(())
}
pub fn get_instructions(&self) -> Vec<String> {
self.circuit_performances.keys().cloned().collect()
}
}
impl PerformanceTracker for PerformanceStore {
fn is_enabled(&self) -> bool {
true
}
fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>> {
let entry = self.circuit_performances.entry(circuit.to_string());
if self.strict_mode && matches!(entry, Entry::Vacant(_)) {
self.has_errored = true;
return Err(format!("Missing circuit depth for {}", circuit).into());
}
let entry_val = entry.or_insert(profile_info);
if profile_info.network_depth > entry_val.network_depth {
self.has_errored = true;
return Err(format!(
"Computed circuit depth ({}) is greater than expected ({}) for {}.",
profile_info.network_depth, entry_val.network_depth, circuit
)
.into());
}
if profile_info != *entry_val {
if self.strict_mode {
self.has_errored = true;
return Err(format!(
"Computed circuit performance ({:?}) is different than expected ({:?}) for {}. Please run the same test with `TEST_CIRCUIT= cargo test --all-features`.",
profile_info, *entry_val, circuit
)
.into());
} else {
*entry_val = profile_info;
}
}
Ok(())
}
}
impl<T: PerformanceTracker> PerformanceTracker for Option<T> {
fn is_enabled(&self) -> bool {
match self {
None => false,
Some(a) => a.is_enabled(),
}
}
fn track(&mut self, circuit: &str, depth: ProfileInfo) -> Result<(), Box<dyn Error>> {
match self {
None => Err("Tried to track in empty performance tracker.".into()),
Some(a) => a.track(circuit, depth),
}
}
}