use sqlitegraph::{
EdgeSpec, NodeSpec, SnapshotId,
backend::{BackendDirection, GraphBackend, SqliteGraphBackend},
};
fn create_test_graph() -> Result<SqliteGraphBackend, Box<dyn std::error::Error>> {
let backend = SqliteGraphBackend::in_memory()?;
let mut node_ids = Vec::new();
for i in 1..=5 {
let node_id = backend.insert_node(NodeSpec {
kind: "Node".to_string(),
name: format!("node_{}", i),
file_path: None,
data: serde_json::json!({"id": i}),
})?;
node_ids.push(node_id);
}
for i in 0..4 {
backend.insert_edge(EdgeSpec {
from: node_ids[i],
to: node_ids[i + 1],
edge_type: "chain".to_string(),
data: serde_json::json!({"order": i}),
})?;
}
Ok(backend)
}
#[test]
fn test_query_cache_bfs_hit_correctness() -> Result<(), Box<dyn std::error::Error>> {
let backend = create_test_graph()?;
let result1 = backend.bfs(SnapshotId::current(), 1, 3)?;
let result2 = backend.bfs(SnapshotId::current(), 1, 3)?;
assert_eq!(result1, result2, "Cached BFS result should match original");
assert!(result1.contains(&1), "Should include start node");
assert!(result1.contains(&2), "Should include node 2");
assert!(result1.contains(&3), "Should include node 3");
assert!(result1.contains(&4), "Should include node 4");
assert!(!result1.contains(&5), "Should not include node 5 (depth 4)");
Ok(())
}
#[test]
fn test_query_cache_k_hop_hit_correctness() -> Result<(), Box<dyn std::error::Error>> {
let backend = create_test_graph()?;
let result1 = backend.k_hop(SnapshotId::current(), 1, 2, BackendDirection::Outgoing)?;
let result2 = backend.k_hop(SnapshotId::current(), 1, 2, BackendDirection::Outgoing)?;
assert_eq!(
result1, result2,
"Cached k-hop result should match original"
);
println!("k-hop result: {:?}", result1);
assert!(result1.contains(&2), "Should include node 2 (depth 1)");
assert!(result1.contains(&3), "Should include node 3 (depth 2)");
if result1.len() >= 2 {
assert!(result1.contains(&2), "Should include node 2 (depth 1)");
}
assert!(!result1.contains(&1), "Should not include start node");
Ok(())
}
#[test]
fn test_query_cache_mvcc_invalidation() -> Result<(), Box<dyn std::error::Error>> {
let backend = create_test_graph()?;
let initial_result = backend.bfs(SnapshotId::current(), 1, 3)?;
let initial_count = initial_result.len();
backend.insert_edge(EdgeSpec {
from: 1,
to: 4, edge_type: "shortcut".to_string(),
data: serde_json::json!({"direct": true}),
})?;
let modified_result = backend.bfs(SnapshotId::current(), 1, 3)?;
assert_ne!(
initial_count,
modified_result.len(),
"Cache should be invalidated after graph mutation"
);
for node in [1, 2, 3, 4] {
assert!(
initial_result.contains(&node),
"Initial result should contain node {}",
node
);
assert!(
modified_result.contains(&node),
"Modified result should contain node {}",
node
);
}
Ok(())
}
#[test]
fn test_query_cache_different_parameters() -> Result<(), Box<dyn std::error::Error>> {
let backend = create_test_graph()?;
let result_depth_2 = backend.bfs(SnapshotId::current(), 1, 2)?;
let result_depth_3 = backend.bfs(SnapshotId::current(), 1, 3)?;
let result_start_2 = backend.bfs(SnapshotId::current(), 2, 2)?;
assert_ne!(
result_depth_2, result_depth_3,
"Different depths should produce different results"
);
assert_ne!(
result_depth_2, result_start_2,
"Different start nodes should produce different results"
);
assert_ne!(
result_depth_3, result_start_2,
"Different start nodes should produce different results"
);
assert!(
result_depth_2.len() < result_depth_3.len(),
"Deeper search should find more nodes"
);
Ok(())
}
#[test]
fn test_query_cache_shortest_path() -> Result<(), Box<dyn std::error::Error>> {
let backend = create_test_graph()?;
let result1 = backend.shortest_path(SnapshotId::current(), 1, 4)?;
let result2 = backend.shortest_path(SnapshotId::current(), 1, 4)?;
assert_eq!(
result1, result2,
"Cached shortest path result should match original"
);
assert!(result1.is_some(), "Should find path from 1 to 4");
let path = result1.unwrap();
assert_eq!(path[0], 1, "Path should start at node 1");
assert_eq!(path.last().unwrap(), &4, "Path should end at node 4");
Ok(())
}
#[test]
fn test_query_cache_filtered_k_hop() -> Result<(), Box<dyn std::error::Error>> {
let backend = create_test_graph()?;
let result1 = backend.k_hop_filtered(
SnapshotId::current(),
1,
2,
BackendDirection::Outgoing,
&["chain"],
)?;
let result2 = backend.k_hop_filtered(
SnapshotId::current(),
1,
2,
BackendDirection::Outgoing,
&["chain"],
)?;
assert_eq!(
result1, result2,
"Cached filtered k-hop result should match original"
);
println!("filtered k-hop result: {:?}", result1);
assert!(result1.contains(&2), "Should include node 2 via chain");
assert!(result1.contains(&3), "Should include node 3 via chain");
if result1.len() >= 2 {
assert!(result1.contains(&2), "Should include node 2 via chain");
}
Ok(())
}
#[test]
fn test_query_cache_after_edge_removal() -> Result<(), Box<dyn std::error::Error>> {
let backend = create_test_graph()?;
let initial_result = backend.bfs(SnapshotId::current(), 1, 3)?;
let after_mutation_result = backend.bfs(SnapshotId::current(), 1, 3)?;
assert!(
!initial_result.is_empty(),
"Initial result should not be empty"
);
assert!(
!after_mutation_result.is_empty(),
"After mutation result should not be empty"
);
Ok(())
}
#[test]
fn test_query_cache_concurrent_safety() -> Result<(), Box<dyn std::error::Error>> {
let backend = create_test_graph()?;
for _ in 0..5 {
let _ = backend.bfs(SnapshotId::current(), 1, 2)?;
let _ = backend.k_hop(SnapshotId::current(), 1, 1, BackendDirection::Outgoing)?;
}
Ok(())
}