use chrono::{DateTime, Datelike, Utc};
use std::sync::Arc;
pub trait ShardingStrategy: Send + Sync {
fn calculate(&self, timestamp: DateTime<Utc>, total_shards: u32) -> u32;
fn name(&self) -> &'static str;
fn is_valid_shard_id(&self, shard_id: u32, total_shards: u32) -> bool;
fn current_shard(&self, total_shards: u32) -> u32;
fn boxed_clone(&self) -> Box<dyn ShardingStrategy>;
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct YearlyStrategy;
impl YearlyStrategy {
pub(crate) fn new() -> Self {
Self
}
}
impl Default for YearlyStrategy {
fn default() -> Self {
Self
}
}
impl ShardingStrategy for YearlyStrategy {
fn calculate(&self, timestamp: DateTime<Utc>, total_shards: u32) -> u32 {
let year = timestamp.year() as u32;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
year.hash(&mut hasher);
(hasher.finish() as u32) % total_shards
}
fn name(&self) -> &'static str {
"yearly"
}
fn is_valid_shard_id(&self, shard_id: u32, total_shards: u32) -> bool {
shard_id < total_shards
}
fn current_shard(&self, total_shards: u32) -> u32 {
self.calculate(Utc::now(), total_shards)
}
fn boxed_clone(&self) -> Box<dyn ShardingStrategy> {
Box::new(*self)
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct MonthlyStrategy;
impl MonthlyStrategy {
pub(crate) fn new() -> Self {
Self
}
}
impl Default for MonthlyStrategy {
fn default() -> Self {
Self
}
}
impl ShardingStrategy for MonthlyStrategy {
fn calculate(&self, timestamp: DateTime<Utc>, total_shards: u32) -> u32 {
let year_month = timestamp.year() as u32 * 12 + timestamp.month();
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
year_month.hash(&mut hasher);
(hasher.finish() as u32) % total_shards
}
fn name(&self) -> &'static str {
"monthly"
}
fn is_valid_shard_id(&self, shard_id: u32, total_shards: u32) -> bool {
shard_id < total_shards
}
fn current_shard(&self, total_shards: u32) -> u32 {
self.calculate(Utc::now(), total_shards)
}
fn boxed_clone(&self) -> Box<dyn ShardingStrategy> {
Box::new(*self)
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct DailyStrategy;
impl DailyStrategy {
pub(crate) fn new() -> Self {
Self
}
}
impl Default for DailyStrategy {
fn default() -> Self {
Self
}
}
impl ShardingStrategy for DailyStrategy {
fn calculate(&self, timestamp: DateTime<Utc>, total_shards: u32) -> u32 {
let days = timestamp.num_days_from_ce();
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
days.hash(&mut hasher);
(hasher.finish() as u32) % total_shards
}
fn name(&self) -> &'static str {
"daily"
}
fn is_valid_shard_id(&self, shard_id: u32, total_shards: u32) -> bool {
shard_id < total_shards
}
fn current_shard(&self, total_shards: u32) -> u32 {
self.calculate(Utc::now(), total_shards)
}
fn boxed_clone(&self) -> Box<dyn ShardingStrategy> {
Box::new(*self)
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct HashStrategy;
impl HashStrategy {
pub(crate) fn new() -> Self {
Self
}
}
impl Default for HashStrategy {
fn default() -> Self {
Self
}
}
impl ShardingStrategy for HashStrategy {
fn calculate(&self, timestamp: DateTime<Utc>, total_shards: u32) -> u32 {
use std::hash::{Hash, Hasher};
use twox_hash::XxHash64;
let mut hasher = XxHash64::default();
timestamp.to_rfc3339().as_bytes().hash(&mut hasher);
let hash = hasher.finish();
(hash % total_shards as u64) as u32
}
fn name(&self) -> &'static str {
"hash"
}
fn is_valid_shard_id(&self, shard_id: u32, total_shards: u32) -> bool {
shard_id < total_shards
}
fn current_shard(&self, total_shards: u32) -> u32 {
self.calculate(Utc::now(), total_shards)
}
fn boxed_clone(&self) -> Box<dyn ShardingStrategy> {
Box::new(*self)
}
}
pub(crate) fn create_strategy(name: &str) -> Box<dyn ShardingStrategy> {
match name.to_lowercase().as_str() {
"yearly" | "year" => Box::new(YearlyStrategy),
"monthly" | "month" => Box::new(MonthlyStrategy),
"daily" | "day" => Box::new(DailyStrategy),
"hash" => Box::new(HashStrategy),
_ => Box::new(YearlyStrategy), }
}
#[derive(Debug, Clone)]
pub struct ShardInfo {
pub shard_id: u32,
pub name: String,
pub connection_string: String,
}
pub struct ShardRouter {
total_shards: u32,
strategy: Box<dyn ShardingStrategy>,
shards: dashmap::DashMap<u32, ShardInfo>,
pools: dashmap::DashMap<u32, Arc<crate::pool::DbPool>>,
}
impl Default for ShardRouter {
fn default() -> Self {
Self {
total_shards: 0,
strategy: Box::new(YearlyStrategy),
shards: dashmap::DashMap::new(),
pools: dashmap::DashMap::new(),
}
}
}
impl Clone for ShardRouter {
fn clone(&self) -> Self {
Self {
total_shards: self.total_shards,
strategy: self.strategy.boxed_clone(),
shards: dashmap::DashMap::from_iter(self.shards.iter().map(|e| (*e.key(), e.value().clone()))),
pools: dashmap::DashMap::from_iter(self.pools.iter().map(|e| (*e.key(), e.value().clone()))),
}
}
}
impl ShardRouter {
pub fn new<S: ShardingStrategy + 'static>(strategy: S, total_shards: u32) -> Self {
Self {
total_shards,
strategy: Box::new(strategy),
shards: dashmap::DashMap::new(),
pools: dashmap::DashMap::new(),
}
}
pub fn with_strategy(strategy: &str, total_shards: u32) -> Self {
Self {
total_shards,
strategy: create_strategy(strategy),
shards: dashmap::DashMap::new(),
pools: dashmap::DashMap::new(),
}
}
pub async fn with_config(config: &ShardConfig) -> Result<Self, crate::error::DbError> {
let mut router = Self::with_strategy(&config.strategy, config.total_shards);
let connections: Vec<(u32, String, String)> = config
.generate_all_connections()
.into_iter()
.map(|(shard_id, connection_string)| {
let name = format!("{}_{}", config.prefix, shard_id);
(shard_id, name, connection_string)
})
.collect();
use futures::stream::{self, StreamExt};
let pool_futures: Vec<_> = connections
.iter()
.map(|(shard_id, name, connection_string)| {
let conn_string = connection_string.clone();
let shard_id_copy = *shard_id;
let name_copy = name.clone();
async move {
let result = crate::pool::DbPool::new(&conn_string).await;
(shard_id_copy, name_copy, result)
}
})
.collect();
let mut pool_stream = stream::iter(pool_futures).buffer_unordered(config.total_shards as usize);
let mut results: Vec<(u32, String, Result<crate::pool::DbPool, crate::error::DbError>)> = Vec::new();
while let Some(result) = pool_stream.next().await {
results.push(result);
}
for (shard_id, name, result) in results {
match result {
Ok(pool) => {
router.register_shard_with_pool(
shard_id,
name,
format!("{}_{}", config.prefix, shard_id),
Arc::new(pool),
);
}
Err(e) => {
tracing::warn!("Failed to create connection pool for shard {}: {}", shard_id, e);
router.register_shard(shard_id, name, format!("{}_{}", config.prefix, shard_id));
}
}
}
Ok(router)
}
pub fn with_config_sync(config: &ShardConfig) -> Self {
let mut router = Self::with_strategy(&config.strategy, config.total_shards);
for (shard_id, connection_string) in config.generate_all_connections() {
router.register_shard(shard_id, format!("{}_{}", config.prefix, shard_id), connection_string);
}
router
}
pub fn register_shard(&mut self, shard_id: u32, name: String, connection_string: String) {
self.shards.insert(
shard_id,
ShardInfo {
shard_id,
name,
connection_string,
},
);
}
pub fn register_shard_with_pool(
&mut self,
shard_id: u32,
name: String,
connection_string: String,
pool: Arc<crate::pool::DbPool>,
) {
self.shards.insert(
shard_id,
ShardInfo {
shard_id,
name,
connection_string: connection_string.clone(),
},
);
self.pools.insert(shard_id, pool);
}
pub fn set_pool(&mut self, shard_id: u32, pool: Arc<crate::pool::DbPool>) -> Result<(), crate::error::DbError> {
if !self.shards.contains_key(&shard_id) {
return Err(crate::error::DbError::Config(format!(
"Shard {} not registered",
shard_id
)));
}
self.pools.insert(shard_id, pool);
Ok(())
}
pub fn route(&self, timestamp: DateTime<Utc>) -> Option<ShardInfo> {
let shard_id = self.strategy.calculate(timestamp, self.total_shards);
self.shards.get(&shard_id).map(|r| r.value().clone())
}
pub fn route_with_key(&self, timestamp: DateTime<Utc>, key: &str) -> Option<ShardInfo> {
let shard_id = self.calculate_shard(timestamp, key);
self.shards.get(&shard_id).map(|r| r.value().clone())
}
pub fn calculate_shard(&self, timestamp: DateTime<Utc>, key: &str) -> u32 {
if key.is_empty() {
self.strategy.calculate(timestamp, self.total_shards)
} else {
use std::hash::{Hash, Hasher};
use twox_hash::XxHash64;
let mut hasher = XxHash64::default();
timestamp.to_rfc3339().as_bytes().hash(&mut hasher);
key.as_bytes().hash(&mut hasher);
let hash = hasher.finish();
(hash % self.total_shards as u64) as u32
}
}
pub fn all_shards(&self) -> Vec<ShardInfo> {
self.shards.iter().map(|r| r.value().clone()).collect()
}
pub fn strategy_name(&self) -> &'static str {
self.strategy.name()
}
pub fn total_shards(&self) -> u32 {
self.total_shards
}
pub fn get_pool(&self, shard_id: u32) -> Option<Arc<crate::pool::DbPool>> {
self.pools.get(&shard_id).map(|r| r.value().clone())
}
pub async fn get_session(&self, shard_id: u32) -> Result<Option<crate::pool::Session>, crate::error::DbError> {
if let Some(pool) = self.pools.get(&shard_id) {
let session = pool.get_session("default").await?;
Ok(Some(session))
} else {
Ok(None)
}
}
pub async fn get_session_for_timestamp(
&self,
timestamp: DateTime<Utc>,
) -> Result<Option<crate::pool::Session>, crate::error::DbError> {
let shard_id = self.strategy.calculate(timestamp, self.total_shards);
self.get_session(shard_id).await
}
pub fn initialized_shards(&self) -> Vec<u32> {
self.pools.iter().map(|r| *r.key()).collect()
}
pub fn has_pool(&self, shard_id: u32) -> bool {
self.pools.contains_key(&shard_id)
}
pub fn pool_count(&self) -> usize {
self.pools.len()
}
pub fn remove_pool(&mut self, shard_id: u32) -> Option<Arc<crate::pool::DbPool>> {
self.pools.remove(&shard_id).map(|(_, v)| v)
}
pub fn clear_pools(&mut self) {
self.pools.clear();
}
}
#[derive(Debug, Clone)]
pub struct ShardConfig {
pub strategy: String,
pub total_shards: u32,
pub prefix: String,
pub connection_template: String,
}
impl Default for ShardConfig {
fn default() -> Self {
Self {
strategy: "yearly".to_string(),
total_shards: 12,
prefix: "db".to_string(),
connection_template: "sqlite:./data/{shard}.db".to_string(),
}
}
}
impl ShardConfig {
pub fn new(strategy: &str, total_shards: u32, prefix: &str, connection_template: &str) -> Self {
Self {
strategy: strategy.to_string(),
total_shards,
prefix: prefix.to_string(),
connection_template: connection_template.to_string(),
}
}
pub fn generate_connection_string(&self, shard_id: u32) -> String {
self.connection_template
.replace("{shard}", &format!("{}_{}", self.prefix, shard_id))
.replace("{prefix}", &self.prefix)
.replace("{id}", &shard_id.to_string())
}
pub fn generate_all_connections(&self) -> Vec<(u32, String)> {
(0..self.total_shards)
.map(|id| (id, self.generate_connection_string(id)))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{TimeZone, Utc};
#[test]
fn test_yearly_strategy() {
let strategy = YearlyStrategy;
let dt = Utc.with_ymd_and_hms(2024, 1, 15, 0, 0, 0).unwrap();
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let year = 2024u32;
let mut hasher = DefaultHasher::new();
year.hash(&mut hasher);
let expected = (hasher.finish() as u32) % 12;
assert_eq!(strategy.calculate(dt, 12), expected);
assert_eq!(strategy.name(), "yearly");
}
#[test]
fn test_monthly_strategy() {
let strategy = MonthlyStrategy;
let dt = Utc.with_ymd_and_hms(2024, 3, 15, 0, 0, 0).unwrap();
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let year_month = 2024u32 * 12 + 3;
let mut hasher = DefaultHasher::new();
year_month.hash(&mut hasher);
let expected = (hasher.finish() as u32) % 100;
assert_eq!(strategy.calculate(dt, 100), expected);
assert_eq!(strategy.name(), "monthly");
}
#[test]
fn test_daily_strategy() {
let strategy = DailyStrategy;
let dt = Utc.with_ymd_and_hms(2024, 1, 15, 0, 0, 0).unwrap();
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let days = dt.num_days_from_ce();
let mut hasher = DefaultHasher::new();
days.hash(&mut hasher);
let expected = (hasher.finish() as u32) % 100;
assert_eq!(strategy.calculate(dt, 100), expected);
assert_eq!(strategy.name(), "daily");
}
#[test]
fn test_shard_router() {
let mut router = ShardRouter::with_strategy("yearly", 12);
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let year = 2024u32;
let mut hasher = DefaultHasher::new();
year.hash(&mut hasher);
let shard_id = (hasher.finish() as u32) % 12;
router.register_shard(shard_id, "db_2024".to_string(), "sqlite:./data/db_2024.db".to_string());
router.register_shard(
(shard_id + 1) % 12,
"db_2025".to_string(),
"sqlite:./data/db_2025.db".to_string(),
);
let dt = Utc.with_ymd_and_hms(2024, 6, 15, 0, 0, 0).unwrap();
let calculated_shard = router.calculate_shard(dt, "");
let shard = router.route(dt);
assert!(
shard.is_some(),
"Expected shard to be Some, but got None. Calculated shard: {}",
calculated_shard
);
assert_eq!(shard.unwrap().name, "db_2024");
}
#[test]
fn test_shard_config() {
let config = ShardConfig::new("yearly", 12, "order", "postgresql://localhost/{shard}");
assert_eq!(config.generate_connection_string(4), "postgresql://localhost/order_4");
assert_eq!(config.strategy, "yearly");
}
#[tokio::test]
async fn test_router_with_config() {
let config = ShardConfig::new("yearly", 4, "data", "postgresql://localhost/{shard}");
let router = ShardRouter::with_config(&config).await.unwrap();
assert_eq!(router.total_shards(), 4);
assert_eq!(router.all_shards().len(), 4);
assert_eq!(router.strategy_name(), "yearly");
}
#[test]
fn test_strategy_factory_invalid_name() {
let invalid_strategy = create_strategy("invalid_strategy_name");
let default_strategy = create_strategy("default");
let test_time = Utc::now();
let shard1 = invalid_strategy.calculate(test_time, 12);
let shard2 = default_strategy.calculate(test_time, 12);
assert_eq!(shard1, shard2, "Invalid strategy should fall back to default");
}
#[test]
fn test_strategy_factory_aliases() {
let test_time = Utc::now();
let yearly1 = create_strategy("yearly");
let yearly2 = create_strategy("year");
let monthly1 = create_strategy("monthly");
let monthly2 = create_strategy("month");
let daily1 = create_strategy("daily");
let daily2 = create_strategy("day");
let y1 = yearly1.calculate(test_time, 12);
let y2 = yearly2.calculate(test_time, 12);
let m1 = monthly1.calculate(test_time, 100);
let m2 = monthly2.calculate(test_time, 100);
let d1 = daily1.calculate(test_time, 365);
let d2 = daily2.calculate(test_time, 365);
assert_eq!(y1, y2, "'year' should alias to 'yearly'");
assert_eq!(m1, m2, "'month' should alias to 'monthly'");
assert_eq!(d1, d2, "'day' should alias to 'daily'");
}
#[test]
fn test_strategy_factory_case_insensitive() {
let test_time = Utc::now();
let variants = vec!["YEARLY", "Yearly", "yEaRlY"];
let base_shard = create_strategy("yearly").calculate(test_time, 12);
for variant in variants {
let strategy = create_strategy(variant);
let shard = strategy.calculate(test_time, 12);
assert_eq!(shard, base_shard, "'{}' should work like 'yearly'", variant);
}
}
}