Skip to main content

batuta/tui/graph/
filtering.rs

1//! Graph filtering operations
2//!
3//! Contains filtering methods for `Graph<N: Clone, E: Clone>`.
4
5use super::graph_core::Graph;
6use super::types::{Edge, Node};
7use crate::tui::graph_analytics::GraphAnalytics;
8
9// ============================================================================
10// GRAPH-005b: Filtering (Neo4j/Gephi pattern)
11// ============================================================================
12
13impl<N: Clone, E: Clone> Graph<N, E> {
14    /// Filter graph to nodes matching predicate
15    ///
16    /// Returns a new graph containing only matching nodes and their edges.
17    #[must_use]
18    pub fn filter_nodes<F>(&self, predicate: F) -> Self
19    where
20        F: Fn(&Node<N>) -> bool,
21    {
22        let mut filtered = Self::new();
23
24        // Add matching nodes
25        for node in self.nodes() {
26            if predicate(node) {
27                filtered.add_node(node.clone());
28            }
29        }
30
31        // Add edges where both endpoints exist
32        for edge in &self.edges {
33            if filtered.nodes.contains_key(&edge.from) && filtered.nodes.contains_key(&edge.to) {
34                filtered.add_edge(edge.clone());
35            }
36        }
37
38        filtered
39    }
40
41    /// Filter to nodes with minimum degree
42    #[must_use]
43    pub fn filter_by_min_degree(&self, min_degree: usize) -> Self {
44        let degrees = GraphAnalytics::degree_centrality(self);
45        let n = self.node_count();
46        let threshold = if n > 1 { min_degree as f32 / (n - 1) as f32 } else { 0.0 };
47
48        self.filter_nodes(|node| degrees.get(&node.id).unwrap_or(&0.0) >= &threshold)
49    }
50
51    /// Filter to top N nodes by importance
52    #[must_use]
53    pub fn filter_top_n(&self, n: usize) -> Self {
54        let mut nodes_by_importance: Vec<_> = self.nodes().collect();
55        nodes_by_importance.sort_by(|a, b| {
56            b.importance.partial_cmp(&a.importance).unwrap_or(std::cmp::Ordering::Equal)
57        });
58
59        let top_ids: std::collections::HashSet<_> =
60            nodes_by_importance.iter().take(n).map(|n| &n.id).collect();
61
62        self.filter_nodes(|node| top_ids.contains(&node.id))
63    }
64
65    /// Filter to nodes matching label pattern
66    #[must_use]
67    pub fn filter_by_label(&self, pattern: &str) -> Self {
68        let pattern_lower = pattern.to_lowercase();
69        self.filter_nodes(|node| {
70            node.label.as_ref().map(|l| l.to_lowercase().contains(&pattern_lower)).unwrap_or(false)
71        })
72    }
73
74    /// Filter to subgraph containing path between two nodes
75    #[must_use]
76    pub fn filter_path(&self, from: &str, to: &str) -> Self {
77        if let Some(path) = GraphAnalytics::shortest_path(self, from, to) {
78            let path_set: std::collections::HashSet<_> = path.into_iter().collect();
79            self.filter_nodes(|node| path_set.contains(&node.id))
80        } else {
81            Self::new()
82        }
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn test_filter_by_min_degree_single_node() {
92        let mut graph: Graph<(), ()> = Graph::new();
93        graph.add_node(Node::new("A", ()));
94        // Single node has threshold 0.0
95        let filtered = graph.filter_by_min_degree(1);
96        assert_eq!(filtered.node_count(), 1);
97    }
98
99    #[test]
100    fn test_filter_by_min_degree_empty_graph() {
101        let graph: Graph<(), ()> = Graph::new();
102        let filtered = graph.filter_by_min_degree(0);
103        assert_eq!(filtered.node_count(), 0);
104    }
105
106    #[test]
107    fn test_filter_path_no_path() {
108        let mut graph: Graph<(), ()> = Graph::new();
109        graph.add_node(Node::new("A", ()));
110        graph.add_node(Node::new("B", ()));
111        // No edges, so no path
112        let filtered = graph.filter_path("A", "B");
113        assert_eq!(filtered.node_count(), 0);
114    }
115
116    #[test]
117    fn test_filter_by_label_no_match() {
118        let mut graph: Graph<(), ()> = Graph::new();
119        graph.add_node(Node::new("A", ()).with_label("Hello"));
120        let filtered = graph.filter_by_label("xyz");
121        assert_eq!(filtered.node_count(), 0);
122    }
123
124    #[test]
125    fn test_filter_by_label_no_labels() {
126        let mut graph: Graph<(), ()> = Graph::new();
127        graph.add_node(Node::new("A", ())); // No label
128        let filtered = graph.filter_by_label("test");
129        assert_eq!(filtered.node_count(), 0);
130    }
131
132    #[test]
133    fn test_filter_nodes_with_edges() {
134        let mut graph: Graph<&str, i32> = Graph::new();
135        graph.add_node(Node::new("A", "a"));
136        graph.add_node(Node::new("B", "b"));
137        graph.add_node(Node::new("C", "c"));
138        graph.add_edge(Edge::new("A", "B", 1));
139        graph.add_edge(Edge::new("B", "C", 2));
140
141        // Filter to keep A and C (not B)
142        let filtered = graph.filter_nodes(|n| n.id != "B");
143        assert_eq!(filtered.node_count(), 2);
144        assert_eq!(filtered.edge_count(), 0); // Edges to B are removed
145    }
146
147    #[test]
148    fn test_filter_top_n_with_equal_importance() {
149        let mut graph: Graph<(), ()> = Graph::new();
150        for i in 0..5 {
151            let node = Node::new(format!("n{}", i), ());
152            // All same importance
153            graph.add_node(node);
154        }
155        let filtered = graph.filter_top_n(3);
156        assert_eq!(filtered.node_count(), 3);
157    }
158}