arcis-compiler 0.9.7

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    network_content::NetworkContent,
    preprocess_info::PreprocessInfo,
    profile_summary::ProfileSummary,
};
use std::{
    collections::{hash_map::Entry, HashMap},
    error::Error,
    str::FromStr,
};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ProfileInfo {
    pub total_gates: usize,
    pub network_depth: usize,
    pub network_content: NetworkContent,
    pub pre_process: PreprocessInfo,
    pub circuit_hash: usize,
}

impl ProfileInfo {
    pub fn csv_header() -> impl Iterator<Item = &'static str> {
        [
            "Weight",
            "Gate W.",
            "Depth W.",
            "🛜Size W.",
            "Prepro. W.",
            "Gates",
            "Depth",
            "🛜Size",
            "🛜Bits",
            "🛜Base",
            "🛜Scal.",
            "🛜M107",
            "🛜Pts",
            "daBits",
            "Triples",
            "Singlets",
            "Pows",
            "BiTriples",
            "Binglets",
            "Hash",
        ]
        .into_iter()
    }
    pub fn json_fields() -> impl Iterator<Item = &'static str> {
        [
            "weight",
            "gate_weight",
            "depth_weight",
            "network_size_weight",
            "preprocess_weight",
            "total_gates",
            "network_depth",
            "network_size",
            "network_bit",
            "network_base",
            "network_scalar",
            "network_mersenne",
            "network_point",
            "da_bits",
            "arith_triples",
            "arith_singlets",
            "pow_pairs",
            "bit_triples",
            "bit_singlets",
            "hash",
        ]
        .into_iter()
    }
    pub fn usize_values(&self) -> impl Iterator<Item = usize> {
        let summary = self.summary();
        let depth_weight = ProfileSummary::new(self.network_depth, 0, 0, 0).weight();
        let gate_weight = ProfileSummary::new(0, self.total_gates, 0, 0).weight();
        let network_size_weight = ProfileSummary::new(0, 0, summary.network_size, 0).weight();
        [
            summary.weight(),
            gate_weight,
            depth_weight,
            network_size_weight,
            summary.preprocess_weight,
            self.total_gates,
            self.network_depth,
            summary.network_size,
            self.network_content.bit,
            self.network_content.base,
            self.network_content.scalar,
            self.network_content.mersenne,
            self.network_content.point,
            self.pre_process.da_bits,
            self.pre_process.arith_triples,
            self.pre_process.arith_singlets,
            self.pre_process.pow_pairs,
            self.pre_process.bit_triples,
            self.pre_process.bit_singlets,
            self.circuit_hash,
        ]
        .into_iter()
    }
    pub fn string_values(&self) -> impl Iterator<Item = String> {
        self.usize_values().map(|x| x.to_string())
    }
    fn summary(&self) -> ProfileSummary {
        ProfileSummary::new(
            self.network_depth,
            self.total_gates,
            self.network_content.network_size(),
            self.pre_process.weight(),
        )
    }
    pub fn weight(&self) -> usize {
        self.summary().weight()
    }
    pub fn to_json(&self) -> String {
        let hash_map =
            std::iter::zip(Self::json_fields(), self.usize_values()).collect::<HashMap<_, _>>();
        serde_json::to_string(&hash_map).expect("ProfileInfo::to_json() failed")
    }
}

impl TryFrom<Vec<usize>> for ProfileInfo {
    type Error = Box<dyn Error>;

    fn try_from(value: Vec<usize>) -> Result<Self, Self::Error> {
        let expected = ProfileInfo::csv_header().count();
        if value.len() < expected {
            return Err(format!("Expected {} records, got {}", expected, value.len()).into());
        }
        let res = ProfileInfo {
            total_gates: value[5],
            network_depth: value[6],
            network_content: NetworkContent::try_from(&value[8..13])?,
            pre_process: PreprocessInfo::try_from(&value[13..19])?,
            circuit_hash: value[19],
        };
        Ok(res)
    }
}

pub trait PerformanceTracker {
    fn is_enabled(&self) -> bool;
    fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>>;
}

#[derive(Debug)]
pub struct PerformanceStore {
    strict_mode: bool,
    circuit_performances: HashMap<String, ProfileInfo>,
    has_errored: bool,
}

impl PerformanceStore {
    pub fn new(strict_mode: bool, circuit_depths: HashMap<String, ProfileInfo>) -> Self {
        Self {
            strict_mode,
            circuit_performances: circuit_depths,
            has_errored: false,
        }
    }
    pub fn from_csv(strict_mode: bool, csv_file: &str) -> Result<Self, Box<dyn Error>> {
        let circuit_depths = if std::fs::exists(csv_file)? {
            let mut content = csv::Reader::from_path(csv_file)?;
            let mut circuit_depths = HashMap::new();
            for record in content.records() {
                let record = record?;
                let circuit = record[0].to_string();

                let mut info_vec = Vec::with_capacity(record.len() - 1);

                for s in record.iter().skip(1) {
                    let info = usize::from_str(s)?;
                    info_vec.push(info);
                }

                let profile_info = info_vec.try_into()?;

                circuit_depths.insert(circuit, profile_info);
            }
            circuit_depths
        } else {
            if strict_mode {
                return Err(format!("CSV performance file {} is missing.", csv_file).into());
            }
            HashMap::new()
        };
        Ok(Self::new(strict_mode, circuit_depths))
    }
    pub fn to_csv(&self, csv_file: &str) -> Result<(), Box<dyn Error>> {
        if self.strict_mode {
            return Ok(());
        }
        if self.has_errored {
            return Err("Cannot save an errored performance store.".into());
        }
        let mut records: Vec<(&str, ProfileInfo)> = self
            .circuit_performances
            .iter()
            .map(|(s, d)| (s.as_str(), *d))
            .collect();
        records.sort_by_key(|(a, b)| (b.network_depth == 0, *a));
        let mut writer = csv::Writer::from_path(csv_file)?;
        writer.write_field("Circuit Name")?;
        writer.write_record(ProfileInfo::csv_header())?;
        for (name, profile_info) in records {
            writer.write_field(name)?;
            writer.write_record(profile_info.string_values())?;
        }
        Ok(())
    }
    pub fn get_instructions(&self) -> Vec<String> {
        self.circuit_performances.keys().cloned().collect()
    }
}

impl PerformanceTracker for PerformanceStore {
    fn is_enabled(&self) -> bool {
        true
    }
    fn track(&mut self, circuit: &str, profile_info: ProfileInfo) -> Result<(), Box<dyn Error>> {
        let entry = self.circuit_performances.entry(circuit.to_string());
        if self.strict_mode && matches!(entry, Entry::Vacant(_)) {
            self.has_errored = true;
            return Err(format!("Missing circuit depth for {}", circuit).into());
        }
        let entry_val = entry.or_insert(profile_info);
        if profile_info.network_depth > entry_val.network_depth {
            self.has_errored = true;
            return Err(format!(
                "Computed circuit depth ({}) is greater than expected ({}) for {}.",
                profile_info.network_depth, entry_val.network_depth, circuit
            )
            .into());
        }
        if profile_info != *entry_val {
            if self.strict_mode {
                self.has_errored = true;
                return Err(format!(
                    "Computed circuit performance ({:?}) is different than expected ({:?}) for {}. Please run the same test with `TEST_CIRCUIT= cargo test --all-features`.",
                    profile_info, *entry_val, circuit
                )
                    .into());
            } else {
                *entry_val = profile_info;
            }
        }

        Ok(())
    }
}

impl<T: PerformanceTracker> PerformanceTracker for Option<T> {
    fn is_enabled(&self) -> bool {
        match self {
            None => false,
            Some(a) => a.is_enabled(),
        }
    }

    fn track(&mut self, circuit: &str, depth: ProfileInfo) -> Result<(), Box<dyn Error>> {
        match self {
            None => Err("Tried to track in empty performance tracker.".into()),
            Some(a) => a.track(circuit, depth),
        }
    }
}