arrow_graph/algorithms/
aggregation.rs1use arrow::record_batch::RecordBatch;
2use arrow::array::{StringArray, Float64Array, UInt64Array};
3use arrow::datatypes::{DataType, Field, Schema};
4use std::sync::Arc;
5use std::collections::HashMap;
6use crate::algorithms::{GraphAlgorithm, AlgorithmParams};
7use crate::graph::ArrowGraph;
8use crate::error::{GraphError, Result};
9
10pub struct GraphDensity;
11
12impl GraphAlgorithm for GraphDensity {
13 fn execute(&self, _graph: &ArrowGraph, _params: &AlgorithmParams) -> Result<RecordBatch> {
14 todo!("Calculate graph density metric")
15 }
16
17 fn name(&self) -> &'static str {
18 "graph_density"
19 }
20
21 fn description(&self) -> &'static str {
22 "Calculate the density of the graph"
23 }
24}
25
26pub struct TriangleCount;
27
28impl TriangleCount {
29 fn count_triangles(&self, graph: &ArrowGraph) -> Result<u64> {
31 let mut triangle_count = 0u64;
32 let node_ids: Vec<String> = graph.node_ids().cloned().collect();
33
34 for i in 0..node_ids.len() {
36 for j in (i + 1)..node_ids.len() {
37 for k in (j + 1)..node_ids.len() {
38 let node_a = &node_ids[i];
39 let node_b = &node_ids[j];
40 let node_c = &node_ids[k];
41
42 let has_ab = (graph.neighbors(node_a)
44 .map(|neighbors| neighbors.contains(node_b))
45 .unwrap_or(false)) ||
46 (graph.neighbors(node_b)
47 .map(|neighbors| neighbors.contains(node_a))
48 .unwrap_or(false));
49
50 let has_bc = (graph.neighbors(node_b)
51 .map(|neighbors| neighbors.contains(node_c))
52 .unwrap_or(false)) ||
53 (graph.neighbors(node_c)
54 .map(|neighbors| neighbors.contains(node_b))
55 .unwrap_or(false));
56
57 let has_ac = (graph.neighbors(node_a)
58 .map(|neighbors| neighbors.contains(node_c))
59 .unwrap_or(false)) ||
60 (graph.neighbors(node_c)
61 .map(|neighbors| neighbors.contains(node_a))
62 .unwrap_or(false));
63
64 if has_ab && has_bc && has_ac {
65 triangle_count += 1;
66 }
67 }
68 }
69 }
70
71 Ok(triangle_count)
72 }
73
74 fn count_triangles_per_node(&self, graph: &ArrowGraph) -> Result<HashMap<String, u64>> {
76 let mut node_triangles: HashMap<String, u64> = HashMap::new();
77 let node_ids: Vec<String> = graph.node_ids().cloned().collect();
78
79 for node_id in graph.node_ids() {
81 node_triangles.insert(node_id.clone(), 0);
82 }
83
84 for i in 0..node_ids.len() {
86 for j in (i + 1)..node_ids.len() {
87 for k in (j + 1)..node_ids.len() {
88 let node_a = &node_ids[i];
89 let node_b = &node_ids[j];
90 let node_c = &node_ids[k];
91
92 let has_ab = (graph.neighbors(node_a)
94 .map(|neighbors| neighbors.contains(node_b))
95 .unwrap_or(false)) ||
96 (graph.neighbors(node_b)
97 .map(|neighbors| neighbors.contains(node_a))
98 .unwrap_or(false));
99
100 let has_bc = (graph.neighbors(node_b)
101 .map(|neighbors| neighbors.contains(node_c))
102 .unwrap_or(false)) ||
103 (graph.neighbors(node_c)
104 .map(|neighbors| neighbors.contains(node_b))
105 .unwrap_or(false));
106
107 let has_ac = (graph.neighbors(node_a)
108 .map(|neighbors| neighbors.contains(node_c))
109 .unwrap_or(false)) ||
110 (graph.neighbors(node_c)
111 .map(|neighbors| neighbors.contains(node_a))
112 .unwrap_or(false));
113
114 if has_ab && has_bc && has_ac {
115 *node_triangles.get_mut(node_a).unwrap() += 1;
117 *node_triangles.get_mut(node_b).unwrap() += 1;
118 *node_triangles.get_mut(node_c).unwrap() += 1;
119 }
120 }
121 }
122 }
123
124 Ok(node_triangles)
125 }
126}
127
128impl GraphAlgorithm for TriangleCount {
129 fn execute(&self, graph: &ArrowGraph, _params: &AlgorithmParams) -> Result<RecordBatch> {
130 let total_triangles = self.count_triangles(graph)?;
131
132 let schema = Arc::new(Schema::new(vec![
133 Field::new("metric", DataType::Utf8, false),
134 Field::new("value", DataType::UInt64, false),
135 ]));
136
137 RecordBatch::try_new(
138 schema,
139 vec![
140 Arc::new(StringArray::from(vec!["triangle_count"])),
141 Arc::new(UInt64Array::from(vec![total_triangles])),
142 ],
143 ).map_err(GraphError::from)
144 }
145
146 fn name(&self) -> &'static str {
147 "triangle_count"
148 }
149
150 fn description(&self) -> &'static str {
151 "Count the total number of triangles in the graph"
152 }
153}
154
155pub struct ClusteringCoefficient;
156
157impl ClusteringCoefficient {
158 fn calculate_local_clustering(&self, graph: &ArrowGraph) -> Result<HashMap<String, f64>> {
160 let mut clustering: HashMap<String, f64> = HashMap::new();
161 let triangle_counter = TriangleCount;
162 let node_triangles = triangle_counter.count_triangles_per_node(graph)?;
163
164 for node_id in graph.node_ids() {
165 if let Some(neighbors) = graph.neighbors(node_id) {
166 let degree = neighbors.len();
167
168 if degree < 2 {
169 clustering.insert(node_id.clone(), 0.0);
171 } else {
172 let triangles = *node_triangles.get(node_id).unwrap_or(&0);
173 let possible_triangles = (degree * (degree - 1)) / 2;
174 let coefficient = triangles as f64 / possible_triangles as f64;
175 clustering.insert(node_id.clone(), coefficient);
176 }
177 } else {
178 clustering.insert(node_id.clone(), 0.0);
179 }
180 }
181
182 Ok(clustering)
183 }
184
185 fn calculate_global_clustering(&self, graph: &ArrowGraph) -> Result<f64> {
187 let triangle_counter = TriangleCount;
188 let total_triangles = triangle_counter.count_triangles(graph)? as f64;
189
190 let mut total_triples = 0u64;
192
193 for node_id in graph.node_ids() {
194 if let Some(neighbors) = graph.neighbors(node_id) {
195 let degree = neighbors.len();
196 if degree >= 2 {
197 total_triples += (degree * (degree - 1)) as u64 / 2;
200 }
201 }
202 }
203
204 if total_triples == 0 {
205 Ok(0.0)
206 } else {
207 let coefficient = 3.0 * total_triangles / total_triples as f64;
210 Ok(coefficient.min(1.0))
212 }
213 }
214}
215
216impl GraphAlgorithm for ClusteringCoefficient {
217 fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
218 let mode: String = params.get("mode").unwrap_or("local".to_string());
219
220 match mode.as_str() {
221 "local" => {
222 let clustering = self.calculate_local_clustering(graph)?;
223
224 if clustering.is_empty() {
225 let schema = Arc::new(Schema::new(vec![
226 Field::new("node_id", DataType::Utf8, false),
227 Field::new("clustering_coefficient", DataType::Float64, false),
228 ]));
229
230 return RecordBatch::try_new(
231 schema,
232 vec![
233 Arc::new(StringArray::from(Vec::<String>::new())),
234 Arc::new(Float64Array::from(Vec::<f64>::new())),
235 ],
236 ).map_err(GraphError::from);
237 }
238
239 let mut sorted_nodes: Vec<(&String, &f64)> = clustering.iter().collect();
241 sorted_nodes.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
242
243 let node_ids: Vec<String> = sorted_nodes.iter().map(|(node, _)| (*node).clone()).collect();
244 let coefficients: Vec<f64> = sorted_nodes.iter().map(|(_, &coeff)| coeff).collect();
245
246 let schema = Arc::new(Schema::new(vec![
247 Field::new("node_id", DataType::Utf8, false),
248 Field::new("clustering_coefficient", DataType::Float64, false),
249 ]));
250
251 RecordBatch::try_new(
252 schema,
253 vec![
254 Arc::new(StringArray::from(node_ids)),
255 Arc::new(Float64Array::from(coefficients)),
256 ],
257 ).map_err(GraphError::from)
258 },
259 "global" => {
260 let global_coefficient = self.calculate_global_clustering(graph)?;
261
262 let schema = Arc::new(Schema::new(vec![
263 Field::new("metric", DataType::Utf8, false),
264 Field::new("value", DataType::Float64, false),
265 ]));
266
267 RecordBatch::try_new(
268 schema,
269 vec![
270 Arc::new(StringArray::from(vec!["global_clustering_coefficient"])),
271 Arc::new(Float64Array::from(vec![global_coefficient])),
272 ],
273 ).map_err(GraphError::from)
274 },
275 _ => Err(GraphError::invalid_parameter(
276 "mode must be 'local' or 'global'"
277 ))
278 }
279 }
280
281 fn name(&self) -> &'static str {
282 "clustering_coefficient"
283 }
284
285 fn description(&self) -> &'static str {
286 "Calculate local or global clustering coefficient"
287 }
288}