use std::sync::Arc;
use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore};
pub struct MemoryPermitting {
inner: Arc<Semaphore>,
mem_in_bytes: usize,
}
impl MemoryPermitting {
pub const MAX: usize = usize::MAX >> 3;
#[must_use]
pub fn new(mem_in_bytes: usize) -> Self {
Self { inner: Arc::new(Semaphore::new(mem_in_bytes)), mem_in_bytes }
}
#[must_use]
pub fn total_memory(&self) -> usize {
self.inner.available_permits()
}
pub async fn acquire(&self, mem_in_bytes: usize) -> Result<MemoryPermit, MemoryPermitError> {
if mem_in_bytes > self.mem_in_bytes {
return Err(MemoryPermitError::ExceedsMaxPermittedMemory);
} else if mem_in_bytes == 0 {
return Err(MemoryPermitError::TriedToAcquireZero);
}
let permits = accquire_raw(&self.inner, mem_in_bytes).await?;
Ok(MemoryPermit { inner: permits, mem_in_bytes: self.mem_in_bytes })
}
}
impl Clone for MemoryPermitting {
fn clone(&self) -> Self {
Self { inner: self.inner.clone(), mem_in_bytes: self.mem_in_bytes }
}
}
#[derive(Debug, thiserror::Error)]
pub enum MemoryPermitError {
#[error("Request a permit for 0 memory.")]
TriedToAcquireZero,
#[error("Requested memory exceeds the maximum permitted memory")]
ExceedsMaxPermittedMemory,
#[error("Split request with insufficient memory permit")]
NotEnoughMemoryToSplit,
#[error("The semaphore has been explicitly closed, this is a bug")]
Closed(#[from] AcquireError),
}
pub struct MemoryPermit {
inner: Vec<OwnedSemaphorePermit>,
mem_in_bytes: usize,
}
impl MemoryPermit {
#[must_use]
pub fn num_bytes(&self) -> usize {
#[allow(clippy::pedantic)]
self.inner.iter().map(|p| p.num_permits()).sum()
}
pub fn split(&mut self, mem_in_bytes: usize) -> Result<MemoryPermit, MemoryPermitError> {
if mem_in_bytes > self.num_bytes() {
return Err(MemoryPermitError::NotEnoughMemoryToSplit);
} else if mem_in_bytes == 0 {
return Err(MemoryPermitError::TriedToAcquireZero);
}
let mut permits = Vec::new();
let mut to_acquire = mem_in_bytes;
while let Some(permit) = self.inner.last_mut() {
let num_permits = permit.num_permits();
if num_permits <= to_acquire {
to_acquire -= num_permits;
permits.push(self.inner.pop().unwrap());
} else {
permits.push(permit.split(to_acquire).unwrap());
}
}
Ok(MemoryPermit { inner: permits, mem_in_bytes: self.mem_in_bytes })
}
pub async fn increase(&mut self, mem_in_bytes: usize) -> Result<(), MemoryPermitError> {
if mem_in_bytes == 0 {
return Ok(());
}
self.inner.extend(
accquire_raw(
self.inner
.first()
.expect("We should have at least one permit, this is a bug.")
.semaphore(),
mem_in_bytes,
)
.await?,
);
Ok(())
}
pub fn release(&mut self, mem_in_bytes: usize) -> Result<(), MemoryPermitError> {
if mem_in_bytes == 0 {
return Ok(());
}
let _ = self.split(mem_in_bytes)?;
Ok(())
}
}
async fn accquire_raw(
inner: &Arc<Semaphore>,
mem_in_bytes: usize,
) -> Result<Vec<OwnedSemaphorePermit>, MemoryPermitError> {
let mut permits = Vec::new();
let mut to_acquire = mem_in_bytes;
while to_acquire > 0 {
let n = to_acquire.min(u32::MAX as usize);
let permit = inner.clone().acquire_many_owned(n as u32).await?;
permits.push(permit);
to_acquire -= n;
}
Ok(permits)
}