use atomic_refcell::AtomicRefCell;
use std::collections::HashMap;
use std::fmt;
use std::hash::Hash;
use std::hash::Hasher;
use std::ptr;
use std::sync::Arc;
use std::sync::Weak;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::constants;
use crate::custom_ops::CustomOperation;
use crate::data_types::{get_size_estimation_in_bits, ArrayShape, ScalarType, Type};
use crate::data_values::Value;
use crate::errors::Result;
use crate::type_inference::{create_type_inference_worker, TypeInferenceWorker};
use crate::version::{VersionedData, DATA_VERSION};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum SliceElement {
SingleIndex(i64),
SubArray(Option<i64>, Option<i64>, Option<i64>),
Ellipsis,
}
pub type Slice = Vec<SliceElement>;
#[doc(hidden)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum Operation {
Input(Type),
Add,
Subtract,
Multiply,
Dot,
Matmul,
Truncate(u64),
Sum(ArrayShape),
PermuteAxes(ArrayShape),
Get(ArrayShape),
GetSlice(Slice),
Reshape(Type),
NOP,
Random(Type),
PRF(u64, Type),
Stack(ArrayShape),
Constant(Type, Value),
A2B,
B2A(ScalarType),
CreateTuple,
CreateNamedTuple(Vec<String>),
CreateVector(Type),
TupleGet(u64),
NamedTupleGet(String),
VectorGet,
Zip,
Repeat(u64),
Call,
Iterate,
ArrayToVector,
VectorToArray,
Custom(CustomOperation),
}
impl fmt::Display for Operation {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let operation_name = if let Operation::Custom(custom_op) = self {
custom_op.get_name()
} else {
let operation_w_type_str = format!("{:?}", *self);
let split_for_operation = operation_w_type_str.split('(');
let vec_operation_and_types: Vec<&str> = split_for_operation.collect();
if vec_operation_and_types.is_empty() {
"-null-".to_owned()
} else {
vec_operation_and_types[0].to_owned()
}
};
write!(f, "{}", operation_name)
}
}
struct NodeBody {
graph: WeakGraph,
node_dependencies: Vec<WeakNode>,
graph_dependencies: Vec<WeakGraph>,
operation: Operation,
id: u64,
}
#[derive(Serialize, Deserialize)]
struct SerializableNodeBody {
node_dependencies: Vec<u64>,
graph_dependencies: Vec<u64>,
operation: Operation,
}
type NodeBodyPointer = Arc<AtomicRefCell<NodeBody>>;
pub struct Node {
body: NodeBodyPointer,
}
type SerializableNode = Arc<SerializableNodeBody>;
impl Clone for Node {
fn clone(&self) -> Self {
Node {
body: self.body.clone(),
}
}
}
impl PartialEq for Node {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.body, &other.body)
}
}
impl Eq for Node {}
impl Hash for Node {
fn hash<H: Hasher>(&self, state: &mut H) {
ptr::hash(&*self.body, state);
}
}
impl Node {
pub fn get_graph(&self) -> Graph {
self.body.borrow().graph.upgrade()
}
pub fn get_node_dependencies(&self) -> Vec<Node> {
self.body
.borrow()
.node_dependencies
.iter()
.map(|n| n.upgrade())
.collect()
}
pub fn get_graph_dependencies(&self) -> Vec<Graph> {
self.body
.borrow()
.graph_dependencies
.iter()
.map(|g| g.upgrade())
.collect()
}
pub fn get_operation(&self) -> Operation {
self.body.borrow().operation.clone()
}
pub fn get_id(&self) -> u64 {
self.body.borrow().id
}
pub fn get_global_id(&self) -> (u64, u64) {
(self.get_graph().get_id(), self.get_id())
}
fn make_serializable(&self) -> SerializableNode {
Arc::new(SerializableNodeBody {
node_dependencies: self
.get_node_dependencies()
.iter()
.map(|n| n.get_id())
.collect(),
graph_dependencies: self
.get_graph_dependencies()
.iter()
.map(|n| n.get_id())
.collect(),
operation: self.get_operation(),
})
}
pub fn get_type(&self) -> Result<Type> {
let context = self.get_graph().get_context();
let mut context_body = context.body.borrow_mut();
if let Some(tc) = &mut context_body.type_checker {
tc.process_node(self.clone())
} else {
Err(runtime_error!("Type checker is not available"))
}
}
}
impl Node {
fn downgrade(&self) -> WeakNode {
WeakNode {
body: Arc::downgrade(&self.body),
}
}
}
type WeakNodeBodyPointer = Weak<AtomicRefCell<NodeBody>>;
struct WeakNode {
body: WeakNodeBodyPointer,
}
impl WeakNode {
fn upgrade(&self) -> Node {
Node {
body: self.body.upgrade().unwrap(),
}
}
}
impl Clone for WeakNode {
fn clone(&self) -> Self {
WeakNode {
body: self.body.clone(),
}
}
}
struct GraphBody {
finalized: bool,
nodes: Vec<Node>,
output_node: Option<WeakNode>,
id: u64,
context: WeakContext,
}
#[derive(Serialize, Deserialize)]
struct SerializableGraphBody {
finalized: bool,
nodes: Vec<SerializableNode>,
output_node: Option<u64>,
}
type GraphBodyPointer = Arc<AtomicRefCell<GraphBody>>;
pub struct Graph {
body: GraphBodyPointer,
}
type SerializableGraph = Arc<SerializableGraphBody>;
impl fmt::Debug for Graph {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Graph")
.field("body", &self.body.as_ptr())
.finish()
}
}
impl Clone for Graph {
fn clone(&self) -> Self {
Graph {
body: self.body.clone(),
}
}
}
impl PartialEq for Graph {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.body, &other.body)
}
}
impl Eq for Graph {}
impl Hash for Graph {
fn hash<H: Hasher>(&self, state: &mut H) {
ptr::hash(&*self.body, state);
}
}
impl Graph {
pub(crate) fn add_node(
&self,
node_dependencies: Vec<Node>,
graph_dependencies: Vec<Graph>,
operation: Operation,
) -> Result<Node> {
if self.is_finalized() {
return Err(runtime_error!("Can't add a node to a finalized graph"));
}
for dependency in &node_dependencies {
if dependency.get_graph() != *self
|| dependency.get_id() >= self.body.borrow().nodes.len() as u64
|| self.body.borrow().nodes[dependency.get_id() as usize] != *dependency
{
return Err(runtime_error!(
"Can't add a node with invalid node dependencies"
));
}
}
for dependency in &graph_dependencies {
if dependency.get_context() != self.get_context()
|| !dependency.is_finalized()
|| dependency.get_id() >= self.get_id()
{
return Err(runtime_error!(
"Can't add a node with invalid graph dependencies"
));
}
}
let id = self.body.borrow().nodes.len() as u64;
let result = Node {
body: Arc::new(AtomicRefCell::new(NodeBody {
graph: self.downgrade(),
node_dependencies: node_dependencies.iter().map(|n| n.downgrade()).collect(),
graph_dependencies: graph_dependencies.iter().map(|g| g.downgrade()).collect(),
operation,
id,
})),
};
{
let mut cell = self.body.borrow_mut();
cell.nodes.push(result.clone());
}
let mut context_has_type_checker = false;
{
let context = self.get_context();
let mut context_cell = context.body.borrow_mut();
let type_checker = &mut context_cell.type_checker;
if type_checker.is_some() {
context_has_type_checker = true;
}
}
if context_has_type_checker {
let type_checking_result = result.get_type();
if type_checking_result.is_err() {
self.remove_last_node(result)?;
return Err(type_checking_result.err().expect("Should not be here"));
}
let type_result = type_checking_result?;
let size_estimate = get_size_estimation_in_bits(type_result);
if size_estimate.is_err() {
self.remove_last_node(result)?;
return Err(runtime_error!("Trying to add a node with invalid size"));
}
if size_estimate? > constants::MAX_INDIVIDUAL_NODE_SIZE {
self.remove_last_node(result)?;
return Err(runtime_error!(
"Trying to add a node larger than MAX_INDIVIDUAL_NODE_SIZE"
));
}
let context = self.get_context();
let size_checking_result = context.try_update_total_size(result.clone());
if size_checking_result.is_err() {
self.remove_last_node(result)?;
return Err(size_checking_result.err().expect("Should not be here"));
}
}
Ok(result)
}
fn remove_last_node(&self, n: Node) -> Result<()> {
if n.get_graph() != *self {
return Err(runtime_error!(
"The node to be removed from a different graph"
));
}
{
let cell = self.body.borrow();
if n != *cell
.nodes
.last()
.ok_or_else(|| runtime_error!("Nodes list is empty"))?
{
return Err(runtime_error!(
"The node to be removed is not the last node"
));
}
};
let context = self.get_context();
context.unregister_node(n.clone())?;
let mut context_body = context.body.borrow_mut();
if let Some(tc) = &mut context_body.type_checker {
tc.unregister_node(n)?;
}
let mut cell = self.body.borrow_mut();
cell.nodes.pop();
Ok(())
}
pub fn input(&self, input_type: Type) -> Result<Node> {
self.add_node(vec![], vec![], Operation::Input(input_type))
}
pub fn add(&self, a: Node, b: Node) -> Result<Node> {
self.add_node(vec![a, b], vec![], Operation::Add)
}
pub fn subtract(&self, a: Node, b: Node) -> Result<Node> {
self.add_node(vec![a, b], vec![], Operation::Subtract)
}
pub fn multiply(&self, a: Node, b: Node) -> Result<Node> {
self.add_node(vec![a, b], vec![], Operation::Multiply)
}
pub fn dot(&self, a: Node, b: Node) -> Result<Node> {
self.add_node(vec![a, b], vec![], Operation::Dot)
}
pub fn matmul(&self, a: Node, b: Node) -> Result<Node> {
self.add_node(vec![a, b], vec![], Operation::Matmul)
}
pub fn truncate(&self, a: Node, scale: u64) -> Result<Node> {
self.add_node(vec![a], vec![], Operation::Truncate(scale))
}
pub fn sum(&self, a: Node, axes: ArrayShape) -> Result<Node> {
self.add_node(vec![a], vec![], Operation::Sum(axes))
}
pub fn permute_axes(&self, a: Node, axes: ArrayShape) -> Result<Node> {
self.add_node(vec![a], vec![], Operation::PermuteAxes(axes))
}
pub fn get(&self, a: Node, index: ArrayShape) -> Result<Node> {
self.add_node(vec![a], vec![], Operation::Get(index))
}
pub fn get_slice(&self, a: Node, slice: Slice) -> Result<Node> {
self.add_node(vec![a], vec![], Operation::GetSlice(slice))
}
pub fn reshape(&self, a: Node, new_type: Type) -> Result<Node> {
let size_estimate = get_size_estimation_in_bits(new_type.clone());
if size_estimate.is_err() {
return Err(runtime_error!(
"Trying to add a reshape node with invalid type size"
));
}
if size_estimate? > constants::MAX_INDIVIDUAL_NODE_SIZE {
return Err(runtime_error!(
"Trying to add a reshape node larger than MAX_INDIVIDUAL_NODE_SIZE"
));
}
self.add_node(vec![a], vec![], Operation::Reshape(new_type))
}
pub(crate) fn nop(&self, a: Node) -> Result<Node> {
self.add_node(vec![a], vec![], Operation::NOP)
}
#[doc(hidden)]
pub fn random(&self, output_type: Type) -> Result<Node> {
self.add_node(vec![], vec![], Operation::Random(output_type))
}
pub(crate) fn prf(&self, key: Node, iv: u64, output_type: Type) -> Result<Node> {
self.add_node(vec![key], vec![], Operation::PRF(iv, output_type))
}
pub fn stack(&self, nodes: Vec<Node>, outer_shape: ArrayShape) -> Result<Node> {
self.add_node(nodes, vec![], Operation::Stack(outer_shape))
}
pub fn constant(&self, output_type: Type, value: Value) -> Result<Node> {
self.add_node(vec![], vec![], Operation::Constant(output_type, value))
}
pub fn a2b(&self, a: Node) -> Result<Node> {
self.add_node(vec![a], vec![], Operation::A2B)
}
pub fn b2a(&self, a: Node, scalar_type: ScalarType) -> Result<Node> {
self.add_node(vec![a], vec![], Operation::B2A(scalar_type))
}
pub fn create_tuple(&self, elements: Vec<Node>) -> Result<Node> {
self.add_node(elements, vec![], Operation::CreateTuple)
}
pub fn create_vector(&self, element_type: Type, elements: Vec<Node>) -> Result<Node> {
self.add_node(elements, vec![], Operation::CreateVector(element_type))
}
pub fn create_named_tuple(&self, elements: Vec<(String, Node)>) -> Result<Node> {
let mut nodes = vec![];
let mut names = vec![];
for (name, node) in elements {
nodes.push(node);
names.push(name);
}
self.add_node(nodes, vec![], Operation::CreateNamedTuple(names))
}
pub fn tuple_get(&self, tuple: Node, index: u64) -> Result<Node> {
self.add_node(vec![tuple], vec![], Operation::TupleGet(index))
}
pub fn named_tuple_get(&self, tuple: Node, key: String) -> Result<Node> {
self.add_node(vec![tuple], vec![], Operation::NamedTupleGet(key))
}
pub fn vector_get(&self, vec: Node, index: Node) -> Result<Node> {
self.add_node(vec![vec, index], vec![], Operation::VectorGet)
}
pub fn zip(&self, nodes: Vec<Node>) -> Result<Node> {
self.add_node(nodes, vec![], Operation::Zip)
}
pub fn repeat(&self, a: Node, n: u64) -> Result<Node> {
self.add_node(vec![a], vec![], Operation::Repeat(n))
}
pub fn call(&self, graph: Graph, arguments: Vec<Node>) -> Result<Node> {
self.add_node(arguments, vec![graph], Operation::Call)
}
pub fn iterate(&self, graph: Graph, state: Node, input: Node) -> Result<Node> {
self.add_node(vec![state, input], vec![graph], Operation::Iterate)
}
pub fn array_to_vector(&self, a: Node) -> Result<Node> {
self.add_node(vec![a], vec![], Operation::ArrayToVector)
}
pub fn vector_to_array(&self, a: Node) -> Result<Node> {
self.add_node(vec![a], vec![], Operation::VectorToArray)
}
pub fn custom_op(&self, op: CustomOperation, arguments: Vec<Node>) -> Result<Node> {
self.add_node(arguments, vec![], Operation::Custom(op))
}
pub fn finalize(&self) -> Result<Graph> {
let output_node = self.body.borrow_mut().output_node.clone();
match output_node {
Some(_) => {
self.body.borrow_mut().finalized = true;
Ok(self.clone())
}
None => Err(runtime_error!("Output node is not set")),
}
}
pub(super) fn is_finalized(&self) -> bool {
self.body.borrow().finalized
}
pub(super) fn check_finalized(&self) -> Result<()> {
if !self.is_finalized() {
return Err(runtime_error!("Graph is not finalized"));
}
Ok(())
}
pub fn get_nodes(&self) -> Vec<Node> {
self.body.borrow().nodes.clone()
}
pub fn set_output_node(&self, output_node: Node) -> Result<()> {
let current_output_node = self.body.borrow().output_node.clone();
match current_output_node {
Some(_) => Err(runtime_error!("Output node is already set")),
None => {
if output_node.get_graph() != *self {
Err(runtime_error!("Output node has to be from the same graph"))
} else {
self.body.borrow_mut().output_node = Some(output_node.downgrade());
Ok(())
}
}
}
}
pub fn get_output_node(&self) -> Result<Node> {
let current_output_node = self.body.borrow().output_node.clone();
match current_output_node {
Some(output_node) => Ok(output_node.upgrade()),
None => Err(runtime_error!("Output node is not set")),
}
}
pub fn get_id(&self) -> u64 {
self.body.borrow().id
}
pub fn get_num_nodes(&self) -> u64 {
self.body.borrow().nodes.len() as u64
}
pub fn get_node_by_id(&self, id: u64) -> Result<Node> {
let nodes = &self.body.borrow().nodes;
if id >= nodes.len() as u64 {
Err(runtime_error!("Invalid id for the node retrieval"))
} else {
Ok(nodes[id as usize].clone())
}
}
pub fn get_context(&self) -> Context {
self.body.borrow().context.upgrade()
}
fn make_serializable(&self) -> SerializableGraph {
let output_node = match self.get_output_node() {
Ok(n) => Some(n.get_id()),
Err(_) => None,
};
Arc::new(SerializableGraphBody {
finalized: self.is_finalized(),
nodes: self
.get_nodes()
.iter()
.map(|n| n.make_serializable())
.collect(),
output_node,
})
}
}
impl Graph {
fn downgrade(&self) -> WeakGraph {
WeakGraph {
body: Arc::downgrade(&self.body),
}
}
}
type WeakGraphBodyPointer = Weak<AtomicRefCell<GraphBody>>;
struct WeakGraph {
body: WeakGraphBodyPointer,
}
impl WeakGraph {
fn upgrade(&self) -> Graph {
Graph {
body: self.body.upgrade().unwrap(),
}
}
}
impl Clone for WeakGraph {
fn clone(&self) -> Self {
WeakGraph {
body: self.body.clone(),
}
}
}
#[doc(hidden)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum NodeAnnotation {
AssociativeOperation,
Private,
Send(u64, u64), PRFMultiplication,
PRFB2A,
}
#[doc(hidden)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum GraphAnnotation {
AssociativeOperation,
OneBitState,
SmallState,
}
struct ContextBody {
finalized: bool,
graphs: Vec<Graph>,
main_graph: Option<WeakGraph>,
graphs_names: HashMap<u64, String>,
graphs_names_inverse: HashMap<String, u64>,
nodes_names: HashMap<(u64, u64), String>,
nodes_names_inverse: HashMap<u64, HashMap<String, u64>>,
nodes_annotations: HashMap<(u64, u64), Vec<NodeAnnotation>>,
graphs_annotations: HashMap<u64, Vec<GraphAnnotation>>,
total_size_nodes: u64,
type_checker: Option<TypeInferenceWorker>,
}
type ContextBodyPointer = Arc<AtomicRefCell<ContextBody>>;
pub struct Context {
body: ContextBodyPointer,
}
impl fmt::Debug for Context {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Context")
.field("body", &self.body.as_ptr())
.finish()
}
}
#[derive(Serialize, Deserialize)]
struct SerializableContextBody {
finalized: bool,
graphs: Vec<SerializableGraph>,
main_graph: Option<u64>,
graphs_names: Vec<(u64, String)>,
nodes_names: Vec<((u64, u64), String)>,
nodes_annotations: Vec<((u64, u64), Vec<NodeAnnotation>)>,
graphs_annotations: Vec<(u64, Vec<GraphAnnotation>)>,
}
impl SerializableContextBody {
fn recover_original_graph(
serializable_graph: SerializableGraph,
context: Context,
) -> Result<Graph> {
let result_graph = context.create_graph()?;
for node in &serializable_graph.nodes {
let mut node_dependencies = vec![];
for id in &node.node_dependencies {
let current_nodes = &result_graph.body.borrow().nodes;
if *id >= current_nodes.len() as u64 {
return Err(runtime_error!("Non-existent node dependency"));
}
node_dependencies.push(current_nodes[*id as usize].clone());
}
let mut graph_dependencies = vec![];
for id in &node.graph_dependencies {
let context = result_graph.get_context();
let current_graphs = &context.body.borrow().graphs;
if *id >= current_graphs.len() as u64 {
return Err(runtime_error!("Non-existent graph dependency"));
}
graph_dependencies.push(current_graphs[*id as usize].clone());
}
result_graph.add_node(
node_dependencies,
graph_dependencies,
node.operation.clone(),
)?;
}
if let Some(id) = serializable_graph.output_node {
let rebuilt_output_node = {
let current_nodes = &result_graph.body.borrow().nodes;
if id >= current_nodes.len() as u64 {
return Err(runtime_error!("Non-existent output node"));
}
current_nodes[id as usize].clone()
};
result_graph.set_output_node(rebuilt_output_node)?;
}
if serializable_graph.finalized {
result_graph.finalize()?;
}
Ok(result_graph)
}
fn recover_original_context(&self) -> Result<Context> {
let result_context = create_context()?;
for graph in &self.graphs {
let _result_graph =
Self::recover_original_graph(graph.clone(), result_context.clone())?;
}
if let Some(id) = self.main_graph {
let rebuilt_main_graph = {
let current_graphs = &result_context.body.borrow().graphs;
if id >= current_graphs.len() as u64 {
return Err(runtime_error!("Non-existent main graph"));
}
current_graphs[id as usize].clone()
};
result_context.set_main_graph(rebuilt_main_graph)?;
}
for (id, _) in &self.graphs_names {
let current_graphs = &result_context.body.borrow().graphs;
if *id >= current_graphs.len() as u64 {
return Err(runtime_error!("graphs_names contain an invalid ID"));
}
}
for ((graph_id, node_id), _) in &self.nodes_names {
let current_graphs = &result_context.body.borrow().graphs;
if *graph_id >= current_graphs.len() as u64 {
return Err(runtime_error!("nodes_names contain an invalid graph ID"));
}
let current_nodes = ¤t_graphs[*graph_id as usize].body.borrow().nodes;
if *node_id >= current_nodes.len() as u64 {
return Err(runtime_error!("nodes_names contain an invalid node ID"));
}
}
for (id, name) in &self.graphs_names {
let current_graph = {
let current_graphs = &result_context.body.borrow().graphs;
current_graphs[*id as usize].clone()
};
result_context.set_graph_name(current_graph, name)?;
}
for ((graph_id, node_id), name) in &self.nodes_names {
let current_node = {
let current_graphs = &result_context.body.borrow().graphs;
let current_nodes = ¤t_graphs[*graph_id as usize].body.borrow().nodes;
current_nodes[*node_id as usize].clone()
};
result_context.set_node_name(current_node, name)?;
}
for (id, annotations) in &self.graphs_annotations {
let current_graph = {
let current_graphs = &result_context.body.borrow().graphs;
current_graphs[*id as usize].clone()
};
for annotation in annotations {
result_context.add_graph_annotation(¤t_graph, annotation.clone())?;
}
}
for ((graph_id, node_id), annotations) in &self.nodes_annotations {
let current_node = {
let current_graphs = &result_context.body.borrow().graphs;
let current_nodes = ¤t_graphs[*graph_id as usize].body.borrow().nodes;
current_nodes[*node_id as usize].clone()
};
for annotation in annotations {
result_context.add_node_annotation(¤t_node, annotation.clone())?;
}
}
if self.finalized {
result_context.finalize()?;
}
Ok(result_context)
}
}
type SerializableContext = Arc<SerializableContextBody>;
impl Clone for Context {
fn clone(&self) -> Self {
Context {
body: self.body.clone(),
}
}
}
impl PartialEq for Context {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.body, &other.body)
}
}
impl Eq for Context {}
impl Context {
pub fn create_graph(&self) -> Result<Graph> {
if self.body.borrow().finalized {
return Err(runtime_error!("Can't add a graph to a finalized context"));
}
let id = self.body.borrow().graphs.len() as u64;
let result = Graph {
body: Arc::new(AtomicRefCell::new(GraphBody {
finalized: false,
nodes: vec![],
output_node: None,
id,
context: self.downgrade(),
})),
};
self.body.borrow_mut().graphs.push(result.clone());
Ok(result)
}
pub fn finalize(&self) -> Result<Context> {
for graph in self.get_graphs() {
graph.check_finalized()?;
}
let main_graph = self.body.borrow().main_graph.clone();
match main_graph {
Some(_) => {
self.body.borrow_mut().finalized = true;
Ok(self.clone())
}
_ => Err(runtime_error!(
"Can't finalize the context without the main graph"
)),
}
}
pub fn set_main_graph(&self, graph: Graph) -> Result<Context> {
let current_main_graph = self.body.borrow().main_graph.clone();
match current_main_graph {
Some(_) => Err(runtime_error!("Main graph is already set")),
None => {
if graph.get_context() != *self {
return Err(runtime_error!("Main graph is from the wrong context"));
}
graph.check_finalized()?;
self.body.borrow_mut().main_graph = Some(graph.downgrade());
Ok(self.clone())
}
}
}
pub fn get_graphs(&self) -> Vec<Graph> {
self.body.borrow().graphs.clone()
}
pub(super) fn is_finalized(&self) -> bool {
self.body.borrow().finalized
}
pub fn check_finalized(&self) -> Result<()> {
if !self.is_finalized() {
return Err(runtime_error!("Context is not finalized"));
}
Ok(())
}
pub fn get_main_graph(&self) -> Result<Graph> {
match &self.body.borrow().main_graph {
Some(g) => Ok(g.upgrade()),
None => Err(runtime_error!("main graph is not set")),
}
}
pub fn get_num_graphs(&self) -> u64 {
self.body.borrow().graphs.len() as u64
}
pub fn get_graph_by_id(&self, id: u64) -> Result<Graph> {
let graphs = &self.body.borrow().graphs;
if id >= graphs.len() as u64 {
Err(runtime_error!("Invalid id for the graph retrieval"))
} else {
Ok(graphs[id as usize].clone())
}
}
pub fn get_node_by_global_id(&self, id: (u64, u64)) -> Result<Node> {
self.get_graph_by_id(id.0)?.get_node_by_id(id.1)
}
fn make_serializable(&self) -> SerializableContext {
let main_graph = match self.get_main_graph() {
Ok(g) => Some(g.get_id()),
Err(_) => None,
};
let cell = self.body.borrow();
Arc::new(SerializableContextBody {
finalized: self.is_finalized(),
graphs: self
.get_graphs()
.iter()
.map(|g| g.make_serializable())
.collect(),
main_graph,
graphs_names: cell.graphs_names.clone().into_iter().collect(),
nodes_names: cell.nodes_names.clone().into_iter().collect(),
graphs_annotations: cell.graphs_annotations.clone().into_iter().collect(),
nodes_annotations: cell.nodes_annotations.clone().into_iter().collect(),
})
}
fn add_type_checker(&self) -> Result<Context> {
{
let mut cell = self.body.borrow_mut();
if cell.type_checker.is_some() {
return Err(runtime_error!(
"Type checker associated with the context already exists"
));
}
cell.type_checker = Some(create_type_inference_worker(self.clone()));
}
for graph in self.get_graphs() {
for node in graph.get_nodes() {
node.get_type()?;
}
}
Ok(self.clone())
}
fn get_total_size_nodes(&self) -> u64 {
self.body.borrow().total_size_nodes
}
fn set_total_size_nodes(&self, size: u64) {
self.body.borrow_mut().total_size_nodes = size;
}
fn try_update_total_size(&self, node: Node) -> Result<()> {
let node_type: Type;
match node.get_operation() {
Operation::Input(input_type) => {
node_type = input_type;
}
Operation::Constant(t, _) => {
node_type = t;
}
_ => return Ok(()),
}
if !node_type.is_valid() {
return Err(runtime_error!("Node with an invalid type: {:?}", node_type));
}
let new_total_size = self
.get_total_size_nodes()
.checked_add(get_size_estimation_in_bits(node_type)?)
.ok_or_else(|| runtime_error!("add overflow!"))?;
if new_total_size > constants::MAX_TOTAL_SIZE_NODES {
return Err(runtime_error!(
"Can't add a node: total size of nodes exceeds MAX_TOTAL_SIZE_NODES"
));
}
self.set_total_size_nodes(new_total_size);
Ok(())
}
fn unregister_node(&self, node: Node) -> Result<()> {
if node.get_graph().get_context() != *self {
return Err(runtime_error!(
"The node to be unregister from a different context"
));
}
if self.is_finalized() {
return Err(runtime_error!(
"Can't unregister a node from a finalized context"
));
}
let node_id = node.get_id();
let graph_id = node.get_graph().get_id();
let mut cell = self.body.borrow_mut();
let name_option = cell.nodes_names.remove(&(graph_id, node_id));
cell.nodes_annotations.remove(&(graph_id, node_id));
if cell.nodes_names_inverse.get(&graph_id).is_none() {
return Ok(());
}
let graph_map_inverse = cell
.nodes_names_inverse
.get_mut(&graph_id)
.expect("Should not be here!");
if let Some(name) = name_option {
graph_map_inverse.remove(&name);
}
Ok(())
}
fn to_versioned_data(&self) -> Result<VersionedData> {
VersionedData::create_versioned_data(
DATA_VERSION,
serde_json::to_string(&self.make_serializable())?,
)
}
}
impl Serialize for Context {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let versioned_context = self
.to_versioned_data()
.expect("Error during conversion from Context into VersionedData");
versioned_context.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Context {
fn deserialize<D>(deserializer: D) -> std::result::Result<Context, D::Error>
where
D: Deserializer<'de>,
{
let versioned_context = VersionedData::deserialize(deserializer)?;
if !versioned_context.check_version(DATA_VERSION) {
Err(runtime_error!(
"Context version doesn't match the requirement"
))
.map_err(serde::de::Error::custom)
} else {
let serializable_context =
serde_json::from_str::<SerializableContext>(versioned_context.get_data_string())
.expect("Error during deserialization of SerializableContext");
serializable_context
.recover_original_context()
.map_err(serde::de::Error::custom)
}
}
}
pub(super) fn create_unchecked_context() -> Result<Context> {
Ok(Context {
body: Arc::new(AtomicRefCell::new(ContextBody {
finalized: false,
graphs: vec![],
main_graph: None,
graphs_names: HashMap::new(),
graphs_names_inverse: HashMap::new(),
nodes_names: HashMap::new(),
nodes_names_inverse: HashMap::new(),
graphs_annotations: HashMap::new(),
nodes_annotations: HashMap::new(),
type_checker: None,
total_size_nodes: 0,
})),
})
}
pub fn create_context() -> Result<Context> {
let context = create_unchecked_context()?;
context.add_type_checker()?;
Ok(context)
}
fn graphs_deep_equal(graph1: Graph, graph2: Graph) -> bool {
let graph1_body = graph1.body.borrow();
let graph2_body = graph2.body.borrow();
if graph1_body.finalized != graph2_body.finalized {
return false;
}
if graph1_body.nodes.len() != graph2_body.nodes.len() {
return false;
}
for j in 0..graph1_body.nodes.len() {
let node1 = graph1_body.nodes[j].clone();
let node2 = graph2_body.nodes[j].clone();
let node1_body = node1.body.borrow();
let node2_body = node2.body.borrow();
if node1_body.operation != node2_body.operation {
return false;
}
let node_dependencies1: Vec<u64> = node1_body
.node_dependencies
.iter()
.map(|n| n.upgrade().get_id())
.collect();
let node_dependencies2: Vec<u64> = node2_body
.node_dependencies
.iter()
.map(|n| n.upgrade().get_id())
.collect();
if node_dependencies1 != node_dependencies2 {
return false;
}
let graph_dependencies1: Vec<u64> = node1_body
.graph_dependencies
.iter()
.map(|g| g.upgrade().get_id())
.collect();
let graph_dependencies2: Vec<u64> = node2_body
.graph_dependencies
.iter()
.map(|g| g.upgrade().get_id())
.collect();
if graph_dependencies1 != graph_dependencies2 {
return false;
}
}
if graph1_body
.output_node
.clone()
.map(|n| n.upgrade().get_id())
!= graph2_body
.output_node
.clone()
.map(|n| n.upgrade().get_id())
{
return false;
}
true
}
pub fn contexts_deep_equal(context1: Context, context2: Context) -> bool {
let body1 = context1.body.borrow();
let body2 = context2.body.borrow();
if body1.finalized != body2.finalized {
return false;
}
if body1.graphs_names != body2.graphs_names {
return false;
}
if body1.nodes_names != body2.nodes_names {
return false;
}
if body1.nodes_annotations != body2.nodes_annotations {
return false;
}
if body1.graphs_annotations != body2.graphs_annotations {
return false;
}
if body1.graphs.len() != body2.graphs.len() {
return false;
}
for i in 0..body1.graphs.len() {
if !graphs_deep_equal(body1.graphs[i].clone(), body2.graphs[i].clone()) {
return false;
}
}
body1.main_graph.clone().map(|g| g.upgrade().get_id())
== body2.main_graph.clone().map(|g| g.upgrade().get_id())
}
impl Context {
pub fn set_graph_name(&self, graph: Graph, name: &str) -> Result<Context> {
if graph.get_context() != *self {
return Err(runtime_error!(
"The graph to be named is in a different context"
));
}
if self.is_finalized() {
return Err(runtime_error!(
"Can't set a graph name in a finalized context"
));
}
let id = graph.get_id();
let name_owned = name.to_owned();
let mut cell = self.body.borrow_mut();
if cell.graphs_names.get(&id).is_some() {
return Err(runtime_error!("Can't set the graph name twice"));
}
if cell.graphs_names_inverse.get(name).is_some() {
return Err(runtime_error!("Graph names must be unique"));
}
cell.graphs_names.insert(id, name_owned.clone());
cell.graphs_names_inverse.insert(name_owned, id);
Ok(self.clone())
}
pub fn get_graph_name(&self, graph: Graph) -> Result<String> {
if graph.get_context() != *self {
return Err(runtime_error!("The graph is in a different context"));
}
let cell = self.body.borrow();
Ok(cell
.graphs_names
.get(&graph.get_id())
.ok_or_else(|| runtime_error!("The graph does not have a name assigned"))?
.clone())
}
pub fn retrieve_graph(&self, name: &str) -> Result<Graph> {
let cell = self.body.borrow();
let id = cell
.graphs_names_inverse
.get(name)
.ok_or_else(|| runtime_error!("No graph with such a name exists"))?;
let graph = cell.graphs[*id as usize].clone();
Ok(graph)
}
pub fn set_node_name(&self, node: Node, name: &str) -> Result<Context> {
if node.get_graph().get_context() != *self {
return Err(runtime_error!(
"The node to be named is in a different context"
));
}
if self.is_finalized() {
return Err(runtime_error!(
"Can't set a node name in a finalized context"
));
}
let node_id = node.get_id();
let graph_id = node.get_graph().get_id();
let mut cell = self.body.borrow_mut();
if cell.nodes_names.get(&(graph_id, node_id)).is_some() {
return Err(runtime_error!("Can't set the node name twice"));
}
if cell.nodes_names_inverse.get(&graph_id).is_none() {
cell.nodes_names_inverse.insert(graph_id, HashMap::new());
}
let graph_map_inverse = cell
.nodes_names_inverse
.get_mut(&graph_id)
.expect("Should not be here!");
if graph_map_inverse.get(name).is_some() {
return Err(runtime_error!(
"Node names must be unique (within the graph)"
));
}
graph_map_inverse.insert(name.to_owned(), node_id);
cell.nodes_names
.insert((graph_id, node_id), name.to_owned());
Ok(self.clone())
}
pub fn get_node_name(&self, node: Node) -> Result<String> {
if node.get_graph().get_context() != *self {
return Err(runtime_error!("The node is in a different context"));
}
let node_id = node.get_id();
let graph_id = node.get_graph().get_id();
let cell = self.body.borrow_mut();
Ok(cell
.nodes_names
.get(&(graph_id, node_id))
.ok_or_else(|| runtime_error!("The node is not named"))?
.clone())
}
pub fn retrieve_node(&self, graph: Graph, name: &str) -> Result<Node> {
if graph.get_context() != *self {
return Err(runtime_error!("The graph is in a different context"));
}
let graph_id = graph.get_id();
let cell = self.body.borrow();
let node_id = cell
.nodes_names_inverse
.get(&graph_id)
.ok_or_else(|| runtime_error!("The graph has no named nodes"))?
.get(name)
.ok_or_else(|| runtime_error!("Node with a given name does not exist"))?;
Ok(graph.body.borrow().nodes[*node_id as usize].clone())
}
fn prepare_input_values<T: Clone>(
&self,
graph: Graph,
values: HashMap<&str, T>,
) -> Result<Vec<T>> {
if graph.get_context() != *self {
return Err(runtime_error!("The graph is in a different context"));
}
let graph_id = graph.get_id();
let cell = self.body.borrow();
for node_name in values.keys() {
cell.nodes_names_inverse
.get(&graph_id)
.ok_or_else(|| runtime_error!("Trying to call graph without named nodes"))?
.get(node_name as &str)
.ok_or_else(|| runtime_error!("Input with a given name is not found"))?;
}
let mut result = vec![];
for node in graph.get_nodes() {
if let Operation::Input(_) = node.get_operation() {
let node_id = node.get_id();
let node_name = cell
.nodes_names
.get(&(graph_id, node_id))
.ok_or_else(|| runtime_error!("Unnamed input"))?;
let node_value = values
.get(node_name as &str)
.ok_or_else(|| runtime_error!("Unspecified input"))?
.clone();
result.push(node_value);
}
}
Ok(result)
}
pub(super) fn add_node_annotation(
&self,
node: &Node,
annotation: NodeAnnotation,
) -> Result<Context> {
if node.get_graph().get_context() != *self {
return Err(runtime_error!(
"The node to be annotated is in a different context"
));
}
if self.is_finalized() {
return Err(runtime_error!(
"Can't add a node annotation in a finalized context"
));
}
let node_id = node.get_id();
let graph_id = node.get_graph().get_id();
let key = (graph_id, node_id);
let mut cell = self.body.borrow_mut();
let annotations = cell.nodes_annotations.get_mut(&key);
if let Some(annotation_vec) = annotations {
annotation_vec.push(annotation);
} else {
cell.nodes_annotations.insert(key, vec![annotation]);
}
Ok(self.clone())
}
pub(super) fn get_node_annotations(&self, node: Node) -> Result<Vec<NodeAnnotation>> {
if node.get_graph().get_context() != *self {
return Err(runtime_error!("The node is in a different context"));
}
let node_id = node.get_id();
let graph_id = node.get_graph().get_id();
let cell = self.body.borrow_mut();
Ok(cell
.nodes_annotations
.get(&(graph_id, node_id))
.cloned()
.unwrap_or_else(Vec::new))
}
fn add_graph_annotation(&self, graph: &Graph, annotation: GraphAnnotation) -> Result<Context> {
if graph.get_context() != *self {
return Err(runtime_error!(
"The graph to be annotated is in a different context"
));
}
if self.is_finalized() {
return Err(runtime_error!(
"Can't set a graph annotation in a finalized context"
));
}
let id = graph.get_id();
let mut cell = self.body.borrow_mut();
let annotations = cell.graphs_annotations.get_mut(&id);
if let Some(annotation_vec) = annotations {
annotation_vec.push(annotation);
} else {
cell.graphs_annotations.insert(id, vec![annotation]);
}
Ok(self.clone())
}
fn get_graph_annotations(&self, graph: Graph) -> Result<Vec<GraphAnnotation>> {
if graph.get_context() != *self {
return Err(runtime_error!("The graph is in a different context"));
}
let cell = self.body.borrow();
Ok(cell
.graphs_annotations
.get(&graph.get_id())
.cloned()
.unwrap_or_else(Vec::new))
}
}
impl Graph {
pub fn set_as_main(&self) -> Result<Graph> {
self.get_context().set_main_graph(self.clone())?;
Ok(self.clone())
}
pub fn set_name(&self, name: &str) -> Result<Graph> {
self.get_context().set_graph_name(self.clone(), name)?;
Ok(self.clone())
}
pub fn get_name(&self) -> Result<String> {
self.get_context().get_graph_name(self.clone())
}
#[doc(hidden)]
pub fn add_annotation(&self, annotation: GraphAnnotation) -> Result<Graph> {
self.get_context().add_graph_annotation(self, annotation)?;
Ok(self.clone())
}
pub(super) fn get_annotations(&self) -> Result<Vec<GraphAnnotation>> {
self.get_context().get_graph_annotations(self.clone())
}
pub fn retrieve_node(&self, name: &str) -> Result<Node> {
self.get_context().retrieve_node(self.clone(), name)
}
pub fn prepare_input_values<T: Clone>(&self, values: HashMap<&str, T>) -> Result<Vec<T>> {
self.get_context()
.prepare_input_values(self.clone(), values)
}
}
impl Node {
pub fn set_name(&self, name: &str) -> Result<Node> {
self.get_graph()
.get_context()
.set_node_name(self.clone(), name)?;
Ok(self.clone())
}
pub fn get_name(&self) -> Result<String> {
self.get_graph().get_context().get_node_name(self.clone())
}
#[doc(hidden)]
pub fn add_annotation(&self, annotation: NodeAnnotation) -> Result<Node> {
self.get_graph()
.get_context()
.add_node_annotation(self, annotation)?;
Ok(self.clone())
}
#[doc(hidden)]
pub fn get_annotations(&self) -> Result<Vec<NodeAnnotation>> {
self.get_graph()
.get_context()
.get_node_annotations(self.clone())
}
pub fn add(&self, b: Node) -> Result<Node> {
self.get_graph().add(self.clone(), b)
}
pub fn subtract(&self, b: Node) -> Result<Node> {
self.get_graph().subtract(self.clone(), b)
}
pub fn multiply(&self, b: Node) -> Result<Node> {
self.get_graph().multiply(self.clone(), b)
}
pub fn dot(&self, b: Node) -> Result<Node> {
self.get_graph().dot(self.clone(), b)
}
pub fn matmul(&self, b: Node) -> Result<Node> {
self.get_graph().matmul(self.clone(), b)
}
pub fn truncate(&self, scale: u64) -> Result<Node> {
self.get_graph().truncate(self.clone(), scale)
}
pub fn sum(&self, axes: ArrayShape) -> Result<Node> {
self.get_graph().sum(self.clone(), axes)
}
pub fn permute_axes(&self, axes: ArrayShape) -> Result<Node> {
self.get_graph().permute_axes(self.clone(), axes)
}
pub fn get(&self, index: ArrayShape) -> Result<Node> {
self.get_graph().get(self.clone(), index)
}
pub fn get_slice(&self, slice: Slice) -> Result<Node> {
self.get_graph().get_slice(self.clone(), slice)
}
pub fn reshape(&self, new_type: Type) -> Result<Node> {
self.get_graph().reshape(self.clone(), new_type)
}
#[doc(hidden)]
pub fn nop(&self) -> Result<Node> {
self.get_graph().nop(self.clone())
}
#[doc(hidden)]
pub fn prf(&self, iv: u64, output_type: Type) -> Result<Node> {
self.get_graph().prf(self.clone(), iv, output_type)
}
pub fn a2b(&self) -> Result<Node> {
self.get_graph().a2b(self.clone())
}
pub fn b2a(&self, scalar_type: ScalarType) -> Result<Node> {
self.get_graph().b2a(self.clone(), scalar_type)
}
pub fn tuple_get(&self, index: u64) -> Result<Node> {
self.get_graph().tuple_get(self.clone(), index)
}
pub fn named_tuple_get(&self, key: String) -> Result<Node> {
self.get_graph().named_tuple_get(self.clone(), key)
}
pub fn vector_get(&self, index: Node) -> Result<Node> {
self.get_graph().vector_get(self.clone(), index)
}
pub fn array_to_vector(&self) -> Result<Node> {
self.get_graph().array_to_vector(self.clone())
}
pub fn vector_to_array(&self) -> Result<Node> {
self.get_graph().vector_to_array(self.clone())
}
pub fn repeat(&self, n: u64) -> Result<Node> {
self.get_graph().repeat(self.clone(), n)
}
pub fn set_as_output(&self) -> Result<Node> {
self.get_graph().set_output_node(self.clone())?;
Ok(self.clone())
}
}
pub(crate) fn copy_node_name(in_node: Node, out_node: Node) -> Result<()> {
let node_name_result = in_node.get_name();
if let Ok(node_name) = node_name_result {
out_node.set_name(&node_name)?;
}
Ok(())
}
impl Context {
pub(super) fn downgrade(&self) -> WeakContext {
WeakContext {
body: Arc::downgrade(&self.body),
}
}
}
type WeakContextBodyPointer = Weak<AtomicRefCell<ContextBody>>;
pub(super) struct WeakContext {
body: WeakContextBodyPointer,
}
impl WeakContext {
pub(super) fn upgrade(&self) -> Context {
Context {
body: self.body.upgrade().unwrap(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data_types::{
array_type, scalar_type, tuple_type, vector_type, BIT, UINT16, UINT64,
};
use crate::version::DATA_VERSION;
use std::rc::Rc;
#[test]
fn test_wellformed_cases() {
let context = create_unchecked_context().unwrap();
let graph = context.create_graph().unwrap();
let input1 = graph.input(scalar_type(BIT)).unwrap();
let input2 = graph.input(scalar_type(BIT)).unwrap();
graph.add(input1.clone(), input2.clone()).unwrap();
graph.subtract(input1.clone(), input2.clone()).unwrap();
graph.multiply(input1.clone(), input2.clone()).unwrap();
graph.dot(input1.clone(), input2.clone()).unwrap();
graph.matmul(input1.clone(), input2.clone()).unwrap();
graph.truncate(input1.clone(), 123).unwrap();
let input3 = graph.input(array_type(vec![10, 20, 30], BIT)).unwrap();
graph.sum(input3.clone(), vec![1, 2]).unwrap();
graph.permute_axes(input3.clone(), vec![1, 2, 0]).unwrap();
graph.get(input3.clone(), vec![1, 2]).unwrap();
graph
.reshape(input3.clone(), array_type(vec![20, 300], BIT))
.unwrap();
graph.nop(input3.clone()).unwrap();
let key = graph.random(array_type(vec![128], BIT)).unwrap();
graph
.prf(key.clone(), 0, array_type(vec![10, 10], UINT64))
.unwrap();
graph
.stack(vec![input1.clone(), input2.clone()], vec![2, 1])
.unwrap();
let c = graph
.constant(scalar_type(BIT), Value::from_bytes(vec![1]))
.unwrap();
let input4 = graph.input(array_type(vec![10, 10], UINT64)).unwrap();
let bits = graph.a2b(input4.clone()).unwrap();
graph.b2a(bits.clone(), UINT64).unwrap();
let t = graph
.create_tuple(vec![input1.clone(), input2.clone()])
.unwrap();
let _v = graph
.create_vector(scalar_type(BIT), vec![input1.clone(), input2.clone()])
.unwrap();
let nt = graph
.create_named_tuple(vec![
("Name".to_owned(), input1.clone()),
("Gender".to_owned(), input2.clone()),
])
.unwrap();
graph.tuple_get(t, 1).unwrap();
graph.named_tuple_get(nt, "Gender".to_owned()).unwrap();
let v = graph.repeat(c.clone(), 100).unwrap();
graph.zip(vec![v.clone(), v.clone(), v.clone()]).unwrap();
let zero = graph
.constant(scalar_type(UINT64), Value::from_bytes(vec![0; 8]))
.unwrap();
graph.vector_get(v, zero).unwrap();
graph.array_to_vector(input1.clone()).unwrap();
graph.vector_to_array(input1.clone()).unwrap();
}
#[test]
fn call_iterate_test() {
let context = create_unchecked_context().unwrap();
let single_bit_adder = context.create_graph().unwrap();
{
let carry = single_bit_adder.input(scalar_type(BIT)).unwrap();
let inputs = single_bit_adder
.input(tuple_type(vec![scalar_type(BIT), scalar_type(BIT)]))
.unwrap();
let a = single_bit_adder.tuple_get(inputs.clone(), 0).unwrap();
let b = single_bit_adder.tuple_get(inputs.clone(), 1).unwrap();
let ac = single_bit_adder.add(carry.clone(), a.clone()).unwrap();
let bc = single_bit_adder.add(carry.clone(), b.clone()).unwrap();
let result = single_bit_adder.add(ac.clone(), b.clone()).unwrap();
let result_carry = single_bit_adder
.add(
single_bit_adder.multiply(ac.clone(), bc.clone()).unwrap(),
carry,
)
.unwrap();
let output = single_bit_adder
.create_tuple(vec![result_carry.clone(), result.clone()])
.unwrap();
single_bit_adder.set_output_node(output).unwrap();
single_bit_adder.finalize().unwrap();
}
let v32 = vector_type(32, scalar_type(BIT));
let adder = context.create_graph().unwrap();
{
let a = adder.input(v32.clone()).unwrap();
let b = adder.input(v32.clone()).unwrap();
let azb = adder.zip(vec![a, b]).unwrap();
let c = adder
.constant(scalar_type(BIT), Value::from_bytes(vec![0]))
.unwrap();
let cr = adder.iterate(single_bit_adder, c, azb).unwrap();
let r = adder.tuple_get(cr, 1).unwrap();
adder.set_output_node(r).unwrap();
adder.finalize().unwrap();
}
let three_adder = context.create_graph().unwrap();
let a = three_adder.input(v32.clone()).unwrap();
let b = three_adder.input(v32.clone()).unwrap();
let c = three_adder.input(v32.clone()).unwrap();
let result = three_adder
.call(
adder.clone(),
vec![three_adder.call(adder.clone(), vec![a, b]).unwrap(), c],
)
.unwrap();
three_adder.set_output_node(result).unwrap();
three_adder.finalize().unwrap();
context.set_main_graph(three_adder).unwrap();
context.finalize().unwrap();
}
#[test]
fn test_malformed_graphs() {
let context = create_unchecked_context().unwrap();
let graph = context.create_graph().unwrap();
let graph2 = context.create_graph().unwrap();
let input1 = graph.input(scalar_type(UINT64)).unwrap();
let input2 = graph2.input(scalar_type(UINT64)).unwrap();
let e1 = graph.add(input1.clone(), input2.clone());
assert!(e1.is_err());
let fake_node = Node {
body: Arc::new(AtomicRefCell::new(NodeBody {
graph: graph.downgrade(),
node_dependencies: vec![],
graph_dependencies: vec![],
operation: Operation::Input(scalar_type(BIT)),
id: 0,
})),
};
let e2 = graph.add(fake_node.clone(), input1.clone());
assert!(e2.is_err());
let fake_node_2 = Node {
body: Arc::new(AtomicRefCell::new(NodeBody {
graph: graph.downgrade(),
node_dependencies: vec![],
graph_dependencies: vec![],
operation: Operation::Input(scalar_type(BIT)),
id: 31337,
})),
};
let e3 = graph.add(fake_node_2.clone(), input1.clone());
assert!(e3.is_err());
graph.set_output_node(input1.clone()).unwrap();
graph.finalize().unwrap();
let e4 = graph.add(input1.clone(), input1.clone());
assert!(e4.is_err());
let graph3 = context.create_graph().unwrap();
let e5 = graph3.finalize();
assert!(e5.is_err());
let e6 = graph3.set_output_node(input1);
assert!(e6.is_err());
}
#[test]
fn test_malformed_contexts() {
let context = create_unchecked_context().unwrap();
let e1 = context.finalize();
assert!(e1.is_err());
let graph = context.create_graph().unwrap();
let e2 = graph.finalize();
assert!(e2.is_err());
graph
.set_output_node(graph.create_tuple(vec![]).unwrap())
.unwrap();
let e4 = context.set_main_graph(graph.clone());
assert!(e4.is_err());
graph.finalize().unwrap();
let e3 = context.finalize();
assert!(e3.is_err());
context.set_main_graph(graph.clone()).unwrap();
context.finalize().unwrap();
}
#[test]
fn test_malformed_call_iterate() {
let context1 = create_unchecked_context().unwrap();
let graph1 = context1.create_graph().unwrap();
let output = graph1.create_tuple(vec![]).unwrap();
graph1.set_output_node(output).unwrap();
let graph2 = context1.create_graph().unwrap();
let e1 = graph2.call(graph1.clone(), vec![]);
assert!(e1.is_err());
graph1.finalize().unwrap();
graph2.call(graph1.clone(), vec![]).unwrap();
let context2 = create_unchecked_context().unwrap();
let graph3 = context2.create_graph().unwrap();
let e2 = graph3.call(graph1.clone(), vec![]);
assert!(e2.is_err());
let graph4 = context1.create_graph().unwrap();
graph4.input(tuple_type(vec![])).unwrap();
graph4.input(tuple_type(vec![])).unwrap();
let t = graph4.create_tuple(vec![]).unwrap();
let tt = graph4.create_tuple(vec![t.clone(), t.clone()]).unwrap();
graph4.set_output_node(tt).unwrap();
let graph5 = context1.create_graph().unwrap();
let es = graph5.create_tuple(vec![]).unwrap();
let v = graph5
.repeat(graph5.create_tuple(vec![]).unwrap(), 10)
.unwrap();
let e3 = graph5.iterate(graph4.clone(), es.clone(), v.clone());
assert!(e3.is_err());
graph4.finalize().unwrap();
graph5
.iterate(graph4.clone(), es.clone(), v.clone())
.unwrap();
let graph6 = context2.create_graph().unwrap();
let es = graph6.create_tuple(vec![]).unwrap();
let v = graph6
.repeat(graph6.create_tuple(vec![]).unwrap(), 10)
.unwrap();
let e4 = graph6.iterate(graph4.clone(), es.clone(), v.clone());
assert!(e4.is_err());
}
#[test]
fn test_graph_consistency() {
let context = create_unchecked_context().unwrap();
let graph = context.create_graph().unwrap();
let input1 = graph.input(scalar_type(BIT)).unwrap();
let input2 = graph.input(scalar_type(BIT)).unwrap();
graph.add(input1.clone(), input2.clone()).unwrap();
graph.set_output_node(input1.clone()).unwrap();
graph.finalize().unwrap();
for (i, node) in graph.get_nodes().iter().enumerate() {
assert_eq!(node.get_id(), i as u64);
assert!(graph == node.get_graph());
for dependency in node.get_node_dependencies() {
assert!(dependency.get_id() < node.get_id());
}
}
let operations: Vec<Operation> = graph
.get_nodes()
.iter()
.map(|x| x.get_operation())
.collect();
assert!(operations.len() == 3);
match operations[0] {
Operation::Input(_) => {}
_ => {
panic!("Input expected");
}
}
match operations[1] {
Operation::Input(_) => {}
_ => {
panic!("Input expected");
}
}
match operations[2] {
Operation::Add => {}
_ => {
panic!("Add expected");
}
}
}
#[test]
fn test_unfinalized_graphs() {
let context = create_unchecked_context().unwrap();
let e = context.finalize();
assert!(e.is_err());
let graph = context.create_graph().unwrap();
let graph2 = context.create_graph().unwrap();
let e = context.finalize();
assert!(e.is_err());
let i = graph2.input(scalar_type(BIT)).unwrap();
graph2.set_output_node(i).unwrap();
graph2.finalize().unwrap();
context.set_main_graph(graph2).unwrap();
let e = context.finalize();
assert!(e.is_err());
let ii = graph.input(scalar_type(BIT)).unwrap();
graph.set_output_node(ii).unwrap();
graph.finalize().unwrap();
context.finalize().unwrap();
}
#[test]
fn test_operation_serialization() {
let o = Operation::Constant(scalar_type(BIT), Value::from_bytes(vec![1]));
let se = serde_json::to_string(&o).unwrap();
assert_eq!(
se,
format!("{{\"Constant\":[{{\"Scalar\":{{\"signed\":false,\"modulus\":2}}}},{{\"version\":{},\"data\":\"{{\\\"body\\\":{{\\\"Bytes\\\":[1]}}}}\"}}]}}", DATA_VERSION)
);
let de = serde_json::from_str::<Operation>(&se).unwrap();
assert_eq!(de, o);
}
fn context_generators() -> Vec<Box<dyn Fn() -> Context>> {
let context1 = || {
let context = create_unchecked_context().unwrap();
context
};
let context2 = || {
let context = create_unchecked_context().unwrap();
let graph = context.create_graph().unwrap();
let i = graph.input(scalar_type(BIT)).unwrap();
graph.set_output_node(i).unwrap();
graph.finalize().unwrap();
context.set_main_graph(graph).unwrap();
context.finalize().unwrap();
context
};
let context3 = || {
let context = create_unchecked_context().unwrap();
context.create_graph().unwrap();
context
};
let context4 = || {
let context = create_unchecked_context().unwrap();
let graph = context.create_graph().unwrap();
let i = graph.input(scalar_type(BIT)).unwrap();
graph.set_output_node(i).unwrap();
graph.finalize().unwrap();
context
};
let context5 = || {
let context = create_unchecked_context().unwrap();
let graph = context.create_graph().unwrap();
graph.input(scalar_type(BIT)).unwrap();
context
};
let context6 = || {
let context = create_unchecked_context().unwrap();
let graph = context.create_graph().unwrap();
graph
.constant(scalar_type(BIT), Value::from_bytes(vec![1]))
.unwrap();
context
};
let context7 = || {
let context = create_unchecked_context().unwrap();
let graph = context.create_graph().unwrap();
let i1 = graph.input(scalar_type(BIT)).unwrap();
let i2 = graph.input(scalar_type(BIT)).unwrap();
graph.add(i1, i2).unwrap();
context
};
let context8 = || {
let context = create_unchecked_context().unwrap();
let graph = context.create_graph().unwrap();
let i1 = graph.input(scalar_type(BIT)).unwrap();
let i2 = graph.input(scalar_type(BIT)).unwrap();
graph.add(i2, i1).unwrap();
context
};
let context9 = || {
let context = create_unchecked_context().unwrap();
let graph1 = context.create_graph().unwrap();
let i1 = graph1.input(scalar_type(BIT)).unwrap();
graph1.set_output_node(i1).unwrap();
graph1.finalize().unwrap();
let graph2 = context.create_graph().unwrap();
let i2 = graph2.input(scalar_type(BIT)).unwrap();
graph2.set_output_node(i2).unwrap();
graph2.finalize().unwrap();
let graph3 = context.create_graph().unwrap();
let i = graph3.input(scalar_type(BIT)).unwrap();
graph3.call(graph1, vec![i]).unwrap();
context
};
let context10 = || {
let context = create_unchecked_context().unwrap();
let graph1 = context.create_graph().unwrap();
let i1 = graph1.input(scalar_type(BIT)).unwrap();
graph1.set_output_node(i1).unwrap();
graph1.finalize().unwrap();
let graph2 = context.create_graph().unwrap();
let i2 = graph2.input(scalar_type(BIT)).unwrap();
graph2.set_output_node(i2).unwrap();
graph2.finalize().unwrap();
let graph3 = context.create_graph().unwrap();
let i = graph3.input(scalar_type(BIT)).unwrap();
graph3.call(graph2, vec![i]).unwrap();
context
};
let context11 = || {
let context = create_unchecked_context().unwrap();
let graph1 = context.create_graph().unwrap();
let i1 = graph1.input(scalar_type(BIT)).unwrap();
graph1.set_output_node(i1).unwrap();
graph1.finalize().unwrap();
let graph2 = context.create_graph().unwrap();
let i2 = graph2.input(scalar_type(BIT)).unwrap();
graph2.set_output_node(i2).unwrap();
graph2.finalize().unwrap();
let graph3 = context.create_graph().unwrap();
let i = graph3.input(scalar_type(BIT)).unwrap();
let o = graph3.call(graph2, vec![i]).unwrap();
graph3.set_output_node(o).unwrap();
context
};
let context12 = || {
let context = create_unchecked_context().unwrap();
let graph = context.create_graph().unwrap();
let i = graph.input(scalar_type(BIT)).unwrap();
graph.set_output_node(i).unwrap();
graph.finalize().unwrap();
context.set_main_graph(graph).unwrap();
context
};
let context13 = || {
let context = create_unchecked_context().unwrap();
let graph = context.create_graph().unwrap();
let i = graph.input(scalar_type(BIT)).unwrap();
graph.set_output_node(i).unwrap();
graph.finalize().unwrap();
context.set_main_graph(graph.clone()).unwrap();
context.set_graph_name(graph, "main").unwrap();
context.finalize().unwrap();
context
};
let context14 = || {
let context = create_unchecked_context().unwrap();
let graph = context.create_graph().unwrap();
let i = graph.input(scalar_type(BIT)).unwrap();
graph.set_output_node(i.clone()).unwrap();
graph.finalize().unwrap();
context.set_main_graph(graph.clone()).unwrap();
context.set_graph_name(graph.clone(), "main").unwrap();
context.set_node_name(i.clone(), "input").unwrap();
context
.add_graph_annotation(&graph, GraphAnnotation::AssociativeOperation)
.unwrap();
context
.add_node_annotation(&i, NodeAnnotation::AssociativeOperation)
.unwrap();
context.finalize().unwrap();
context
};
let mut closures: Vec<Box<dyn Fn() -> Context>> = vec![];
closures.push(Box::new(context1));
closures.push(Box::new(context2));
closures.push(Box::new(context3));
closures.push(Box::new(context4));
closures.push(Box::new(context5));
closures.push(Box::new(context6));
closures.push(Box::new(context7));
closures.push(Box::new(context8));
closures.push(Box::new(context9));
closures.push(Box::new(context10));
closures.push(Box::new(context11));
closures.push(Box::new(context12));
closures.push(Box::new(context13));
closures.push(Box::new(context14));
closures
}
fn test_context_deep_equal_helper_equal<F>(f: F)
where
F: Fn() -> Context,
{
let context1 = f();
let context2 = f();
assert!(context1 != context2);
assert!(contexts_deep_equal(context1, context2));
}
fn test_context_deep_equal_helper_nonequal<F1, F2>(f1: F1, f2: F2)
where
F1: Fn() -> Context,
F2: Fn() -> Context,
{
let context1 = f1();
let context2 = f2();
assert!(context1 != context2);
assert!(!contexts_deep_equal(context1, context2));
}
#[test]
fn test_context_deep_equal() {
let generators = context_generators();
for i in 0..generators.len() {
test_context_deep_equal_helper_equal(&generators[i]);
for j in 0..i {
test_context_deep_equal_helper_nonequal(&generators[i], &generators[j]);
}
}
}
pub fn deserialize_error_lenient(serialized_string: &str, error_msg: &str) {
use std::panic::catch_unwind;
let result = catch_unwind(|| serde_json::from_str::<Context>(serialized_string).unwrap());
use ciphercore_utils::execute_main::extract_panic_message;
if let Err(e) = result {
match extract_panic_message(e) {
Some(msg) => {
if !msg.contains(error_msg) {
panic!("Undesireable panic: {}", msg);
}
}
None => panic!("Panic of unknown type"),
}
} else {
panic!("Expected error not occur")
}
}
use std::{
fs::File,
io::{prelude::*, BufReader},
path::Path,
};
fn lines_from_file(filename: impl AsRef<Path>) -> Vec<String> {
let file = File::open(filename).expect("no such file");
let buf = BufReader::new(file);
buf.lines()
.map(|l| l.expect("Could not parse line"))
.collect()
}
#[test]
#[cfg(not(feature = "nightly-features"))]
fn test_context_serialize() {
let generators = context_generators();
let contexts: Vec<Context> = generators.iter().map(|generator| generator()).collect();
let serialized_contexts: Vec<String> = contexts
.iter()
.map(|context| serde_json::to_string(context).unwrap())
.collect();
let deserialized_contexts: Vec<Context> = serialized_contexts
.iter()
.map(|serialized_context| serde_json::from_str(serialized_context).unwrap())
.collect();
assert_eq!(contexts.len(), deserialized_contexts.len());
for i in 0..contexts.len() {
assert!(contexts[i] != deserialized_contexts[i]);
assert!(contexts_deep_equal(
contexts[i].clone(),
deserialized_contexts[i].clone()
));
}
let test_case = lines_from_file("./src/test_data/version_testcase.txt");
assert_eq!(serde_json::to_string(&contexts[0]).unwrap(), test_case[0]);
deserialize_error_lenient(&test_case[1], "Non-existent main graph");
assert_eq!(serde_json::to_string(&contexts[9]).unwrap(), test_case[2]);
deserialize_error_lenient(&test_case[3], "Non-existent node dependency");
deserialize_error_lenient(&test_case[4], "Non-existent graph dependency");
assert_eq!(serde_json::to_string(&contexts[13]).unwrap(), test_case[5]);
deserialize_error_lenient(&test_case[6], "Non-existent output node");
deserialize_error_lenient(&test_case[7], "graphs_names contain an invalid ID");
deserialize_error_lenient(&test_case[8], "nodes_names contain an invalid graph ID");
deserialize_error_lenient(&test_case[9], "nodes_names contain an invalid node ID");
deserialize_error_lenient(
&test_case[10],
"Context version doesn't match the requirement",
);
deserialize_error_lenient(
&test_case[11],
"Context version doesn't match the requirement",
);
}
use crate::data_types::INT32;
use crate::data_values::Value;
use crate::evaluators::random_evaluate;
use std::iter::FromIterator;
#[test]
fn test_named_contexts() {
let helper = || -> Result<Context> {
let context = create_context()?;
let graph = context.create_graph()?;
let input_a = graph.input(scalar_type(INT32))?;
let input_b = graph.input(scalar_type(INT32))?;
let output = graph.add(input_a.clone(), input_b.clone())?;
graph.set_output_node(output.clone())?;
graph.finalize()?;
context.set_main_graph(graph.clone())?;
assert!(context.get_graph_name(graph.clone()).is_err());
assert!(context.retrieve_graph("main").is_err());
assert!(context.get_node_name(input_a.clone()).is_err());
assert!(context.retrieve_node(graph.clone(), "a").is_err());
context.set_graph_name(graph.clone(), "main")?;
context.set_node_name(input_a.clone(), "a")?;
assert!(context.retrieve_node(graph.clone(), "b").is_err());
context.set_node_name(input_b.clone(), "b")?;
context.finalize()?;
assert_eq!(context.get_graph_name(graph.clone())?, "main");
assert_eq!(context.get_node_name(input_a.clone())?, "a");
assert_eq!(context.get_node_name(input_b.clone())?, "b");
assert!(context.retrieve_node(graph.clone(), "a")? == input_a.clone());
Ok(context)
};
let context = helper().unwrap();
let helper2 = |context: Context| -> Result<i32> {
let other_context = create_context()?;
let other_graph = other_context.create_graph()?;
let input = other_graph.input(scalar_type(BIT))?;
let other_input = other_graph.input(scalar_type(BIT))?;
assert!(context
.prepare_input_values::<Value>(other_graph.clone(), HashMap::new())
.is_err());
assert!(other_context
.prepare_input_values::<Value>(
other_graph.clone(),
HashMap::from_iter([("a", Value::from_scalar(123, INT32)?)])
)
.is_err());
other_context.set_node_name(input, "b")?;
assert!(other_context
.prepare_input_values::<Value>(
other_graph.clone(),
HashMap::from_iter([("a", Value::from_scalar(123, INT32)?)])
)
.is_err());
assert!(other_context
.prepare_input_values::<Value>(
other_graph.clone(),
HashMap::from_iter([("b", Value::from_scalar(123, INT32)?)])
)
.is_err());
other_context.set_node_name(other_input, "c")?;
assert!(other_context
.prepare_input_values::<Value>(
other_graph,
HashMap::from_iter([("b", Value::from_scalar(123, INT32)?)])
)
.is_err());
let g = context.retrieve_graph("main")?;
let result = random_evaluate(
g.clone(),
context.prepare_input_values(
g.clone(),
HashMap::from_iter([
("a", Value::from_scalar(123, INT32)?),
("b", Value::from_scalar(456, INT32)?),
]),
)?,
)?;
let result = result.to_i32(INT32)?;
Ok(result)
};
assert_eq!(helper2(context).unwrap(), 579);
let helper3 = |context: Context| -> Result<()> {
let other_context = create_context()?;
let other_graph = other_context.create_graph()?;
let other_node = other_graph.input(scalar_type(BIT))?;
assert!(context
.set_graph_name(other_graph.clone(), "outside")
.is_err());
assert!(context.get_graph_name(other_graph.clone()).is_err());
assert!(context
.set_node_name(other_node.clone(), "outside")
.is_err());
assert!(context.get_node_name(other_node.clone()).is_err());
assert!(context.retrieve_node(other_graph.clone(), "a").is_err());
Ok(())
};
helper3(helper().unwrap()).unwrap();
let helper4 = || -> Result<()> {
let context = create_context()?;
let graph = context.create_graph()?;
let input = graph.input(scalar_type(BIT))?;
graph.set_output_node(input.clone())?;
graph.finalize()?;
context.set_main_graph(graph.clone())?;
context.finalize()?;
assert!(context.set_graph_name(graph, "main").is_err());
assert!(context.set_node_name(input, "input").is_err());
Ok(())
};
helper4().unwrap();
let helper5 = || -> Result<()> {
let context = create_context()?;
let graph = context.create_graph()?;
let input = graph.input(scalar_type(BIT))?;
let other_graph = context.create_graph()?;
let other_input = graph.input(scalar_type(BIT))?;
context.set_graph_name(graph.clone(), "main")?;
assert!(context.set_graph_name(graph, "main3").is_err());
assert!(context.set_graph_name(other_graph, "main").is_err());
context.set_node_name(input.clone(), "input")?;
assert!(context.set_node_name(input, "input3").is_err());
assert!(context.set_node_name(other_input, "input").is_err());
Ok(())
};
helper5().unwrap();
}
#[test]
fn test_context_type_checking() {
|| -> Result<()> {
let context = create_context()?;
let g = context.create_graph()?;
let i = g.input(tuple_type(vec![]))?;
assert!(g.add(i.clone(), i.clone()).is_err());
assert_eq!(g.get_nodes().len(), 1);
Ok(())
}()
.unwrap();
}
fn generate_pair_of_equal_contexts() -> Vec<(Context, Context)> {
let context1 = || -> Result<Context> {
let context = create_unchecked_context()?;
let g = context.create_graph()?;
let i = g.constant(scalar_type(BIT), Value::from_scalar(0, BIT)?)?;
g.set_output_node(i)?;
g.finalize()?;
g.set_as_main()?;
Ok(context)
}()
.unwrap();
let context2 = || -> Result<Context> {
let context = create_unchecked_context()?;
let g = context.create_graph()?;
let i = g.constant(scalar_type(BIT), Value::from_scalar(0, BIT)?)?;
i.set_as_output()?;
g.finalize()?;
context.set_main_graph(g)?;
Ok(context)
}()
.unwrap();
let context3 = || -> Result<Context> {
let context = create_unchecked_context()?;
let g = context.create_graph()?;
context.set_graph_name(g, "random graph name")?;
Ok(context)
}()
.unwrap();
let context4 = || -> Result<Context> {
let context = create_unchecked_context()?;
context.create_graph()?.set_name("random graph name")?;
Ok(context)
}()
.unwrap();
let context5 = || -> Result<Context> {
let context = create_unchecked_context()?;
let g = context.create_graph()?;
let i = g.constant(scalar_type(BIT), Value::from_scalar(0, BIT)?)?;
context.set_node_name(i, "random node name")?;
Ok(context)
}()
.unwrap();
let context6 = || -> Result<Context> {
let context = create_unchecked_context()?;
let g = context.create_graph()?;
g.constant(scalar_type(BIT), Value::from_scalar(0, BIT)?)?
.set_name("random node name")?;
Ok(context)
}()
.unwrap();
let context7 = || -> Result<Context> {
let context = create_unchecked_context()?;
let g = context.create_graph()?;
let i1 = g.input(scalar_type(BIT))?;
let i2 = g.input(scalar_type(BIT))?;
g.add(i1.clone(), i2.clone())?;
g.subtract(i1.clone(), i2.clone())?;
g.multiply(i1.clone(), i2.clone())?;
g.dot(i1.clone(), i2.clone())?;
g.matmul(i1.clone(), i2.clone())?;
g.truncate(i1.clone(), 123)?;
g.sum(i1.clone(), vec![1, 4, 7])?;
g.permute_axes(i1.clone(), vec![1, 4, 7])?;
g.get(i1.clone(), vec![1, 4])?;
g.reshape(i1.clone(), array_type(vec![12, 34], BIT))?;
g.nop(i1.clone())?;
g.prf(i1.clone(), 123, scalar_type(BIT))?;
g.a2b(i1.clone())?;
g.b2a(i1.clone(), BIT)?;
g.tuple_get(i1.clone(), 0)?;
g.named_tuple_get(i1.clone(), "field name".to_owned())?;
g.vector_get(i1.clone(), i2)?;
g.array_to_vector(i1.clone())?;
g.vector_to_array(i1.clone())?;
g.repeat(i1.clone(), 123)?;
Ok(context)
}()
.unwrap();
let context8 = || -> Result<Context> {
let context = create_unchecked_context()?;
let g = context.create_graph()?;
let i1 = g.input(scalar_type(BIT))?;
let i2 = g.input(scalar_type(BIT))?;
i1.add(i2.clone())?;
i1.subtract(i2.clone())?;
i1.multiply(i2.clone())?;
i1.dot(i2.clone())?;
i1.matmul(i2.clone())?;
i1.truncate(123)?;
i1.sum(vec![1, 4, 7])?;
i1.permute_axes(vec![1, 4, 7])?;
i1.get(vec![1, 4])?;
i1.reshape(array_type(vec![12, 34], BIT))?;
i1.nop()?;
i1.prf(123, scalar_type(BIT))?;
i1.a2b()?;
i1.b2a(BIT)?;
i1.tuple_get(0)?;
i1.named_tuple_get("field name".to_owned())?;
i1.vector_get(i2)?;
i1.array_to_vector()?;
i1.vector_to_array()?;
i1.repeat(123)?;
Ok(context)
}()
.unwrap();
let result = vec![
(context1, context2),
(context3, context4),
(context5, context6),
(context7, context8),
];
result
}
#[test]
fn test_node_graph_helpers() {
let pairs_of_contexts = generate_pair_of_equal_contexts();
for (context1, context2) in pairs_of_contexts {
assert!(contexts_deep_equal(context1, context2));
}
|| -> Result<()> {
let context = create_context()?;
let g = context.create_graph()?.set_name("graph name")?;
let i = g.input(scalar_type(BIT))?.set_name("node name")?;
assert_eq!(g.get_name()?, "graph name");
assert!(g.retrieve_node("node name")? == i);
assert!(i.get_name()? == "node name");
assert_eq!(
g.prepare_input_values(hashmap!("node name" => Value::from_scalar(1, BIT)?))?,
vec![Value::from_scalar(1, BIT)?]
);
Ok(())
}()
.unwrap();
}
#[test]
fn test_operation_fmt_display() {
let test_operation_fmt_display_helper = || -> Result<()> {
let o0 = Rc::new(Operation::Input(scalar_type(UINT16)));
assert_eq!(format!("{}", o0), "Input");
let o1 = Rc::new(Operation::Add);
assert_eq!(format!("{}", o1), "Add");
let o2 = Rc::new(Operation::Truncate(10));
assert_eq!(format!("{}", o2), "Truncate");
let o3 = Rc::new(Operation::Get(vec![10, 20]));
assert_eq!(format!("{}", o3), "Get");
let o4 = Rc::new(Operation::NOP);
assert_eq!(format!("{}", o4), "NOP");
let o5 = Rc::new(Operation::CreateNamedTuple(vec![
"Name".to_string(),
"Address".to_string(),
]));
assert_eq!(format!("{}", o5), "CreateNamedTuple");
let o6 = Rc::new(Operation::NamedTupleGet("Name".to_string()));
assert_eq!(format!("{}", o6), "NamedTupleGet");
Ok(())
};
test_operation_fmt_display_helper().unwrap();
}
#[test]
fn test_annotations() {
let test_annotations_helper = || -> Result<()> {
let context = create_context()?;
let g = context.create_graph()?;
let i = g.input(scalar_type(BIT))?;
g.add_annotation(GraphAnnotation::AssociativeOperation)?;
i.add_annotation(NodeAnnotation::AssociativeOperation)?;
assert_eq!(
g.get_annotations()?,
vec![GraphAnnotation::AssociativeOperation]
);
assert_eq!(
i.get_annotations()?,
vec![NodeAnnotation::AssociativeOperation]
);
Ok(())
};
test_annotations_helper().unwrap();
}
}