Skip to main content

sqlmodel_pool/
sharding.rs

1//! Horizontal sharding support for SQLModel Rust.
2//!
3//! This module provides infrastructure for partitioning data across multiple
4//! database shards based on a shard key.
5//!
6//! # Overview
7//!
8//! Horizontal sharding distributes rows across multiple databases based on a
9//! shard key (e.g., `user_id`, `tenant_id`). This enables:
10//!
11//! - Horizontal scalability beyond single-database limits
12//! - Data isolation between tenants/regions
13//! - Improved query performance through data locality
14//!
15//! # Example
16//!
17//! ```rust,ignore
18//! use sqlmodel_pool::{Pool, PoolConfig, ShardedPool, ShardChooser};
19//! use sqlmodel_core::{Model, Value};
20//!
21//! // Define a shard chooser based on modulo hashing
22//! struct ModuloShardChooser {
23//!     shard_count: usize,
24//! }
25//!
26//! impl ShardChooser for ModuloShardChooser {
27//!     fn choose_for_model(&self, shard_key: &Value) -> String {
28//!         let id = match shard_key {
29//!             Value::BigInt(n) => *n as usize,
30//!             Value::Int(n) => *n as usize,
31//!             _ => 0,
32//!         };
33//!         format!("shard_{}", id % self.shard_count)
34//!     }
35//!
36//!     fn choose_for_query(&self, _hints: &QueryHints) -> Vec<String> {
37//!         // Query all shards by default
38//!         (0..self.shard_count)
39//!             .map(|i| format!("shard_{}", i))
40//!             .collect()
41//!     }
42//! }
43//!
44//! // Create sharded pool
45//! let mut sharded_pool = ShardedPool::new(ModuloShardChooser { shard_count: 3 });
46//! sharded_pool.add_shard("shard_0", pool_0);
47//! sharded_pool.add_shard("shard_1", pool_1);
48//! sharded_pool.add_shard("shard_2", pool_2);
49//!
50//! // Insert routes to correct shard based on model's shard key
51//! let order = Order { user_id: 42, ... };
52//! let shard = sharded_pool.choose_for_model(&order);
53//! ```
54
55use std::collections::HashMap;
56use std::future::Future;
57use std::sync::Arc;
58
59use asupersync::{Cx, Outcome};
60use sqlmodel_core::error::{PoolError, PoolErrorKind};
61use sqlmodel_core::{Connection, Error, Model, Value};
62
63use crate::{Pool, PoolConfig, PooledConnection};
64
65/// Hints for query routing when a specific shard key isn't available.
66///
67/// When executing queries that don't have a clear shard key (e.g., range queries,
68/// aggregations), these hints help the `ShardChooser` decide which shards to query.
69#[derive(Debug, Clone, Default)]
70pub struct QueryHints {
71    /// Specific shard names to target (if known).
72    pub target_shards: Option<Vec<String>>,
73
74    /// Whether to query all shards (scatter-gather).
75    pub scatter_gather: bool,
76
77    /// Optional shard key value extracted from query predicates.
78    pub shard_key_value: Option<Value>,
79
80    /// Query type hint (e.g., "select", "aggregate", "count").
81    pub query_type: Option<String>,
82}
83
84impl QueryHints {
85    /// Create empty hints (defaults to scatter-gather).
86    #[must_use]
87    pub fn new() -> Self {
88        Self::default()
89    }
90
91    /// Target specific shards by name.
92    #[must_use]
93    pub fn target(mut self, shards: Vec<String>) -> Self {
94        self.target_shards = Some(shards);
95        self
96    }
97
98    /// Enable scatter-gather mode (query all shards).
99    #[must_use]
100    pub fn scatter_gather(mut self) -> Self {
101        self.scatter_gather = true;
102        self
103    }
104
105    /// Provide a shard key value for routing.
106    #[must_use]
107    pub fn with_shard_key(mut self, value: Value) -> Self {
108        self.shard_key_value = Some(value);
109        self
110    }
111
112    /// Set the query type hint.
113    #[must_use]
114    pub fn query_type(mut self, query_type: impl Into<String>) -> Self {
115        self.query_type = Some(query_type.into());
116        self
117    }
118}
119
120/// Trait for determining which shard(s) to use for operations.
121///
122/// Implement this trait to define your sharding strategy. Common strategies:
123///
124/// - **Modulo hashing**: `shard_key % shard_count`
125/// - **Range-based**: Partition by key ranges (e.g., user IDs 0-1M → shard_0)
126/// - **Consistent hashing**: Minimize rebalancing when adding/removing shards
127/// - **Tenant-based**: Map tenant IDs directly to shard names
128///
129/// # Example
130///
131/// ```rust,ignore
132/// struct TenantShardChooser {
133///     tenant_to_shard: HashMap<String, String>,
134///     default_shard: String,
135/// }
136///
137/// impl ShardChooser for TenantShardChooser {
138///     fn choose_for_model(&self, shard_key: &Value) -> String {
139///         if let Value::Text(tenant_id) = shard_key {
140///             self.tenant_to_shard
141///                 .get(tenant_id)
142///                 .cloned()
143///                 .unwrap_or_else(|| self.default_shard.clone())
144///         } else {
145///             self.default_shard.clone()
146///         }
147///     }
148///
149///     fn choose_for_query(&self, hints: &QueryHints) -> Vec<String> {
150///         if let Some(Value::Text(tenant_id)) = &hints.shard_key_value {
151///             vec![self.choose_for_model(&Value::Text(tenant_id.clone()))]
152///         } else {
153///             // Query all shards
154///             self.tenant_to_shard.values().cloned().collect()
155///         }
156///     }
157/// }
158/// ```
159pub trait ShardChooser: Send + Sync {
160    /// Choose the shard for a model based on its shard key value.
161    ///
162    /// This is used for INSERT, UPDATE, and DELETE operations where the
163    /// shard key is known from the model instance.
164    ///
165    /// # Arguments
166    ///
167    /// * `shard_key` - The value of the model's shard key field
168    ///
169    /// # Returns
170    ///
171    /// The name of the shard to use (must match a shard registered in `ShardedPool`).
172    fn choose_for_model(&self, shard_key: &Value) -> String;
173
174    /// Choose which shards to query based on query hints.
175    ///
176    /// For queries where the shard key isn't directly available (e.g., range
177    /// queries, joins, aggregations), this method returns the list of shards
178    /// to query.
179    ///
180    /// # Arguments
181    ///
182    /// * `hints` - Query routing hints (target shards, shard key value, etc.)
183    ///
184    /// # Returns
185    ///
186    /// List of shard names to query. For point queries with a known shard key,
187    /// this should return a single shard. For scatter-gather, return all shards.
188    fn choose_for_query(&self, hints: &QueryHints) -> Vec<String>;
189
190    /// Get all registered shard names.
191    ///
192    /// Default implementation returns an empty vec; override if your chooser
193    /// tracks shard names internally.
194    fn all_shards(&self) -> Vec<String> {
195        vec![]
196    }
197}
198
199/// A simple modulo-based shard chooser for numeric shard keys.
200///
201/// Routes based on `shard_key % shard_count`, producing shard names like
202/// `shard_0`, `shard_1`, etc.
203///
204/// This is suitable for evenly distributed numeric keys (e.g., auto-increment IDs).
205/// Not suitable for sequential inserts (hotspotting on latest shard) or
206/// non-numeric keys.
207#[derive(Debug, Clone)]
208pub struct ModuloShardChooser {
209    shard_count: usize,
210    shard_prefix: String,
211}
212
213impl ModuloShardChooser {
214    /// Create a new modulo shard chooser with the given number of shards.
215    ///
216    /// Shards are named `shard_0`, `shard_1`, ..., `shard_{n-1}`.
217    ///
218    /// # Panics
219    ///
220    /// Panics if `shard_count` is 0, as this would cause division by zero
221    /// when routing to shards.
222    #[must_use]
223    pub fn new(shard_count: usize) -> Self {
224        assert!(shard_count > 0, "shard_count must be greater than 0");
225        Self {
226            shard_count,
227            shard_prefix: "shard_".to_string(),
228        }
229    }
230
231    /// Set a custom prefix for shard names (default: "shard_").
232    #[must_use]
233    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
234        self.shard_prefix = prefix.into();
235        self
236    }
237
238    /// Get the shard count.
239    #[must_use]
240    pub fn shard_count(&self) -> usize {
241        self.shard_count
242    }
243
244    /// Extract a numeric value from a Value for modulo calculation.
245    ///
246    /// Truncation on 32-bit platforms is acceptable here since we only need
247    /// the value for consistent shard routing via modulo.
248    #[allow(clippy::cast_possible_truncation)]
249    fn extract_numeric(&self, value: &Value) -> usize {
250        match value {
251            Value::BigInt(n) => (*n).unsigned_abs() as usize,
252            Value::Int(n) => (*n).unsigned_abs() as usize,
253            Value::SmallInt(n) => (*n).unsigned_abs() as usize,
254            Value::Text(s) => {
255                // Hash the string for non-numeric keys
256                use std::hash::{Hash, Hasher};
257                let mut hasher = std::collections::hash_map::DefaultHasher::new();
258                s.hash(&mut hasher);
259                hasher.finish() as usize
260            }
261            _ => 0,
262        }
263    }
264}
265
266impl ShardChooser for ModuloShardChooser {
267    fn choose_for_model(&self, shard_key: &Value) -> String {
268        let n = self.extract_numeric(shard_key);
269        format!("{}{}", self.shard_prefix, n % self.shard_count)
270    }
271
272    fn choose_for_query(&self, hints: &QueryHints) -> Vec<String> {
273        // If specific shards are targeted, use those
274        if let Some(ref targets) = hints.target_shards {
275            return targets.clone();
276        }
277
278        // If shard key is available, route to specific shard
279        if let Some(ref value) = hints.shard_key_value {
280            return vec![self.choose_for_model(value)];
281        }
282
283        // Default: scatter-gather to all shards
284        self.all_shards()
285    }
286
287    fn all_shards(&self) -> Vec<String> {
288        (0..self.shard_count)
289            .map(|i| format!("{}{}", self.shard_prefix, i))
290            .collect()
291    }
292}
293
294/// A sharded connection pool that routes operations to the correct shard.
295///
296/// `ShardedPool` wraps multiple `Pool` instances, one per shard, and uses
297/// a `ShardChooser` to determine which shard to use for each operation.
298///
299/// # Example
300///
301/// ```rust,ignore
302/// // Create pools for each shard
303/// let pool_0 = Pool::new(PoolConfig::new(10));
304/// let pool_1 = Pool::new(PoolConfig::new(10));
305///
306/// // Create sharded pool with modulo chooser
307/// let chooser = ModuloShardChooser::new(2);
308/// let mut sharded = ShardedPool::new(chooser);
309/// sharded.add_shard("shard_0", pool_0);
310/// sharded.add_shard("shard_1", pool_1);
311///
312/// // Acquire connection from specific shard
313/// let conn = sharded.acquire_for_model(&cx, &order, factory).await?;
314/// ```
315pub struct ShardedPool<C: Connection, S: ShardChooser> {
316    shards: HashMap<String, Pool<C>>,
317    chooser: Arc<S>,
318}
319
320impl<C: Connection, S: ShardChooser> ShardedPool<C, S> {
321    /// Create a new sharded pool with the given shard chooser.
322    pub fn new(chooser: S) -> Self {
323        Self {
324            shards: HashMap::new(),
325            chooser: Arc::new(chooser),
326        }
327    }
328
329    /// Add a shard to the pool.
330    ///
331    /// # Arguments
332    ///
333    /// * `name` - The shard name (must match names returned by the chooser)
334    /// * `pool` - The connection pool for this shard
335    pub fn add_shard(&mut self, name: impl Into<String>, pool: Pool<C>) {
336        self.shards.insert(name.into(), pool);
337    }
338
339    /// Add a shard with a new pool created from the given config.
340    pub fn add_shard_with_config(&mut self, name: impl Into<String>, config: PoolConfig) {
341        self.shards.insert(name.into(), Pool::new(config));
342    }
343
344    /// Get a reference to the shard chooser.
345    pub fn chooser(&self) -> &S {
346        &self.chooser
347    }
348
349    /// Get a reference to a specific shard's pool.
350    pub fn get_shard(&self, name: &str) -> Option<&Pool<C>> {
351        self.shards.get(name)
352    }
353
354    /// Get all shard names.
355    pub fn shard_names(&self) -> Vec<String> {
356        self.shards.keys().cloned().collect()
357    }
358
359    /// Get the number of shards.
360    pub fn shard_count(&self) -> usize {
361        self.shards.len()
362    }
363
364    /// Check if a shard exists.
365    pub fn has_shard(&self, name: &str) -> bool {
366        self.shards.contains_key(name)
367    }
368
369    /// Choose the shard for a model based on its shard key.
370    ///
371    /// Returns the shard name. Use this when you need to know the shard
372    /// without acquiring a connection.
373    #[allow(clippy::result_large_err)]
374    pub fn choose_for_model<M: Model>(&self, model: &M) -> Result<String, Error> {
375        let shard_key = model.shard_key_value().ok_or_else(|| {
376            Error::Pool(PoolError {
377                kind: PoolErrorKind::Config,
378                message: format!(
379                    "Model {} has no shard key defined; add #[sqlmodel(shard_key = \"field\")]",
380                    M::TABLE_NAME
381                ),
382                source: None,
383            })
384        })?;
385        Ok(self.chooser.choose_for_model(&shard_key))
386    }
387
388    /// Choose shards for a query based on hints.
389    pub fn choose_for_query(&self, hints: &QueryHints) -> Vec<String> {
390        self.chooser.choose_for_query(hints)
391    }
392
393    /// Acquire a connection from the shard determined by the model's shard key.
394    ///
395    /// # Arguments
396    ///
397    /// * `cx` - The async context
398    /// * `model` - The model instance (must have a shard key)
399    /// * `factory` - Connection factory function
400    ///
401    /// # Errors
402    ///
403    /// Returns an error if:
404    /// - The model has no shard key
405    /// - The determined shard doesn't exist
406    /// - Connection acquisition fails
407    pub async fn acquire_for_model<M, F, Fut>(
408        &self,
409        cx: &Cx,
410        model: &M,
411        factory: F,
412    ) -> Outcome<PooledConnection<C>, Error>
413    where
414        M: Model,
415        F: Fn() -> Fut,
416        Fut: Future<Output = Outcome<C, Error>>,
417    {
418        let shard_name = match self.choose_for_model(model) {
419            Ok(name) => name,
420            Err(e) => return Outcome::Err(e),
421        };
422
423        self.acquire_from_shard(cx, &shard_name, factory).await
424    }
425
426    /// Acquire a connection from a specific shard by name.
427    ///
428    /// # Arguments
429    ///
430    /// * `cx` - The async context
431    /// * `shard_name` - The name of the shard to acquire from
432    /// * `factory` - Connection factory function
433    ///
434    /// # Errors
435    ///
436    /// Returns an error if:
437    /// - The shard doesn't exist
438    /// - Connection acquisition fails
439    pub async fn acquire_from_shard<F, Fut>(
440        &self,
441        cx: &Cx,
442        shard_name: &str,
443        factory: F,
444    ) -> Outcome<PooledConnection<C>, Error>
445    where
446        F: Fn() -> Fut,
447        Fut: Future<Output = Outcome<C, Error>>,
448    {
449        let Some(pool) = self.shards.get(shard_name) else {
450            return Outcome::Err(Error::Pool(PoolError {
451                kind: PoolErrorKind::Config,
452                message: format!(
453                    "shard '{}' not found; available shards: {:?}",
454                    shard_name,
455                    self.shard_names()
456                ),
457                source: None,
458            }));
459        };
460
461        pool.acquire(cx, factory).await
462    }
463
464    /// Acquire connections from multiple shards for scatter-gather queries.
465    ///
466    /// Returns a map of shard name to pooled connection for each successfully
467    /// acquired connection. Failed acquisitions are logged but don't fail the
468    /// entire operation.
469    ///
470    /// # Arguments
471    ///
472    /// * `cx` - The async context
473    /// * `hints` - Query routing hints
474    /// * `factory` - Connection factory function
475    pub async fn acquire_for_query<F, Fut>(
476        &self,
477        cx: &Cx,
478        hints: &QueryHints,
479        factory: F,
480    ) -> Result<HashMap<String, PooledConnection<C>>, Error>
481    where
482        F: Fn() -> Fut + Clone,
483        Fut: Future<Output = Outcome<C, Error>>,
484    {
485        let target_shards = self.choose_for_query(hints);
486        let mut connections = HashMap::new();
487
488        for shard_name in target_shards {
489            match self
490                .acquire_from_shard(cx, &shard_name, factory.clone())
491                .await
492            {
493                Outcome::Ok(conn) => {
494                    connections.insert(shard_name, conn);
495                }
496                Outcome::Err(e) => {
497                    tracing::warn!(shard = %shard_name, error = %e, "Failed to acquire connection from shard");
498                }
499                Outcome::Cancelled(reason) => {
500                    tracing::debug!(shard = %shard_name, reason = ?reason, "Cancelled while acquiring from shard");
501                }
502                Outcome::Panicked(info) => {
503                    tracing::error!(shard = %shard_name, panic = ?info, "Panic while acquiring from shard");
504                }
505            }
506        }
507
508        if connections.is_empty() {
509            return Err(Error::Pool(PoolError {
510                kind: PoolErrorKind::Exhausted,
511                message: "failed to acquire connection from any shard".to_string(),
512                source: None,
513            }));
514        }
515
516        Ok(connections)
517    }
518
519    /// Close all shards.
520    pub fn close(&self) {
521        for pool in self.shards.values() {
522            pool.close();
523        }
524    }
525
526    /// Check if all shards are closed.
527    pub fn is_closed(&self) -> bool {
528        self.shards.values().all(|p| p.is_closed())
529    }
530
531    /// Get aggregate statistics across all shards.
532    pub fn stats(&self) -> ShardedPoolStats {
533        let mut total = ShardedPoolStats::default();
534
535        for (name, pool) in &self.shards {
536            let shard_stats = pool.stats();
537            total.per_shard.insert(name.clone(), shard_stats.clone());
538            total.total_connections += shard_stats.total_connections;
539            total.idle_connections += shard_stats.idle_connections;
540            total.active_connections += shard_stats.active_connections;
541            total.pending_requests += shard_stats.pending_requests;
542            total.connections_created += shard_stats.connections_created;
543            total.connections_closed += shard_stats.connections_closed;
544            total.acquires += shard_stats.acquires;
545            total.timeouts += shard_stats.timeouts;
546        }
547
548        total.shard_count = self.shards.len();
549        total
550    }
551}
552
553/// Aggregate statistics for a sharded pool.
554#[derive(Debug, Clone, Default)]
555pub struct ShardedPoolStats {
556    /// Number of shards.
557    pub shard_count: usize,
558    /// Per-shard statistics.
559    pub per_shard: HashMap<String, crate::PoolStats>,
560    /// Total connections across all shards.
561    pub total_connections: usize,
562    /// Idle connections across all shards.
563    pub idle_connections: usize,
564    /// Active connections across all shards.
565    pub active_connections: usize,
566    /// Pending requests across all shards.
567    pub pending_requests: usize,
568    /// Total connections created across all shards.
569    pub connections_created: u64,
570    /// Total connections closed across all shards.
571    pub connections_closed: u64,
572    /// Total acquires across all shards.
573    pub acquires: u64,
574    /// Total timeouts across all shards.
575    pub timeouts: u64,
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581
582    #[test]
583    fn test_query_hints_builder() {
584        let hints = QueryHints::new()
585            .target(vec!["shard_0".to_string()])
586            .with_shard_key(Value::BigInt(42))
587            .query_type("select");
588
589        assert_eq!(hints.target_shards, Some(vec!["shard_0".to_string()]));
590        assert_eq!(hints.shard_key_value, Some(Value::BigInt(42)));
591        assert_eq!(hints.query_type, Some("select".to_string()));
592    }
593
594    #[test]
595    fn test_query_hints_scatter_gather() {
596        let hints = QueryHints::new().scatter_gather();
597        assert!(hints.scatter_gather);
598    }
599
600    #[test]
601    fn test_modulo_shard_chooser_new() {
602        let chooser = ModuloShardChooser::new(4);
603        assert_eq!(chooser.shard_count(), 4);
604    }
605
606    #[test]
607    fn test_modulo_shard_chooser_with_prefix() {
608        let chooser = ModuloShardChooser::new(3).with_prefix("db_");
609        assert_eq!(
610            chooser.choose_for_model(&Value::BigInt(0)),
611            "db_0".to_string()
612        );
613        assert_eq!(
614            chooser.choose_for_model(&Value::BigInt(1)),
615            "db_1".to_string()
616        );
617    }
618
619    #[test]
620    fn test_modulo_shard_chooser_choose_for_model() {
621        let chooser = ModuloShardChooser::new(3);
622
623        assert_eq!(chooser.choose_for_model(&Value::BigInt(0)), "shard_0");
624        assert_eq!(chooser.choose_for_model(&Value::BigInt(1)), "shard_1");
625        assert_eq!(chooser.choose_for_model(&Value::BigInt(2)), "shard_2");
626        assert_eq!(chooser.choose_for_model(&Value::BigInt(3)), "shard_0");
627        assert_eq!(chooser.choose_for_model(&Value::BigInt(100)), "shard_1");
628    }
629
630    #[test]
631    fn test_modulo_shard_chooser_int_types() {
632        let chooser = ModuloShardChooser::new(2);
633
634        assert_eq!(chooser.choose_for_model(&Value::Int(5)), "shard_1");
635        assert_eq!(chooser.choose_for_model(&Value::SmallInt(4)), "shard_0");
636    }
637
638    #[test]
639    fn test_modulo_shard_chooser_negative_values() {
640        let chooser = ModuloShardChooser::new(3);
641
642        // Negative values should use absolute value
643        assert_eq!(chooser.choose_for_model(&Value::BigInt(-1)), "shard_1");
644        assert_eq!(chooser.choose_for_model(&Value::BigInt(-3)), "shard_0");
645    }
646
647    #[test]
648    fn test_modulo_shard_chooser_string_hash() {
649        let chooser = ModuloShardChooser::new(3);
650
651        // Strings should be hashed consistently
652        let shard1 = chooser.choose_for_model(&Value::Text("user_abc".to_string()));
653        let shard2 = chooser.choose_for_model(&Value::Text("user_abc".to_string()));
654        assert_eq!(shard1, shard2);
655
656        // Different strings may hash to same or different shards
657        let _ = chooser.choose_for_model(&Value::Text("user_xyz".to_string()));
658    }
659
660    #[test]
661    fn test_modulo_shard_chooser_all_shards() {
662        let chooser = ModuloShardChooser::new(3);
663        let all = chooser.all_shards();
664
665        assert_eq!(all.len(), 3);
666        assert!(all.contains(&"shard_0".to_string()));
667        assert!(all.contains(&"shard_1".to_string()));
668        assert!(all.contains(&"shard_2".to_string()));
669    }
670
671    #[test]
672    fn test_modulo_shard_chooser_choose_for_query_with_key() {
673        let chooser = ModuloShardChooser::new(3);
674        let hints = QueryHints::new().with_shard_key(Value::BigInt(5));
675
676        let shards = chooser.choose_for_query(&hints);
677        assert_eq!(shards.len(), 1);
678        assert_eq!(shards[0], "shard_2"); // 5 % 3 = 2
679    }
680
681    #[test]
682    fn test_modulo_shard_chooser_choose_for_query_scatter() {
683        let chooser = ModuloShardChooser::new(3);
684        let hints = QueryHints::new().scatter_gather();
685
686        let shards = chooser.choose_for_query(&hints);
687        assert_eq!(shards.len(), 3);
688    }
689
690    #[test]
691    fn test_modulo_shard_chooser_choose_for_query_target() {
692        let chooser = ModuloShardChooser::new(3);
693        let hints = QueryHints::new().target(vec!["shard_1".to_string()]);
694
695        let shards = chooser.choose_for_query(&hints);
696        assert_eq!(shards, vec!["shard_1"]);
697    }
698
699    #[test]
700    fn test_sharded_pool_stats_default() {
701        let stats = ShardedPoolStats::default();
702        assert_eq!(stats.shard_count, 0);
703        assert_eq!(stats.total_connections, 0);
704        assert!(stats.per_shard.is_empty());
705    }
706
707    #[test]
708    #[should_panic(expected = "shard_count must be greater than 0")]
709    fn test_modulo_shard_chooser_zero_shards_panics() {
710        let _ = ModuloShardChooser::new(0);
711    }
712}