use super::*;
#[test]
fn test_sage_aggregation_default() {
assert_eq!(SAGEAggregation::default(), SAGEAggregation::Mean);
}
#[test]
fn test_gnn_large_graph() {
let gcn = GCNConv::new(16, 8);
let mut edges = Vec::new();
for i in 0..100 {
edges.push([i, (i + 1) % 100]); if i < 50 {
edges.push([i, i + 50]); }
}
let x = create_test_tensor(&[100, 16], 1);
let adj = AdjacencyMatrix::from_edge_index(&edges, 100);
let out = gcn.forward(&x, &adj);
assert_eq!(out.shape(), &[100, 8]);
}
#[test]
fn test_gat_multiple_heads_attention() {
let gat = GATConv::new(4, 2, 4);
let x = Tensor::new(
&[
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
],
&[4, 4],
);
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1], [1, 2], [2, 3], [3, 0]], 4);
let out = gat.forward(&x, &adj);
assert_eq!(out.shape(), &[4, 8]);
let out_data = out.data();
let has_variance = out_data.windows(2).any(|w| (w[0] - w[1]).abs() > 1e-6);
assert!(has_variance, "Multi-head output should have variance");
}
#[test]
fn test_sage_lstm_aggregation() {
let sage = SAGEConv::new(4, 2).with_aggregation(SAGEAggregation::Lstm);
assert_eq!(sage.aggregation(), SAGEAggregation::Lstm);
let x = Tensor::new(
&[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ],
&[3, 4],
);
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1], [1, 2], [2, 0]], 3);
let out = sage.forward(&x, &adj);
assert_eq!(out.shape(), &[3, 2]);
}
#[test]
fn test_sage_without_bias() {
let sage = SAGEConv::new(4, 2).without_bias();
let x = create_test_tensor(&[3, 4], 1);
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1], [1, 2]], 3);
let out = sage.forward(&x, &adj);
assert_eq!(out.shape(), &[3, 2]);
}
#[test]
fn test_sage_without_bias_no_bias_added() {
let sage_with_bias = SAGEConv::new(4, 2);
let sage_without_bias = SAGEConv::new(4, 2).without_bias();
let x = Tensor::new(&[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], &[2, 4]);
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1], [1, 0]], 2);
let out_with = sage_with_bias.forward(&x, &adj);
let out_without = sage_without_bias.forward(&x, &adj);
assert_eq!(out_with.shape(), out_without.shape());
}
#[test]
fn test_adjacency_matrix_add_self_loops_idempotent() {
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1], [1, 2]], 3).add_self_loops();
assert!(adj.has_self_loops());
let edge_count_before = adj.num_edges();
let adj2 = adj.add_self_loops();
assert!(adj2.has_self_loops());
assert_eq!(adj2.num_edges(), edge_count_before);
}
#[test]
fn test_adjacency_matrix_add_self_loops_with_weights() {
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1], [1, 2]], 3)
.with_weights(vec![0.5, 2.0])
.add_self_loops();
assert!(adj.has_self_loops());
assert_eq!(adj.num_edges(), 5);
let dense = adj.to_dense();
assert!((dense.get(0, 0) - 1.0).abs() < 0.01); assert!((dense.get(1, 1) - 1.0).abs() < 0.01); assert!((dense.get(2, 2) - 1.0).abs() < 0.01); }
#[test]
fn test_gcn_weight_accessor() {
let gcn = GCNConv::new(4, 2);
let weight = gcn.weight();
assert_eq!(weight.shape(), &[4, 2]);
}
#[test]
fn test_gcn_bias_accessor() {
let gcn = GCNConv::new(4, 2);
let bias = gcn.bias();
assert!(bias.is_some());
assert_eq!(bias.unwrap().shape(), &[2]);
}
#[test]
fn test_gcn_without_bias_accessor() {
let gcn = GCNConv::new(4, 2).without_bias();
assert!(gcn.bias().is_none());
}
#[test]
fn test_adjacency_matrix_edge_accessors() {
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1], [2, 3]], 4);
assert_eq!(adj.edge_src(), &[0, 2]);
assert_eq!(adj.edge_tgt(), &[1, 3]);
}
#[test]
fn test_adjacency_matrix_neighbors_isolated_node() {
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1]], 3);
let neighbors = adj.neighbors(2); assert!(neighbors.is_empty());
}
#[test]
fn test_adjacency_matrix_degrees_skip_out_of_bounds() {
let adj = AdjacencyMatrix::from_coo(vec![0, 1, 5], vec![1, 5, 0], 3);
let in_deg = adj.in_degrees();
assert_eq!(in_deg[1], 1.0);
assert_eq!(in_deg[0], 1.0);
assert_eq!(in_deg[2], 0.0);
let out_deg = adj.out_degrees();
assert_eq!(out_deg[0], 1.0);
assert_eq!(out_deg[1], 1.0);
}
#[test]
fn test_gat_with_dropout() {
let gat = GATConv::new(8, 4, 2).with_dropout(0.5);
let x = create_test_tensor(&[3, 8], 1);
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1], [1, 2]], 3);
let out = gat.forward(&x, &adj);
assert_eq!(out.shape(), &[3, 8]);
}
#[test]
fn test_gat_without_concat_with_bias() {
let gat = GATConv::new(4, 2, 3).without_concat(); let x = Tensor::new(
&[
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ],
&[3, 4],
);
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1], [1, 2], [2, 0]], 3);
let out = gat.forward(&x, &adj);
assert_eq!(out.shape(), &[3, 2]); }
#[test]
fn test_gat_without_concat_without_bias() {
let gat = GATConv::new(4, 2, 2).without_concat().without_bias();
let x = create_test_tensor(&[3, 4], 1);
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1], [1, 2]], 3);
let out = gat.forward(&x, &adj);
assert_eq!(out.shape(), &[3, 2]);
}
#[test]
fn test_adjacency_matrix_debug_clone() {
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1]], 2);
let debug_str = format!("{:?}", adj);
assert!(debug_str.contains("AdjacencyMatrix"));
let cloned = adj.clone();
assert_eq!(cloned.num_nodes(), 2);
assert_eq!(cloned.num_edges(), 1);
}
#[test]
fn test_gcn_conv_debug_clone() {
let gcn = GCNConv::new(4, 2);
let debug_str = format!("{:?}", gcn);
assert!(debug_str.contains("GCNConv"));
let cloned = gcn.clone();
assert_eq!(cloned.in_features(), 4);
assert_eq!(cloned.out_features(), 2);
}
#[test]
fn test_sage_conv_debug_clone() {
let sage = SAGEConv::new(4, 2);
let debug_str = format!("{:?}", sage);
assert!(debug_str.contains("SAGEConv"));
let cloned = sage.clone();
assert_eq!(cloned.in_features(), 4);
assert_eq!(cloned.out_features(), 2);
}
#[test]
fn test_gat_conv_debug_clone() {
let gat = GATConv::new(4, 2, 3);
let debug_str = format!("{:?}", gat);
assert!(debug_str.contains("GATConv"));
let cloned = gat.clone();
assert_eq!(cloned.in_features(), 4);
assert_eq!(cloned.out_features(), 2);
assert_eq!(cloned.num_heads(), 3);
}
#[test]
fn test_sage_aggregation_debug_clone() {
let agg = SAGEAggregation::Max;
let debug_str = format!("{:?}", agg);
assert!(debug_str.contains("Max"));
let cloned = agg;
assert_eq!(cloned, SAGEAggregation::Max);
}
#[test]
fn test_gcn_forward_with_preexisting_self_loops() {
let gcn = GCNConv::new(4, 2);
let x = create_test_tensor(&[3, 4], 1);
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1], [1, 2]], 3).add_self_loops();
let out = gcn.forward(&x, &adj);
assert_eq!(out.shape(), &[3, 2]);
}
#[test]
fn test_gat_forward_with_preexisting_self_loops() {
let gat = GATConv::new(4, 2, 2);
let x = create_test_tensor(&[3, 4], 1);
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1], [1, 2]], 3).add_self_loops();
let out = gat.forward(&x, &adj);
assert_eq!(out.shape(), &[3, 4]); }
#[test]
fn test_adjacency_matrix_to_dense_ignores_out_of_bounds() {
let adj = AdjacencyMatrix::from_coo(vec![0, 5], vec![1, 0], 3);
let dense = adj.to_dense();
assert_eq!(dense.n_rows(), 3);
assert_eq!(dense.n_cols(), 3);
assert!((dense.get(0, 1) - 1.0).abs() < 0.01);
}
#[test]
fn test_sage_forward_isolated_nodes() {
let sage = SAGEConv::new(4, 2).with_aggregation(SAGEAggregation::Max);
let x = create_test_tensor(&[4, 4], 1);
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1]], 4);
let out = sage.forward(&x, &adj);
assert_eq!(out.shape(), &[4, 2]);
}
#[test]
fn test_gat_leaky_relu_negative() {
let gat = GATConv::new(4, 2, 1).with_negative_slope(0.2);
let x = Tensor::new(&[-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0], &[2, 4]);
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1], [1, 0]], 2);
let out = gat.forward(&x, &adj);
assert_eq!(out.shape(), &[2, 2]);
}
#[test]
fn test_gcn_forward_with_edge_weights() {
let gcn = GCNConv::new(4, 2);
let x = create_test_tensor(&[3, 4], 1);
let adj = AdjacencyMatrix::from_edge_index(&[[0, 1], [1, 2]], 3).with_weights(vec![0.5, 2.0]);
let out = gcn.forward(&x, &adj);
assert_eq!(out.shape(), &[3, 2]);
}