use std::fmt;
const FALLBACK_RAM_MB: u64 = 8 * 1024;
const EMBED_MB_PER_BATCH_SLOT: u64 = 32;
const EMBED_ARENA_BUDGET_NUM: u64 = 75;
const EMBED_ARENA_BUDGET_DEN: u64 = 100;
const MEMORY_LIMIT_FRACTION_NUM: u64 = 25;
const MEMORY_LIMIT_FRACTION_DEN: u64 = 100;
const MEMORY_LIMIT_FLOOR_MB: u64 = 1_024;
const MEMORY_LIMIT_CEIL_MB: u64 = 65_536;
const CHUNKS_PER_MB: u64 = 50;
const MAX_CHUNKS_FLOOR: usize = 50_000;
const MAX_CHUNKS_CEIL: usize = 800_000;
fn compute_memory_limit_mb(total_ram_mb: u64) -> usize {
let raw = total_ram_mb * MEMORY_LIMIT_FRACTION_NUM / MEMORY_LIMIT_FRACTION_DEN;
raw.clamp(MEMORY_LIMIT_FLOOR_MB, MEMORY_LIMIT_CEIL_MB) as usize
}
fn compute_max_chunks(memory_limit_mb: usize) -> usize {
let raw = (memory_limit_mb as u64) * CHUNKS_PER_MB;
(raw as usize).clamp(MAX_CHUNKS_FLOOR, MAX_CHUNKS_CEIL)
}
const MIN_COMPUTED_BATCH_SIZE: usize = 32;
const MAX_COMPUTED_BATCH_SIZE: usize = 512;
fn compute_max_batch_size(memory_limit_mb: usize) -> usize {
let budget_mb = (memory_limit_mb as u64) * EMBED_ARENA_BUDGET_NUM / EMBED_ARENA_BUDGET_DEN;
let raw = (budget_mb / EMBED_MB_PER_BATCH_SLOT) as usize;
raw.clamp(MIN_COMPUTED_BATCH_SIZE, MAX_COMPUTED_BATCH_SIZE)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemoryTier {
Medium,
Large,
XLarge,
}
impl fmt::Display for MemoryTier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
MemoryTier::Medium => "Medium",
MemoryTier::Large => "Large",
MemoryTier::XLarge => "XLarge",
})
}
}
impl MemoryTier {
pub fn batch_size_hard_cap(self) -> usize {
match self {
MemoryTier::Medium => 128,
MemoryTier::Large => 256,
MemoryTier::XLarge => 512,
}
}
pub fn from_total_ram_mb(total_ram_mb: u64) -> Self {
let gb = total_ram_mb / 1024;
match gb {
0..=31 => MemoryTier::Medium,
32..=63 => MemoryTier::Large,
_ => MemoryTier::XLarge,
}
}
fn defaults(self, memory_limit_mb: usize) -> TierDefaults {
let (embedding_cache, bm25_corpus_cap, max_kg_nodes) = match self {
MemoryTier::Medium => (5_000, 100_000, 150_000),
MemoryTier::Large => (10_000, 200_000, 300_000),
MemoryTier::XLarge => (20_000, 400_000, 500_000),
};
TierDefaults {
memory_limit_mb,
max_chunks: compute_max_chunks(memory_limit_mb),
embedding_cache,
max_batch_size: compute_max_batch_size(memory_limit_mb),
bm25_corpus_cap,
max_kg_nodes,
}
}
}
#[derive(Debug, Clone, Copy)]
struct TierDefaults {
memory_limit_mb: usize,
max_chunks: usize,
embedding_cache: usize,
max_batch_size: usize,
bm25_corpus_cap: usize,
max_kg_nodes: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct MemoryPolicy {
pub total_ram_mb: u64,
pub tier: MemoryTier,
pub memory_limit_mb: usize,
pub max_chunks: usize,
pub embedding_cache: usize,
pub max_batch_size: usize,
pub bm25_corpus_cap: usize,
pub max_kg_nodes: usize,
}
impl MemoryPolicy {
pub fn detect() -> Self {
let total_ram_mb = detect_total_ram_mb().unwrap_or_else(|| {
tracing::warn!(
"memory_policy: could not detect total system RAM — \
falling back to {FALLBACK_RAM_MB} MB (Medium tier defaults)"
);
FALLBACK_RAM_MB
});
Self::from_total_ram_mb(total_ram_mb)
}
pub fn from_total_ram_mb(total_ram_mb: u64) -> Self {
let tier = MemoryTier::from_total_ram_mb(total_ram_mb);
let proportional_limit_mb = compute_memory_limit_mb(total_ram_mb);
let d = tier.defaults(proportional_limit_mb);
let memory_limit_mb = env_override_usize("TRUSTY_MEMORY_LIMIT_MB", d.memory_limit_mb);
let derived_batch_size = if memory_limit_mb == d.memory_limit_mb {
d.max_batch_size
} else {
compute_max_batch_size(memory_limit_mb)
};
let derived_max_chunks = if memory_limit_mb == d.memory_limit_mb {
d.max_chunks
} else {
compute_max_chunks(memory_limit_mb)
};
let explicit = std::env::var("TRUSTY_MAX_BATCH_SIZE_EXPLICIT")
.map(|v| v == "1")
.unwrap_or(false);
let env_set = std::env::var("TRUSTY_MAX_BATCH_SIZE").is_ok();
let raw_batch_size = env_override_usize("TRUSTY_MAX_BATCH_SIZE", derived_batch_size);
let batch_cap = tier.batch_size_hard_cap();
let max_batch_size = if explicit && env_set {
tracing::warn!(
"memory_policy: TRUSTY_MAX_BATCH_SIZE_EXPLICIT=1 — honoring \
TRUSTY_MAX_BATCH_SIZE={} verbatim and bypassing tier {} hard cap of {}. \
Ensure you have measured the actual ORT transient-allocation cost per slot \
on your workload (defaults assume 32 MB/slot with arena disabled).",
raw_batch_size,
tier,
batch_cap,
);
raw_batch_size
} else if raw_batch_size > batch_cap {
tracing::warn!(
"memory_policy: TRUSTY_MAX_BATCH_SIZE={} exceeds tier {} hard cap of {}; \
clamping to protect against ORT transient-arena spike (issue #89). \
Set TRUSTY_MAX_BATCH_SIZE_EXPLICIT=1 to bypass this clamp.",
raw_batch_size,
tier,
batch_cap,
);
batch_cap
} else {
raw_batch_size
};
let policy = MemoryPolicy {
total_ram_mb,
tier,
memory_limit_mb,
max_chunks: env_override_usize("TRUSTY_MAX_CHUNKS", derived_max_chunks),
embedding_cache: env_override_usize("TRUSTY_EMBEDDING_CACHE", d.embedding_cache),
max_batch_size,
bm25_corpus_cap: env_override_usize("TRUSTY_BM25_CORPUS_CAP", d.bm25_corpus_cap),
max_kg_nodes: env_override_usize("TRUSTY_MAX_KG_NODES", d.max_kg_nodes),
};
policy.apply_to_env();
policy
}
pub fn apply_to_env(&self) {
unsafe {
std::env::set_var("TRUSTY_MEMORY_LIMIT_MB", self.memory_limit_mb.to_string());
std::env::set_var("TRUSTY_MAX_CHUNKS", self.max_chunks.to_string());
std::env::set_var("TRUSTY_EMBEDDING_CACHE", self.embedding_cache.to_string());
std::env::set_var("TRUSTY_MAX_BATCH_SIZE", self.max_batch_size.to_string());
std::env::set_var("TRUSTY_BM25_CORPUS_CAP", self.bm25_corpus_cap.to_string());
std::env::set_var("TRUSTY_MAX_KG_NODES", self.max_kg_nodes.to_string());
}
}
pub fn log_summary(&self) {
let gb = self.total_ram_mb / 1024;
let proportional = compute_memory_limit_mb(self.total_ram_mb);
tracing::info!(
"trusty-search: detected {} GB RAM → tier={} (proportional memory_limit_mb={}, 25% of RAM clamped to [{}, {}])",
gb,
self.tier,
proportional,
MEMORY_LIMIT_FLOOR_MB,
MEMORY_LIMIT_CEIL_MB,
);
tracing::info!(
" MEMORY_LIMIT_MB={} MAX_CHUNKS={} EMBEDDING_CACHE={} \
MAX_BATCH_SIZE={} BM25_CORPUS_CAP={} MAX_KG_NODES={}",
self.memory_limit_mb,
self.max_chunks,
self.embedding_cache,
self.max_batch_size,
self.bm25_corpus_cap,
self.max_kg_nodes,
);
}
}
fn env_override_usize(name: &str, default: usize) -> usize {
match std::env::var(name) {
Ok(v) => match v.parse::<usize>() {
Ok(n) => n,
Err(_) => {
tracing::warn!(
"memory_policy: {name}={v:?} is not a valid usize; \
using tier default ({default})"
);
default
}
},
Err(_) => default,
}
}
pub fn detect_total_ram_mb() -> Option<u64> {
#[cfg(target_os = "macos")]
{
detect_macos_ram_mb()
}
#[cfg(target_os = "linux")]
{
detect_linux_ram_mb()
}
#[cfg(not(any(target_os = "macos", target_os = "linux")))]
{
None
}
}
#[cfg(target_os = "macos")]
fn detect_macos_ram_mb() -> Option<u64> {
use std::process::Command;
let output = Command::new("sysctl")
.args(["-n", "hw.memsize"])
.output()
.ok()?;
if !output.status.success() {
return None;
}
let text = String::from_utf8(output.stdout).ok()?;
let bytes: u64 = text.trim().parse().ok()?;
Some(bytes / (1024 * 1024))
}
#[cfg(target_os = "linux")]
fn detect_linux_ram_mb() -> Option<u64> {
let text = std::fs::read_to_string("/proc/meminfo").ok()?;
for line in text.lines() {
if let Some(rest) = line.strip_prefix("MemTotal:") {
let mut parts = rest.split_whitespace();
let kb: u64 = parts.next()?.parse().ok()?;
return Some(kb / 1024);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
static ENV_LOCK: Mutex<()> = Mutex::new(());
#[test]
fn test_tier_selection() {
assert_eq!(MemoryTier::from_total_ram_mb(16 * 1024), MemoryTier::Medium);
assert_eq!(MemoryTier::from_total_ram_mb(31 * 1024), MemoryTier::Medium);
assert_eq!(MemoryTier::from_total_ram_mb(32 * 1024), MemoryTier::Large);
assert_eq!(MemoryTier::from_total_ram_mb(63 * 1024), MemoryTier::Large);
assert_eq!(MemoryTier::from_total_ram_mb(64 * 1024), MemoryTier::XLarge);
assert_eq!(
MemoryTier::from_total_ram_mb(192 * 1024),
MemoryTier::XLarge
);
assert_eq!(MemoryTier::from_total_ram_mb(15 * 1024), MemoryTier::Medium);
assert_eq!(MemoryTier::from_total_ram_mb(8 * 1024), MemoryTier::Medium);
assert_eq!(MemoryTier::from_total_ram_mb(4 * 1024), MemoryTier::Medium);
}
#[test]
fn test_tier_defaults_table() {
let medium = MemoryTier::Medium.defaults(compute_memory_limit_mb(16 * 1024));
assert_eq!(medium.memory_limit_mb, 4_096);
assert_eq!(medium.max_chunks, 204_800);
assert_eq!(medium.max_batch_size, 96);
assert_eq!(medium.embedding_cache, 5_000);
let large = MemoryTier::Large.defaults(compute_memory_limit_mb(32 * 1024));
assert_eq!(large.memory_limit_mb, 8_192);
assert_eq!(large.max_chunks, 409_600);
assert_eq!(large.max_batch_size, 192);
let xl = MemoryTier::XLarge.defaults(compute_memory_limit_mb(64 * 1024));
assert_eq!(xl.memory_limit_mb, 16_384);
assert_eq!(xl.max_chunks, 800_000);
assert_eq!(xl.embedding_cache, 20_000);
assert_eq!(xl.max_kg_nodes, 500_000);
assert_eq!(xl.max_batch_size, 384);
let huge = MemoryTier::XLarge.defaults(compute_memory_limit_mb(128 * 1024));
assert_eq!(huge.memory_limit_mb, 32 * 1024);
let max_host = MemoryTier::XLarge.defaults(compute_memory_limit_mb(256 * 1024));
assert_eq!(max_host.memory_limit_mb, MEMORY_LIMIT_CEIL_MB as usize);
}
#[test]
fn test_compute_memory_limit_from_ram() {
assert_eq!(compute_memory_limit_mb(16 * 1024), 4 * 1024); assert_eq!(compute_memory_limit_mb(32 * 1024), 8 * 1024); assert_eq!(compute_memory_limit_mb(64 * 1024), 16 * 1024); assert_eq!(compute_memory_limit_mb(128 * 1024), 32 * 1024); assert_eq!(compute_memory_limit_mb(192 * 1024), 48 * 1024);
assert_eq!(compute_memory_limit_mb(256 * 1024), 64 * 1024);
assert_eq!(compute_memory_limit_mb(1024 * 1024), 64 * 1024);
assert_eq!(compute_memory_limit_mb(0), 1_024);
assert_eq!(compute_memory_limit_mb(2 * 1024), 1_024); assert_eq!(compute_memory_limit_mb(4 * 1024), 1_024); }
#[test]
fn test_compute_max_chunks_from_limit() {
assert_eq!(compute_max_chunks(4_096), 204_800);
assert_eq!(compute_max_chunks(8_192), 409_600);
assert_eq!(compute_max_chunks(16_384), 800_000); assert_eq!(compute_max_chunks(32_768), 800_000); assert_eq!(compute_max_chunks(65_536), 800_000); assert_eq!(compute_max_chunks(0), MAX_CHUNKS_FLOOR);
assert_eq!(compute_max_chunks(500), MAX_CHUNKS_FLOOR);
}
#[test]
fn test_memory_limit_scales_proportionally_across_xlarge_hosts() {
let _g = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let prior = std::env::var("TRUSTY_MEMORY_LIMIT_MB").ok();
unsafe {
std::env::remove_var("TRUSTY_MEMORY_LIMIT_MB");
}
let p64 = MemoryPolicy::from_total_ram_mb(64 * 1024);
unsafe {
std::env::remove_var("TRUSTY_MEMORY_LIMIT_MB");
}
let p128 = MemoryPolicy::from_total_ram_mb(128 * 1024);
unsafe {
match prior {
Some(v) => std::env::set_var("TRUSTY_MEMORY_LIMIT_MB", v),
None => std::env::remove_var("TRUSTY_MEMORY_LIMIT_MB"),
}
}
assert_eq!(p64.tier, MemoryTier::XLarge);
assert_eq!(p128.tier, MemoryTier::XLarge);
assert!(
p128.memory_limit_mb > p64.memory_limit_mb,
"128 GB host ({} MB) should have a larger memory_limit_mb than \
a 64 GB host ({} MB) — see issue #120",
p128.memory_limit_mb,
p64.memory_limit_mb,
);
assert_eq!(p64.memory_limit_mb, 16 * 1024);
assert_eq!(p128.memory_limit_mb, 32 * 1024);
}
#[test]
fn test_compute_max_batch_size_from_limit() {
assert_eq!(compute_max_batch_size(4_096), 96);
assert_eq!(compute_max_batch_size(8_192), 192);
assert_eq!(compute_max_batch_size(16_384), 384);
assert_eq!(compute_max_batch_size(0), MIN_COMPUTED_BATCH_SIZE);
assert_eq!(compute_max_batch_size(1_024), MIN_COMPUTED_BATCH_SIZE);
assert_eq!(compute_max_batch_size(64_000), MAX_COMPUTED_BATCH_SIZE);
assert_eq!(compute_max_batch_size(1_000_000), MAX_COMPUTED_BATCH_SIZE);
assert_eq!(compute_max_batch_size(22_000), MAX_COMPUTED_BATCH_SIZE);
}
#[test]
fn test_env_override() {
let _g = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let prior = std::env::var("TRUSTY_MAX_CHUNKS").ok();
unsafe {
std::env::set_var("TRUSTY_MAX_CHUNKS", "42");
}
let policy = MemoryPolicy::from_total_ram_mb(16 * 1024);
assert_eq!(policy.tier, MemoryTier::Medium);
assert_eq!(policy.max_chunks, 42);
unsafe {
match prior {
Some(v) => std::env::set_var("TRUSTY_MAX_CHUNKS", v),
None => std::env::remove_var("TRUSTY_MAX_CHUNKS"),
}
}
}
#[test]
fn test_tier_batch_size_hard_cap() {
assert_eq!(MemoryTier::Medium.batch_size_hard_cap(), 128);
assert_eq!(MemoryTier::Large.batch_size_hard_cap(), 256);
assert_eq!(MemoryTier::XLarge.batch_size_hard_cap(), 512);
}
#[test]
fn test_batch_size_env_override_clamped_by_hard_cap() {
let _g = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let prior = std::env::var("TRUSTY_MAX_BATCH_SIZE").ok();
let prior_explicit = std::env::var("TRUSTY_MAX_BATCH_SIZE_EXPLICIT").ok();
unsafe {
std::env::set_var("TRUSTY_MAX_BATCH_SIZE", "2048");
std::env::remove_var("TRUSTY_MAX_BATCH_SIZE_EXPLICIT");
}
let policy = MemoryPolicy::from_total_ram_mb(16 * 1024);
assert_eq!(policy.tier, MemoryTier::Medium);
assert_eq!(
policy.max_batch_size, 128,
"Medium tier must clamp TRUSTY_MAX_BATCH_SIZE=2048 down to 128"
);
unsafe {
match prior {
Some(v) => std::env::set_var("TRUSTY_MAX_BATCH_SIZE", v),
None => std::env::remove_var("TRUSTY_MAX_BATCH_SIZE"),
}
match prior_explicit {
Some(v) => std::env::set_var("TRUSTY_MAX_BATCH_SIZE_EXPLICIT", v),
None => std::env::remove_var("TRUSTY_MAX_BATCH_SIZE_EXPLICIT"),
}
}
}
#[test]
fn test_batch_size_explicit_flag_bypasses_clamp() {
let _g = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let prior = std::env::var("TRUSTY_MAX_BATCH_SIZE").ok();
let prior_explicit = std::env::var("TRUSTY_MAX_BATCH_SIZE_EXPLICIT").ok();
let prior_mem = std::env::var("TRUSTY_MEMORY_LIMIT_MB").ok();
unsafe {
std::env::set_var("TRUSTY_MAX_BATCH_SIZE", "512");
std::env::set_var("TRUSTY_MAX_BATCH_SIZE_EXPLICIT", "1");
std::env::remove_var("TRUSTY_MEMORY_LIMIT_MB");
}
let policy = MemoryPolicy::from_total_ram_mb(16 * 1024);
assert_eq!(policy.tier, MemoryTier::Medium);
assert_eq!(
policy.max_batch_size, 512,
"TRUSTY_MAX_BATCH_SIZE_EXPLICIT=1 must bypass the tier hard cap"
);
unsafe {
match prior {
Some(v) => std::env::set_var("TRUSTY_MAX_BATCH_SIZE", v),
None => std::env::remove_var("TRUSTY_MAX_BATCH_SIZE"),
}
match prior_explicit {
Some(v) => std::env::set_var("TRUSTY_MAX_BATCH_SIZE_EXPLICIT", v),
None => std::env::remove_var("TRUSTY_MAX_BATCH_SIZE_EXPLICIT"),
}
match prior_mem {
Some(v) => std::env::set_var("TRUSTY_MEMORY_LIMIT_MB", v),
None => std::env::remove_var("TRUSTY_MEMORY_LIMIT_MB"),
}
}
}
#[test]
fn test_ram_detection_returns_nonzero() {
if let Some(mb) = detect_total_ram_mb() {
assert!(mb > 0, "detected RAM should be > 0, got {mb}");
assert!(
mb < 4 * 1024 * 1024,
"detected RAM implausibly large: {mb} MB"
);
}
}
}