use super::*;
use crate::category::core::{Dtype, NdArrayType, ScalarOp, Shape, TensorOp};
use open_hypergraphs::lax::OpenHypergraph;
fn print_ssa(ssa: &[SSA<NdArrayType, TensorOp>]) {
println!(
"{}",
ssa.iter()
.map(|ssa| format!("{ssa}"))
.collect::<Vec<_>>()
.join("\n")
);
}
#[test]
fn test_simple_operation_ssa() {
let input_type = NdArrayType {
shape: Shape(vec![2, 2]),
dtype: Dtype::F32,
};
let output_type = NdArrayType {
shape: Shape(vec![2, 2]),
dtype: Dtype::F32,
};
let mut graph = OpenHypergraph::empty();
let input_node = graph.new_node(input_type);
let output_node = graph.new_node(output_type);
let _ = graph.new_edge(
TensorOp::Map(ScalarOp::Neg),
lax::Hyperedge {
sources: vec![input_node],
targets: vec![output_node],
},
);
graph.sources = vec![input_node];
graph.targets = vec![output_node];
let strict_graph = graph.to_strict();
let ssa_form = ssa(strict_graph).expect("cycle found");
println!("SSA Decomposition:");
print_ssa(&ssa_form);
assert_eq!(ssa_form.len(), 1); }
#[test]
fn test_matmul_and_pointwise_sum_ssa() {
let matrix_a_type = NdArrayType {
shape: Shape(vec![2, 3]),
dtype: Dtype::F32,
};
let matrix_b_type = NdArrayType {
shape: Shape(vec![3, 4]),
dtype: Dtype::F32,
};
let result_matmul_type = NdArrayType {
shape: Shape(vec![2, 4]),
dtype: Dtype::F32,
};
let vector_type = NdArrayType {
shape: Shape(vec![2, 4]),
dtype: Dtype::F32,
};
let final_result_type = NdArrayType {
shape: Shape(vec![2, 4]),
dtype: Dtype::F32,
};
let mut graph = OpenHypergraph::empty();
let a_node = graph.new_node(matrix_a_type);
let b_node = graph.new_node(matrix_b_type);
let c_node = graph.new_node(vector_type);
let matmul_result_node = graph.new_node(result_matmul_type);
let final_result_node = graph.new_node(final_result_type);
let _matmul_edge = graph.new_edge(
TensorOp::MatMul,
lax::Hyperedge {
sources: vec![a_node, b_node],
targets: vec![matmul_result_node],
},
);
let _sum_edge = graph.new_edge(
TensorOp::Map(ScalarOp::Add),
lax::Hyperedge {
sources: vec![matmul_result_node, c_node],
targets: vec![final_result_node],
},
);
graph.sources = vec![a_node, b_node, c_node];
graph.targets = vec![final_result_node];
let strict_graph = graph.to_strict();
let ssa_form = ssa(strict_graph).expect("cycle found");
println!("SSA Decomposition:");
print_ssa(&ssa_form);
assert_eq!(ssa_form.len(), 2); }