1use crate::hash::HashRing;
9use crate::node::NodeId;
10use crate::partition::{PartitionKey, PartitionStrategy};
11use crate::shard::{Shard, ShardId};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::RwLock;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
22pub enum RouteDecision {
23 Single { shard_id: ShardId, node_id: NodeId },
25 Multi { routes: Vec<ShardRoute> },
27 Broadcast { shards: Vec<ShardId> },
29 Primary { shard_id: ShardId, node_id: NodeId },
31 AnyReplica {
33 shard_id: ShardId,
34 candidates: Vec<NodeId>,
35 },
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ShardRoute {
41 pub shard_id: ShardId,
42 pub node_id: NodeId,
43 pub is_primary: bool,
44}
45
46#[derive(Debug, Clone)]
52pub struct RoutingTable {
53 entries: HashMap<ShardId, RoutingEntry>,
54 version: u64,
55}
56
57#[derive(Debug, Clone)]
59pub struct RoutingEntry {
60 pub shard_id: ShardId,
61 pub primary: NodeId,
62 pub replicas: Vec<NodeId>,
63 pub key_range_start: u64,
64 pub key_range_end: u64,
65}
66
67impl RoutingTable {
68 pub fn new() -> Self {
70 Self {
71 entries: HashMap::new(),
72 version: 0,
73 }
74 }
75
76 pub fn upsert(&mut self, entry: RoutingEntry) {
78 self.entries.insert(entry.shard_id.clone(), entry);
79 self.version += 1;
80 }
81
82 pub fn remove(&mut self, shard_id: &ShardId) -> Option<RoutingEntry> {
84 let entry = self.entries.remove(shard_id);
85 if entry.is_some() {
86 self.version += 1;
87 }
88 entry
89 }
90
91 pub fn get(&self, shard_id: &ShardId) -> Option<&RoutingEntry> {
93 self.entries.get(shard_id)
94 }
95
96 pub fn find_shard(&self, key_hash: u64) -> Option<&RoutingEntry> {
98 self.entries
99 .values()
100 .find(|e| key_hash >= e.key_range_start && key_hash < e.key_range_end)
101 }
102
103 pub fn all_entries(&self) -> impl Iterator<Item = &RoutingEntry> {
105 self.entries.values()
106 }
107
108 pub fn version(&self) -> u64 {
110 self.version
111 }
112
113 pub fn len(&self) -> usize {
115 self.entries.len()
116 }
117
118 pub fn is_empty(&self) -> bool {
120 self.entries.is_empty()
121 }
122
123 pub fn from_shards(shards: &[Shard]) -> Self {
125 let mut table = Self::new();
126
127 for shard in shards {
128 table.upsert(RoutingEntry {
129 shard_id: shard.id.clone(),
130 primary: shard.primary_node.clone(),
131 replicas: shard.replica_nodes.clone(),
132 key_range_start: shard.key_range_start.unwrap_or(0),
133 key_range_end: shard.key_range_end.unwrap_or(u64::MAX),
134 });
135 }
136
137 table
138 }
139}
140
141impl Default for RoutingTable {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147pub struct ShardRouter {
153 routing_table: RwLock<RoutingTable>,
154 hash_ring: RwLock<HashRing>,
155 partition_strategy: PartitionStrategy,
156 prefer_local: bool,
157 local_node: Option<NodeId>,
158}
159
160impl ShardRouter {
161 pub fn new(strategy: PartitionStrategy) -> Self {
163 Self {
164 routing_table: RwLock::new(RoutingTable::new()),
165 hash_ring: RwLock::new(HashRing::default()),
166 partition_strategy: strategy,
167 prefer_local: true,
168 local_node: None,
169 }
170 }
171
172 pub fn with_local_node(strategy: PartitionStrategy, local_node: NodeId) -> Self {
174 Self {
175 routing_table: RwLock::new(RoutingTable::new()),
176 hash_ring: RwLock::new(HashRing::default()),
177 partition_strategy: strategy,
178 prefer_local: true,
179 local_node: Some(local_node),
180 }
181 }
182
183 pub fn update_routing(&self, shards: &[Shard]) {
185 let table = RoutingTable::from_shards(shards);
186 *self
187 .routing_table
188 .write()
189 .expect("router routing_table lock poisoned") = table;
190
191 let mut ring = self
193 .hash_ring
194 .write()
195 .expect("router hash_ring lock poisoned");
196 *ring = HashRing::default();
197 for shard in shards {
198 ring.add_node(shard.primary_node.clone());
199 for replica in &shard.replica_nodes {
200 ring.add_node(replica.clone());
201 }
202 }
203 }
204
205 pub fn route(&self, key: &PartitionKey) -> RouteDecision {
207 let hash = key.hash_value();
208 let table = self
209 .routing_table
210 .read()
211 .expect("router routing_table lock poisoned");
212
213 if let Some(entry) = table.find_shard(hash) {
214 RouteDecision::Single {
215 shard_id: entry.shard_id.clone(),
216 node_id: self.select_node(entry),
217 }
218 } else {
219 let ring = self
221 .hash_ring
222 .read()
223 .expect("router hash_ring lock poisoned");
224 if let Some(node) = ring.get_node(&format!("{}", hash)) {
225 RouteDecision::Single {
226 shard_id: ShardId::new(0),
227 node_id: node.clone(),
228 }
229 } else {
230 RouteDecision::Broadcast {
231 shards: table.entries.keys().cloned().collect(),
232 }
233 }
234 }
235 }
236
237 pub fn route_write(&self, key: &PartitionKey) -> RouteDecision {
239 let hash = key.hash_value();
240 let table = self
241 .routing_table
242 .read()
243 .expect("router routing_table lock poisoned");
244
245 if let Some(entry) = table.find_shard(hash) {
246 RouteDecision::Primary {
247 shard_id: entry.shard_id.clone(),
248 node_id: entry.primary.clone(),
249 }
250 } else {
251 RouteDecision::Broadcast {
252 shards: table.entries.keys().cloned().collect(),
253 }
254 }
255 }
256
257 pub fn route_read(&self, key: &PartitionKey) -> RouteDecision {
259 let hash = key.hash_value();
260 let table = self
261 .routing_table
262 .read()
263 .expect("router routing_table lock poisoned");
264
265 if let Some(entry) = table.find_shard(hash) {
266 let mut candidates = vec![entry.primary.clone()];
267 candidates.extend(entry.replicas.iter().cloned());
268
269 RouteDecision::AnyReplica {
270 shard_id: entry.shard_id.clone(),
271 candidates,
272 }
273 } else {
274 drop(table);
275 self.route(key)
276 }
277 }
278
279 pub fn route_range(&self, start_key: &PartitionKey, end_key: &PartitionKey) -> RouteDecision {
281 let start_hash = start_key.hash_value();
282 let end_hash = end_key.hash_value();
283
284 let table = self
285 .routing_table
286 .read()
287 .expect("router routing_table lock poisoned");
288 let mut routes = Vec::new();
289
290 for entry in table.all_entries() {
291 if entry.key_range_end > start_hash && entry.key_range_start < end_hash {
293 routes.push(ShardRoute {
294 shard_id: entry.shard_id.clone(),
295 node_id: self.select_node(entry),
296 is_primary: true,
297 });
298 }
299 }
300
301 if routes.is_empty() {
302 RouteDecision::Broadcast {
303 shards: table.entries.keys().cloned().collect(),
304 }
305 } else if routes.len() == 1 {
306 let route = routes.remove(0);
307 RouteDecision::Single {
308 shard_id: route.shard_id,
309 node_id: route.node_id,
310 }
311 } else {
312 RouteDecision::Multi { routes }
313 }
314 }
315
316 pub fn route_all(&self) -> RouteDecision {
318 let table = self
319 .routing_table
320 .read()
321 .expect("router routing_table lock poisoned");
322
323 let routes: Vec<_> = table
324 .all_entries()
325 .map(|entry| ShardRoute {
326 shard_id: entry.shard_id.clone(),
327 node_id: self.select_node(entry),
328 is_primary: true,
329 })
330 .collect();
331
332 if routes.is_empty() {
333 RouteDecision::Broadcast { shards: vec![] }
334 } else {
335 RouteDecision::Multi { routes }
336 }
337 }
338
339 fn select_node(&self, entry: &RoutingEntry) -> NodeId {
341 if self.prefer_local {
343 if let Some(ref local) = self.local_node {
344 if &entry.primary == local {
345 return entry.primary.clone();
346 }
347 if entry.replicas.contains(local) {
348 return local.clone();
349 }
350 }
351 }
352
353 entry.primary.clone()
355 }
356
357 pub fn routing_version(&self) -> u64 {
359 self.routing_table
360 .read()
361 .expect("router routing_table lock poisoned")
362 .version()
363 }
364
365 pub fn strategy(&self) -> &PartitionStrategy {
367 &self.partition_strategy
368 }
369
370 pub fn is_initialized(&self) -> bool {
372 !self
373 .routing_table
374 .read()
375 .expect("router routing_table lock poisoned")
376 .is_empty()
377 }
378}
379
380pub struct QueryAnalyzer {
386 partition_columns: Vec<String>,
387}
388
389impl QueryAnalyzer {
390 pub fn new(partition_columns: Vec<String>) -> Self {
392 Self { partition_columns }
393 }
394
395 pub fn analyze(&self, query: &str) -> QueryRouting {
397 let query_upper = query.to_uppercase();
398
399 let query_type = if query_upper.starts_with("SELECT") {
401 QueryType::Read
402 } else if query_upper.starts_with("INSERT") {
403 QueryType::Write
404 } else if query_upper.starts_with("UPDATE") {
405 QueryType::Write
406 } else if query_upper.starts_with("DELETE") {
407 QueryType::Write
408 } else {
409 QueryType::Admin
410 };
411
412 let partition_key = self.extract_partition_key(query);
414 let requires_all_shards = partition_key.is_none() && query_type == QueryType::Read;
415
416 QueryRouting {
417 query_type,
418 partition_key,
419 requires_all_shards,
420 }
421 }
422
423 fn extract_partition_key(&self, query: &str) -> Option<PartitionKey> {
424 let query_lower = query.to_lowercase();
425
426 for col in &self.partition_columns {
427 let col_lower = col.to_lowercase();
428
429 let patterns = [
431 format!("{} = '", col_lower),
432 format!("{} ='", col_lower),
433 format!("{}='", col_lower),
434 format!("{} = ", col_lower),
435 ];
436
437 for pattern in &patterns {
438 if let Some(start) = query_lower.find(pattern) {
439 let value_start = start + pattern.len();
440 let remaining = &query[value_start..];
441
442 let value = if remaining.starts_with('\'') {
443 remaining[1..].split('\'').next().map(|s| s.to_string())
445 } else {
446 remaining
448 .split(|c: char| c.is_whitespace() || c == ')' || c == ';')
449 .next()
450 .map(|s| s.trim_matches('\'').to_string())
451 };
452
453 if let Some(v) = value {
454 if !v.is_empty() {
455 if let Ok(i) = v.parse::<i64>() {
457 return Some(PartitionKey::Int(i));
458 }
459 return Some(PartitionKey::String(v));
460 }
461 }
462 }
463 }
464 }
465
466 None
467 }
468}
469
470#[derive(Debug, Clone)]
472pub struct QueryRouting {
473 pub query_type: QueryType,
474 pub partition_key: Option<PartitionKey>,
475 pub requires_all_shards: bool,
476}
477
478#[derive(Debug, Clone, Copy, PartialEq, Eq)]
480pub enum QueryType {
481 Read,
482 Write,
483 Admin,
484}
485
486#[cfg(test)]
491mod tests {
492 use super::*;
493
494 fn create_test_shards() -> Vec<Shard> {
495 vec![
496 Shard::with_range(ShardId::new(0), NodeId::new("node1"), 0, u64::MAX / 2),
497 Shard::with_range(
498 ShardId::new(1),
499 NodeId::new("node2"),
500 u64::MAX / 2,
501 u64::MAX,
502 ),
503 ]
504 }
505
506 #[test]
507 fn test_routing_table() {
508 let shards = create_test_shards();
509 let table = RoutingTable::from_shards(&shards);
510
511 assert_eq!(table.len(), 2);
512 assert!(table.get(&ShardId::new(0)).is_some());
513 }
514
515 #[test]
516 fn test_routing_table_find_shard() {
517 let shards = create_test_shards();
518 let table = RoutingTable::from_shards(&shards);
519
520 let entry1 = table.find_shard(100).unwrap();
521 assert_eq!(entry1.shard_id.as_u32(), 0);
522
523 let entry2 = table.find_shard(u64::MAX - 100).unwrap();
524 assert_eq!(entry2.shard_id.as_u32(), 1);
525 }
526
527 #[test]
528 fn test_shard_router() {
529 let strategy = PartitionStrategy::hash(vec!["id".to_string()], 2);
530 let router = ShardRouter::new(strategy);
531
532 let shards = create_test_shards();
533 router.update_routing(&shards);
534
535 assert!(router.is_initialized());
536 }
537
538 #[test]
539 fn test_router_route() {
540 let strategy = PartitionStrategy::hash(vec!["id".to_string()], 2);
541 let router = ShardRouter::new(strategy);
542
543 let shards = create_test_shards();
544 router.update_routing(&shards);
545
546 let key = PartitionKey::string("test_key");
547 let decision = router.route(&key);
548
549 match decision {
550 RouteDecision::Single { shard_id, node_id } => {
551 assert!(shard_id.as_u32() < 2);
552 assert!(["node1", "node2"].contains(&node_id.as_str()));
553 }
554 _ => panic!("Expected single route"),
555 }
556 }
557
558 #[test]
559 fn test_router_route_write() {
560 let strategy = PartitionStrategy::hash(vec!["id".to_string()], 2);
561 let router = ShardRouter::new(strategy);
562
563 let mut shards = create_test_shards();
564 shards[0].add_replica(NodeId::new("node3"));
565 router.update_routing(&shards);
566
567 let key = PartitionKey::int(1);
568 let decision = router.route_write(&key);
569
570 match decision {
571 RouteDecision::Primary { node_id, .. } => {
572 assert!(["node1", "node2"].contains(&node_id.as_str()));
574 }
575 _ => panic!("Expected primary route"),
576 }
577 }
578
579 #[test]
580 fn test_router_route_read() {
581 let strategy = PartitionStrategy::hash(vec!["id".to_string()], 2);
582 let router = ShardRouter::new(strategy);
583
584 let mut shards = create_test_shards();
585 shards[0].add_replica(NodeId::new("node3"));
586 router.update_routing(&shards);
587
588 let key = PartitionKey::int(1);
589 let decision = router.route_read(&key);
590
591 match decision {
592 RouteDecision::AnyReplica { candidates, .. } => {
593 assert!(!candidates.is_empty());
594 }
595 _ => panic!("Expected any replica route"),
596 }
597 }
598
599 #[test]
600 fn test_router_route_all() {
601 let strategy = PartitionStrategy::hash(vec!["id".to_string()], 2);
602 let router = ShardRouter::new(strategy);
603
604 let shards = create_test_shards();
605 router.update_routing(&shards);
606
607 let decision = router.route_all();
608
609 match decision {
610 RouteDecision::Multi { routes } => {
611 assert_eq!(routes.len(), 2);
612 }
613 _ => panic!("Expected multi route"),
614 }
615 }
616
617 #[test]
618 fn test_query_analyzer() {
619 let analyzer = QueryAnalyzer::new(vec!["user_id".to_string()]);
620
621 let routing = analyzer.analyze("SELECT * FROM users WHERE user_id = 123");
622 assert_eq!(routing.query_type, QueryType::Read);
623 assert!(routing.partition_key.is_some());
624
625 let routing = analyzer.analyze("INSERT INTO users VALUES (1, 'Alice')");
626 assert_eq!(routing.query_type, QueryType::Write);
627 }
628
629 #[test]
630 fn test_query_analyzer_no_key() {
631 let analyzer = QueryAnalyzer::new(vec!["user_id".to_string()]);
632
633 let routing = analyzer.analyze("SELECT * FROM users");
634 assert!(routing.partition_key.is_none());
635 assert!(routing.requires_all_shards);
636 }
637
638 #[test]
639 fn test_route_decision_types() {
640 let single = RouteDecision::Single {
641 shard_id: ShardId::new(0),
642 node_id: NodeId::new("node1"),
643 };
644
645 let broadcast = RouteDecision::Broadcast {
646 shards: vec![ShardId::new(0), ShardId::new(1)],
647 };
648
649 match single {
650 RouteDecision::Single { .. } => {}
651 _ => panic!("Expected single"),
652 }
653
654 match broadcast {
655 RouteDecision::Broadcast { shards } => {
656 assert_eq!(shards.len(), 2);
657 }
658 _ => panic!("Expected broadcast"),
659 }
660 }
661}