mod tlsf;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use svod_device::Buffer;
use svod_dtype::{DType, DeviceSpec};
use svod_ir::{Op, UOp};
use tracing::{debug, trace};
use crate::schedule::Schedule;
const MIN_BLOCK_SIZE: usize = 256;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PlannerMode {
Disabled,
Remap,
Arena,
}
pub fn parse_mode(raw: Option<&str>) -> PlannerMode {
let Some(raw) = raw else {
return PlannerMode::Arena;
};
let normalized = raw.trim().to_ascii_lowercase();
match normalized.as_str() {
"0" | "off" | "none" | "disabled" => PlannerMode::Disabled,
"remap" | "pool" => PlannerMode::Remap,
_ => PlannerMode::Arena,
}
}
pub fn mode_from_env() -> PlannerMode {
parse_mode(std::env::var("SVOD_MEMORY_PLANNER").ok().as_deref())
}
type LogicalBufferView = (usize, usize, DType, Vec<usize>);
#[inline]
fn round_up(size: usize, block_size: usize) -> usize {
size.div_ceil(block_size) * block_size
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BufferPoolKey {
pub device: DeviceSpec,
pub dtype: DType,
pub size: usize,
}
#[derive(Debug, Clone)]
pub struct BufferLiveness {
pub first_appearance: usize,
pub last_appearance: usize,
pub pool_key: BufferPoolKey,
pub prototype: Buffer,
}
#[derive(Debug, Clone)]
struct BufferEvent {
timestep: usize,
is_alloc: bool,
buffer_id: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ReuseDependency {
pub predecessor_step: usize,
pub successor_step: usize,
}
struct ReusableBuffer {
buffer: Buffer,
released_by_step: usize,
}
struct PlannerInput {
liveness: HashMap<u64, BufferLiveness>,
occurrences: Vec<(BufferKey, u64)>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BufferKey {
pub kernel_idx: usize,
pub buffer_idx: usize,
}
#[derive(Debug)]
pub struct MemoryPlannerResult {
pub buffer_replace: HashMap<BufferKey, Buffer>,
pub memory_saved: usize,
pub buffers_reused: usize,
pub reuse_dependencies: Vec<ReuseDependency>,
}
fn collect_noopt_buffer_ids(schedule: &Schedule) -> HashSet<u64> {
let mut by_storage: HashMap<u64, HashSet<LogicalBufferView>> = HashMap::new();
let mut masked_store_ids = HashSet::new();
for item in schedule {
for buffer in &item.buffers {
by_storage.entry(buffer.storage_id().0).or_default().insert((
buffer.offset(),
buffer.size(),
buffer.dtype(),
buffer.shape().to_vec(),
));
}
let uop_id_to_buffer_id: HashMap<u64, u64> =
item.buffer_uop_ids.iter().copied().zip(item.buffers.iter().map(|b| b.id().0)).collect();
for node in item.ast.toposort() {
let Op::Store { index, .. } = node.op() else {
continue;
};
collect_masked_store_buffer_ids(index, &uop_id_to_buffer_id, &mut masked_store_ids);
}
}
let aliased_storages: HashSet<u64> =
by_storage.into_iter().filter_map(|(sid, views)| (views.len() > 1).then_some(sid)).collect();
let aliased_ids = schedule.iter().flat_map(|item| {
item.buffers
.iter()
.filter(|b| aliased_storages.contains(&b.storage_id().0))
.map(|b| b.id().0)
.collect::<Vec<_>>()
});
schedule
.iter()
.filter(|item| !matches!(item.ast.op(), Op::Sink { .. }))
.flat_map(|item| item.buffers.iter().map(|b| b.id().0))
.chain(aliased_ids)
.chain(masked_store_ids)
.collect()
}
fn collect_masked_store_buffer_ids(
index: &Arc<UOp>,
uop_id_to_buffer_id: &HashMap<u64, u64>,
masked_store_ids: &mut HashSet<u64>,
) {
match index.op() {
Op::Index { buffer, gate: Some(_), .. } => {
if let Some(buffer_id) = uop_id_to_buffer_id.get(&buffer.buf_uop().id) {
masked_store_ids.insert(*buffer_id);
}
}
Op::Index { .. } => {}
other => {
for child in other.children() {
collect_masked_store_buffer_ids(child, uop_id_to_buffer_id, masked_store_ids);
}
}
}
}
fn should_skip_buffer(buffer: &Buffer, output_buffer_ids: &HashSet<u64>, noopt_buffer_ids: &HashSet<u64>) -> bool {
buffer.allocator().device_spec().is_disk()
|| buffer.offset() != 0
|| buffer.is_allocated()
|| output_buffer_ids.contains(&buffer.id().0)
|| noopt_buffer_ids.contains(&buffer.id().0)
}
fn analyze_liveness(schedule: &Schedule, output_buffer_ids: &HashSet<u64>) -> PlannerInput {
let noopt_buffer_ids = collect_noopt_buffer_ids(schedule);
let mut liveness: HashMap<u64, BufferLiveness> = HashMap::new();
let mut occurrences: Vec<(BufferKey, u64)> = Vec::new();
for (step_idx, item) in schedule.iter().enumerate() {
for (buf_idx, buffer) in item.buffers.iter().enumerate() {
let key = BufferKey { kernel_idx: step_idx, buffer_idx: buf_idx };
let buf_id = buffer.id().0;
if should_skip_buffer(buffer, output_buffer_ids, &noopt_buffer_ids) {
trace!(step_idx, buf_idx, buffer_id = buf_id, "skipping buffer in memory planner");
continue;
}
occurrences.push((key, buf_id));
let pool_key = BufferPoolKey {
device: buffer.allocator().device_spec(),
dtype: buffer.dtype(),
size: round_up(buffer.size(), MIN_BLOCK_SIZE),
};
liveness
.entry(buf_id)
.and_modify(|info| {
info.first_appearance = info.first_appearance.min(step_idx);
info.last_appearance = info.last_appearance.max(step_idx);
})
.or_insert_with(|| BufferLiveness {
first_appearance: step_idx,
last_appearance: step_idx,
pool_key,
prototype: buffer.clone(),
});
}
}
debug!(num_optimizable = liveness.len(), "liveness analysis complete");
PlannerInput { liveness, occurrences }
}
fn build_event_timeline(liveness: &HashMap<u64, BufferLiveness>) -> Vec<BufferEvent> {
let mut events = Vec::with_capacity(liveness.len() * 2);
for (&buf_id, info) in liveness {
events.push(BufferEvent { timestep: info.first_appearance, is_alloc: true, buffer_id: buf_id });
events.push(BufferEvent { timestep: info.last_appearance + 1, is_alloc: false, buffer_id: buf_id });
}
events.sort_by_key(|e| (e.timestep, e.is_alloc, e.buffer_id));
events
}
fn process_events(
events: &[BufferEvent],
liveness: &HashMap<u64, BufferLiveness>,
occurrences: &[(BufferKey, u64)],
) -> (HashMap<BufferKey, Buffer>, usize, usize, Vec<ReuseDependency>) {
let mut free_pools: HashMap<BufferPoolKey, Vec<ReusableBuffer>> = HashMap::new();
let mut memory_saved: usize = 0;
let mut buffers_reused: usize = 0;
let mut reuse_dependencies = Vec::new();
let mut chosen_by_id: HashMap<u64, Buffer> = HashMap::new();
let mut active_buffers: HashMap<u64, Buffer> = HashMap::new();
for event in events {
let info = match liveness.get(&event.buffer_id) {
Some(info) => info,
None => continue,
};
let pool_key = &info.pool_key;
if event.is_alloc {
if let Some(pool) = free_pools.get_mut(pool_key)
&& let Some(reused) = pool.pop()
{
trace!(timestep = event.timestep, reused_buffer_id = reused.buffer.id().0, "reusing buffer from pool");
reuse_dependencies.push(ReuseDependency {
predecessor_step: reused.released_by_step,
successor_step: event.timestep,
});
chosen_by_id.insert(event.buffer_id, reused.buffer.clone());
active_buffers.insert(event.buffer_id, reused.buffer);
memory_saved += pool_key.size;
buffers_reused += 1;
continue;
}
chosen_by_id.insert(event.buffer_id, info.prototype.clone());
active_buffers.insert(event.buffer_id, info.prototype.clone());
} else {
if let Some(buffer) = active_buffers.remove(&event.buffer_id) {
free_pools
.entry(pool_key.clone())
.or_default()
.push(ReusableBuffer { buffer, released_by_step: info.last_appearance });
}
}
}
let mut buffer_replace: HashMap<BufferKey, Buffer> = HashMap::new();
for (key, buf_id) in occurrences {
if let Some(chosen) = chosen_by_id.get(buf_id)
&& chosen.id().0 != *buf_id
{
buffer_replace.insert(*key, chosen.clone());
}
}
(buffer_replace, memory_saved, buffers_reused, reuse_dependencies)
}
type LaneKey = (DeviceSpec, bool);
fn memory_plan_arena(schedule: &Schedule, output_buffer_ids: &HashSet<u64>) -> MemoryPlannerResult {
let empty_result = || MemoryPlannerResult {
buffer_replace: HashMap::new(),
memory_saved: 0,
buffers_reused: 0,
reuse_dependencies: Vec::new(),
};
let planner_input = analyze_liveness(schedule, output_buffer_ids);
let liveness = planner_input.liveness;
if liveness.is_empty() {
return empty_result();
}
let mut copy_bufs: HashSet<u64> = HashSet::new();
for item in schedule {
let runtime_ast = crate::realize::runtime_effect_ast(&item.ast);
if !matches!(runtime_ast.op(), Op::Copy { .. }) {
continue;
}
for buffer in &item.buffers {
let id = buffer.id().0;
if liveness.contains_key(&id) {
copy_bufs.insert(id);
}
}
}
let lane_key = |id: u64| -> LaneKey {
let info = &liveness[&id];
(info.prototype.allocator().device_spec(), copy_bufs.contains(&id))
};
let buf_hold: HashMap<u64, usize> = copy_bufs
.iter()
.map(|&id| {
let info = &liveness[&id];
(id, info.last_appearance - info.first_appearance + 1)
})
.collect();
let nbytes_rounded: HashMap<u64, usize> =
liveness.iter().map(|(&id, info)| (id, round_up(info.prototype.size(), MIN_BLOCK_SIZE))).collect();
let mut events: Vec<BufferEvent> = Vec::with_capacity(liveness.len() * 2);
for (&id, info) in &liveness {
events.push(BufferEvent { timestep: info.first_appearance, is_alloc: true, buffer_id: id });
events.push(BufferEvent {
timestep: info.last_appearance + 1 + buf_hold.get(&id).copied().unwrap_or(0),
is_alloc: false,
buffer_id: id,
});
}
events.sort_by_key(|e| (e.timestep, e.is_alloc, e.buffer_id));
let total_bytes: usize = nbytes_rounded.values().sum();
let arena_budget = total_bytes.saturating_mul(2).max(MIN_BLOCK_SIZE);
let mut tlsfs: HashMap<LaneKey, tlsf::TlsfAllocator> = HashMap::new();
let mut offsets: HashMap<u64, usize> = HashMap::new();
let mut peaks: HashMap<LaneKey, usize> = HashMap::new();
let mut freed_ranges: HashMap<LaneKey, Vec<(usize, usize, usize)>> = HashMap::new();
let mut reuse_dependencies: Vec<ReuseDependency> = Vec::new();
for event in &events {
let lane = lane_key(event.buffer_id);
let info = &liveness[&event.buffer_id];
let alloc =
tlsfs.entry(lane.clone()).or_insert_with(|| tlsf::TlsfAllocator::new(arena_budget, 0, MIN_BLOCK_SIZE, 32));
if event.is_alloc {
let req = nbytes_rounded[&event.buffer_id];
let off = match alloc.alloc(req, 1) {
Ok(o) => o,
Err(e) => {
tracing::warn!(?e, "arena planner: TLSF alloc failed; skipping arena rewrite");
return empty_result();
}
};
offsets.insert(event.buffer_id, off);
let used_end = off + info.prototype.size();
let peak = peaks.entry(lane.clone()).or_insert(0);
if used_end > *peak {
*peak = used_end;
}
let alloc_end = off + req;
if let Some(ranges) = freed_ranges.get(&lane) {
for &(prev_off, prev_end, prev_last_step) in ranges {
let overlaps = off < prev_end && prev_off < alloc_end;
if overlaps {
reuse_dependencies.push(ReuseDependency {
predecessor_step: prev_last_step,
successor_step: info.first_appearance,
});
}
}
}
if let Some(ranges) = freed_ranges.get_mut(&lane) {
ranges.retain(|&(o, e, _)| o >= alloc_end || e <= off);
}
} else if let Some(off) = offsets.get(&event.buffer_id).copied() {
let req = nbytes_rounded[&event.buffer_id];
if let Err(e) = alloc.free(off) {
tracing::warn!(?e, "arena planner: TLSF free failed; skipping arena rewrite");
return empty_result();
}
freed_ranges.entry(lane).or_default().push((off, off + req, info.last_appearance));
}
}
let mut lane_proto: HashMap<LaneKey, Buffer> = HashMap::with_capacity(peaks.len());
for (&id, info) in &liveness {
lane_proto.entry(lane_key(id)).or_insert_with(|| info.prototype.clone());
}
let mut arenas: HashMap<LaneKey, Buffer> = HashMap::new();
for (lane, &peak) in &peaks {
if peak == 0 {
continue;
}
let arena_size = round_up(peak, MIN_BLOCK_SIZE);
let prototype = lane_proto.get(lane).expect("every populated lane must have a prototype");
let arena = Buffer::new(
prototype.allocator_arc(),
svod_dtype::DType::UInt8,
vec![arena_size],
svod_device::allocator::BufferOptions::default(),
);
arenas.insert(lane.clone(), arena);
}
let mut buffer_replace: HashMap<BufferKey, Buffer> = HashMap::new();
let mut buffers_reused = 0usize;
for (key, buf_id) in &planner_input.occurrences {
let Some(&offset) = offsets.get(buf_id) else {
continue;
};
let Some(arena) = arenas.get(&lane_key(*buf_id)) else {
continue;
};
let info = &liveness[buf_id];
let byte_size = info.prototype.size();
let view = match arena.view(offset, byte_size) {
Ok(v) => v,
Err(e) => {
tracing::warn!(?e, "arena planner: view failed; skipping rewrite for one slot");
continue;
}
};
buffer_replace.insert(*key, view);
buffers_reused += 1;
}
let arena_total: usize = peaks.values().map(|&p| round_up(p, MIN_BLOCK_SIZE)).sum();
let memory_saved = total_bytes.saturating_sub(arena_total);
debug!(
buffers_planned = liveness.len(),
buffers_replaced = buffers_reused,
memory_saved_bytes = memory_saved,
arena_count = arenas.len(),
"arena memory planner complete"
);
MemoryPlannerResult { buffer_replace, memory_saved, buffers_reused, reuse_dependencies }
}
#[allow(rustdoc::private_intra_doc_links)]
pub fn memory_planner(schedule: &Schedule, output_buffer_ids: &HashSet<u64>, mode: PlannerMode) -> MemoryPlannerResult {
let empty_result = || MemoryPlannerResult {
buffer_replace: HashMap::new(),
memory_saved: 0,
buffers_reused: 0,
reuse_dependencies: Vec::new(),
};
if matches!(mode, PlannerMode::Disabled) {
return empty_result();
}
if schedule.is_empty() {
return empty_result();
}
if matches!(mode, PlannerMode::Arena) {
return memory_plan_arena(schedule, output_buffer_ids);
}
let planner_input = analyze_liveness(schedule, output_buffer_ids);
let liveness = planner_input.liveness;
if liveness.is_empty() {
debug!("no optimizable buffers found");
return empty_result();
}
let events = build_event_timeline(&liveness);
let (buffer_replace, memory_saved, buffers_reused, reuse_dependencies) =
process_events(&events, &liveness, &planner_input.occurrences);
debug!(
buffers_analyzed = liveness.len(),
buffers_reused,
memory_saved_bytes = memory_saved,
"memory planner complete"
);
MemoryPlannerResult { buffer_replace, memory_saved, buffers_reused, reuse_dependencies }
}
pub fn apply_buffer_replacements(schedule: &mut Schedule, replacements: &HashMap<BufferKey, Buffer>) {
for (&key, replacement) in replacements {
if let Some(item) = schedule.get_mut(key.kernel_idx)
&& let Some(buffer) = item.buffers.get_mut(key.buffer_idx)
{
*buffer = replacement.clone();
}
}
}
pub fn apply_reuse_dependencies(schedule: &mut Schedule, reuse_dependencies: &[ReuseDependency]) {
for dep in reuse_dependencies {
if dep.predecessor_step == dep.successor_step {
continue;
}
debug_assert!(
dep.successor_step > dep.predecessor_step,
"reuse dependency must be forward-edge: predecessor={} >= successor={}",
dep.predecessor_step,
dep.successor_step,
);
if dep.predecessor_step >= schedule.len() {
continue;
}
let Some(successor) = schedule.get_mut(dep.successor_step) else {
continue;
};
if !successor.instance_dependencies.contains(&dep.predecessor_step) {
successor.instance_dependencies.push(dep.predecessor_step);
}
}
}
#[cfg(test)]
#[path = "../test/unit/memory_planner.rs"]
mod tests;