use roboticus_core::config::MemoryConfig;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MemoryBudgets {
pub working: usize,
pub episodic: usize,
pub semantic: usize,
pub procedural: usize,
pub relationship: usize,
}
pub struct MemoryBudgetManager {
config: MemoryConfig,
}
impl MemoryBudgetManager {
pub fn new(config: MemoryConfig) -> Self {
Self { config }
}
pub fn allocate_budgets(&self, total_tokens: usize) -> MemoryBudgets {
let working = pct(total_tokens, self.config.working_budget_pct);
let episodic = pct(total_tokens, self.config.episodic_budget_pct);
let semantic = pct(total_tokens, self.config.semantic_budget_pct);
let procedural = pct(total_tokens, self.config.procedural_budget_pct);
let relationship = pct(total_tokens, self.config.relationship_budget_pct);
let allocated = working + episodic + semantic + procedural + relationship;
let rollover = total_tokens.saturating_sub(allocated);
MemoryBudgets {
working: working + rollover,
episodic,
semantic,
procedural,
relationship,
}
}
pub fn config(&self) -> &MemoryConfig {
&self.config
}
}
fn pct(total: usize, percent: f64) -> usize {
((total as f64) * percent / 100.0).floor() as usize
}
#[derive(Debug, Clone)]
pub struct TierEffectiveness {
pub working_roi: f64,
pub episodic_roi: f64,
pub semantic_roi: f64,
pub procedural_roi: f64,
pub relationship_roi: f64,
}
impl Default for TierEffectiveness {
fn default() -> Self {
Self {
working_roi: 1.0,
episodic_roi: 1.0,
semantic_roi: 1.0,
procedural_roi: 1.0,
relationship_roi: 1.0,
}
}
}
pub struct AdaptiveBudgetAllocator {
base: MemoryConfig,
}
impl AdaptiveBudgetAllocator {
pub fn new(base: MemoryConfig) -> Self {
Self { base }
}
pub fn allocate(
&self,
total_tokens: usize,
effectiveness: &TierEffectiveness,
) -> MemoryBudgets {
let base = [
self.base.working_budget_pct,
self.base.episodic_budget_pct,
self.base.semantic_budget_pct,
self.base.procedural_budget_pct,
self.base.relationship_budget_pct,
];
let rois = [
effectiveness.working_roi,
effectiveness.episodic_roi,
effectiveness.semantic_roi,
effectiveness.procedural_roi,
effectiveness.relationship_roi,
];
let max_roi = rois.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if max_roi <= 0.0 {
let mgr = MemoryBudgetManager::new(self.base.clone());
return mgr.allocate_budgets(total_tokens);
}
let mut weighted: Vec<f64> = base
.iter()
.zip(rois.iter())
.map(|(b, r)| b * (1.0 + r / max_roi))
.collect();
let total_weighted: f64 = weighted.iter().sum();
if total_weighted > 0.0 {
for w in &mut weighted {
*w = *w / total_weighted * 100.0;
}
}
const MIN_PCT: f64 = 5.0;
const MAX_PCT: f64 = 50.0;
for w in &mut weighted {
*w = w.clamp(MIN_PCT, MAX_PCT);
}
let clamped_total: f64 = weighted.iter().sum();
let diff = 100.0 - clamped_total;
if diff.abs() > 0.01 {
let unclamped_sum: f64 = weighted
.iter()
.filter(|&&w| w > MIN_PCT && w < MAX_PCT)
.sum();
if unclamped_sum > 0.0 {
for w in &mut weighted {
if *w > MIN_PCT && *w < MAX_PCT {
*w += diff * (*w / unclamped_sum);
*w = w.clamp(MIN_PCT, MAX_PCT);
}
}
}
}
let tokens: Vec<usize> = weighted.iter().map(|p| pct(total_tokens, *p)).collect();
let allocated: usize = tokens.iter().sum();
let rollover = total_tokens.saturating_sub(allocated);
MemoryBudgets {
working: tokens[0] + rollover,
episodic: tokens[1],
semantic: tokens[2],
procedural: tokens[3],
relationship: tokens[4],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config() -> MemoryConfig {
MemoryConfig {
working_budget_pct: 30.0,
episodic_budget_pct: 25.0,
semantic_budget_pct: 20.0,
procedural_budget_pct: 15.0,
relationship_budget_pct: 10.0,
embedding_provider: None,
embedding_model: None,
hybrid_weight: 0.5,
ann_index: false,
adaptive_budget: false,
}
}
#[test]
fn static_allocation_matches_config() {
let mgr = MemoryBudgetManager::new(default_config());
let b = mgr.allocate_budgets(10_000);
assert_eq!(b.working, 3_000);
assert_eq!(b.episodic, 2_500);
assert_eq!(b.semantic, 2_000);
assert_eq!(b.procedural, 1_500);
assert_eq!(b.relationship, 1_000);
}
#[test]
fn adaptive_with_uniform_roi_matches_base() {
let alloc = AdaptiveBudgetAllocator::new(default_config());
let eff = TierEffectiveness {
working_roi: 1.0,
episodic_roi: 1.0,
semantic_roi: 1.0,
procedural_roi: 1.0,
relationship_roi: 1.0,
};
let b = alloc.allocate(10_000, &eff);
let total = b.working + b.episodic + b.semantic + b.procedural + b.relationship;
assert_eq!(total, 10_000);
assert!(b.working >= 2_500 && b.working <= 3_500);
}
#[test]
fn adaptive_shifts_toward_high_roi() {
let alloc = AdaptiveBudgetAllocator::new(default_config());
let eff = TierEffectiveness {
working_roi: 0.1,
episodic_roi: 0.1,
semantic_roi: 5.0, procedural_roi: 0.1,
relationship_roi: 0.1,
};
let b = alloc.allocate(10_000, &eff);
assert!(
b.semantic > 2_000,
"semantic should exceed base 2000, got {}",
b.semantic
);
let total = b.working + b.episodic + b.semantic + b.procedural + b.relationship;
assert_eq!(total, 10_000);
}
#[test]
fn adaptive_respects_floor_constraint() {
let alloc = AdaptiveBudgetAllocator::new(default_config());
let eff = TierEffectiveness {
working_roi: 100.0,
episodic_roi: 0.001,
semantic_roi: 0.001,
procedural_roi: 0.001,
relationship_roi: 0.001,
};
let b = alloc.allocate(10_000, &eff);
assert!(b.episodic >= 500, "episodic below 5% floor: {}", b.episodic);
assert!(b.semantic >= 500, "semantic below 5% floor: {}", b.semantic);
}
#[test]
fn adaptive_fallback_on_zero_roi() {
let alloc = AdaptiveBudgetAllocator::new(default_config());
let eff = TierEffectiveness {
working_roi: 0.0,
episodic_roi: 0.0,
semantic_roi: 0.0,
procedural_roi: 0.0,
relationship_roi: 0.0,
};
let b = alloc.allocate(10_000, &eff);
assert_eq!(b.working, 3_000);
assert_eq!(b.episodic, 2_500);
}
#[test]
fn adaptive_total_always_equals_input() {
let alloc = AdaptiveBudgetAllocator::new(default_config());
let cases = vec![
TierEffectiveness {
working_roi: 5.0,
episodic_roi: 1.0,
semantic_roi: 0.5,
procedural_roi: 3.0,
relationship_roi: 0.1,
},
TierEffectiveness::default(),
];
for eff in &cases {
for total in [99, 1_000, 10_000, 50_000] {
let b = alloc.allocate(total, eff);
let sum = b.working + b.episodic + b.semantic + b.procedural + b.relationship;
assert_eq!(sum, total, "total mismatch for input {total}");
}
}
}
}