use tokio::sync::broadcast;
use crate::dropper::DropHandle;
use crate::error::ShardError;
use crate::keyspace::ShardConfig;
use crate::shard::{
self, PreparedShard, ReplicationEvent, ShardHandle, ShardPersistenceConfig, ShardRequest,
ShardResponse,
};
const DEFAULT_SHARD_BUFFER: usize = 4096;
#[derive(Debug, Clone, Default)]
pub struct EngineConfig {
pub shard: ShardConfig,
pub persistence: Option<ShardPersistenceConfig>,
pub replication_tx: Option<broadcast::Sender<ReplicationEvent>>,
#[cfg(feature = "protobuf")]
pub schema_registry: Option<crate::schema::SharedSchemaRegistry>,
pub shard_channel_buffer: usize,
}
#[derive(Debug, Clone)]
pub struct Engine {
shards: Vec<ShardHandle>,
replication_tx: Option<broadcast::Sender<ReplicationEvent>>,
#[cfg(feature = "protobuf")]
schema_registry: Option<crate::schema::SharedSchemaRegistry>,
}
impl Engine {
pub fn new(shard_count: usize) -> Self {
Self::with_config(shard_count, EngineConfig::default())
}
pub fn with_config(shard_count: usize, config: EngineConfig) -> Self {
assert!(shard_count > 0, "shard count must be at least 1");
assert!(
shard_count <= u16::MAX as usize,
"shard count must fit in u16"
);
let drop_handle = DropHandle::spawn();
let buffer = if config.shard_channel_buffer == 0 {
DEFAULT_SHARD_BUFFER
} else {
config.shard_channel_buffer
};
let shards = (0..shard_count)
.map(|i| {
let mut shard_config = config.shard.clone();
shard_config.shard_id = i as u16;
shard::spawn_shard(
buffer,
shard_config,
config.persistence.clone(),
Some(drop_handle.clone()),
config.replication_tx.clone(),
#[cfg(feature = "protobuf")]
config.schema_registry.clone(),
)
})
.collect();
Self {
shards,
replication_tx: config.replication_tx,
#[cfg(feature = "protobuf")]
schema_registry: config.schema_registry,
}
}
pub fn prepare(shard_count: usize, config: EngineConfig) -> (Self, Vec<PreparedShard>) {
assert!(shard_count > 0, "shard count must be at least 1");
assert!(
shard_count <= u16::MAX as usize,
"shard count must fit in u16"
);
let drop_handle = DropHandle::spawn();
let buffer = if config.shard_channel_buffer == 0 {
DEFAULT_SHARD_BUFFER
} else {
config.shard_channel_buffer
};
let mut handles = Vec::with_capacity(shard_count);
let mut prepared = Vec::with_capacity(shard_count);
for i in 0..shard_count {
let mut shard_config = config.shard.clone();
shard_config.shard_id = i as u16;
let (handle, shard) = shard::prepare_shard(
buffer,
shard_config,
config.persistence.clone(),
Some(drop_handle.clone()),
config.replication_tx.clone(),
#[cfg(feature = "protobuf")]
config.schema_registry.clone(),
);
handles.push(handle);
prepared.push(shard);
}
let engine = Self {
shards: handles,
replication_tx: config.replication_tx,
#[cfg(feature = "protobuf")]
schema_registry: config.schema_registry,
};
(engine, prepared)
}
pub fn with_available_cores() -> Self {
Self::with_available_cores_config(EngineConfig::default())
}
pub fn with_available_cores_config(config: EngineConfig) -> Self {
let cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
Self::with_config(cores, config)
}
#[cfg(feature = "protobuf")]
pub fn schema_registry(&self) -> Option<&crate::schema::SharedSchemaRegistry> {
self.schema_registry.as_ref()
}
pub fn shard_count(&self) -> usize {
self.shards.len()
}
pub fn subscribe_replication(&self) -> Option<broadcast::Receiver<ReplicationEvent>> {
self.replication_tx.as_ref().map(|tx| tx.subscribe())
}
pub async fn send_to_shard(
&self,
shard_idx: usize,
request: ShardRequest,
) -> Result<ShardResponse, ShardError> {
if shard_idx >= self.shards.len() {
return Err(ShardError::Unavailable);
}
self.shards[shard_idx].send(request).await
}
pub async fn route(
&self,
key: &str,
request: ShardRequest,
) -> Result<ShardResponse, ShardError> {
let idx = self.shard_for_key(key);
self.shards[idx].send(request).await
}
pub async fn broadcast<F>(&self, make_req: F) -> Result<Vec<ShardResponse>, ShardError>
where
F: Fn() -> ShardRequest,
{
let mut receivers = Vec::with_capacity(self.shards.len());
for shard in &self.shards {
receivers.push(shard.dispatch(make_req()).await?);
}
let mut results = Vec::with_capacity(receivers.len());
for rx in receivers {
results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
}
Ok(results)
}
pub async fn route_multi<F>(
&self,
keys: &[String],
make_req: F,
) -> Result<Vec<ShardResponse>, ShardError>
where
F: Fn(String) -> ShardRequest,
{
let mut receivers = Vec::with_capacity(keys.len());
for key in keys {
let idx = self.shard_for_key(key);
let rx = self.shards[idx].dispatch(make_req(key.clone())).await?;
receivers.push(rx);
}
let mut results = Vec::with_capacity(receivers.len());
for rx in receivers {
results.push(rx.await.map_err(|_| ShardError::Unavailable)?);
}
Ok(results)
}
pub fn same_shard(&self, key1: &str, key2: &str) -> bool {
self.shard_for_key(key1) == self.shard_for_key(key2)
}
pub fn shard_for_key(&self, key: &str) -> usize {
shard_index(key, self.shards.len())
}
pub async fn dispatch_to_shard(
&self,
shard_idx: usize,
request: ShardRequest,
) -> Result<tokio::sync::oneshot::Receiver<ShardResponse>, ShardError> {
if shard_idx >= self.shards.len() {
return Err(ShardError::Unavailable);
}
self.shards[shard_idx].dispatch(request).await
}
pub async fn dispatch_reusable_to_shard(
&self,
shard_idx: usize,
request: ShardRequest,
reply: tokio::sync::mpsc::Sender<ShardResponse>,
) -> Result<(), ShardError> {
if shard_idx >= self.shards.len() {
return Err(ShardError::Unavailable);
}
self.shards[shard_idx]
.dispatch_reusable(request, reply)
.await
}
pub async fn dispatch_batch_to_shard(
&self,
shard_idx: usize,
requests: Vec<ShardRequest>,
) -> Result<Vec<tokio::sync::oneshot::Receiver<ShardResponse>>, ShardError> {
if shard_idx >= self.shards.len() {
return Err(ShardError::Unavailable);
}
self.shards[shard_idx].dispatch_batch(requests).await
}
}
fn shard_index(key: &str, shard_count: usize) -> usize {
const FNV_OFFSET: u64 = 0xcbf29ce484222325;
const FNV_PRIME: u64 = 0x100000001b3;
let mut hash = FNV_OFFSET;
for byte in key.as_bytes() {
hash ^= *byte as u64;
hash = hash.wrapping_mul(FNV_PRIME);
}
(hash as usize) % shard_count
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Value;
use bytes::Bytes;
#[test]
fn same_key_same_shard() {
let idx1 = shard_index("foo", 8);
let idx2 = shard_index("foo", 8);
assert_eq!(idx1, idx2);
}
#[test]
fn keys_spread_across_shards() {
let mut seen = std::collections::HashSet::new();
for i in 0..100 {
let key = format!("key:{i}");
seen.insert(shard_index(&key, 4));
}
assert!(seen.len() > 1, "expected keys to spread across shards");
}
#[test]
fn single_shard_always_zero() {
assert_eq!(shard_index("anything", 1), 0);
assert_eq!(shard_index("other", 1), 0);
}
#[tokio::test]
async fn engine_round_trip() {
let engine = Engine::new(4);
let resp = engine
.route(
"greeting",
ShardRequest::Set {
key: "greeting".into(),
value: Bytes::from("hello"),
expire: None,
nx: false,
xx: false,
},
)
.await
.unwrap();
assert!(matches!(resp, ShardResponse::Ok));
let resp = engine
.route(
"greeting",
ShardRequest::Get {
key: "greeting".into(),
},
)
.await
.unwrap();
match resp {
ShardResponse::Value(Some(Value::String(data))) => {
assert_eq!(data, Bytes::from("hello"));
}
other => panic!("expected Value(Some(String)), got {other:?}"),
}
}
#[tokio::test]
async fn multi_shard_del() {
let engine = Engine::new(4);
for key in &["a", "b", "c", "d"] {
engine
.route(
key,
ShardRequest::Set {
key: key.to_string(),
value: Bytes::from("v"),
expire: None,
nx: false,
xx: false,
},
)
.await
.unwrap();
}
let mut count = 0i64;
for key in &["a", "b", "c", "d", "missing"] {
let resp = engine
.route(
key,
ShardRequest::Del {
key: key.to_string(),
},
)
.await
.unwrap();
if let ShardResponse::Bool(true) = resp {
count += 1;
}
}
assert_eq!(count, 4);
}
#[test]
#[should_panic(expected = "shard count must be at least 1")]
fn zero_shards_panics() {
Engine::new(0);
}
}