use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use asupersync::{Cx, Outcome};
use sqlmodel_core::error::{PoolError, PoolErrorKind};
use sqlmodel_core::{Connection, Error, Model, Value};
use crate::{Pool, PoolConfig, PooledConnection};
#[derive(Debug, Clone, Default)]
pub struct QueryHints {
pub target_shards: Option<Vec<String>>,
pub scatter_gather: bool,
pub shard_key_value: Option<Value>,
pub query_type: Option<String>,
}
impl QueryHints {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn target(mut self, shards: Vec<String>) -> Self {
self.target_shards = Some(shards);
self
}
#[must_use]
pub fn scatter_gather(mut self) -> Self {
self.scatter_gather = true;
self
}
#[must_use]
pub fn with_shard_key(mut self, value: Value) -> Self {
self.shard_key_value = Some(value);
self
}
#[must_use]
pub fn query_type(mut self, query_type: impl Into<String>) -> Self {
self.query_type = Some(query_type.into());
self
}
}
pub trait ShardChooser: Send + Sync {
fn choose_for_model(&self, shard_key: &Value) -> String;
fn choose_for_query(&self, hints: &QueryHints) -> Vec<String>;
fn all_shards(&self) -> Vec<String> {
vec![]
}
}
#[derive(Debug, Clone)]
pub struct ModuloShardChooser {
shard_count: usize,
shard_prefix: String,
}
impl ModuloShardChooser {
#[must_use]
pub fn new(shard_count: usize) -> Self {
assert!(shard_count > 0, "shard_count must be greater than 0");
Self {
shard_count,
shard_prefix: "shard_".to_string(),
}
}
#[must_use]
pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
self.shard_prefix = prefix.into();
self
}
#[must_use]
pub fn shard_count(&self) -> usize {
self.shard_count
}
#[allow(clippy::cast_possible_truncation)]
fn extract_numeric(&self, value: &Value) -> usize {
match value {
Value::BigInt(n) => (*n).unsigned_abs() as usize,
Value::Int(n) => (*n).unsigned_abs() as usize,
Value::SmallInt(n) => (*n).unsigned_abs() as usize,
Value::Text(s) => {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
s.hash(&mut hasher);
hasher.finish() as usize
}
_ => 0,
}
}
}
impl ShardChooser for ModuloShardChooser {
fn choose_for_model(&self, shard_key: &Value) -> String {
let n = self.extract_numeric(shard_key);
format!("{}{}", self.shard_prefix, n % self.shard_count)
}
fn choose_for_query(&self, hints: &QueryHints) -> Vec<String> {
if let Some(ref targets) = hints.target_shards {
return targets.clone();
}
if let Some(ref value) = hints.shard_key_value {
return vec![self.choose_for_model(value)];
}
self.all_shards()
}
fn all_shards(&self) -> Vec<String> {
(0..self.shard_count)
.map(|i| format!("{}{}", self.shard_prefix, i))
.collect()
}
}
pub struct ShardedPool<C: Connection, S: ShardChooser> {
shards: HashMap<String, Pool<C>>,
chooser: Arc<S>,
}
impl<C: Connection, S: ShardChooser> ShardedPool<C, S> {
pub fn new(chooser: S) -> Self {
Self {
shards: HashMap::new(),
chooser: Arc::new(chooser),
}
}
pub fn add_shard(&mut self, name: impl Into<String>, pool: Pool<C>) {
self.shards.insert(name.into(), pool);
}
pub fn add_shard_with_config(&mut self, name: impl Into<String>, config: PoolConfig) {
self.shards.insert(name.into(), Pool::new(config));
}
pub fn chooser(&self) -> &S {
&self.chooser
}
pub fn get_shard(&self, name: &str) -> Option<&Pool<C>> {
self.shards.get(name)
}
pub fn shard_names(&self) -> Vec<String> {
self.shards.keys().cloned().collect()
}
pub fn shard_count(&self) -> usize {
self.shards.len()
}
pub fn has_shard(&self, name: &str) -> bool {
self.shards.contains_key(name)
}
#[allow(clippy::result_large_err)]
pub fn choose_for_model<M: Model>(&self, model: &M) -> Result<String, Error> {
let shard_key = model.shard_key_value().ok_or_else(|| {
Error::Pool(PoolError {
kind: PoolErrorKind::Config,
message: format!(
"Model {} has no shard key defined; add #[sqlmodel(shard_key = \"field\")]",
M::TABLE_NAME
),
source: None,
})
})?;
Ok(self.chooser.choose_for_model(&shard_key))
}
pub fn choose_for_query(&self, hints: &QueryHints) -> Vec<String> {
self.chooser.choose_for_query(hints)
}
pub async fn acquire_for_model<M, F, Fut>(
&self,
cx: &Cx,
model: &M,
factory: F,
) -> Outcome<PooledConnection<C>, Error>
where
M: Model,
F: Fn() -> Fut,
Fut: Future<Output = Outcome<C, Error>>,
{
let shard_name = match self.choose_for_model(model) {
Ok(name) => name,
Err(e) => return Outcome::Err(e),
};
self.acquire_from_shard(cx, &shard_name, factory).await
}
pub async fn acquire_from_shard<F, Fut>(
&self,
cx: &Cx,
shard_name: &str,
factory: F,
) -> Outcome<PooledConnection<C>, Error>
where
F: Fn() -> Fut,
Fut: Future<Output = Outcome<C, Error>>,
{
let Some(pool) = self.shards.get(shard_name) else {
return Outcome::Err(Error::Pool(PoolError {
kind: PoolErrorKind::Config,
message: format!(
"shard '{}' not found; available shards: {:?}",
shard_name,
self.shard_names()
),
source: None,
}));
};
pool.acquire(cx, factory).await
}
pub async fn acquire_for_query<F, Fut>(
&self,
cx: &Cx,
hints: &QueryHints,
factory: F,
) -> Result<HashMap<String, PooledConnection<C>>, Error>
where
F: Fn() -> Fut + Clone,
Fut: Future<Output = Outcome<C, Error>>,
{
let target_shards = self.choose_for_query(hints);
let mut connections = HashMap::new();
for shard_name in target_shards {
match self
.acquire_from_shard(cx, &shard_name, factory.clone())
.await
{
Outcome::Ok(conn) => {
connections.insert(shard_name, conn);
}
Outcome::Err(e) => {
tracing::warn!(shard = %shard_name, error = %e, "Failed to acquire connection from shard");
}
Outcome::Cancelled(reason) => {
tracing::debug!(shard = %shard_name, reason = ?reason, "Cancelled while acquiring from shard");
}
Outcome::Panicked(info) => {
tracing::error!(shard = %shard_name, panic = ?info, "Panic while acquiring from shard");
}
}
}
if connections.is_empty() {
return Err(Error::Pool(PoolError {
kind: PoolErrorKind::Exhausted,
message: "failed to acquire connection from any shard".to_string(),
source: None,
}));
}
Ok(connections)
}
pub fn close(&self) {
for pool in self.shards.values() {
pool.close();
}
}
pub fn is_closed(&self) -> bool {
self.shards.values().all(|p| p.is_closed())
}
pub fn stats(&self) -> ShardedPoolStats {
let mut total = ShardedPoolStats::default();
for (name, pool) in &self.shards {
let shard_stats = pool.stats();
total.per_shard.insert(name.clone(), shard_stats.clone());
total.total_connections += shard_stats.total_connections;
total.idle_connections += shard_stats.idle_connections;
total.active_connections += shard_stats.active_connections;
total.pending_requests += shard_stats.pending_requests;
total.connections_created += shard_stats.connections_created;
total.connections_closed += shard_stats.connections_closed;
total.acquires += shard_stats.acquires;
total.timeouts += shard_stats.timeouts;
}
total.shard_count = self.shards.len();
total
}
}
#[derive(Debug, Clone, Default)]
pub struct ShardedPoolStats {
pub shard_count: usize,
pub per_shard: HashMap<String, crate::PoolStats>,
pub total_connections: usize,
pub idle_connections: usize,
pub active_connections: usize,
pub pending_requests: usize,
pub connections_created: u64,
pub connections_closed: u64,
pub acquires: u64,
pub timeouts: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_hints_builder() {
let hints = QueryHints::new()
.target(vec!["shard_0".to_string()])
.with_shard_key(Value::BigInt(42))
.query_type("select");
assert_eq!(hints.target_shards, Some(vec!["shard_0".to_string()]));
assert_eq!(hints.shard_key_value, Some(Value::BigInt(42)));
assert_eq!(hints.query_type, Some("select".to_string()));
}
#[test]
fn test_query_hints_scatter_gather() {
let hints = QueryHints::new().scatter_gather();
assert!(hints.scatter_gather);
}
#[test]
fn test_modulo_shard_chooser_new() {
let chooser = ModuloShardChooser::new(4);
assert_eq!(chooser.shard_count(), 4);
}
#[test]
fn test_modulo_shard_chooser_with_prefix() {
let chooser = ModuloShardChooser::new(3).with_prefix("db_");
assert_eq!(
chooser.choose_for_model(&Value::BigInt(0)),
"db_0".to_string()
);
assert_eq!(
chooser.choose_for_model(&Value::BigInt(1)),
"db_1".to_string()
);
}
#[test]
fn test_modulo_shard_chooser_choose_for_model() {
let chooser = ModuloShardChooser::new(3);
assert_eq!(chooser.choose_for_model(&Value::BigInt(0)), "shard_0");
assert_eq!(chooser.choose_for_model(&Value::BigInt(1)), "shard_1");
assert_eq!(chooser.choose_for_model(&Value::BigInt(2)), "shard_2");
assert_eq!(chooser.choose_for_model(&Value::BigInt(3)), "shard_0");
assert_eq!(chooser.choose_for_model(&Value::BigInt(100)), "shard_1");
}
#[test]
fn test_modulo_shard_chooser_int_types() {
let chooser = ModuloShardChooser::new(2);
assert_eq!(chooser.choose_for_model(&Value::Int(5)), "shard_1");
assert_eq!(chooser.choose_for_model(&Value::SmallInt(4)), "shard_0");
}
#[test]
fn test_modulo_shard_chooser_negative_values() {
let chooser = ModuloShardChooser::new(3);
assert_eq!(chooser.choose_for_model(&Value::BigInt(-1)), "shard_1");
assert_eq!(chooser.choose_for_model(&Value::BigInt(-3)), "shard_0");
}
#[test]
fn test_modulo_shard_chooser_string_hash() {
let chooser = ModuloShardChooser::new(3);
let shard1 = chooser.choose_for_model(&Value::Text("user_abc".to_string()));
let shard2 = chooser.choose_for_model(&Value::Text("user_abc".to_string()));
assert_eq!(shard1, shard2);
let _ = chooser.choose_for_model(&Value::Text("user_xyz".to_string()));
}
#[test]
fn test_modulo_shard_chooser_all_shards() {
let chooser = ModuloShardChooser::new(3);
let all = chooser.all_shards();
assert_eq!(all.len(), 3);
assert!(all.contains(&"shard_0".to_string()));
assert!(all.contains(&"shard_1".to_string()));
assert!(all.contains(&"shard_2".to_string()));
}
#[test]
fn test_modulo_shard_chooser_choose_for_query_with_key() {
let chooser = ModuloShardChooser::new(3);
let hints = QueryHints::new().with_shard_key(Value::BigInt(5));
let shards = chooser.choose_for_query(&hints);
assert_eq!(shards.len(), 1);
assert_eq!(shards[0], "shard_2"); }
#[test]
fn test_modulo_shard_chooser_choose_for_query_scatter() {
let chooser = ModuloShardChooser::new(3);
let hints = QueryHints::new().scatter_gather();
let shards = chooser.choose_for_query(&hints);
assert_eq!(shards.len(), 3);
}
#[test]
fn test_modulo_shard_chooser_choose_for_query_target() {
let chooser = ModuloShardChooser::new(3);
let hints = QueryHints::new().target(vec!["shard_1".to_string()]);
let shards = chooser.choose_for_query(&hints);
assert_eq!(shards, vec!["shard_1"]);
}
#[test]
fn test_sharded_pool_stats_default() {
let stats = ShardedPoolStats::default();
assert_eq!(stats.shard_count, 0);
assert_eq!(stats.total_connections, 0);
assert!(stats.per_shard.is_empty());
}
#[test]
#[should_panic(expected = "shard_count must be greater than 0")]
fn test_modulo_shard_chooser_zero_shards_panics() {
let _ = ModuloShardChooser::new(0);
}
}