loadwise-core 0.1.0

Core traits, strategies, and in-memory stores for loadwise
Documentation
//! Per-node multi-dimensional quota tracking.

use std::collections::HashMap;
use std::hash::Hash;
use std::sync::RwLock;
use std::time::{Duration, Instant};

use crate::window::FixedWindowCounter;

/// Configuration for a single quota dimension (e.g., "rpm" or "tpm").
#[derive(Clone, bon::Builder)]
pub struct QuotaDimension {
    /// Dimension name (e.g., "rpm", "tpm").
    pub name: String,
    /// Time window for this dimension.
    pub window: Duration,
    /// Maximum allowed units within the window.
    pub limit: u64,
    /// Number of time buckets within the window (higher = more precision, more memory).
    /// **Default: 6.**
    #[builder(default = 6)]
    pub resolution: usize,
}

/// Per-node quota configuration.
#[derive(Clone, bon::Builder)]
pub struct QuotaConfig {
    /// Quota dimensions to track for this node.
    pub dimensions: Vec<QuotaDimension>,
}

/// Internal per-node state.
struct NodeQuota {
    /// (dimension_name, limit, counter)
    dimensions: Vec<(String, u64, FixedWindowCounter)>,
    /// If set, node is temporarily exhausted until this instant.
    exhausted_until: Option<Instant>,
}

/// Tracks rate/quota usage for a set of nodes.
///
/// Thread-safe — all methods take `&self` and synchronise internally via [`RwLock`].
///
/// # Examples
///
/// ```
/// # extern crate loadwise_core as loadwise;
/// use loadwise::quota::{QuotaTracker, QuotaConfig, QuotaDimension};
/// use std::time::Duration;
///
/// let tracker = QuotaTracker::<String>::new();
///
/// tracker.register(&"key-1".into(), QuotaConfig::builder()
///     .dimensions(vec![
///         QuotaDimension::builder()
///             .name("rpm".into())
///             .window(Duration::from_secs(60))
///             .limit(100)
///             .build(),
///     ])
///     .build());
///
/// tracker.record_usage(&"key-1".into(), &[("rpm", 1)]);
/// assert_eq!(tracker.remaining(&"key-1".into(), "rpm"), 99);
/// ```
pub struct QuotaTracker<Id: Eq + Hash + Clone> {
    nodes: RwLock<HashMap<Id, NodeQuota>>,
}

impl<Id: Eq + Hash + Clone> std::fmt::Debug for QuotaTracker<Id> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let count = self.nodes.read().unwrap().len();
        f.debug_struct("QuotaTracker")
            .field("tracked_nodes", &count)
            .finish_non_exhaustive()
    }
}

impl<Id: Eq + Hash + Clone> QuotaTracker<Id> {
    /// Creates a new, empty tracker.
    pub fn new() -> Self {
        Self {
            nodes: RwLock::new(HashMap::new()),
        }
    }

    /// Register a node with its quota configuration. If already registered, replaces config.
    pub fn register(&self, id: &Id, config: QuotaConfig) {
        let dimensions = config
            .dimensions
            .into_iter()
            .map(|d| {
                let counter = FixedWindowCounter::new(d.window, d.resolution);
                (d.name, d.limit, counter)
            })
            .collect();

        let mut nodes = self.nodes.write().unwrap();
        nodes.insert(
            id.clone(),
            NodeQuota {
                dimensions,
                exhausted_until: None,
            },
        );
    }

    /// Record usage. `amounts` is a slice of `(dimension_name, amount)` pairs.
    /// Unknown dimensions are silently ignored.
    ///
    /// Returns `true` if the node was registered (usage recorded),
    /// `false` if the node ID was not found.
    pub fn record_usage(&self, id: &Id, amounts: &[(&str, u64)]) -> bool {
        let mut nodes = self.nodes.write().unwrap();
        let Some(node) = nodes.get_mut(id) else {
            return false;
        };

        // Clear expired exhaustion deadline while we hold the write lock.
        if let Some(t) = node.exhausted_until {
            if Instant::now() >= t {
                node.exhausted_until = None;
            }
        }

        for &(dim_name, amount) in amounts {
            for (name, _, counter) in &node.dimensions {
                if name == dim_name {
                    counter.record(amount);
                    break;
                }
            }
        }
        true
    }

    /// Mark a node as exhausted for `duration` (e.g., from a 429 Retry-After header).
    pub fn mark_exhausted(&self, id: &Id, duration: Duration) {
        let mut nodes = self.nodes.write().unwrap();
        if let Some(node) = nodes.get_mut(id) {
            node.exhausted_until = Some(Instant::now() + duration);
        }
    }

    /// Whether the node has capacity for the estimated usage across ALL dimensions.
    /// Returns `false` if the node is marked exhausted or any dimension would exceed its limit.
    /// Returns `true` for unregistered nodes (unknown = no quota constraint).
    pub fn has_capacity(&self, id: &Id, estimated: &[(&str, u64)]) -> bool {
        let nodes = self.nodes.read().unwrap();
        let Some(node) = nodes.get(id) else {
            return true;
        };

        // Check exhaustion (stale deadlines are treated as not-exhausted).
        if let Some(t) = node.exhausted_until {
            if Instant::now() < t {
                return false;
            }
        }

        for &(dim_name, est) in estimated {
            for (name, limit, counter) in &node.dimensions {
                if name == dim_name {
                    if counter.remaining(*limit) < est {
                        return false;
                    }
                    break;
                }
            }
        }

        true
    }

    /// Remaining capacity in a specific dimension. Returns `u64::MAX` for unknown nodes/dimensions.
    pub fn remaining(&self, id: &Id, dimension: &str) -> u64 {
        let nodes = self.nodes.read().unwrap();
        let Some(node) = nodes.get(id) else {
            return u64::MAX;
        };

        for (name, limit, counter) in &node.dimensions {
            if name == dimension {
                return counter.remaining(*limit);
            }
        }

        u64::MAX
    }

    /// Pressure score across all dimensions: `0.0` = fully idle, `1.0` = at least one dimension full.
    /// Computed as `max(usage_i / limit_i)` across all dimensions.
    /// Dimensions with `limit == 0` are skipped (treated as unconstrained).
    /// Returns `0.0` for unregistered nodes. Returns `1.0` if the node is marked exhausted.
    pub fn pressure(&self, id: &Id) -> f64 {
        let nodes = self.nodes.read().unwrap();
        let Some(node) = nodes.get(id) else {
            return 0.0;
        };

        // Check exhaustion (stale deadlines are treated as not-exhausted).
        if let Some(t) = node.exhausted_until {
            if Instant::now() < t {
                return 1.0;
            }
        }

        let mut max_pressure: f64 = 0.0;
        for (_, limit, counter) in &node.dimensions {
            if *limit == 0 {
                continue;
            }
            let usage = counter.sum() as f64 / *limit as f64;
            if usage > max_pressure {
                max_pressure = usage;
            }
        }

        max_pressure.min(1.0)
    }
}

impl<Id: Eq + Hash + Clone> Default for QuotaTracker<Id> {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn make_tracker_with_node(id: &str, limit: u64) -> (QuotaTracker<String>, String) {
        let tracker = QuotaTracker::<String>::new();
        let node_id = id.to_string();
        tracker.register(
            &node_id,
            QuotaConfig::builder()
                .dimensions(vec![QuotaDimension::builder()
                    .name("rpm".into())
                    .window(Duration::from_secs(60))
                    .limit(limit)
                    .build()])
                .build(),
        );
        (tracker, node_id)
    }

    #[test]
    fn basic_register_and_record() {
        let (tracker, id) = make_tracker_with_node("n1", 100);

        tracker.record_usage(&id, &[("rpm", 10)]);
        assert_eq!(tracker.remaining(&id, "rpm"), 90);

        tracker.record_usage(&id, &[("rpm", 5)]);
        assert_eq!(tracker.remaining(&id, "rpm"), 85);
    }

    #[test]
    fn pressure_increases_with_usage() {
        let (tracker, id) = make_tracker_with_node("n1", 100);

        let p0 = tracker.pressure(&id);
        assert!((p0 - 0.0).abs() < f64::EPSILON);

        tracker.record_usage(&id, &[("rpm", 50)]);
        let p50 = tracker.pressure(&id);
        assert!((p50 - 0.5).abs() < f64::EPSILON);

        tracker.record_usage(&id, &[("rpm", 50)]);
        let p100 = tracker.pressure(&id);
        assert!((p100 - 1.0).abs() < f64::EPSILON);
    }

    #[test]
    fn has_capacity_with_estimation() {
        let (tracker, id) = make_tracker_with_node("n1", 100);

        tracker.record_usage(&id, &[("rpm", 90)]);
        assert!(tracker.has_capacity(&id, &[("rpm", 10)]));
        assert!(!tracker.has_capacity(&id, &[("rpm", 11)]));
    }

    #[test]
    fn mark_exhausted_blocks_capacity() {
        let (tracker, id) = make_tracker_with_node("n1", 100);

        tracker.mark_exhausted(&id, Duration::from_secs(60));
        assert!(!tracker.has_capacity(&id, &[("rpm", 1)]));
        assert!((tracker.pressure(&id) - 1.0).abs() < f64::EPSILON);
    }

    #[test]
    fn exhausted_expires() {
        let (tracker, id) = make_tracker_with_node("n1", 100);

        tracker.mark_exhausted(&id, Duration::from_millis(20));
        assert!(!tracker.has_capacity(&id, &[("rpm", 1)]));

        std::thread::sleep(Duration::from_millis(30));
        assert!(tracker.has_capacity(&id, &[("rpm", 1)]));
    }

    #[test]
    fn unknown_node_has_unlimited_capacity() {
        let tracker = QuotaTracker::<String>::new();
        let unknown = "unknown".to_string();

        assert!(tracker.has_capacity(&unknown, &[("rpm", 999)]));
        assert_eq!(tracker.remaining(&unknown, "rpm"), u64::MAX);
        assert!((tracker.pressure(&unknown) - 0.0).abs() < f64::EPSILON);
    }
}