Skip to main content

arcis_compiler/
profile_info.rs

1use crate::{
2    network_content::NetworkContent,
3    preprocess_info::PreprocessInfo,
4    profile_summary::ProfileSummary,
5};
6use std::{
7    collections::{hash_map::Entry, HashMap},
8    error::Error,
9    str::FromStr,
10};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct ProfileInfo {
14    pub total_gates: usize,
15    pub network_depth: usize,
16    pub network_content: NetworkContent,
17    pub pre_process: PreprocessInfo,
18    pub circuit_hash: usize,
19}
20
21impl ProfileInfo {
22    pub fn csv_header() -> impl Iterator<Item = &'static str> {
23        [
24            "Weight",
25            "Gate W.",
26            "Depth W.",
27            "🛜Size W.",
28            "Prepro. W.",
29            "Gates",
30            "Depth",
31            "🛜Size",
32            "🛜Bits",
33            "🛜Base",
34            "🛜Scal.",
35            "🛜M107",
36            "🛜Pts",
37            "daBits",
38            "Triples",
39            "Singlets",
40            "Pows",
41            "BiTriples",
42            "Binglets",
43            "Hash",
44        ]
45        .into_iter()
46    }
47    pub fn json_fields() -> impl Iterator<Item = &'static str> {
48        [
49            "weight",
50            "gate_weight",
51            "depth_weight",
52            "network_size_weight",
53            "preprocess_weight",
54            "total_gates",
55            "network_depth",
56            "network_size",
57            "network_bit",
58            "network_base",
59            "network_scalar",
60            "network_mersenne",
61            "network_point",
62            "da_bits",
63            "arith_triples",
64            "arith_singlets",
65            "pow_pairs",
66            "bit_triples",
67            "bit_singlets",
68            "hash",
69        ]
70        .into_iter()
71    }
72    pub fn usize_values(&self) -> impl Iterator<Item = usize> {
73        let summary = self.summary();
74        let depth_weight = ProfileSummary::new(self.network_depth, 0, 0, 0).weight();
75        let gate_weight = ProfileSummary::new(0, self.total_gates, 0, 0).weight();
76        let network_size_weight = ProfileSummary::new(0, 0, summary.network_size, 0).weight();
77        [
78            summary.weight(),
79            gate_weight,
80            depth_weight,
81            network_size_weight,
82            summary.preprocess_weight,
83            self.total_gates,
84            self.network_depth,
85            summary.network_size,
86            self.network_content.bit,
87            self.network_content.base,
88            self.network_content.scalar,
89            self.network_content.mersenne,
90            self.network_content.point,
91            self.pre_process.da_bits,
92            self.pre_process.arith_triples,
93            self.pre_process.arith_singlets,
94            self.pre_process.pow_pairs,
95            self.pre_process.bit_triples,
96            self.pre_process.bit_singlets,
97            self.circuit_hash,
98        ]
99        .into_iter()
100    }
101    pub fn string_values(&self) -> impl Iterator<Item = String> {
102        self.usize_values().map(|x| x.to_string())
103    }
104    fn summary(&self) -> ProfileSummary {
105        ProfileSummary::new(
106            self.network_depth,
107            self.total_gates,
108            self.network_content.network_size(),
109            self.pre_process.weight(),
110        )
111    }
112    pub fn weight(&self) -> usize {
113        self.summary().weight()
114    }
115    pub fn to_json(&self) -> String {
116        let hash_map =
117            std::iter::zip(Self::json_fields(), self.usize_values()).collect::<HashMap<_, _>>();
118        serde_json::to_string(&hash_map).expect("ProfileInfo::to_json() failed")
119    }
120}
121
122impl TryFrom<Vec<usize>> for ProfileInfo {
123    type Error = Box<dyn Error>;
124
125    fn try_from(value: Vec<usize>) -> Result<Self, Self::Error> {
126        let expected = ProfileInfo::csv_header().count();
127        if value.len() < expected {
128            return Err(format!("Expected {} records, got {}", expected, value.len()).into());
129        }
130        let res = ProfileInfo {
131            total_gates: value[5],
132            network_depth: value[6],
133            network_content: NetworkContent::try_from(&value[8..13])?,
134            pre_process: PreprocessInfo::try_from(&value[13..19])?,
135            circuit_hash: value[19],
136        };
137        Ok(res)
138    }
139}
140
141pub trait PerformanceTracker {
142    fn is_enabled(&self) -> bool;
143    fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>>;
144}
145
146#[derive(Debug)]
147pub struct PerformanceStore {
148    strict_mode: bool,
149    circuit_performances: HashMap<String, ProfileInfo>,
150    has_errored: bool,
151}
152
153impl PerformanceStore {
154    pub fn new(strict_mode: bool, circuit_depths: HashMap<String, ProfileInfo>) -> Self {
155        Self {
156            strict_mode,
157            circuit_performances: circuit_depths,
158            has_errored: false,
159        }
160    }
161    pub fn from_csv(strict_mode: bool, csv_file: &str) -> Result<Self, Box<dyn Error>> {
162        let circuit_depths = if std::fs::exists(csv_file)? {
163            let mut content = csv::Reader::from_path(csv_file)?;
164            let mut circuit_depths = HashMap::new();
165            for record in content.records() {
166                let record = record?;
167                let circuit = record[0].to_string();
168
169                let mut info_vec = Vec::with_capacity(record.len() - 1);
170
171                for s in record.iter().skip(1) {
172                    let info = usize::from_str(s)?;
173                    info_vec.push(info);
174                }
175
176                let profile_info = info_vec.try_into()?;
177
178                circuit_depths.insert(circuit, profile_info);
179            }
180            circuit_depths
181        } else {
182            if strict_mode {
183                return Err(format!("CSV performance file {} is missing.", csv_file).into());
184            }
185            HashMap::new()
186        };
187        Ok(Self::new(strict_mode, circuit_depths))
188    }
189    pub fn to_csv(&self, csv_file: &str) -> Result<(), Box<dyn Error>> {
190        if self.strict_mode {
191            return Ok(());
192        }
193        if self.has_errored {
194            return Err("Cannot save an errored performance store.".into());
195        }
196        let mut records: Vec<(&str, ProfileInfo)> = self
197            .circuit_performances
198            .iter()
199            .map(|(s, d)| (s.as_str(), *d))
200            .collect();
201        records.sort_by_key(|(a, b)| (b.network_depth == 0, *a));
202        let mut writer = csv::Writer::from_path(csv_file)?;
203        writer.write_field("Circuit Name")?;
204        writer.write_record(ProfileInfo::csv_header())?;
205        for (name, profile_info) in records {
206            writer.write_field(name)?;
207            writer.write_record(profile_info.string_values())?;
208        }
209        Ok(())
210    }
211    pub fn get_instructions(&self) -> Vec<String> {
212        self.circuit_performances.keys().cloned().collect()
213    }
214}
215
216impl PerformanceTracker for PerformanceStore {
217    fn is_enabled(&self) -> bool {
218        true
219    }
220    fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>> {
221        let entry = self.circuit_performances.entry(circuit.to_string());
222        if self.strict_mode && matches!(entry, Entry::Vacant(_)) {
223            self.has_errored = true;
224            return Err(format!("Missing circuit depth for {}", circuit).into());
225        }
226        let entry_val = entry.or_insert(profile_info);
227        if profile_info.network_depth > entry_val.network_depth {
228            self.has_errored = true;
229            return Err(format!(
230                "Computed circuit depth ({}) is greater than expected ({}) for {}.",
231                profile_info.network_depth, entry_val.network_depth, circuit
232            )
233            .into());
234        }
235        if profile_info != *entry_val {
236            if self.strict_mode {
237                self.has_errored = true;
238                return Err(format!(
239                    "Computed circuit performance ({:?}) is different than expected ({:?}) for {}. Please run the same test with `TEST_CIRCUIT= cargo test --all-features`.",
240                    profile_info, *entry_val, circuit
241                )
242                    .into());
243            } else {
244                *entry_val = profile_info;
245            }
246        }
247
248        Ok(())
249    }
250}
251
252impl<T: PerformanceTracker> PerformanceTracker for Option<T> {
253    fn is_enabled(&self) -> bool {
254        match self {
255            None => false,
256            Some(a) => a.is_enabled(),
257        }
258    }
259
260    fn track(&mut self, circuit: &str, depth: ProfileInfo) -> Result<(), Box<dyn Error>> {
261        match self {
262            None => Err("Tried to track in empty performance tracker.".into()),
263            Some(a) => a.track(circuit, depth),
264        }
265    }
266}