use std::{
future::Future,
sync::{
Mutex, OnceLock,
atomic::{AtomicUsize, Ordering},
mpsc as std_mpsc,
},
thread,
};
use datum::{StreamError, StreamResult};
use tokio::runtime::{Builder, Handle};
use tokio::sync::mpsc;
pub(crate) const DEFAULT_COMMAND_BUFFER: usize = 64;
pub(crate) const DEFAULT_SHARDED_MIN_CONNECTIONS: usize = 64;
const SHARDED_TOKIO_SHARDS_ENV: &str = "DATUM_NET_SHARDED_TOKIO_SHARDS";
const SHARDED_TOKIO_MIN_CONNECTIONS_ENV: &str = "DATUM_NET_SHARDED_TOKIO_MIN_CONNECTIONS";
const SHARDED_TOKIO_DISABLE_ENV: &str = "DATUM_NET_SHARDED_TOKIO_DISABLE";
static SHARDED_CONNECTION_COUNT: AtomicUsize = AtomicUsize::new(0);
pub(crate) struct AsyncCommandSender<T> {
sender: mpsc::Sender<T>,
closed_message: &'static str,
}
impl<T> Clone for AsyncCommandSender<T> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
closed_message: self.closed_message,
}
}
}
impl<T> AsyncCommandSender<T> {
pub(crate) fn new(sender: mpsc::Sender<T>, closed_message: &'static str) -> Self {
Self {
sender,
closed_message,
}
}
pub(crate) fn send_blocking(&self, command: T) -> StreamResult<()> {
self.sender.blocking_send(command).map_err(|_| {
StreamError::Failed(format!("{} command channel closed", self.closed_message))
})
}
pub(crate) fn send_or_blocking(&self, command: T) -> StreamResult<()> {
match self.sender.try_send(command) {
Ok(()) => Ok(()),
Err(mpsc::error::TrySendError::Full(command)) => self.send_blocking(command),
Err(mpsc::error::TrySendError::Closed(_)) => Err(StreamError::Failed(format!(
"{} command channel closed",
self.closed_message
))),
}
}
pub(crate) fn try_send(&self, command: T) -> StreamResult<()> {
match self.sender.try_send(command) {
Ok(()) => Ok(()),
Err(mpsc::error::TrySendError::Full(_)) => Err(StreamError::Failed(format!(
"{} command channel full",
self.closed_message
))),
Err(mpsc::error::TrySendError::Closed(_)) => Err(StreamError::Failed(format!(
"{} command channel closed",
self.closed_message
))),
}
}
}
pub(crate) fn command_channel<T>(
capacity: usize,
closed_message: &'static str,
) -> (AsyncCommandSender<T>, mpsc::Receiver<T>) {
let (sender, receiver) = mpsc::channel(capacity.max(1));
(AsyncCommandSender::new(sender, closed_message), receiver)
}
#[derive(Debug, Clone)]
pub(crate) struct DemandBatcher {
window: usize,
refill: usize,
consumed_since_refill: usize,
}
impl DemandBatcher {
pub(crate) fn new(window: usize) -> Self {
assert!(window > 0, "demand window must be greater than zero");
Self {
window,
refill: (window / 2).max(1),
consumed_since_refill: 0,
}
}
pub(crate) fn initial(&self) -> usize {
self.window
}
pub(crate) fn record_consumed(&mut self) -> Option<usize> {
self.consumed_since_refill += 1;
if self.consumed_since_refill >= self.refill {
let demand = self.consumed_since_refill;
self.consumed_since_refill = 0;
Some(demand)
} else {
None
}
}
}
pub(crate) struct ShardedTokioCarrierExecution {
handle: Handle,
sharded: bool,
_permit: ActiveConnectionPermit,
}
impl ShardedTokioCarrierExecution {
pub(crate) fn handle(&self) -> Handle {
self.handle.clone()
}
pub(crate) fn is_sharded(&self) -> bool {
self.sharded
}
pub(crate) async fn run<T, Fut>(&self, future: Fut) -> StreamResult<T>
where
T: Send + 'static,
Fut: Future<Output = StreamResult<T>> + Send + 'static,
{
if !self.sharded {
return future.await;
}
let (sender, receiver) = tokio::sync::oneshot::channel();
self.handle.spawn(async move {
let _ = sender.send(future.await);
});
receiver.await.map_err(|_| {
StreamError::Failed("sharded Tokio carrier task ended before replying".to_owned())
})?
}
}
pub(crate) fn sharded_tokio_carrier_execution(
fallback: Handle,
active_connections: &'static AtomicUsize,
) -> ShardedTokioCarrierExecution {
let permit = ActiveConnectionPermit::new(active_connections);
let active = permit.active_connections();
let Some(shards) = sharded_tokio_shard_count(active) else {
return ShardedTokioCarrierExecution {
handle: fallback,
sharded: false,
_permit: permit,
};
};
match sharded_tokio_runtime().select(shards) {
Ok(handle) => {
SHARDED_CONNECTION_COUNT.fetch_add(1, Ordering::Relaxed);
ShardedTokioCarrierExecution {
handle,
sharded: true,
_permit: permit,
}
}
Err(_) => ShardedTokioCarrierExecution {
handle: fallback,
sharded: false,
_permit: permit,
},
}
}
static SHARDED_TOKIO_TEST_GUARD: Mutex<()> = Mutex::new(());
static SHARDED_TOKIO_TEST_CONFIG: Mutex<Option<ShardedTokioTestConfig>> = Mutex::new(None);
#[doc(hidden)]
pub struct ShardedTokioTestConfig {
pub shard_count: Option<usize>,
pub min_connections: Option<usize>,
}
#[doc(hidden)]
pub fn with_sharded_tokio_test_config<F, R>(config: ShardedTokioTestConfig, f: F) -> R
where
F: FnOnce() -> R,
{
let _guard = SHARDED_TOKIO_TEST_GUARD
.lock()
.unwrap_or_else(|e| e.into_inner());
*SHARDED_TOKIO_TEST_CONFIG
.lock()
.expect("sharded Tokio test config poisoned") = Some(config);
let result = f();
*SHARDED_TOKIO_TEST_CONFIG
.lock()
.expect("sharded Tokio test config poisoned") = None;
result
}
#[doc(hidden)]
pub fn sharded_tokio_carrier_connection_count() -> usize {
SHARDED_CONNECTION_COUNT.load(Ordering::Relaxed)
}
pub(crate) fn sharded_tokio_shard_count(active_connections: usize) -> Option<usize> {
if let Some(ref config) = *SHARDED_TOKIO_TEST_CONFIG
.lock()
.expect("sharded Tokio test config poisoned")
{
let cores = config.shard_count.unwrap_or_else(physical_cores);
let max_shards = cores.min(cores);
let min_connections = config
.min_connections
.unwrap_or(DEFAULT_SHARDED_MIN_CONNECTIONS);
if cores < 2 || max_shards < 2 || active_connections < min_connections {
return None;
}
return Some(active_connections.min(max_shards).max(1));
}
if sharding_disabled() {
return None;
}
let cores = physical_cores();
let max_shards = configured_shards().unwrap_or(cores).min(cores);
let min_connections = configured_min_connections();
if cores < 2 || max_shards < 2 || active_connections < min_connections {
return None;
}
Some(active_connections.min(max_shards).max(1))
}
struct ActiveConnectionPermit {
active_connections: &'static AtomicUsize,
}
impl ActiveConnectionPermit {
fn new(active_connections: &'static AtomicUsize) -> Self {
active_connections.fetch_add(1, Ordering::AcqRel);
Self { active_connections }
}
fn active_connections(&self) -> usize {
self.active_connections.load(Ordering::Acquire)
}
}
impl Drop for ActiveConnectionPermit {
fn drop(&mut self) {
self.active_connections.fetch_sub(1, Ordering::AcqRel);
}
}
struct ShardedTokioRuntime {
shards: Mutex<Vec<ShardRuntime>>,
next: AtomicUsize,
}
struct ShardRuntime {
handle: Handle,
_thread: thread::JoinHandle<()>,
}
impl ShardedTokioRuntime {
fn select(&self, shards: usize) -> StreamResult<Handle> {
let shards = shards.max(1);
self.ensure_shards(shards)?;
let guard = self
.shards
.lock()
.expect("sharded Tokio carrier runtime poisoned");
let index = self.next.fetch_add(1, Ordering::Relaxed) % shards;
Ok(guard[index].handle.clone())
}
fn ensure_shards(&self, shards: usize) -> StreamResult<()> {
let mut guard = self
.shards
.lock()
.expect("sharded Tokio carrier runtime poisoned");
while guard.len() < shards {
let index = guard.len();
guard.push(start_shard_runtime(index)?);
}
Ok(())
}
}
fn sharded_tokio_runtime() -> &'static ShardedTokioRuntime {
static RUNTIME: OnceLock<ShardedTokioRuntime> = OnceLock::new();
RUNTIME.get_or_init(|| ShardedTokioRuntime {
shards: Mutex::new(Vec::new()),
next: AtomicUsize::new(0),
})
}
fn start_shard_runtime(index: usize) -> StreamResult<ShardRuntime> {
let (sender, receiver) = std_mpsc::sync_channel(1);
let thread = thread::Builder::new()
.name(format!("datum-net-carrier-shard-{index}"))
.spawn(move || {
let runtime = Builder::new_current_thread().enable_all().build();
match runtime {
Ok(runtime) => {
let handle = runtime.handle().clone();
let _ = sender.send(Ok(handle));
runtime.block_on(std::future::pending::<()>());
}
Err(error) => {
let _ = sender.send(Err(error.to_string()));
}
}
})
.map_err(|error| {
StreamError::Failed(format!(
"failed to spawn sharded Tokio carrier thread: {error}"
))
})?;
let handle = receiver
.recv()
.map_err(|_| {
StreamError::Failed("sharded Tokio carrier thread exited during startup".to_owned())
})?
.map_err(|error| {
StreamError::Failed(format!(
"failed to start sharded Tokio carrier runtime: {error}"
))
})?;
Ok(ShardRuntime {
handle,
_thread: thread,
})
}
fn physical_cores() -> usize {
let logical = thread::available_parallelism()
.map(usize::from)
.unwrap_or(1);
let physical = num_cpus::get_physical().max(1);
physical.min(logical)
}
fn configured_shards() -> Option<usize> {
parse_env_usize(SHARDED_TOKIO_SHARDS_ENV).filter(|value| *value > 0)
}
fn configured_min_connections() -> usize {
parse_env_usize(SHARDED_TOKIO_MIN_CONNECTIONS_ENV)
.filter(|value| *value > 0)
.unwrap_or(DEFAULT_SHARDED_MIN_CONNECTIONS)
}
fn sharding_disabled() -> bool {
std::env::var(SHARDED_TOKIO_DISABLE_ENV)
.ok()
.is_some_and(|value| matches!(value.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
}
fn parse_env_usize(name: &str) -> Option<usize> {
std::env::var(name).ok()?.parse().ok()
}