1use arrow::record_batch::RecordBatch;
2use arrow::array::{StringArray, Float64Array, UInt32Array};
3use arrow::datatypes::{DataType, Field, Schema};
4use arrow::array::builder::{ListBuilder, StringBuilder};
5use std::sync::Arc;
6use std::collections::{HashMap, BinaryHeap, VecDeque};
7use std::cmp::Ordering;
8use crate::algorithms::{GraphAlgorithm, AlgorithmParams};
9use crate::graph::ArrowGraph;
10use crate::error::{GraphError, Result};
11
12#[derive(Debug, Clone, PartialEq)]
13struct DijkstraNode {
14 node_id: String,
15 distance: f64,
16 previous: Option<String>,
17}
18
19impl Eq for DijkstraNode {}
20
21impl Ord for DijkstraNode {
22 fn cmp(&self, other: &Self) -> Ordering {
23 other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal)
25 }
26}
27
28impl PartialOrd for DijkstraNode {
29 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
30 Some(self.cmp(other))
31 }
32}
33
34pub struct ShortestPath;
35
36impl ShortestPath {
37 fn dijkstra(
39 &self,
40 graph: &ArrowGraph,
41 source: &str,
42 target: Option<&str>,
43 ) -> Result<HashMap<String, (f64, Option<String>)>> {
44 let mut distances: HashMap<String, f64> = HashMap::new();
45 let mut previous: HashMap<String, Option<String>> = HashMap::new();
46 let mut heap = BinaryHeap::new();
47
48 for node_id in graph.node_ids() {
50 let dist = if node_id == source { 0.0 } else { f64::INFINITY };
51 distances.insert(node_id.clone(), dist);
52 previous.insert(node_id.clone(), None);
53 }
54
55 heap.push(DijkstraNode {
56 node_id: source.to_string(),
57 distance: 0.0,
58 previous: None,
59 });
60
61 while let Some(current) = heap.pop() {
62 if let Some(target_node) = target {
64 if current.node_id == target_node {
65 break;
66 }
67 }
68
69 if current.distance > *distances.get(¤t.node_id).unwrap_or(&f64::INFINITY) {
71 continue;
72 }
73
74 if let Some(neighbors) = graph.neighbors(¤t.node_id) {
76 for neighbor in neighbors {
77 let edge_weight = graph.edge_weight(¤t.node_id, neighbor).unwrap_or(1.0);
78 let new_distance = current.distance + edge_weight;
79
80 if new_distance < *distances.get(neighbor).unwrap_or(&f64::INFINITY) {
81 distances.insert(neighbor.clone(), new_distance);
82 previous.insert(neighbor.clone(), Some(current.node_id.clone()));
83
84 heap.push(DijkstraNode {
85 node_id: neighbor.clone(),
86 distance: new_distance,
87 previous: Some(current.node_id.clone()),
88 });
89 }
90 }
91 }
92 }
93
94 let mut result = HashMap::new();
96 for node_id in graph.node_ids() {
97 let dist = *distances.get(node_id).unwrap_or(&f64::INFINITY);
98 let prev = previous.get(node_id).cloned().flatten();
99 result.insert(node_id.clone(), (dist, prev));
100 }
101
102 Ok(result)
103 }
104
105 fn reconstruct_path(
107 &self,
108 target: &str,
109 previous: &HashMap<String, Option<String>>,
110 ) -> Vec<String> {
111 let mut path = Vec::new();
112 let mut current = Some(target.to_string());
113
114 while let Some(node) = current {
115 path.push(node.clone());
116 current = previous.get(&node).cloned().flatten();
117 }
118
119 path.reverse();
120 path
121 }
122}
123
124impl GraphAlgorithm for ShortestPath {
125 fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
126 let source: String = params.get("source")
127 .ok_or_else(|| GraphError::invalid_parameter("source parameter required"))?;
128
129 let target: Option<String> = params.get("target");
130
131 match target {
132 Some(target_node) => {
133 let results = self.dijkstra(graph, &source, Some(&target_node))?;
135
136 if let Some((distance, _)) = results.get(&target_node) {
137 if distance.is_infinite() {
138 return Err(GraphError::algorithm("No path found between source and target"));
139 }
140
141 let path = self.reconstruct_path(&target_node, &results.iter()
142 .map(|(k, (_, prev))| (k.clone(), prev.clone()))
143 .collect());
144
145 let schema = Arc::new(Schema::new(vec![
147 Field::new("source", DataType::Utf8, false),
148 Field::new("target", DataType::Utf8, false),
149 Field::new("distance", DataType::Float64, false),
150 Field::new("path", DataType::List(
151 Arc::new(Field::new("item", DataType::Utf8, true))
152 ), false),
153 ]));
154
155 let mut list_builder = ListBuilder::new(StringBuilder::new());
157 for node in &path {
158 list_builder.values().append_value(node);
159 }
160 list_builder.append(true);
161 let path_array = list_builder.finish();
162
163 RecordBatch::try_new(
164 schema,
165 vec![
166 Arc::new(StringArray::from(vec![source])),
167 Arc::new(StringArray::from(vec![target_node])),
168 Arc::new(Float64Array::from(vec![*distance])),
169 Arc::new(path_array),
170 ],
171 ).map_err(GraphError::from)
172 } else {
173 Err(GraphError::node_not_found(target_node))
174 }
175 }
176 None => {
177 let results = self.dijkstra(graph, &source, None)?;
179
180 let mut targets = Vec::new();
181 let mut distances = Vec::new();
182
183 for (node_id, (distance, _)) in results.iter() {
184 if node_id != &source && !distance.is_infinite() {
185 targets.push(node_id.clone());
186 distances.push(*distance);
187 }
188 }
189
190 let schema = Arc::new(Schema::new(vec![
191 Field::new("source", DataType::Utf8, false),
192 Field::new("target", DataType::Utf8, false),
193 Field::new("distance", DataType::Float64, false),
194 ]));
195
196 let sources = vec![source; targets.len()];
197
198 RecordBatch::try_new(
199 schema,
200 vec![
201 Arc::new(StringArray::from(sources)),
202 Arc::new(StringArray::from(targets)),
203 Arc::new(Float64Array::from(distances)),
204 ],
205 ).map_err(GraphError::from)
206 }
207 }
208 }
209
210 fn name(&self) -> &'static str {
211 "shortest_path"
212 }
213
214 fn description(&self) -> &'static str {
215 "Find the shortest path between nodes using Dijkstra's algorithm"
216 }
217}
218
219pub struct AllPaths;
220
221impl AllPaths {
222 fn find_all_paths(
224 &self,
225 graph: &ArrowGraph,
226 source: &str,
227 target: &str,
228 max_hops: usize,
229 ) -> Result<Vec<Vec<String>>> {
230 let mut all_paths = Vec::new();
231 let mut queue = VecDeque::new();
232
233 queue.push_back((vec![source.to_string()], 0));
235
236 while let Some((current_path, hops)) = queue.pop_front() {
237 let current_node = current_path.last().unwrap();
238
239 if current_node == target {
240 all_paths.push(current_path);
241 continue;
242 }
243
244 if hops >= max_hops {
245 continue;
246 }
247
248 if let Some(neighbors) = graph.neighbors(current_node) {
249 for neighbor in neighbors {
250 if !current_path.contains(neighbor) {
252 let mut new_path = current_path.clone();
253 new_path.push(neighbor.clone());
254 queue.push_back((new_path, hops + 1));
255 }
256 }
257 }
258 }
259
260 Ok(all_paths)
261 }
262}
263
264impl GraphAlgorithm for AllPaths {
265 fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
266 let source: String = params.get("source")
267 .ok_or_else(|| GraphError::invalid_parameter("source parameter required"))?;
268
269 let target: String = params.get("target")
270 .ok_or_else(|| GraphError::invalid_parameter("target parameter required"))?;
271
272 let max_hops: usize = params.get("max_hops").unwrap_or(10);
273
274 let paths = self.find_all_paths(graph, &source, &target, max_hops)?;
275
276 let schema = Arc::new(Schema::new(vec![
277 Field::new("source", DataType::Utf8, false),
278 Field::new("target", DataType::Utf8, false),
279 Field::new("path_length", DataType::UInt32, false),
280 Field::new("path", DataType::List(
281 Arc::new(Field::new("item", DataType::Utf8, true))
282 ), false),
283 ]));
284
285 let mut sources = Vec::new();
286 let mut targets = Vec::new();
287 let mut path_lengths = Vec::new();
288 let mut path_arrays = Vec::new();
289
290 for path in paths {
291 sources.push(source.clone());
292 targets.push(target.clone());
293 path_lengths.push(path.len() as u32 - 1); let path_values: Vec<Option<String>> = path.into_iter().map(Some).collect();
296 path_arrays.push(Some(path_values));
297 }
298
299 let mut list_builder = ListBuilder::new(StringBuilder::new());
301 for path_values in path_arrays {
302 if let Some(path) = path_values {
303 for node in path {
304 if let Some(node_str) = node {
305 list_builder.values().append_value(&node_str);
306 }
307 }
308 list_builder.append(true);
309 } else {
310 list_builder.append(false);
311 }
312 }
313 let list_array = list_builder.finish();
314
315 RecordBatch::try_new(
316 schema,
317 vec![
318 Arc::new(StringArray::from(sources)),
319 Arc::new(StringArray::from(targets)),
320 Arc::new(UInt32Array::from(path_lengths)),
321 Arc::new(list_array),
322 ],
323 ).map_err(GraphError::from)
324 }
325
326 fn name(&self) -> &'static str {
327 "all_paths"
328 }
329
330 fn description(&self) -> &'static str {
331 "Find all paths between two nodes with optional hop limit"
332 }
333}