1use crate::error::Result;
2use crate::streaming::incremental::{IncrementalGraphProcessor, UpdateResult};
3use arrow::array::Array;
4use std::collections::HashMap;
5
6pub trait StreamingAlgorithm<T> {
9 fn initialize(&mut self, processor: &IncrementalGraphProcessor) -> Result<()>;
11
12 fn update(&mut self, processor: &IncrementalGraphProcessor, changes: &UpdateResult) -> Result<()>;
14
15 fn get_result(&self) -> &T;
17
18 fn recompute(&mut self, processor: &IncrementalGraphProcessor) -> Result<()>;
20
21 fn needs_recomputation(&self, changes: &UpdateResult) -> bool;
23}
24
25#[derive(Debug, Clone)]
27pub struct StreamingPageRank {
28 scores: HashMap<String, f64>,
29 damping_factor: f64,
30 max_iterations: usize,
31 tolerance: f64,
32 iteration_count: usize,
33 converged: bool,
34}
35
36impl StreamingPageRank {
37 pub fn new(damping_factor: f64, max_iterations: usize, tolerance: f64) -> Self {
38 Self {
39 scores: HashMap::new(),
40 damping_factor,
41 max_iterations,
42 tolerance,
43 iteration_count: 0,
44 converged: false,
45 }
46 }
47
48 pub fn default() -> Self {
50 Self::new(0.85, 50, 1e-6)
51 }
52
53 fn iterate(&mut self, adjacency: &HashMap<String, Vec<(String, f64)>>, nodes: &[String]) -> Result<bool> {
55 let node_count = nodes.len() as f64;
56 let base_score = (1.0 - self.damping_factor) / node_count;
57
58 let mut new_scores = HashMap::new();
59
60 for node in nodes {
62 new_scores.insert(node.clone(), base_score);
63 }
64
65 for (source, targets) in adjacency {
67 let source_score = self.scores.get(source).copied().unwrap_or(1.0 / node_count);
68 let out_degree = targets.len() as f64;
69
70 if out_degree > 0.0 {
71 let contribution_per_link = self.damping_factor * source_score / out_degree;
72
73 for (target, _weight) in targets {
74 *new_scores.entry(target.clone()).or_insert(base_score) += contribution_per_link;
75 }
76 }
77 }
78
79 let mut max_change: f64 = 0.0;
81 for (node, new_score) in &new_scores {
82 let old_score = self.scores.get(node).copied().unwrap_or(1.0 / node_count);
83 let change = (new_score - old_score).abs();
84 max_change = max_change.max(change);
85 }
86
87 self.scores = new_scores;
88 self.iteration_count += 1;
89
90 let converged = max_change < self.tolerance;
91 self.converged = converged;
92
93 Ok(converged)
94 }
95
96 pub fn top_nodes(&self, k: usize) -> Vec<(String, f64)> {
98 let mut node_scores: Vec<(String, f64)> = self.scores.iter()
99 .map(|(node, score)| (node.clone(), *score))
100 .collect();
101
102 node_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
103 node_scores.truncate(k);
104 node_scores
105 }
106
107 pub fn node_score(&self, node_id: &str) -> Option<f64> {
109 self.scores.get(node_id).copied()
110 }
111}
112
113impl StreamingAlgorithm<HashMap<String, f64>> for StreamingPageRank {
114 fn initialize(&mut self, processor: &IncrementalGraphProcessor) -> Result<()> {
115 let graph = processor.graph();
117 let nodes_batch = &graph.nodes;
118 let edges_batch = &graph.edges;
119
120 let mut nodes = Vec::new();
122 if nodes_batch.num_rows() > 0 {
123 let node_ids = nodes_batch.column(0)
124 .as_any()
125 .downcast_ref::<arrow::array::StringArray>()
126 .ok_or_else(|| crate::error::GraphError::graph_construction("Expected string array for node IDs"))?;
127
128 for i in 0..node_ids.len() {
129 nodes.push(node_ids.value(i).to_string());
130 }
131 }
132
133 let mut adjacency: HashMap<String, Vec<(String, f64)>> = HashMap::new();
135 if edges_batch.num_rows() > 0 {
136 let source_ids = edges_batch.column(0)
137 .as_any()
138 .downcast_ref::<arrow::array::StringArray>()
139 .ok_or_else(|| crate::error::GraphError::graph_construction("Expected string array for source IDs"))?;
140 let target_ids = edges_batch.column(1)
141 .as_any()
142 .downcast_ref::<arrow::array::StringArray>()
143 .ok_or_else(|| crate::error::GraphError::graph_construction("Expected string array for target IDs"))?;
144 let weights = edges_batch.column(2)
145 .as_any()
146 .downcast_ref::<arrow::array::Float64Array>()
147 .ok_or_else(|| crate::error::GraphError::graph_construction("Expected float64 array for weights"))?;
148
149 for i in 0..source_ids.len() {
150 let source = source_ids.value(i).to_string();
151 let target = target_ids.value(i).to_string();
152 let weight = weights.value(i);
153
154 adjacency.entry(source).or_insert_with(Vec::new).push((target, weight));
155 }
156 }
157
158 let node_count = nodes.len() as f64;
160 if node_count > 0.0 {
161 let initial_score = 1.0 / node_count;
162 for node in &nodes {
163 self.scores.insert(node.clone(), initial_score);
164 }
165
166 for _ in 0..self.max_iterations {
168 if self.iterate(&adjacency, &nodes)? {
169 break;
170 }
171 }
172 }
173
174 Ok(())
175 }
176
177 fn update(&mut self, processor: &IncrementalGraphProcessor, changes: &UpdateResult) -> Result<()> {
178 if self.needs_recomputation(changes) {
180 return self.recompute(processor);
181 }
182
183 let graph = processor.graph();
186 let nodes_batch = &graph.nodes;
187 let edges_batch = &graph.edges;
188
189 if changes.vertices_added > 0 {
191 let node_count = processor.graph().node_count() as f64;
192 let initial_score = 1.0 / node_count;
193
194 for score in self.scores.values_mut() {
196 *score *= (node_count - changes.vertices_added as f64) / node_count;
197 }
198
199 if nodes_batch.num_rows() > 0 {
201 let node_ids = nodes_batch.column(0)
202 .as_any()
203 .downcast_ref::<arrow::array::StringArray>()
204 .ok_or_else(|| crate::error::GraphError::graph_construction("Expected string array for node IDs"))?;
205
206 for i in 0..node_ids.len() {
207 let node = node_ids.value(i).to_string();
208 self.scores.entry(node).or_insert(initial_score);
209 }
210 }
211 }
212
213 if changes.vertices_removed > 0 {
215 let mut valid_nodes = std::collections::HashSet::new();
218 if nodes_batch.num_rows() > 0 {
219 let node_ids = nodes_batch.column(0)
220 .as_any()
221 .downcast_ref::<arrow::array::StringArray>()
222 .ok_or_else(|| crate::error::GraphError::graph_construction("Expected string array for node IDs"))?;
223
224 for i in 0..node_ids.len() {
225 valid_nodes.insert(node_ids.value(i).to_string());
226 }
227 }
228
229 self.scores.retain(|node, _| valid_nodes.contains(node));
230 }
231
232 if changes.edges_added > 0 || changes.edges_removed > 0 {
234 let mut adjacency: HashMap<String, Vec<(String, f64)>> = HashMap::new();
236 if edges_batch.num_rows() > 0 {
237 let source_ids = edges_batch.column(0)
238 .as_any()
239 .downcast_ref::<arrow::array::StringArray>()
240 .ok_or_else(|| crate::error::GraphError::graph_construction("Expected string array for source IDs"))?;
241 let target_ids = edges_batch.column(1)
242 .as_any()
243 .downcast_ref::<arrow::array::StringArray>()
244 .ok_or_else(|| crate::error::GraphError::graph_construction("Expected string array for target IDs"))?;
245 let weights = edges_batch.column(2)
246 .as_any()
247 .downcast_ref::<arrow::array::Float64Array>()
248 .ok_or_else(|| crate::error::GraphError::graph_construction("Expected float64 array for weights"))?;
249
250 for i in 0..source_ids.len() {
251 let source = source_ids.value(i).to_string();
252 let target = target_ids.value(i).to_string();
253 let weight = weights.value(i);
254
255 adjacency.entry(source).or_insert_with(Vec::new).push((target, weight));
256 }
257 }
258
259 let nodes: Vec<String> = self.scores.keys().cloned().collect();
260
261 let update_iterations = std::cmp::min(10, self.max_iterations);
263 for _ in 0..update_iterations {
264 if self.iterate(&adjacency, &nodes)? {
265 break;
266 }
267 }
268 }
269
270 Ok(())
271 }
272
273 fn get_result(&self) -> &HashMap<String, f64> {
274 &self.scores
275 }
276
277 fn recompute(&mut self, processor: &IncrementalGraphProcessor) -> Result<()> {
278 self.scores.clear();
279 self.iteration_count = 0;
280 self.converged = false;
281 self.initialize(processor)
282 }
283
284 fn needs_recomputation(&self, changes: &UpdateResult) -> bool {
285 let total_changes = changes.vertices_added + changes.vertices_removed +
287 changes.edges_added + changes.edges_removed;
288
289 total_changes > 10 || !self.converged
290 }
291}
292
293#[derive(Debug, Clone)]
295pub struct StreamingConnectedComponents {
296 components: HashMap<String, String>, component_sizes: HashMap<String, usize>, }
299
300impl StreamingConnectedComponents {
301 pub fn new() -> Self {
302 Self {
303 components: HashMap::new(),
304 component_sizes: HashMap::new(),
305 }
306 }
307
308 pub fn component_of(&self, node_id: &str) -> Option<&String> {
310 self.components.get(node_id)
311 }
312
313 pub fn component_size(&self, node_id: &str) -> Option<usize> {
315 self.components.get(node_id)
316 .and_then(|comp_id| self.component_sizes.get(comp_id))
317 .copied()
318 }
319
320 pub fn all_components(&self) -> Vec<(String, usize)> {
322 self.component_sizes.iter()
323 .map(|(id, size)| (id.clone(), *size))
324 .collect()
325 }
326
327 pub fn component_count(&self) -> usize {
329 self.component_sizes.len()
330 }
331
332 #[allow(dead_code)]
334 fn find_root(&self, mut node: String, temp_parents: &mut HashMap<String, String>) -> String {
335 let mut path = Vec::new();
336
337 while let Some(parent) = temp_parents.get(&node).or_else(|| self.components.get(&node)) {
339 if parent == &node {
340 break; }
342 path.push(node.clone());
343 node = parent.clone();
344 }
345
346 for path_node in path {
348 temp_parents.insert(path_node, node.clone());
349 }
350
351 node
352 }
353
354 fn union_components(&mut self, node1: &str, node2: &str) {
356 let comp1 = self.components.get(node1).cloned().unwrap_or_else(|| node1.to_string());
357 let comp2 = self.components.get(node2).cloned().unwrap_or_else(|| node2.to_string());
358
359 if comp1 == comp2 {
360 return; }
362
363 let size1 = self.component_sizes.get(&comp1).copied().unwrap_or(1);
365 let size2 = self.component_sizes.get(&comp2).copied().unwrap_or(1);
366
367 let (smaller, larger, new_size) = if size1 <= size2 {
368 (comp1, comp2, size1 + size2)
369 } else {
370 (comp2, comp1, size1 + size2)
371 };
372
373 let nodes_to_update: Vec<String> = self.components.iter()
375 .filter(|(_, comp)| *comp == &smaller)
376 .map(|(node, _)| node.clone())
377 .collect();
378
379 for node in nodes_to_update {
380 self.components.insert(node, larger.clone());
381 }
382
383 self.component_sizes.insert(larger, new_size);
385 self.component_sizes.remove(&smaller);
386 }
387}
388
389impl Default for StreamingConnectedComponents {
390 fn default() -> Self {
391 Self::new()
392 }
393}
394
395impl StreamingAlgorithm<HashMap<String, String>> for StreamingConnectedComponents {
396 fn initialize(&mut self, processor: &IncrementalGraphProcessor) -> Result<()> {
397 let graph = processor.graph();
398 let nodes_batch = &graph.nodes;
399 let edges_batch = &graph.edges;
400
401 self.components.clear();
402 self.component_sizes.clear();
403
404 if nodes_batch.num_rows() > 0 {
406 let node_ids = nodes_batch.column(0)
407 .as_any()
408 .downcast_ref::<arrow::array::StringArray>()
409 .ok_or_else(|| crate::error::GraphError::graph_construction("Expected string array for node IDs"))?;
410
411 for i in 0..node_ids.len() {
412 let node = node_ids.value(i).to_string();
413 self.components.insert(node.clone(), node.clone());
414 self.component_sizes.insert(node, 1);
415 }
416 }
417
418 if edges_batch.num_rows() > 0 {
420 let source_ids = edges_batch.column(0)
421 .as_any()
422 .downcast_ref::<arrow::array::StringArray>()
423 .ok_or_else(|| crate::error::GraphError::graph_construction("Expected string array for source IDs"))?;
424 let target_ids = edges_batch.column(1)
425 .as_any()
426 .downcast_ref::<arrow::array::StringArray>()
427 .ok_or_else(|| crate::error::GraphError::graph_construction("Expected string array for target IDs"))?;
428
429 for i in 0..source_ids.len() {
430 let source = source_ids.value(i);
431 let target = target_ids.value(i);
432 self.union_components(source, target);
433 }
434 }
435
436 Ok(())
437 }
438
439 fn update(&mut self, processor: &IncrementalGraphProcessor, changes: &UpdateResult) -> Result<()> {
440 if self.needs_recomputation(changes) {
442 return self.recompute(processor);
443 }
444
445 let graph = processor.graph();
446 let nodes_batch = &graph.nodes;
447 let edges_batch = &graph.edges;
448
449 if changes.vertices_added > 0 {
451 if nodes_batch.num_rows() > 0 {
452 let node_ids = nodes_batch.column(0)
453 .as_any()
454 .downcast_ref::<arrow::array::StringArray>()
455 .ok_or_else(|| crate::error::GraphError::graph_construction("Expected string array for node IDs"))?;
456
457 for i in 0..node_ids.len() {
458 let node = node_ids.value(i).to_string();
459 if !self.components.contains_key(&node) {
460 self.components.insert(node.clone(), node.clone());
461 self.component_sizes.insert(node, 1);
462 }
463 }
464 }
465 }
466
467 if changes.vertices_removed > 0 {
469 return self.recompute(processor);
472 }
473
474 if changes.edges_added > 0 {
476 if edges_batch.num_rows() > 0 {
477 let source_ids = edges_batch.column(0)
478 .as_any()
479 .downcast_ref::<arrow::array::StringArray>()
480 .ok_or_else(|| crate::error::GraphError::graph_construction("Expected string array for source IDs"))?;
481 let target_ids = edges_batch.column(1)
482 .as_any()
483 .downcast_ref::<arrow::array::StringArray>()
484 .ok_or_else(|| crate::error::GraphError::graph_construction("Expected string array for target IDs"))?;
485
486 for i in 0..source_ids.len() {
487 let source = source_ids.value(i);
488 let target = target_ids.value(i);
489
490 if !self.components.contains_key(source) {
492 self.components.insert(source.to_string(), source.to_string());
493 self.component_sizes.insert(source.to_string(), 1);
494 }
495 if !self.components.contains_key(target) {
496 self.components.insert(target.to_string(), target.to_string());
497 self.component_sizes.insert(target.to_string(), 1);
498 }
499
500 self.union_components(source, target);
501 }
502 }
503 }
504
505 if changes.edges_removed > 0 {
507 return self.recompute(processor);
510 }
511
512 Ok(())
513 }
514
515 fn get_result(&self) -> &HashMap<String, String> {
516 &self.components
517 }
518
519 fn recompute(&mut self, processor: &IncrementalGraphProcessor) -> Result<()> {
520 self.initialize(processor)
521 }
522
523 fn needs_recomputation(&self, changes: &UpdateResult) -> bool {
524 let total_changes = changes.vertices_added + changes.vertices_removed +
527 changes.edges_added + changes.edges_removed;
528
529 changes.vertices_removed > 0 || changes.edges_removed > 0 || total_changes > 20
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536 use crate::graph::ArrowGraph;
537 use arrow::array::{StringArray, Float64Array};
538 use arrow::record_batch::RecordBatch;
539 use arrow::datatypes::{Schema, Field, DataType};
540 use std::sync::Arc;
541
542 fn create_test_graph() -> Result<ArrowGraph> {
543 let nodes_schema = Arc::new(Schema::new(vec![
545 Field::new("id", DataType::Utf8, false),
546 ]));
547 let node_ids = StringArray::from(vec!["A", "B", "C", "D"]);
548 let nodes_batch = RecordBatch::try_new(
549 nodes_schema,
550 vec![Arc::new(node_ids)],
551 )?;
552
553 let edges_schema = Arc::new(Schema::new(vec![
555 Field::new("source", DataType::Utf8, false),
556 Field::new("target", DataType::Utf8, false),
557 Field::new("weight", DataType::Float64, false),
558 ]));
559 let sources = StringArray::from(vec!["A", "B"]);
560 let targets = StringArray::from(vec!["B", "C"]);
561 let weights = Float64Array::from(vec![1.0, 1.0]);
562 let edges_batch = RecordBatch::try_new(
563 edges_schema,
564 vec![Arc::new(sources), Arc::new(targets), Arc::new(weights)],
565 )?;
566
567 ArrowGraph::new(nodes_batch, edges_batch)
568 }
569
570 #[test]
571 fn test_streaming_pagerank_initialization() {
572 let graph = create_test_graph().unwrap();
573 let processor = IncrementalGraphProcessor::new(graph).unwrap();
574
575 let mut pagerank = StreamingPageRank::default();
576 pagerank.initialize(&processor).unwrap();
577
578 let scores = pagerank.get_result();
579 assert_eq!(scores.len(), 4); for node in ["A", "B", "C", "D"] {
583 assert!(scores.contains_key(node));
584 assert!(scores[node] > 0.0);
585 }
586
587 assert!(scores["B"] > scores["A"]);
589 assert!(scores["C"] > scores["D"]);
590 }
591
592 #[test]
593 fn test_streaming_pagerank_update() {
594 let graph = create_test_graph().unwrap();
595 let mut processor = IncrementalGraphProcessor::new(graph).unwrap();
596 processor.set_batch_size(1); let mut pagerank = StreamingPageRank::default();
599 pagerank.initialize(&processor).unwrap();
600
601 let initial_scores = pagerank.get_result().clone();
602
603 processor.add_edge("A".to_string(), "D".to_string(), 1.0).unwrap();
605
606 let update_result = crate::streaming::incremental::UpdateResult {
608 vertices_added: 0,
609 vertices_removed: 0,
610 edges_added: 1,
611 edges_removed: 0,
612 affected_components: vec![],
613 recomputation_needed: false,
614 };
615
616 pagerank.update(&processor, &update_result).unwrap();
617
618 let updated_scores = pagerank.get_result();
619
620 assert!(updated_scores["D"] > initial_scores["D"]);
622 }
623
624 #[test]
625 fn test_streaming_connected_components_initialization() {
626 let graph = create_test_graph().unwrap();
627 let processor = IncrementalGraphProcessor::new(graph).unwrap();
628
629 let mut components = StreamingConnectedComponents::new();
630 components.initialize(&processor).unwrap();
631
632 let result = components.get_result();
633 assert_eq!(result.len(), 4); let comp_a = &result["A"];
637 let comp_b = &result["B"];
638 let comp_c = &result["C"];
639 assert_eq!(comp_a, comp_b);
640 assert_eq!(comp_b, comp_c);
641
642 let comp_d = &result["D"];
644 assert_ne!(comp_a, comp_d);
645
646 assert_eq!(components.component_count(), 2);
648 }
649
650 #[test]
651 fn test_streaming_connected_components_update() {
652 let graph = create_test_graph().unwrap();
653 let mut processor = IncrementalGraphProcessor::new(graph).unwrap();
654 processor.set_batch_size(1); let mut components = StreamingConnectedComponents::new();
657 components.initialize(&processor).unwrap();
658
659 assert_eq!(components.component_count(), 2); processor.add_edge("C".to_string(), "D".to_string(), 1.0).unwrap();
663
664 let update_result = crate::streaming::incremental::UpdateResult {
665 vertices_added: 0,
666 vertices_removed: 0,
667 edges_added: 1,
668 edges_removed: 0,
669 affected_components: vec![],
670 recomputation_needed: false,
671 };
672
673 components.update(&processor, &update_result).unwrap();
674
675 assert_eq!(components.component_count(), 1);
677
678 let result = components.get_result();
679 let comp_a = &result["A"];
680 let comp_d = &result["D"];
681 assert_eq!(comp_a, comp_d); }
683
684 #[test]
685 fn test_streaming_algorithm_recomputation() {
686 let graph = create_test_graph().unwrap();
687 let mut processor = IncrementalGraphProcessor::new(graph).unwrap();
688
689 let mut pagerank = StreamingPageRank::default();
690 pagerank.initialize(&processor).unwrap();
691
692 let large_changes = crate::streaming::incremental::UpdateResult {
694 vertices_added: 15,
695 vertices_removed: 5,
696 edges_added: 20,
697 edges_removed: 10,
698 affected_components: vec![],
699 recomputation_needed: true,
700 };
701
702 assert!(pagerank.needs_recomputation(&large_changes));
703
704 pagerank.update(&processor, &large_changes).unwrap();
706
707 let scores = pagerank.get_result();
709 assert_eq!(scores.len(), 4);
710 }
711
712 #[test]
713 fn test_pagerank_top_nodes() {
714 let graph = create_test_graph().unwrap();
715 let processor = IncrementalGraphProcessor::new(graph).unwrap();
716
717 let mut pagerank = StreamingPageRank::default();
718 pagerank.initialize(&processor).unwrap();
719
720 let top_2 = pagerank.top_nodes(2);
721 assert_eq!(top_2.len(), 2);
722
723 assert!(top_2[0].1 >= top_2[1].1);
725
726 assert!(pagerank.node_score("A").is_some());
728 assert!(pagerank.node_score("nonexistent").is_none());
729 }
730
731 #[test]
732 fn test_connected_components_queries() {
733 let graph = create_test_graph().unwrap();
734 let processor = IncrementalGraphProcessor::new(graph).unwrap();
735
736 let mut components = StreamingConnectedComponents::new();
737 components.initialize(&processor).unwrap();
738
739 assert!(components.component_of("A").is_some());
741 assert!(components.component_of("nonexistent").is_none());
742
743 assert!(components.component_size("A").is_some());
744 assert_eq!(components.component_size("A"), Some(3)); assert_eq!(components.component_size("D"), Some(1)); let all_components = components.all_components();
748 assert_eq!(all_components.len(), 2);
749
750 let total_size: usize = all_components.iter().map(|(_, size)| size).sum();
752 assert_eq!(total_size, 4); }
754}