arrow_graph/algorithms/
pathfinding.rs

1use arrow::record_batch::RecordBatch;
2use arrow::array::{StringArray, Float64Array, UInt32Array};
3use arrow::datatypes::{DataType, Field, Schema};
4use arrow::array::builder::{ListBuilder, StringBuilder};
5use std::sync::Arc;
6use std::collections::{HashMap, BinaryHeap, VecDeque};
7use std::cmp::Ordering;
8use crate::algorithms::{GraphAlgorithm, AlgorithmParams};
9use crate::graph::ArrowGraph;
10use crate::error::{GraphError, Result};
11
12#[derive(Debug, Clone, PartialEq)]
13struct DijkstraNode {
14    node_id: String,
15    distance: f64,
16    previous: Option<String>,
17}
18
19impl Eq for DijkstraNode {}
20
21impl Ord for DijkstraNode {
22    fn cmp(&self, other: &Self) -> Ordering {
23        // Reverse ordering for min-heap behavior
24        other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal)
25    }
26}
27
28impl PartialOrd for DijkstraNode {
29    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
30        Some(self.cmp(other))
31    }
32}
33
34pub struct ShortestPath;
35
36impl ShortestPath {
37    /// Dijkstra's algorithm implementation optimized for Arrow
38    fn dijkstra(
39        &self,
40        graph: &ArrowGraph,
41        source: &str,
42        target: Option<&str>,
43    ) -> Result<HashMap<String, (f64, Option<String>)>> {
44        let mut distances: HashMap<String, f64> = HashMap::new();
45        let mut previous: HashMap<String, Option<String>> = HashMap::new();
46        let mut heap = BinaryHeap::new();
47        
48        // Initialize distances
49        for node_id in graph.node_ids() {
50            let dist = if node_id == source { 0.0 } else { f64::INFINITY };
51            distances.insert(node_id.clone(), dist);
52            previous.insert(node_id.clone(), None);
53        }
54        
55        heap.push(DijkstraNode {
56            node_id: source.to_string(),
57            distance: 0.0,
58            previous: None,
59        });
60        
61        while let Some(current) = heap.pop() {
62            // Early termination if we're looking for a specific target
63            if let Some(target_node) = target {
64                if current.node_id == target_node {
65                    break;
66                }
67            }
68            
69            // Skip if we've already found a better path
70            if current.distance > *distances.get(&current.node_id).unwrap_or(&f64::INFINITY) {
71                continue;
72            }
73            
74            // Check neighbors
75            if let Some(neighbors) = graph.neighbors(&current.node_id) {
76                for neighbor in neighbors {
77                    let edge_weight = graph.edge_weight(&current.node_id, neighbor).unwrap_or(1.0);
78                    let new_distance = current.distance + edge_weight;
79                    
80                    if new_distance < *distances.get(neighbor).unwrap_or(&f64::INFINITY) {
81                        distances.insert(neighbor.clone(), new_distance);
82                        previous.insert(neighbor.clone(), Some(current.node_id.clone()));
83                        
84                        heap.push(DijkstraNode {
85                            node_id: neighbor.clone(),
86                            distance: new_distance,
87                            previous: Some(current.node_id.clone()),
88                        });
89                    }
90                }
91            }
92        }
93        
94        // Combine results
95        let mut result = HashMap::new();
96        for node_id in graph.node_ids() {
97            let dist = *distances.get(node_id).unwrap_or(&f64::INFINITY);
98            let prev = previous.get(node_id).cloned().flatten();
99            result.insert(node_id.clone(), (dist, prev));
100        }
101        
102        Ok(result)
103    }
104    
105    /// Reconstruct path from source to target
106    fn reconstruct_path(
107        &self,
108        target: &str,
109        previous: &HashMap<String, Option<String>>,
110    ) -> Vec<String> {
111        let mut path = Vec::new();
112        let mut current = Some(target.to_string());
113        
114        while let Some(node) = current {
115            path.push(node.clone());
116            current = previous.get(&node).cloned().flatten();
117        }
118        
119        path.reverse();
120        path
121    }
122}
123
124impl GraphAlgorithm for ShortestPath {
125    fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
126        let source: String = params.get("source")
127            .ok_or_else(|| GraphError::invalid_parameter("source parameter required"))?;
128        
129        let target: Option<String> = params.get("target");
130        
131        match target {
132            Some(target_node) => {
133                // Single source-target shortest path
134                let results = self.dijkstra(graph, &source, Some(&target_node))?;
135                
136                if let Some((distance, _)) = results.get(&target_node) {
137                    if distance.is_infinite() {
138                        return Err(GraphError::algorithm("No path found between source and target"));
139                    }
140                    
141                    let path = self.reconstruct_path(&target_node, &results.iter()
142                        .map(|(k, (_, prev))| (k.clone(), prev.clone()))
143                        .collect());
144                    
145                    // Create Arrow RecordBatch with path and distance
146                    let schema = Arc::new(Schema::new(vec![
147                        Field::new("source", DataType::Utf8, false),
148                        Field::new("target", DataType::Utf8, false),
149                        Field::new("distance", DataType::Float64, false),
150                        Field::new("path", DataType::List(
151                            Arc::new(Field::new("item", DataType::Utf8, true))
152                        ), false),
153                    ]));
154                    
155                    // Build path array using ListBuilder
156                    let mut list_builder = ListBuilder::new(StringBuilder::new());
157                    for node in &path {
158                        list_builder.values().append_value(node);
159                    }
160                    list_builder.append(true);
161                    let path_array = list_builder.finish();
162                    
163                    RecordBatch::try_new(
164                        schema,
165                        vec![
166                            Arc::new(StringArray::from(vec![source])),
167                            Arc::new(StringArray::from(vec![target_node])),
168                            Arc::new(Float64Array::from(vec![*distance])),
169                            Arc::new(path_array),
170                        ],
171                    ).map_err(GraphError::from)
172                } else {
173                    Err(GraphError::node_not_found(target_node))
174                }
175            }
176            None => {
177                // Single source shortest paths to all nodes
178                let results = self.dijkstra(graph, &source, None)?;
179                
180                let mut targets = Vec::new();
181                let mut distances = Vec::new();
182                
183                for (node_id, (distance, _)) in results.iter() {
184                    if node_id != &source && !distance.is_infinite() {
185                        targets.push(node_id.clone());
186                        distances.push(*distance);
187                    }
188                }
189                
190                let schema = Arc::new(Schema::new(vec![
191                    Field::new("source", DataType::Utf8, false),
192                    Field::new("target", DataType::Utf8, false),
193                    Field::new("distance", DataType::Float64, false),
194                ]));
195                
196                let sources = vec![source; targets.len()];
197                
198                RecordBatch::try_new(
199                    schema,
200                    vec![
201                        Arc::new(StringArray::from(sources)),
202                        Arc::new(StringArray::from(targets)),
203                        Arc::new(Float64Array::from(distances)),
204                    ],
205                ).map_err(GraphError::from)
206            }
207        }
208    }
209    
210    fn name(&self) -> &'static str {
211        "shortest_path"
212    }
213    
214    fn description(&self) -> &'static str {
215        "Find the shortest path between nodes using Dijkstra's algorithm"
216    }
217}
218
219pub struct AllPaths;
220
221impl AllPaths {
222    /// BFS-based all paths implementation with hop limit
223    fn find_all_paths(
224        &self,
225        graph: &ArrowGraph,
226        source: &str,
227        target: &str,
228        max_hops: usize,
229    ) -> Result<Vec<Vec<String>>> {
230        let mut all_paths = Vec::new();
231        let mut queue = VecDeque::new();
232        
233        // Start with the source node
234        queue.push_back((vec![source.to_string()], 0));
235        
236        while let Some((current_path, hops)) = queue.pop_front() {
237            let current_node = current_path.last().unwrap();
238            
239            if current_node == target {
240                all_paths.push(current_path);
241                continue;
242            }
243            
244            if hops >= max_hops {
245                continue;
246            }
247            
248            if let Some(neighbors) = graph.neighbors(current_node) {
249                for neighbor in neighbors {
250                    // Avoid cycles
251                    if !current_path.contains(neighbor) {
252                        let mut new_path = current_path.clone();
253                        new_path.push(neighbor.clone());
254                        queue.push_back((new_path, hops + 1));
255                    }
256                }
257            }
258        }
259        
260        Ok(all_paths)
261    }
262}
263
264impl GraphAlgorithm for AllPaths {
265    fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
266        let source: String = params.get("source")
267            .ok_or_else(|| GraphError::invalid_parameter("source parameter required"))?;
268        
269        let target: String = params.get("target")
270            .ok_or_else(|| GraphError::invalid_parameter("target parameter required"))?;
271        
272        let max_hops: usize = params.get("max_hops").unwrap_or(10);
273        
274        let paths = self.find_all_paths(graph, &source, &target, max_hops)?;
275        
276        let schema = Arc::new(Schema::new(vec![
277            Field::new("source", DataType::Utf8, false),
278            Field::new("target", DataType::Utf8, false),
279            Field::new("path_length", DataType::UInt32, false),
280            Field::new("path", DataType::List(
281                Arc::new(Field::new("item", DataType::Utf8, true))
282            ), false),
283        ]));
284        
285        let mut sources = Vec::new();
286        let mut targets = Vec::new();
287        let mut path_lengths = Vec::new();
288        let mut path_arrays = Vec::new();
289        
290        for path in paths {
291            sources.push(source.clone());
292            targets.push(target.clone());
293            path_lengths.push(path.len() as u32 - 1); // Number of edges
294            
295            let path_values: Vec<Option<String>> = path.into_iter().map(Some).collect();
296            path_arrays.push(Some(path_values));
297        }
298        
299        // Build all paths using ListBuilder
300        let mut list_builder = ListBuilder::new(StringBuilder::new());
301        for path_values in path_arrays {
302            if let Some(path) = path_values {
303                for node in path {
304                    if let Some(node_str) = node {
305                        list_builder.values().append_value(&node_str);
306                    }
307                }
308                list_builder.append(true);
309            } else {
310                list_builder.append(false);
311            }
312        }
313        let list_array = list_builder.finish();
314        
315        RecordBatch::try_new(
316            schema,
317            vec![
318                Arc::new(StringArray::from(sources)),
319                Arc::new(StringArray::from(targets)),
320                Arc::new(UInt32Array::from(path_lengths)),
321                Arc::new(list_array),
322            ],
323        ).map_err(GraphError::from)
324    }
325    
326    fn name(&self) -> &'static str {
327        "all_paths"
328    }
329    
330    fn description(&self) -> &'static str {
331        "Find all paths between two nodes with optional hop limit"
332    }
333}