1use crate::error::{DbError, Result};
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use sqlx::PgPool;
7use std::collections::HashMap;
8use std::hash::{Hash, Hasher};
9use std::sync::Arc;
10use uuid::Uuid;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub enum ShardKey {
15 UserId(Uuid),
17 TokenId(Uuid),
19 Custom(String),
21 Integer(i64),
23}
24
25impl ShardKey {
26 pub fn hash_value(&self) -> u64 {
28 let mut hasher = std::collections::hash_map::DefaultHasher::new();
29 self.hash(&mut hasher);
30 hasher.finish()
31 }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum ShardingStrategy {
37 Hash,
39 Range,
41 Geographic,
43 Custom,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ShardInfo {
50 pub id: u32,
52 pub name: String,
54 pub connection_string: String,
56 pub is_active: bool,
58 pub weight: u32,
60 pub region: Option<String>,
62 pub range_start: Option<u64>,
64 pub range_end: Option<u64>,
66}
67
68pub struct ShardPoolManager {
70 pools: Arc<RwLock<HashMap<u32, PgPool>>>,
72 shards: Arc<RwLock<Vec<ShardInfo>>>,
74 strategy: ShardingStrategy,
76}
77
78impl ShardPoolManager {
79 pub fn new(strategy: ShardingStrategy) -> Self {
81 Self {
82 pools: Arc::new(RwLock::new(HashMap::new())),
83 shards: Arc::new(RwLock::new(Vec::new())),
84 strategy,
85 }
86 }
87
88 pub async fn add_shard(&self, shard_info: ShardInfo) -> Result<()> {
90 let pool = PgPool::connect(&shard_info.connection_string)
91 .await
92 .map_err(|e| DbError::Connection(format!("Failed to connect to shard: {}", e)))?;
93
94 let mut pools = self.pools.write();
95 let mut shards = self.shards.write();
96
97 pools.insert(shard_info.id, pool);
98 shards.push(shard_info);
99
100 Ok(())
101 }
102
103 pub async fn remove_shard(&self, shard_id: u32) -> Result<()> {
105 let mut pools = self.pools.write();
106 let mut shards = self.shards.write();
107
108 pools.remove(&shard_id);
109 shards.retain(|s| s.id != shard_id);
110
111 Ok(())
112 }
113
114 pub fn get_shard_id(&self, key: &ShardKey) -> Result<u32> {
116 let shards = self.shards.read();
117
118 if shards.is_empty() {
119 return Err(DbError::Other("No shards configured".to_string()));
120 }
121
122 match self.strategy {
123 ShardingStrategy::Hash => {
124 let hash = key.hash_value();
125 let active_shards: Vec<_> = shards.iter().filter(|s| s.is_active).collect();
126
127 if active_shards.is_empty() {
128 return Err(DbError::Other("No active shards available".to_string()));
129 }
130
131 let index = (hash % active_shards.len() as u64) as usize;
132 Ok(active_shards[index].id)
133 }
134 ShardingStrategy::Range => {
135 let hash = key.hash_value();
136
137 for shard in shards.iter().filter(|s| s.is_active) {
138 if let (Some(start), Some(end)) = (shard.range_start, shard.range_end) {
139 if hash >= start && hash < end {
140 return Ok(shard.id);
141 }
142 }
143 }
144
145 Err(DbError::Other(
146 "No shard found for key in range".to_string(),
147 ))
148 }
149 ShardingStrategy::Geographic | ShardingStrategy::Custom => {
150 shards
152 .iter()
153 .find(|s| s.is_active)
154 .map(|s| s.id)
155 .ok_or_else(|| DbError::Other("No active shards available".to_string()))
156 }
157 }
158 }
159
160 pub fn get_pool(&self, key: &ShardKey) -> Result<PgPool> {
162 let shard_id = self.get_shard_id(key)?;
163 self.get_pool_by_id(shard_id)
164 }
165
166 pub fn get_pool_by_id(&self, shard_id: u32) -> Result<PgPool> {
168 let pools = self.pools.read();
169 pools
170 .get(&shard_id)
171 .cloned()
172 .ok_or_else(|| DbError::Other(format!("Shard {} not found", shard_id)))
173 }
174
175 pub fn get_all_active_pools(&self) -> Vec<(u32, PgPool)> {
177 let pools = self.pools.read();
178 let shards = self.shards.read();
179
180 shards
181 .iter()
182 .filter(|s| s.is_active)
183 .filter_map(|s| pools.get(&s.id).map(|pool| (s.id, pool.clone())))
184 .collect()
185 }
186
187 pub fn get_shard_info(&self, shard_id: u32) -> Option<ShardInfo> {
189 let shards = self.shards.read();
190 shards.iter().find(|s| s.id == shard_id).cloned()
191 }
192
193 pub fn list_shards(&self) -> Vec<ShardInfo> {
195 let shards = self.shards.read();
196 shards.clone()
197 }
198
199 pub fn set_shard_active(&self, shard_id: u32, active: bool) -> Result<()> {
201 let mut shards = self.shards.write();
202
203 if let Some(shard) = shards.iter_mut().find(|s| s.id == shard_id) {
204 shard.is_active = active;
205 Ok(())
206 } else {
207 Err(DbError::Other(format!("Shard {} not found", shard_id)))
208 }
209 }
210
211 pub fn shard_count(&self) -> usize {
213 let shards = self.shards.read();
214 shards.len()
215 }
216
217 pub fn active_shard_count(&self) -> usize {
219 let shards = self.shards.read();
220 shards.iter().filter(|s| s.is_active).count()
221 }
222}
223
224pub struct ShardCoordinator {
226 manager: Arc<ShardPoolManager>,
227}
228
229impl ShardCoordinator {
230 pub fn new(manager: Arc<ShardPoolManager>) -> Self {
232 Self { manager }
233 }
234
235 pub async fn execute_on_shard<F, T>(&self, key: &ShardKey, f: F) -> Result<T>
237 where
238 F: FnOnce(
239 &PgPool,
240 )
241 -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send>>
242 + Send,
243 T: Send,
244 {
245 let pool = self.manager.get_pool(key)?;
246 f(&pool).await
247 }
248
249 pub async fn execute_on_all_shards<F, T>(&self, f: F) -> Result<Vec<(u32, T)>>
251 where
252 F: Fn(&PgPool) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send>>
253 + Send
254 + Sync,
255 T: Send,
256 {
257 let pools = self.manager.get_all_active_pools();
258 let mut results = Vec::new();
259
260 for (shard_id, pool) in pools {
261 match f(&pool).await {
262 Ok(result) => results.push((shard_id, result)),
263 Err(e) => {
264 tracing::warn!("Error executing on shard {}: {}", shard_id, e);
265 }
267 }
268 }
269
270 Ok(results)
271 }
272
273 pub async fn aggregate_from_all_shards<F, T, R>(
275 &self,
276 query: F,
277 aggregator: fn(Vec<T>) -> R,
278 ) -> Result<R>
279 where
280 F: Fn(&PgPool) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send>>
281 + Send
282 + Sync,
283 T: Send,
284 R: Send,
285 {
286 let results = self.execute_on_all_shards(query).await?;
287 let values: Vec<T> = results.into_iter().map(|(_, v)| v).collect();
288 Ok(aggregator(values))
289 }
290}
291
292pub struct ConsistentHashRing {
294 virtual_nodes: u32,
296 ring: Arc<RwLock<Vec<(u64, u32)>>>,
298}
299
300impl ConsistentHashRing {
301 pub fn new(virtual_nodes: u32) -> Self {
303 Self {
304 virtual_nodes,
305 ring: Arc::new(RwLock::new(Vec::new())),
306 }
307 }
308
309 pub fn add_shard(&self, shard_id: u32) {
311 let mut ring = self.ring.write();
312
313 for i in 0..self.virtual_nodes {
314 let key = format!("shard-{}-vnode-{}", shard_id, i);
315 let mut hasher = std::collections::hash_map::DefaultHasher::new();
316 key.hash(&mut hasher);
317 let hash = hasher.finish();
318 ring.push((hash, shard_id));
319 }
320
321 ring.sort_by_key(|(hash, _)| *hash);
322 }
323
324 pub fn remove_shard(&self, shard_id: u32) {
326 let mut ring = self.ring.write();
327 ring.retain(|(_, id)| *id != shard_id);
328 }
329
330 pub fn get_shard(&self, key: &ShardKey) -> Option<u32> {
332 let ring = self.ring.read();
333
334 if ring.is_empty() {
335 return None;
336 }
337
338 let hash = key.hash_value();
339
340 match ring.binary_search_by_key(&hash, |(h, _)| *h) {
342 Ok(idx) => Some(ring[idx].1),
343 Err(idx) => {
344 if idx >= ring.len() {
345 Some(ring[0].1)
347 } else {
348 Some(ring[idx].1)
349 }
350 }
351 }
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_shard_key_hash() {
361 let key1 = ShardKey::UserId(Uuid::nil());
362 let key2 = ShardKey::UserId(Uuid::nil());
363 assert_eq!(key1.hash_value(), key2.hash_value());
364 }
365
366 #[test]
367 fn test_consistent_hash_ring() {
368 let ring = ConsistentHashRing::new(100);
369 ring.add_shard(1);
370 ring.add_shard(2);
371 ring.add_shard(3);
372
373 let key = ShardKey::UserId(Uuid::nil());
374 let shard = ring.get_shard(&key);
375 assert!(shard.is_some());
376 assert!(shard.unwrap() <= 3);
377 }
378
379 #[test]
380 fn test_shard_manager_hash_strategy() {
381 let manager = ShardPoolManager::new(ShardingStrategy::Hash);
382 assert_eq!(manager.shard_count(), 0);
383 assert_eq!(manager.active_shard_count(), 0);
384 }
385}