use std::collections::HashMap;
use std::time::Instant;
use petgraph::Directed;
use petgraph::algo::isomorphism;
use petgraph::graph::{DefaultIx, Graph, NodeIndex};
use crate::{errors::SqliteGraphError, graph::SqliteGraph, progress::ProgressCallback};
type PgDiGraph = Graph<i64, (), Directed, DefaultIx>;
#[derive(Debug, Clone, Default)]
pub struct SubgraphPatternBounds {
pub max_matches: Option<usize>,
pub timeout_ms: Option<u64>,
pub max_pattern_nodes: Option<usize>,
}
impl SubgraphPatternBounds {
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn with_max_matches(mut self, max: usize) -> Self {
self.max_matches = Some(max);
self
}
#[inline]
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = Some(timeout_ms);
self
}
#[inline]
pub fn with_max_pattern_nodes(mut self, max: usize) -> Self {
self.max_pattern_nodes = Some(max);
self
}
#[inline]
pub fn is_bounded(&self) -> bool {
self.max_matches.is_some() || self.timeout_ms.is_some() || self.max_pattern_nodes.is_some()
}
}
#[derive(Debug, Clone)]
pub struct SubgraphMatchResult {
pub matches: Vec<Vec<i64>>,
pub patterns_found: usize,
pub computation_time_ms: u128,
pub bounded_hit: bool,
}
impl SubgraphMatchResult {
#[inline]
pub fn is_empty(&self) -> bool {
self.matches.is_empty()
}
#[inline]
pub fn count(&self) -> usize {
self.matches.len()
}
#[inline]
pub fn first_match(&self) -> Option<&[i64]> {
self.matches.first().map(|m| m.as_slice())
}
}
fn graph_to_petgraph(graph: &SqliteGraph) -> Result<PgDiGraph, SqliteGraphError> {
let entity_ids = graph.all_entity_ids()?;
let mut pg = PgDiGraph::new();
let mut id_to_index: HashMap<i64, NodeIndex> = HashMap::new();
for &id in &entity_ids {
let idx = pg.add_node(id);
id_to_index.insert(id, idx);
}
for &from_id in &entity_ids {
if let Ok(outgoing) = graph.fetch_outgoing(from_id) {
for &to_id in &outgoing {
if let (Some(&from_idx), Some(&to_idx)) =
(id_to_index.get(&from_id), id_to_index.get(&to_id))
{
pg.add_edge(from_idx, to_idx, ());
}
}
}
}
Ok(pg)
}
pub fn find_subgraph_patterns(
graph: &SqliteGraph,
pattern: &SqliteGraph,
bounds: SubgraphPatternBounds,
) -> Result<SubgraphMatchResult, SqliteGraphError> {
let start_time = Instant::now();
let pattern_ids = pattern.all_entity_ids()?;
let pattern_count = pattern_ids.len();
if let Some(max_nodes) = bounds.max_pattern_nodes {
if pattern_count > max_nodes {
return Err(SqliteGraphError::invalid_input(format!(
"Pattern too large: {} nodes exceeds max_pattern_nodes bound of {}",
pattern_count, max_nodes
)));
}
}
let target_pg: PgDiGraph = graph_to_petgraph(graph)?;
let pattern_pg: PgDiGraph = graph_to_petgraph(pattern)?;
let target_node_ids: Vec<i64> = target_pg.node_indices().map(|ni| target_pg[ni]).collect();
let _pattern_node_ids: Vec<i64> = pattern_pg.node_indices().map(|ni| pattern_pg[ni]).collect();
let mut matches = Vec::new();
let mut bounded_hit = false;
let timeout = bounds.timeout_ms.map(std::time::Duration::from_millis);
let pattern_ref = &pattern_pg;
let target_ref = &target_pg;
let mut node_match = |_: &i64, _: &i64| -> bool { true };
let mut edge_match = |_: &(), _: &()| -> bool { true };
let iso_iter = isomorphism::subgraph_isomorphisms_iter(
&pattern_ref,
&target_ref,
&mut node_match,
&mut edge_match,
);
if let Some(iso_iter) = iso_iter {
for mapping in iso_iter {
if let Some(to) = timeout {
if start_time.elapsed() >= to {
bounded_hit = true;
break;
}
}
if let Some(max) = bounds.max_matches {
if matches.len() >= max {
bounded_hit = true;
break;
}
}
let match_mapping: Vec<i64> = mapping
.iter()
.map(|&target_idx| {
if target_idx < target_node_ids.len() {
target_node_ids[target_idx]
} else {
0 }
})
.collect();
matches.push(match_mapping);
}
}
Ok(SubgraphMatchResult {
patterns_found: matches.len(),
matches,
computation_time_ms: start_time.elapsed().as_millis(),
bounded_hit,
})
}
pub fn find_subgraph_patterns_with_progress<F>(
graph: &SqliteGraph,
pattern: &SqliteGraph,
bounds: SubgraphPatternBounds,
progress: &F,
) -> Result<SubgraphMatchResult, SqliteGraphError>
where
F: ProgressCallback,
{
progress.on_progress(0, Some(3), "Converting graphs to petgraph format");
let pattern_ids = pattern.all_entity_ids()?;
let pattern_count = pattern_ids.len();
if let Some(max_nodes) = bounds.max_pattern_nodes {
if pattern_count > max_nodes {
return Err(SqliteGraphError::invalid_input(format!(
"Pattern too large: {} nodes exceeds max_pattern_nodes bound of {}",
pattern_count, max_nodes
)));
}
}
let target_pg: PgDiGraph = graph_to_petgraph(graph)?;
let pattern_pg: PgDiGraph = graph_to_petgraph(pattern)?;
let target_node_ids: Vec<i64> = target_pg.node_indices().map(|ni| target_pg[ni]).collect();
progress.on_progress(
1,
Some(3),
&format!("Searching for patterns ({} pattern nodes)", pattern_count),
);
let start_time = Instant::now();
let timeout = bounds.timeout_ms.map(std::time::Duration::from_millis);
let mut matches = Vec::new();
let mut bounded_hit = false;
let pattern_ref = &pattern_pg;
let target_ref = &target_pg;
let mut node_match = |_: &i64, _: &i64| -> bool { true };
let mut edge_match = |_: &(), _: &()| -> bool { true };
let iso_iter = isomorphism::subgraph_isomorphisms_iter(
&pattern_ref,
&target_ref,
&mut node_match,
&mut edge_match,
);
if let Some(iso_iter) = iso_iter {
for mapping in iso_iter {
if let Some(to) = timeout {
if start_time.elapsed() >= to {
bounded_hit = true;
break;
}
}
if let Some(max) = bounds.max_matches {
if matches.len() >= max {
bounded_hit = true;
break;
}
}
let match_mapping: Vec<i64> = mapping
.iter()
.map(|&target_idx| {
if target_idx < target_node_ids.len() {
target_node_ids[target_idx]
} else {
0
}
})
.collect();
matches.push(match_mapping);
if matches.len() % 10 == 0 {
progress.on_progress(
2,
Some(3),
&format!("Found {} matches so far", matches.len()),
);
}
}
}
let final_msg = if bounded_hit {
format!(
"Search complete: {} matches found (stopped by bounds)",
matches.len()
)
} else {
format!("Search complete: {} matches found", matches.len())
};
progress.on_progress(3, Some(3), &final_msg);
progress.on_complete();
Ok(SubgraphMatchResult {
patterns_found: matches.len(),
matches,
computation_time_ms: start_time.elapsed().as_millis(),
bounded_hit,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{GraphEdge, GraphEntity};
fn create_test_graph_with_nodes(count: usize) -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..count {
let entity = GraphEntity {
id: 0,
kind: "test".to_string(),
name: format!("test_{}", i),
file_path: Some(format!("test_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
graph
}
fn get_entity_ids(graph: &SqliteGraph, count: usize) -> Vec<i64> {
graph
.all_entity_ids()
.expect("Failed to get IDs")
.into_iter()
.take(count)
.collect()
}
fn add_edge(graph: &SqliteGraph, from_idx: i64, to_idx: i64) {
let ids: Vec<i64> = graph.all_entity_ids().expect("Failed to get IDs");
let edge = GraphEdge {
id: 0,
from_id: ids[from_idx as usize],
to_id: ids[to_idx as usize],
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).ok();
}
fn create_chain_pattern() -> SqliteGraph {
let pattern = SqliteGraph::open_in_memory().expect("Failed to create pattern");
for i in 0..2 {
let entity = GraphEntity {
id: 0,
kind: "pattern".to_string(),
name: format!("p{}", i),
file_path: None,
data: serde_json::json!({}),
};
pattern
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let ids: Vec<i64> = pattern.all_entity_ids().expect("Failed to get IDs");
let edge = GraphEdge {
id: 0,
from_id: ids[0],
to_id: ids[1],
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
pattern.insert_edge(&edge).ok();
pattern
}
fn create_triangle_pattern() -> SqliteGraph {
let pattern = SqliteGraph::open_in_memory().expect("Failed to create pattern");
for i in 0..3 {
let entity = GraphEntity {
id: 0,
kind: "pattern".to_string(),
name: format!("p{}", i),
file_path: None,
data: serde_json::json!({}),
};
pattern
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let ids: Vec<i64> = pattern.all_entity_ids().expect("Failed to get IDs");
for (from, to) in &[(0, 1), (1, 2), (2, 0)] {
let edge = GraphEdge {
id: 0,
from_id: ids[*from],
to_id: ids[*to],
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
pattern.insert_edge(&edge).ok();
}
pattern
}
#[test]
fn test_find_subgraph_patterns_simple_chain() {
let graph = create_test_graph_with_nodes(4);
for (from, to) in &[(0, 1), (1, 2), (2, 3)] {
add_edge(&graph, *from, *to);
}
let pattern = create_chain_pattern();
let bounds = SubgraphPatternBounds::default();
let result = find_subgraph_patterns(&graph, &pattern, bounds).unwrap();
assert_eq!(result.patterns_found, 3);
assert_eq!(result.count(), 3);
assert!(!result.is_empty());
assert!(!result.bounded_hit);
}
#[test]
fn test_find_subgraph_patterns_triangle() {
let graph = create_test_graph_with_nodes(6);
let ids = get_entity_ids(&graph, 6);
let triangle1 = [(0, 1), (1, 2), (2, 0)];
let triangle2 = [(3, 4), (4, 5), (5, 3)];
for (from, to) in triangle1.iter().chain(triangle2.iter()) {
let edge = GraphEdge {
id: 0,
from_id: ids[*from],
to_id: ids[*to],
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).ok();
}
let pattern = create_triangle_pattern();
let bounds = SubgraphPatternBounds::default();
let result = find_subgraph_patterns(&graph, &pattern, bounds).unwrap();
assert!(
result.patterns_found >= 2,
"Should find at least 2 triangle matches, found {}",
result.patterns_found
);
assert!(!result.is_empty());
assert!(!result.bounded_hit);
}
#[test]
fn test_find_subgraph_patterns_max_matches() {
let graph = create_test_graph_with_nodes(10);
for i in 0..9 {
add_edge(&graph, i, i + 1);
}
let pattern = create_chain_pattern();
let bounds = SubgraphPatternBounds {
max_matches: Some(2), ..Default::default()
};
let result = find_subgraph_patterns(&graph, &pattern, bounds).unwrap();
assert!(result.patterns_found <= 2);
assert!(result.bounded_hit); }
#[test]
fn test_find_subgraph_patterns_timeout() {
let graph = create_test_graph_with_nodes(10);
for i in 0..9 {
add_edge(&graph, i, i + 1);
}
let pattern = create_chain_pattern();
let bounds = SubgraphPatternBounds {
timeout_ms: Some(1), ..Default::default()
};
let result = find_subgraph_patterns(&graph, &pattern, bounds).unwrap();
assert!(result.computation_time_ms < 100); }
#[test]
fn test_find_subgraph_patterns_empty_result() {
let graph = create_test_graph_with_nodes(3);
for (from, to) in &[(0, 1), (1, 2)] {
add_edge(&graph, *from, *to);
}
let pattern = create_triangle_pattern();
let bounds = SubgraphPatternBounds::default();
let result = find_subgraph_patterns(&graph, &pattern, bounds).unwrap();
assert_eq!(result.patterns_found, 0);
assert!(result.is_empty());
assert!(result.first_match().is_none());
}
#[test]
fn test_find_subgraph_patterns_progress() {
use crate::progress::NoProgress;
let graph = create_test_graph_with_nodes(5);
for i in 0..4 {
add_edge(&graph, i, i + 1);
}
let pattern = create_chain_pattern();
let bounds = SubgraphPatternBounds::default();
let progress = NoProgress;
let result =
find_subgraph_patterns_with_progress(&graph, &pattern, bounds, &progress).unwrap();
assert_eq!(result.patterns_found, 4); }
#[test]
fn test_subgraph_pattern_bounds_builder() {
let bounds = SubgraphPatternBounds::new()
.with_max_matches(100)
.with_timeout(5000)
.with_max_pattern_nodes(10);
assert_eq!(bounds.max_matches, Some(100));
assert_eq!(bounds.timeout_ms, Some(5000));
assert_eq!(bounds.max_pattern_nodes, Some(10));
assert!(bounds.is_bounded());
}
#[test]
fn test_subgraph_match_result_helpers() {
let result = SubgraphMatchResult {
matches: vec![vec![1, 2], vec![2, 3]],
patterns_found: 2,
computation_time_ms: 100,
bounded_hit: false,
};
assert!(!result.is_empty());
assert_eq!(result.count(), 2);
assert_eq!(result.first_match(), Some(&[1, 2][..]));
}
#[test]
fn test_subgraph_match_result_empty_helpers() {
let result = SubgraphMatchResult {
matches: vec![],
patterns_found: 0,
computation_time_ms: 50,
bounded_hit: false,
};
assert!(result.is_empty());
assert_eq!(result.count(), 0);
assert!(result.first_match().is_none());
}
#[test]
fn test_max_pattern_nodes_rejection() {
let graph = create_test_graph_with_nodes(5);
let pattern = create_test_graph_with_nodes(15);
let bounds = SubgraphPatternBounds {
max_pattern_nodes: Some(10),
..Default::default()
};
let result = find_subgraph_patterns(&graph, &pattern, bounds);
assert!(result.is_err());
}
#[test]
fn test_single_node_pattern() {
let graph = create_test_graph_with_nodes(3);
let pattern = SqliteGraph::open_in_memory().expect("Failed to create pattern");
let entity = GraphEntity {
id: 0,
kind: "pattern".to_string(),
name: "p0".to_string(),
file_path: None,
data: serde_json::json!({}),
};
pattern
.insert_entity(&entity)
.expect("Failed to insert entity");
let bounds = SubgraphPatternBounds::default();
let result = find_subgraph_patterns(&graph, &pattern, bounds).unwrap();
assert_eq!(result.patterns_found, 3);
}
#[test]
fn test_pattern_larger_than_target() {
let graph = create_test_graph_with_nodes(2);
let pattern = create_triangle_pattern();
let bounds = SubgraphPatternBounds::default();
let result = find_subgraph_patterns(&graph, &pattern, bounds).unwrap();
assert_eq!(result.patterns_found, 0);
assert!(result.is_empty());
}
#[test]
fn test_default_bounds_unbounded() {
let bounds = SubgraphPatternBounds::default();
assert_eq!(bounds.max_matches, None);
assert_eq!(bounds.timeout_ms, None);
assert_eq!(bounds.max_pattern_nodes, None);
assert!(!bounds.is_bounded());
}
#[test]
fn test_computation_time_tracking() {
let graph = create_test_graph_with_nodes(3);
for (from, to) in &[(0, 1), (1, 2)] {
add_edge(&graph, *from, *to);
}
let pattern = create_chain_pattern();
let bounds = SubgraphPatternBounds::default();
let result = find_subgraph_patterns(&graph, &pattern, bounds).unwrap();
assert!(result.computation_time_ms < 1000);
}
#[test]
fn test_bounded_hit_flag() {
let graph = create_test_graph_with_nodes(5);
for i in 0..4 {
add_edge(&graph, i, i + 1);
}
let pattern = create_chain_pattern();
let bounds = SubgraphPatternBounds {
max_matches: Some(2),
..Default::default()
};
let result = find_subgraph_patterns(&graph, &pattern, bounds).unwrap();
assert_eq!(result.patterns_found, 2);
assert!(result.bounded_hit);
}
#[test]
fn test_bounded_hit_flag_timeout() {
let graph = create_test_graph_with_nodes(20);
for i in 0..19 {
add_edge(&graph, i, i + 1);
}
let pattern = create_chain_pattern();
let bounds = SubgraphPatternBounds {
timeout_ms: Some(1), ..Default::default()
};
let result = find_subgraph_patterns(&graph, &pattern, bounds).unwrap();
assert!(result.computation_time_ms < 100);
}
}