Skip to main content

arcis_compiler/
profile_info.rs

1use crate::{
2    network_content::NetworkContent,
3    preprocess_info::PreprocessInfo,
4    profile_summary::ProfileSummary,
5};
6use std::{
7    collections::{hash_map::Entry, HashMap},
8    error::Error,
9    str::FromStr,
10};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct ProfileInfo {
14    pub total_gates: usize,
15    pub network_depth: usize,
16    pub network_content: NetworkContent,
17    pub pre_process: PreprocessInfo,
18}
19
20impl ProfileInfo {
21    pub fn csv_header() -> impl Iterator<Item = &'static str> {
22        [
23            "Weight",
24            "Gate W.",
25            "Depth W.",
26            "🛜Size W.",
27            "Prepro. W.",
28            "Gates",
29            "Depth",
30            "🛜Size",
31            "🛜Bits",
32            "🛜Base",
33            "🛜Scal.",
34            "🛜M107",
35            "🛜Pts",
36            "daBits",
37            "Triples",
38            "Singlets",
39            "Pows",
40            "BiTriples",
41            "Binglets",
42        ]
43        .into_iter()
44    }
45    pub fn json_fields() -> impl Iterator<Item = &'static str> {
46        [
47            "weight",
48            "gate_weight",
49            "depth_weight",
50            "network_size_weight",
51            "preprocess_weight",
52            "total_gates",
53            "network_depth",
54            "network_size",
55            "network_bit",
56            "network_base",
57            "network_scalar",
58            "network_mersenne",
59            "network_point",
60            "da_bits",
61            "arith_triples",
62            "arith_singlets",
63            "pow_pairs",
64            "bit_triples",
65            "bit_singlets",
66        ]
67        .into_iter()
68    }
69    pub fn usize_values(&self) -> impl Iterator<Item = usize> {
70        let summary = self.summary();
71        let depth_weight = ProfileSummary::new(self.network_depth, 0, 0, 0).weight();
72        let gate_weight = ProfileSummary::new(0, self.total_gates, 0, 0).weight();
73        let network_size_weight = ProfileSummary::new(0, 0, summary.network_size, 0).weight();
74        [
75            summary.weight(),
76            gate_weight,
77            depth_weight,
78            network_size_weight,
79            summary.preprocess_weight,
80            self.total_gates,
81            self.network_depth,
82            summary.network_size,
83            self.network_content.bit,
84            self.network_content.base,
85            self.network_content.scalar,
86            self.network_content.mersenne,
87            self.network_content.point,
88            self.pre_process.da_bits,
89            self.pre_process.arith_triples,
90            self.pre_process.arith_singlets,
91            self.pre_process.pow_pairs,
92            self.pre_process.bit_triples,
93            self.pre_process.bit_singlets,
94        ]
95        .into_iter()
96    }
97    pub fn string_values(&self) -> impl Iterator<Item = String> {
98        self.usize_values().map(|x| x.to_string())
99    }
100    fn summary(&self) -> ProfileSummary {
101        ProfileSummary::new(
102            self.network_depth,
103            self.total_gates,
104            self.network_content.network_size(),
105            self.pre_process.weight(),
106        )
107    }
108    pub fn weight(&self) -> usize {
109        self.summary().weight()
110    }
111    pub fn to_json(&self) -> String {
112        let hash_map =
113            std::iter::zip(Self::json_fields(), self.usize_values()).collect::<HashMap<_, _>>();
114        serde_json::to_string(&hash_map).expect("ProfileInfo::to_json() failed")
115    }
116}
117
118impl TryFrom<Vec<usize>> for ProfileInfo {
119    type Error = Box<dyn Error>;
120
121    fn try_from(value: Vec<usize>) -> Result<Self, Self::Error> {
122        let expected = ProfileInfo::csv_header().count();
123        if value.len() < expected {
124            return Err(format!("Expected {} records, got {}", expected, value.len()).into());
125        }
126        let res = ProfileInfo {
127            total_gates: value[5],
128            network_depth: value[6],
129            network_content: NetworkContent::try_from(&value[8..13])?,
130            pre_process: PreprocessInfo::try_from(&value[13..])?,
131        };
132        Ok(res)
133    }
134}
135
136pub trait PerformanceTracker {
137    fn is_enabled(&self) -> bool;
138    fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>>;
139}
140
141#[derive(Debug)]
142pub struct PerformanceStore {
143    strict_mode: bool,
144    circuit_performances: HashMap<String, ProfileInfo>,
145    has_errored: bool,
146}
147
148impl PerformanceStore {
149    pub fn new(strict_mode: bool, circuit_depths: HashMap<String, ProfileInfo>) -> Self {
150        Self {
151            strict_mode,
152            circuit_performances: circuit_depths,
153            has_errored: false,
154        }
155    }
156    pub fn from_csv(strict_mode: bool, csv_file: &str) -> Result<Self, Box<dyn Error>> {
157        let circuit_depths = if std::fs::exists(csv_file)? {
158            let mut content = csv::Reader::from_path(csv_file)?;
159            let mut circuit_depths = HashMap::new();
160            for record in content.records() {
161                let record = record?;
162                let circuit = record[0].to_string();
163
164                let mut info_vec = Vec::with_capacity(record.len() - 1);
165
166                for s in record.iter().skip(1) {
167                    let info = usize::from_str(s)?;
168                    info_vec.push(info);
169                }
170
171                let profile_info = info_vec.try_into()?;
172
173                circuit_depths.insert(circuit, profile_info);
174            }
175            circuit_depths
176        } else {
177            if strict_mode {
178                return Err(format!("CSV performance file {} is missing.", csv_file).into());
179            }
180            HashMap::new()
181        };
182        Ok(Self::new(strict_mode, circuit_depths))
183    }
184    pub fn to_csv(&self, csv_file: &str) -> Result<(), Box<dyn Error>> {
185        if self.strict_mode {
186            return Ok(());
187        }
188        if self.has_errored {
189            return Err("Cannot save an errored performance store.".into());
190        }
191        let mut records: Vec<(&str, ProfileInfo)> = self
192            .circuit_performances
193            .iter()
194            .map(|(s, d)| (s.as_str(), *d))
195            .collect();
196        records.sort_by_key(|(a, b)| (b.network_depth == 0, *a));
197        let mut writer = csv::Writer::from_path(csv_file)?;
198        writer.write_field("Circuit Name")?;
199        writer.write_record(ProfileInfo::csv_header())?;
200        for (name, profile_info) in records {
201            writer.write_field(name)?;
202            writer.write_record(profile_info.string_values())?;
203        }
204        Ok(())
205    }
206    pub fn get_instructions(&self) -> Vec<String> {
207        self.circuit_performances.keys().cloned().collect()
208    }
209}
210
211impl PerformanceTracker for PerformanceStore {
212    fn is_enabled(&self) -> bool {
213        true
214    }
215    fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>> {
216        let entry = self.circuit_performances.entry(circuit.to_string());
217        if self.strict_mode && matches!(entry, Entry::Vacant(_)) {
218            self.has_errored = true;
219            return Err(format!("Missing circuit depth for {}", circuit).into());
220        }
221        let entry_val = entry.or_insert(profile_info);
222        if profile_info.network_depth > entry_val.network_depth {
223            self.has_errored = true;
224            return Err(format!(
225                "Computed circuit depth ({}) is greater than expected ({}) for {}.",
226                profile_info.network_depth, entry_val.network_depth, circuit
227            )
228            .into());
229        }
230        if profile_info != *entry_val {
231            if self.strict_mode {
232                self.has_errored = true;
233                return Err(format!(
234                    "Computed circuit performance ({:?}) is different than expected ({:?}) for {}. Please run the same test with `TEST_CIRCUIT= cargo test --all-features`.",
235                    profile_info, *entry_val, circuit
236                )
237                    .into());
238            } else {
239                *entry_val = profile_info;
240            }
241        }
242
243        Ok(())
244    }
245}
246
247impl<T: PerformanceTracker> PerformanceTracker for Option<T> {
248    fn is_enabled(&self) -> bool {
249        match self {
250            None => false,
251            Some(a) => a.is_enabled(),
252        }
253    }
254
255    fn track(&mut self, circuit: &str, depth: ProfileInfo) -> Result<(), Box<dyn Error>> {
256        match self {
257            None => Err("Tried to track in empty performance tracker.".into()),
258            Some(a) => a.track(circuit, depth),
259        }
260    }
261}