use std::sync::atomic::{AtomicU64, Ordering};
use crate::{
Stream,
error::{Error, InvariantViolationPayload, Result},
};
use super::recommended_working_set_bytes;
pub trait WiredMemoryPolicy: Send + Sync {
fn limit(&self, baseline: u64, active_sizes: &[u64]) -> u64;
fn can_admit(&self, _baseline: u64, _active_sizes: &[u64], _new_size: u64) -> bool {
true
}
fn id(&self) -> &str {
""
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct WiredSumPolicy {
cap: Option<u64>,
}
impl WiredSumPolicy {
pub fn new(cap: Option<u64>) -> Self {
Self { cap }
}
#[inline(always)]
pub const fn cap(&self) -> Option<u64> {
self.cap
}
#[must_use]
pub fn with_cap(mut self, cap: Option<u64>) -> Self {
self.cap = cap;
self
}
fn clamp(&self, value: u64) -> u64 {
if let Some(cap) = self.cap {
return value.min(cap);
}
if let Ok(Some(max_bytes)) = recommended_working_set_bytes() {
return value.min(max_bytes);
}
value
}
}
impl WiredMemoryPolicy for WiredSumPolicy {
fn limit(&self, baseline: u64, active_sizes: &[u64]) -> u64 {
let sum: u64 = active_sizes.iter().copied().sum();
self.clamp(baseline.saturating_add(sum))
}
fn can_admit(&self, baseline: u64, active_sizes: &[u64], new_size: u64) -> bool {
let projected = baseline
.saturating_add(active_sizes.iter().copied().sum())
.saturating_add(new_size);
self.clamp(projected) == projected
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
pub struct WiredMaxPolicy;
impl WiredMaxPolicy {
pub fn new() -> Self {
Self
}
}
impl WiredMemoryPolicy for WiredMaxPolicy {
fn limit(&self, baseline: u64, active_sizes: &[u64]) -> u64 {
let max_active = active_sizes.iter().copied().max().unwrap_or(0);
baseline.max(max_active)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct WiredFixedPolicy {
limit_bytes: u64,
}
impl WiredFixedPolicy {
pub fn new(limit_bytes: u64) -> Self {
Self { limit_bytes }
}
#[inline(always)]
pub const fn limit_bytes(&self) -> u64 {
self.limit_bytes
}
#[must_use]
pub fn with_limit_bytes(mut self, b: u64) -> Self {
self.limit_bytes = b;
self
}
}
impl WiredMemoryPolicy for WiredFixedPolicy {
fn limit(&self, _baseline: u64, _active_sizes: &[u64]) -> u64 {
self.limit_bytes
}
}
#[derive(Debug, Clone)]
pub struct WiredBudgetPolicy {
id: String,
base_bytes: u64,
cap: Option<u64>,
}
static AUTO_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
impl WiredBudgetPolicy {
pub fn new(base_bytes: u64, cap: Option<u64>) -> Self {
let n = AUTO_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
Self::with_id(format!("auto-{n}"), base_bytes, cap)
}
pub fn with_id(id: impl Into<String>, base_bytes: u64, cap: Option<u64>) -> Self {
Self {
id: id.into(),
base_bytes,
cap,
}
}
#[inline(always)]
pub fn id(&self) -> &str {
&self.id
}
#[inline(always)]
pub fn id_str(&self) -> &str {
self.id()
}
#[inline(always)]
pub const fn base_bytes(&self) -> u64 {
self.base_bytes
}
#[must_use]
pub fn with_base_bytes(mut self, b: u64) -> Self {
self.base_bytes = b;
self
}
#[inline(always)]
pub const fn cap(&self) -> Option<u64> {
self.cap
}
#[must_use]
pub fn with_cap(mut self, cap: Option<u64>) -> Self {
self.cap = cap;
self
}
fn clamp(&self, value: u64) -> u64 {
if let Some(cap) = self.cap {
return value.min(cap);
}
if let Ok(Some(max_bytes)) = recommended_working_set_bytes() {
return value.min(max_bytes);
}
value
}
}
impl PartialEq for WiredBudgetPolicy {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for WiredBudgetPolicy {}
impl std::hash::Hash for WiredBudgetPolicy {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.id.hash(state);
}
}
impl WiredMemoryPolicy for WiredBudgetPolicy {
fn limit(&self, baseline: u64, active_sizes: &[u64]) -> u64 {
let sum: u64 = active_sizes.iter().copied().sum();
self.clamp(baseline.saturating_add(self.base_bytes).saturating_add(sum))
}
fn can_admit(&self, baseline: u64, active_sizes: &[u64], new_size: u64) -> bool {
let projected = baseline
.saturating_add(self.base_bytes)
.saturating_add(active_sizes.iter().copied().sum())
.saturating_add(new_size);
self.clamp(projected) == projected
}
fn id(&self) -> &str {
&self.id
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct WiredMemoryMeasurement {
weight_bytes: u64,
kv_bytes: u64,
workspace_bytes: u64,
peak_active_bytes: u64,
token_count: usize,
prefill_step_size: usize,
}
impl WiredMemoryMeasurement {
pub fn new(
weight_bytes: u64,
kv_bytes: u64,
workspace_bytes: u64,
peak_active_bytes: u64,
token_count: usize,
prefill_step_size: usize,
) -> Self {
Self {
weight_bytes,
kv_bytes,
workspace_bytes,
peak_active_bytes,
token_count,
prefill_step_size,
}
}
#[inline(always)]
pub fn weight_bytes(&self) -> u64 {
self.weight_bytes
}
#[inline(always)]
pub fn kv_bytes(&self) -> u64 {
self.kv_bytes
}
#[inline(always)]
pub fn workspace_bytes(&self) -> u64 {
self.workspace_bytes
}
#[inline(always)]
pub fn peak_active_bytes(&self) -> u64 {
self.peak_active_bytes
}
#[inline(always)]
pub fn token_count(&self) -> usize {
self.token_count
}
#[inline(always)]
pub fn prefill_step_size(&self) -> usize {
self.prefill_step_size
}
pub fn total_bytes(&self) -> u64 {
self
.weight_bytes
.saturating_add(self.kv_bytes)
.saturating_add(self.workspace_bytes)
}
}
pub fn tune(
_model_bytes: u64,
_token_count: usize,
_prefill_step_size: usize,
_streams: &[Stream],
) -> Result<WiredMemoryMeasurement> {
Err(Error::InvariantViolation(InvariantViolationPayload::new(
"WiredMemoryUtils::tune",
"not yet implemented — requires Model::prefill_only (stub); see issue #168",
)))
}