1use crate::hash::HashRing;
9use crate::node::NodeId;
10use crate::partition::{PartitionKey, PartitionStrategy};
11use crate::shard::{Shard, ShardId};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::RwLock;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
22pub enum RouteDecision {
23 Single {
25 shard_id: ShardId,
26 node_id: NodeId,
27 },
28 Multi {
30 routes: Vec<ShardRoute>,
31 },
32 Broadcast {
34 shards: Vec<ShardId>,
35 },
36 Primary {
38 shard_id: ShardId,
39 node_id: NodeId,
40 },
41 AnyReplica {
43 shard_id: ShardId,
44 candidates: Vec<NodeId>,
45 },
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ShardRoute {
51 pub shard_id: ShardId,
52 pub node_id: NodeId,
53 pub is_primary: bool,
54}
55
56#[derive(Debug, Clone)]
62pub struct RoutingTable {
63 entries: HashMap<ShardId, RoutingEntry>,
64 version: u64,
65}
66
67#[derive(Debug, Clone)]
69pub struct RoutingEntry {
70 pub shard_id: ShardId,
71 pub primary: NodeId,
72 pub replicas: Vec<NodeId>,
73 pub key_range_start: u64,
74 pub key_range_end: u64,
75}
76
77impl RoutingTable {
78 pub fn new() -> Self {
80 Self {
81 entries: HashMap::new(),
82 version: 0,
83 }
84 }
85
86 pub fn upsert(&mut self, entry: RoutingEntry) {
88 self.entries.insert(entry.shard_id.clone(), entry);
89 self.version += 1;
90 }
91
92 pub fn remove(&mut self, shard_id: &ShardId) -> Option<RoutingEntry> {
94 let entry = self.entries.remove(shard_id);
95 if entry.is_some() {
96 self.version += 1;
97 }
98 entry
99 }
100
101 pub fn get(&self, shard_id: &ShardId) -> Option<&RoutingEntry> {
103 self.entries.get(shard_id)
104 }
105
106 pub fn find_shard(&self, key_hash: u64) -> Option<&RoutingEntry> {
108 self.entries
109 .values()
110 .find(|e| key_hash >= e.key_range_start && key_hash < e.key_range_end)
111 }
112
113 pub fn all_entries(&self) -> impl Iterator<Item = &RoutingEntry> {
115 self.entries.values()
116 }
117
118 pub fn version(&self) -> u64 {
120 self.version
121 }
122
123 pub fn len(&self) -> usize {
125 self.entries.len()
126 }
127
128 pub fn is_empty(&self) -> bool {
130 self.entries.is_empty()
131 }
132
133 pub fn from_shards(shards: &[Shard]) -> Self {
135 let mut table = Self::new();
136
137 for shard in shards {
138 table.upsert(RoutingEntry {
139 shard_id: shard.id.clone(),
140 primary: shard.primary_node.clone(),
141 replicas: shard.replica_nodes.clone(),
142 key_range_start: shard.key_range_start.unwrap_or(0),
143 key_range_end: shard.key_range_end.unwrap_or(u64::MAX),
144 });
145 }
146
147 table
148 }
149}
150
151impl Default for RoutingTable {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157pub struct ShardRouter {
163 routing_table: RwLock<RoutingTable>,
164 hash_ring: RwLock<HashRing>,
165 partition_strategy: PartitionStrategy,
166 prefer_local: bool,
167 local_node: Option<NodeId>,
168}
169
170impl ShardRouter {
171 pub fn new(strategy: PartitionStrategy) -> Self {
173 Self {
174 routing_table: RwLock::new(RoutingTable::new()),
175 hash_ring: RwLock::new(HashRing::default()),
176 partition_strategy: strategy,
177 prefer_local: true,
178 local_node: None,
179 }
180 }
181
182 pub fn with_local_node(strategy: PartitionStrategy, local_node: NodeId) -> Self {
184 Self {
185 routing_table: RwLock::new(RoutingTable::new()),
186 hash_ring: RwLock::new(HashRing::default()),
187 partition_strategy: strategy,
188 prefer_local: true,
189 local_node: Some(local_node),
190 }
191 }
192
193 pub fn update_routing(&self, shards: &[Shard]) {
195 let table = RoutingTable::from_shards(shards);
196 *self.routing_table.write().unwrap() = table;
197
198 let mut ring = self.hash_ring.write().unwrap();
200 *ring = HashRing::default();
201 for shard in shards {
202 ring.add_node(shard.primary_node.clone());
203 for replica in &shard.replica_nodes {
204 ring.add_node(replica.clone());
205 }
206 }
207 }
208
209 pub fn route(&self, key: &PartitionKey) -> RouteDecision {
211 let hash = key.hash_value();
212 let table = self.routing_table.read().unwrap();
213
214 if let Some(entry) = table.find_shard(hash) {
215 RouteDecision::Single {
216 shard_id: entry.shard_id.clone(),
217 node_id: self.select_node(entry),
218 }
219 } else {
220 let ring = self.hash_ring.read().unwrap();
222 if let Some(node) = ring.get_node(&format!("{}", hash)) {
223 RouteDecision::Single {
224 shard_id: ShardId::new(0),
225 node_id: node.clone(),
226 }
227 } else {
228 RouteDecision::Broadcast {
229 shards: table.entries.keys().cloned().collect(),
230 }
231 }
232 }
233 }
234
235 pub fn route_write(&self, key: &PartitionKey) -> RouteDecision {
237 let hash = key.hash_value();
238 let table = self.routing_table.read().unwrap();
239
240 if let Some(entry) = table.find_shard(hash) {
241 RouteDecision::Primary {
242 shard_id: entry.shard_id.clone(),
243 node_id: entry.primary.clone(),
244 }
245 } else {
246 RouteDecision::Broadcast {
247 shards: table.entries.keys().cloned().collect(),
248 }
249 }
250 }
251
252 pub fn route_read(&self, key: &PartitionKey) -> RouteDecision {
254 let hash = key.hash_value();
255 let table = self.routing_table.read().unwrap();
256
257 if let Some(entry) = table.find_shard(hash) {
258 let mut candidates = vec![entry.primary.clone()];
259 candidates.extend(entry.replicas.iter().cloned());
260
261 RouteDecision::AnyReplica {
262 shard_id: entry.shard_id.clone(),
263 candidates,
264 }
265 } else {
266 self.route(key)
267 }
268 }
269
270 pub fn route_range(&self, start_key: &PartitionKey, end_key: &PartitionKey) -> RouteDecision {
272 let start_hash = start_key.hash_value();
273 let end_hash = end_key.hash_value();
274
275 let table = self.routing_table.read().unwrap();
276 let mut routes = Vec::new();
277
278 for entry in table.all_entries() {
279 if entry.key_range_end > start_hash && entry.key_range_start < end_hash {
281 routes.push(ShardRoute {
282 shard_id: entry.shard_id.clone(),
283 node_id: self.select_node(entry),
284 is_primary: true,
285 });
286 }
287 }
288
289 if routes.is_empty() {
290 RouteDecision::Broadcast {
291 shards: table.entries.keys().cloned().collect(),
292 }
293 } else if routes.len() == 1 {
294 let route = routes.remove(0);
295 RouteDecision::Single {
296 shard_id: route.shard_id,
297 node_id: route.node_id,
298 }
299 } else {
300 RouteDecision::Multi { routes }
301 }
302 }
303
304 pub fn route_all(&self) -> RouteDecision {
306 let table = self.routing_table.read().unwrap();
307
308 let routes: Vec<_> = table
309 .all_entries()
310 .map(|entry| ShardRoute {
311 shard_id: entry.shard_id.clone(),
312 node_id: self.select_node(entry),
313 is_primary: true,
314 })
315 .collect();
316
317 if routes.is_empty() {
318 RouteDecision::Broadcast { shards: vec![] }
319 } else {
320 RouteDecision::Multi { routes }
321 }
322 }
323
324 fn select_node(&self, entry: &RoutingEntry) -> NodeId {
326 if self.prefer_local {
328 if let Some(ref local) = self.local_node {
329 if &entry.primary == local {
330 return entry.primary.clone();
331 }
332 if entry.replicas.contains(local) {
333 return local.clone();
334 }
335 }
336 }
337
338 entry.primary.clone()
340 }
341
342 pub fn routing_version(&self) -> u64 {
344 self.routing_table.read().unwrap().version()
345 }
346
347 pub fn strategy(&self) -> &PartitionStrategy {
349 &self.partition_strategy
350 }
351
352 pub fn is_initialized(&self) -> bool {
354 !self.routing_table.read().unwrap().is_empty()
355 }
356}
357
358pub struct QueryAnalyzer {
364 partition_columns: Vec<String>,
365}
366
367impl QueryAnalyzer {
368 pub fn new(partition_columns: Vec<String>) -> Self {
370 Self { partition_columns }
371 }
372
373 pub fn analyze(&self, query: &str) -> QueryRouting {
375 let query_upper = query.to_uppercase();
376
377 let query_type = if query_upper.starts_with("SELECT") {
379 QueryType::Read
380 } else if query_upper.starts_with("INSERT") {
381 QueryType::Write
382 } else if query_upper.starts_with("UPDATE") {
383 QueryType::Write
384 } else if query_upper.starts_with("DELETE") {
385 QueryType::Write
386 } else {
387 QueryType::Admin
388 };
389
390 let partition_key = self.extract_partition_key(query);
392 let requires_all_shards = partition_key.is_none() && query_type == QueryType::Read;
393
394 QueryRouting {
395 query_type,
396 partition_key,
397 requires_all_shards,
398 }
399 }
400
401 fn extract_partition_key(&self, query: &str) -> Option<PartitionKey> {
402 let query_lower = query.to_lowercase();
403
404 for col in &self.partition_columns {
405 let col_lower = col.to_lowercase();
406
407 let patterns = [
409 format!("{} = '", col_lower),
410 format!("{} ='", col_lower),
411 format!("{}='", col_lower),
412 format!("{} = ", col_lower),
413 ];
414
415 for pattern in &patterns {
416 if let Some(start) = query_lower.find(pattern) {
417 let value_start = start + pattern.len();
418 let remaining = &query[value_start..];
419
420 let value = if remaining.starts_with('\'') {
421 remaining[1..]
423 .split('\'')
424 .next()
425 .map(|s| s.to_string())
426 } else {
427 remaining
429 .split(|c: char| c.is_whitespace() || c == ')' || c == ';')
430 .next()
431 .map(|s| s.trim_matches('\'').to_string())
432 };
433
434 if let Some(v) = value {
435 if !v.is_empty() {
436 if let Ok(i) = v.parse::<i64>() {
438 return Some(PartitionKey::Int(i));
439 }
440 return Some(PartitionKey::String(v));
441 }
442 }
443 }
444 }
445 }
446
447 None
448 }
449}
450
451#[derive(Debug, Clone)]
453pub struct QueryRouting {
454 pub query_type: QueryType,
455 pub partition_key: Option<PartitionKey>,
456 pub requires_all_shards: bool,
457}
458
459#[derive(Debug, Clone, Copy, PartialEq, Eq)]
461pub enum QueryType {
462 Read,
463 Write,
464 Admin,
465}
466
467#[cfg(test)]
472mod tests {
473 use super::*;
474
475 fn create_test_shards() -> Vec<Shard> {
476 vec![
477 Shard::with_range(
478 ShardId::new(0),
479 NodeId::new("node1"),
480 0,
481 u64::MAX / 2,
482 ),
483 Shard::with_range(
484 ShardId::new(1),
485 NodeId::new("node2"),
486 u64::MAX / 2,
487 u64::MAX,
488 ),
489 ]
490 }
491
492 #[test]
493 fn test_routing_table() {
494 let shards = create_test_shards();
495 let table = RoutingTable::from_shards(&shards);
496
497 assert_eq!(table.len(), 2);
498 assert!(table.get(&ShardId::new(0)).is_some());
499 }
500
501 #[test]
502 fn test_routing_table_find_shard() {
503 let shards = create_test_shards();
504 let table = RoutingTable::from_shards(&shards);
505
506 let entry1 = table.find_shard(100).unwrap();
507 assert_eq!(entry1.shard_id.as_u32(), 0);
508
509 let entry2 = table.find_shard(u64::MAX - 100).unwrap();
510 assert_eq!(entry2.shard_id.as_u32(), 1);
511 }
512
513 #[test]
514 fn test_shard_router() {
515 let strategy = PartitionStrategy::hash(vec!["id".to_string()], 2);
516 let router = ShardRouter::new(strategy);
517
518 let shards = create_test_shards();
519 router.update_routing(&shards);
520
521 assert!(router.is_initialized());
522 }
523
524 #[test]
525 fn test_router_route() {
526 let strategy = PartitionStrategy::hash(vec!["id".to_string()], 2);
527 let router = ShardRouter::new(strategy);
528
529 let shards = create_test_shards();
530 router.update_routing(&shards);
531
532 let key = PartitionKey::string("test_key");
533 let decision = router.route(&key);
534
535 match decision {
536 RouteDecision::Single { shard_id, node_id } => {
537 assert!(shard_id.as_u32() < 2);
538 assert!(["node1", "node2"].contains(&node_id.as_str()));
539 }
540 _ => panic!("Expected single route"),
541 }
542 }
543
544 #[test]
545 fn test_router_route_write() {
546 let strategy = PartitionStrategy::hash(vec!["id".to_string()], 2);
547 let router = ShardRouter::new(strategy);
548
549 let mut shards = create_test_shards();
550 shards[0].add_replica(NodeId::new("node3"));
551 router.update_routing(&shards);
552
553 let key = PartitionKey::int(1);
554 let decision = router.route_write(&key);
555
556 match decision {
557 RouteDecision::Primary { node_id, .. } => {
558 assert!(["node1", "node2"].contains(&node_id.as_str()));
560 }
561 _ => panic!("Expected primary route"),
562 }
563 }
564
565 #[test]
566 fn test_router_route_read() {
567 let strategy = PartitionStrategy::hash(vec!["id".to_string()], 2);
568 let router = ShardRouter::new(strategy);
569
570 let mut shards = create_test_shards();
571 shards[0].add_replica(NodeId::new("node3"));
572 router.update_routing(&shards);
573
574 let key = PartitionKey::int(1);
575 let decision = router.route_read(&key);
576
577 match decision {
578 RouteDecision::AnyReplica { candidates, .. } => {
579 assert!(!candidates.is_empty());
580 }
581 _ => panic!("Expected any replica route"),
582 }
583 }
584
585 #[test]
586 fn test_router_route_all() {
587 let strategy = PartitionStrategy::hash(vec!["id".to_string()], 2);
588 let router = ShardRouter::new(strategy);
589
590 let shards = create_test_shards();
591 router.update_routing(&shards);
592
593 let decision = router.route_all();
594
595 match decision {
596 RouteDecision::Multi { routes } => {
597 assert_eq!(routes.len(), 2);
598 }
599 _ => panic!("Expected multi route"),
600 }
601 }
602
603 #[test]
604 fn test_query_analyzer() {
605 let analyzer = QueryAnalyzer::new(vec!["user_id".to_string()]);
606
607 let routing = analyzer.analyze("SELECT * FROM users WHERE user_id = 123");
608 assert_eq!(routing.query_type, QueryType::Read);
609 assert!(routing.partition_key.is_some());
610
611 let routing = analyzer.analyze("INSERT INTO users VALUES (1, 'Alice')");
612 assert_eq!(routing.query_type, QueryType::Write);
613 }
614
615 #[test]
616 fn test_query_analyzer_no_key() {
617 let analyzer = QueryAnalyzer::new(vec!["user_id".to_string()]);
618
619 let routing = analyzer.analyze("SELECT * FROM users");
620 assert!(routing.partition_key.is_none());
621 assert!(routing.requires_all_shards);
622 }
623
624 #[test]
625 fn test_route_decision_types() {
626 let single = RouteDecision::Single {
627 shard_id: ShardId::new(0),
628 node_id: NodeId::new("node1"),
629 };
630
631 let broadcast = RouteDecision::Broadcast {
632 shards: vec![ShardId::new(0), ShardId::new(1)],
633 };
634
635 match single {
636 RouteDecision::Single { .. } => {}
637 _ => panic!("Expected single"),
638 }
639
640 match broadcast {
641 RouteDecision::Broadcast { shards } => {
642 assert_eq!(shards.len(), 2);
643 }
644 _ => panic!("Expected broadcast"),
645 }
646 }
647}