Skip to main content

aegis_replication/
router.rs

1//! Aegis Shard Router
2//!
3//! Query routing to appropriate shards.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use crate::hash::HashRing;
9use crate::node::NodeId;
10use crate::partition::{PartitionKey, PartitionStrategy};
11use crate::shard::{Shard, ShardId};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::RwLock;
15
16// =============================================================================
17// Route Decision
18// =============================================================================
19
20/// A routing decision for a query.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub enum RouteDecision {
23    /// Route to a single shard.
24    Single { shard_id: ShardId, node_id: NodeId },
25    /// Route to multiple shards (scatter-gather).
26    Multi { routes: Vec<ShardRoute> },
27    /// Route to all shards (broadcast).
28    Broadcast { shards: Vec<ShardId> },
29    /// Route to primary only.
30    Primary { shard_id: ShardId, node_id: NodeId },
31    /// Route to any replica (for read queries).
32    AnyReplica {
33        shard_id: ShardId,
34        candidates: Vec<NodeId>,
35    },
36}
37
38/// A route to a specific shard.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ShardRoute {
41    pub shard_id: ShardId,
42    pub node_id: NodeId,
43    pub is_primary: bool,
44}
45
46// =============================================================================
47// Routing Table
48// =============================================================================
49
50/// A routing table for shard lookups.
51#[derive(Debug, Clone)]
52pub struct RoutingTable {
53    entries: HashMap<ShardId, RoutingEntry>,
54    version: u64,
55}
56
57/// An entry in the routing table.
58#[derive(Debug, Clone)]
59pub struct RoutingEntry {
60    pub shard_id: ShardId,
61    pub primary: NodeId,
62    pub replicas: Vec<NodeId>,
63    pub key_range_start: u64,
64    pub key_range_end: u64,
65}
66
67impl RoutingTable {
68    /// Create a new routing table.
69    pub fn new() -> Self {
70        Self {
71            entries: HashMap::new(),
72            version: 0,
73        }
74    }
75
76    /// Add or update an entry.
77    pub fn upsert(&mut self, entry: RoutingEntry) {
78        self.entries.insert(entry.shard_id.clone(), entry);
79        self.version += 1;
80    }
81
82    /// Remove an entry.
83    pub fn remove(&mut self, shard_id: &ShardId) -> Option<RoutingEntry> {
84        let entry = self.entries.remove(shard_id);
85        if entry.is_some() {
86            self.version += 1;
87        }
88        entry
89    }
90
91    /// Get an entry.
92    pub fn get(&self, shard_id: &ShardId) -> Option<&RoutingEntry> {
93        self.entries.get(shard_id)
94    }
95
96    /// Find the shard for a key hash.
97    pub fn find_shard(&self, key_hash: u64) -> Option<&RoutingEntry> {
98        self.entries
99            .values()
100            .find(|e| key_hash >= e.key_range_start && key_hash < e.key_range_end)
101    }
102
103    /// Get all entries.
104    pub fn all_entries(&self) -> impl Iterator<Item = &RoutingEntry> {
105        self.entries.values()
106    }
107
108    /// Get the table version.
109    pub fn version(&self) -> u64 {
110        self.version
111    }
112
113    /// Get the number of entries.
114    pub fn len(&self) -> usize {
115        self.entries.len()
116    }
117
118    /// Check if the table is empty.
119    pub fn is_empty(&self) -> bool {
120        self.entries.is_empty()
121    }
122
123    /// Build from a shard manager.
124    pub fn from_shards(shards: &[Shard]) -> Self {
125        let mut table = Self::new();
126
127        for shard in shards {
128            table.upsert(RoutingEntry {
129                shard_id: shard.id.clone(),
130                primary: shard.primary_node.clone(),
131                replicas: shard.replica_nodes.clone(),
132                key_range_start: shard.key_range_start.unwrap_or(0),
133                key_range_end: shard.key_range_end.unwrap_or(u64::MAX),
134            });
135        }
136
137        table
138    }
139}
140
141impl Default for RoutingTable {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147// =============================================================================
148// Shard Router
149// =============================================================================
150
151/// Routes queries to appropriate shards.
152pub struct ShardRouter {
153    routing_table: RwLock<RoutingTable>,
154    hash_ring: RwLock<HashRing>,
155    partition_strategy: PartitionStrategy,
156    prefer_local: bool,
157    local_node: Option<NodeId>,
158}
159
160impl ShardRouter {
161    /// Create a new shard router.
162    pub fn new(strategy: PartitionStrategy) -> Self {
163        Self {
164            routing_table: RwLock::new(RoutingTable::new()),
165            hash_ring: RwLock::new(HashRing::default()),
166            partition_strategy: strategy,
167            prefer_local: true,
168            local_node: None,
169        }
170    }
171
172    /// Create a router with a local node preference.
173    pub fn with_local_node(strategy: PartitionStrategy, local_node: NodeId) -> Self {
174        Self {
175            routing_table: RwLock::new(RoutingTable::new()),
176            hash_ring: RwLock::new(HashRing::default()),
177            partition_strategy: strategy,
178            prefer_local: true,
179            local_node: Some(local_node),
180        }
181    }
182
183    /// Update the routing table from shards.
184    pub fn update_routing(&self, shards: &[Shard]) {
185        let table = RoutingTable::from_shards(shards);
186        *self
187            .routing_table
188            .write()
189            .expect("router routing_table lock poisoned") = table;
190
191        // Update hash ring
192        let mut ring = self
193            .hash_ring
194            .write()
195            .expect("router hash_ring lock poisoned");
196        *ring = HashRing::default();
197        for shard in shards {
198            ring.add_node(shard.primary_node.clone());
199            for replica in &shard.replica_nodes {
200                ring.add_node(replica.clone());
201            }
202        }
203    }
204
205    /// Route a query with a partition key.
206    pub fn route(&self, key: &PartitionKey) -> RouteDecision {
207        let hash = key.hash_value();
208        let table = self
209            .routing_table
210            .read()
211            .expect("router routing_table lock poisoned");
212
213        if let Some(entry) = table.find_shard(hash) {
214            RouteDecision::Single {
215                shard_id: entry.shard_id.clone(),
216                node_id: self.select_node(entry),
217            }
218        } else {
219            // Fallback to hash ring
220            let ring = self
221                .hash_ring
222                .read()
223                .expect("router hash_ring lock poisoned");
224            if let Some(node) = ring.get_node(&format!("{}", hash)) {
225                RouteDecision::Single {
226                    shard_id: ShardId::new(0),
227                    node_id: node.clone(),
228                }
229            } else {
230                RouteDecision::Broadcast {
231                    shards: table.entries.keys().cloned().collect(),
232                }
233            }
234        }
235    }
236
237    /// Route for a write operation (always to primary).
238    pub fn route_write(&self, key: &PartitionKey) -> RouteDecision {
239        let hash = key.hash_value();
240        let table = self
241            .routing_table
242            .read()
243            .expect("router routing_table lock poisoned");
244
245        if let Some(entry) = table.find_shard(hash) {
246            RouteDecision::Primary {
247                shard_id: entry.shard_id.clone(),
248                node_id: entry.primary.clone(),
249            }
250        } else {
251            RouteDecision::Broadcast {
252                shards: table.entries.keys().cloned().collect(),
253            }
254        }
255    }
256
257    /// Route for a read operation (can use replicas).
258    pub fn route_read(&self, key: &PartitionKey) -> RouteDecision {
259        let hash = key.hash_value();
260        let table = self
261            .routing_table
262            .read()
263            .expect("router routing_table lock poisoned");
264
265        if let Some(entry) = table.find_shard(hash) {
266            let mut candidates = vec![entry.primary.clone()];
267            candidates.extend(entry.replicas.iter().cloned());
268
269            RouteDecision::AnyReplica {
270                shard_id: entry.shard_id.clone(),
271                candidates,
272            }
273        } else {
274            drop(table);
275            self.route(key)
276        }
277    }
278
279    /// Route to multiple shards for a range query.
280    pub fn route_range(&self, start_key: &PartitionKey, end_key: &PartitionKey) -> RouteDecision {
281        let start_hash = start_key.hash_value();
282        let end_hash = end_key.hash_value();
283
284        let table = self
285            .routing_table
286            .read()
287            .expect("router routing_table lock poisoned");
288        let mut routes = Vec::new();
289
290        for entry in table.all_entries() {
291            // Check if shard overlaps with query range
292            if entry.key_range_end > start_hash && entry.key_range_start < end_hash {
293                routes.push(ShardRoute {
294                    shard_id: entry.shard_id.clone(),
295                    node_id: self.select_node(entry),
296                    is_primary: true,
297                });
298            }
299        }
300
301        if routes.is_empty() {
302            RouteDecision::Broadcast {
303                shards: table.entries.keys().cloned().collect(),
304            }
305        } else if routes.len() == 1 {
306            let route = routes.remove(0);
307            RouteDecision::Single {
308                shard_id: route.shard_id,
309                node_id: route.node_id,
310            }
311        } else {
312            RouteDecision::Multi { routes }
313        }
314    }
315
316    /// Route to all shards (for queries without partition key).
317    pub fn route_all(&self) -> RouteDecision {
318        let table = self
319            .routing_table
320            .read()
321            .expect("router routing_table lock poisoned");
322
323        let routes: Vec<_> = table
324            .all_entries()
325            .map(|entry| ShardRoute {
326                shard_id: entry.shard_id.clone(),
327                node_id: self.select_node(entry),
328                is_primary: true,
329            })
330            .collect();
331
332        if routes.is_empty() {
333            RouteDecision::Broadcast { shards: vec![] }
334        } else {
335            RouteDecision::Multi { routes }
336        }
337    }
338
339    /// Select a node for a routing entry.
340    fn select_node(&self, entry: &RoutingEntry) -> NodeId {
341        // Prefer local node if available
342        if self.prefer_local {
343            if let Some(ref local) = self.local_node {
344                if &entry.primary == local {
345                    return entry.primary.clone();
346                }
347                if entry.replicas.contains(local) {
348                    return local.clone();
349                }
350            }
351        }
352
353        // Default to primary
354        entry.primary.clone()
355    }
356
357    /// Get the current routing table version.
358    pub fn routing_version(&self) -> u64 {
359        self.routing_table
360            .read()
361            .expect("router routing_table lock poisoned")
362            .version()
363    }
364
365    /// Get the partition strategy.
366    pub fn strategy(&self) -> &PartitionStrategy {
367        &self.partition_strategy
368    }
369
370    /// Check if routing table is initialized.
371    pub fn is_initialized(&self) -> bool {
372        !self
373            .routing_table
374            .read()
375            .expect("router routing_table lock poisoned")
376            .is_empty()
377    }
378}
379
380// =============================================================================
381// Query Analyzer
382// =============================================================================
383
384/// Analyzes queries to extract partition keys.
385pub struct QueryAnalyzer {
386    partition_columns: Vec<String>,
387}
388
389impl QueryAnalyzer {
390    /// Create a new query analyzer.
391    pub fn new(partition_columns: Vec<String>) -> Self {
392        Self { partition_columns }
393    }
394
395    /// Analyze a query to extract routing information.
396    pub fn analyze(&self, query: &str) -> QueryRouting {
397        let query_upper = query.to_uppercase();
398
399        // Determine query type
400        let query_type = if query_upper.starts_with("SELECT") {
401            QueryType::Read
402        } else if query_upper.starts_with("INSERT") {
403            QueryType::Write
404        } else if query_upper.starts_with("UPDATE") {
405            QueryType::Write
406        } else if query_upper.starts_with("DELETE") {
407            QueryType::Write
408        } else {
409            QueryType::Admin
410        };
411
412        // Try to extract partition key from WHERE clause
413        let partition_key = self.extract_partition_key(query);
414        let requires_all_shards = partition_key.is_none() && query_type == QueryType::Read;
415
416        QueryRouting {
417            query_type,
418            partition_key,
419            requires_all_shards,
420        }
421    }
422
423    fn extract_partition_key(&self, query: &str) -> Option<PartitionKey> {
424        let query_lower = query.to_lowercase();
425
426        for col in &self.partition_columns {
427            let col_lower = col.to_lowercase();
428
429            // Look for "column = 'value'" or "column = value"
430            let patterns = [
431                format!("{} = '", col_lower),
432                format!("{} ='", col_lower),
433                format!("{}='", col_lower),
434                format!("{} = ", col_lower),
435            ];
436
437            for pattern in &patterns {
438                if let Some(start) = query_lower.find(pattern) {
439                    let value_start = start + pattern.len();
440                    let remaining = &query[value_start..];
441
442                    let value = if remaining.starts_with('\'') {
443                        // String value
444                        remaining[1..].split('\'').next().map(|s| s.to_string())
445                    } else {
446                        // Numeric or unquoted
447                        remaining
448                            .split(|c: char| c.is_whitespace() || c == ')' || c == ';')
449                            .next()
450                            .map(|s| s.trim_matches('\'').to_string())
451                    };
452
453                    if let Some(v) = value {
454                        if !v.is_empty() {
455                            // Try to parse as integer
456                            if let Ok(i) = v.parse::<i64>() {
457                                return Some(PartitionKey::Int(i));
458                            }
459                            return Some(PartitionKey::String(v));
460                        }
461                    }
462                }
463            }
464        }
465
466        None
467    }
468}
469
470/// Query routing information.
471#[derive(Debug, Clone)]
472pub struct QueryRouting {
473    pub query_type: QueryType,
474    pub partition_key: Option<PartitionKey>,
475    pub requires_all_shards: bool,
476}
477
478/// Type of query.
479#[derive(Debug, Clone, Copy, PartialEq, Eq)]
480pub enum QueryType {
481    Read,
482    Write,
483    Admin,
484}
485
486// =============================================================================
487// Tests
488// =============================================================================
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493
494    fn create_test_shards() -> Vec<Shard> {
495        vec![
496            Shard::with_range(ShardId::new(0), NodeId::new("node1"), 0, u64::MAX / 2),
497            Shard::with_range(
498                ShardId::new(1),
499                NodeId::new("node2"),
500                u64::MAX / 2,
501                u64::MAX,
502            ),
503        ]
504    }
505
506    #[test]
507    fn test_routing_table() {
508        let shards = create_test_shards();
509        let table = RoutingTable::from_shards(&shards);
510
511        assert_eq!(table.len(), 2);
512        assert!(table.get(&ShardId::new(0)).is_some());
513    }
514
515    #[test]
516    fn test_routing_table_find_shard() {
517        let shards = create_test_shards();
518        let table = RoutingTable::from_shards(&shards);
519
520        let entry1 = table.find_shard(100).unwrap();
521        assert_eq!(entry1.shard_id.as_u32(), 0);
522
523        let entry2 = table.find_shard(u64::MAX - 100).unwrap();
524        assert_eq!(entry2.shard_id.as_u32(), 1);
525    }
526
527    #[test]
528    fn test_shard_router() {
529        let strategy = PartitionStrategy::hash(vec!["id".to_string()], 2);
530        let router = ShardRouter::new(strategy);
531
532        let shards = create_test_shards();
533        router.update_routing(&shards);
534
535        assert!(router.is_initialized());
536    }
537
538    #[test]
539    fn test_router_route() {
540        let strategy = PartitionStrategy::hash(vec!["id".to_string()], 2);
541        let router = ShardRouter::new(strategy);
542
543        let shards = create_test_shards();
544        router.update_routing(&shards);
545
546        let key = PartitionKey::string("test_key");
547        let decision = router.route(&key);
548
549        match decision {
550            RouteDecision::Single { shard_id, node_id } => {
551                assert!(shard_id.as_u32() < 2);
552                assert!(["node1", "node2"].contains(&node_id.as_str()));
553            }
554            _ => panic!("Expected single route"),
555        }
556    }
557
558    #[test]
559    fn test_router_route_write() {
560        let strategy = PartitionStrategy::hash(vec!["id".to_string()], 2);
561        let router = ShardRouter::new(strategy);
562
563        let mut shards = create_test_shards();
564        shards[0].add_replica(NodeId::new("node3"));
565        router.update_routing(&shards);
566
567        let key = PartitionKey::int(1);
568        let decision = router.route_write(&key);
569
570        match decision {
571            RouteDecision::Primary { node_id, .. } => {
572                // Should route to primary, not replica
573                assert!(["node1", "node2"].contains(&node_id.as_str()));
574            }
575            _ => panic!("Expected primary route"),
576        }
577    }
578
579    #[test]
580    fn test_router_route_read() {
581        let strategy = PartitionStrategy::hash(vec!["id".to_string()], 2);
582        let router = ShardRouter::new(strategy);
583
584        let mut shards = create_test_shards();
585        shards[0].add_replica(NodeId::new("node3"));
586        router.update_routing(&shards);
587
588        let key = PartitionKey::int(1);
589        let decision = router.route_read(&key);
590
591        match decision {
592            RouteDecision::AnyReplica { candidates, .. } => {
593                assert!(!candidates.is_empty());
594            }
595            _ => panic!("Expected any replica route"),
596        }
597    }
598
599    #[test]
600    fn test_router_route_all() {
601        let strategy = PartitionStrategy::hash(vec!["id".to_string()], 2);
602        let router = ShardRouter::new(strategy);
603
604        let shards = create_test_shards();
605        router.update_routing(&shards);
606
607        let decision = router.route_all();
608
609        match decision {
610            RouteDecision::Multi { routes } => {
611                assert_eq!(routes.len(), 2);
612            }
613            _ => panic!("Expected multi route"),
614        }
615    }
616
617    #[test]
618    fn test_query_analyzer() {
619        let analyzer = QueryAnalyzer::new(vec!["user_id".to_string()]);
620
621        let routing = analyzer.analyze("SELECT * FROM users WHERE user_id = 123");
622        assert_eq!(routing.query_type, QueryType::Read);
623        assert!(routing.partition_key.is_some());
624
625        let routing = analyzer.analyze("INSERT INTO users VALUES (1, 'Alice')");
626        assert_eq!(routing.query_type, QueryType::Write);
627    }
628
629    #[test]
630    fn test_query_analyzer_no_key() {
631        let analyzer = QueryAnalyzer::new(vec!["user_id".to_string()]);
632
633        let routing = analyzer.analyze("SELECT * FROM users");
634        assert!(routing.partition_key.is_none());
635        assert!(routing.requires_all_shards);
636    }
637
638    #[test]
639    fn test_route_decision_types() {
640        let single = RouteDecision::Single {
641            shard_id: ShardId::new(0),
642            node_id: NodeId::new("node1"),
643        };
644
645        let broadcast = RouteDecision::Broadcast {
646            shards: vec![ShardId::new(0), ShardId::new(1)],
647        };
648
649        match single {
650            RouteDecision::Single { .. } => {}
651            _ => panic!("Expected single"),
652        }
653
654        match broadcast {
655            RouteDecision::Broadcast { shards } => {
656                assert_eq!(shards.len(), 2);
657            }
658            _ => panic!("Expected broadcast"),
659        }
660    }
661}