Skip to main content

reddb_server/storage/engine/
projection.rs

1//! Graph Projections for RedDB
2//!
3//! Provides graph projection capabilities similar to Neo4j GDS:
4//! - Native projections: Copy subgraph with filtering
5//! - Cypher/query-based projections: Project from traversal results
6//! - Property projections: Select specific properties
7//! - Aggregated relationships: Combine parallel edges
8//!
9//! Projections create lightweight views over the graph for efficient
10//! algorithm execution without modifying the original data.
11
12use std::collections::{HashMap, HashSet};
13
14use super::graph_store::{GraphStore, StoredNode};
15
16// ============================================================================
17// Projection Filter Predicates
18// ============================================================================
19
20/// Node filter specification
21#[derive(Clone, Default)]
22pub struct NodeFilter {
23    /// Include only nodes whose category label matches one of these strings.
24    pub labels: Option<Vec<String>>,
25    /// Include only nodes with these IDs
26    pub ids: Option<HashSet<String>>,
27}
28
29impl NodeFilter {
30    /// Create an empty filter (include all nodes)
31    pub fn all() -> Self {
32        Self::default()
33    }
34
35    /// Filter by node category labels (string form).
36    pub fn with_labels<I, S>(mut self, labels: I) -> Self
37    where
38        I: IntoIterator<Item = S>,
39        S: Into<String>,
40    {
41        self.labels = Some(labels.into_iter().map(Into::into).collect());
42        self
43    }
44
45    /// Filter by node IDs
46    pub fn with_ids(mut self, ids: HashSet<String>) -> Self {
47        self.ids = Some(ids);
48        self
49    }
50
51    /// Check if a node matches this filter
52    pub fn matches(&self, node: &StoredNode) -> bool {
53        if let Some(ref labels) = self.labels {
54            if !labels.iter().any(|l| l == node.node_type.as_str()) {
55                return false;
56            }
57        }
58
59        if let Some(ref ids) = self.ids {
60            if !ids.contains(&node.id) {
61                return false;
62            }
63        }
64
65        true
66    }
67}
68
69/// Edge filter specification
70#[derive(Clone, Default)]
71pub struct EdgeFilter {
72    /// Include only edges whose label matches one of these strings.
73    pub edge_types: Option<Vec<String>>,
74    /// Minimum edge weight
75    pub min_weight: Option<f32>,
76    /// Maximum edge weight
77    pub max_weight: Option<f32>,
78}
79
80impl EdgeFilter {
81    /// Create an empty filter (include all edges)
82    pub fn all() -> Self {
83        Self::default()
84    }
85
86    /// Filter by edge labels (string form).
87    pub fn with_types<I, S>(mut self, types: I) -> Self
88    where
89        I: IntoIterator<Item = S>,
90        S: Into<String>,
91    {
92        self.edge_types = Some(types.into_iter().map(Into::into).collect());
93        self
94    }
95
96    /// Filter by minimum weight
97    pub fn with_min_weight(mut self, weight: f32) -> Self {
98        self.min_weight = Some(weight);
99        self
100    }
101
102    /// Filter by maximum weight
103    pub fn with_max_weight(mut self, weight: f32) -> Self {
104        self.max_weight = Some(weight);
105        self
106    }
107
108    /// Check if an edge label/weight matches this filter.
109    pub fn matches(&self, edge_label: &str, weight: f32) -> bool {
110        if let Some(ref types) = self.edge_types {
111            if !types.iter().any(|t| t == edge_label) {
112                return false;
113            }
114        }
115
116        // Check weight bounds
117        if let Some(min) = self.min_weight {
118            if weight < min {
119                return false;
120            }
121        }
122
123        if let Some(max) = self.max_weight {
124            if weight > max {
125                return false;
126            }
127        }
128
129        true
130    }
131}
132
133// ============================================================================
134// Property Projection
135// ============================================================================
136
137/// Specifies which properties to include in the projection
138#[derive(Clone, Default)]
139pub struct PropertyProjection {
140    /// Whether to include node label
141    pub include_label: bool,
142    /// Whether to include edge weight
143    pub include_weight: bool,
144}
145
146impl PropertyProjection {
147    /// Include all properties
148    pub fn all() -> Self {
149        Self {
150            include_label: true,
151            include_weight: true,
152        }
153    }
154
155    /// Create minimal projection
156    pub fn minimal() -> Self {
157        Self {
158            include_label: false,
159            include_weight: false,
160        }
161    }
162}
163
164// ============================================================================
165// Edge Aggregation
166// ============================================================================
167
168/// Strategy for aggregating parallel edges
169#[derive(Clone, Copy, Debug, PartialEq, Eq)]
170pub enum AggregationStrategy {
171    /// Keep all edges (no aggregation)
172    None,
173    /// Keep only one edge, use sum of weights
174    SumWeight,
175    /// Keep only one edge, use average weight
176    AvgWeight,
177    /// Keep only one edge, use minimum weight
178    MinWeight,
179    /// Keep only one edge, use maximum weight
180    MaxWeight,
181    /// Count the number of parallel edges
182    Count,
183}
184
185// ============================================================================
186// Graph Projection
187// ============================================================================
188
189/// A projected view of a graph
190///
191/// Projections are lightweight copies optimized for algorithm execution.
192/// They don't modify the original graph.
193pub struct GraphProjection {
194    /// Projected nodes (id → node)
195    nodes: HashMap<String, ProjectedNode>,
196    /// Outgoing edges (source_id → [(target_id, edge_label, weight)])
197    outgoing: HashMap<String, Vec<(String, String, f32)>>,
198    /// Incoming edges (target_id → [(source_id, edge_label, weight)])
199    incoming: HashMap<String, Vec<(String, String, f32)>>,
200    /// Projection statistics
201    stats: ProjectionStats,
202}
203
204/// A projected node with minimal data for algorithms
205#[derive(Clone, Debug)]
206pub struct ProjectedNode {
207    pub id: String,
208    pub label: String,
209    /// Optional category label (string form). `None` when the projection
210    /// asked for property-only nodes.
211    pub category: Option<String>,
212}
213
214/// Statistics about the projection
215#[derive(Clone, Debug, Default)]
216pub struct ProjectionStats {
217    /// Number of nodes in projection
218    pub node_count: usize,
219    /// Number of edges in projection
220    pub edge_count: usize,
221    /// Number of nodes filtered out
222    pub nodes_filtered: usize,
223    /// Number of edges filtered out
224    pub edges_filtered: usize,
225    /// Number of edges aggregated
226    pub edges_aggregated: usize,
227}
228
229impl GraphProjection {
230    /// Create a native projection from a graph with filters
231    pub fn native(
232        graph: &GraphStore,
233        node_filter: NodeFilter,
234        edge_filter: EdgeFilter,
235        property_projection: PropertyProjection,
236        aggregation: AggregationStrategy,
237    ) -> Self {
238        let mut nodes: HashMap<String, ProjectedNode> = HashMap::new();
239        let mut outgoing: HashMap<String, Vec<(String, String, f32)>> = HashMap::new();
240        let mut incoming: HashMap<String, Vec<(String, String, f32)>> = HashMap::new();
241        let mut stats = ProjectionStats::default();
242
243        // Collect matching nodes
244        let mut node_ids: HashSet<String> = HashSet::new();
245        for node in graph.iter_nodes() {
246            if node_filter.matches(&node) {
247                let projected = ProjectedNode {
248                    id: node.id.clone(),
249                    label: node.label.clone(),
250                    category: if property_projection.include_label {
251                        Some(node.node_type.as_str().to_string())
252                    } else {
253                        None
254                    },
255                };
256                node_ids.insert(node.id.clone());
257                nodes.insert(node.id.clone(), projected);
258                stats.node_count += 1;
259            } else {
260                stats.nodes_filtered += 1;
261            }
262        }
263
264        // Collect matching edges (both endpoints must be in projection)
265        // Group edges by (source, target) for potential aggregation
266        let mut edge_groups: HashMap<(String, String), Vec<(String, f32)>> = HashMap::new();
267
268        for node_id in &node_ids {
269            for (edge_type, target, weight) in graph.outgoing_edges(node_id) {
270                if !node_ids.contains(&target) {
271                    continue;
272                }
273
274                let edge_label = edge_type.as_str().to_string();
275                if edge_filter.matches(&edge_label, weight) {
276                    let key = (node_id.clone(), target);
277                    edge_groups
278                        .entry(key)
279                        .or_default()
280                        .push((edge_label, weight));
281                } else {
282                    stats.edges_filtered += 1;
283                }
284            }
285        }
286
287        // Apply aggregation
288        for ((source, target), edges) in edge_groups {
289            match aggregation {
290                AggregationStrategy::None => {
291                    // Keep all edges
292                    for (edge_type, weight) in edges {
293                        outgoing.entry(source.clone()).or_default().push((
294                            target.clone(),
295                            edge_type.clone(),
296                            weight,
297                        ));
298                        incoming.entry(target.clone()).or_default().push((
299                            source.clone(),
300                            edge_type,
301                            weight,
302                        ));
303                        stats.edge_count += 1;
304                    }
305                }
306                _ => {
307                    // Aggregate to single edge
308                    if let Some((first_type, _)) = edges.first().cloned() {
309                        let weight = match aggregation {
310                            AggregationStrategy::SumWeight => edges.iter().map(|(_, w)| w).sum(),
311                            AggregationStrategy::AvgWeight => {
312                                let sum: f32 = edges.iter().map(|(_, w)| w).sum();
313                                sum / edges.len() as f32
314                            }
315                            AggregationStrategy::MinWeight => {
316                                edges.iter().map(|(_, w)| *w).fold(f32::INFINITY, f32::min)
317                            }
318                            AggregationStrategy::MaxWeight => edges
319                                .iter()
320                                .map(|(_, w)| *w)
321                                .fold(f32::NEG_INFINITY, f32::max),
322                            AggregationStrategy::Count => edges.len() as f32,
323                            AggregationStrategy::None => unreachable!(),
324                        };
325
326                        if edges.len() > 1 {
327                            stats.edges_aggregated += edges.len() - 1;
328                        }
329
330                        outgoing.entry(source.clone()).or_default().push((
331                            target.clone(),
332                            first_type.clone(),
333                            weight,
334                        ));
335                        incoming
336                            .entry(target)
337                            .or_default()
338                            .push((source, first_type, weight));
339                        stats.edge_count += 1;
340                    }
341                }
342            }
343        }
344
345        Self {
346            nodes,
347            outgoing,
348            incoming,
349            stats,
350        }
351    }
352
353    /// Create a projection from a list of node IDs (induced subgraph)
354    pub fn from_nodes(graph: &GraphStore, node_ids: &[String]) -> Self {
355        let id_set: HashSet<String> = node_ids.iter().cloned().collect();
356        let node_filter = NodeFilter::all().with_ids(id_set);
357        Self::native(
358            graph,
359            node_filter,
360            EdgeFilter::all(),
361            PropertyProjection::all(),
362            AggregationStrategy::None,
363        )
364    }
365
366    /// Create a projection from traversal path results
367    pub fn from_paths(graph: &GraphStore, paths: &[Vec<String>]) -> Self {
368        let mut node_ids: HashSet<String> = HashSet::new();
369        for path in paths {
370            node_ids.extend(path.iter().cloned());
371        }
372        let node_filter = NodeFilter::all().with_ids(node_ids);
373        Self::native(
374            graph,
375            node_filter,
376            EdgeFilter::all(),
377            PropertyProjection::all(),
378            AggregationStrategy::None,
379        )
380    }
381
382    /// Create an undirected projection (each edge becomes bidirectional)
383    pub fn undirected(
384        graph: &GraphStore,
385        node_filter: NodeFilter,
386        edge_filter: EdgeFilter,
387    ) -> Self {
388        let mut projection = Self::native(
389            graph,
390            node_filter,
391            edge_filter,
392            PropertyProjection::all(),
393            AggregationStrategy::SumWeight,
394        );
395
396        // Add reverse edges
397        let mut additional: Vec<(String, String, String, f32)> = Vec::new();
398
399        for (source, edges) in &projection.outgoing {
400            for (target, edge_type, weight) in edges {
401                // Check if reverse edge already exists
402                let has_reverse = projection
403                    .outgoing
404                    .get(target)
405                    .map(|e| e.iter().any(|(t, _, _)| t == source))
406                    .unwrap_or(false);
407
408                if !has_reverse {
409                    additional.push((target.clone(), source.clone(), edge_type.clone(), *weight));
410                }
411            }
412        }
413
414        for (source, target, edge_type, weight) in additional {
415            projection
416                .outgoing
417                .entry(source.clone())
418                .or_default()
419                .push((target.clone(), edge_type.clone(), weight));
420            projection
421                .incoming
422                .entry(target)
423                .or_default()
424                .push((source, edge_type, weight));
425            projection.stats.edge_count += 1;
426        }
427
428        projection
429    }
430
431    /// Get projection statistics
432    pub fn stats(&self) -> &ProjectionStats {
433        &self.stats
434    }
435
436    /// Get number of nodes
437    pub fn node_count(&self) -> usize {
438        self.nodes.len()
439    }
440
441    /// Get number of edges
442    pub fn edge_count(&self) -> usize {
443        self.stats.edge_count
444    }
445
446    /// Get a node by ID
447    pub fn get_node(&self, id: &str) -> Option<&ProjectedNode> {
448        self.nodes.get(id)
449    }
450
451    /// Check if node exists
452    pub fn has_node(&self, id: &str) -> bool {
453        self.nodes.contains_key(id)
454    }
455
456    /// Iterate over all nodes
457    pub fn iter_nodes(&self) -> impl Iterator<Item = &ProjectedNode> {
458        self.nodes.values()
459    }
460
461    /// Get node IDs
462    pub fn node_ids(&self) -> impl Iterator<Item = &String> {
463        self.nodes.keys()
464    }
465
466    /// Get outgoing edges from a node `(target_id, edge_label, weight)`.
467    pub fn outgoing(&self, node_id: &str) -> &[(String, String, f32)] {
468        self.outgoing
469            .get(node_id)
470            .map(|v| v.as_slice())
471            .unwrap_or(&[])
472    }
473
474    /// Get incoming edges to a node `(source_id, edge_label, weight)`.
475    pub fn incoming(&self, node_id: &str) -> &[(String, String, f32)] {
476        self.incoming
477            .get(node_id)
478            .map(|v| v.as_slice())
479            .unwrap_or(&[])
480    }
481
482    /// Get out-degree of a node
483    pub fn out_degree(&self, node_id: &str) -> usize {
484        self.outgoing.get(node_id).map(|v| v.len()).unwrap_or(0)
485    }
486
487    /// Get in-degree of a node
488    pub fn in_degree(&self, node_id: &str) -> usize {
489        self.incoming.get(node_id).map(|v| v.len()).unwrap_or(0)
490    }
491
492    /// Get neighbors (outgoing targets)
493    pub fn neighbors(&self, node_id: &str) -> Vec<&str> {
494        self.outgoing
495            .get(node_id)
496            .map(|edges| edges.iter().map(|(t, _, _)| t.as_str()).collect())
497            .unwrap_or_default()
498    }
499
500    /// Get neighbors with weights
501    pub fn neighbors_weighted(&self, node_id: &str) -> Vec<(&str, f32)> {
502        self.outgoing
503            .get(node_id)
504            .map(|edges| edges.iter().map(|(t, _, w)| (t.as_str(), *w)).collect())
505            .unwrap_or_default()
506    }
507
508    /// Get all neighbors (both directions)
509    pub fn all_neighbors(&self, node_id: &str) -> HashSet<&str> {
510        let mut neighbors: HashSet<&str> = HashSet::new();
511
512        if let Some(edges) = self.outgoing.get(node_id) {
513            for (target, _, _) in edges {
514                neighbors.insert(target.as_str());
515            }
516        }
517
518        if let Some(edges) = self.incoming.get(node_id) {
519            for (source, _, _) in edges {
520                neighbors.insert(source.as_str());
521            }
522        }
523
524        neighbors
525    }
526}
527
528// ============================================================================
529// Projection Builder
530// ============================================================================
531
532/// Builder for creating graph projections with fluent API
533pub struct ProjectionBuilder<'a> {
534    graph: &'a GraphStore,
535    node_filter: NodeFilter,
536    edge_filter: EdgeFilter,
537    property_projection: PropertyProjection,
538    aggregation: AggregationStrategy,
539    undirected: bool,
540}
541
542impl<'a> ProjectionBuilder<'a> {
543    /// Create a new projection builder
544    pub fn new(graph: &'a GraphStore) -> Self {
545        Self {
546            graph,
547            node_filter: NodeFilter::all(),
548            edge_filter: EdgeFilter::all(),
549            property_projection: PropertyProjection::all(),
550            aggregation: AggregationStrategy::None,
551            undirected: false,
552        }
553    }
554
555    /// Filter nodes by category label (string form).
556    pub fn with_node_labels<I, S>(mut self, labels: I) -> Self
557    where
558        I: IntoIterator<Item = S>,
559        S: Into<String>,
560    {
561        self.node_filter = self.node_filter.with_labels(labels);
562        self
563    }
564
565    /// Filter nodes by IDs
566    pub fn with_node_ids(mut self, ids: HashSet<String>) -> Self {
567        self.node_filter = self.node_filter.with_ids(ids);
568        self
569    }
570
571    /// Filter edges by category label (string form).
572    pub fn with_edge_types<I, S>(mut self, types: I) -> Self
573    where
574        I: IntoIterator<Item = S>,
575        S: Into<String>,
576    {
577        self.edge_filter = self.edge_filter.with_types(types);
578        self
579    }
580
581    /// Filter edges by minimum weight
582    pub fn with_min_weight(mut self, weight: f32) -> Self {
583        self.edge_filter = self.edge_filter.with_min_weight(weight);
584        self
585    }
586
587    /// Filter edges by maximum weight
588    pub fn with_max_weight(mut self, weight: f32) -> Self {
589        self.edge_filter = self.edge_filter.with_max_weight(weight);
590        self
591    }
592
593    /// Set edge aggregation strategy
594    pub fn aggregate(mut self, strategy: AggregationStrategy) -> Self {
595        self.aggregation = strategy;
596        self
597    }
598
599    /// Make the projection undirected
600    pub fn undirected(mut self) -> Self {
601        self.undirected = true;
602        self
603    }
604
605    /// Build the projection
606    pub fn build(self) -> GraphProjection {
607        if self.undirected {
608            GraphProjection::undirected(self.graph, self.node_filter, self.edge_filter)
609        } else {
610            GraphProjection::native(
611                self.graph,
612                self.node_filter,
613                self.edge_filter,
614                self.property_projection,
615                self.aggregation,
616            )
617        }
618    }
619}
620
621// ============================================================================
622// Tests
623// ============================================================================
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628
629    fn create_test_graph() -> GraphStore {
630        let graph = GraphStore::new();
631
632        let _ = graph.add_node_with_label("A", "Server A", "host");
633        let _ = graph.add_node_with_label("B", "Server B", "host");
634        let _ = graph.add_node_with_label("C", "DB Server", "service");
635        let _ = graph.add_node_with_label("D", "Web Server", "service");
636
637        let _ = graph.add_edge_with_label("A", "B", "connects_to", 1.0);
638        let _ = graph.add_edge_with_label("A", "C", "connects_to", 2.0);
639        let _ = graph.add_edge_with_label("B", "C", "auth_access", 1.5);
640        let _ = graph.add_edge_with_label("B", "D", "connects_to", 1.0);
641        let _ = graph.add_edge_with_label("C", "D", "connects_to", 0.5);
642
643        graph
644    }
645
646    #[test]
647    fn test_full_projection() {
648        let graph = create_test_graph();
649        let projection = GraphProjection::native(
650            &graph,
651            NodeFilter::all(),
652            EdgeFilter::all(),
653            PropertyProjection::all(),
654            AggregationStrategy::None,
655        );
656
657        assert_eq!(projection.node_count(), 4);
658        assert_eq!(projection.edge_count(), 5);
659    }
660
661    #[test]
662    fn test_node_label_filter() {
663        let graph = create_test_graph();
664        let projection = GraphProjection::native(
665            &graph,
666            NodeFilter::all().with_labels(["host"]),
667            EdgeFilter::all(),
668            PropertyProjection::all(),
669            AggregationStrategy::None,
670        );
671
672        assert_eq!(projection.node_count(), 2); // A and B
673        assert!(projection.has_node("A"));
674        assert!(projection.has_node("B"));
675        assert!(!projection.has_node("C"));
676        assert!(!projection.has_node("D"));
677    }
678
679    #[test]
680    fn test_edge_type_filter() {
681        let graph = create_test_graph();
682        let projection = GraphProjection::native(
683            &graph,
684            NodeFilter::all(),
685            EdgeFilter::all().with_types(["connects_to"]),
686            PropertyProjection::all(),
687            AggregationStrategy::None,
688        );
689
690        // A->B, A->C, B->D, C->D are ConnectsTo, B->C is HasAccess
691        assert_eq!(projection.edge_count(), 4);
692    }
693
694    #[test]
695    fn test_weight_filter() {
696        let graph = create_test_graph();
697        let projection = GraphProjection::native(
698            &graph,
699            NodeFilter::all(),
700            EdgeFilter::all().with_min_weight(1.0),
701            PropertyProjection::all(),
702            AggregationStrategy::None,
703        );
704
705        // Edges with weight >= 1.0: A->B(1.0), A->C(2.0), B->C(1.5), B->D(1.0)
706        assert_eq!(projection.edge_count(), 4);
707    }
708
709    #[test]
710    fn test_projection_builder() {
711        let graph = create_test_graph();
712        let projection = ProjectionBuilder::new(&graph)
713            .with_node_labels(["service"])
714            .build();
715
716        assert_eq!(projection.node_count(), 2); // C and D
717    }
718
719    #[test]
720    fn test_undirected_projection() {
721        let graph = create_test_graph();
722        let projection = ProjectionBuilder::new(&graph).undirected().build();
723
724        // Each edge should be traversable in both directions
725        assert!(projection.neighbors("A").contains(&"B"));
726        // Reverse edge should also exist
727        let b_neighbors = projection.neighbors("B");
728        assert!(b_neighbors.contains(&"A")); // Reverse of A->B
729    }
730
731    #[test]
732    fn test_from_nodes() {
733        let graph = create_test_graph();
734        let projection = GraphProjection::from_nodes(&graph, &["A".to_string(), "B".to_string()]);
735
736        assert_eq!(projection.node_count(), 2);
737        // Only edge A->B should be included
738        assert_eq!(projection.edge_count(), 1);
739    }
740
741    #[test]
742    fn test_neighbors() {
743        let graph = create_test_graph();
744        let projection = GraphProjection::native(
745            &graph,
746            NodeFilter::all(),
747            EdgeFilter::all(),
748            PropertyProjection::all(),
749            AggregationStrategy::None,
750        );
751
752        let a_neighbors = projection.neighbors("A");
753        assert!(a_neighbors.contains(&"B"));
754        assert!(a_neighbors.contains(&"C"));
755        assert_eq!(a_neighbors.len(), 2);
756    }
757
758    #[test]
759    fn test_degrees() {
760        let graph = create_test_graph();
761        let projection = GraphProjection::native(
762            &graph,
763            NodeFilter::all(),
764            EdgeFilter::all(),
765            PropertyProjection::all(),
766            AggregationStrategy::None,
767        );
768
769        assert_eq!(projection.out_degree("A"), 2); // A -> B, C
770        assert_eq!(projection.in_degree("D"), 2); // B, C -> D
771        assert_eq!(projection.out_degree("D"), 0); // D has no outgoing
772    }
773}