use std::collections::{HashMap, HashSet};
use super::graph_store::{GraphStore, StoredNode};
#[derive(Clone, Default)]
pub struct NodeFilter {
pub labels: Option<Vec<String>>,
pub ids: Option<HashSet<String>>,
}
impl NodeFilter {
pub fn all() -> Self {
Self::default()
}
pub fn with_labels<I, S>(mut self, labels: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.labels = Some(labels.into_iter().map(Into::into).collect());
self
}
pub fn with_ids(mut self, ids: HashSet<String>) -> Self {
self.ids = Some(ids);
self
}
pub fn matches(&self, node: &StoredNode) -> bool {
if let Some(ref labels) = self.labels {
if !labels.iter().any(|l| l == node.node_type.as_str()) {
return false;
}
}
if let Some(ref ids) = self.ids {
if !ids.contains(&node.id) {
return false;
}
}
true
}
}
#[derive(Clone, Default)]
pub struct EdgeFilter {
pub edge_types: Option<Vec<String>>,
pub min_weight: Option<f32>,
pub max_weight: Option<f32>,
}
impl EdgeFilter {
pub fn all() -> Self {
Self::default()
}
pub fn with_types<I, S>(mut self, types: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.edge_types = Some(types.into_iter().map(Into::into).collect());
self
}
pub fn with_min_weight(mut self, weight: f32) -> Self {
self.min_weight = Some(weight);
self
}
pub fn with_max_weight(mut self, weight: f32) -> Self {
self.max_weight = Some(weight);
self
}
pub fn matches(&self, edge_label: &str, weight: f32) -> bool {
if let Some(ref types) = self.edge_types {
if !types.iter().any(|t| t == edge_label) {
return false;
}
}
if let Some(min) = self.min_weight {
if weight < min {
return false;
}
}
if let Some(max) = self.max_weight {
if weight > max {
return false;
}
}
true
}
}
#[derive(Clone, Default)]
pub struct PropertyProjection {
pub include_label: bool,
pub include_weight: bool,
}
impl PropertyProjection {
pub fn all() -> Self {
Self {
include_label: true,
include_weight: true,
}
}
pub fn minimal() -> Self {
Self {
include_label: false,
include_weight: false,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AggregationStrategy {
None,
SumWeight,
AvgWeight,
MinWeight,
MaxWeight,
Count,
}
pub struct GraphProjection {
nodes: HashMap<String, ProjectedNode>,
outgoing: HashMap<String, Vec<(String, String, f32)>>,
incoming: HashMap<String, Vec<(String, String, f32)>>,
stats: ProjectionStats,
}
#[derive(Clone, Debug)]
pub struct ProjectedNode {
pub id: String,
pub label: String,
pub category: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ProjectionStats {
pub node_count: usize,
pub edge_count: usize,
pub nodes_filtered: usize,
pub edges_filtered: usize,
pub edges_aggregated: usize,
}
impl GraphProjection {
pub fn native(
graph: &GraphStore,
node_filter: NodeFilter,
edge_filter: EdgeFilter,
property_projection: PropertyProjection,
aggregation: AggregationStrategy,
) -> Self {
let mut nodes: HashMap<String, ProjectedNode> = HashMap::new();
let mut outgoing: HashMap<String, Vec<(String, String, f32)>> = HashMap::new();
let mut incoming: HashMap<String, Vec<(String, String, f32)>> = HashMap::new();
let mut stats = ProjectionStats::default();
let mut node_ids: HashSet<String> = HashSet::new();
for node in graph.iter_nodes() {
if node_filter.matches(&node) {
let projected = ProjectedNode {
id: node.id.clone(),
label: node.label.clone(),
category: if property_projection.include_label {
Some(node.node_type.as_str().to_string())
} else {
None
},
};
node_ids.insert(node.id.clone());
nodes.insert(node.id.clone(), projected);
stats.node_count += 1;
} else {
stats.nodes_filtered += 1;
}
}
let mut edge_groups: HashMap<(String, String), Vec<(String, f32)>> = HashMap::new();
for node_id in &node_ids {
for (edge_type, target, weight) in graph.outgoing_edges(node_id) {
if !node_ids.contains(&target) {
continue;
}
let edge_label = edge_type.as_str().to_string();
if edge_filter.matches(&edge_label, weight) {
let key = (node_id.clone(), target);
edge_groups
.entry(key)
.or_default()
.push((edge_label, weight));
} else {
stats.edges_filtered += 1;
}
}
}
for ((source, target), edges) in edge_groups {
match aggregation {
AggregationStrategy::None => {
for (edge_type, weight) in edges {
outgoing.entry(source.clone()).or_default().push((
target.clone(),
edge_type.clone(),
weight,
));
incoming.entry(target.clone()).or_default().push((
source.clone(),
edge_type,
weight,
));
stats.edge_count += 1;
}
}
_ => {
if let Some((first_type, _)) = edges.first().cloned() {
let weight = match aggregation {
AggregationStrategy::SumWeight => edges.iter().map(|(_, w)| w).sum(),
AggregationStrategy::AvgWeight => {
let sum: f32 = edges.iter().map(|(_, w)| w).sum();
sum / edges.len() as f32
}
AggregationStrategy::MinWeight => {
edges.iter().map(|(_, w)| *w).fold(f32::INFINITY, f32::min)
}
AggregationStrategy::MaxWeight => edges
.iter()
.map(|(_, w)| *w)
.fold(f32::NEG_INFINITY, f32::max),
AggregationStrategy::Count => edges.len() as f32,
AggregationStrategy::None => unreachable!(),
};
if edges.len() > 1 {
stats.edges_aggregated += edges.len() - 1;
}
outgoing.entry(source.clone()).or_default().push((
target.clone(),
first_type.clone(),
weight,
));
incoming
.entry(target)
.or_default()
.push((source, first_type, weight));
stats.edge_count += 1;
}
}
}
}
Self {
nodes,
outgoing,
incoming,
stats,
}
}
pub fn from_nodes(graph: &GraphStore, node_ids: &[String]) -> Self {
let id_set: HashSet<String> = node_ids.iter().cloned().collect();
let node_filter = NodeFilter::all().with_ids(id_set);
Self::native(
graph,
node_filter,
EdgeFilter::all(),
PropertyProjection::all(),
AggregationStrategy::None,
)
}
pub fn from_paths(graph: &GraphStore, paths: &[Vec<String>]) -> Self {
let mut node_ids: HashSet<String> = HashSet::new();
for path in paths {
node_ids.extend(path.iter().cloned());
}
let node_filter = NodeFilter::all().with_ids(node_ids);
Self::native(
graph,
node_filter,
EdgeFilter::all(),
PropertyProjection::all(),
AggregationStrategy::None,
)
}
pub fn undirected(
graph: &GraphStore,
node_filter: NodeFilter,
edge_filter: EdgeFilter,
) -> Self {
let mut projection = Self::native(
graph,
node_filter,
edge_filter,
PropertyProjection::all(),
AggregationStrategy::SumWeight,
);
let mut additional: Vec<(String, String, String, f32)> = Vec::new();
for (source, edges) in &projection.outgoing {
for (target, edge_type, weight) in edges {
let has_reverse = projection
.outgoing
.get(target)
.map(|e| e.iter().any(|(t, _, _)| t == source))
.unwrap_or(false);
if !has_reverse {
additional.push((target.clone(), source.clone(), edge_type.clone(), *weight));
}
}
}
for (source, target, edge_type, weight) in additional {
projection
.outgoing
.entry(source.clone())
.or_default()
.push((target.clone(), edge_type.clone(), weight));
projection
.incoming
.entry(target)
.or_default()
.push((source, edge_type, weight));
projection.stats.edge_count += 1;
}
projection
}
pub fn stats(&self) -> &ProjectionStats {
&self.stats
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.stats.edge_count
}
pub fn get_node(&self, id: &str) -> Option<&ProjectedNode> {
self.nodes.get(id)
}
pub fn has_node(&self, id: &str) -> bool {
self.nodes.contains_key(id)
}
pub fn iter_nodes(&self) -> impl Iterator<Item = &ProjectedNode> {
self.nodes.values()
}
pub fn node_ids(&self) -> impl Iterator<Item = &String> {
self.nodes.keys()
}
pub fn outgoing(&self, node_id: &str) -> &[(String, String, f32)] {
self.outgoing
.get(node_id)
.map(|v| v.as_slice())
.unwrap_or(&[])
}
pub fn incoming(&self, node_id: &str) -> &[(String, String, f32)] {
self.incoming
.get(node_id)
.map(|v| v.as_slice())
.unwrap_or(&[])
}
pub fn out_degree(&self, node_id: &str) -> usize {
self.outgoing.get(node_id).map(|v| v.len()).unwrap_or(0)
}
pub fn in_degree(&self, node_id: &str) -> usize {
self.incoming.get(node_id).map(|v| v.len()).unwrap_or(0)
}
pub fn neighbors(&self, node_id: &str) -> Vec<&str> {
self.outgoing
.get(node_id)
.map(|edges| edges.iter().map(|(t, _, _)| t.as_str()).collect())
.unwrap_or_default()
}
pub fn neighbors_weighted(&self, node_id: &str) -> Vec<(&str, f32)> {
self.outgoing
.get(node_id)
.map(|edges| edges.iter().map(|(t, _, w)| (t.as_str(), *w)).collect())
.unwrap_or_default()
}
pub fn all_neighbors(&self, node_id: &str) -> HashSet<&str> {
let mut neighbors: HashSet<&str> = HashSet::new();
if let Some(edges) = self.outgoing.get(node_id) {
for (target, _, _) in edges {
neighbors.insert(target.as_str());
}
}
if let Some(edges) = self.incoming.get(node_id) {
for (source, _, _) in edges {
neighbors.insert(source.as_str());
}
}
neighbors
}
}
pub struct ProjectionBuilder<'a> {
graph: &'a GraphStore,
node_filter: NodeFilter,
edge_filter: EdgeFilter,
property_projection: PropertyProjection,
aggregation: AggregationStrategy,
undirected: bool,
}
impl<'a> ProjectionBuilder<'a> {
pub fn new(graph: &'a GraphStore) -> Self {
Self {
graph,
node_filter: NodeFilter::all(),
edge_filter: EdgeFilter::all(),
property_projection: PropertyProjection::all(),
aggregation: AggregationStrategy::None,
undirected: false,
}
}
pub fn with_node_labels<I, S>(mut self, labels: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.node_filter = self.node_filter.with_labels(labels);
self
}
pub fn with_node_ids(mut self, ids: HashSet<String>) -> Self {
self.node_filter = self.node_filter.with_ids(ids);
self
}
pub fn with_edge_types<I, S>(mut self, types: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.edge_filter = self.edge_filter.with_types(types);
self
}
pub fn with_min_weight(mut self, weight: f32) -> Self {
self.edge_filter = self.edge_filter.with_min_weight(weight);
self
}
pub fn with_max_weight(mut self, weight: f32) -> Self {
self.edge_filter = self.edge_filter.with_max_weight(weight);
self
}
pub fn aggregate(mut self, strategy: AggregationStrategy) -> Self {
self.aggregation = strategy;
self
}
pub fn undirected(mut self) -> Self {
self.undirected = true;
self
}
pub fn build(self) -> GraphProjection {
if self.undirected {
GraphProjection::undirected(self.graph, self.node_filter, self.edge_filter)
} else {
GraphProjection::native(
self.graph,
self.node_filter,
self.edge_filter,
self.property_projection,
self.aggregation,
)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_graph() -> GraphStore {
let graph = GraphStore::new();
let _ = graph.add_node_with_label("A", "Server A", "host");
let _ = graph.add_node_with_label("B", "Server B", "host");
let _ = graph.add_node_with_label("C", "DB Server", "service");
let _ = graph.add_node_with_label("D", "Web Server", "service");
let _ = graph.add_edge_with_label("A", "B", "connects_to", 1.0);
let _ = graph.add_edge_with_label("A", "C", "connects_to", 2.0);
let _ = graph.add_edge_with_label("B", "C", "auth_access", 1.5);
let _ = graph.add_edge_with_label("B", "D", "connects_to", 1.0);
let _ = graph.add_edge_with_label("C", "D", "connects_to", 0.5);
graph
}
#[test]
fn test_full_projection() {
let graph = create_test_graph();
let projection = GraphProjection::native(
&graph,
NodeFilter::all(),
EdgeFilter::all(),
PropertyProjection::all(),
AggregationStrategy::None,
);
assert_eq!(projection.node_count(), 4);
assert_eq!(projection.edge_count(), 5);
}
#[test]
fn test_node_label_filter() {
let graph = create_test_graph();
let projection = GraphProjection::native(
&graph,
NodeFilter::all().with_labels(["host"]),
EdgeFilter::all(),
PropertyProjection::all(),
AggregationStrategy::None,
);
assert_eq!(projection.node_count(), 2); assert!(projection.has_node("A"));
assert!(projection.has_node("B"));
assert!(!projection.has_node("C"));
assert!(!projection.has_node("D"));
}
#[test]
fn test_edge_type_filter() {
let graph = create_test_graph();
let projection = GraphProjection::native(
&graph,
NodeFilter::all(),
EdgeFilter::all().with_types(["connects_to"]),
PropertyProjection::all(),
AggregationStrategy::None,
);
assert_eq!(projection.edge_count(), 4);
}
#[test]
fn test_weight_filter() {
let graph = create_test_graph();
let projection = GraphProjection::native(
&graph,
NodeFilter::all(),
EdgeFilter::all().with_min_weight(1.0),
PropertyProjection::all(),
AggregationStrategy::None,
);
assert_eq!(projection.edge_count(), 4);
}
#[test]
fn test_projection_builder() {
let graph = create_test_graph();
let projection = ProjectionBuilder::new(&graph)
.with_node_labels(["service"])
.build();
assert_eq!(projection.node_count(), 2); }
#[test]
fn test_undirected_projection() {
let graph = create_test_graph();
let projection = ProjectionBuilder::new(&graph).undirected().build();
assert!(projection.neighbors("A").contains(&"B"));
let b_neighbors = projection.neighbors("B");
assert!(b_neighbors.contains(&"A")); }
#[test]
fn test_from_nodes() {
let graph = create_test_graph();
let projection = GraphProjection::from_nodes(&graph, &["A".to_string(), "B".to_string()]);
assert_eq!(projection.node_count(), 2);
assert_eq!(projection.edge_count(), 1);
}
#[test]
fn test_neighbors() {
let graph = create_test_graph();
let projection = GraphProjection::native(
&graph,
NodeFilter::all(),
EdgeFilter::all(),
PropertyProjection::all(),
AggregationStrategy::None,
);
let a_neighbors = projection.neighbors("A");
assert!(a_neighbors.contains(&"B"));
assert!(a_neighbors.contains(&"C"));
assert_eq!(a_neighbors.len(), 2);
}
#[test]
fn test_degrees() {
let graph = create_test_graph();
let projection = GraphProjection::native(
&graph,
NodeFilter::all(),
EdgeFilter::all(),
PropertyProjection::all(),
AggregationStrategy::None,
);
assert_eq!(projection.out_degree("A"), 2); assert_eq!(projection.in_degree("D"), 2); assert_eq!(projection.out_degree("D"), 0); }
}