Skip to main content

optirs_gpu/memory/allocation/
slab_allocator.rs

1// Slab allocator for GPU memory management
2//
3// This module implements a slab allocator optimized for fixed-size allocations.
4// Slab allocation is highly efficient for objects of the same size and provides
5// excellent cache locality and minimal fragmentation.
6
7#[allow(dead_code)]
8use std::collections::{HashMap, VecDeque};
9use std::ptr::NonNull;
10use std::sync::{Arc, Mutex};
11use std::time::Instant;
12
13/// Slab allocator for fixed-size objects
14pub struct SlabAllocator {
15    /// Cache configurations indexed by object size
16    caches: HashMap<usize, SlabCache>,
17    /// Statistics for the entire allocator
18    stats: SlabStats,
19    /// Configuration
20    config: SlabConfig,
21    /// Memory pool for backing slabs
22    memory_pool: MemoryPool,
23}
24
25/// Slab cache for objects of a specific size
26pub struct SlabCache {
27    /// Object size for this cache
28    object_size: usize,
29    /// List of slabs (pages)
30    slabs: Vec<Slab>,
31    /// Partially filled slabs
32    partial_slabs: VecDeque<usize>,
33    /// Full slabs
34    full_slabs: Vec<usize>,
35    /// Empty slabs
36    empty_slabs: VecDeque<usize>,
37    /// Cache statistics
38    stats: CacheStats,
39    /// Cache configuration
40    config: CacheConfig,
41}
42
43/// Individual slab (page) containing multiple objects
44pub struct Slab {
45    /// Base address of the slab
46    base_ptr: NonNull<u8>,
47    /// Size of the slab in bytes
48    slab_size: usize,
49    /// Object size
50    object_size: usize,
51    /// Number of objects in this slab
52    object_count: usize,
53    /// Free object list (indices)
54    free_objects: VecDeque<usize>,
55    /// Allocated object count
56    allocated_count: usize,
57    /// Allocation bitmap for fast lookups
58    allocation_bitmap: Vec<u64>,
59    /// Slab creation time
60    created_at: Instant,
61    /// Last allocation time
62    last_alloc: Option<Instant>,
63    /// Last deallocation time
64    last_dealloc: Option<Instant>,
65    /// Access frequency counter
66    access_count: u64,
67}
68
69impl Slab {
70    pub fn new(base_ptr: NonNull<u8>, slab_size: usize, object_size: usize) -> Self {
71        let object_count = slab_size / object_size;
72        let bitmap_size = object_count.div_ceil(64); // Round up to nearest 64-bit word
73
74        let mut free_objects = VecDeque::with_capacity(object_count);
75        for i in 0..object_count {
76            free_objects.push_back(i);
77        }
78
79        Self {
80            base_ptr,
81            slab_size,
82            object_size,
83            object_count,
84            free_objects,
85            allocated_count: 0,
86            allocation_bitmap: vec![0; bitmap_size],
87            created_at: Instant::now(),
88            last_alloc: None,
89            last_dealloc: None,
90            access_count: 0,
91        }
92    }
93
94    /// Allocate an object from this slab
95    pub fn allocate(&mut self) -> Option<NonNull<u8>> {
96        if let Some(object_index) = self.free_objects.pop_front() {
97            // Mark object as allocated in bitmap
98            let word_index = object_index / 64;
99            let bit_index = object_index % 64;
100            self.allocation_bitmap[word_index] |= 1u64 << bit_index;
101
102            self.allocated_count += 1;
103            self.last_alloc = Some(Instant::now());
104            self.access_count += 1;
105
106            // Calculate object address
107            let object_offset = object_index * self.object_size;
108            let object_ptr =
109                unsafe { NonNull::new_unchecked(self.base_ptr.as_ptr().add(object_offset)) };
110
111            Some(object_ptr)
112        } else {
113            None
114        }
115    }
116
117    /// Deallocate an object in this slab
118    pub fn deallocate(&mut self, ptr: NonNull<u8>) -> Result<(), SlabError> {
119        // Calculate object index from pointer
120        let ptr_addr = ptr.as_ptr() as usize;
121        let base_addr = self.base_ptr.as_ptr() as usize;
122
123        if ptr_addr < base_addr || ptr_addr >= base_addr + self.slab_size {
124            return Err(SlabError::InvalidPointer(
125                "Pointer not in this slab".to_string(),
126            ));
127        }
128
129        let offset = ptr_addr - base_addr;
130        if !offset.is_multiple_of(self.object_size) {
131            return Err(SlabError::InvalidPointer(
132                "Pointer not aligned to object boundary".to_string(),
133            ));
134        }
135
136        let object_index = offset / self.object_size;
137        if object_index >= self.object_count {
138            return Err(SlabError::InvalidPointer(
139                "Object index out of bounds".to_string(),
140            ));
141        }
142
143        // Check if object is actually allocated
144        let word_index = object_index / 64;
145        let bit_index = object_index % 64;
146        if (self.allocation_bitmap[word_index] & (1u64 << bit_index)) == 0 {
147            return Err(SlabError::DoubleFree("Object already free".to_string()));
148        }
149
150        // Mark as free
151        self.allocation_bitmap[word_index] &= !(1u64 << bit_index);
152        self.free_objects.push_back(object_index);
153        self.allocated_count -= 1;
154        self.last_dealloc = Some(Instant::now());
155
156        Ok(())
157    }
158
159    /// Check if slab is full
160    pub fn is_full(&self) -> bool {
161        self.allocated_count == self.object_count
162    }
163
164    /// Check if slab is empty
165    pub fn is_empty(&self) -> bool {
166        self.allocated_count == 0
167    }
168
169    /// Check if slab is partially filled
170    pub fn is_partial(&self) -> bool {
171        self.allocated_count > 0 && self.allocated_count < self.object_count
172    }
173
174    /// Get utilization ratio (0.0 to 1.0)
175    pub fn get_utilization(&self) -> f64 {
176        self.allocated_count as f64 / self.object_count as f64
177    }
178
179    /// Get slab statistics
180    pub fn get_stats(&self) -> SlabStats {
181        SlabStats {
182            total_objects: self.object_count,
183            allocated_objects: self.allocated_count,
184            free_objects: self.object_count - self.allocated_count,
185            utilization: self.get_utilization(),
186            access_count: self.access_count,
187            age: self.created_at.elapsed(),
188        }
189    }
190}
191
192/// Memory pool for backing slab storage
193pub struct MemoryPool {
194    /// Base address of the memory pool
195    base_ptr: NonNull<u8>,
196    /// Total size of the memory pool
197    total_size: usize,
198    /// Current allocation offset
199    current_offset: usize,
200    /// Free regions for reuse
201    free_regions: VecDeque<FreeRegion>,
202    /// Allocation alignment
203    alignment: usize,
204}
205
206/// Free memory region
207#[derive(Debug, Clone)]
208pub struct FreeRegion {
209    pub offset: usize,
210    pub size: usize,
211    pub freed_at: Instant,
212}
213
214impl MemoryPool {
215    pub fn new(base_ptr: NonNull<u8>, total_size: usize, alignment: usize) -> Self {
216        Self {
217            base_ptr,
218            total_size,
219            current_offset: 0,
220            free_regions: VecDeque::new(),
221            alignment,
222        }
223    }
224
225    /// Allocate a slab from the memory pool
226    pub fn allocate_slab(&mut self, size: usize) -> Option<NonNull<u8>> {
227        let aligned_size = (size + self.alignment - 1) & !(self.alignment - 1);
228
229        // Try to reuse a free region first
230        if let Some(region_index) = self.find_suitable_free_region(aligned_size) {
231            let region = self
232                .free_regions
233                .remove(region_index)
234                .expect("unwrap failed");
235            let ptr = unsafe { NonNull::new_unchecked(self.base_ptr.as_ptr().add(region.offset)) };
236
237            // If region is larger than needed, split it
238            if region.size > aligned_size {
239                let remaining_region = FreeRegion {
240                    offset: region.offset + aligned_size,
241                    size: region.size - aligned_size,
242                    freed_at: region.freed_at,
243                };
244                self.free_regions.push_back(remaining_region);
245            }
246
247            return Some(ptr);
248        }
249
250        // Allocate from the end of the pool
251        if self.current_offset + aligned_size <= self.total_size {
252            let ptr =
253                unsafe { NonNull::new_unchecked(self.base_ptr.as_ptr().add(self.current_offset)) };
254            self.current_offset += aligned_size;
255            Some(ptr)
256        } else {
257            None
258        }
259    }
260
261    /// Free a slab back to the memory pool
262    pub fn free_slab(&mut self, ptr: NonNull<u8>, size: usize) {
263        let base_addr = self.base_ptr.as_ptr() as usize;
264        let ptr_addr = ptr.as_ptr() as usize;
265
266        if ptr_addr >= base_addr && ptr_addr < base_addr + self.total_size {
267            let offset = ptr_addr - base_addr;
268            let region = FreeRegion {
269                offset,
270                size,
271                freed_at: Instant::now(),
272            };
273
274            // Insert in sorted order to facilitate coalescing
275            let insert_pos = self
276                .free_regions
277                .binary_search_by_key(&offset, |r| r.offset)
278                .unwrap_or_else(|pos| pos);
279
280            self.free_regions.insert(insert_pos, region);
281
282            // Try to coalesce adjacent regions
283            self.coalesce_free_regions();
284        }
285    }
286
287    fn find_suitable_free_region(&self, size: usize) -> Option<usize> {
288        self.free_regions
289            .iter()
290            .position(|region| region.size >= size)
291    }
292
293    fn coalesce_free_regions(&mut self) {
294        let mut i = 0;
295        while i < self.free_regions.len().saturating_sub(1) {
296            let current_end = self.free_regions[i].offset + self.free_regions[i].size;
297            if current_end == self.free_regions[i + 1].offset {
298                // Coalesce regions
299                let next_region = self.free_regions.remove(i + 1).expect("unwrap failed");
300                self.free_regions[i].size += next_region.size;
301            } else {
302                i += 1;
303            }
304        }
305    }
306
307    pub fn get_usage(&self) -> MemoryPoolUsage {
308        let free_size = self.free_regions.iter().map(|r| r.size).sum::<usize>();
309        let allocated_size = self.current_offset - free_size;
310
311        MemoryPoolUsage {
312            total_size: self.total_size,
313            allocated_size,
314            free_size,
315            current_offset: self.current_offset,
316            free_regions: self.free_regions.len(),
317        }
318    }
319}
320
321/// Cache configuration
322#[derive(Debug, Clone)]
323pub struct CacheConfig {
324    /// Objects per slab
325    pub objects_per_slab: usize,
326    /// Maximum number of empty slabs to keep
327    pub max_empty_slabs: usize,
328    /// Enable slab coloring for cache performance
329    pub enable_coloring: bool,
330    /// Color offset for cache line alignment
331    pub color_offset: usize,
332    /// Enable object construction/destruction
333    pub enable_ctor_dtor: bool,
334    /// Object constructor function
335    pub constructor: Option<fn(*mut u8)>,
336    /// Object destructor function  
337    pub destructor: Option<fn(*mut u8)>,
338}
339
340impl Default for CacheConfig {
341    fn default() -> Self {
342        Self {
343            objects_per_slab: 64,
344            max_empty_slabs: 3,
345            enable_coloring: true,
346            color_offset: 0,
347            enable_ctor_dtor: false,
348            constructor: None,
349            destructor: None,
350        }
351    }
352}
353
354/// Slab allocator configuration
355#[derive(Debug, Clone)]
356pub struct SlabConfig {
357    /// Default slab size
358    pub default_slab_size: usize,
359    /// Memory alignment requirement
360    pub alignment: usize,
361    /// Enable statistics collection
362    pub enable_stats: bool,
363    /// Enable debugging features
364    pub enable_debug: bool,
365    /// Memory reclamation threshold
366    pub reclaim_threshold: f64,
367    /// Enable automatic reclamation
368    pub auto_reclaim: bool,
369}
370
371impl Default for SlabConfig {
372    fn default() -> Self {
373        Self {
374            default_slab_size: 4096, // 4KB page size
375            alignment: 256,
376            enable_stats: true,
377            enable_debug: false,
378            reclaim_threshold: 0.8,
379            auto_reclaim: true,
380        }
381    }
382}
383
384/// Cache statistics
385#[derive(Debug, Clone, Default)]
386pub struct CacheStats {
387    pub total_allocations: u64,
388    pub total_deallocations: u64,
389    pub cache_hits: u64,
390    pub cache_misses: u64,
391    pub slab_allocations: u64,
392    pub slab_deallocations: u64,
393    pub objects_allocated: u64,
394    pub objects_free: u64,
395    pub average_utilization: f64,
396}
397
398/// Slab statistics
399#[derive(Debug, Clone, Default)]
400pub struct SlabStats {
401    pub total_objects: usize,
402    pub allocated_objects: usize,
403    pub free_objects: usize,
404    pub utilization: f64,
405    pub access_count: u64,
406    pub age: std::time::Duration,
407}
408
409/// Memory pool usage statistics
410#[derive(Debug, Clone)]
411pub struct MemoryPoolUsage {
412    pub total_size: usize,
413    pub allocated_size: usize,
414    pub free_size: usize,
415    pub current_offset: usize,
416    pub free_regions: usize,
417}
418
419impl SlabCache {
420    pub fn new(object_size: usize, config: CacheConfig) -> Self {
421        Self {
422            object_size,
423            slabs: Vec::new(),
424            partial_slabs: VecDeque::new(),
425            full_slabs: Vec::new(),
426            empty_slabs: VecDeque::new(),
427            stats: CacheStats::default(),
428            config,
429        }
430    }
431
432    /// Allocate an object from this cache
433    pub fn allocate(&mut self, memory_pool: &mut MemoryPool) -> Result<NonNull<u8>, SlabError> {
434        self.stats.total_allocations += 1;
435
436        // Try partial slabs first
437        if let Some(&slab_index) = self.partial_slabs.front() {
438            if let Some(ptr) = self.slabs[slab_index].allocate() {
439                self.stats.cache_hits += 1;
440                self.stats.objects_allocated += 1;
441
442                // Move to full slabs if now full
443                if self.slabs[slab_index].is_full() {
444                    self.partial_slabs.pop_front();
445                    self.full_slabs.push(slab_index);
446                }
447
448                // Apply constructor if enabled
449                if self.config.enable_ctor_dtor {
450                    if let Some(ctor) = self.config.constructor {
451                        ctor(ptr.as_ptr());
452                    }
453                }
454
455                return Ok(ptr);
456            }
457        }
458
459        // Try empty slabs
460        if let Some(slab_index) = self.empty_slabs.pop_front() {
461            if let Some(ptr) = self.slabs[slab_index].allocate() {
462                self.stats.cache_hits += 1;
463                self.stats.objects_allocated += 1;
464                self.partial_slabs.push_back(slab_index);
465
466                if self.config.enable_ctor_dtor {
467                    if let Some(ctor) = self.config.constructor {
468                        ctor(ptr.as_ptr());
469                    }
470                }
471
472                return Ok(ptr);
473            }
474        }
475
476        // Need to allocate a new slab
477        self.stats.cache_misses += 1;
478        self.allocate_new_slab(memory_pool)?;
479
480        // Try allocation again with new slab
481        if let Some(&slab_index) = self.partial_slabs.back() {
482            if let Some(ptr) = self.slabs[slab_index].allocate() {
483                self.stats.objects_allocated += 1;
484
485                if self.config.enable_ctor_dtor {
486                    if let Some(ctor) = self.config.constructor {
487                        ctor(ptr.as_ptr());
488                    }
489                }
490
491                return Ok(ptr);
492            }
493        }
494
495        Err(SlabError::OutOfMemory(
496            "Failed to allocate after creating new slab".to_string(),
497        ))
498    }
499
500    /// Deallocate an object back to this cache
501    pub fn deallocate(&mut self, ptr: NonNull<u8>) -> Result<(), SlabError> {
502        // Apply destructor if enabled
503        if self.config.enable_ctor_dtor {
504            if let Some(dtor) = self.config.destructor {
505                dtor(ptr.as_ptr());
506            }
507        }
508
509        // Find which slab contains this pointer
510        let mut slab_index = None;
511        for (i, slab) in self.slabs.iter().enumerate() {
512            let base_addr = slab.base_ptr.as_ptr() as usize;
513            let ptr_addr = ptr.as_ptr() as usize;
514
515            if ptr_addr >= base_addr && ptr_addr < base_addr + slab.slab_size {
516                slab_index = Some(i);
517                break;
518            }
519        }
520
521        let slab_index = slab_index.ok_or_else(|| {
522            SlabError::InvalidPointer("Pointer not found in any slab".to_string())
523        })?;
524
525        let was_full = self.slabs[slab_index].is_full();
526        self.slabs[slab_index].deallocate(ptr)?;
527
528        self.stats.total_deallocations += 1;
529        self.stats.objects_allocated -= 1;
530        self.stats.objects_free += 1;
531
532        // Update slab lists based on new state
533        if was_full {
534            // Remove from full slabs, add to partial
535            if let Some(pos) = self.full_slabs.iter().position(|&i| i == slab_index) {
536                self.full_slabs.remove(pos);
537                self.partial_slabs.push_back(slab_index);
538            }
539        } else if self.slabs[slab_index].is_empty() {
540            // Remove from partial, add to empty
541            if let Some(pos) = self.partial_slabs.iter().position(|&i| i == slab_index) {
542                self.partial_slabs.remove(pos);
543                self.empty_slabs.push_back(slab_index);
544            }
545        }
546
547        Ok(())
548    }
549
550    fn allocate_new_slab(&mut self, memory_pool: &mut MemoryPool) -> Result<(), SlabError> {
551        let slab_size = self.calculate_slab_size();
552
553        let slab_ptr = memory_pool.allocate_slab(slab_size).ok_or_else(|| {
554            SlabError::OutOfMemory("Cannot allocate slab from memory pool".to_string())
555        })?;
556
557        let slab = Slab::new(slab_ptr, slab_size, self.object_size);
558        let slab_index = self.slabs.len();
559
560        self.slabs.push(slab);
561        self.partial_slabs.push_back(slab_index);
562        self.stats.slab_allocations += 1;
563
564        Ok(())
565    }
566
567    fn calculate_slab_size(&self) -> usize {
568        // Calculate optimal slab size based on object size and configuration
569        let objects_per_slab = self.config.objects_per_slab;
570        let base_size = objects_per_slab * self.object_size;
571
572        // Add coloring offset if enabled
573        if self.config.enable_coloring {
574            base_size + self.config.color_offset
575        } else {
576            base_size
577        }
578    }
579
580    /// Get cache statistics
581    pub fn get_stats(&self) -> &CacheStats {
582        &self.stats
583    }
584
585    /// Get detailed cache information
586    pub fn get_cache_info(&self) -> CacheInfo {
587        let total_objects = self.slabs.iter().map(|s| s.object_count).sum();
588        let allocated_objects = self.slabs.iter().map(|s| s.allocated_count).sum();
589        let average_utilization = if total_objects > 0 {
590            allocated_objects as f64 / total_objects as f64
591        } else {
592            0.0
593        };
594
595        CacheInfo {
596            object_size: self.object_size,
597            total_slabs: self.slabs.len(),
598            partial_slabs: self.partial_slabs.len(),
599            full_slabs: self.full_slabs.len(),
600            empty_slabs: self.empty_slabs.len(),
601            total_objects,
602            allocated_objects,
603            free_objects: total_objects - allocated_objects,
604            average_utilization,
605            memory_overhead: self.calculate_memory_overhead(),
606        }
607    }
608
609    fn calculate_memory_overhead(&self) -> f64 {
610        let useful_memory: usize = self
611            .slabs
612            .iter()
613            .map(|s| s.allocated_count * s.object_size)
614            .sum();
615
616        let total_memory: usize = self.slabs.iter().map(|s| s.slab_size).sum();
617
618        if total_memory > 0 {
619            1.0 - (useful_memory as f64 / total_memory as f64)
620        } else {
621            0.0
622        }
623    }
624
625    /// Reclaim empty slabs
626    pub fn reclaim_empty_slabs(&mut self, memory_pool: &mut MemoryPool) -> usize {
627        let mut reclaimed = 0;
628        let keep_count = self.config.max_empty_slabs;
629
630        while self.empty_slabs.len() > keep_count {
631            if let Some(slab_index) = self.empty_slabs.pop_front() {
632                let slab = &self.slabs[slab_index];
633                memory_pool.free_slab(slab.base_ptr, slab.slab_size);
634                reclaimed += 1;
635                self.stats.slab_deallocations += 1;
636            }
637        }
638
639        reclaimed
640    }
641}
642
643/// Cache information
644#[derive(Debug, Clone)]
645pub struct CacheInfo {
646    pub object_size: usize,
647    pub total_slabs: usize,
648    pub partial_slabs: usize,
649    pub full_slabs: usize,
650    pub empty_slabs: usize,
651    pub total_objects: usize,
652    pub allocated_objects: usize,
653    pub free_objects: usize,
654    pub average_utilization: f64,
655    pub memory_overhead: f64,
656}
657
658impl SlabAllocator {
659    pub fn new(base_ptr: NonNull<u8>, total_size: usize, config: SlabConfig) -> Self {
660        let memory_pool = MemoryPool::new(base_ptr, total_size, config.alignment);
661
662        Self {
663            caches: HashMap::new(),
664            stats: SlabStats::default(),
665            memory_pool,
666            config,
667        }
668    }
669
670    /// Allocate object of specified size
671    pub fn allocate(&mut self, size: usize) -> Result<NonNull<u8>, SlabError> {
672        if size == 0 {
673            return Err(SlabError::InvalidSize(
674                "Cannot allocate zero bytes".to_string(),
675            ));
676        }
677
678        // Round up size to alignment boundary
679        let aligned_size = (size + self.config.alignment - 1) & !(self.config.alignment - 1);
680
681        // Get or create cache for this size
682        self.caches.entry(aligned_size).or_insert_with(|| {
683            let cache_config = CacheConfig::default();
684
685            SlabCache::new(aligned_size, cache_config)
686        });
687
688        let cache = self.caches.get_mut(&aligned_size).expect("unwrap failed");
689        cache.allocate(&mut self.memory_pool)
690    }
691
692    /// Deallocate object
693    pub fn deallocate(&mut self, ptr: NonNull<u8>, size: usize) -> Result<(), SlabError> {
694        let aligned_size = (size + self.config.alignment - 1) & !(self.config.alignment - 1);
695
696        let cache = self
697            .caches
698            .get_mut(&aligned_size)
699            .ok_or_else(|| SlabError::InvalidPointer("No cache found for this size".to_string()))?;
700
701        cache.deallocate(ptr)
702    }
703
704    /// Get allocator statistics
705    pub fn get_stats(&self) -> SlabAllocatorStats {
706        let mut total_caches = 0;
707        let mut total_slabs = 0;
708        let mut total_objects = 0;
709        let mut allocated_objects = 0;
710        let mut total_allocations = 0;
711        let mut total_deallocations = 0;
712
713        for cache in self.caches.values() {
714            total_caches += 1;
715            let info = cache.get_cache_info();
716            total_slabs += info.total_slabs;
717            total_objects += info.total_objects;
718            allocated_objects += info.allocated_objects;
719
720            let stats = cache.get_stats();
721            total_allocations += stats.total_allocations;
722            total_deallocations += stats.total_deallocations;
723        }
724
725        let memory_usage = self.memory_pool.get_usage();
726
727        SlabAllocatorStats {
728            total_caches,
729            total_slabs,
730            total_objects,
731            allocated_objects,
732            free_objects: total_objects - allocated_objects,
733            total_allocations,
734            total_deallocations,
735            memory_usage,
736            cache_efficiency: if total_allocations > 0 {
737                allocated_objects as f64 / total_allocations as f64
738            } else {
739                0.0
740            },
741        }
742    }
743
744    /// Get information about all caches
745    pub fn get_all_cache_info(&self) -> Vec<(usize, CacheInfo)> {
746        self.caches
747            .iter()
748            .map(|(&size, cache)| (size, cache.get_cache_info()))
749            .collect()
750    }
751
752    /// Reclaim memory from empty slabs
753    pub fn reclaim_memory(&mut self) -> usize {
754        let mut total_reclaimed = 0;
755
756        for cache in self.caches.values_mut() {
757            total_reclaimed += cache.reclaim_empty_slabs(&mut self.memory_pool);
758        }
759
760        total_reclaimed
761    }
762
763    /// Destroy cache for specific size
764    pub fn destroy_cache(&mut self, size: usize) -> Result<(), SlabError> {
765        let aligned_size = (size + self.config.alignment - 1) & !(self.config.alignment - 1);
766
767        if let Some(mut cache) = self.caches.remove(&aligned_size) {
768            // Reclaim all slabs from this cache
769            cache.reclaim_empty_slabs(&mut self.memory_pool);
770            Ok(())
771        } else {
772            Err(SlabError::InvalidSize("Cache not found".to_string()))
773        }
774    }
775
776    /// Get memory pool usage
777    pub fn get_memory_usage(&self) -> MemoryPoolUsage {
778        self.memory_pool.get_usage()
779    }
780}
781
782// Safety: SlabAllocator manages GPU memory pointers via NonNull<u8>. While NonNull is not Send/Sync by default,
783// it's safe to share SlabAllocator across threads when protected by Arc<Mutex<>> because:
784// 1. The pointers point to GPU memory managed by the GPU driver
785// 2. The Mutex provides exclusive access for all mutable operations
786// 3. No thread-local state is maintained
787unsafe impl Send for SlabAllocator {}
788unsafe impl Sync for SlabAllocator {}
789
790/// Slab allocator statistics
791#[derive(Debug, Clone)]
792pub struct SlabAllocatorStats {
793    pub total_caches: usize,
794    pub total_slabs: usize,
795    pub total_objects: usize,
796    pub allocated_objects: usize,
797    pub free_objects: usize,
798    pub total_allocations: u64,
799    pub total_deallocations: u64,
800    pub memory_usage: MemoryPoolUsage,
801    pub cache_efficiency: f64,
802}
803
804/// Slab allocator errors
805#[derive(Debug, Clone)]
806pub enum SlabError {
807    InvalidSize(String),
808    OutOfMemory(String),
809    InvalidPointer(String),
810    DoubleFree(String),
811    CorruptedSlab(String),
812}
813
814impl std::fmt::Display for SlabError {
815    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
816        match self {
817            SlabError::InvalidSize(msg) => write!(f, "Invalid size: {}", msg),
818            SlabError::OutOfMemory(msg) => write!(f, "Out of memory: {}", msg),
819            SlabError::InvalidPointer(msg) => write!(f, "Invalid pointer: {}", msg),
820            SlabError::DoubleFree(msg) => write!(f, "Double free: {}", msg),
821            SlabError::CorruptedSlab(msg) => write!(f, "Corrupted slab: {}", msg),
822        }
823    }
824}
825
826impl std::error::Error for SlabError {}
827
828/// Thread-safe slab allocator wrapper
829pub struct ThreadSafeSlabAllocator {
830    allocator: Arc<Mutex<SlabAllocator>>,
831}
832
833impl ThreadSafeSlabAllocator {
834    pub fn new(base_ptr: NonNull<u8>, total_size: usize, config: SlabConfig) -> Self {
835        let allocator = SlabAllocator::new(base_ptr, total_size, config);
836        Self {
837            allocator: Arc::new(Mutex::new(allocator)),
838        }
839    }
840
841    pub fn allocate(&self, size: usize) -> Result<NonNull<u8>, SlabError> {
842        let mut allocator = self.allocator.lock().expect("lock poisoned");
843        allocator.allocate(size)
844    }
845
846    pub fn deallocate(&self, ptr: NonNull<u8>, size: usize) -> Result<(), SlabError> {
847        let mut allocator = self.allocator.lock().expect("lock poisoned");
848        allocator.deallocate(ptr, size)
849    }
850
851    pub fn get_stats(&self) -> SlabAllocatorStats {
852        let allocator = self.allocator.lock().expect("lock poisoned");
853        allocator.get_stats()
854    }
855
856    pub fn reclaim_memory(&self) -> usize {
857        let mut allocator = self.allocator.lock().expect("lock poisoned");
858        allocator.reclaim_memory()
859    }
860}
861
862#[cfg(test)]
863mod tests {
864    use super::*;
865
866    #[test]
867    fn test_slab_creation() {
868        let size = 4096;
869        let memory = vec![0u8; size];
870        let ptr = NonNull::new(memory.as_ptr() as *mut u8).expect("unwrap failed");
871
872        let slab = Slab::new(ptr, size, 64);
873        assert_eq!(slab.object_count, size / 64);
874        assert!(slab.is_empty());
875        assert!(!slab.is_full());
876    }
877
878    #[test]
879    fn test_slab_allocation() {
880        let size = 4096;
881        let memory = vec![0u8; size];
882        let ptr = NonNull::new(memory.as_ptr() as *mut u8).expect("unwrap failed");
883
884        let mut slab = Slab::new(ptr, size, 64);
885
886        let alloc1 = slab.allocate();
887        assert!(alloc1.is_some());
888        assert!(slab.is_partial());
889
890        let alloc2 = slab.allocate();
891        assert!(alloc2.is_some());
892        assert_ne!(
893            alloc1.expect("unwrap failed"),
894            alloc2.expect("unwrap failed")
895        );
896    }
897
898    #[test]
899    fn test_slab_deallocation() {
900        let size = 4096;
901        let memory = vec![0u8; size];
902        let ptr = NonNull::new(memory.as_ptr() as *mut u8).expect("unwrap failed");
903
904        let mut slab = Slab::new(ptr, size, 64);
905
906        let alloc_ptr = slab.allocate().expect("unwrap failed");
907        let result = slab.deallocate(alloc_ptr);
908        assert!(result.is_ok());
909    }
910
911    #[test]
912    fn test_memory_pool() {
913        let size = 1024 * 1024;
914        let memory = vec![0u8; size];
915        let ptr = NonNull::new(memory.as_ptr() as *mut u8).expect("unwrap failed");
916
917        let mut pool = MemoryPool::new(ptr, size, 256);
918
919        let slab1 = pool.allocate_slab(4096);
920        assert!(slab1.is_some());
921
922        let slab2 = pool.allocate_slab(4096);
923        assert!(slab2.is_some());
924
925        assert_ne!(slab1.expect("unwrap failed"), slab2.expect("unwrap failed"));
926    }
927
928    #[test]
929    fn test_slab_cache() {
930        let size = 1024 * 1024;
931        let memory = vec![0u8; size];
932        let ptr = NonNull::new(memory.as_ptr() as *mut u8).expect("unwrap failed");
933
934        let mut pool = MemoryPool::new(ptr, size, 256);
935        let config = CacheConfig::default();
936        let mut cache = SlabCache::new(64, config);
937
938        let alloc1 = cache.allocate(&mut pool);
939        assert!(alloc1.is_ok());
940
941        let alloc2 = cache.allocate(&mut pool);
942        assert!(alloc2.is_ok());
943    }
944
945    #[test]
946    fn test_slab_allocator() {
947        let size = 1024 * 1024;
948        let memory = vec![0u8; size];
949        let ptr = NonNull::new(memory.as_ptr() as *mut u8).expect("unwrap failed");
950
951        let config = SlabConfig::default();
952        let mut allocator = SlabAllocator::new(ptr, size, config);
953
954        let alloc1 = allocator.allocate(64);
955        assert!(alloc1.is_ok());
956
957        let alloc2 = allocator.allocate(128);
958        assert!(alloc2.is_ok());
959
960        let stats = allocator.get_stats();
961        // Note: Depending on the slab configuration, sizes 64 and 128 might map to the same cache class
962        assert!(stats.total_caches >= 1); // At least one cache should be created
963    }
964
965    #[test]
966    fn test_thread_safe_allocator() {
967        let size = 1024 * 1024;
968        let memory = vec![0u8; size];
969        let ptr = NonNull::new(memory.as_ptr() as *mut u8).expect("unwrap failed");
970
971        let config = SlabConfig::default();
972        let allocator = ThreadSafeSlabAllocator::new(ptr, size, config);
973
974        let alloc_result = allocator.allocate(64);
975        assert!(alloc_result.is_ok());
976
977        let stats = allocator.get_stats();
978        assert!(stats.allocated_objects > 0);
979    }
980}