1use crate::{
2 network_content::NetworkContent,
3 preprocess_info::PreprocessInfo,
4 profile_summary::ProfileSummary,
5};
6use std::{
7 collections::{hash_map::Entry, HashMap},
8 error::Error,
9 str::FromStr,
10};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct ProfileInfo {
14 pub total_gates: usize,
15 pub network_depth: usize,
16 pub network_content: NetworkContent,
17 pub pre_process: PreprocessInfo,
18}
19
20impl ProfileInfo {
21 pub fn csv_header() -> impl Iterator<Item = &'static str> {
22 [
23 "Weight",
24 "Gate W.",
25 "Depth W.",
26 "🛜Size W.",
27 "Prepro. W.",
28 "Gates",
29 "Depth",
30 "🛜Size",
31 "🛜Bits",
32 "🛜Base",
33 "🛜Scal.",
34 "🛜M107",
35 "🛜Pts",
36 "daBits",
37 "Triples",
38 "Singlets",
39 "Pows",
40 "BiTriples",
41 "Binglets",
42 ]
43 .into_iter()
44 }
45 pub fn json_fields() -> impl Iterator<Item = &'static str> {
46 [
47 "weight",
48 "gate_weight",
49 "depth_weight",
50 "network_size_weight",
51 "preprocess_weight",
52 "total_gates",
53 "network_depth",
54 "network_size",
55 "network_bit",
56 "network_base",
57 "network_scalar",
58 "network_mersenne",
59 "network_point",
60 "da_bits",
61 "arith_triples",
62 "arith_singlets",
63 "pow_pairs",
64 "bit_triples",
65 "bit_singlets",
66 ]
67 .into_iter()
68 }
69 pub fn usize_values(&self) -> impl Iterator<Item = usize> {
70 let summary = self.summary();
71 let depth_weight = ProfileSummary::new(self.network_depth, 0, 0, 0).weight();
72 let gate_weight = ProfileSummary::new(0, self.total_gates, 0, 0).weight();
73 let network_size_weight = ProfileSummary::new(0, 0, summary.network_size, 0).weight();
74 [
75 summary.weight(),
76 gate_weight,
77 depth_weight,
78 network_size_weight,
79 summary.preprocess_weight,
80 self.total_gates,
81 self.network_depth,
82 summary.network_size,
83 self.network_content.bit,
84 self.network_content.base,
85 self.network_content.scalar,
86 self.network_content.mersenne,
87 self.network_content.point,
88 self.pre_process.da_bits,
89 self.pre_process.arith_triples,
90 self.pre_process.arith_singlets,
91 self.pre_process.pow_pairs,
92 self.pre_process.bit_triples,
93 self.pre_process.bit_singlets,
94 ]
95 .into_iter()
96 }
97 pub fn string_values(&self) -> impl Iterator<Item = String> {
98 self.usize_values().map(|x| x.to_string())
99 }
100 fn summary(&self) -> ProfileSummary {
101 ProfileSummary::new(
102 self.network_depth,
103 self.total_gates,
104 self.network_content.network_size(),
105 self.pre_process.weight(),
106 )
107 }
108 pub fn weight(&self) -> usize {
109 self.summary().weight()
110 }
111 pub fn to_json(&self) -> String {
112 let hash_map =
113 std::iter::zip(Self::json_fields(), self.usize_values()).collect::<HashMap<_, _>>();
114 serde_json::to_string(&hash_map).expect("ProfileInfo::to_json() failed")
115 }
116}
117
118impl TryFrom<Vec<usize>> for ProfileInfo {
119 type Error = Box<dyn Error>;
120
121 fn try_from(value: Vec<usize>) -> Result<Self, Self::Error> {
122 let expected = ProfileInfo::csv_header().count();
123 if value.len() < expected {
124 return Err(format!("Expected {} records, got {}", expected, value.len()).into());
125 }
126 let res = ProfileInfo {
127 total_gates: value[5],
128 network_depth: value[6],
129 network_content: NetworkContent::try_from(&value[8..13])?,
130 pre_process: PreprocessInfo::try_from(&value[13..])?,
131 };
132 Ok(res)
133 }
134}
135
136pub trait PerformanceTracker {
137 fn is_enabled(&self) -> bool;
138 fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>>;
139}
140
141#[derive(Debug)]
142pub struct PerformanceStore {
143 strict_mode: bool,
144 circuit_performances: HashMap<String, ProfileInfo>,
145 has_errored: bool,
146}
147
148impl PerformanceStore {
149 pub fn new(strict_mode: bool, circuit_depths: HashMap<String, ProfileInfo>) -> Self {
150 Self {
151 strict_mode,
152 circuit_performances: circuit_depths,
153 has_errored: false,
154 }
155 }
156 pub fn from_csv(strict_mode: bool, csv_file: &str) -> Result<Self, Box<dyn Error>> {
157 let circuit_depths = if std::fs::exists(csv_file)? {
158 let mut content = csv::Reader::from_path(csv_file)?;
159 let mut circuit_depths = HashMap::new();
160 for record in content.records() {
161 let record = record?;
162 let circuit = record[0].to_string();
163
164 let mut info_vec = Vec::with_capacity(record.len() - 1);
165
166 for s in record.iter().skip(1) {
167 let info = usize::from_str(s)?;
168 info_vec.push(info);
169 }
170
171 let profile_info = info_vec.try_into()?;
172
173 circuit_depths.insert(circuit, profile_info);
174 }
175 circuit_depths
176 } else {
177 if strict_mode {
178 return Err(format!("CSV performance file {} is missing.", csv_file).into());
179 }
180 HashMap::new()
181 };
182 Ok(Self::new(strict_mode, circuit_depths))
183 }
184 pub fn to_csv(&self, csv_file: &str) -> Result<(), Box<dyn Error>> {
185 if self.strict_mode {
186 return Ok(());
187 }
188 if self.has_errored {
189 return Err("Cannot save an errored performance store.".into());
190 }
191 let mut records: Vec<(&str, ProfileInfo)> = self
192 .circuit_performances
193 .iter()
194 .map(|(s, d)| (s.as_str(), *d))
195 .collect();
196 records.sort_by_key(|(a, b)| (b.network_depth == 0, *a));
197 let mut writer = csv::Writer::from_path(csv_file)?;
198 writer.write_field("Circuit Name")?;
199 writer.write_record(ProfileInfo::csv_header())?;
200 for (name, profile_info) in records {
201 writer.write_field(name)?;
202 writer.write_record(profile_info.string_values())?;
203 }
204 Ok(())
205 }
206 pub fn get_instructions(&self) -> Vec<String> {
207 self.circuit_performances.keys().cloned().collect()
208 }
209}
210
211impl PerformanceTracker for PerformanceStore {
212 fn is_enabled(&self) -> bool {
213 true
214 }
215 fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>> {
216 let entry = self.circuit_performances.entry(circuit.to_string());
217 if self.strict_mode && matches!(entry, Entry::Vacant(_)) {
218 self.has_errored = true;
219 return Err(format!("Missing circuit depth for {}", circuit).into());
220 }
221 let entry_val = entry.or_insert(profile_info);
222 if profile_info.network_depth > entry_val.network_depth {
223 self.has_errored = true;
224 return Err(format!(
225 "Computed circuit depth ({}) is greater than expected ({}) for {}.",
226 profile_info.network_depth, entry_val.network_depth, circuit
227 )
228 .into());
229 }
230 if profile_info != *entry_val {
231 if self.strict_mode {
232 self.has_errored = true;
233 return Err(format!(
234 "Computed circuit performance ({:?}) is different than expected ({:?}) for {}. Please run the same test with `TEST_CIRCUIT= cargo test --all-features`.",
235 profile_info, *entry_val, circuit
236 )
237 .into());
238 } else {
239 *entry_val = profile_info;
240 }
241 }
242
243 Ok(())
244 }
245}
246
247impl<T: PerformanceTracker> PerformanceTracker for Option<T> {
248 fn is_enabled(&self) -> bool {
249 match self {
250 None => false,
251 Some(a) => a.is_enabled(),
252 }
253 }
254
255 fn track(&mut self, circuit: &str, depth: ProfileInfo) -> Result<(), Box<dyn Error>> {
256 match self {
257 None => Err("Tried to track in empty performance tracker.".into()),
258 Some(a) => a.track(circuit, depth),
259 }
260 }
261}