Skip to main content

reddb_server/storage/engine/
graph_table_index.rs

1//! Bidirectional Graph-Table Index
2//!
3//! Enables unified queries by maintaining bidirectional mappings:
4//! - node_id → (table_id, row_id)
5//! - (table_id, row_id) → node_id
6//!
7//! # Architecture
8//!
9//! ```text
10//! ┌─────────────────────────────────────────────────────────────┐
11//! │                    GraphTableIndex                           │
12//! ├─────────────────────────────────────────────────────────────┤
13//! │  NodeToRow Index (16 shards)    RowToNode Index (16 shards) │
14//! │  ┌────┐┌────┐┌────┐...         ┌────┐┌────┐┌────┐...       │
15//! │  │ S0 ││ S1 ││ S2 │            │ S0 ││ S1 ││ S2 │          │
16//! │  └────┘└────┘└────┘            └────┘└────┘└────┘          │
17//! │      │                              │                       │
18//! │      ▼                              ▼                       │
19//! │  node_id → TableRef            RowKey → node_id             │
20//! └─────────────────────────────────────────────────────────────┘
21//! ```
22//!
23//! # Thread Safety
24//!
25//! Uses sharded RwLock for concurrent access:
26//! - Multiple readers can access different shards simultaneously
27//! - Writers only block their specific shard
28//! - FNV hashing distributes keys evenly across shards
29
30use std::collections::HashMap;
31use std::sync::RwLock;
32
33use super::graph_store::TableRef;
34
35/// Number of shards for concurrent access
36const NUM_SHARDS: usize = 16;
37
38/// FNV-1a hash for fast shard selection
39fn fnv_hash(data: &[u8]) -> u64 {
40    const FNV_OFFSET: u64 = 0xcbf29ce484222325;
41    const FNV_PRIME: u64 = 0x100000001b3;
42
43    let mut hash = FNV_OFFSET;
44    for byte in data {
45        hash ^= *byte as u64;
46        hash = hash.wrapping_mul(FNV_PRIME);
47    }
48    hash
49}
50
51/// Composite key for row lookups: (table_id, row_id)
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
53pub struct RowKey {
54    pub table_id: u16,
55    pub row_id: u64,
56}
57
58impl RowKey {
59    pub fn new(table_id: u16, row_id: u64) -> Self {
60        Self { table_id, row_id }
61    }
62
63    pub fn from_table_ref(tref: &TableRef) -> Self {
64        Self {
65            table_id: tref.table_id,
66            row_id: tref.row_id,
67        }
68    }
69
70    /// Convert to bytes for hashing
71    fn to_bytes(&self) -> [u8; 10] {
72        let mut buf = [0u8; 10];
73        buf[0..2].copy_from_slice(&self.table_id.to_le_bytes());
74        buf[2..10].copy_from_slice(&self.row_id.to_le_bytes());
75        buf
76    }
77}
78
79/// Sharded index for node_id → TableRef
80struct NodeToRowIndex {
81    shards: Vec<RwLock<HashMap<String, TableRef>>>,
82}
83
84impl NodeToRowIndex {
85    fn new() -> Self {
86        let mut shards = Vec::with_capacity(NUM_SHARDS);
87        for _ in 0..NUM_SHARDS {
88            shards.push(RwLock::new(HashMap::new()));
89        }
90        Self { shards }
91    }
92
93    fn shard_for(&self, node_id: &str) -> usize {
94        (fnv_hash(node_id.as_bytes()) as usize) % NUM_SHARDS
95    }
96
97    fn insert(&self, node_id: String, table_ref: TableRef) {
98        let shard = self.shard_for(&node_id);
99        if let Ok(mut map) = self.shards[shard].write() {
100            map.insert(node_id, table_ref);
101        }
102    }
103
104    fn get(&self, node_id: &str) -> Option<TableRef> {
105        let shard = self.shard_for(node_id);
106        if let Ok(map) = self.shards[shard].read() {
107            map.get(node_id).copied()
108        } else {
109            None
110        }
111    }
112
113    fn remove(&self, node_id: &str) -> Option<TableRef> {
114        let shard = self.shard_for(node_id);
115        if let Ok(mut map) = self.shards[shard].write() {
116            map.remove(node_id)
117        } else {
118            None
119        }
120    }
121
122    fn contains(&self, node_id: &str) -> bool {
123        let shard = self.shard_for(node_id);
124        if let Ok(map) = self.shards[shard].read() {
125            map.contains_key(node_id)
126        } else {
127            false
128        }
129    }
130
131    fn len(&self) -> usize {
132        self.shards
133            .iter()
134            .filter_map(|s| s.read().ok())
135            .map(|m| m.len())
136            .sum()
137    }
138}
139
140/// Sharded index for (table_id, row_id) → node_id
141struct RowToNodeIndex {
142    shards: Vec<RwLock<HashMap<RowKey, String>>>,
143}
144
145impl RowToNodeIndex {
146    fn new() -> Self {
147        let mut shards = Vec::with_capacity(NUM_SHARDS);
148        for _ in 0..NUM_SHARDS {
149            shards.push(RwLock::new(HashMap::new()));
150        }
151        Self { shards }
152    }
153
154    fn shard_for(&self, key: &RowKey) -> usize {
155        (fnv_hash(&key.to_bytes()) as usize) % NUM_SHARDS
156    }
157
158    fn insert(&self, key: RowKey, node_id: String) {
159        let shard = self.shard_for(&key);
160        if let Ok(mut map) = self.shards[shard].write() {
161            map.insert(key, node_id);
162        }
163    }
164
165    fn get(&self, key: &RowKey) -> Option<String> {
166        let shard = self.shard_for(key);
167        if let Ok(map) = self.shards[shard].read() {
168            map.get(key).cloned()
169        } else {
170            None
171        }
172    }
173
174    fn remove(&self, key: &RowKey) -> Option<String> {
175        let shard = self.shard_for(key);
176        if let Ok(mut map) = self.shards[shard].write() {
177            map.remove(key)
178        } else {
179            None
180        }
181    }
182
183    fn contains(&self, key: &RowKey) -> bool {
184        let shard = self.shard_for(key);
185        if let Ok(map) = self.shards[shard].read() {
186            map.contains_key(key)
187        } else {
188            false
189        }
190    }
191
192    /// Get all nodes for a specific table
193    fn nodes_for_table(&self, table_id: u16) -> Vec<(u64, String)> {
194        let mut results = Vec::new();
195        for shard in &self.shards {
196            if let Ok(map) = shard.read() {
197                for (key, node_id) in map.iter() {
198                    if key.table_id == table_id {
199                        results.push((key.row_id, node_id.clone()));
200                    }
201                }
202            }
203        }
204        results
205    }
206
207    fn len(&self) -> usize {
208        self.shards
209            .iter()
210            .filter_map(|s| s.read().ok())
211            .map(|m| m.len())
212            .sum()
213    }
214}
215
216/// Bidirectional index for graph-table linkage
217///
218/// Enables efficient lookups in both directions:
219/// - From graph node to table row
220/// - From table row to graph node
221///
222/// Thread-safe with sharded locking for concurrent access.
223pub struct GraphTableIndex {
224    node_to_row: NodeToRowIndex,
225    row_to_node: RowToNodeIndex,
226}
227
228impl GraphTableIndex {
229    /// Create a new empty index
230    pub fn new() -> Self {
231        Self {
232            node_to_row: NodeToRowIndex::new(),
233            row_to_node: RowToNodeIndex::new(),
234        }
235    }
236
237    /// Link a graph node to a table row
238    ///
239    /// Creates bidirectional mapping between node_id and (table_id, row_id).
240    /// Overwrites existing mappings if present.
241    pub fn link(&self, node_id: &str, table_id: u16, row_id: u64) {
242        let table_ref = TableRef::new(table_id, row_id);
243        let row_key = RowKey::new(table_id, row_id);
244
245        self.node_to_row.insert(node_id.to_string(), table_ref);
246        self.row_to_node.insert(row_key, node_id.to_string());
247    }
248
249    /// Unlink a graph node from its table row
250    ///
251    /// Removes both directions of the mapping.
252    /// Returns the TableRef if it existed.
253    pub fn unlink_node(&self, node_id: &str) -> Option<TableRef> {
254        if let Some(table_ref) = self.node_to_row.remove(node_id) {
255            let row_key = RowKey::from_table_ref(&table_ref);
256            self.row_to_node.remove(&row_key);
257            Some(table_ref)
258        } else {
259            None
260        }
261    }
262
263    /// Unlink a table row from its graph node
264    ///
265    /// Removes both directions of the mapping.
266    /// Returns the node_id if it existed.
267    pub fn unlink_row(&self, table_id: u16, row_id: u64) -> Option<String> {
268        let row_key = RowKey::new(table_id, row_id);
269        if let Some(node_id) = self.row_to_node.remove(&row_key) {
270            self.node_to_row.remove(&node_id);
271            Some(node_id)
272        } else {
273            None
274        }
275    }
276
277    /// Get the table row for a graph node
278    pub fn get_row_for_node(&self, node_id: &str) -> Option<TableRef> {
279        self.node_to_row.get(node_id)
280    }
281
282    /// Get the graph node for a table row
283    pub fn get_node_for_row(&self, table_id: u16, row_id: u64) -> Option<String> {
284        let row_key = RowKey::new(table_id, row_id);
285        self.row_to_node.get(&row_key)
286    }
287
288    /// Check if a node is linked to a table row
289    pub fn is_node_linked(&self, node_id: &str) -> bool {
290        self.node_to_row.contains(node_id)
291    }
292
293    /// Check if a table row is linked to a graph node
294    pub fn is_row_linked(&self, table_id: u16, row_id: u64) -> bool {
295        let row_key = RowKey::new(table_id, row_id);
296        self.row_to_node.contains(&row_key)
297    }
298
299    /// Get all nodes linked to a specific table
300    ///
301    /// Returns pairs of (row_id, node_id) for the given table.
302    pub fn nodes_for_table(&self, table_id: u16) -> Vec<(u64, String)> {
303        self.row_to_node.nodes_for_table(table_id)
304    }
305
306    /// Get statistics about the index
307    pub fn stats(&self) -> GraphTableIndexStats {
308        GraphTableIndexStats {
309            node_to_row_count: self.node_to_row.len(),
310            row_to_node_count: self.row_to_node.len(),
311            num_shards: NUM_SHARDS,
312        }
313    }
314
315    /// Clear all mappings
316    pub fn clear(&self) {
317        for shard in &self.node_to_row.shards {
318            if let Ok(mut map) = shard.write() {
319                map.clear();
320            }
321        }
322        for shard in &self.row_to_node.shards {
323            if let Ok(mut map) = shard.write() {
324                map.clear();
325            }
326        }
327    }
328
329    /// Serialize to bytes for persistence
330    pub fn serialize(&self) -> Vec<u8> {
331        let mut buf = Vec::new();
332
333        // Collect all mappings
334        let mut mappings = Vec::new();
335        for shard in &self.node_to_row.shards {
336            if let Ok(map) = shard.read() {
337                for (node_id, table_ref) in map.iter() {
338                    mappings.push((node_id.clone(), *table_ref));
339                }
340            }
341        }
342
343        // Write count
344        buf.extend_from_slice(&(mappings.len() as u32).to_le_bytes());
345
346        // Write each mapping: node_id_len(2) + node_id + table_ref(10)
347        for (node_id, table_ref) in mappings {
348            let id_bytes = node_id.as_bytes();
349            buf.extend_from_slice(&(id_bytes.len() as u16).to_le_bytes());
350            buf.extend_from_slice(id_bytes);
351            buf.extend_from_slice(&table_ref.encode());
352        }
353
354        buf
355    }
356
357    /// Deserialize from bytes
358    pub fn deserialize(data: &[u8]) -> Result<Self, GraphTableIndexError> {
359        if data.len() < 4 {
360            return Err(GraphTableIndexError::InvalidData("Too short".to_string()));
361        }
362
363        let index = Self::new();
364        let count = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
365        let mut offset = 4;
366
367        for _ in 0..count {
368            if offset + 2 > data.len() {
369                return Err(GraphTableIndexError::InvalidData(
370                    "Truncated node_id length".to_string(),
371                ));
372            }
373
374            let id_len = u16::from_le_bytes([data[offset], data[offset + 1]]) as usize;
375            offset += 2;
376
377            if offset + id_len + 10 > data.len() {
378                return Err(GraphTableIndexError::InvalidData(
379                    "Truncated mapping".to_string(),
380                ));
381            }
382
383            let node_id = String::from_utf8_lossy(&data[offset..offset + id_len]).to_string();
384            offset += id_len;
385
386            let table_ref = TableRef::decode(&data[offset..]).ok_or_else(|| {
387                GraphTableIndexError::InvalidData("Invalid table ref".to_string())
388            })?;
389            offset += 10;
390
391            index.link(&node_id, table_ref.table_id, table_ref.row_id);
392        }
393
394        Ok(index)
395    }
396}
397
398impl Default for GraphTableIndex {
399    fn default() -> Self {
400        Self::new()
401    }
402}
403
404/// Statistics for GraphTableIndex
405#[derive(Debug, Clone, Copy)]
406pub struct GraphTableIndexStats {
407    pub node_to_row_count: usize,
408    pub row_to_node_count: usize,
409    pub num_shards: usize,
410}
411
412/// Error type for GraphTableIndex operations
413#[derive(Debug, Clone)]
414pub enum GraphTableIndexError {
415    InvalidData(String),
416}
417
418impl std::fmt::Display for GraphTableIndexError {
419    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
420        match self {
421            Self::InvalidData(msg) => write!(f, "Invalid data: {}", msg),
422        }
423    }
424}
425
426impl std::error::Error for GraphTableIndexError {}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    #[test]
433    fn test_link_and_lookup() {
434        let index = GraphTableIndex::new();
435
436        index.link("host:192.168.1.1", 1, 100);
437        index.link("service:ssh", 2, 200);
438
439        // Forward lookup (node → row)
440        let tref = index.get_row_for_node("host:192.168.1.1").unwrap();
441        assert_eq!(tref.table_id, 1);
442        assert_eq!(tref.row_id, 100);
443
444        // Reverse lookup (row → node)
445        let node_id = index.get_node_for_row(2, 200).unwrap();
446        assert_eq!(node_id, "service:ssh");
447
448        // Non-existent
449        assert!(index.get_row_for_node("unknown").is_none());
450        assert!(index.get_node_for_row(99, 999).is_none());
451    }
452
453    #[test]
454    fn test_unlink() {
455        let index = GraphTableIndex::new();
456
457        index.link("node1", 1, 10);
458        assert!(index.is_node_linked("node1"));
459        assert!(index.is_row_linked(1, 10));
460
461        // Unlink by node
462        let tref = index.unlink_node("node1").unwrap();
463        assert_eq!(tref.table_id, 1);
464        assert_eq!(tref.row_id, 10);
465
466        assert!(!index.is_node_linked("node1"));
467        assert!(!index.is_row_linked(1, 10));
468    }
469
470    #[test]
471    fn test_unlink_by_row() {
472        let index = GraphTableIndex::new();
473
474        index.link("node2", 2, 20);
475
476        let node_id = index.unlink_row(2, 20).unwrap();
477        assert_eq!(node_id, "node2");
478
479        assert!(!index.is_node_linked("node2"));
480        assert!(!index.is_row_linked(2, 20));
481    }
482
483    #[test]
484    fn test_nodes_for_table() {
485        let index = GraphTableIndex::new();
486
487        index.link("host:1", 1, 100);
488        index.link("host:2", 1, 101);
489        index.link("host:3", 1, 102);
490        index.link("service:1", 2, 200);
491
492        let hosts = index.nodes_for_table(1);
493        assert_eq!(hosts.len(), 3);
494
495        let services = index.nodes_for_table(2);
496        assert_eq!(services.len(), 1);
497    }
498
499    #[test]
500    fn test_serialization() {
501        let index = GraphTableIndex::new();
502
503        index.link("node:a", 1, 100);
504        index.link("node:b", 2, 200);
505        index.link("node:c", 1, 300);
506
507        let bytes = index.serialize();
508        let restored = GraphTableIndex::deserialize(&bytes).unwrap();
509
510        assert_eq!(restored.stats().node_to_row_count, 3);
511        assert_eq!(restored.get_row_for_node("node:a").unwrap().row_id, 100);
512        assert_eq!(restored.get_node_for_row(2, 200).unwrap(), "node:b");
513    }
514
515    #[test]
516    fn test_concurrent_access() {
517        use std::sync::Arc;
518        use std::thread;
519
520        let index = Arc::new(GraphTableIndex::new());
521        let mut handles = vec![];
522
523        // Spawn writers
524        for i in 0..10 {
525            let idx = Arc::clone(&index);
526            handles.push(thread::spawn(move || {
527                for j in 0..100 {
528                    idx.link(&format!("node:{}:{}", i, j), i as u16, j);
529                }
530            }));
531        }
532
533        // Spawn readers
534        for _ in 0..5 {
535            let idx = Arc::clone(&index);
536            handles.push(thread::spawn(move || {
537                for i in 0..10 {
538                    for j in 0..100 {
539                        let _ = idx.get_row_for_node(&format!("node:{}:{}", i, j));
540                    }
541                }
542            }));
543        }
544
545        for h in handles {
546            h.join().unwrap();
547        }
548
549        assert_eq!(index.stats().node_to_row_count, 1000);
550    }
551}