extern crate ciphercore_base;
use ciphercore_base::data_types::{
get_size_in_bits, Type, BIT, INT128, INT16, INT32, INT64, INT8, UINT128, UINT16, UINT32,
UINT64, UINT8,
};
use ciphercore_base::errors::Result;
use ciphercore_base::evaluators::simple_evaluator::SimpleEvaluator;
use ciphercore_base::graphs::{Context, Graph, Node, NodeAnnotation, Operation};
use ciphercore_base::inline::inline_common::DepthOptimizationLevel;
use ciphercore_base::inline::inline_ops::{InlineConfig, InlineMode};
use ciphercore_base::mpc::mpc_compiler::prepare_context;
use ciphercore_utils::execute_main::execute_main;
use std::fs;
use std::collections::HashMap;
use clap::{ArgAction, ArgEnum, Parser};
struct InputInfo {
name: String,
type_string: String,
}
fn is_network_node(node: Node) -> Result<bool> {
if let Operation::NOP = node.get_operation() {
for annotation in node.get_annotations()? {
if let NodeAnnotation::Send(_, _) = annotation {
return Ok(true);
}
}
}
Ok(false)
}
fn get_buffer_length_in_items(t: Type) -> u64 {
if t.is_array() {
t.get_shape().iter().product()
} else {
1
}
}
fn calculate_integer_operations(node: Node) -> Result<u64> {
match node.get_operation() {
Operation::Add | Operation::Subtract | Operation::Multiply => {
Ok(get_buffer_length_in_items(node.get_type()?))
}
Operation::Truncate(_) | Operation::Sum(_) | Operation::PRF(_, _) => {
let dependency = node.get_node_dependencies()[0].clone();
let inp_t = dependency.get_type()?;
Ok(get_buffer_length_in_items(inp_t))
}
Operation::Random(t) => {
Ok(get_buffer_length_in_items(t))
}
Operation::Dot | Operation::Matmul => {
let dependency0 = node.get_node_dependencies()[0].clone();
let dependency1 = node.get_node_dependencies()[1].clone();
let type0 = dependency0.get_type()?;
let type1 = dependency1.get_type()?;
let shape0 = type0.get_shape();
let shape1 = type1.get_shape();
let result_len = get_buffer_length_in_items(node.get_type()?);
if shape0.len() == 1 && shape1.len() == 1 {
Ok(shape0[0])
} else {
let middle_dim = if shape1.len() > 1 {
shape1[shape1.len() - 2]
} else {
shape1[0]
};
Ok(middle_dim * result_len)
}
}
_ => Ok(1),
}
}
fn calculate_network_rounds(graph: Graph) -> Result<u32> {
let nodes = graph.get_nodes();
let mut nops = HashMap::<Node, u32>::new();
for node in nodes {
let node_is_network = u32::from(is_network_node(node.clone())?);
let dependee_nops_max = node
.get_node_dependencies()
.iter()
.map(|n| *nops.get(n).unwrap_or(&0))
.max()
.unwrap_or(0);
nops.insert(node, dependee_nops_max + node_is_network);
}
Ok(*nops.values().max().unwrap())
}
const TWO: u64 = 2;
fn format_bits(t_in_bits: u64) -> String {
let t_in_bytes = t_in_bits / 8;
if t_in_bytes >= TWO.pow(30) {
format!("{:.2}GB", t_in_bytes as f32 / TWO.pow(30) as f32)
} else if t_in_bytes >= TWO.pow(20) {
format!("{:.2}MB", t_in_bytes as f32 / TWO.pow(20) as f32)
} else if t_in_bytes >= TWO.pow(10) {
format!("{:.2}KB", t_in_bytes as f32 / TWO.pow(10) as f32)
} else {
format!("{t_in_bytes}B")
}
}
const TEN: u64 = 10;
fn format_operations(ops: u64) -> String {
if ops >= TEN.pow(9) {
format!("{:.2}G Ops", ops as f32 / TEN.pow(9) as f32)
} else if ops >= TEN.pow(6) {
format!("{:.2}M Ops", ops as f32 / TEN.pow(6) as f32)
} else if ops >= TEN.pow(3) {
format!("{:.2}K Ops", ops as f32 / TEN.pow(3) as f32)
} else {
format!("{ops} Ops")
}
}
struct RamStats {
peak_ram_bits: u64,
total_ram_bits: u64,
}
fn calculate_ram_stats(graph: Graph) -> Result<RamStats> {
let nodes = graph.get_nodes();
let mut node_size = vec![];
for node in nodes.iter() {
node_size.push(get_size_in_bits(node.get_type()?)?);
}
let mut remaining_dependents = vec![0; nodes.len()];
for node in nodes.iter() {
for dep in node.get_node_dependencies() {
remaining_dependents[dep.get_id() as usize] += 1;
}
}
let mut max_ram = 0;
let mut cur_ram = 0;
let mut total_ram = 0;
for node in nodes.iter() {
cur_ram += node_size[node.get_id() as usize];
total_ram += node_size[node.get_id() as usize];
for dep in node.get_node_dependencies() {
let dep_id = dep.get_id() as usize;
remaining_dependents[dep_id] -= 1;
if remaining_dependents[dep_id] == 0 {
cur_ram -= node_size[dep_id];
}
}
max_ram = max_ram.max(cur_ram);
}
Ok(RamStats {
peak_ram_bits: max_ram,
total_ram_bits: total_ram,
})
}
pub(crate) fn print_stats(graph: Graph) -> Result<()> {
let mut cnt = HashMap::<String, u64>::new();
let mut inputs = Vec::<InputInfo>::new();
let mut network_traffic_in_bits = 0;
let mut total_integer_operations = 0;
let mut total_bit_operations = 0;
let mut total_8bits_operations = 0;
let mut total_16bits_operations = 0;
let mut total_32bits_operations = 0;
let mut total_64bits_operations = 0;
let mut total_128bits_operations = 0;
for node in graph.get_nodes() {
let op = node.get_operation();
let op_name = format!("{op}");
*cnt.entry(op_name).or_insert(0) += 1;
match op {
Operation::Input(_) => {
let input = InputInfo {
name: node.get_name()?.unwrap_or_else(|| "unnamed".to_owned()),
type_string: format!("{}", node.get_type()?),
};
inputs.push(input);
}
Operation::NOP => {
if is_network_node(node.clone())? {
network_traffic_in_bits += get_size_in_bits(node.get_type()?)?;
}
}
Operation::Add
| Operation::Subtract
| Operation::Multiply
| Operation::Dot
| Operation::Matmul
| Operation::Truncate(_)
| Operation::Sum(_) => {
let st = node.get_type()?.get_scalar_type();
let ops = calculate_integer_operations(node.clone())?;
match st {
BIT => total_bit_operations += ops,
UINT8 | INT8 => total_8bits_operations += ops,
UINT16 | INT16 => total_16bits_operations += ops,
UINT32 | INT32 => total_32bits_operations += ops,
UINT64 | INT64 => total_64bits_operations += ops,
UINT128 | INT128 => total_128bits_operations += ops,
};
total_integer_operations += ops;
}
_ => {}
}
}
let mut entries: Vec<(String, u64)> = cnt.iter().map(|e| (e.0.clone(), *e.1)).collect();
let network_rounds = calculate_network_rounds(graph.clone())?;
entries.sort_by_key(|e| -(e.1 as i64));
println!("-------Stats--------");
println!("Inputs: ",);
for (i, input) in inputs.iter().enumerate() {
println!(" {}. Name:{}", i + 1, input.name);
println!(" Type:{}", input.type_string);
}
let output_node = graph.get_output_node()?;
let output_name = output_node
.get_name()?
.unwrap_or_else(|| "unnamed".to_owned());
let output_type = format!("{}", output_node.get_type()?);
println!("Output: ",);
println!(" Name:{output_name}");
println!(" Type:{output_type}");
println!("Network rounds: {network_rounds}");
println!("Network traffic: {}", format_bits(network_traffic_in_bits));
let ram_stats = calculate_ram_stats(graph.clone())?;
println!("Peak RAM: {}", format_bits(ram_stats.peak_ram_bits));
println!("Total RAM: {}", format_bits(ram_stats.total_ram_bits));
println!(
"Total number of integer arithmetic operations: {}",
format_operations(total_integer_operations)
);
println!(
" Total number of 1-bit arithmetic operations: {}",
format_operations(total_bit_operations)
);
println!(
" Total number of 8-bit arithmetic operations: {}",
format_operations(total_8bits_operations)
);
println!(
" Total number of 16-bit arithmetic operations: {}",
format_operations(total_16bits_operations)
);
println!(
" Total number of 32-bit arithmetic operations: {}",
format_operations(total_32bits_operations)
);
println!(
" Total number of 64-bit arithmetic operations: {}",
format_operations(total_64bits_operations)
);
println!(
" Total number of 128-bit arithmetic operations: {}",
format_operations(total_128bits_operations)
);
println!("Total operations: {}", graph.get_nodes().len());
println!("Operations: ",);
for e in entries {
println!(" {:<10}\t{}", e.0, e.1);
}
Ok(())
}
#[derive(Debug, ArgEnum, Clone)]
enum InlineModeArg {
Simple,
DepthOptimizedDefault,
DepthOptimizedExtreme,
}
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about=None)]
struct Args {
#[clap(value_parser)]
input_path: String,
#[clap(action = ArgAction::SetTrue)]
prepare: bool,
#[clap(arg_enum, value_parser)]
inline_mode: Option<InlineModeArg>,
}
fn get_evaluator() -> Result<SimpleEvaluator> {
SimpleEvaluator::new(None)
}
fn get_inline_mode(mode_val: Option<InlineModeArg>) -> InlineMode {
match mode_val {
Some(mode) => match mode {
InlineModeArg::Simple => InlineMode::Simple,
InlineModeArg::DepthOptimizedDefault => {
InlineMode::DepthOptimized(DepthOptimizationLevel::Default)
}
InlineModeArg::DepthOptimizedExtreme => {
InlineMode::DepthOptimized(DepthOptimizationLevel::Extreme)
}
},
None => InlineMode::Simple,
}
}
fn main() {
env_logger::init();
execute_main(|| -> Result<()> {
let args = Args::parse();
let buffer = fs::read_to_string(&args.input_path)?;
let context = serde_json::from_str::<Context>(&buffer).unwrap();
let context2 = if args.prepare {
let evaluator0 = get_evaluator()?;
prepare_context(
context,
InlineConfig {
default_mode: get_inline_mode(args.inline_mode),
..Default::default()
},
evaluator0,
false,
)?
} else {
context
};
eprintln!("Calculating stats...");
print_stats(context2.get_main_graph()?)?;
Ok(())
});
}