1use crate::{
2 network_content::NetworkContent,
3 preprocess_info::PreprocessInfo,
4 profile_summary::ProfileSummary,
5};
6use std::{
7 collections::{hash_map::Entry, HashMap},
8 error::Error,
9 str::FromStr,
10};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct ProfileInfo {
14 pub total_gates: usize,
15 pub network_depth: usize,
16 pub network_content: NetworkContent,
17 pub pre_process: PreprocessInfo,
18 pub circuit_hash: usize,
19}
20
21impl ProfileInfo {
22 pub fn csv_header() -> impl Iterator<Item = &'static str> {
23 [
24 "Weight",
25 "Gate W.",
26 "Depth W.",
27 "🛜Size W.",
28 "Prepro. W.",
29 "Gates",
30 "Depth",
31 "🛜Size",
32 "🛜Bits",
33 "🛜Base",
34 "🛜Scal.",
35 "🛜M107",
36 "🛜Pts",
37 "daBits",
38 "Triples",
39 "Singlets",
40 "Pows",
41 "BiTriples",
42 "Binglets",
43 "Hash",
44 ]
45 .into_iter()
46 }
47 pub fn json_fields() -> impl Iterator<Item = &'static str> {
48 [
49 "weight",
50 "gate_weight",
51 "depth_weight",
52 "network_size_weight",
53 "preprocess_weight",
54 "total_gates",
55 "network_depth",
56 "network_size",
57 "network_bit",
58 "network_base",
59 "network_scalar",
60 "network_mersenne",
61 "network_point",
62 "da_bits",
63 "arith_triples",
64 "arith_singlets",
65 "pow_pairs",
66 "bit_triples",
67 "bit_singlets",
68 "hash",
69 ]
70 .into_iter()
71 }
72 pub fn usize_values(&self) -> impl Iterator<Item = usize> {
73 let summary = self.summary();
74 let depth_weight = ProfileSummary::new(self.network_depth, 0, 0, 0).weight();
75 let gate_weight = ProfileSummary::new(0, self.total_gates, 0, 0).weight();
76 let network_size_weight = ProfileSummary::new(0, 0, summary.network_size, 0).weight();
77 [
78 summary.weight(),
79 gate_weight,
80 depth_weight,
81 network_size_weight,
82 summary.preprocess_weight,
83 self.total_gates,
84 self.network_depth,
85 summary.network_size,
86 self.network_content.bit,
87 self.network_content.base,
88 self.network_content.scalar,
89 self.network_content.mersenne,
90 self.network_content.point,
91 self.pre_process.da_bits,
92 self.pre_process.arith_triples,
93 self.pre_process.arith_singlets,
94 self.pre_process.pow_pairs,
95 self.pre_process.bit_triples,
96 self.pre_process.bit_singlets,
97 self.circuit_hash,
98 ]
99 .into_iter()
100 }
101 pub fn string_values(&self) -> impl Iterator<Item = String> {
102 self.usize_values().map(|x| x.to_string())
103 }
104 fn summary(&self) -> ProfileSummary {
105 ProfileSummary::new(
106 self.network_depth,
107 self.total_gates,
108 self.network_content.network_size(),
109 self.pre_process.weight(),
110 )
111 }
112 pub fn weight(&self) -> usize {
113 self.summary().weight()
114 }
115 pub fn to_json(&self) -> String {
116 let hash_map =
117 std::iter::zip(Self::json_fields(), self.usize_values()).collect::<HashMap<_, _>>();
118 serde_json::to_string(&hash_map).expect("ProfileInfo::to_json() failed")
119 }
120}
121
122impl TryFrom<Vec<usize>> for ProfileInfo {
123 type Error = Box<dyn Error>;
124
125 fn try_from(value: Vec<usize>) -> Result<Self, Self::Error> {
126 let expected = ProfileInfo::csv_header().count();
127 if value.len() < expected {
128 return Err(format!("Expected {} records, got {}", expected, value.len()).into());
129 }
130 let res = ProfileInfo {
131 total_gates: value[5],
132 network_depth: value[6],
133 network_content: NetworkContent::try_from(&value[8..13])?,
134 pre_process: PreprocessInfo::try_from(&value[13..19])?,
135 circuit_hash: value[19],
136 };
137 Ok(res)
138 }
139}
140
141pub trait PerformanceTracker {
142 fn is_enabled(&self) -> bool;
143 fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>>;
144}
145
146#[derive(Debug)]
147pub struct PerformanceStore {
148 strict_mode: bool,
149 circuit_performances: HashMap<String, ProfileInfo>,
150 has_errored: bool,
151}
152
153impl PerformanceStore {
154 pub fn new(strict_mode: bool, circuit_depths: HashMap<String, ProfileInfo>) -> Self {
155 Self {
156 strict_mode,
157 circuit_performances: circuit_depths,
158 has_errored: false,
159 }
160 }
161 pub fn from_csv(strict_mode: bool, csv_file: &str) -> Result<Self, Box<dyn Error>> {
162 let circuit_depths = if std::fs::exists(csv_file)? {
163 let mut content = csv::Reader::from_path(csv_file)?;
164 let mut circuit_depths = HashMap::new();
165 for record in content.records() {
166 let record = record?;
167 let circuit = record[0].to_string();
168
169 let mut info_vec = Vec::with_capacity(record.len() - 1);
170
171 for s in record.iter().skip(1) {
172 let info = usize::from_str(s)?;
173 info_vec.push(info);
174 }
175
176 let profile_info = info_vec.try_into()?;
177
178 circuit_depths.insert(circuit, profile_info);
179 }
180 circuit_depths
181 } else {
182 if strict_mode {
183 return Err(format!("CSV performance file {} is missing.", csv_file).into());
184 }
185 HashMap::new()
186 };
187 Ok(Self::new(strict_mode, circuit_depths))
188 }
189 pub fn to_csv(&self, csv_file: &str) -> Result<(), Box<dyn Error>> {
190 if self.strict_mode {
191 return Ok(());
192 }
193 if self.has_errored {
194 return Err("Cannot save an errored performance store.".into());
195 }
196 let mut records: Vec<(&str, ProfileInfo)> = self
197 .circuit_performances
198 .iter()
199 .map(|(s, d)| (s.as_str(), *d))
200 .collect();
201 records.sort_by_key(|(a, b)| (b.network_depth == 0, *a));
202 let mut writer = csv::Writer::from_path(csv_file)?;
203 writer.write_field("Circuit Name")?;
204 writer.write_record(ProfileInfo::csv_header())?;
205 for (name, profile_info) in records {
206 writer.write_field(name)?;
207 writer.write_record(profile_info.string_values())?;
208 }
209 Ok(())
210 }
211 pub fn get_instructions(&self) -> Vec<String> {
212 self.circuit_performances.keys().cloned().collect()
213 }
214}
215
216impl PerformanceTracker for PerformanceStore {
217 fn is_enabled(&self) -> bool {
218 true
219 }
220 fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>> {
221 let entry = self.circuit_performances.entry(circuit.to_string());
222 if self.strict_mode && matches!(entry, Entry::Vacant(_)) {
223 self.has_errored = true;
224 return Err(format!("Missing circuit depth for {}", circuit).into());
225 }
226 let entry_val = entry.or_insert(profile_info);
227 if profile_info.network_depth > entry_val.network_depth {
228 self.has_errored = true;
229 return Err(format!(
230 "Computed circuit depth ({}) is greater than expected ({}) for {}.",
231 profile_info.network_depth, entry_val.network_depth, circuit
232 )
233 .into());
234 }
235 if profile_info != *entry_val {
236 if self.strict_mode {
237 self.has_errored = true;
238 return Err(format!(
239 "Computed circuit performance ({:?}) is different than expected ({:?}) for {}. Please run the same test with `TEST_CIRCUIT= cargo test --all-features`.",
240 profile_info, *entry_val, circuit
241 )
242 .into());
243 } else {
244 *entry_val = profile_info;
245 }
246 }
247
248 Ok(())
249 }
250}
251
252impl<T: PerformanceTracker> PerformanceTracker for Option<T> {
253 fn is_enabled(&self) -> bool {
254 match self {
255 None => false,
256 Some(a) => a.is_enabled(),
257 }
258 }
259
260 fn track(&mut self, circuit: &str, depth: ProfileInfo) -> Result<(), Box<dyn Error>> {
261 match self {
262 None => Err("Tried to track in empty performance tracker.".into()),
263 Some(a) => a.track(circuit, depth),
264 }
265 }
266}