Skip to main content

arcis_compiler/
profile_info.rs

1use crate::{
2    auxiliary_circuit_info::AuxiliaryCircuitInfo,
3    network_content::NetworkContent,
4    preprocess_info::PreprocessInfo,
5    profile_summary::ProfileSummary,
6};
7use std::{
8    collections::{hash_map::Entry, HashMap},
9    error::Error,
10    str::FromStr,
11};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub struct ProfileInfo {
15    pub total_gates: usize,
16    pub network_depth: usize,
17    pub network_content: NetworkContent,
18    pub pre_process: PreprocessInfo,
19    pub auxiliary_circuits: AuxiliaryCircuitInfo<usize>,
20    pub circuit_hash: usize,
21}
22
23impl ProfileInfo {
24    pub fn csv_header() -> impl Iterator<Item = &'static str> {
25        [
26            "Weight",
27            "Gate W.",
28            "Depth W.",
29            "🛜Size W.",
30            "Prepro. W.",
31            "Assoc. W.",
32            "Gates",
33            "Depth",
34            "🛜Size",
35            "🛜Bits",
36            "🛜Base",
37            "🛜Scal.",
38            "🛜M107",
39            "🛜Pts",
40            "daBits",
41            "Triples",
42            "Singlets",
43            "Pows",
44            "BiTriples",
45            "Binglets",
46            "Base🛟",
47            "Scal.🛟",
48            "Hash",
49        ]
50        .into_iter()
51    }
52    pub fn json_fields() -> impl Iterator<Item = &'static str> {
53        [
54            "weight",
55            "gate_weight",
56            "depth_weight",
57            "network_size_weight",
58            "preprocess_weight",
59            "auxiliary_circuit_weight",
60            "total_gates",
61            "network_depth",
62            "network_size",
63            "network_bit",
64            "network_base",
65            "network_scalar",
66            "network_mersenne",
67            "network_point",
68            "da_bits",
69            "arith_triples",
70            "arith_singlets",
71            "pow_pairs",
72            "bit_triples",
73            "bit_singlets",
74            "base_rescue",
75            "scalar_rescue",
76            "hash",
77        ]
78        .into_iter()
79    }
80    pub fn usize_values(&self) -> impl Iterator<Item = usize> {
81        let summary = self.summary();
82        let depth_weight = ProfileSummary::new(self.network_depth, 0, 0, 0, 0).weight();
83        let gate_weight = ProfileSummary::new(0, self.total_gates, 0, 0, 0).weight();
84        let network_size_weight = ProfileSummary::new(0, 0, summary.network_size, 0, 0).weight();
85        [
86            summary.weight(),
87            gate_weight,
88            depth_weight,
89            network_size_weight,
90            summary.preprocess_weight,
91            self.auxiliary_circuits.weight(),
92            self.total_gates,
93            self.network_depth,
94            summary.network_size,
95            self.network_content.bit,
96            self.network_content.base,
97            self.network_content.scalar,
98            self.network_content.mersenne,
99            self.network_content.point,
100            self.pre_process.da_bits,
101            self.pre_process.arith_triples,
102            self.pre_process.arith_singlets,
103            self.pre_process.pow_pairs,
104            self.pre_process.bit_triples,
105            self.pre_process.bit_singlets,
106            self.auxiliary_circuits.shared_rescue_base_field_key,
107            self.auxiliary_circuits.shared_rescue_scalar_field_key,
108            self.circuit_hash,
109        ]
110        .into_iter()
111    }
112    pub fn string_values(&self) -> impl Iterator<Item = String> {
113        self.usize_values().map(|x| x.to_string())
114    }
115    fn summary(&self) -> ProfileSummary {
116        ProfileSummary::new(
117            self.network_depth,
118            self.total_gates,
119            self.network_content.network_size(),
120            self.pre_process.weight(),
121            self.auxiliary_circuits.weight(),
122        )
123    }
124    pub fn weight(&self) -> usize {
125        self.summary().weight()
126    }
127    pub fn to_json(&self) -> String {
128        let hash_map = std::iter::zip(Self::json_fields(), self.usize_values())
129            .collect::<std::collections::BTreeMap<_, _>>();
130        serde_json::to_string(&hash_map).expect("ProfileInfo::to_json() failed")
131    }
132}
133
134impl TryFrom<Vec<usize>> for ProfileInfo {
135    type Error = Box<dyn Error>;
136
137    fn try_from(value: Vec<usize>) -> Result<Self, Self::Error> {
138        let expected = ProfileInfo::csv_header().count();
139        if value.len() < expected {
140            return Err(format!("Expected {} records, got {}", expected, value.len()).into());
141        }
142        let res = ProfileInfo {
143            total_gates: value[6],
144            network_depth: value[7],
145            network_content: NetworkContent::try_from(&value[9..14])?,
146            pre_process: PreprocessInfo::try_from(&value[14..20])?,
147            auxiliary_circuits: AuxiliaryCircuitInfo::try_from(&value[20..22])?,
148            circuit_hash: value[22],
149        };
150        Ok(res)
151    }
152}
153
154pub trait PerformanceTracker {
155    fn is_enabled(&self) -> bool;
156    fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>>;
157}
158
159#[derive(Debug)]
160pub struct PerformanceStore {
161    strict_mode: bool,
162    circuit_performances: HashMap<String, ProfileInfo>,
163    has_errored: bool,
164}
165
166impl PerformanceStore {
167    pub fn new(strict_mode: bool, circuit_depths: HashMap<String, ProfileInfo>) -> Self {
168        Self {
169            strict_mode,
170            circuit_performances: circuit_depths,
171            has_errored: false,
172        }
173    }
174    pub fn from_csv(strict_mode: bool, csv_file: &str) -> Result<Self, Box<dyn Error>> {
175        let circuit_depths = if std::fs::exists(csv_file)? {
176            let mut content = csv::Reader::from_path(csv_file)?;
177            let mut circuit_depths = HashMap::new();
178            for record in content.records() {
179                let record = record?;
180                let circuit = record[0].to_string();
181
182                let mut info_vec = Vec::with_capacity(record.len() - 1);
183
184                for s in record.iter().skip(1) {
185                    let info = usize::from_str(s)?;
186                    info_vec.push(info);
187                }
188
189                let profile_info = info_vec.try_into()?;
190
191                circuit_depths.insert(circuit, profile_info);
192            }
193            circuit_depths
194        } else {
195            if strict_mode {
196                return Err(format!("CSV performance file {} is missing.", csv_file).into());
197            }
198            HashMap::new()
199        };
200        Ok(Self::new(strict_mode, circuit_depths))
201    }
202    pub fn to_csv(&self, csv_file: &str) -> Result<(), Box<dyn Error>> {
203        if self.strict_mode {
204            return Ok(());
205        }
206        if self.has_errored {
207            return Err("Cannot save an errored performance store.".into());
208        }
209        let mut records: Vec<(&str, ProfileInfo)> = self
210            .circuit_performances
211            .iter()
212            .map(|(s, d)| (s.as_str(), *d))
213            .collect();
214        records.sort_by_key(|(a, b)| (b.network_depth == 0, *a));
215        let mut writer = csv::Writer::from_path(csv_file)?;
216        writer.write_field("Circuit Name")?;
217        writer.write_record(ProfileInfo::csv_header())?;
218        for (name, profile_info) in records {
219            writer.write_field(name)?;
220            writer.write_record(profile_info.string_values())?;
221        }
222        Ok(())
223    }
224    pub fn get_instructions(&self) -> Vec<String> {
225        self.circuit_performances.keys().cloned().collect()
226    }
227}
228
229impl PerformanceTracker for PerformanceStore {
230    fn is_enabled(&self) -> bool {
231        true
232    }
233    fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>> {
234        let entry = self.circuit_performances.entry(circuit.to_string());
235        if self.strict_mode && matches!(entry, Entry::Vacant(_)) {
236            self.has_errored = true;
237            return Err(format!("Missing circuit depth for {}", circuit).into());
238        }
239        let entry_val = entry.or_insert(profile_info);
240        if profile_info.network_depth > entry_val.network_depth {
241            self.has_errored = true;
242            return Err(format!(
243                "Computed circuit depth ({}) is greater than expected ({}) for {}.",
244                profile_info.network_depth, entry_val.network_depth, circuit
245            )
246            .into());
247        }
248        if profile_info != *entry_val {
249            if self.strict_mode {
250                self.has_errored = true;
251                return Err(format!(
252                    "Computed circuit performance ({:?}) is different than expected ({:?}) for {}. Please run the same test with `TEST_CIRCUIT= cargo test --all-features`.",
253                    profile_info, *entry_val, circuit
254                )
255                    .into());
256            } else {
257                *entry_val = profile_info;
258            }
259        }
260
261        Ok(())
262    }
263}
264
265impl<T: PerformanceTracker> PerformanceTracker for Option<T> {
266    fn is_enabled(&self) -> bool {
267        match self {
268            None => false,
269            Some(a) => a.is_enabled(),
270        }
271    }
272
273    fn track(&mut self, circuit: &str, depth: ProfileInfo) -> Result<(), Box<dyn Error>> {
274        match self {
275            None => Err("Tried to track in empty performance tracker.".into()),
276            Some(a) => a.track(circuit, depth),
277        }
278    }
279}