use crate::device::context::DEVICE_MANAGER;
use crate::{Device, Result};
use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, Instant};
use super::config::EagerExecutionConfig;
#[allow(dead_code)]
pub(super) struct MemoryPool {
pub(super) blocks: RwLock<HashMap<Device, Vec<MemoryBlock>>>,
pub(super) config: EagerExecutionConfig,
}
#[derive(Debug)]
#[allow(dead_code)]
pub(super) struct MemoryBlock {
pub(super) ptr: *mut u8,
pub(super) size: usize,
pub(super) available: bool,
pub(super) last_used: Instant,
}
unsafe impl Send for MemoryBlock {}
unsafe impl Sync for MemoryBlock {}
#[allow(dead_code)]
impl MemoryPool {
pub(super) fn new(config: EagerExecutionConfig) -> Self {
Self {
blocks: RwLock::new(HashMap::new()),
config,
}
}
pub(super) fn allocate(&self, device: &Device, size: usize) -> Result<*mut u8> {
let _start = Instant::now();
{
let mut blocks = self
.blocks
.write()
.expect("write lock should not be poisoned");
let device_blocks = blocks.entry(*device).or_default();
for block in device_blocks.iter_mut() {
if block.available && block.size >= size {
block.available = false;
block.last_used = Instant::now();
return Ok(block.ptr);
}
}
}
let context = DEVICE_MANAGER.get_context(device)?;
let ptr = context.allocator().allocate(size)?;
{
let mut blocks = self
.blocks
.write()
.expect("write lock should not be poisoned");
let device_blocks = blocks.entry(*device).or_default();
device_blocks.push(MemoryBlock {
ptr,
size,
available: false,
last_used: Instant::now(),
});
}
Ok(ptr)
}
pub(super) fn deallocate(&self, device: &Device, ptr: *mut u8) -> Result<()> {
let mut blocks = self
.blocks
.write()
.expect("write lock should not be poisoned");
if let Some(device_blocks) = blocks.get_mut(device) {
for block in device_blocks.iter_mut() {
if block.ptr == ptr {
block.available = true;
block.last_used = Instant::now();
return Ok(());
}
}
}
Ok(())
}
pub(super) fn cleanup_old_blocks(&self) {
let threshold = Duration::from_secs(60); let now = Instant::now();
let mut blocks = self
.blocks
.write()
.expect("write lock should not be poisoned");
for device_blocks in blocks.values_mut() {
device_blocks.retain(|block| {
if block.available && now.duration_since(block.last_used) > threshold {
false
} else {
true
}
});
}
}
pub(super) fn pre_warm(&self, device: &Device, size: usize, num_blocks: usize) -> Result<()> {
let context = DEVICE_MANAGER.get_context(device)?;
let mut blocks = self
.blocks
.write()
.expect("write lock should not be poisoned");
let device_blocks = blocks.entry(*device).or_default();
for _ in 0..num_blocks {
let ptr = context.allocator().allocate(size)?;
device_blocks.push(MemoryBlock {
ptr,
size,
available: true, last_used: Instant::now(),
});
}
Ok(())
}
}
#[allow(dead_code)]
pub(super) struct MemoryGuard {
pub(super) device: Device,
pub(super) estimated_memory: usize,
pub(super) operation: String,
}