use std::collections::VecDeque;
use std::fmt;
use std::hash::{Hash, Hasher};
use ahash::{AHashMap, AHashSet};
use serde_json::Value;
use crate::errors::SqliteGraphError;
use crate::graph::SqliteGraph;
use crate::progress::ProgressCallback;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Operation {
Read,
Write,
}
impl fmt::Display for Operation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Operation::Read => write!(f, "R"),
Operation::Write => write!(f, "W"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TraceEvent {
pub event_id: i64,
pub thread_id: i64,
pub operation: Operation,
pub memory_location: i64,
pub vector_clock: VectorClock,
}
impl TraceEvent {
pub fn new(
event_id: i64,
thread_id: i64,
operation: Operation,
memory_location: i64,
vector_clock: VectorClock,
) -> Self {
Self {
event_id,
thread_id,
operation,
memory_location,
vector_clock,
}
}
pub fn with_thread(
event_id: i64,
thread_id: i64,
operation: Operation,
memory_location: i64,
) -> Self {
Self {
event_id,
thread_id,
operation,
memory_location,
vector_clock: VectorClock::new().incremented(thread_id),
}
}
}
impl Hash for TraceEvent {
fn hash<H: Hasher>(&self, state: &mut H) {
self.event_id.hash(state);
self.thread_id.hash(state);
self.operation.hash(state);
self.memory_location.hash(state);
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct VectorClock {
clocks: AHashMap<i64, u64>,
}
impl VectorClock {
pub fn new() -> Self {
Self {
clocks: AHashMap::new(),
}
}
pub fn is_empty(&self) -> bool {
self.clocks.is_empty()
}
pub fn get(&self, thread_id: i64) -> u64 {
*self.clocks.get(&thread_id).unwrap_or(&0)
}
pub fn increment(&mut self, thread_id: i64) {
*self.clocks.entry(thread_id).or_insert(0) += 1;
}
pub fn incremented(mut self, thread_id: i64) -> Self {
self.increment(thread_id);
self
}
pub fn merge(&mut self, other: &VectorClock) {
for (&thread_id, &their_clock) in &other.clocks {
let my_clock = self.clocks.entry(thread_id).or_insert(0);
*my_clock = (*my_clock).max(their_clock);
}
}
pub fn happens_before(&self, other: &VectorClock) -> bool {
let mut found_strictly_less = false;
for (&thread_id, &my_clock) in &self.clocks {
let their_clock = other.get(thread_id);
if my_clock > their_clock {
return false; }
if my_clock < their_clock {
found_strictly_less = true;
}
}
for (&thread_id, &their_clock) in &other.clocks {
if !self.clocks.contains_key(&thread_id) {
if 0 < their_clock {
found_strictly_less = true;
}
}
}
found_strictly_less
}
pub fn is_concurrent(&self, other: &VectorClock) -> bool {
!self.happens_before(other) && !other.happens_before(self)
}
pub fn threads(&self) -> impl Iterator<Item = i64> + '_ {
self.clocks.keys().copied()
}
pub fn len(&self) -> usize {
self.clocks.len()
}
}
impl Default for VectorClock {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct HappensBeforeResult {
pub concurrent_pairs: Vec<(TraceEvent, TraceEvent)>,
pub total_events: usize,
pub conflicts_detected: usize,
}
impl HappensBeforeResult {
fn new(concurrent_pairs: Vec<(TraceEvent, TraceEvent)>, total_events: usize) -> Self {
let conflicts_detected = concurrent_pairs.len();
Self {
concurrent_pairs,
total_events,
conflicts_detected,
}
}
pub fn has_races(&self) -> bool {
!self.concurrent_pairs.is_empty()
}
pub fn raced_locations(&self) -> AHashSet<i64> {
let mut locations = AHashSet::new();
for (event_a, _) in &self.concurrent_pairs {
locations.insert(event_a.memory_location);
}
locations
}
}
pub fn happens_before_analysis(
events: &[TraceEvent],
) -> Result<HappensBeforeResult, SqliteGraphError> {
let total_events = events.len();
let mut by_location: AHashMap<i64, Vec<&TraceEvent>> = AHashMap::new();
for event in events {
by_location
.entry(event.memory_location)
.or_default()
.push(event);
}
let mut concurrent_pairs = Vec::new();
for events_at_location in by_location.values() {
for (i, event_a) in events_at_location.iter().enumerate() {
for event_b in events_at_location.iter().skip(i + 1) {
if event_a.vector_clock.is_concurrent(&event_b.vector_clock) {
if event_a.operation == Operation::Write
|| event_b.operation == Operation::Write
{
concurrent_pairs.push(((*event_a).clone(), (*event_b).clone()));
}
}
}
}
}
Ok(HappensBeforeResult::new(concurrent_pairs, total_events))
}
pub type WeightCallback = dyn Fn(i64, i64, &Value) -> f64;
#[derive(Clone)]
pub struct ImpactRadiusConfig<'a> {
pub max_distance: f64,
pub max_hops: Option<usize>,
pub weight_fn: &'a WeightCallback,
}
#[derive(Debug, Clone)]
pub struct ImpactRadiusResult {
pub blast_zone: AHashSet<i64>,
pub distances: AHashMap<i64, f64>,
pub boundary: AHashSet<i64>,
pub size: usize,
}
impl ImpactRadiusResult {
pub fn is_affected(&self, node: i64) -> bool {
self.blast_zone.contains(&node)
}
pub fn distance_to(&self, node: i64) -> Option<f64> {
self.distances.get(&node).copied()
}
pub fn is_boundary(&self, node: i64) -> bool {
self.boundary.contains(&node)
}
}
pub fn impact_radius(
graph: &SqliteGraph,
source: i64,
config: &ImpactRadiusConfig,
) -> Result<ImpactRadiusResult, SqliteGraphError> {
let max_distance = config.max_distance;
let max_hops = config.max_hops;
let weight_fn = config.weight_fn;
let mut distances: AHashMap<i64, f64> = AHashMap::new();
let mut blast_zone: AHashSet<i64> = AHashSet::new();
let mut queue: VecDeque<(i64, f64, usize)> = VecDeque::new();
distances.insert(source, 0.0);
queue.push_back((source, 0.0, 0));
while let Some((node, dist, hops)) = queue.pop_front() {
if dist > max_distance {
continue;
}
if max_hops.is_some_and(|limit| hops >= limit) {
}
blast_zone.insert(node);
let outgoing = graph.fetch_outgoing(node)?;
for neighbor in outgoing {
let edge_data = &serde_json::json!({});
let weight = weight_fn(node, neighbor, edge_data);
if !weight.is_finite() {
return Err(SqliteGraphError::invalid_input(format!(
"Invalid weight for edge {} -> {}: weight must be finite, got {}",
node, neighbor, weight
)));
}
let new_dist = dist + weight;
if new_dist > max_distance {
continue;
}
if max_hops.is_some_and(|limit| hops + 1 > limit) {
continue;
}
let should_enqueue = match distances.get(&neighbor) {
Some(&old_dist) => new_dist < old_dist,
None => true,
};
if should_enqueue {
distances.insert(neighbor, new_dist);
queue.push_back((neighbor, new_dist, hops + 1));
}
}
}
blast_zone.insert(source);
distances.entry(source).or_insert(0.0);
let epsilon = 1e-9;
let boundary: AHashSet<i64> = distances
.iter()
.filter(|(_, dist)| (*dist - max_distance).abs() < epsilon)
.map(|(&node, _)| node)
.collect();
let size = blast_zone.len();
Ok(ImpactRadiusResult {
blast_zone,
distances,
boundary,
size,
})
}
pub fn impact_radius_with_progress<F>(
graph: &SqliteGraph,
source: i64,
config: &ImpactRadiusConfig,
progress: &F,
) -> Result<ImpactRadiusResult, SqliteGraphError>
where
F: ProgressCallback,
{
let max_distance = config.max_distance;
let max_hops = config.max_hops;
let weight_fn = config.weight_fn;
let mut distances: AHashMap<i64, f64> = AHashMap::new();
let mut blast_zone: AHashSet<i64> = AHashSet::new();
let mut queue: VecDeque<(i64, f64, usize)> = VecDeque::new();
let mut nodes_visited = 0;
distances.insert(source, 0.0);
queue.push_back((source, 0.0, 0));
while let Some((node, dist, hops)) = queue.pop_front() {
nodes_visited += 1;
if nodes_visited % 10 == 0 {
progress.on_progress(
nodes_visited,
None,
&format!(
"Impact radius: visited {}, blast_zone {}",
nodes_visited,
blast_zone.len()
),
);
}
if dist > max_distance {
continue;
}
if max_hops.is_some_and(|limit| hops >= limit) {
}
blast_zone.insert(node);
let outgoing = graph.fetch_outgoing(node)?;
for neighbor in outgoing {
let edge_data = &serde_json::json!({});
let weight = weight_fn(node, neighbor, edge_data);
if !weight.is_finite() {
progress.on_error(&std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Invalid weight for edge {} -> {}: {}",
node, neighbor, weight
),
));
return Err(SqliteGraphError::invalid_input(format!(
"Invalid weight for edge {} -> {}: weight must be finite, got {}",
node, neighbor, weight
)));
}
let new_dist = dist + weight;
if new_dist > max_distance {
continue;
}
if max_hops.is_some_and(|limit| hops + 1 > limit) {
continue;
}
let should_enqueue = match distances.get(&neighbor) {
Some(&old_dist) => new_dist < old_dist,
None => true,
};
if should_enqueue {
distances.insert(neighbor, new_dist);
queue.push_back((neighbor, new_dist, hops + 1));
}
}
}
progress.on_complete();
blast_zone.insert(source);
distances.entry(source).or_insert(0.0);
let epsilon = 1e-9;
let boundary: AHashSet<i64> = distances
.iter()
.filter(|(_, dist)| (*dist - max_distance).abs() < epsilon)
.map(|(&node, _)| node)
.collect();
let size = blast_zone.len();
Ok(ImpactRadiusResult {
blast_zone,
distances,
boundary,
size,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn default_weight_fn(_from: i64, _to: i64, _edge_data: &Value) -> f64 {
1.0
}
fn make_event(
event_id: i64,
thread_id: i64,
operation: Operation,
memory_location: i64,
vc_clocks: Vec<(i64, u64)>,
) -> TraceEvent {
let mut vc = VectorClock::new();
for (tid, clock) in vc_clocks {
vc.clocks.insert(tid, clock);
}
TraceEvent::new(event_id, thread_id, operation, memory_location, vc)
}
#[test]
fn test_vector_clock_new() {
let vc = VectorClock::new();
assert!(vc.is_empty(), "New vector clock should be empty");
assert_eq!(vc.len(), 0, "Length should be 0");
assert_eq!(vc.get(1), 0, "Missing thread should return 0");
}
#[test]
fn test_vector_clock_increment() {
let mut vc = VectorClock::new();
vc.increment(1);
assert_eq!(vc.get(1), 1, "First increment should set to 1");
assert!(!vc.is_empty(), "Should not be empty after increment");
vc.increment(1);
assert_eq!(vc.get(1), 2, "Second increment should set to 2");
vc.increment(2);
assert_eq!(vc.get(1), 2, "Thread 1 should still be 2");
assert_eq!(vc.get(2), 1, "Thread 2 should be 1");
}
#[test]
fn test_vector_clock_incremented() {
let vc1 = VectorClock::new().incremented(5);
assert_eq!(vc1.get(5), 1, "Thread 5 should be 1");
assert_eq!(vc1.get(1), 0, "Thread 1 should be 0");
}
#[test]
fn test_vector_clock_happens_before_simple() {
let mut vc_a = VectorClock::new();
vc_a.increment(1);
let mut vc_b = VectorClock::new();
vc_b.increment(1);
vc_b.increment(1);
assert!(vc_a.happens_before(&vc_b), "A should happen-before B");
assert!(!vc_b.happens_before(&vc_a), "B should not happen-before A");
}
#[test]
fn test_vector_clock_happens_before_partial_order() {
let mut vc_a = VectorClock::new();
vc_a.increment(1);
let mut vc_b = VectorClock::new();
vc_b.increment(1);
vc_b.increment(2);
assert!(
vc_a.happens_before(&vc_b),
"A should happen-before B (progressed on thread 2)"
);
assert!(!vc_b.happens_before(&vc_a), "B should not happen-before A");
}
#[test]
fn test_vector_clock_happens_before_equal() {
let mut vc_a = VectorClock::new();
vc_a.increment(1);
let mut vc_b = VectorClock::new();
vc_b.increment(1);
assert!(
!vc_a.happens_before(&vc_b),
"Equal clocks should not satisfy happens-before (need strict <)"
);
assert!(
!vc_b.happens_before(&vc_a),
"Equal clocks should not satisfy happens-before (need strict <)"
);
}
#[test]
fn test_vector_clock_happens_before_empty() {
let vc_empty = VectorClock::new();
let vc_nonempty = VectorClock::new().incremented(1);
assert!(
vc_empty.happens_before(&vc_nonempty),
"Empty clock should happen-before non-empty"
);
assert!(
!vc_nonempty.happens_before(&vc_empty),
"Non-empty should not happen-before empty"
);
}
#[test]
fn test_vector_clock_is_concurrent() {
let vc_a = VectorClock::new().incremented(1);
let vc_b = VectorClock::new().incremented(2);
assert!(vc_a.is_concurrent(&vc_b), "Clocks should be concurrent");
assert!(vc_b.is_concurrent(&vc_a), "Concurrency should be symmetric");
}
#[test]
fn test_vector_clock_is_concurrent_complex() {
let mut vc_a = VectorClock::new();
vc_a.increment(1);
vc_a.increment(1);
vc_a.increment(2);
let mut vc_b = VectorClock::new();
vc_b.increment(1);
vc_b.increment(2);
vc_b.increment(2);
assert!(
vc_a.is_concurrent(&vc_b),
"Should be concurrent (different ordering per thread)"
);
}
#[test]
fn test_vector_clock_is_concurrent_ordered() {
let vc_a = VectorClock::new().incremented(1);
let mut vc_b = VectorClock::new();
vc_b.increment(1);
vc_b.increment(1);
assert!(
!vc_a.is_concurrent(&vc_b),
"Ordered clocks should not be concurrent"
);
assert!(
!vc_b.is_concurrent(&vc_a),
"Ordered clocks should not be concurrent"
);
}
#[test]
fn test_vector_clock_merge() {
let mut vc_a = VectorClock::new();
vc_a.increment(1);
let vc_b = VectorClock::new().incremented(2);
vc_a.merge(&vc_b);
assert_eq!(vc_a.get(1), 1, "Thread 1 should be max(1, 0) = 1");
assert_eq!(vc_a.get(2), 1, "Thread 2 should be max(0, 1) = 1");
}
#[test]
fn test_vector_clock_merge_existing() {
let mut vc_a = VectorClock::new();
vc_a.increment(1);
let mut vc_b = VectorClock::new();
vc_b.increment(1);
vc_b.increment(1);
vc_b.increment(1);
vc_a.merge(&vc_b);
assert_eq!(vc_a.get(1), 3, "Thread 1 should be max(1, 3) = 3");
}
#[test]
fn test_vector_clock_merge_empty_into_nonempty() {
let mut vc = VectorClock::new();
vc.increment(1);
let original = vc.clone();
vc.merge(&VectorClock::new());
assert_eq!(vc.get(1), original.get(1), "Clock should be unchanged");
}
#[test]
fn test_vector_clock_merge_empty_into_empty() {
let mut vc = VectorClock::new();
vc.merge(&VectorClock::new());
assert!(vc.is_empty());
}
#[test]
fn test_trace_event_new() {
let vc = VectorClock::new().incremented(1);
let event = TraceEvent::new(10, 5, Operation::Read, 100, vc.clone());
assert_eq!(event.event_id, 10);
assert_eq!(event.thread_id, 5);
assert_eq!(event.operation, Operation::Read);
assert_eq!(event.memory_location, 100);
assert_eq!(event.vector_clock.get(1), 1);
}
#[test]
fn test_trace_event_with_thread() {
let event = TraceEvent::with_thread(1, 5, Operation::Write, 100);
assert_eq!(event.event_id, 1);
assert_eq!(event.thread_id, 5);
assert_eq!(event.operation, Operation::Write);
assert_eq!(event.memory_location, 100);
assert_eq!(event.vector_clock.get(5), 1);
}
#[test]
fn test_happens_before_empty() {
let events: Vec<TraceEvent> = vec![];
let result =
happens_before_analysis(&events).expect("Analysis should succeed on empty trace");
assert_eq!(result.total_events, 0);
assert_eq!(result.conflicts_detected, 0);
assert!(result.concurrent_pairs.is_empty());
assert!(!result.has_races());
}
#[test]
fn test_happens_before_single_event() {
let events = vec![TraceEvent::with_thread(1, 1, Operation::Write, 100)];
let result = happens_before_analysis(&events).expect("Analysis should succeed");
assert_eq!(result.total_events, 1);
assert_eq!(result.conflicts_detected, 0);
assert!(!result.has_races());
}
#[test]
fn test_happens_before_single_thread() {
let events = vec![
TraceEvent::with_thread(1, 1, Operation::Write, 100),
make_event(2, 1, Operation::Read, 100, vec![(1, 2)]),
make_event(3, 1, Operation::Write, 100, vec![(1, 3)]),
];
let result = happens_before_analysis(&events).expect("Analysis should succeed");
assert_eq!(result.total_events, 3);
assert_eq!(result.conflicts_detected, 0);
assert!(!result.has_races());
}
#[test]
fn test_happens_before_concurrent_writes() {
let events = vec![
TraceEvent::with_thread(1, 1, Operation::Write, 100),
TraceEvent::with_thread(2, 2, Operation::Write, 100),
];
let result = happens_before_analysis(&events).expect("Analysis should succeed");
assert_eq!(result.total_events, 2);
assert_eq!(result.conflicts_detected, 1);
assert!(result.has_races());
let (event_a, event_b) = &result.concurrent_pairs[0];
assert_eq!(event_a.thread_id, 1);
assert_eq!(event_b.thread_id, 2);
assert_eq!(event_a.memory_location, 100);
assert_eq!(event_b.memory_location, 100);
}
#[test]
fn test_happens_before_read_write_conflict() {
let events = vec![
TraceEvent::with_thread(1, 1, Operation::Read, 100),
TraceEvent::with_thread(2, 2, Operation::Write, 100),
];
let result = happens_before_analysis(&events).expect("Analysis should succeed");
assert_eq!(result.conflicts_detected, 1);
assert!(result.has_races());
}
#[test]
fn test_happens_before_read_only_no_race() {
let events = vec![
TraceEvent::with_thread(1, 1, Operation::Read, 100),
TraceEvent::with_thread(2, 2, Operation::Read, 100),
];
let result = happens_before_analysis(&events).expect("Analysis should succeed");
assert_eq!(result.conflicts_detected, 0);
assert!(!result.has_races());
}
#[test]
fn test_happens_before_ordered_events() {
let events = vec![
TraceEvent::with_thread(1, 1, Operation::Write, 100),
make_event(2, 1, Operation::Write, 100, vec![(1, 2)]),
];
let result = happens_before_analysis(&events).expect("Analysis should succeed");
assert_eq!(result.conflicts_detected, 0);
assert!(!result.has_races());
}
#[test]
fn test_happens_before_different_locations() {
let events = vec![
TraceEvent::with_thread(1, 1, Operation::Write, 100),
TraceEvent::with_thread(2, 2, Operation::Write, 200),
];
let result = happens_before_analysis(&events).expect("Analysis should succeed");
assert_eq!(result.conflicts_detected, 0);
assert!(!result.has_races());
}
#[test]
fn test_happens_before_multiple_locations() {
let events = vec![
TraceEvent::with_thread(1, 1, Operation::Write, 100),
TraceEvent::with_thread(2, 2, Operation::Write, 100),
TraceEvent::with_thread(1, 1, Operation::Write, 200),
make_event(4, 1, Operation::Write, 200, vec![(1, 2)]),
TraceEvent::with_thread(2, 2, Operation::Read, 300),
make_event(6, 1, Operation::Read, 300, vec![(1, 1)]),
];
let result = happens_before_analysis(&events).expect("Analysis should succeed");
assert_eq!(result.total_events, 6);
assert_eq!(
result.conflicts_detected, 1,
"Should detect 1 race at location 100"
);
assert!(result.has_races());
let raced = result.raced_locations();
assert!(raced.contains(&100));
assert!(!raced.contains(&200));
assert!(!raced.contains(&300));
}
#[test]
fn test_happens_before_synchronized_threads() {
let events = vec![
TraceEvent::with_thread(1, 1, Operation::Write, 100),
make_event(2, 1, Operation::Write, 100, vec![(1, 2), (2, 1)]),
make_event(3, 2, Operation::Write, 100, vec![(1, 2), (2, 2)]),
];
let result = happens_before_analysis(&events).expect("Analysis should succeed");
assert_eq!(
result.conflicts_detected, 0,
"Synchronized access should not race"
);
assert!(!result.has_races());
}
#[test]
fn test_happens_before_three_threads() {
let events = vec![
TraceEvent::with_thread(1, 1, Operation::Write, 100),
TraceEvent::with_thread(2, 2, Operation::Write, 100),
TraceEvent::with_thread(3, 3, Operation::Write, 100),
];
let result = happens_before_analysis(&events).expect("Analysis should succeed");
assert_eq!(result.conflicts_detected, 3);
assert!(result.has_races());
}
#[test]
fn test_happens_before_result_raced_locations() {
let events = vec![
TraceEvent::with_thread(1, 1, Operation::Write, 100),
TraceEvent::with_thread(2, 2, Operation::Write, 100),
TraceEvent::with_thread(1, 1, Operation::Write, 200),
TraceEvent::with_thread(2, 2, Operation::Write, 200),
TraceEvent::with_thread(1, 1, Operation::Write, 300),
TraceEvent::with_thread(2, 2, Operation::Write, 300),
];
let result = happens_before_analysis(&events).expect("Analysis should succeed");
let locations = result.raced_locations();
assert_eq!(locations.len(), 3);
assert!(locations.contains(&100));
assert!(locations.contains(&200));
assert!(locations.contains(&300));
}
#[test]
fn test_operation_display() {
assert_eq!(format!("{}", Operation::Read), "R");
assert_eq!(format!("{}", Operation::Write), "W");
}
fn create_impact_chain() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..5 {
let entity = crate::GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
for i in 0..entity_ids.len().saturating_sub(1) {
let edge = crate::GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[i + 1],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_impact_diamond() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = crate::GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let edges = vec![(0, 1), (0, 2), (1, 3), (2, 3)];
for (from_idx, to_idx) in edges {
let edge = crate::GraphEdge {
id: 0,
from_id: entity_ids[from_idx],
to_id: entity_ids[to_idx],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_impact_disconnected() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = crate::GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let edge1 = crate::GraphEdge {
id: 0,
from_id: entity_ids[0],
to_id: entity_ids[1],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge1).expect("Failed to insert edge");
let edge2 = crate::GraphEdge {
id: 0,
from_id: entity_ids[2],
to_id: entity_ids[3],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge2).expect("Failed to insert edge");
graph
}
#[test]
fn test_impact_radius_empty() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let config = ImpactRadiusConfig {
max_distance: 10.0,
max_hops: None,
weight_fn: &default_weight_fn,
};
let result = impact_radius(&graph, 999, &config)
.expect("Impact radius should succeed on empty graph");
assert_eq!(result.size, 1, "Empty graph should have blast zone size 1");
assert!(
result.blast_zone.contains(&999),
"Source should be in blast zone"
);
assert_eq!(
*result.distances.get(&999).unwrap(),
0.0,
"Source distance should be 0"
);
assert!(
result.boundary.is_empty(),
"Empty graph should have no boundary nodes"
);
}
#[test]
fn test_impact_radius_linear_chain() {
let graph = create_impact_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let config = ImpactRadiusConfig {
max_distance: 2.0,
max_hops: None,
weight_fn: &default_weight_fn,
};
let result =
impact_radius(&graph, entity_ids[0], &config).expect("Impact radius should succeed");
assert_eq!(result.size, 3, "Should have 3 nodes within 2 hops");
assert!(
result.blast_zone.contains(&entity_ids[0]),
"Node 0 should be in blast zone"
);
assert!(
result.blast_zone.contains(&entity_ids[1]),
"Node 1 should be in blast zone"
);
assert!(
result.blast_zone.contains(&entity_ids[2]),
"Node 2 should be in blast zone"
);
assert!(
!result.blast_zone.contains(&entity_ids[3]),
"Node 3 should NOT be in blast zone"
);
assert!(
!result.blast_zone.contains(&entity_ids[4]),
"Node 4 should NOT be in blast zone"
);
assert_eq!(
*result.distances.get(&entity_ids[0]).unwrap(),
0.0,
"Node 0 distance = 0"
);
assert_eq!(
*result.distances.get(&entity_ids[1]).unwrap(),
1.0,
"Node 1 distance = 1"
);
assert_eq!(
*result.distances.get(&entity_ids[2]).unwrap(),
2.0,
"Node 2 distance = 2"
);
}
#[test]
fn test_impact_radius_boundary_detection() {
let graph = create_impact_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let config = ImpactRadiusConfig {
max_distance: 2.0,
max_hops: None,
weight_fn: &default_weight_fn,
};
let result =
impact_radius(&graph, entity_ids[0], &config).expect("Impact radius should succeed");
assert_eq!(result.boundary.len(), 1, "Should have 1 boundary node");
assert!(
result.boundary.contains(&entity_ids[2]),
"Node 2 should be on boundary"
);
assert!(
!result.boundary.contains(&entity_ids[0]),
"Node 0 should NOT be on boundary"
);
assert!(
!result.boundary.contains(&entity_ids[1]),
"Node 1 should NOT be on boundary"
);
}
#[test]
fn test_impact_radius_max_hops() {
let graph = create_impact_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let config = ImpactRadiusConfig {
max_distance: 10.0,
max_hops: Some(2),
weight_fn: &default_weight_fn,
};
let result =
impact_radius(&graph, entity_ids[0], &config).expect("Impact radius should succeed");
assert_eq!(result.size, 3, "Should have 3 nodes due to hop limit");
assert!(
result.blast_zone.contains(&entity_ids[0]),
"Node 0 should be in blast zone"
);
assert!(
result.blast_zone.contains(&entity_ids[1]),
"Node 1 should be in blast zone"
);
assert!(
result.blast_zone.contains(&entity_ids[2]),
"Node 2 should be in blast zone"
);
assert!(
!result.blast_zone.contains(&entity_ids[3]),
"Node 3 should NOT be in blast zone"
);
}
#[test]
fn test_impact_radius_weight_extraction() {
let graph = create_impact_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let custom_weight_fn = |_from: i64, _to: i64, _edge_data: &Value| -> f64 { 2.0 };
let config = ImpactRadiusConfig {
max_distance: 4.0,
max_hops: None,
weight_fn: &custom_weight_fn,
};
let result =
impact_radius(&graph, entity_ids[0], &config).expect("Impact radius should succeed");
assert_eq!(result.size, 3, "Should have 3 nodes with custom weights");
assert_eq!(
*result.distances.get(&entity_ids[1]).unwrap(),
2.0,
"Node 1 distance = 2"
);
assert_eq!(
*result.distances.get(&entity_ids[2]).unwrap(),
4.0,
"Node 2 distance = 4"
);
}
#[test]
fn test_impact_radius_unweighted() {
let graph = create_impact_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let config = ImpactRadiusConfig {
max_distance: 5.0,
max_hops: None,
weight_fn: &default_weight_fn,
};
let result =
impact_radius(&graph, entity_ids[0], &config).expect("Impact radius should succeed");
for i in 0..5 {
let expected_dist = i as f64;
let actual_dist = result
.distances
.get(&entity_ids[i])
.copied()
.unwrap_or(999.0);
assert_eq!(
actual_dist, expected_dist,
"Node {} should have distance {}",
i, expected_dist
);
}
}
#[test]
fn test_impact_radius_disconnected() {
let graph = create_impact_disconnected();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let config = ImpactRadiusConfig {
max_distance: 10.0,
max_hops: None,
weight_fn: &default_weight_fn,
};
let result =
impact_radius(&graph, entity_ids[0], &config).expect("Impact radius should succeed");
assert_eq!(result.size, 2, "Should only reach component 1 (2 nodes)");
assert!(
result.blast_zone.contains(&entity_ids[0]),
"Node 0 should be in blast zone"
);
assert!(
result.blast_zone.contains(&entity_ids[1]),
"Node 1 should be in blast zone"
);
assert!(
!result.blast_zone.contains(&entity_ids[2]),
"Node 2 should NOT be in blast zone (different component)"
);
assert!(
!result.blast_zone.contains(&entity_ids[3]),
"Node 3 should NOT be in blast zone (different component)"
);
}
#[test]
fn test_impact_radius_diamond() {
let graph = create_impact_diamond();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let config = ImpactRadiusConfig {
max_distance: 10.0,
max_hops: None,
weight_fn: &default_weight_fn,
};
let result =
impact_radius(&graph, entity_ids[0], &config).expect("Impact radius should succeed");
assert_eq!(result.size, 4, "Should reach all 4 nodes");
for &id in &entity_ids {
assert!(
result.blast_zone.contains(&id),
"Node {} should be in blast zone",
id
);
}
assert_eq!(
*result.distances.get(&entity_ids[0]).unwrap(),
0.0,
"Node 0 distance = 0"
);
assert_eq!(
*result.distances.get(&entity_ids[1]).unwrap(),
1.0,
"Node 1 distance = 1"
);
assert_eq!(
*result.distances.get(&entity_ids[2]).unwrap(),
1.0,
"Node 2 distance = 1"
);
assert_eq!(
*result.distances.get(&entity_ids[3]).unwrap(),
2.0,
"Node 3 distance = 2"
);
}
#[test]
fn test_impact_radius_result_is_affected() {
let graph = create_impact_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let config = ImpactRadiusConfig {
max_distance: 2.0,
max_hops: None,
weight_fn: &default_weight_fn,
};
let result =
impact_radius(&graph, entity_ids[0], &config).expect("Impact radius should succeed");
assert!(
result.is_affected(entity_ids[0]),
"Node 0 should be affected"
);
assert!(
result.is_affected(entity_ids[1]),
"Node 1 should be affected"
);
assert!(
result.is_affected(entity_ids[2]),
"Node 2 should be affected"
);
assert!(
!result.is_affected(entity_ids[3]),
"Node 3 should NOT be affected"
);
assert!(
!result.is_affected(entity_ids[4]),
"Node 4 should NOT be affected"
);
}
#[test]
fn test_impact_radius_result_distance_to() {
let graph = create_impact_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let config = ImpactRadiusConfig {
max_distance: 2.0,
max_hops: None,
weight_fn: &default_weight_fn,
};
let result =
impact_radius(&graph, entity_ids[0], &config).expect("Impact radius should succeed");
assert_eq!(
result.distance_to(entity_ids[0]),
Some(0.0),
"Node 0 distance"
);
assert_eq!(
result.distance_to(entity_ids[1]),
Some(1.0),
"Node 1 distance"
);
assert_eq!(
result.distance_to(entity_ids[2]),
Some(2.0),
"Node 2 distance"
);
assert_eq!(
result.distance_to(entity_ids[3]),
None,
"Node 3 should return None"
);
assert_eq!(
result.distance_to(999),
None,
"Non-existent node should return None"
);
}
#[test]
fn test_impact_radius_result_is_boundary() {
let graph = create_impact_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let config = ImpactRadiusConfig {
max_distance: 2.0,
max_hops: None,
weight_fn: &default_weight_fn,
};
let result =
impact_radius(&graph, entity_ids[0], &config).expect("Impact radius should succeed");
assert!(
!result.is_boundary(entity_ids[0]),
"Node 0 should NOT be boundary"
);
assert!(
!result.is_boundary(entity_ids[1]),
"Node 1 should NOT be boundary"
);
assert!(
result.is_boundary(entity_ids[2]),
"Node 2 should be boundary"
);
assert!(
!result.is_boundary(entity_ids[3]),
"Node 3 should NOT be boundary"
);
}
#[test]
fn test_impact_radius_with_progress() {
use crate::progress::NoProgress;
let graph = create_impact_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let config = ImpactRadiusConfig {
max_distance: 3.0,
max_hops: None,
weight_fn: &default_weight_fn,
};
let progress = NoProgress;
let result_with = impact_radius_with_progress(&graph, entity_ids[0], &config, &progress)
.expect("Impact radius with progress should succeed");
let result_without =
impact_radius(&graph, entity_ids[0], &config).expect("Impact radius should succeed");
assert_eq!(
result_with.size, result_without.size,
"Progress and non-progress results should match"
);
assert_eq!(
result_with.blast_zone, result_without.blast_zone,
"Blast zones should match"
);
assert_eq!(
result_with.boundary, result_without.boundary,
"Boundaries should match"
);
}
#[test]
fn test_impact_radius_zero_max_distance() {
let graph = create_impact_chain();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let config = ImpactRadiusConfig {
max_distance: 0.0,
max_hops: None,
weight_fn: &default_weight_fn,
};
let result =
impact_radius(&graph, entity_ids[0], &config).expect("Impact radius should succeed");
assert_eq!(result.size, 1, "Should only have source node");
assert!(
result.blast_zone.contains(&entity_ids[0]),
"Source should be in blast zone"
);
assert!(
!result.blast_zone.contains(&entity_ids[1]),
"Node 1 should NOT be in blast zone"
);
}
#[test]
fn test_impact_radius_shorter_path() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = crate::GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let edge1 = crate::GraphEdge {
id: 0,
from_id: entity_ids[0],
to_id: entity_ids[1],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge1).expect("Failed to insert edge");
let edge2 = crate::GraphEdge {
id: 0,
from_id: entity_ids[1],
to_id: entity_ids[3],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge2).expect("Failed to insert edge");
let edge3 = crate::GraphEdge {
id: 0,
from_id: entity_ids[0],
to_id: entity_ids[2],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge3).expect("Failed to insert edge");
let edge4 = crate::GraphEdge {
id: 0,
from_id: entity_ids[2],
to_id: entity_ids[3],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge4).expect("Failed to insert edge");
let edge5 = crate::GraphEdge {
id: 0,
from_id: entity_ids[0],
to_id: entity_ids[3],
edge_type: "direct".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge5).expect("Failed to insert edge");
let config = ImpactRadiusConfig {
max_distance: 10.0,
max_hops: None,
weight_fn: &default_weight_fn,
};
let result =
impact_radius(&graph, entity_ids[0], &config).expect("Impact radius should succeed");
assert_eq!(
*result.distances.get(&entity_ids[3]).unwrap(),
1.0,
"Node 3 should have distance 1 via direct edge"
);
}
}