kaccy_db/
sharding.rs

1//! Database sharding utilities for horizontal scaling
2
3use crate::error::{DbError, Result};
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use sqlx::PgPool;
7use std::collections::HashMap;
8use std::hash::{Hash, Hasher};
9use std::sync::Arc;
10use uuid::Uuid;
11
12/// Shard key type for routing queries to appropriate shards
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub enum ShardKey {
15    /// Shard by user ID
16    UserId(Uuid),
17    /// Shard by token ID
18    TokenId(Uuid),
19    /// Shard by custom string key
20    Custom(String),
21    /// Shard by integer key
22    Integer(i64),
23}
24
25impl ShardKey {
26    /// Get the hash value for this shard key
27    pub fn hash_value(&self) -> u64 {
28        let mut hasher = std::collections::hash_map::DefaultHasher::new();
29        self.hash(&mut hasher);
30        hasher.finish()
31    }
32}
33
34/// Sharding strategy for distributing data across shards
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum ShardingStrategy {
37    /// Hash-based sharding (consistent hashing)
38    Hash,
39    /// Range-based sharding
40    Range,
41    /// Geographic sharding
42    Geographic,
43    /// Custom sharding logic
44    Custom,
45}
46
47/// Shard information
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ShardInfo {
50    /// Shard ID
51    pub id: u32,
52    /// Shard name
53    pub name: String,
54    /// Database connection string
55    pub connection_string: String,
56    /// Whether this shard is active
57    pub is_active: bool,
58    /// Shard weight for load balancing (higher = more load)
59    pub weight: u32,
60    /// Geographic region (optional)
61    pub region: Option<String>,
62    /// Range start (for range-based sharding)
63    pub range_start: Option<u64>,
64    /// Range end (for range-based sharding)
65    pub range_end: Option<u64>,
66}
67
68/// Shard pool manager
69pub struct ShardPoolManager {
70    /// Map of shard ID to connection pool
71    pools: Arc<RwLock<HashMap<u32, PgPool>>>,
72    /// Shard metadata
73    shards: Arc<RwLock<Vec<ShardInfo>>>,
74    /// Sharding strategy
75    strategy: ShardingStrategy,
76}
77
78impl ShardPoolManager {
79    /// Create a new shard pool manager
80    pub fn new(strategy: ShardingStrategy) -> Self {
81        Self {
82            pools: Arc::new(RwLock::new(HashMap::new())),
83            shards: Arc::new(RwLock::new(Vec::new())),
84            strategy,
85        }
86    }
87
88    /// Add a shard to the manager
89    pub async fn add_shard(&self, shard_info: ShardInfo) -> Result<()> {
90        let pool = PgPool::connect(&shard_info.connection_string)
91            .await
92            .map_err(|e| DbError::Connection(format!("Failed to connect to shard: {}", e)))?;
93
94        let mut pools = self.pools.write();
95        let mut shards = self.shards.write();
96
97        pools.insert(shard_info.id, pool);
98        shards.push(shard_info);
99
100        Ok(())
101    }
102
103    /// Remove a shard from the manager
104    pub async fn remove_shard(&self, shard_id: u32) -> Result<()> {
105        let mut pools = self.pools.write();
106        let mut shards = self.shards.write();
107
108        pools.remove(&shard_id);
109        shards.retain(|s| s.id != shard_id);
110
111        Ok(())
112    }
113
114    /// Get shard ID for a given shard key
115    pub fn get_shard_id(&self, key: &ShardKey) -> Result<u32> {
116        let shards = self.shards.read();
117
118        if shards.is_empty() {
119            return Err(DbError::Other("No shards configured".to_string()));
120        }
121
122        match self.strategy {
123            ShardingStrategy::Hash => {
124                let hash = key.hash_value();
125                let active_shards: Vec<_> = shards.iter().filter(|s| s.is_active).collect();
126
127                if active_shards.is_empty() {
128                    return Err(DbError::Other("No active shards available".to_string()));
129                }
130
131                let index = (hash % active_shards.len() as u64) as usize;
132                Ok(active_shards[index].id)
133            }
134            ShardingStrategy::Range => {
135                let hash = key.hash_value();
136
137                for shard in shards.iter().filter(|s| s.is_active) {
138                    if let (Some(start), Some(end)) = (shard.range_start, shard.range_end) {
139                        if hash >= start && hash < end {
140                            return Ok(shard.id);
141                        }
142                    }
143                }
144
145                Err(DbError::Other(
146                    "No shard found for key in range".to_string(),
147                ))
148            }
149            ShardingStrategy::Geographic | ShardingStrategy::Custom => {
150                // For geographic/custom, use first active shard as fallback
151                shards
152                    .iter()
153                    .find(|s| s.is_active)
154                    .map(|s| s.id)
155                    .ok_or_else(|| DbError::Other("No active shards available".to_string()))
156            }
157        }
158    }
159
160    /// Get pool for a given shard key
161    pub fn get_pool(&self, key: &ShardKey) -> Result<PgPool> {
162        let shard_id = self.get_shard_id(key)?;
163        self.get_pool_by_id(shard_id)
164    }
165
166    /// Get pool by shard ID
167    pub fn get_pool_by_id(&self, shard_id: u32) -> Result<PgPool> {
168        let pools = self.pools.read();
169        pools
170            .get(&shard_id)
171            .cloned()
172            .ok_or_else(|| DbError::Other(format!("Shard {} not found", shard_id)))
173    }
174
175    /// Get all active shard pools
176    pub fn get_all_active_pools(&self) -> Vec<(u32, PgPool)> {
177        let pools = self.pools.read();
178        let shards = self.shards.read();
179
180        shards
181            .iter()
182            .filter(|s| s.is_active)
183            .filter_map(|s| pools.get(&s.id).map(|pool| (s.id, pool.clone())))
184            .collect()
185    }
186
187    /// Get shard info by ID
188    pub fn get_shard_info(&self, shard_id: u32) -> Option<ShardInfo> {
189        let shards = self.shards.read();
190        shards.iter().find(|s| s.id == shard_id).cloned()
191    }
192
193    /// List all shards
194    pub fn list_shards(&self) -> Vec<ShardInfo> {
195        let shards = self.shards.read();
196        shards.clone()
197    }
198
199    /// Set shard active/inactive
200    pub fn set_shard_active(&self, shard_id: u32, active: bool) -> Result<()> {
201        let mut shards = self.shards.write();
202
203        if let Some(shard) = shards.iter_mut().find(|s| s.id == shard_id) {
204            shard.is_active = active;
205            Ok(())
206        } else {
207            Err(DbError::Other(format!("Shard {} not found", shard_id)))
208        }
209    }
210
211    /// Get shard count
212    pub fn shard_count(&self) -> usize {
213        let shards = self.shards.read();
214        shards.len()
215    }
216
217    /// Get active shard count
218    pub fn active_shard_count(&self) -> usize {
219        let shards = self.shards.read();
220        shards.iter().filter(|s| s.is_active).count()
221    }
222}
223
224/// Shard coordinator for cross-shard operations
225pub struct ShardCoordinator {
226    manager: Arc<ShardPoolManager>,
227}
228
229impl ShardCoordinator {
230    /// Create a new shard coordinator
231    pub fn new(manager: Arc<ShardPoolManager>) -> Self {
232        Self { manager }
233    }
234
235    /// Execute a query on a specific shard
236    pub async fn execute_on_shard<F, T>(&self, key: &ShardKey, f: F) -> Result<T>
237    where
238        F: FnOnce(
239                &PgPool,
240            )
241                -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send>>
242            + Send,
243        T: Send,
244    {
245        let pool = self.manager.get_pool(key)?;
246        f(&pool).await
247    }
248
249    /// Execute a query on all active shards and collect results
250    pub async fn execute_on_all_shards<F, T>(&self, f: F) -> Result<Vec<(u32, T)>>
251    where
252        F: Fn(&PgPool) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send>>
253            + Send
254            + Sync,
255        T: Send,
256    {
257        let pools = self.manager.get_all_active_pools();
258        let mut results = Vec::new();
259
260        for (shard_id, pool) in pools {
261            match f(&pool).await {
262                Ok(result) => results.push((shard_id, result)),
263                Err(e) => {
264                    tracing::warn!("Error executing on shard {}: {}", shard_id, e);
265                    // Continue with other shards
266                }
267            }
268        }
269
270        Ok(results)
271    }
272
273    /// Aggregate results from all shards
274    pub async fn aggregate_from_all_shards<F, T, R>(
275        &self,
276        query: F,
277        aggregator: fn(Vec<T>) -> R,
278    ) -> Result<R>
279    where
280        F: Fn(&PgPool) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send>>
281            + Send
282            + Sync,
283        T: Send,
284        R: Send,
285    {
286        let results = self.execute_on_all_shards(query).await?;
287        let values: Vec<T> = results.into_iter().map(|(_, v)| v).collect();
288        Ok(aggregator(values))
289    }
290}
291
292/// Consistent hashing ring for shard routing
293pub struct ConsistentHashRing {
294    /// Virtual nodes per shard
295    virtual_nodes: u32,
296    /// Ring of hash values to shard IDs
297    ring: Arc<RwLock<Vec<(u64, u32)>>>,
298}
299
300impl ConsistentHashRing {
301    /// Create a new consistent hash ring
302    pub fn new(virtual_nodes: u32) -> Self {
303        Self {
304            virtual_nodes,
305            ring: Arc::new(RwLock::new(Vec::new())),
306        }
307    }
308
309    /// Add a shard to the ring
310    pub fn add_shard(&self, shard_id: u32) {
311        let mut ring = self.ring.write();
312
313        for i in 0..self.virtual_nodes {
314            let key = format!("shard-{}-vnode-{}", shard_id, i);
315            let mut hasher = std::collections::hash_map::DefaultHasher::new();
316            key.hash(&mut hasher);
317            let hash = hasher.finish();
318            ring.push((hash, shard_id));
319        }
320
321        ring.sort_by_key(|(hash, _)| *hash);
322    }
323
324    /// Remove a shard from the ring
325    pub fn remove_shard(&self, shard_id: u32) {
326        let mut ring = self.ring.write();
327        ring.retain(|(_, id)| *id != shard_id);
328    }
329
330    /// Get shard ID for a given key
331    pub fn get_shard(&self, key: &ShardKey) -> Option<u32> {
332        let ring = self.ring.read();
333
334        if ring.is_empty() {
335            return None;
336        }
337
338        let hash = key.hash_value();
339
340        // Binary search for the first node >= hash
341        match ring.binary_search_by_key(&hash, |(h, _)| *h) {
342            Ok(idx) => Some(ring[idx].1),
343            Err(idx) => {
344                if idx >= ring.len() {
345                    // Wrap around to first node
346                    Some(ring[0].1)
347                } else {
348                    Some(ring[idx].1)
349                }
350            }
351        }
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_shard_key_hash() {
361        let key1 = ShardKey::UserId(Uuid::nil());
362        let key2 = ShardKey::UserId(Uuid::nil());
363        assert_eq!(key1.hash_value(), key2.hash_value());
364    }
365
366    #[test]
367    fn test_consistent_hash_ring() {
368        let ring = ConsistentHashRing::new(100);
369        ring.add_shard(1);
370        ring.add_shard(2);
371        ring.add_shard(3);
372
373        let key = ShardKey::UserId(Uuid::nil());
374        let shard = ring.get_shard(&key);
375        assert!(shard.is_some());
376        assert!(shard.unwrap() <= 3);
377    }
378
379    #[test]
380    fn test_shard_manager_hash_strategy() {
381        let manager = ShardPoolManager::new(ShardingStrategy::Hash);
382        assert_eq!(manager.shard_count(), 0);
383        assert_eq!(manager.active_shard_count(), 0);
384    }
385}