arrow_graph/algorithms/
aggregation.rs

1use arrow::record_batch::RecordBatch;
2use arrow::array::{StringArray, Float64Array, UInt64Array};
3use arrow::datatypes::{DataType, Field, Schema};
4use std::sync::Arc;
5use std::collections::HashMap;
6use crate::algorithms::{GraphAlgorithm, AlgorithmParams};
7use crate::graph::ArrowGraph;
8use crate::error::{GraphError, Result};
9
10pub struct GraphDensity;
11
12impl GraphAlgorithm for GraphDensity {
13    fn execute(&self, _graph: &ArrowGraph, _params: &AlgorithmParams) -> Result<RecordBatch> {
14        todo!("Calculate graph density metric")
15    }
16    
17    fn name(&self) -> &'static str {
18        "graph_density"
19    }
20    
21    fn description(&self) -> &'static str {
22        "Calculate the density of the graph"
23    }
24}
25
26pub struct TriangleCount;
27
28impl TriangleCount {
29    /// Count triangles in the graph using node enumeration method
30    fn count_triangles(&self, graph: &ArrowGraph) -> Result<u64> {
31        let mut triangle_count = 0u64;
32        let node_ids: Vec<String> = graph.node_ids().cloned().collect();
33        
34        // For each triple of nodes, check if they form a triangle
35        for i in 0..node_ids.len() {
36            for j in (i + 1)..node_ids.len() {
37                for k in (j + 1)..node_ids.len() {
38                    let node_a = &node_ids[i];
39                    let node_b = &node_ids[j];
40                    let node_c = &node_ids[k];
41                    
42                    // Check if all three edges exist (undirected): A-B, B-C, C-A
43                    let has_ab = (graph.neighbors(node_a)
44                        .map(|neighbors| neighbors.contains(node_b))
45                        .unwrap_or(false)) ||
46                        (graph.neighbors(node_b)
47                        .map(|neighbors| neighbors.contains(node_a))
48                        .unwrap_or(false));
49                    
50                    let has_bc = (graph.neighbors(node_b)
51                        .map(|neighbors| neighbors.contains(node_c))
52                        .unwrap_or(false)) ||
53                        (graph.neighbors(node_c)
54                        .map(|neighbors| neighbors.contains(node_b))
55                        .unwrap_or(false));
56                    
57                    let has_ac = (graph.neighbors(node_a)
58                        .map(|neighbors| neighbors.contains(node_c))
59                        .unwrap_or(false)) ||
60                        (graph.neighbors(node_c)
61                        .map(|neighbors| neighbors.contains(node_a))
62                        .unwrap_or(false));
63                    
64                    if has_ab && has_bc && has_ac {
65                        triangle_count += 1;
66                    }
67                }
68            }
69        }
70        
71        Ok(triangle_count)
72    }
73    
74    /// Count triangles per node for local clustering coefficient
75    fn count_triangles_per_node(&self, graph: &ArrowGraph) -> Result<HashMap<String, u64>> {
76        let mut node_triangles: HashMap<String, u64> = HashMap::new();
77        let node_ids: Vec<String> = graph.node_ids().cloned().collect();
78        
79        // Initialize all nodes with 0 triangles
80        for node_id in graph.node_ids() {
81            node_triangles.insert(node_id.clone(), 0);
82        }
83        
84        // Find all triangles and count them for each participating node
85        for i in 0..node_ids.len() {
86            for j in (i + 1)..node_ids.len() {
87                for k in (j + 1)..node_ids.len() {
88                    let node_a = &node_ids[i];
89                    let node_b = &node_ids[j];
90                    let node_c = &node_ids[k];
91                    
92                    // Check if all three edges exist (undirected): A-B, B-C, C-A
93                    let has_ab = (graph.neighbors(node_a)
94                        .map(|neighbors| neighbors.contains(node_b))
95                        .unwrap_or(false)) ||
96                        (graph.neighbors(node_b)
97                        .map(|neighbors| neighbors.contains(node_a))
98                        .unwrap_or(false));
99                    
100                    let has_bc = (graph.neighbors(node_b)
101                        .map(|neighbors| neighbors.contains(node_c))
102                        .unwrap_or(false)) ||
103                        (graph.neighbors(node_c)
104                        .map(|neighbors| neighbors.contains(node_b))
105                        .unwrap_or(false));
106                    
107                    let has_ac = (graph.neighbors(node_a)
108                        .map(|neighbors| neighbors.contains(node_c))
109                        .unwrap_or(false)) ||
110                        (graph.neighbors(node_c)
111                        .map(|neighbors| neighbors.contains(node_a))
112                        .unwrap_or(false));
113                    
114                    if has_ab && has_bc && has_ac {
115                        // Triangle found, increment count for all three nodes
116                        *node_triangles.get_mut(node_a).unwrap() += 1;
117                        *node_triangles.get_mut(node_b).unwrap() += 1;
118                        *node_triangles.get_mut(node_c).unwrap() += 1;
119                    }
120                }
121            }
122        }
123        
124        Ok(node_triangles)
125    }
126}
127
128impl GraphAlgorithm for TriangleCount {
129    fn execute(&self, graph: &ArrowGraph, _params: &AlgorithmParams) -> Result<RecordBatch> {
130        let total_triangles = self.count_triangles(graph)?;
131        
132        let schema = Arc::new(Schema::new(vec![
133            Field::new("metric", DataType::Utf8, false),
134            Field::new("value", DataType::UInt64, false),
135        ]));
136        
137        RecordBatch::try_new(
138            schema,
139            vec![
140                Arc::new(StringArray::from(vec!["triangle_count"])),
141                Arc::new(UInt64Array::from(vec![total_triangles])),
142            ],
143        ).map_err(GraphError::from)
144    }
145    
146    fn name(&self) -> &'static str {
147        "triangle_count"
148    }
149    
150    fn description(&self) -> &'static str {
151        "Count the total number of triangles in the graph"
152    }
153}
154
155pub struct ClusteringCoefficient;
156
157impl ClusteringCoefficient {
158    /// Calculate local clustering coefficient for each node
159    fn calculate_local_clustering(&self, graph: &ArrowGraph) -> Result<HashMap<String, f64>> {
160        let mut clustering: HashMap<String, f64> = HashMap::new();
161        let triangle_counter = TriangleCount;
162        let node_triangles = triangle_counter.count_triangles_per_node(graph)?;
163        
164        for node_id in graph.node_ids() {
165            if let Some(neighbors) = graph.neighbors(node_id) {
166                let degree = neighbors.len();
167                
168                if degree < 2 {
169                    // Nodes with degree < 2 cannot form triangles
170                    clustering.insert(node_id.clone(), 0.0);
171                } else {
172                    let triangles = *node_triangles.get(node_id).unwrap_or(&0);
173                    let possible_triangles = (degree * (degree - 1)) / 2;
174                    let coefficient = triangles as f64 / possible_triangles as f64;
175                    clustering.insert(node_id.clone(), coefficient);
176                }
177            } else {
178                clustering.insert(node_id.clone(), 0.0);
179            }
180        }
181        
182        Ok(clustering)
183    }
184    
185    /// Calculate global clustering coefficient (transitivity)
186    fn calculate_global_clustering(&self, graph: &ArrowGraph) -> Result<f64> {
187        let triangle_counter = TriangleCount;
188        let total_triangles = triangle_counter.count_triangles(graph)? as f64;
189        
190        // Count total number of connected triples (paths of length 2)
191        let mut total_triples = 0u64;
192        
193        for node_id in graph.node_ids() {
194            if let Some(neighbors) = graph.neighbors(node_id) {
195                let degree = neighbors.len();
196                if degree >= 2 {
197                    // Number of connected triples centered at this node
198                    // Each pair of neighbors forms a triple with this node as center
199                    total_triples += (degree * (degree - 1)) as u64 / 2;
200                }
201            }
202        }
203        
204        if total_triples == 0 {
205            Ok(0.0)
206        } else {
207            // Global clustering coefficient = 3 * triangles / triples
208            // Note: Each triangle contributes 3 triples, so we multiply by 3
209            let coefficient = 3.0 * total_triangles / total_triples as f64;
210            // Ensure coefficient is within valid range [0, 1]
211            Ok(coefficient.min(1.0))
212        }
213    }
214}
215
216impl GraphAlgorithm for ClusteringCoefficient {
217    fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
218        let mode: String = params.get("mode").unwrap_or("local".to_string());
219        
220        match mode.as_str() {
221            "local" => {
222                let clustering = self.calculate_local_clustering(graph)?;
223                
224                if clustering.is_empty() {
225                    let schema = Arc::new(Schema::new(vec![
226                        Field::new("node_id", DataType::Utf8, false),
227                        Field::new("clustering_coefficient", DataType::Float64, false),
228                    ]));
229                    
230                    return RecordBatch::try_new(
231                        schema,
232                        vec![
233                            Arc::new(StringArray::from(Vec::<String>::new())),
234                            Arc::new(Float64Array::from(Vec::<f64>::new())),
235                        ],
236                    ).map_err(GraphError::from);
237                }
238                
239                // Sort by clustering coefficient (descending)
240                let mut sorted_nodes: Vec<(&String, &f64)> = clustering.iter().collect();
241                sorted_nodes.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
242                
243                let node_ids: Vec<String> = sorted_nodes.iter().map(|(node, _)| (*node).clone()).collect();
244                let coefficients: Vec<f64> = sorted_nodes.iter().map(|(_, &coeff)| coeff).collect();
245                
246                let schema = Arc::new(Schema::new(vec![
247                    Field::new("node_id", DataType::Utf8, false),
248                    Field::new("clustering_coefficient", DataType::Float64, false),
249                ]));
250                
251                RecordBatch::try_new(
252                    schema,
253                    vec![
254                        Arc::new(StringArray::from(node_ids)),
255                        Arc::new(Float64Array::from(coefficients)),
256                    ],
257                ).map_err(GraphError::from)
258            },
259            "global" => {
260                let global_coefficient = self.calculate_global_clustering(graph)?;
261                
262                let schema = Arc::new(Schema::new(vec![
263                    Field::new("metric", DataType::Utf8, false),
264                    Field::new("value", DataType::Float64, false),
265                ]));
266                
267                RecordBatch::try_new(
268                    schema,
269                    vec![
270                        Arc::new(StringArray::from(vec!["global_clustering_coefficient"])),
271                        Arc::new(Float64Array::from(vec![global_coefficient])),
272                    ],
273                ).map_err(GraphError::from)
274            },
275            _ => Err(GraphError::invalid_parameter(
276                "mode must be 'local' or 'global'"
277            ))
278        }
279    }
280    
281    fn name(&self) -> &'static str {
282        "clustering_coefficient"
283    }
284    
285    fn description(&self) -> &'static str {
286        "Calculate local or global clustering coefficient"
287    }
288}