1use std::collections::{HashMap, HashSet};
13
14use super::graph_store::{GraphStore, StoredNode};
15
16#[derive(Clone, Default)]
22pub struct NodeFilter {
23 pub labels: Option<Vec<String>>,
25 pub ids: Option<HashSet<String>>,
27}
28
29impl NodeFilter {
30 pub fn all() -> Self {
32 Self::default()
33 }
34
35 pub fn with_labels<I, S>(mut self, labels: I) -> Self
37 where
38 I: IntoIterator<Item = S>,
39 S: Into<String>,
40 {
41 self.labels = Some(labels.into_iter().map(Into::into).collect());
42 self
43 }
44
45 pub fn with_ids(mut self, ids: HashSet<String>) -> Self {
47 self.ids = Some(ids);
48 self
49 }
50
51 pub fn matches(&self, node: &StoredNode) -> bool {
53 if let Some(ref labels) = self.labels {
54 if !labels.iter().any(|l| l == node.node_type.as_str()) {
55 return false;
56 }
57 }
58
59 if let Some(ref ids) = self.ids {
60 if !ids.contains(&node.id) {
61 return false;
62 }
63 }
64
65 true
66 }
67}
68
69#[derive(Clone, Default)]
71pub struct EdgeFilter {
72 pub edge_types: Option<Vec<String>>,
74 pub min_weight: Option<f32>,
76 pub max_weight: Option<f32>,
78}
79
80impl EdgeFilter {
81 pub fn all() -> Self {
83 Self::default()
84 }
85
86 pub fn with_types<I, S>(mut self, types: I) -> Self
88 where
89 I: IntoIterator<Item = S>,
90 S: Into<String>,
91 {
92 self.edge_types = Some(types.into_iter().map(Into::into).collect());
93 self
94 }
95
96 pub fn with_min_weight(mut self, weight: f32) -> Self {
98 self.min_weight = Some(weight);
99 self
100 }
101
102 pub fn with_max_weight(mut self, weight: f32) -> Self {
104 self.max_weight = Some(weight);
105 self
106 }
107
108 pub fn matches(&self, edge_label: &str, weight: f32) -> bool {
110 if let Some(ref types) = self.edge_types {
111 if !types.iter().any(|t| t == edge_label) {
112 return false;
113 }
114 }
115
116 if let Some(min) = self.min_weight {
118 if weight < min {
119 return false;
120 }
121 }
122
123 if let Some(max) = self.max_weight {
124 if weight > max {
125 return false;
126 }
127 }
128
129 true
130 }
131}
132
133#[derive(Clone, Default)]
139pub struct PropertyProjection {
140 pub include_label: bool,
142 pub include_weight: bool,
144}
145
146impl PropertyProjection {
147 pub fn all() -> Self {
149 Self {
150 include_label: true,
151 include_weight: true,
152 }
153 }
154
155 pub fn minimal() -> Self {
157 Self {
158 include_label: false,
159 include_weight: false,
160 }
161 }
162}
163
164#[derive(Clone, Copy, Debug, PartialEq, Eq)]
170pub enum AggregationStrategy {
171 None,
173 SumWeight,
175 AvgWeight,
177 MinWeight,
179 MaxWeight,
181 Count,
183}
184
185pub struct GraphProjection {
194 nodes: HashMap<String, ProjectedNode>,
196 outgoing: HashMap<String, Vec<(String, String, f32)>>,
198 incoming: HashMap<String, Vec<(String, String, f32)>>,
200 stats: ProjectionStats,
202}
203
204#[derive(Clone, Debug)]
206pub struct ProjectedNode {
207 pub id: String,
208 pub label: String,
209 pub category: Option<String>,
212}
213
214#[derive(Clone, Debug, Default)]
216pub struct ProjectionStats {
217 pub node_count: usize,
219 pub edge_count: usize,
221 pub nodes_filtered: usize,
223 pub edges_filtered: usize,
225 pub edges_aggregated: usize,
227}
228
229impl GraphProjection {
230 pub fn native(
232 graph: &GraphStore,
233 node_filter: NodeFilter,
234 edge_filter: EdgeFilter,
235 property_projection: PropertyProjection,
236 aggregation: AggregationStrategy,
237 ) -> Self {
238 let mut nodes: HashMap<String, ProjectedNode> = HashMap::new();
239 let mut outgoing: HashMap<String, Vec<(String, String, f32)>> = HashMap::new();
240 let mut incoming: HashMap<String, Vec<(String, String, f32)>> = HashMap::new();
241 let mut stats = ProjectionStats::default();
242
243 let mut node_ids: HashSet<String> = HashSet::new();
245 for node in graph.iter_nodes() {
246 if node_filter.matches(&node) {
247 let projected = ProjectedNode {
248 id: node.id.clone(),
249 label: node.label.clone(),
250 category: if property_projection.include_label {
251 Some(node.node_type.as_str().to_string())
252 } else {
253 None
254 },
255 };
256 node_ids.insert(node.id.clone());
257 nodes.insert(node.id.clone(), projected);
258 stats.node_count += 1;
259 } else {
260 stats.nodes_filtered += 1;
261 }
262 }
263
264 let mut edge_groups: HashMap<(String, String), Vec<(String, f32)>> = HashMap::new();
267
268 for node_id in &node_ids {
269 for (edge_type, target, weight) in graph.outgoing_edges(node_id) {
270 if !node_ids.contains(&target) {
271 continue;
272 }
273
274 let edge_label = edge_type.as_str().to_string();
275 if edge_filter.matches(&edge_label, weight) {
276 let key = (node_id.clone(), target);
277 edge_groups
278 .entry(key)
279 .or_default()
280 .push((edge_label, weight));
281 } else {
282 stats.edges_filtered += 1;
283 }
284 }
285 }
286
287 for ((source, target), edges) in edge_groups {
289 match aggregation {
290 AggregationStrategy::None => {
291 for (edge_type, weight) in edges {
293 outgoing.entry(source.clone()).or_default().push((
294 target.clone(),
295 edge_type.clone(),
296 weight,
297 ));
298 incoming.entry(target.clone()).or_default().push((
299 source.clone(),
300 edge_type,
301 weight,
302 ));
303 stats.edge_count += 1;
304 }
305 }
306 _ => {
307 if let Some((first_type, _)) = edges.first().cloned() {
309 let weight = match aggregation {
310 AggregationStrategy::SumWeight => edges.iter().map(|(_, w)| w).sum(),
311 AggregationStrategy::AvgWeight => {
312 let sum: f32 = edges.iter().map(|(_, w)| w).sum();
313 sum / edges.len() as f32
314 }
315 AggregationStrategy::MinWeight => {
316 edges.iter().map(|(_, w)| *w).fold(f32::INFINITY, f32::min)
317 }
318 AggregationStrategy::MaxWeight => edges
319 .iter()
320 .map(|(_, w)| *w)
321 .fold(f32::NEG_INFINITY, f32::max),
322 AggregationStrategy::Count => edges.len() as f32,
323 AggregationStrategy::None => unreachable!(),
324 };
325
326 if edges.len() > 1 {
327 stats.edges_aggregated += edges.len() - 1;
328 }
329
330 outgoing.entry(source.clone()).or_default().push((
331 target.clone(),
332 first_type.clone(),
333 weight,
334 ));
335 incoming
336 .entry(target)
337 .or_default()
338 .push((source, first_type, weight));
339 stats.edge_count += 1;
340 }
341 }
342 }
343 }
344
345 Self {
346 nodes,
347 outgoing,
348 incoming,
349 stats,
350 }
351 }
352
353 pub fn from_nodes(graph: &GraphStore, node_ids: &[String]) -> Self {
355 let id_set: HashSet<String> = node_ids.iter().cloned().collect();
356 let node_filter = NodeFilter::all().with_ids(id_set);
357 Self::native(
358 graph,
359 node_filter,
360 EdgeFilter::all(),
361 PropertyProjection::all(),
362 AggregationStrategy::None,
363 )
364 }
365
366 pub fn from_paths(graph: &GraphStore, paths: &[Vec<String>]) -> Self {
368 let mut node_ids: HashSet<String> = HashSet::new();
369 for path in paths {
370 node_ids.extend(path.iter().cloned());
371 }
372 let node_filter = NodeFilter::all().with_ids(node_ids);
373 Self::native(
374 graph,
375 node_filter,
376 EdgeFilter::all(),
377 PropertyProjection::all(),
378 AggregationStrategy::None,
379 )
380 }
381
382 pub fn undirected(
384 graph: &GraphStore,
385 node_filter: NodeFilter,
386 edge_filter: EdgeFilter,
387 ) -> Self {
388 let mut projection = Self::native(
389 graph,
390 node_filter,
391 edge_filter,
392 PropertyProjection::all(),
393 AggregationStrategy::SumWeight,
394 );
395
396 let mut additional: Vec<(String, String, String, f32)> = Vec::new();
398
399 for (source, edges) in &projection.outgoing {
400 for (target, edge_type, weight) in edges {
401 let has_reverse = projection
403 .outgoing
404 .get(target)
405 .map(|e| e.iter().any(|(t, _, _)| t == source))
406 .unwrap_or(false);
407
408 if !has_reverse {
409 additional.push((target.clone(), source.clone(), edge_type.clone(), *weight));
410 }
411 }
412 }
413
414 for (source, target, edge_type, weight) in additional {
415 projection
416 .outgoing
417 .entry(source.clone())
418 .or_default()
419 .push((target.clone(), edge_type.clone(), weight));
420 projection
421 .incoming
422 .entry(target)
423 .or_default()
424 .push((source, edge_type, weight));
425 projection.stats.edge_count += 1;
426 }
427
428 projection
429 }
430
431 pub fn stats(&self) -> &ProjectionStats {
433 &self.stats
434 }
435
436 pub fn node_count(&self) -> usize {
438 self.nodes.len()
439 }
440
441 pub fn edge_count(&self) -> usize {
443 self.stats.edge_count
444 }
445
446 pub fn get_node(&self, id: &str) -> Option<&ProjectedNode> {
448 self.nodes.get(id)
449 }
450
451 pub fn has_node(&self, id: &str) -> bool {
453 self.nodes.contains_key(id)
454 }
455
456 pub fn iter_nodes(&self) -> impl Iterator<Item = &ProjectedNode> {
458 self.nodes.values()
459 }
460
461 pub fn node_ids(&self) -> impl Iterator<Item = &String> {
463 self.nodes.keys()
464 }
465
466 pub fn outgoing(&self, node_id: &str) -> &[(String, String, f32)] {
468 self.outgoing
469 .get(node_id)
470 .map(|v| v.as_slice())
471 .unwrap_or(&[])
472 }
473
474 pub fn incoming(&self, node_id: &str) -> &[(String, String, f32)] {
476 self.incoming
477 .get(node_id)
478 .map(|v| v.as_slice())
479 .unwrap_or(&[])
480 }
481
482 pub fn out_degree(&self, node_id: &str) -> usize {
484 self.outgoing.get(node_id).map(|v| v.len()).unwrap_or(0)
485 }
486
487 pub fn in_degree(&self, node_id: &str) -> usize {
489 self.incoming.get(node_id).map(|v| v.len()).unwrap_or(0)
490 }
491
492 pub fn neighbors(&self, node_id: &str) -> Vec<&str> {
494 self.outgoing
495 .get(node_id)
496 .map(|edges| edges.iter().map(|(t, _, _)| t.as_str()).collect())
497 .unwrap_or_default()
498 }
499
500 pub fn neighbors_weighted(&self, node_id: &str) -> Vec<(&str, f32)> {
502 self.outgoing
503 .get(node_id)
504 .map(|edges| edges.iter().map(|(t, _, w)| (t.as_str(), *w)).collect())
505 .unwrap_or_default()
506 }
507
508 pub fn all_neighbors(&self, node_id: &str) -> HashSet<&str> {
510 let mut neighbors: HashSet<&str> = HashSet::new();
511
512 if let Some(edges) = self.outgoing.get(node_id) {
513 for (target, _, _) in edges {
514 neighbors.insert(target.as_str());
515 }
516 }
517
518 if let Some(edges) = self.incoming.get(node_id) {
519 for (source, _, _) in edges {
520 neighbors.insert(source.as_str());
521 }
522 }
523
524 neighbors
525 }
526}
527
528pub struct ProjectionBuilder<'a> {
534 graph: &'a GraphStore,
535 node_filter: NodeFilter,
536 edge_filter: EdgeFilter,
537 property_projection: PropertyProjection,
538 aggregation: AggregationStrategy,
539 undirected: bool,
540}
541
542impl<'a> ProjectionBuilder<'a> {
543 pub fn new(graph: &'a GraphStore) -> Self {
545 Self {
546 graph,
547 node_filter: NodeFilter::all(),
548 edge_filter: EdgeFilter::all(),
549 property_projection: PropertyProjection::all(),
550 aggregation: AggregationStrategy::None,
551 undirected: false,
552 }
553 }
554
555 pub fn with_node_labels<I, S>(mut self, labels: I) -> Self
557 where
558 I: IntoIterator<Item = S>,
559 S: Into<String>,
560 {
561 self.node_filter = self.node_filter.with_labels(labels);
562 self
563 }
564
565 pub fn with_node_ids(mut self, ids: HashSet<String>) -> Self {
567 self.node_filter = self.node_filter.with_ids(ids);
568 self
569 }
570
571 pub fn with_edge_types<I, S>(mut self, types: I) -> Self
573 where
574 I: IntoIterator<Item = S>,
575 S: Into<String>,
576 {
577 self.edge_filter = self.edge_filter.with_types(types);
578 self
579 }
580
581 pub fn with_min_weight(mut self, weight: f32) -> Self {
583 self.edge_filter = self.edge_filter.with_min_weight(weight);
584 self
585 }
586
587 pub fn with_max_weight(mut self, weight: f32) -> Self {
589 self.edge_filter = self.edge_filter.with_max_weight(weight);
590 self
591 }
592
593 pub fn aggregate(mut self, strategy: AggregationStrategy) -> Self {
595 self.aggregation = strategy;
596 self
597 }
598
599 pub fn undirected(mut self) -> Self {
601 self.undirected = true;
602 self
603 }
604
605 pub fn build(self) -> GraphProjection {
607 if self.undirected {
608 GraphProjection::undirected(self.graph, self.node_filter, self.edge_filter)
609 } else {
610 GraphProjection::native(
611 self.graph,
612 self.node_filter,
613 self.edge_filter,
614 self.property_projection,
615 self.aggregation,
616 )
617 }
618 }
619}
620
621#[cfg(test)]
626mod tests {
627 use super::*;
628
629 fn create_test_graph() -> GraphStore {
630 let graph = GraphStore::new();
631
632 let _ = graph.add_node_with_label("A", "Server A", "host");
633 let _ = graph.add_node_with_label("B", "Server B", "host");
634 let _ = graph.add_node_with_label("C", "DB Server", "service");
635 let _ = graph.add_node_with_label("D", "Web Server", "service");
636
637 let _ = graph.add_edge_with_label("A", "B", "connects_to", 1.0);
638 let _ = graph.add_edge_with_label("A", "C", "connects_to", 2.0);
639 let _ = graph.add_edge_with_label("B", "C", "auth_access", 1.5);
640 let _ = graph.add_edge_with_label("B", "D", "connects_to", 1.0);
641 let _ = graph.add_edge_with_label("C", "D", "connects_to", 0.5);
642
643 graph
644 }
645
646 #[test]
647 fn test_full_projection() {
648 let graph = create_test_graph();
649 let projection = GraphProjection::native(
650 &graph,
651 NodeFilter::all(),
652 EdgeFilter::all(),
653 PropertyProjection::all(),
654 AggregationStrategy::None,
655 );
656
657 assert_eq!(projection.node_count(), 4);
658 assert_eq!(projection.edge_count(), 5);
659 }
660
661 #[test]
662 fn test_node_label_filter() {
663 let graph = create_test_graph();
664 let projection = GraphProjection::native(
665 &graph,
666 NodeFilter::all().with_labels(["host"]),
667 EdgeFilter::all(),
668 PropertyProjection::all(),
669 AggregationStrategy::None,
670 );
671
672 assert_eq!(projection.node_count(), 2); assert!(projection.has_node("A"));
674 assert!(projection.has_node("B"));
675 assert!(!projection.has_node("C"));
676 assert!(!projection.has_node("D"));
677 }
678
679 #[test]
680 fn test_edge_type_filter() {
681 let graph = create_test_graph();
682 let projection = GraphProjection::native(
683 &graph,
684 NodeFilter::all(),
685 EdgeFilter::all().with_types(["connects_to"]),
686 PropertyProjection::all(),
687 AggregationStrategy::None,
688 );
689
690 assert_eq!(projection.edge_count(), 4);
692 }
693
694 #[test]
695 fn test_weight_filter() {
696 let graph = create_test_graph();
697 let projection = GraphProjection::native(
698 &graph,
699 NodeFilter::all(),
700 EdgeFilter::all().with_min_weight(1.0),
701 PropertyProjection::all(),
702 AggregationStrategy::None,
703 );
704
705 assert_eq!(projection.edge_count(), 4);
707 }
708
709 #[test]
710 fn test_projection_builder() {
711 let graph = create_test_graph();
712 let projection = ProjectionBuilder::new(&graph)
713 .with_node_labels(["service"])
714 .build();
715
716 assert_eq!(projection.node_count(), 2); }
718
719 #[test]
720 fn test_undirected_projection() {
721 let graph = create_test_graph();
722 let projection = ProjectionBuilder::new(&graph).undirected().build();
723
724 assert!(projection.neighbors("A").contains(&"B"));
726 let b_neighbors = projection.neighbors("B");
728 assert!(b_neighbors.contains(&"A")); }
730
731 #[test]
732 fn test_from_nodes() {
733 let graph = create_test_graph();
734 let projection = GraphProjection::from_nodes(&graph, &["A".to_string(), "B".to_string()]);
735
736 assert_eq!(projection.node_count(), 2);
737 assert_eq!(projection.edge_count(), 1);
739 }
740
741 #[test]
742 fn test_neighbors() {
743 let graph = create_test_graph();
744 let projection = GraphProjection::native(
745 &graph,
746 NodeFilter::all(),
747 EdgeFilter::all(),
748 PropertyProjection::all(),
749 AggregationStrategy::None,
750 );
751
752 let a_neighbors = projection.neighbors("A");
753 assert!(a_neighbors.contains(&"B"));
754 assert!(a_neighbors.contains(&"C"));
755 assert_eq!(a_neighbors.len(), 2);
756 }
757
758 #[test]
759 fn test_degrees() {
760 let graph = create_test_graph();
761 let projection = GraphProjection::native(
762 &graph,
763 NodeFilter::all(),
764 EdgeFilter::all(),
765 PropertyProjection::all(),
766 AggregationStrategy::None,
767 );
768
769 assert_eq!(projection.out_degree("A"), 2); assert_eq!(projection.in_degree("D"), 2); assert_eq!(projection.out_degree("D"), 0); }
773}