1use crate::{
2 auxiliary_circuit_info::AuxiliaryCircuitInfo,
3 network_content::NetworkContent,
4 preprocess_info::PreprocessInfo,
5 profile_summary::ProfileSummary,
6};
7use std::{
8 collections::{hash_map::Entry, HashMap},
9 error::Error,
10 str::FromStr,
11};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub struct ProfileInfo {
15 pub total_gates: usize,
16 pub network_depth: usize,
17 pub network_content: NetworkContent,
18 pub pre_process: PreprocessInfo,
19 pub auxiliary_circuits: AuxiliaryCircuitInfo<usize>,
20 pub circuit_hash: usize,
21}
22
23impl ProfileInfo {
24 pub fn csv_header() -> impl Iterator<Item = &'static str> {
25 [
26 "Weight",
27 "Gate W.",
28 "Depth W.",
29 "🛜Size W.",
30 "Prepro. W.",
31 "Assoc. W.",
32 "Gates",
33 "Depth",
34 "🛜Size",
35 "🛜Bits",
36 "🛜Base",
37 "🛜Scal.",
38 "🛜M107",
39 "🛜Pts",
40 "daBits",
41 "Triples",
42 "Singlets",
43 "Pows",
44 "BiTriples",
45 "Binglets",
46 "Base🛟",
47 "Scal.🛟",
48 "Hash",
49 ]
50 .into_iter()
51 }
52 pub fn json_fields() -> impl Iterator<Item = &'static str> {
53 [
54 "weight",
55 "gate_weight",
56 "depth_weight",
57 "network_size_weight",
58 "preprocess_weight",
59 "auxiliary_circuit_weight",
60 "total_gates",
61 "network_depth",
62 "network_size",
63 "network_bit",
64 "network_base",
65 "network_scalar",
66 "network_mersenne",
67 "network_point",
68 "da_bits",
69 "arith_triples",
70 "arith_singlets",
71 "pow_pairs",
72 "bit_triples",
73 "bit_singlets",
74 "base_rescue",
75 "scalar_rescue",
76 "hash",
77 ]
78 .into_iter()
79 }
80 pub fn usize_values(&self) -> impl Iterator<Item = usize> {
81 let summary = self.summary();
82 let depth_weight = ProfileSummary::new(self.network_depth, 0, 0, 0, 0).weight();
83 let gate_weight = ProfileSummary::new(0, self.total_gates, 0, 0, 0).weight();
84 let network_size_weight = ProfileSummary::new(0, 0, summary.network_size, 0, 0).weight();
85 [
86 summary.weight(),
87 gate_weight,
88 depth_weight,
89 network_size_weight,
90 summary.preprocess_weight,
91 self.auxiliary_circuits.weight(),
92 self.total_gates,
93 self.network_depth,
94 summary.network_size,
95 self.network_content.bit,
96 self.network_content.base,
97 self.network_content.scalar,
98 self.network_content.mersenne,
99 self.network_content.point,
100 self.pre_process.da_bits,
101 self.pre_process.arith_triples,
102 self.pre_process.arith_singlets,
103 self.pre_process.pow_pairs,
104 self.pre_process.bit_triples,
105 self.pre_process.bit_singlets,
106 self.auxiliary_circuits.shared_rescue_base_field_key,
107 self.auxiliary_circuits.shared_rescue_scalar_field_key,
108 self.circuit_hash,
109 ]
110 .into_iter()
111 }
112 pub fn string_values(&self) -> impl Iterator<Item = String> {
113 self.usize_values().map(|x| x.to_string())
114 }
115 fn summary(&self) -> ProfileSummary {
116 ProfileSummary::new(
117 self.network_depth,
118 self.total_gates,
119 self.network_content.network_size(),
120 self.pre_process.weight(),
121 self.auxiliary_circuits.weight(),
122 )
123 }
124 pub fn weight(&self) -> usize {
125 self.summary().weight()
126 }
127 pub fn to_json(&self) -> String {
128 let hash_map = std::iter::zip(Self::json_fields(), self.usize_values())
129 .collect::<std::collections::BTreeMap<_, _>>();
130 serde_json::to_string(&hash_map).expect("ProfileInfo::to_json() failed")
131 }
132}
133
134impl TryFrom<Vec<usize>> for ProfileInfo {
135 type Error = Box<dyn Error>;
136
137 fn try_from(value: Vec<usize>) -> Result<Self, Self::Error> {
138 let expected = ProfileInfo::csv_header().count();
139 if value.len() < expected {
140 return Err(format!("Expected {} records, got {}", expected, value.len()).into());
141 }
142 let res = ProfileInfo {
143 total_gates: value[6],
144 network_depth: value[7],
145 network_content: NetworkContent::try_from(&value[9..14])?,
146 pre_process: PreprocessInfo::try_from(&value[14..20])?,
147 auxiliary_circuits: AuxiliaryCircuitInfo::try_from(&value[20..22])?,
148 circuit_hash: value[22],
149 };
150 Ok(res)
151 }
152}
153
154pub trait PerformanceTracker {
155 fn is_enabled(&self) -> bool;
156 fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>>;
157}
158
159#[derive(Debug)]
160pub struct PerformanceStore {
161 strict_mode: bool,
162 circuit_performances: HashMap<String, ProfileInfo>,
163 has_errored: bool,
164}
165
166impl PerformanceStore {
167 pub fn new(strict_mode: bool, circuit_depths: HashMap<String, ProfileInfo>) -> Self {
168 Self {
169 strict_mode,
170 circuit_performances: circuit_depths,
171 has_errored: false,
172 }
173 }
174 pub fn from_csv(strict_mode: bool, csv_file: &str) -> Result<Self, Box<dyn Error>> {
175 let circuit_depths = if std::fs::exists(csv_file)? {
176 let mut content = csv::Reader::from_path(csv_file)?;
177 let mut circuit_depths = HashMap::new();
178 for record in content.records() {
179 let record = record?;
180 let circuit = record[0].to_string();
181
182 let mut info_vec = Vec::with_capacity(record.len() - 1);
183
184 for s in record.iter().skip(1) {
185 let info = usize::from_str(s)?;
186 info_vec.push(info);
187 }
188
189 let profile_info = info_vec.try_into()?;
190
191 circuit_depths.insert(circuit, profile_info);
192 }
193 circuit_depths
194 } else {
195 if strict_mode {
196 return Err(format!("CSV performance file {} is missing.", csv_file).into());
197 }
198 HashMap::new()
199 };
200 Ok(Self::new(strict_mode, circuit_depths))
201 }
202 pub fn to_csv(&self, csv_file: &str) -> Result<(), Box<dyn Error>> {
203 if self.strict_mode {
204 return Ok(());
205 }
206 if self.has_errored {
207 return Err("Cannot save an errored performance store.".into());
208 }
209 let mut records: Vec<(&str, ProfileInfo)> = self
210 .circuit_performances
211 .iter()
212 .map(|(s, d)| (s.as_str(), *d))
213 .collect();
214 records.sort_by_key(|(a, b)| (b.network_depth == 0, *a));
215 let mut writer = csv::Writer::from_path(csv_file)?;
216 writer.write_field("Circuit Name")?;
217 writer.write_record(ProfileInfo::csv_header())?;
218 for (name, profile_info) in records {
219 writer.write_field(name)?;
220 writer.write_record(profile_info.string_values())?;
221 }
222 Ok(())
223 }
224 pub fn get_instructions(&self) -> Vec<String> {
225 self.circuit_performances.keys().cloned().collect()
226 }
227}
228
229impl PerformanceTracker for PerformanceStore {
230 fn is_enabled(&self) -> bool {
231 true
232 }
233 fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>> {
234 let entry = self.circuit_performances.entry(circuit.to_string());
235 if self.strict_mode && matches!(entry, Entry::Vacant(_)) {
236 self.has_errored = true;
237 return Err(format!("Missing circuit depth for {}", circuit).into());
238 }
239 let entry_val = entry.or_insert(profile_info);
240 if profile_info.network_depth > entry_val.network_depth {
241 self.has_errored = true;
242 return Err(format!(
243 "Computed circuit depth ({}) is greater than expected ({}) for {}.",
244 profile_info.network_depth, entry_val.network_depth, circuit
245 )
246 .into());
247 }
248 if profile_info != *entry_val {
249 if self.strict_mode {
250 self.has_errored = true;
251 return Err(format!(
252 "Computed circuit performance ({:?}) is different than expected ({:?}) for {}. Please run the same test with `TEST_CIRCUIT= cargo test --all-features`.",
253 profile_info, *entry_val, circuit
254 )
255 .into());
256 } else {
257 *entry_val = profile_info;
258 }
259 }
260
261 Ok(())
262 }
263}
264
265impl<T: PerformanceTracker> PerformanceTracker for Option<T> {
266 fn is_enabled(&self) -> bool {
267 match self {
268 None => false,
269 Some(a) => a.is_enabled(),
270 }
271 }
272
273 fn track(&mut self, circuit: &str, depth: ProfileInfo) -> Result<(), Box<dyn Error>> {
274 match self {
275 None => Err("Tried to track in empty performance tracker.".into()),
276 Some(a) => a.track(circuit, depth),
277 }
278 }
279}