datum-net 0.9.0

Network sources and sinks for Datum streams, built on datum-core
Documentation
//! Shared helpers for Tokio-owned network carrier tasks.
//!
//! The synchronous Datum stream core cannot await socket readiness directly.
//! Carrier modules use these small adapters to send bounded, batched demand or
//! command messages to a single Tokio task that owns the socket and IO buffers.

use std::{
    future::Future,
    sync::{
        Mutex, OnceLock,
        atomic::{AtomicUsize, Ordering},
        mpsc as std_mpsc,
    },
    thread,
};

use datum::{StreamError, StreamResult};
use tokio::runtime::{Builder, Handle};
use tokio::sync::mpsc;

pub(crate) const DEFAULT_COMMAND_BUFFER: usize = 64;
pub(crate) const DEFAULT_SHARDED_MIN_CONNECTIONS: usize = 64;

const SHARDED_TOKIO_SHARDS_ENV: &str = "DATUM_NET_SHARDED_TOKIO_SHARDS";
const SHARDED_TOKIO_MIN_CONNECTIONS_ENV: &str = "DATUM_NET_SHARDED_TOKIO_MIN_CONNECTIONS";
const SHARDED_TOKIO_DISABLE_ENV: &str = "DATUM_NET_SHARDED_TOKIO_DISABLE";

static SHARDED_CONNECTION_COUNT: AtomicUsize = AtomicUsize::new(0);

pub(crate) struct AsyncCommandSender<T> {
    sender: mpsc::Sender<T>,
    closed_message: &'static str,
}

impl<T> Clone for AsyncCommandSender<T> {
    fn clone(&self) -> Self {
        Self {
            sender: self.sender.clone(),
            closed_message: self.closed_message,
        }
    }
}

impl<T> AsyncCommandSender<T> {
    pub(crate) fn new(sender: mpsc::Sender<T>, closed_message: &'static str) -> Self {
        Self {
            sender,
            closed_message,
        }
    }

    pub(crate) fn send_blocking(&self, command: T) -> StreamResult<()> {
        self.sender.blocking_send(command).map_err(|_| {
            StreamError::Failed(format!("{} command channel closed", self.closed_message))
        })
    }

    pub(crate) fn send_or_blocking(&self, command: T) -> StreamResult<()> {
        match self.sender.try_send(command) {
            Ok(()) => Ok(()),
            Err(mpsc::error::TrySendError::Full(command)) => self.send_blocking(command),
            Err(mpsc::error::TrySendError::Closed(_)) => Err(StreamError::Failed(format!(
                "{} command channel closed",
                self.closed_message
            ))),
        }
    }

    pub(crate) fn try_send(&self, command: T) -> StreamResult<()> {
        match self.sender.try_send(command) {
            Ok(()) => Ok(()),
            Err(mpsc::error::TrySendError::Full(_)) => Err(StreamError::Failed(format!(
                "{} command channel full",
                self.closed_message
            ))),
            Err(mpsc::error::TrySendError::Closed(_)) => Err(StreamError::Failed(format!(
                "{} command channel closed",
                self.closed_message
            ))),
        }
    }
}

pub(crate) fn command_channel<T>(
    capacity: usize,
    closed_message: &'static str,
) -> (AsyncCommandSender<T>, mpsc::Receiver<T>) {
    let (sender, receiver) = mpsc::channel(capacity.max(1));
    (AsyncCommandSender::new(sender, closed_message), receiver)
}

#[derive(Debug, Clone)]
pub(crate) struct DemandBatcher {
    window: usize,
    refill: usize,
    consumed_since_refill: usize,
}

impl DemandBatcher {
    pub(crate) fn new(window: usize) -> Self {
        assert!(window > 0, "demand window must be greater than zero");
        Self {
            window,
            refill: (window / 2).max(1),
            consumed_since_refill: 0,
        }
    }

    pub(crate) fn initial(&self) -> usize {
        self.window
    }

    pub(crate) fn record_consumed(&mut self) -> Option<usize> {
        self.consumed_since_refill += 1;
        if self.consumed_since_refill >= self.refill {
            let demand = self.consumed_since_refill;
            self.consumed_since_refill = 0;
            Some(demand)
        } else {
            None
        }
    }
}

pub(crate) struct ShardedTokioCarrierExecution {
    handle: Handle,
    sharded: bool,
    _permit: ActiveConnectionPermit,
}

impl ShardedTokioCarrierExecution {
    pub(crate) fn handle(&self) -> Handle {
        self.handle.clone()
    }

    pub(crate) fn is_sharded(&self) -> bool {
        self.sharded
    }

    pub(crate) async fn run<T, Fut>(&self, future: Fut) -> StreamResult<T>
    where
        T: Send + 'static,
        Fut: Future<Output = StreamResult<T>> + Send + 'static,
    {
        if !self.sharded {
            return future.await;
        }

        let (sender, receiver) = tokio::sync::oneshot::channel();
        self.handle.spawn(async move {
            let _ = sender.send(future.await);
        });
        receiver.await.map_err(|_| {
            StreamError::Failed("sharded Tokio carrier task ended before replying".to_owned())
        })?
    }
}

pub(crate) fn sharded_tokio_carrier_execution(
    fallback: Handle,
    active_connections: &'static AtomicUsize,
) -> ShardedTokioCarrierExecution {
    let permit = ActiveConnectionPermit::new(active_connections);
    let active = permit.active_connections();
    let Some(shards) = sharded_tokio_shard_count(active) else {
        return ShardedTokioCarrierExecution {
            handle: fallback,
            sharded: false,
            _permit: permit,
        };
    };

    match sharded_tokio_runtime().select(shards) {
        Ok(handle) => {
            SHARDED_CONNECTION_COUNT.fetch_add(1, Ordering::Relaxed);
            ShardedTokioCarrierExecution {
                handle,
                sharded: true,
                _permit: permit,
            }
        }
        Err(_) => ShardedTokioCarrierExecution {
            handle: fallback,
            sharded: false,
            _permit: permit,
        },
    }
}

static SHARDED_TOKIO_TEST_GUARD: Mutex<()> = Mutex::new(());
static SHARDED_TOKIO_TEST_CONFIG: Mutex<Option<ShardedTokioTestConfig>> = Mutex::new(None);

#[doc(hidden)]
pub struct ShardedTokioTestConfig {
    pub shard_count: Option<usize>,
    pub min_connections: Option<usize>,
}

#[doc(hidden)]
pub fn with_sharded_tokio_test_config<F, R>(config: ShardedTokioTestConfig, f: F) -> R
where
    F: FnOnce() -> R,
{
    let _guard = SHARDED_TOKIO_TEST_GUARD
        .lock()
        .unwrap_or_else(|e| e.into_inner());
    *SHARDED_TOKIO_TEST_CONFIG
        .lock()
        .expect("sharded Tokio test config poisoned") = Some(config);
    let result = f();
    *SHARDED_TOKIO_TEST_CONFIG
        .lock()
        .expect("sharded Tokio test config poisoned") = None;
    result
}

#[doc(hidden)]
pub fn sharded_tokio_carrier_connection_count() -> usize {
    SHARDED_CONNECTION_COUNT.load(Ordering::Relaxed)
}

pub(crate) fn sharded_tokio_shard_count(active_connections: usize) -> Option<usize> {
    if let Some(ref config) = *SHARDED_TOKIO_TEST_CONFIG
        .lock()
        .expect("sharded Tokio test config poisoned")
    {
        let cores = config.shard_count.unwrap_or_else(physical_cores);
        // The test override already stands in for the physical-core count, so the
        // `.min(cores)` clamp that the production path applies to
        // `configured_shards()` is a no-op here; kept for shape parity.
        let max_shards = cores.min(cores);
        let min_connections = config
            .min_connections
            .unwrap_or(DEFAULT_SHARDED_MIN_CONNECTIONS);
        if cores < 2 || max_shards < 2 || active_connections < min_connections {
            return None;
        }
        return Some(active_connections.min(max_shards).max(1));
    }

    if sharding_disabled() {
        return None;
    }

    let cores = physical_cores();
    let max_shards = configured_shards().unwrap_or(cores).min(cores);
    let min_connections = configured_min_connections();
    if cores < 2 || max_shards < 2 || active_connections < min_connections {
        return None;
    }

    Some(active_connections.min(max_shards).max(1))
}

struct ActiveConnectionPermit {
    active_connections: &'static AtomicUsize,
}

impl ActiveConnectionPermit {
    fn new(active_connections: &'static AtomicUsize) -> Self {
        active_connections.fetch_add(1, Ordering::AcqRel);
        Self { active_connections }
    }

    fn active_connections(&self) -> usize {
        self.active_connections.load(Ordering::Acquire)
    }
}

impl Drop for ActiveConnectionPermit {
    fn drop(&mut self) {
        self.active_connections.fetch_sub(1, Ordering::AcqRel);
    }
}

struct ShardedTokioRuntime {
    shards: Mutex<Vec<ShardRuntime>>,
    next: AtomicUsize,
}

struct ShardRuntime {
    handle: Handle,
    _thread: thread::JoinHandle<()>,
}

impl ShardedTokioRuntime {
    fn select(&self, shards: usize) -> StreamResult<Handle> {
        let shards = shards.max(1);
        self.ensure_shards(shards)?;
        let guard = self
            .shards
            .lock()
            .expect("sharded Tokio carrier runtime poisoned");
        let index = self.next.fetch_add(1, Ordering::Relaxed) % shards;
        Ok(guard[index].handle.clone())
    }

    fn ensure_shards(&self, shards: usize) -> StreamResult<()> {
        let mut guard = self
            .shards
            .lock()
            .expect("sharded Tokio carrier runtime poisoned");
        while guard.len() < shards {
            let index = guard.len();
            guard.push(start_shard_runtime(index)?);
        }
        Ok(())
    }
}

fn sharded_tokio_runtime() -> &'static ShardedTokioRuntime {
    static RUNTIME: OnceLock<ShardedTokioRuntime> = OnceLock::new();
    RUNTIME.get_or_init(|| ShardedTokioRuntime {
        shards: Mutex::new(Vec::new()),
        next: AtomicUsize::new(0),
    })
}

fn start_shard_runtime(index: usize) -> StreamResult<ShardRuntime> {
    let (sender, receiver) = std_mpsc::sync_channel(1);
    let thread = thread::Builder::new()
        .name(format!("datum-net-carrier-shard-{index}"))
        .spawn(move || {
            let runtime = Builder::new_current_thread().enable_all().build();
            match runtime {
                Ok(runtime) => {
                    let handle = runtime.handle().clone();
                    let _ = sender.send(Ok(handle));
                    runtime.block_on(std::future::pending::<()>());
                }
                Err(error) => {
                    let _ = sender.send(Err(error.to_string()));
                }
            }
        })
        .map_err(|error| {
            StreamError::Failed(format!(
                "failed to spawn sharded Tokio carrier thread: {error}"
            ))
        })?;

    let handle = receiver
        .recv()
        .map_err(|_| {
            StreamError::Failed("sharded Tokio carrier thread exited during startup".to_owned())
        })?
        .map_err(|error| {
            StreamError::Failed(format!(
                "failed to start sharded Tokio carrier runtime: {error}"
            ))
        })?;
    Ok(ShardRuntime {
        handle,
        _thread: thread,
    })
}

fn physical_cores() -> usize {
    let logical = thread::available_parallelism()
        .map(usize::from)
        .unwrap_or(1);
    let physical = num_cpus::get_physical().max(1);
    physical.min(logical)
}

fn configured_shards() -> Option<usize> {
    parse_env_usize(SHARDED_TOKIO_SHARDS_ENV).filter(|value| *value > 0)
}

fn configured_min_connections() -> usize {
    parse_env_usize(SHARDED_TOKIO_MIN_CONNECTIONS_ENV)
        .filter(|value| *value > 0)
        .unwrap_or(DEFAULT_SHARDED_MIN_CONNECTIONS)
}

fn sharding_disabled() -> bool {
    std::env::var(SHARDED_TOKIO_DISABLE_ENV)
        .ok()
        .is_some_and(|value| matches!(value.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
}

fn parse_env_usize(name: &str) -> Option<usize> {
    std::env::var(name).ok()?.parse().ok()
}