use std::collections::HashMap;
use super::ExpertId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MemoryTier {
Sram,
Dram,
Storage,
}
impl MemoryTier {
pub fn name(&self) -> &'static str {
match self {
MemoryTier::Sram => "SRAM (L2/L3 Cache)",
MemoryTier::Dram => "DRAM (Main Memory)",
MemoryTier::Storage => "Storage (Flash/NVMe)",
}
}
pub fn index(&self) -> usize {
match self {
MemoryTier::Sram => 0,
MemoryTier::Dram => 1,
MemoryTier::Storage => 2,
}
}
}
#[derive(Debug, Clone)]
pub struct SramExpertAffinity {
pub expert_id: ExpertId,
pub access_count: usize,
pub last_access: u64,
pub avg_router_weight: f32,
pub recent_selections: usize,
pub pinned: bool,
}
impl Default for SramExpertAffinity {
fn default() -> Self {
Self {
expert_id: 0,
access_count: 0,
last_access: 0,
avg_router_weight: 0.0,
recent_selections: 0,
pinned: false,
}
}
}
impl SramExpertAffinity {
pub fn new(expert_id: ExpertId) -> Self {
Self {
expert_id,
..Default::default()
}
}
pub fn priority_score(&self) -> f32 {
let freq_factor = (self.access_count as f32 + 1.0).ln();
let recency_factor = if self.last_access == 0 {
0.0
} else {
1.0 / (1.0 + 0.001 / self.last_access as f32)
};
let weight_factor = self.avg_router_weight * 2.0;
freq_factor + recency_factor + weight_factor
}
}
#[derive(Debug, Clone)]
pub struct HardwareConfig {
pub sram_bytes: usize,
pub dram_budget_bytes: usize,
pub sram_expert_slots: usize,
pub dram_expert_slots: usize,
pub expert_size_bytes: usize,
}
impl Default for HardwareConfig {
fn default() -> Self {
Self {
sram_bytes: 8 * 1024 * 1024, dram_budget_bytes: 4 * 1024 * 1024 * 1024, sram_expert_slots: 2,
dram_expert_slots: 8,
expert_size_bytes: 34_000_000, }
}
}
impl HardwareConfig {
pub fn new(sram_bytes: usize, dram_budget_bytes: usize, expert_size_bytes: usize) -> Self {
let sram_expert_slots = sram_bytes / expert_size_bytes.max(1);
let dram_expert_slots = dram_budget_bytes / expert_size_bytes.max(1);
Self {
sram_bytes,
dram_budget_bytes,
sram_expert_slots,
dram_expert_slots,
expert_size_bytes,
}
}
pub fn total_budget(&self) -> usize {
self.sram_bytes + self.dram_budget_bytes
}
pub fn total_slots(&self) -> usize {
self.sram_expert_slots + self.dram_expert_slots
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HardwarePreset {
RaspberryPi5,
Mobile,
Desktop,
WasmBrowser,
Custom,
}
impl HardwarePreset {
pub fn default_config(&self, expert_size_bytes: usize) -> HardwareConfig {
match self {
HardwarePreset::RaspberryPi5 => {
let sram_bytes = 2 * 1024 * 1024; let dram_budget = 6 * 1024 * 1024 * 1024; HardwareConfig::new(sram_bytes, dram_budget, expert_size_bytes)
}
HardwarePreset::Mobile => {
let sram_bytes = 4 * 1024 * 1024;
let dram_budget = 2 * 1024 * 1024 * 1024;
HardwareConfig::new(sram_bytes, dram_budget, expert_size_bytes)
}
HardwarePreset::Desktop => {
let sram_bytes = 32 * 1024 * 1024;
let dram_budget = 12 * 1024 * 1024 * 1024;
HardwareConfig::new(sram_bytes, dram_budget, expert_size_bytes)
}
HardwarePreset::WasmBrowser => {
let sram_bytes = 2 * 1024 * 1024;
let dram_budget = 1024 * 1024 * 1024;
HardwareConfig::new(sram_bytes, dram_budget, expert_size_bytes)
}
HardwarePreset::Custom => HardwareConfig::default(),
}
}
pub fn name(&self) -> &'static str {
match self {
HardwarePreset::RaspberryPi5 => "Raspberry Pi 5",
HardwarePreset::Mobile => "Mobile Device",
HardwarePreset::Desktop => "Desktop Workstation",
HardwarePreset::WasmBrowser => "WASM Browser",
HardwarePreset::Custom => "Custom",
}
}
}
pub struct SramMapper {
config: HardwareConfig,
num_experts: usize,
tier_map: Vec<MemoryTier>,
affinity: Vec<SramExpertAffinity>,
tier_latency: [u64; 3],
sram_used: usize,
dram_used: usize,
access_counter: u64,
}
impl SramMapper {
pub fn from_preset(
preset: HardwarePreset,
num_experts: usize,
expert_size_bytes: usize,
) -> Self {
let config = preset.default_config(expert_size_bytes);
Self::from_config(config, num_experts)
}
pub fn from_config(config: HardwareConfig, num_experts: usize) -> Self {
let tier_map = vec![MemoryTier::Storage; num_experts];
let affinity = (0..num_experts).map(SramExpertAffinity::new).collect();
let tier_latency = [0, 0, 100];
Self {
config,
num_experts,
tier_map,
affinity,
tier_latency,
sram_used: 0,
dram_used: 0,
access_counter: 0,
}
}
pub fn assign_tier(&mut self, expert_id: ExpertId, tier: MemoryTier) -> bool {
if expert_id >= self.num_experts {
return false;
}
let old_tier = self.tier_map[expert_id];
match old_tier {
MemoryTier::Sram => {
if self.sram_used > 0 {
self.sram_used -= 1;
}
}
MemoryTier::Dram => {
if self.dram_used > 0 {
self.dram_used -= 1;
}
}
MemoryTier::Storage => {}
}
match tier {
MemoryTier::Sram => self.sram_used += 1,
MemoryTier::Dram => self.dram_used += 1,
MemoryTier::Storage => {}
}
self.tier_map[expert_id] = tier;
true
}
pub fn get_tier(&self, expert_id: ExpertId) -> MemoryTier {
self.tier_map
.get(expert_id)
.copied()
.unwrap_or(MemoryTier::Storage)
}
pub fn estimate_paging_latency(&self, expert_id: ExpertId) -> u64 {
let tier = self.get_tier(expert_id);
self.tier_latency[tier.index()]
}
pub fn sram_capacity(&self) -> usize {
self.config.sram_expert_slots
}
pub fn dram_capacity(&self) -> usize {
self.config.dram_expert_slots
}
pub fn sram_used(&self) -> usize {
self.sram_used
}
pub fn dram_used(&self) -> usize {
self.dram_used
}
pub fn sram_available(&self) -> usize {
self.config.sram_expert_slots.saturating_sub(self.sram_used)
}
pub fn dram_available(&self) -> usize {
self.config.dram_expert_slots.saturating_sub(self.dram_used)
}
pub fn record_access(&mut self, expert_id: ExpertId, router_weight: f32) {
if expert_id >= self.num_experts {
return;
}
self.access_counter += 1;
let affinity = &mut self.affinity[expert_id];
affinity.access_count += 1;
affinity.last_access = self.access_counter;
affinity.recent_selections += 1;
let alpha = 0.1;
affinity.avg_router_weight =
alpha * router_weight + (1.0 - alpha) * affinity.avg_router_weight;
}
pub fn suggest_eviction_tier(
&self,
_affinity_data: &SramExpertAffinity,
) -> Vec<(ExpertId, MemoryTier)> {
self.suggest_tier_changes()
}
pub fn suggest_tier_changes(&self) -> Vec<(ExpertId, MemoryTier)> {
let mut suggestions = Vec::new();
let mut experts: Vec<(ExpertId, f32, MemoryTier)> = self
.affinity
.iter()
.enumerate()
.filter(|(_, aff)| !aff.pinned)
.map(|(id, aff)| (id, aff.priority_score(), self.tier_map[id]))
.collect();
experts.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let sram_available = self.sram_available();
let mut promoted_to_sram = 0;
for &(expert_id, _priority, current_tier) in &experts {
if promoted_to_sram >= sram_available {
break;
}
if current_tier != MemoryTier::Sram {
suggestions.push((expert_id, MemoryTier::Sram));
promoted_to_sram += 1;
}
}
for &(expert_id, _priority, current_tier) in experts.iter().rev() {
if current_tier == MemoryTier::Sram
&& suggestions.iter().all(|(id, _)| *id != expert_id)
{
if self.dram_available() > 0 {
suggestions.push((expert_id, MemoryTier::Dram));
} else {
suggestions.push((expert_id, MemoryTier::Storage));
}
}
}
suggestions
}
pub fn pin(&mut self, expert_id: ExpertId) {
if expert_id < self.num_experts {
self.affinity[expert_id].pinned = true;
}
}
pub fn unpin(&mut self, expert_id: ExpertId) {
if expert_id < self.num_experts {
self.affinity[expert_id].pinned = false;
}
}
pub fn config(&self) -> &HardwareConfig {
&self.config
}
pub fn num_experts(&self) -> usize {
self.num_experts
}
pub fn set_tier_latencies(&mut self, sram_us: u64, dram_us: u64, storage_us: u64) {
self.tier_latency = [sram_us, dram_us, storage_us];
}
pub fn experts_in_tier(&self, tier: MemoryTier) -> Vec<ExpertId> {
self.tier_map
.iter()
.enumerate()
.filter(|(_, &t)| t == tier)
.map(|(id, _)| id)
.collect()
}
pub fn get_affinity(&self, expert_id: ExpertId) -> Option<&SramExpertAffinity> {
self.affinity.get(expert_id)
}
pub fn reset_affinity(&mut self) {
for (id, aff) in self.affinity.iter_mut().enumerate() {
*aff = SramExpertAffinity::new(id);
}
self.access_counter = 0;
}
pub fn tier_summary(&self) -> HashMap<MemoryTier, usize> {
let mut summary = HashMap::new();
summary.insert(MemoryTier::Sram, 0);
summary.insert(MemoryTier::Dram, 0);
summary.insert(MemoryTier::Storage, 0);
for &tier in &self.tier_map {
*summary.entry(tier).or_insert(0) += 1;
}
summary
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_preset_raspberry_pi() {
let expert_size = 34_000_000; let mapper = SramMapper::from_preset(HardwarePreset::RaspberryPi5, 8, expert_size);
assert_eq!(mapper.num_experts(), 8);
assert_eq!(mapper.sram_capacity(), 0); assert!(mapper.dram_capacity() > 0);
for i in 0..8 {
assert_eq!(mapper.get_tier(i), MemoryTier::Storage);
}
}
#[test]
fn test_from_preset_raspberry_pi_small_experts() {
let expert_size = 500_000; let mapper = SramMapper::from_preset(HardwarePreset::RaspberryPi5, 8, expert_size);
assert_eq!(mapper.sram_capacity(), 4);
assert!(mapper.dram_capacity() > 1000);
}
#[test]
fn test_from_preset_mobile() {
let expert_size = 1024 * 1024; let mapper = SramMapper::from_preset(HardwarePreset::Mobile, 8, expert_size);
assert_eq!(mapper.sram_capacity(), 4);
assert_eq!(mapper.dram_capacity(), 2048);
assert_eq!(mapper.num_experts(), 8);
}
#[test]
fn test_from_preset_desktop() {
let expert_size = 8 * 1024 * 1024; let mapper = SramMapper::from_preset(HardwarePreset::Desktop, 16, expert_size);
assert_eq!(mapper.sram_capacity(), 4);
assert_eq!(mapper.dram_capacity(), 1536);
assert_eq!(mapper.num_experts(), 16);
}
#[test]
fn test_tier_assignment() {
let config = HardwareConfig::new(
16 * 1024 * 1024, 4 * 1024 * 1024 * 1024, 4 * 1024 * 1024, );
let mut mapper = SramMapper::from_config(config, 8);
assert_eq!(mapper.get_tier(0), MemoryTier::Storage);
assert_eq!(mapper.sram_used(), 0);
assert_eq!(mapper.dram_used(), 0);
mapper.assign_tier(0, MemoryTier::Sram);
assert_eq!(mapper.get_tier(0), MemoryTier::Sram);
assert_eq!(mapper.sram_used(), 1);
mapper.assign_tier(1, MemoryTier::Dram);
assert_eq!(mapper.get_tier(1), MemoryTier::Dram);
assert_eq!(mapper.dram_used(), 1);
mapper.assign_tier(0, MemoryTier::Dram);
assert_eq!(mapper.get_tier(0), MemoryTier::Dram);
assert_eq!(mapper.sram_used(), 0);
assert_eq!(mapper.dram_used(), 2);
mapper.assign_tier(1, MemoryTier::Storage);
assert_eq!(mapper.get_tier(1), MemoryTier::Storage);
assert_eq!(mapper.dram_used(), 1);
}
#[test]
fn test_paging_latency_estimates() {
let config = HardwareConfig::new(16 * 1024 * 1024, 4 * 1024 * 1024 * 1024, 4 * 1024 * 1024);
let mut mapper = SramMapper::from_config(config, 4);
mapper.set_tier_latencies(1, 10, 200);
mapper.assign_tier(0, MemoryTier::Sram);
mapper.assign_tier(1, MemoryTier::Dram);
mapper.assign_tier(2, MemoryTier::Storage);
assert_eq!(mapper.estimate_paging_latency(0), 1); assert_eq!(mapper.estimate_paging_latency(1), 10); assert_eq!(mapper.estimate_paging_latency(2), 200); assert_eq!(mapper.estimate_paging_latency(3), 200);
assert_eq!(mapper.estimate_paging_latency(100), 200);
}
#[test]
fn test_capacity_calculations() {
let config = HardwareConfig::new(
32 * 1024 * 1024, 8 * 1024 * 1024 * 1024, 8 * 1024 * 1024, );
let mapper = SramMapper::from_config(config, 16);
assert_eq!(mapper.sram_capacity(), 4);
assert_eq!(mapper.dram_capacity(), 1024);
assert_eq!(mapper.config().total_slots(), 1028);
assert_eq!(
mapper.config().total_budget(),
32 * 1024 * 1024 + 8 * 1024 * 1024 * 1024
);
assert_eq!(mapper.sram_available(), 4);
assert_eq!(mapper.dram_available(), 1024);
}
#[test]
fn test_eviction_suggestions() {
let config = HardwareConfig::new(16 * 1024 * 1024, 4 * 1024 * 1024 * 1024, 4 * 1024 * 1024);
let mut mapper = SramMapper::from_config(config, 8);
for _ in 0..10 {
mapper.record_access(0, 0.8);
}
for _ in 0..5 {
mapper.record_access(1, 0.6);
}
mapper.record_access(2, 0.3);
let suggestions = mapper.suggest_tier_changes();
assert!(!suggestions.is_empty() || mapper.sram_available() == 0);
}
#[test]
fn test_custom_config() {
let config = HardwareConfig {
sram_bytes: 64 * 1024 * 1024, dram_budget_bytes: 16 * 1024 * 1024 * 1024, sram_expert_slots: 8,
dram_expert_slots: 200,
expert_size_bytes: 8 * 1024 * 1024,
};
let mapper = SramMapper::from_config(config.clone(), 32);
assert_eq!(mapper.sram_capacity(), 8);
assert_eq!(mapper.dram_capacity(), 200);
assert_eq!(mapper.num_experts(), 32);
assert_eq!(mapper.config().expert_size_bytes, 8 * 1024 * 1024);
}
#[test]
fn test_affinity_tracking() {
let config = HardwareConfig::default();
let mut mapper = SramMapper::from_config(config, 4);
mapper.record_access(0, 0.9);
mapper.record_access(0, 0.8);
mapper.record_access(1, 0.5);
let aff0 = mapper.get_affinity(0).unwrap();
assert_eq!(aff0.access_count, 2);
assert!(aff0.avg_router_weight > 0.0);
let aff1 = mapper.get_affinity(1).unwrap();
assert_eq!(aff1.access_count, 1);
mapper.reset_affinity();
let aff0_reset = mapper.get_affinity(0).unwrap();
assert_eq!(aff0_reset.access_count, 0);
}
#[test]
fn test_pin_unpin() {
let config = HardwareConfig::default();
let mut mapper = SramMapper::from_config(config, 4);
mapper.pin(0);
assert!(mapper.get_affinity(0).unwrap().pinned);
mapper.unpin(0);
assert!(!mapper.get_affinity(0).unwrap().pinned);
}
#[test]
fn test_experts_in_tier() {
let config = HardwareConfig::new(16 * 1024 * 1024, 4 * 1024 * 1024 * 1024, 4 * 1024 * 1024);
let mut mapper = SramMapper::from_config(config, 8);
mapper.assign_tier(0, MemoryTier::Sram);
mapper.assign_tier(1, MemoryTier::Sram);
mapper.assign_tier(2, MemoryTier::Dram);
mapper.assign_tier(3, MemoryTier::Dram);
mapper.assign_tier(4, MemoryTier::Dram);
let sram_experts = mapper.experts_in_tier(MemoryTier::Sram);
assert_eq!(sram_experts.len(), 2);
assert!(sram_experts.contains(&0));
assert!(sram_experts.contains(&1));
let dram_experts = mapper.experts_in_tier(MemoryTier::Dram);
assert_eq!(dram_experts.len(), 3);
let storage_experts = mapper.experts_in_tier(MemoryTier::Storage);
assert_eq!(storage_experts.len(), 3); }
#[test]
fn test_tier_summary() {
let config = HardwareConfig::new(16 * 1024 * 1024, 4 * 1024 * 1024 * 1024, 4 * 1024 * 1024);
let mut mapper = SramMapper::from_config(config, 8);
mapper.assign_tier(0, MemoryTier::Sram);
mapper.assign_tier(1, MemoryTier::Dram);
mapper.assign_tier(2, MemoryTier::Dram);
let summary = mapper.tier_summary();
assert_eq!(*summary.get(&MemoryTier::Sram).unwrap(), 1);
assert_eq!(*summary.get(&MemoryTier::Dram).unwrap(), 2);
assert_eq!(*summary.get(&MemoryTier::Storage).unwrap(), 5);
}
#[test]
fn test_memory_tier_properties() {
assert_eq!(MemoryTier::Sram.name(), "SRAM (L2/L3 Cache)");
assert_eq!(MemoryTier::Dram.name(), "DRAM (Main Memory)");
assert_eq!(MemoryTier::Storage.name(), "Storage (Flash/NVMe)");
assert_eq!(MemoryTier::Sram.index(), 0);
assert_eq!(MemoryTier::Dram.index(), 1);
assert_eq!(MemoryTier::Storage.index(), 2);
}
#[test]
fn test_hardware_preset_names() {
assert_eq!(HardwarePreset::RaspberryPi5.name(), "Raspberry Pi 5");
assert_eq!(HardwarePreset::Mobile.name(), "Mobile Device");
assert_eq!(HardwarePreset::Desktop.name(), "Desktop Workstation");
assert_eq!(HardwarePreset::WasmBrowser.name(), "WASM Browser");
assert_eq!(HardwarePreset::Custom.name(), "Custom");
}
#[test]
fn test_expert_affinity_priority_score() {
let mut aff = SramExpertAffinity::new(0);
let initial_score = aff.priority_score();
aff.access_count = 100;
aff.avg_router_weight = 0.9;
let high_score = aff.priority_score();
assert!(high_score > initial_score);
}
#[test]
fn test_wasm_browser_preset() {
let expert_size = 2 * 1024 * 1024; let mapper = SramMapper::from_preset(HardwarePreset::WasmBrowser, 8, expert_size);
assert_eq!(mapper.sram_capacity(), 1);
assert_eq!(mapper.dram_capacity(), 512);
}
#[test]
fn test_out_of_range_expert_id() {
let config = HardwareConfig::default();
let mapper = SramMapper::from_config(config, 4);
assert_eq!(mapper.get_tier(100), MemoryTier::Storage);
assert_eq!(mapper.estimate_paging_latency(100), 100); }
#[test]
fn test_record_access_out_of_range() {
let config = HardwareConfig::default();
let mut mapper = SramMapper::from_config(config, 4);
mapper.record_access(100, 0.5);
}
}