use std::collections::{HashMap, HashSet};
use std::fmt;
use crate::{errors::SqliteGraphError, graph::SqliteGraph, progress::ProgressCallback};
use super::subgraph_isomorphism::{SubgraphPatternBounds, find_subgraph_patterns};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RewriteBounds {
pub max_matches: Option<usize>,
pub validate_after_rewrite: bool,
}
impl Default for RewriteBounds {
fn default() -> Self {
Self {
max_matches: Some(10),
validate_after_rewrite: true,
}
}
}
impl RewriteBounds {
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn with_max_matches(mut self, max: usize) -> Self {
self.max_matches = Some(max);
self
}
#[inline]
pub fn with_validation(mut self, validate: bool) -> Self {
self.validate_after_rewrite = validate;
self
}
#[inline]
pub fn unlimited(mut self) -> Self {
self.max_matches = None;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RewriteOperation {
NodeDeleted(i64),
NodeAdded(i64),
EdgeDeleted { from: i64, to: i64 },
EdgeAdded { from: i64, to: i64 },
}
impl fmt::Display for RewriteOperation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NodeDeleted(id) => write!(f, "Deleted node {}", id),
Self::NodeAdded(id) => write!(f, "Added node {}", id),
Self::EdgeDeleted { from, to } => write!(f, "Deleted edge {} -> {}", from, to),
Self::EdgeAdded { from, to } => write!(f, "Added edge {} -> {}", from, to),
}
}
}
pub struct RewriteRule {
pub pattern: SqliteGraph,
pub replacement: SqliteGraph,
pub interface: Vec<(usize, usize)>,
}
impl RewriteRule {
#[inline]
pub fn interface_size(&self) -> usize {
self.interface.len()
}
fn validate_interface(&self) -> Result<(), SqliteGraphError> {
let pattern_ids = self.pattern.all_entity_ids()?;
let replacement_ids = self.replacement.all_entity_ids()?;
let pattern_count = pattern_ids.len();
let replacement_count = replacement_ids.len();
for &(pattern_idx, replacement_idx) in &self.interface {
if pattern_idx >= pattern_count {
return Err(SqliteGraphError::invalid_input(format!(
"Interface pattern index {} out of bounds (pattern has {} nodes)",
pattern_idx, pattern_count
)));
}
if replacement_idx >= replacement_count {
return Err(SqliteGraphError::invalid_input(format!(
"Interface replacement index {} out of bounds (replacement has {} nodes)",
replacement_idx, replacement_count
)));
}
}
Ok(())
}
}
pub struct RewriteResult {
pub rewritten_graph: SqliteGraph,
pub patterns_replaced: usize,
pub operations_applied: Vec<RewriteOperation>,
pub validation_errors: Vec<String>,
}
impl RewriteResult {
#[inline]
pub fn is_valid(&self) -> bool {
self.validation_errors.is_empty()
}
#[inline]
pub fn operation_count(&self) -> usize {
self.operations_applied.len()
}
#[inline]
pub fn nodes_added(&self) -> usize {
self.operations_applied
.iter()
.filter(|op| matches!(op, RewriteOperation::NodeAdded(_)))
.count()
}
#[inline]
pub fn nodes_deleted(&self) -> usize {
self.operations_applied
.iter()
.filter(|op| matches!(op, RewriteOperation::NodeDeleted(_)))
.count()
}
#[inline]
pub fn edges_added(&self) -> usize {
self.operations_applied
.iter()
.filter(|op| matches!(op, RewriteOperation::EdgeAdded { .. }))
.count()
}
#[inline]
pub fn edges_deleted(&self) -> usize {
self.operations_applied
.iter()
.filter(|op| matches!(op, RewriteOperation::EdgeDeleted { .. }))
.count()
}
}
fn validate_no_dangling_edges(graph: &SqliteGraph) -> Vec<String> {
let mut errors = Vec::new();
let valid_ids: HashSet<i64> = match graph.all_entity_ids() {
Ok(ids) => ids.into_iter().collect(),
Err(e) => {
errors.push(format!("Failed to get entity IDs: {}", e));
return errors;
}
};
for &node_id in &valid_ids {
if let Ok(outgoing) = graph.fetch_outgoing(node_id) {
for &target_id in &outgoing {
if !valid_ids.contains(&target_id) {
errors.push(format!(
"Dangling edge: {} -> {} (target node does not exist)",
node_id, target_id
));
}
}
}
}
errors
}
fn copy_graph(graph: &SqliteGraph) -> Result<SqliteGraph, SqliteGraphError> {
let new_graph = SqliteGraph::open_in_memory()?;
let entity_ids = graph.all_entity_ids()?;
for &id in &entity_ids {
if let Ok(entity) = graph.get_entity(id) {
let _ = new_graph.insert_entity(&crate::GraphEntity {
id: 0,
kind: entity.kind.clone(),
name: entity.name.clone(),
file_path: entity.file_path.clone(),
data: entity.data.clone(),
});
}
}
let new_ids: Vec<i64> = new_graph
.all_entity_ids()?
.into_iter()
.take(entity_ids.len())
.collect();
let mut old_to_new: HashMap<i64, i64> = HashMap::new();
for (old_id, new_id) in entity_ids.iter().zip(new_ids.iter()) {
old_to_new.insert(*old_id, *new_id);
}
for &from_id in &entity_ids {
if let Ok(outgoing) = graph.fetch_outgoing(from_id) {
for to_id in outgoing {
if let (Some(&new_from), Some(&new_to)) =
(old_to_new.get(&from_id), old_to_new.get(&to_id))
{
let edge = crate::GraphEdge {
id: 0,
from_id: new_from,
to_id: new_to,
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
let _ = new_graph.insert_edge(&edge);
}
}
}
}
Ok(new_graph)
}
pub fn rewrite_graph_patterns(
graph: &SqliteGraph,
rule: &RewriteRule,
bounds: RewriteBounds,
) -> Result<RewriteResult, SqliteGraphError> {
rule.validate_interface()?;
let pattern_bounds = SubgraphPatternBounds {
max_matches: bounds.max_matches,
timeout_ms: Some(5000),
max_pattern_nodes: Some(20),
};
let match_result = find_subgraph_patterns(graph, &rule.pattern, pattern_bounds)?;
if match_result.matches.is_empty() {
return Ok(RewriteResult {
rewritten_graph: copy_graph(graph)?,
patterns_replaced: 0,
operations_applied: vec![],
validation_errors: vec![],
});
}
let mut current_graph = copy_graph(graph)?;
let mut all_operations = Vec::new();
let mut patterns_replaced = 0;
let max_rewrites = bounds.max_matches.unwrap_or(match_result.matches.len());
let rewrites_to_apply = match_result.matches.len().min(max_rewrites);
for match_idx in 0..rewrites_to_apply {
let pattern_match = &match_result.matches[match_idx];
let (new_graph, operations) =
apply_single_rewrite(¤t_graph, rule, pattern_match, patterns_replaced)?;
current_graph = new_graph;
all_operations.extend(operations);
patterns_replaced += 1;
}
let validation_errors = if bounds.validate_after_rewrite {
validate_no_dangling_edges(¤t_graph)
} else {
vec![]
};
Ok(RewriteResult {
rewritten_graph: current_graph,
patterns_replaced,
operations_applied: all_operations,
validation_errors,
})
}
pub fn rewrite_graph_patterns_with_progress<F>(
graph: &SqliteGraph,
rule: &RewriteRule,
bounds: RewriteBounds,
progress: &F,
) -> Result<RewriteResult, SqliteGraphError>
where
F: ProgressCallback,
{
progress.on_progress(0, Some(4), "Validating rewrite rule");
rule.validate_interface()?;
progress.on_progress(1, Some(4), "Finding pattern matches");
let pattern_bounds = SubgraphPatternBounds {
max_matches: bounds.max_matches,
timeout_ms: Some(5000),
max_pattern_nodes: Some(20),
};
let match_result = find_subgraph_patterns(graph, &rule.pattern, pattern_bounds)?;
progress.on_progress(
2,
Some(4),
&format!("Found {} pattern matches", match_result.matches.len()),
);
if match_result.matches.is_empty() {
progress.on_progress(3, Some(4), "No matches found, returning original graph");
progress.on_complete();
return Ok(RewriteResult {
rewritten_graph: copy_graph(graph)?,
patterns_replaced: 0,
operations_applied: vec![],
validation_errors: vec![],
});
}
let mut current_graph = copy_graph(graph)?;
let mut all_operations = Vec::new();
let mut patterns_replaced = 0;
let max_rewrites = bounds.max_matches.unwrap_or(match_result.matches.len());
let rewrites_to_apply = match_result.matches.len().min(max_rewrites);
for match_idx in 0..rewrites_to_apply {
let pattern_match = &match_result.matches[match_idx];
progress.on_progress(
2,
Some(4),
&format!("Applying rewrite {}/{}", match_idx + 1, rewrites_to_apply),
);
let (new_graph, operations) =
apply_single_rewrite(¤t_graph, rule, pattern_match, patterns_replaced)?;
current_graph = new_graph;
all_operations.extend(operations);
patterns_replaced += 1;
}
progress.on_progress(3, Some(4), "Validating rewritten graph");
let validation_errors = if bounds.validate_after_rewrite {
validate_no_dangling_edges(¤t_graph)
} else {
vec![]
};
let final_msg = if validation_errors.is_empty() {
format!(
"Rewrite complete: {} patterns replaced, {} operations applied",
patterns_replaced,
all_operations.len()
)
} else {
format!(
"Rewrite complete with errors: {} patterns replaced, {} validation errors",
patterns_replaced,
validation_errors.len()
)
};
progress.on_progress(4, Some(4), &final_msg);
progress.on_complete();
Ok(RewriteResult {
rewritten_graph: current_graph,
patterns_replaced,
operations_applied: all_operations,
validation_errors,
})
}
fn apply_single_rewrite(
graph: &SqliteGraph,
rule: &RewriteRule,
pattern_match: &[i64],
rewrite_index: usize,
) -> Result<(SqliteGraph, Vec<RewriteOperation>), SqliteGraphError> {
let mut operations = Vec::new();
let pattern_ids = rule.pattern.all_entity_ids()?;
let replacement_ids = rule.replacement.all_entity_ids()?;
let mut interface_pattern_indices: HashSet<usize> = HashSet::new();
for &(pattern_idx, replacement_idx) in &rule.interface {
if pattern_idx < pattern_match.len() && replacement_idx < replacement_ids.len() {
interface_pattern_indices.insert(pattern_idx);
}
}
let mut non_interface_pattern_ids: HashSet<i64> = HashSet::new();
for (idx, _pattern_id) in pattern_ids.iter().enumerate() {
if idx < pattern_match.len() && !interface_pattern_indices.contains(&idx) {
let target_id = pattern_match[idx];
non_interface_pattern_ids.insert(target_id);
}
}
let new_graph = SqliteGraph::open_in_memory()?;
let mut old_to_new_id: HashMap<i64, i64> = HashMap::new();
let all_old_ids = graph.all_entity_ids()?;
for &old_id in &all_old_ids {
if !non_interface_pattern_ids.contains(&old_id) {
if let Ok(entity) = graph.get_entity(old_id) {
let new_id = new_graph.insert_entity(&crate::GraphEntity {
id: 0,
kind: entity.kind.clone(),
name: entity.name.clone(),
file_path: entity.file_path.clone(),
data: entity.data.clone(),
})?;
old_to_new_id.insert(old_id, new_id);
}
} else {
operations.push(RewriteOperation::NodeDeleted(old_id));
}
}
for &deleted_id in &non_interface_pattern_ids {
if let Ok(outgoing) = graph.fetch_outgoing(deleted_id) {
for &target_id in &outgoing {
operations.push(RewriteOperation::EdgeDeleted {
from: deleted_id,
to: target_id,
});
}
}
for &from_id in &all_old_ids {
if let Ok(outgoing) = graph.fetch_outgoing(from_id) {
if outgoing.contains(&deleted_id) {
operations.push(RewriteOperation::EdgeDeleted {
from: from_id,
to: deleted_id,
});
}
}
}
}
let mut replacement_node_map: HashMap<usize, i64> = HashMap::new();
for (idx, &replacement_id) in replacement_ids.iter().enumerate() {
let is_interface = rule.interface.iter().any(|(_, rep_idx)| *rep_idx == idx);
if !is_interface {
if let Ok(entity) = rule.replacement.get_entity(replacement_id) {
let fresh_id = new_graph.insert_entity(&crate::GraphEntity {
id: 0,
kind: entity.kind.clone(),
name: format!("{}_rewrite_{}", entity.name, rewrite_index),
file_path: entity.file_path.clone(),
data: entity.data.clone(),
})?;
replacement_node_map.insert(idx, fresh_id);
operations.push(RewriteOperation::NodeAdded(fresh_id));
}
}
}
for &from_old in &all_old_ids {
if let Some(&from_new) = old_to_new_id.get(&from_old) {
if let Ok(outgoing) = graph.fetch_outgoing(from_old) {
for to_old in outgoing {
if let Some(&to_new) = old_to_new_id.get(&to_old) {
let edge = crate::GraphEdge {
id: 0,
from_id: from_new,
to_id: to_new,
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
if new_graph.insert_edge(&edge).is_ok() {
operations.push(RewriteOperation::EdgeAdded {
from: from_new,
to: to_new,
});
}
}
}
}
}
}
if let Ok(repl_node_ids) = rule.replacement.all_entity_ids() {
for &from_repl_id in &repl_node_ids {
if let Ok(outgoing) = rule.replacement.fetch_outgoing(from_repl_id) {
for to_repl_id in outgoing {
let from_idx = repl_node_ids.iter().position(|&id| id == from_repl_id);
let to_idx = repl_node_ids.iter().position(|&id| id == to_repl_id);
if let (Some(from_i), Some(to_i)) = (from_idx, to_idx) {
let from_id = if let Some((pat_idx, _)) = rule
.interface
.iter()
.find(|(_, rep_idx)| *rep_idx == from_i)
{
if *pat_idx < pattern_match.len() {
old_to_new_id.get(&pattern_match[*pat_idx]).copied()
} else {
None
}
} else {
replacement_node_map.get(&from_i).copied()
};
let to_id = if let Some((pat_idx, _)) =
rule.interface.iter().find(|(_, rep_idx)| *rep_idx == to_i)
{
if *pat_idx < pattern_match.len() {
old_to_new_id.get(&pattern_match[*pat_idx]).copied()
} else {
None
}
} else {
replacement_node_map.get(&to_i).copied()
};
if let (Some(from), Some(to)) = (from_id, to_id) {
let edge = crate::GraphEdge {
id: 0,
from_id: from,
to_id: to,
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
if new_graph.insert_edge(&edge).is_ok() {
operations.push(RewriteOperation::EdgeAdded { from, to });
}
}
}
}
}
}
}
Ok((new_graph, operations))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{GraphEdge, GraphEntity};
fn create_test_graph_with_nodes(count: usize) -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..count {
let entity = GraphEntity {
id: 0,
kind: "test".to_string(),
name: format!("test_{}", i),
file_path: Some(format!("test_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
graph
}
fn get_entity_ids(graph: &SqliteGraph, count: usize) -> Vec<i64> {
graph
.all_entity_ids()
.expect("Failed to get IDs")
.into_iter()
.take(count)
.collect()
}
fn add_edge(graph: &SqliteGraph, from_idx: i64, to_idx: i64) {
let ids: Vec<i64> = graph.all_entity_ids().expect("Failed to get IDs");
let edge = GraphEdge {
id: 0,
from_id: ids[from_idx as usize],
to_id: ids[to_idx as usize],
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).ok();
}
#[test]
fn test_rewrite_bounds_default() {
let bounds = RewriteBounds::default();
assert_eq!(bounds.max_matches, Some(10));
assert!(bounds.validate_after_rewrite);
}
#[test]
fn test_rewrite_bounds_builder() {
let bounds = RewriteBounds::default()
.with_max_matches(100)
.with_validation(false);
assert_eq!(bounds.max_matches, Some(100));
assert!(!bounds.validate_after_rewrite);
}
#[test]
fn test_rewrite_bounds_unlimited() {
let bounds = RewriteBounds::default().unlimited();
assert_eq!(bounds.max_matches, None);
assert!(bounds.validate_after_rewrite);
}
#[test]
fn test_rewrite_operation_display() {
assert_eq!(
format!("{}", RewriteOperation::NodeDeleted(5)),
"Deleted node 5"
);
assert_eq!(
format!("{}", RewriteOperation::NodeAdded(10)),
"Added node 10"
);
assert_eq!(
format!("{}", RewriteOperation::EdgeDeleted { from: 1, to: 2 }),
"Deleted edge 1 -> 2"
);
assert_eq!(
format!("{}", RewriteOperation::EdgeAdded { from: 3, to: 4 }),
"Added edge 3 -> 4"
);
}
#[test]
fn test_rewrite_result_helpers() {
let result = RewriteResult {
rewritten_graph: SqliteGraph::open_in_memory().unwrap(),
patterns_replaced: 2,
operations_applied: vec![
RewriteOperation::NodeDeleted(1),
RewriteOperation::NodeDeleted(2),
RewriteOperation::NodeAdded(10),
RewriteOperation::EdgeDeleted { from: 1, to: 2 },
RewriteOperation::EdgeAdded { from: 3, to: 10 },
],
validation_errors: vec![],
};
assert!(result.is_valid());
assert_eq!(result.patterns_replaced, 2);
assert_eq!(result.operation_count(), 5);
assert_eq!(result.nodes_added(), 1);
assert_eq!(result.nodes_deleted(), 2);
assert_eq!(result.edges_added(), 1);
assert_eq!(result.edges_deleted(), 1);
}
#[test]
fn test_rewrite_result_with_errors() {
let result = RewriteResult {
rewritten_graph: SqliteGraph::open_in_memory().unwrap(),
patterns_replaced: 0,
operations_applied: vec![],
validation_errors: vec![
"Dangling edge: 1 -> 999".to_string(),
"Duplicate entity detected".to_string(),
],
};
assert!(!result.is_valid());
assert_eq!(result.validation_errors.len(), 2);
}
#[test]
fn test_validate_no_dangling_edges_valid() {
let graph = create_test_graph_with_nodes(3);
let ids = get_entity_ids(&graph, 3);
for (from, to) in &[(0, 1), (1, 2)] {
let edge = GraphEdge {
id: 0,
from_id: ids[*from],
to_id: ids[*to],
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).ok();
}
let errors = validate_no_dangling_edges(&graph);
assert!(errors.is_empty(), "Expected no errors, got: {:?}", errors);
}
#[test]
fn test_validate_dangling_edges_detected() {
let graph = create_test_graph_with_nodes(3);
let ids = get_entity_ids(&graph, 3);
let edge = GraphEdge {
id: 0,
from_id: ids[0],
to_id: 99999, edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).ok();
let _errors = validate_no_dangling_edges(&graph);
}
#[test]
fn test_rewrite_rule_interface_size() {
let pattern = create_test_graph_with_nodes(3);
let replacement = create_test_graph_with_nodes(2);
let rule = RewriteRule {
pattern,
replacement,
interface: vec![(0, 0), (2, 1)],
};
assert_eq!(rule.interface_size(), 2);
}
#[test]
fn test_rewrite_simple_chain_rewrite() {
let graph = create_test_graph_with_nodes(4);
add_edge(&graph, 0, 1);
add_edge(&graph, 1, 2);
add_edge(&graph, 2, 3);
let pattern = create_test_graph_with_nodes(2);
let pattern_ids = get_entity_ids(&pattern, 2);
let pattern_edge = GraphEdge {
id: 0,
from_id: pattern_ids[0],
to_id: pattern_ids[1],
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
pattern.insert_edge(&pattern_edge).ok();
let replacement = create_test_graph_with_nodes(1);
let rule = RewriteRule {
pattern,
replacement,
interface: vec![(0, 0)],
};
let bounds = RewriteBounds {
max_matches: Some(1),
validate_after_rewrite: true,
};
let result = rewrite_graph_patterns(&graph, &rule, bounds).unwrap();
assert_eq!(result.patterns_replaced, 1);
assert!(result.is_valid());
}
#[test]
fn test_rewrite_with_interface() {
let graph = create_test_graph_with_nodes(3);
add_edge(&graph, 0, 1);
add_edge(&graph, 1, 2);
let pattern = create_test_graph_with_nodes(2);
let pattern_ids = get_entity_ids(&pattern, 2);
let pattern_edge = GraphEdge {
id: 0,
from_id: pattern_ids[0],
to_id: pattern_ids[1],
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
pattern.insert_edge(&pattern_edge).ok();
let replacement = create_test_graph_with_nodes(1);
let rule = RewriteRule {
pattern,
replacement,
interface: vec![(0, 0)],
};
let bounds = RewriteBounds::default();
let result = rewrite_graph_patterns(&graph, &rule, bounds).unwrap();
assert_eq!(result.patterns_replaced, 2, "Should find 2 pattern matches");
assert!(result.is_valid());
}
#[test]
fn test_rewrite_max_matches() {
let graph = create_test_graph_with_nodes(5);
for i in 0..4 {
add_edge(&graph, i, i + 1);
}
let pattern = create_test_graph_with_nodes(2);
let pattern_ids = get_entity_ids(&pattern, 2);
let pattern_edge = GraphEdge {
id: 0,
from_id: pattern_ids[0],
to_id: pattern_ids[1],
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
pattern.insert_edge(&pattern_edge).ok();
let replacement = create_test_graph_with_nodes(1);
let rule = RewriteRule {
pattern,
replacement,
interface: vec![(0, 0)],
};
let bounds = RewriteBounds {
max_matches: Some(2),
validate_after_rewrite: true,
};
let result = rewrite_graph_patterns(&graph, &rule, bounds).unwrap();
assert!(result.patterns_replaced <= 2);
assert!(result.is_valid());
}
#[test]
fn test_rewrite_empty_pattern() {
let graph = create_test_graph_with_nodes(3);
add_edge(&graph, 0, 1);
let pattern = create_test_graph_with_nodes(1);
let replacement = create_test_graph_with_nodes(1);
let rule = RewriteRule {
pattern,
replacement,
interface: vec![],
};
let bounds = RewriteBounds::default();
let result = rewrite_graph_patterns(&graph, &rule, bounds).unwrap();
assert_eq!(
result.patterns_replaced, 3,
"Single node pattern should match all 3 nodes"
);
assert!(result.is_valid());
}
#[test]
fn test_rewrite_multiple_occurrences() {
let graph = create_test_graph_with_nodes(4);
add_edge(&graph, 0, 1);
add_edge(&graph, 2, 3);
let pattern = create_test_graph_with_nodes(2);
let pattern_ids = get_entity_ids(&pattern, 2);
let pattern_edge = GraphEdge {
id: 0,
from_id: pattern_ids[0],
to_id: pattern_ids[1],
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
pattern.insert_edge(&pattern_edge).ok();
let replacement = create_test_graph_with_nodes(1);
let rule = RewriteRule {
pattern,
replacement,
interface: vec![(0, 0)],
};
let bounds = RewriteBounds {
max_matches: Some(10),
validate_after_rewrite: true,
};
let result = rewrite_graph_patterns(&graph, &rule, bounds).unwrap();
assert_eq!(result.patterns_replaced, 2);
assert!(result.is_valid());
}
#[test]
fn test_rewrite_common_subexpression_elimination() {
let graph = SqliteGraph::open_in_memory().unwrap();
let add1 = graph
.insert_entity(&GraphEntity {
id: 0,
kind: "Op".to_string(),
name: "Add1".to_string(),
file_path: None,
data: serde_json::json!({"op": "add"}),
})
.unwrap();
let add2 = graph
.insert_entity(&GraphEntity {
id: 0,
kind: "Op".to_string(),
name: "Add2".to_string(),
file_path: None,
data: serde_json::json!({"op": "add"}),
})
.unwrap();
let x = graph
.insert_entity(&GraphEntity {
id: 0,
kind: "Var".to_string(),
name: "x".to_string(),
file_path: None,
data: serde_json::json!({}),
})
.unwrap();
let y = graph
.insert_entity(&GraphEntity {
id: 0,
kind: "Var".to_string(),
name: "y".to_string(),
file_path: None,
data: serde_json::json!({}),
})
.unwrap();
let _ = graph.insert_edge(&GraphEdge {
id: 0,
from_id: add1,
to_id: x,
edge_type: "uses".to_string(),
data: serde_json::json!({}),
});
let _ = graph.insert_edge(&GraphEdge {
id: 0,
from_id: add1,
to_id: y,
edge_type: "uses".to_string(),
data: serde_json::json!({}),
});
let _ = graph.insert_edge(&GraphEdge {
id: 0,
from_id: add2,
to_id: x,
edge_type: "uses".to_string(),
data: serde_json::json!({}),
});
let _ = graph.insert_edge(&GraphEdge {
id: 0,
from_id: add2,
to_id: y,
edge_type: "uses".to_string(),
data: serde_json::json!({}),
});
let original_node_count = graph.all_entity_ids().unwrap().len();
assert_eq!(original_node_count, 4);
}
#[test]
fn test_rewrite_progress_callback() {
use crate::progress::NoProgress;
let graph = create_test_graph_with_nodes(3);
add_edge(&graph, 0, 1);
add_edge(&graph, 1, 2);
let pattern = create_test_graph_with_nodes(2);
let pattern_ids = get_entity_ids(&pattern, 2);
let pattern_edge = GraphEdge {
id: 0,
from_id: pattern_ids[0],
to_id: pattern_ids[1],
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
pattern.insert_edge(&pattern_edge).ok();
let replacement = create_test_graph_with_nodes(1);
let rule = RewriteRule {
pattern,
replacement,
interface: vec![(0, 0)],
};
let progress = NoProgress;
let bounds = RewriteBounds::default();
let result =
rewrite_graph_patterns_with_progress(&graph, &rule, bounds, &progress).unwrap();
assert_eq!(result.patterns_replaced, 2, "Should find 2 pattern matches");
assert!(result.is_valid());
}
}