use std::fmt;
const FALLBACK_RAM_MB: u64 = 8 * 1024;
const EMBED_MB_PER_BATCH_SLOT: u64 = 55;
const EMBED_ARENA_BUDGET_NUM: u64 = 75;
const EMBED_ARENA_BUDGET_DEN: u64 = 100;
const MIN_COMPUTED_BATCH_SIZE: usize = 32;
const MAX_COMPUTED_BATCH_SIZE: usize = 512;
fn compute_max_batch_size(memory_limit_mb: usize) -> usize {
let budget_mb = (memory_limit_mb as u64) * EMBED_ARENA_BUDGET_NUM / EMBED_ARENA_BUDGET_DEN;
let raw = (budget_mb / EMBED_MB_PER_BATCH_SLOT) as usize;
raw.clamp(MIN_COMPUTED_BATCH_SIZE, MAX_COMPUTED_BATCH_SIZE)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemoryTier {
Medium,
Large,
XLarge,
}
impl fmt::Display for MemoryTier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
MemoryTier::Medium => "Medium",
MemoryTier::Large => "Large",
MemoryTier::XLarge => "XLarge",
})
}
}
impl MemoryTier {
pub fn batch_size_hard_cap(self) -> usize {
match self {
MemoryTier::Medium => 64,
MemoryTier::Large => 128,
MemoryTier::XLarge => 256,
}
}
pub fn from_total_ram_mb(total_ram_mb: u64) -> Self {
let gb = total_ram_mb / 1024;
match gb {
0..=31 => MemoryTier::Medium,
32..=63 => MemoryTier::Large,
_ => MemoryTier::XLarge,
}
}
fn defaults(self) -> TierDefaults {
let (memory_limit_mb, max_chunks, embedding_cache, bm25_corpus_cap, max_kg_nodes) =
match self {
MemoryTier::Medium => (4_096, 200_000, 5_000, 100_000, 150_000),
MemoryTier::Large => (8_192, 400_000, 10_000, 200_000, 300_000),
MemoryTier::XLarge => (16_384, 800_000, 20_000, 400_000, 500_000),
};
TierDefaults {
memory_limit_mb,
max_chunks,
embedding_cache,
max_batch_size: compute_max_batch_size(memory_limit_mb),
bm25_corpus_cap,
max_kg_nodes,
}
}
}
#[derive(Debug, Clone, Copy)]
struct TierDefaults {
memory_limit_mb: usize,
max_chunks: usize,
embedding_cache: usize,
max_batch_size: usize,
bm25_corpus_cap: usize,
max_kg_nodes: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct MemoryPolicy {
pub total_ram_mb: u64,
pub tier: MemoryTier,
pub memory_limit_mb: usize,
pub max_chunks: usize,
pub embedding_cache: usize,
pub max_batch_size: usize,
pub bm25_corpus_cap: usize,
pub max_kg_nodes: usize,
}
impl MemoryPolicy {
pub fn detect() -> Self {
let total_ram_mb = detect_total_ram_mb().unwrap_or_else(|| {
tracing::warn!(
"memory_policy: could not detect total system RAM — \
falling back to {FALLBACK_RAM_MB} MB (Medium tier defaults)"
);
FALLBACK_RAM_MB
});
Self::from_total_ram_mb(total_ram_mb)
}
pub fn from_total_ram_mb(total_ram_mb: u64) -> Self {
let tier = MemoryTier::from_total_ram_mb(total_ram_mb);
let d = tier.defaults();
let memory_limit_mb = env_override_usize("TRUSTY_MEMORY_LIMIT_MB", d.memory_limit_mb);
let derived_batch_size = if memory_limit_mb == d.memory_limit_mb {
d.max_batch_size
} else {
compute_max_batch_size(memory_limit_mb)
};
let raw_batch_size = env_override_usize("TRUSTY_MAX_BATCH_SIZE", derived_batch_size);
let batch_cap = tier.batch_size_hard_cap();
let max_batch_size = if raw_batch_size > batch_cap {
tracing::warn!(
"memory_policy: TRUSTY_MAX_BATCH_SIZE={} exceeds tier {} hard cap of {}; \
clamping to protect against ORT transient-arena spike (issue #89)",
raw_batch_size,
tier,
batch_cap,
);
batch_cap
} else {
raw_batch_size
};
let policy = MemoryPolicy {
total_ram_mb,
tier,
memory_limit_mb,
max_chunks: env_override_usize("TRUSTY_MAX_CHUNKS", d.max_chunks),
embedding_cache: env_override_usize("TRUSTY_EMBEDDING_CACHE", d.embedding_cache),
max_batch_size,
bm25_corpus_cap: env_override_usize("TRUSTY_BM25_CORPUS_CAP", d.bm25_corpus_cap),
max_kg_nodes: env_override_usize("TRUSTY_MAX_KG_NODES", d.max_kg_nodes),
};
policy.apply_to_env();
policy
}
pub fn apply_to_env(&self) {
unsafe {
std::env::set_var("TRUSTY_MEMORY_LIMIT_MB", self.memory_limit_mb.to_string());
std::env::set_var("TRUSTY_MAX_CHUNKS", self.max_chunks.to_string());
std::env::set_var("TRUSTY_EMBEDDING_CACHE", self.embedding_cache.to_string());
std::env::set_var("TRUSTY_MAX_BATCH_SIZE", self.max_batch_size.to_string());
std::env::set_var("TRUSTY_BM25_CORPUS_CAP", self.bm25_corpus_cap.to_string());
std::env::set_var("TRUSTY_MAX_KG_NODES", self.max_kg_nodes.to_string());
}
}
pub fn log_summary(&self) {
let gb = self.total_ram_mb / 1024;
tracing::info!("trusty-search: detected {} GB RAM → tier={}", gb, self.tier);
tracing::info!(
" MEMORY_LIMIT_MB={} MAX_CHUNKS={} EMBEDDING_CACHE={} \
MAX_BATCH_SIZE={} BM25_CORPUS_CAP={} MAX_KG_NODES={}",
self.memory_limit_mb,
self.max_chunks,
self.embedding_cache,
self.max_batch_size,
self.bm25_corpus_cap,
self.max_kg_nodes,
);
}
}
fn env_override_usize(name: &str, default: usize) -> usize {
match std::env::var(name) {
Ok(v) => match v.parse::<usize>() {
Ok(n) => n,
Err(_) => {
tracing::warn!(
"memory_policy: {name}={v:?} is not a valid usize; \
using tier default ({default})"
);
default
}
},
Err(_) => default,
}
}
pub fn detect_total_ram_mb() -> Option<u64> {
#[cfg(target_os = "macos")]
{
detect_macos_ram_mb()
}
#[cfg(target_os = "linux")]
{
detect_linux_ram_mb()
}
#[cfg(not(any(target_os = "macos", target_os = "linux")))]
{
None
}
}
#[cfg(target_os = "macos")]
fn detect_macos_ram_mb() -> Option<u64> {
use std::process::Command;
let output = Command::new("sysctl")
.args(["-n", "hw.memsize"])
.output()
.ok()?;
if !output.status.success() {
return None;
}
let text = String::from_utf8(output.stdout).ok()?;
let bytes: u64 = text.trim().parse().ok()?;
Some(bytes / (1024 * 1024))
}
#[cfg(target_os = "linux")]
fn detect_linux_ram_mb() -> Option<u64> {
let text = std::fs::read_to_string("/proc/meminfo").ok()?;
for line in text.lines() {
if let Some(rest) = line.strip_prefix("MemTotal:") {
let mut parts = rest.split_whitespace();
let kb: u64 = parts.next()?.parse().ok()?;
return Some(kb / 1024);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tier_selection() {
assert_eq!(MemoryTier::from_total_ram_mb(16 * 1024), MemoryTier::Medium);
assert_eq!(MemoryTier::from_total_ram_mb(31 * 1024), MemoryTier::Medium);
assert_eq!(MemoryTier::from_total_ram_mb(32 * 1024), MemoryTier::Large);
assert_eq!(MemoryTier::from_total_ram_mb(63 * 1024), MemoryTier::Large);
assert_eq!(MemoryTier::from_total_ram_mb(64 * 1024), MemoryTier::XLarge);
assert_eq!(
MemoryTier::from_total_ram_mb(192 * 1024),
MemoryTier::XLarge
);
assert_eq!(MemoryTier::from_total_ram_mb(15 * 1024), MemoryTier::Medium);
assert_eq!(MemoryTier::from_total_ram_mb(8 * 1024), MemoryTier::Medium);
assert_eq!(MemoryTier::from_total_ram_mb(4 * 1024), MemoryTier::Medium);
}
#[test]
fn test_tier_defaults_table() {
let medium = MemoryTier::Medium.defaults();
assert_eq!(medium.memory_limit_mb, 4_096);
assert_eq!(medium.max_chunks, 200_000);
assert_eq!(medium.max_batch_size, 55);
let large = MemoryTier::Large.defaults();
assert_eq!(large.memory_limit_mb, 8_192);
assert_eq!(large.max_batch_size, 111);
let xl = MemoryTier::XLarge.defaults();
assert_eq!(xl.memory_limit_mb, 16_384);
assert_eq!(xl.max_chunks, 800_000);
assert_eq!(xl.embedding_cache, 20_000);
assert_eq!(xl.max_kg_nodes, 500_000);
assert_eq!(xl.max_batch_size, 223);
}
#[test]
fn test_compute_max_batch_size_from_limit() {
assert_eq!(compute_max_batch_size(4_096), 55);
assert_eq!(compute_max_batch_size(8_192), 111);
assert_eq!(compute_max_batch_size(16_384), 223);
assert_eq!(compute_max_batch_size(0), MIN_COMPUTED_BATCH_SIZE);
assert_eq!(compute_max_batch_size(1_024), MIN_COMPUTED_BATCH_SIZE);
assert_eq!(compute_max_batch_size(64_000), MAX_COMPUTED_BATCH_SIZE);
assert_eq!(compute_max_batch_size(1_000_000), MAX_COMPUTED_BATCH_SIZE);
}
#[test]
fn test_env_override() {
let prior = std::env::var("TRUSTY_MAX_CHUNKS").ok();
unsafe {
std::env::set_var("TRUSTY_MAX_CHUNKS", "42");
}
let policy = MemoryPolicy::from_total_ram_mb(16 * 1024);
assert_eq!(policy.tier, MemoryTier::Medium);
assert_eq!(policy.max_chunks, 42);
unsafe {
match prior {
Some(v) => std::env::set_var("TRUSTY_MAX_CHUNKS", v),
None => std::env::remove_var("TRUSTY_MAX_CHUNKS"),
}
}
}
#[test]
fn test_tier_batch_size_hard_cap() {
assert_eq!(MemoryTier::Medium.batch_size_hard_cap(), 64);
assert_eq!(MemoryTier::Large.batch_size_hard_cap(), 128);
assert_eq!(MemoryTier::XLarge.batch_size_hard_cap(), 256);
}
#[test]
fn test_batch_size_env_override_clamped_by_hard_cap() {
let prior = std::env::var("TRUSTY_MAX_BATCH_SIZE").ok();
unsafe {
std::env::set_var("TRUSTY_MAX_BATCH_SIZE", "2048");
}
let policy = MemoryPolicy::from_total_ram_mb(16 * 1024);
assert_eq!(policy.tier, MemoryTier::Medium);
assert_eq!(
policy.max_batch_size, 64,
"Medium tier must clamp TRUSTY_MAX_BATCH_SIZE=2048 down to 64"
);
unsafe {
match prior {
Some(v) => std::env::set_var("TRUSTY_MAX_BATCH_SIZE", v),
None => std::env::remove_var("TRUSTY_MAX_BATCH_SIZE"),
}
}
}
#[test]
fn test_ram_detection_returns_nonzero() {
if let Some(mb) = detect_total_ram_mb() {
assert!(mb > 0, "detected RAM should be > 0, got {mb}");
assert!(
mb < 4 * 1024 * 1024,
"detected RAM implausibly large: {mb} MB"
);
}
}
}