use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use thiserror::Error;
#[derive(Error, Debug, Clone, PartialEq)]
pub enum WorkspaceError {
#[error(
"Workspace allocation failed: requested {requested} bytes, available {available} bytes"
)]
AllocationFailed { requested: usize, available: usize },
#[error("Workspace not found: {0}")]
NotFound(String),
#[error("Invalid workspace size: {0}")]
InvalidSize(usize),
#[error("Workspace limit exceeded: {limit} bytes")]
LimitExceeded { limit: usize },
#[error("Workspace is in use")]
InUse,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AllocationStrategy {
BestFit,
FirstFit,
ExactFit,
PowerOfTwo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkspaceConfig {
pub initial_size: usize,
pub max_size: usize,
pub growth_factor: f64,
pub strategy: AllocationStrategy,
pub auto_expand: bool,
pub enable_defragmentation: bool,
pub defrag_threshold: f64,
pub num_buckets: usize,
}
impl Default for WorkspaceConfig {
fn default() -> Self {
Self {
initial_size: 1024 * 1024, max_size: 1024 * 1024 * 1024, growth_factor: 2.0,
strategy: AllocationStrategy::BestFit,
auto_expand: true,
enable_defragmentation: false,
defrag_threshold: 0.5,
num_buckets: 16,
}
}
}
impl WorkspaceConfig {
pub fn large_model() -> Self {
Self {
initial_size: 64 * 1024 * 1024, max_size: 8 * 1024 * 1024 * 1024, growth_factor: 1.5,
num_buckets: 32,
..Default::default()
}
}
pub fn small_model() -> Self {
Self {
initial_size: 256 * 1024, max_size: 128 * 1024 * 1024, growth_factor: 2.0,
num_buckets: 8,
..Default::default()
}
}
pub fn memory_optimized() -> Self {
Self {
initial_size: 512 * 1024, max_size: 256 * 1024 * 1024, growth_factor: 1.2,
enable_defragmentation: true,
defrag_threshold: 0.3,
..Default::default()
}
}
}
#[derive(Debug, Clone)]
pub struct Workspace {
pub id: String,
pub size: usize,
pub in_use: bool,
pub allocation_count: usize,
pub total_use_time: std::time::Duration,
}
impl Workspace {
pub fn new(id: String, size: usize) -> Self {
Self {
id,
size,
in_use: false,
allocation_count: 0,
total_use_time: std::time::Duration::ZERO,
}
}
pub fn acquire(&mut self) -> Result<(), WorkspaceError> {
if self.in_use {
return Err(WorkspaceError::InUse);
}
self.in_use = true;
self.allocation_count += 1;
Ok(())
}
pub fn release(&mut self) {
self.in_use = false;
}
}
pub struct WorkspacePool {
config: WorkspaceConfig,
workspaces: HashMap<String, Workspace>,
free_lists: HashMap<usize, VecDeque<String>>, next_id: usize,
stats: WorkspaceStats,
}
impl WorkspacePool {
pub fn new(config: WorkspaceConfig) -> Self {
let mut pool = Self {
config,
workspaces: HashMap::new(),
free_lists: HashMap::new(),
next_id: 0,
stats: WorkspaceStats::default(),
};
pool.preallocate_workspaces();
pool
}
fn preallocate_workspaces(&mut self) {
let sizes = self.compute_bucket_sizes();
for size in sizes {
let _ = self.create_workspace(size);
}
}
fn compute_bucket_sizes(&self) -> Vec<usize> {
let mut sizes = Vec::new();
let mut size = self.config.initial_size;
for _ in 0..self.config.num_buckets {
sizes.push(size);
size = (size as f64 * self.config.growth_factor) as usize;
if size > self.config.max_size {
break;
}
}
sizes
}
fn create_workspace(&mut self, size: usize) -> String {
let id = format!("ws_{}", self.next_id);
self.next_id += 1;
let workspace = Workspace::new(id.clone(), size);
self.workspaces.insert(id.clone(), workspace);
let bucket = self.size_to_bucket(size);
self.free_lists
.entry(bucket)
.or_default()
.push_back(id.clone());
self.stats.total_created += 1;
self.stats.current_total_size += size;
id
}
fn size_to_bucket(&self, size: usize) -> usize {
match self.config.strategy {
AllocationStrategy::PowerOfTwo => size.next_power_of_two(),
_ => {
let sizes = self.compute_bucket_sizes();
sizes.iter().find(|&&s| s >= size).copied().unwrap_or(size)
}
}
}
pub fn allocate(&mut self, size: usize) -> Result<String, WorkspaceError> {
if size > self.config.max_size {
return Err(WorkspaceError::InvalidSize(size));
}
let workspace_id = match self.config.strategy {
AllocationStrategy::BestFit => self.find_best_fit(size),
AllocationStrategy::FirstFit => self.find_first_fit(size),
AllocationStrategy::ExactFit => self.find_exact_fit(size),
AllocationStrategy::PowerOfTwo => {
let bucket_size = size.next_power_of_two();
self.find_first_fit(bucket_size)
}
};
match workspace_id {
Some(id) => {
self.workspaces
.get_mut(&id)
.expect("workspace id from find_first_fit or create_workspace is valid")
.acquire()?;
self.stats.total_allocations += 1;
Ok(id)
}
None => {
if self.config.auto_expand {
let new_size = self.size_to_bucket(size);
let id = self.create_workspace(new_size);
self.workspaces
.get_mut(&id)
.expect("workspace id from create_workspace is valid")
.acquire()?;
self.stats.total_allocations += 1;
self.stats.total_expansions += 1;
Ok(id)
} else {
Err(WorkspaceError::AllocationFailed {
requested: size,
available: self.max_available_size(),
})
}
}
}
}
pub fn release(&mut self, id: &str) -> Result<(), WorkspaceError> {
let workspace_size = {
let workspace = self
.workspaces
.get_mut(id)
.ok_or_else(|| WorkspaceError::NotFound(id.to_string()))?;
workspace.release();
workspace.size
};
self.stats.total_releases += 1;
let bucket = self.size_to_bucket(workspace_size);
self.free_lists
.entry(bucket)
.or_default()
.push_back(id.to_string());
Ok(())
}
fn find_best_fit(&mut self, size: usize) -> Option<String> {
let mut best_id: Option<String> = None;
let mut best_size = usize::MAX;
for (ws_id, workspace) in &self.workspaces {
if !workspace.in_use && workspace.size >= size && workspace.size < best_size {
best_id = Some(ws_id.clone());
best_size = workspace.size;
}
}
if let Some(ref id) = best_id {
let bucket = self.size_to_bucket(best_size);
if let Some(list) = self.free_lists.get_mut(&bucket) {
list.retain(|ws_id| ws_id != id);
}
}
best_id
}
fn find_first_fit(&mut self, size: usize) -> Option<String> {
for (ws_id, workspace) in &self.workspaces {
if !workspace.in_use && workspace.size >= size {
let id = ws_id.clone();
let bucket = self.size_to_bucket(workspace.size);
if let Some(list) = self.free_lists.get_mut(&bucket) {
list.retain(|ws_id| ws_id != &id);
}
return Some(id);
}
}
None
}
fn find_exact_fit(&mut self, size: usize) -> Option<String> {
let bucket = self.size_to_bucket(size);
if let Some(list) = self.free_lists.get_mut(&bucket) {
list.pop_front()
} else {
None
}
}
fn max_available_size(&self) -> usize {
self.workspaces
.values()
.filter(|ws| !ws.in_use)
.map(|ws| ws.size)
.max()
.unwrap_or(0)
}
pub fn stats(&self) -> &WorkspaceStats {
&self.stats
}
pub fn defragment(&mut self) -> DefragmentationResult {
if !self.config.enable_defragmentation {
return DefragmentationResult {
freed_bytes: 0,
merged_workspaces: 0,
};
}
let fragmentation_ratio = self.compute_fragmentation_ratio();
if fragmentation_ratio < self.config.defrag_threshold {
return DefragmentationResult {
freed_bytes: 0,
merged_workspaces: 0,
};
}
let mut free_blocks: Vec<(String, usize)> = self
.workspaces
.iter()
.filter_map(|(id, ws)| {
if ws.in_use {
None
} else {
Some((id.clone(), ws.size))
}
})
.collect();
if free_blocks.len() < 2 {
self.stats.total_defragmentations += 1;
return DefragmentationResult {
freed_bytes: 0,
merged_workspaces: 0,
};
}
free_blocks.sort_by_key(|(_, size)| *size);
let freed_bytes = 0;
let mut merged_workspaces = 0;
let mut pair_index = 0usize;
while pair_index + 1 < free_blocks.len() {
let (id_a, size_a) = &free_blocks[pair_index];
let (id_b, size_b) = &free_blocks[pair_index + 1];
let merged_size = size_a.saturating_add(*size_b);
if merged_size > self.config.max_size {
pair_index += 2;
continue;
}
self.remove_from_free_list(id_a, *size_a);
self.remove_from_free_list(id_b, *size_b);
self.workspaces.remove(id_a);
self.workspaces.remove(id_b);
let merged_id = format!("ws_{}", self.next_id);
self.next_id += 1;
self.workspaces.insert(
merged_id.clone(),
Workspace::new(merged_id.clone(), merged_size),
);
let bucket = self.size_to_bucket(merged_size);
self.free_lists
.entry(bucket)
.or_default()
.push_back(merged_id);
merged_workspaces += 1;
pair_index += 2;
}
self.stats.total_defragmentations += 1;
DefragmentationResult {
freed_bytes,
merged_workspaces,
}
}
fn remove_from_free_list(&mut self, id: &str, size: usize) {
let bucket = self.size_to_bucket(size);
if let Some(list) = self.free_lists.get_mut(&bucket) {
list.retain(|ws_id| ws_id != id);
}
}
fn compute_fragmentation_ratio(&self) -> f64 {
let total_free = self
.workspaces
.values()
.filter(|ws| !ws.in_use)
.map(|ws| ws.size)
.sum::<usize>();
let max_free = self.max_available_size();
if total_free == 0 {
0.0
} else {
1.0 - (max_free as f64 / total_free as f64)
}
}
pub fn clear(&mut self) {
self.workspaces.clear();
self.free_lists.clear();
self.stats = WorkspaceStats::default();
self.preallocate_workspaces();
}
}
pub struct SharedWorkspacePool {
inner: Arc<Mutex<WorkspacePool>>,
}
impl SharedWorkspacePool {
pub fn new(config: WorkspaceConfig) -> Self {
Self {
inner: Arc::new(Mutex::new(WorkspacePool::new(config))),
}
}
pub fn allocate(&self, size: usize) -> Result<String, WorkspaceError> {
self.inner
.lock()
.expect("lock should not be poisoned")
.allocate(size)
}
pub fn release(&self, id: &str) -> Result<(), WorkspaceError> {
self.inner
.lock()
.expect("lock should not be poisoned")
.release(id)
}
pub fn stats(&self) -> WorkspaceStats {
self.inner
.lock()
.expect("lock should not be poisoned")
.stats()
.clone()
}
pub fn defragment(&self) -> DefragmentationResult {
self.inner
.lock()
.expect("lock should not be poisoned")
.defragment()
}
}
impl Clone for SharedWorkspacePool {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct WorkspaceStats {
pub total_created: usize,
pub total_allocations: usize,
pub total_releases: usize,
pub total_expansions: usize,
pub total_defragmentations: usize,
pub current_total_size: usize,
}
impl WorkspaceStats {
pub fn hit_rate(&self) -> f64 {
if self.total_allocations == 0 {
0.0
} else {
1.0 - (self.total_expansions as f64 / self.total_allocations as f64)
}
}
pub fn avg_workspace_size(&self) -> f64 {
if self.total_created == 0 {
0.0
} else {
self.current_total_size as f64 / self.total_created as f64
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DefragmentationResult {
pub freed_bytes: usize,
pub merged_workspaces: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_workspace_creation() {
let ws = Workspace::new("test".to_string(), 1024);
assert_eq!(ws.size, 1024);
assert!(!ws.in_use);
assert_eq!(ws.allocation_count, 0);
}
#[test]
fn test_workspace_acquire_release() {
let mut ws = Workspace::new("test".to_string(), 1024);
assert!(ws.acquire().is_ok());
assert!(ws.in_use);
assert_eq!(ws.allocation_count, 1);
assert!(ws.acquire().is_err());
ws.release();
assert!(!ws.in_use);
assert!(ws.acquire().is_ok());
assert_eq!(ws.allocation_count, 2);
}
#[test]
fn test_workspace_config() {
let config = WorkspaceConfig::large_model();
assert!(config.initial_size > WorkspaceConfig::default().initial_size);
let config = WorkspaceConfig::small_model();
assert!(config.max_size < WorkspaceConfig::default().max_size);
}
#[test]
fn test_workspace_pool_creation() {
let config = WorkspaceConfig::default();
let pool = WorkspacePool::new(config);
assert!(pool.stats().total_created > 0);
}
#[test]
fn test_workspace_allocation() {
let config = WorkspaceConfig::default();
let mut pool = WorkspacePool::new(config);
let id = pool.allocate(512).expect("unwrap");
assert!(!id.is_empty());
let workspace = pool.workspaces.get(&id).expect("unwrap");
assert!(workspace.in_use);
assert!(workspace.size >= 512);
}
#[test]
fn test_workspace_release() {
let config = WorkspaceConfig::default();
let mut pool = WorkspacePool::new(config);
let id = pool.allocate(512).expect("unwrap");
assert!(pool.release(&id).is_ok());
let workspace = pool.workspaces.get(&id).expect("unwrap");
assert!(!workspace.in_use);
}
#[test]
fn test_allocation_strategies() {
for strategy in [
AllocationStrategy::BestFit,
AllocationStrategy::FirstFit,
AllocationStrategy::ExactFit,
AllocationStrategy::PowerOfTwo,
] {
let config = WorkspaceConfig {
strategy,
..Default::default()
};
let mut pool = WorkspacePool::new(config);
let id = pool.allocate(512);
assert!(id.is_ok());
}
}
#[test]
fn test_auto_expansion() {
let config = WorkspaceConfig {
initial_size: 1024,
max_size: 1024 * 1024,
auto_expand: true,
num_buckets: 2, ..Default::default()
};
let mut pool = WorkspacePool::new(config);
let initial_expansions = pool.stats().total_expansions;
let id = pool.allocate(5 * 1024);
assert!(id.is_ok());
assert!(pool.stats().total_expansions > initial_expansions);
}
#[test]
fn test_allocation_without_expansion() {
let config = WorkspaceConfig {
initial_size: 1024,
max_size: 2048,
auto_expand: false,
..Default::default()
};
let mut pool = WorkspacePool::new(config);
let result = pool.allocate(100 * 1024);
let _ = result;
}
#[test]
fn test_stats_hit_rate() {
let stats = WorkspaceStats {
total_allocations: 10,
total_expansions: 2,
..Default::default()
};
assert_eq!(stats.hit_rate(), 0.8);
}
#[test]
fn test_shared_workspace_pool() {
let config = WorkspaceConfig::default();
let pool = SharedWorkspacePool::new(config);
let id = pool.allocate(512).expect("unwrap");
assert!(pool.release(&id).is_ok());
let stats = pool.stats();
assert!(stats.total_allocations > 0);
}
#[test]
fn test_fragmentation_ratio() {
let config = WorkspaceConfig::default();
let pool = WorkspacePool::new(config);
let ratio = pool.compute_fragmentation_ratio();
assert!((0.0..=1.0).contains(&ratio));
}
#[test]
fn test_defragmentation() {
let config = WorkspaceConfig {
enable_defragmentation: true,
..Default::default()
};
let mut pool = WorkspacePool::new(config);
let result = pool.defragment();
assert_eq!(result.freed_bytes, 0);
}
#[test]
fn test_defragmentation_merges_when_threshold_met() {
let config = WorkspaceConfig {
enable_defragmentation: true,
defrag_threshold: 0.0,
num_buckets: 4,
..Default::default()
};
let mut pool = WorkspacePool::new(config);
let before = pool.workspaces.len();
let result = pool.defragment();
let after = pool.workspaces.len();
assert_eq!(result.freed_bytes, 0);
assert!(result.merged_workspaces > 0);
assert!(after < before);
}
}