1use std::collections::{BTreeSet, HashSet};
41use std::sync::atomic::{AtomicUsize, Ordering};
42use std::sync::{Arc, Mutex};
43
44use crate::buffer::CudaBuffer;
45use crate::device::GpuDevice;
46use crate::error::GpuResult;
47
48pub const MIN_BLOCK_SIZE: usize = 512;
54
55pub const SMALL_SIZE: usize = 1 << 20; pub const SMALL_BUFFER: usize = 2 << 20; pub const MIN_LARGE_ALLOC: usize = 10 << 20; pub const LARGE_BUFFER: usize = 20 << 20; pub const ROUND_LARGE: usize = 2 << 20; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
81pub struct StreamId(pub usize);
82
83static NEXT_BLOCK_ID: AtomicUsize = AtomicUsize::new(0);
89
90#[derive(Debug)]
99pub struct Block {
100 pub(crate) id: usize,
102 pub device: usize,
104 pub size: usize,
106 pub ptr: usize,
111 pub stream: StreamId,
113 pub stream_uses: HashSet<StreamId>,
116 pub allocated: bool,
118 pub prev: Option<usize>,
120 pub next: Option<usize>,
122 pub in_small_pool: bool,
124}
125
126impl Block {
127 pub fn new(
129 device: usize,
130 size: usize,
131 ptr: usize,
132 stream: StreamId,
133 in_small_pool: bool,
134 ) -> Self {
135 Self {
136 id: NEXT_BLOCK_ID.fetch_add(1, Ordering::Relaxed),
137 device,
138 size,
139 ptr,
140 stream,
141 stream_uses: HashSet::new(),
142 allocated: false,
143 prev: None,
144 next: None,
145 in_small_pool,
146 }
147 }
148
149 pub fn is_split(&self) -> bool {
151 self.prev.is_some() || self.next.is_some()
152 }
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
164pub(crate) struct BlockKey {
165 stream: StreamId,
166 size: usize,
167 ptr: usize,
168 id: usize,
169}
170
171impl BlockKey {
172 fn from_block(b: &Block) -> Self {
173 Self {
174 stream: b.stream,
175 size: b.size,
176 ptr: b.ptr,
177 id: b.id,
178 }
179 }
180
181 fn search(stream: StreamId, size: usize) -> Self {
183 Self {
184 stream,
185 size,
186 ptr: 0,
187 id: 0,
188 }
189 }
190}
191
192pub(crate) struct BlockPool {
203 free_blocks: BTreeSet<(BlockKey, usize)>, pub is_small: bool,
207}
208
209impl BlockPool {
210 pub fn new(is_small: bool) -> Self {
212 Self {
213 free_blocks: BTreeSet::new(),
214 is_small,
215 }
216 }
217
218 #[cfg(test)]
220 pub fn insert(&mut self, block_idx: usize, block: &Block) {
221 self.free_blocks
222 .insert((BlockKey::from_block(block), block_idx));
223 }
224
225 pub fn insert_key(&mut self, block_idx: usize, key: BlockKey) {
227 self.free_blocks.insert((key, block_idx));
228 }
229
230 pub fn remove_key(&mut self, block_idx: usize, key: BlockKey) {
232 self.free_blocks.remove(&(key, block_idx));
233 }
234
235 pub fn find_free_block(&self, stream: StreamId, size: usize) -> Option<usize> {
237 let search = (BlockKey::search(stream, size), 0);
238 if let Some(&(key, idx)) = self.free_blocks.range(search..).next() {
239 if key.stream == stream {
240 return Some(idx);
242 }
243 }
244 None
245 }
246
247 pub fn len(&self) -> usize {
249 self.free_blocks.len()
250 }
251
252 pub fn clear(&mut self) {
254 self.free_blocks.clear();
255 }
256}
257
258pub(crate) struct AllocatorState {
266 pub(crate) blocks: Vec<Block>,
268 pub(crate) small_pool: BlockPool,
270 pub(crate) large_pool: BlockPool,
272 pub(crate) reserved_bytes: usize,
274 pub(crate) allocated_bytes: usize,
276 pub(crate) peak_bytes: usize,
278 pub(crate) hits: usize,
280 pub(crate) misses: usize,
282}
283
284impl AllocatorState {
285 fn new() -> Self {
286 Self {
287 blocks: Vec::new(),
288 small_pool: BlockPool::new(true),
289 large_pool: BlockPool::new(false),
290 reserved_bytes: 0,
291 allocated_bytes: 0,
292 peak_bytes: 0,
293 hits: 0,
294 misses: 0,
295 }
296 }
297
298 pub(crate) fn get_pool_mut(&mut self, is_small: bool) -> &mut BlockPool {
300 let pool = if is_small {
301 &mut self.small_pool
302 } else {
303 &mut self.large_pool
304 };
305 debug_assert_eq!(pool.is_small, is_small, "pool size-class mismatch");
306 pool
307 }
308
309 pub(crate) fn add_block(&mut self, block: Block) -> usize {
311 let idx = self.blocks.len();
312 self.blocks.push(block);
313 idx
314 }
315
316 pub(crate) fn should_split(&self, block_idx: usize, size: usize) -> bool {
318 let block = &self.blocks[block_idx];
319 let remaining = block.size - size;
320 if block.in_small_pool {
321 remaining >= MIN_BLOCK_SIZE
323 } else {
324 remaining > SMALL_SIZE
327 }
328 }
329
330 pub(crate) fn split_block(&mut self, block_idx: usize, size: usize) {
334 let remaining_size = self.blocks[block_idx].size - size;
335 let remaining_ptr = self.blocks[block_idx].ptr + size;
336 let stream = self.blocks[block_idx].stream;
337 let device = self.blocks[block_idx].device;
338 let is_small = self.blocks[block_idx].in_small_pool;
339 let old_next = self.blocks[block_idx].next;
340
341 let mut remainder = Block::new(device, remaining_size, remaining_ptr, stream, is_small);
343 remainder.prev = Some(block_idx);
344 remainder.next = old_next;
345
346 let rem_idx = self.add_block(remainder);
347
348 self.blocks[block_idx].size = size;
350 self.blocks[block_idx].next = Some(rem_idx);
351
352 if let Some(old_next_idx) = old_next {
354 self.blocks[old_next_idx].prev = Some(rem_idx);
355 }
356
357 let rem_key = BlockKey::from_block(&self.blocks[rem_idx]);
359 let pool = self.get_pool_mut(is_small);
360 pool.insert_key(rem_idx, rem_key);
361 }
362
363 pub(crate) fn try_merge(&mut self, block_idx: usize, neighbor_idx: Option<usize>) -> usize {
366 let Some(nbr_idx) = neighbor_idx else {
367 return 0;
368 };
369
370 if self.blocks[nbr_idx].allocated || !self.blocks[nbr_idx].stream_uses.is_empty() {
372 return 0;
373 }
374
375 let is_small = self.blocks[nbr_idx].in_small_pool;
376 let subsumed_size = self.blocks[nbr_idx].size;
377
378 let nbr_key = BlockKey::from_block(&self.blocks[nbr_idx]);
380 {
381 let pool = self.get_pool_mut(is_small);
382 pool.remove_key(nbr_idx, nbr_key);
383 }
384
385 if self.blocks[block_idx].prev == Some(nbr_idx) {
387 let nbr_prev = self.blocks[nbr_idx].prev;
389 self.blocks[block_idx].ptr = self.blocks[nbr_idx].ptr;
390 self.blocks[block_idx].size += subsumed_size;
391 self.blocks[block_idx].prev = nbr_prev;
392 if let Some(pp) = nbr_prev {
393 self.blocks[pp].next = Some(block_idx);
394 }
395 } else {
396 let nbr_next = self.blocks[nbr_idx].next;
398 self.blocks[block_idx].size += subsumed_size;
399 self.blocks[block_idx].next = nbr_next;
400 if let Some(nn) = nbr_next {
401 self.blocks[nn].prev = Some(block_idx);
402 }
403 }
404
405 self.blocks[nbr_idx].size = 0;
410 self.blocks[nbr_idx].prev = None;
411 self.blocks[nbr_idx].next = None;
412
413 subsumed_size
414 }
415
416 pub(crate) fn free_block(&mut self, block_idx: usize) {
419 self.blocks[block_idx].allocated = false;
420 self.blocks[block_idx].stream_uses.clear();
421 let size = self.blocks[block_idx].size;
422 self.allocated_bytes = self.allocated_bytes.saturating_sub(size);
423
424 let prev = self.blocks[block_idx].prev;
426 let next = self.blocks[block_idx].next;
427 self.try_merge(block_idx, prev);
428 self.try_merge(block_idx, next);
429
430 let is_small = self.blocks[block_idx].in_small_pool;
432 let merged_key = BlockKey::from_block(&self.blocks[block_idx]);
433 let pool = self.get_pool_mut(is_small);
434 pool.insert_key(block_idx, merged_key);
435 }
436
437 pub(crate) fn cached_bytes(&self) -> usize {
439 self.reserved_bytes.saturating_sub(self.allocated_bytes)
440 }
441}
442
443pub fn round_size(size: usize) -> usize {
452 if size < MIN_BLOCK_SIZE {
453 return MIN_BLOCK_SIZE;
454 }
455 (size + MIN_BLOCK_SIZE - 1) & !(MIN_BLOCK_SIZE - 1)
457}
458
459pub fn get_allocation_size(size: usize) -> usize {
464 if size <= SMALL_SIZE {
465 SMALL_BUFFER
466 } else if size < MIN_LARGE_ALLOC {
467 LARGE_BUFFER
468 } else {
469 (size + ROUND_LARGE - 1) & !(ROUND_LARGE - 1)
471 }
472}
473
474pub struct CudaAllocator {
489 device: Arc<GpuDevice>,
490 pub(crate) state: Mutex<AllocatorState>,
491 allocated_bytes_atomic: AtomicUsize,
493 peak_bytes_atomic: AtomicUsize,
495}
496
497impl CudaAllocator {
498 pub fn new(device: Arc<GpuDevice>) -> Self {
500 Self {
501 device,
502 state: Mutex::new(AllocatorState::new()),
503 allocated_bytes_atomic: AtomicUsize::new(0),
504 peak_bytes_atomic: AtomicUsize::new(0),
505 }
506 }
507
508 #[cfg(feature = "cuda")]
519 pub fn alloc_zeros<T>(&self, count: usize) -> GpuResult<CudaBuffer<T>>
520 where
521 T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
522 {
523 let bytes = count.saturating_mul(std::mem::size_of::<T>());
524 let slice = self.device.stream().alloc_zeros::<T>(count)?;
525
526 let prev = self
528 .allocated_bytes_atomic
529 .fetch_add(bytes, Ordering::Relaxed);
530 self.peak_bytes_atomic
531 .fetch_max(prev + bytes, Ordering::Relaxed);
532
533 Ok(CudaBuffer {
534 data: Some(slice),
535 len: count,
536 alloc_len: count,
537 device_ordinal: self.device.ordinal(),
538 pool_fn: None,
539 })
540 }
541
542 #[cfg(feature = "cuda")]
550 pub fn alloc_copy<T>(&self, data: &[T]) -> GpuResult<CudaBuffer<T>>
551 where
552 T: cudarc::driver::DeviceRepr,
553 {
554 let bytes = data.len().saturating_mul(std::mem::size_of::<T>());
555 let slice = self.device.stream().clone_htod(data)?;
556
557 let prev = self
558 .allocated_bytes_atomic
559 .fetch_add(bytes, Ordering::Relaxed);
560 self.peak_bytes_atomic
561 .fetch_max(prev + bytes, Ordering::Relaxed);
562
563 Ok(CudaBuffer {
564 data: Some(slice),
565 len: data.len(),
566 alloc_len: data.len(),
567 device_ordinal: self.device.ordinal(),
568 pool_fn: None,
569 })
570 }
571
572 pub fn free<T>(&self, buffer: CudaBuffer<T>) {
578 let bytes = buffer
579 .len()
580 .checked_mul(std::mem::size_of::<T>())
581 .unwrap_or(0);
582 self.allocated_bytes_atomic
583 .fetch_sub(bytes, Ordering::Relaxed);
584 drop(buffer);
585 }
586
587 #[inline]
593 pub fn memory_allocated(&self) -> usize {
594 self.allocated_bytes_atomic.load(Ordering::Relaxed)
595 }
596
597 #[inline]
600 pub fn max_memory_allocated(&self) -> usize {
601 self.peak_bytes_atomic.load(Ordering::Relaxed)
602 }
603
604 pub fn memory_reserved(&self) -> usize {
606 self.state.lock().map(|s| s.reserved_bytes).unwrap_or(0)
607 }
608
609 pub fn reset_peak_stats(&self) {
611 let current = self.allocated_bytes_atomic.load(Ordering::Relaxed);
612 self.peak_bytes_atomic.store(current, Ordering::Relaxed);
613 }
614
615 pub fn empty_cache(&self) {
623 let Ok(mut state) = self.state.lock() else {
624 return;
625 };
626 state.small_pool.clear();
629 state.large_pool.clear();
630
631 state.reserved_bytes = state.allocated_bytes;
633 }
634
635 #[inline]
637 pub fn device(&self) -> &GpuDevice {
638 &self.device
639 }
640
641 pub fn record_stream_on_block(&self, block_idx: usize, stream: StreamId) {
652 let Ok(mut state) = self.state.lock() else {
653 return;
654 };
655 if block_idx < state.blocks.len() {
656 state.blocks[block_idx].stream_uses.insert(stream);
657 }
658 }
659
660 pub fn block_count(&self) -> usize {
662 self.state.lock().map(|s| s.blocks.len()).unwrap_or(0)
663 }
664
665 pub fn free_block_count(&self) -> usize {
667 self.state
668 .lock()
669 .map(|s| s.small_pool.len() + s.large_pool.len())
670 .unwrap_or(0)
671 }
672
673 pub fn cache_stats(&self) -> (usize, usize) {
675 self.state
676 .lock()
677 .map(|s| (s.hits, s.misses))
678 .unwrap_or((0, 0))
679 }
680
681 pub fn cached_bytes(&self) -> usize {
683 self.state.lock().map(|s| s.cached_bytes()).unwrap_or(0)
684 }
685
686 pub fn cache_find(&self, size: usize, stream: StreamId) -> Option<(usize, usize)> {
698 let rounded = round_size(size);
699 let is_small = rounded <= SMALL_SIZE;
700
701 let Ok(mut state) = self.state.lock() else {
702 return None;
703 };
704
705 let block_idx = {
706 let pool = state.get_pool_mut(is_small);
707 pool.find_free_block(stream, rounded)?
708 };
709
710 let key = BlockKey::from_block(&state.blocks[block_idx]);
712 state.get_pool_mut(is_small).remove_key(block_idx, key);
713
714 if state.should_split(block_idx, rounded) {
716 state.split_block(block_idx, rounded);
717 }
718
719 state.blocks[block_idx].allocated = true;
721 let actual_size = state.blocks[block_idx].size;
722 state.allocated_bytes += actual_size;
723 if state.allocated_bytes > state.peak_bytes {
724 state.peak_bytes = state.allocated_bytes;
725 }
726 state.hits += 1;
727
728 Some((block_idx, actual_size))
729 }
730
731 pub fn cache_insert(
742 &self,
743 requested_size: usize,
744 driver_alloc_size: usize,
745 ptr: usize,
746 stream: StreamId,
747 ) -> (usize, usize) {
748 let rounded = round_size(requested_size);
749 let is_small = rounded <= SMALL_SIZE;
750
751 let Ok(mut state) = self.state.lock() else {
752 return (0, driver_alloc_size);
754 };
755
756 let mut block = Block::new(
757 self.device.ordinal(),
758 driver_alloc_size,
759 ptr,
760 stream,
761 is_small,
762 );
763 block.allocated = true;
764 let block_idx = state.add_block(block);
765
766 state.reserved_bytes += driver_alloc_size;
767
768 if state.should_split(block_idx, rounded) {
770 state.split_block(block_idx, rounded);
771 }
772
773 let actual_size = state.blocks[block_idx].size;
774 state.allocated_bytes += actual_size;
775 if state.allocated_bytes > state.peak_bytes {
776 state.peak_bytes = state.allocated_bytes;
777 }
778 state.misses += 1;
779
780 (block_idx, actual_size)
781 }
782
783 pub fn cache_free(&self, block_idx: usize) {
790 let Ok(mut state) = self.state.lock() else {
791 return;
792 };
793 if block_idx < state.blocks.len() && state.blocks[block_idx].allocated {
794 state.free_block(block_idx);
795 }
796 }
797
798 pub fn driver_alloc_size(size: usize) -> usize {
803 get_allocation_size(round_size(size))
804 }
805}
806
807impl std::fmt::Debug for CudaAllocator {
808 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
809 f.debug_struct("CudaAllocator")
810 .field("device_ordinal", &self.device.ordinal())
811 .field(
812 "allocated_bytes",
813 &self.allocated_bytes_atomic.load(Ordering::Relaxed),
814 )
815 .field(
816 "peak_bytes",
817 &self.peak_bytes_atomic.load(Ordering::Relaxed),
818 )
819 .field("cached_bytes", &self.cached_bytes())
820 .finish()
821 }
822}
823
824#[cfg(not(feature = "cuda"))]
829impl CudaAllocator {
830 pub fn alloc_zeros<T>(&self, _count: usize) -> GpuResult<CudaBuffer<T>> {
832 Err(crate::error::GpuError::NoCudaFeature)
833 }
834
835 pub fn alloc_copy<T>(&self, _data: &[T]) -> GpuResult<CudaBuffer<T>> {
837 Err(crate::error::GpuError::NoCudaFeature)
838 }
839}
840
841#[cfg(test)]
846mod tests {
847 use super::*;
848
849 #[test]
854 fn round_size_minimum() {
855 assert_eq!(round_size(0), MIN_BLOCK_SIZE);
856 assert_eq!(round_size(1), MIN_BLOCK_SIZE);
857 assert_eq!(round_size(511), MIN_BLOCK_SIZE);
858 assert_eq!(round_size(512), MIN_BLOCK_SIZE);
859 }
860
861 #[test]
862 fn round_size_multiples() {
863 assert_eq!(round_size(513), 1024);
864 assert_eq!(round_size(1024), 1024);
865 assert_eq!(round_size(1025), 1536);
866 }
867
868 #[test]
869 fn alloc_size_small() {
870 assert_eq!(get_allocation_size(512), SMALL_BUFFER);
872 assert_eq!(get_allocation_size(SMALL_SIZE), SMALL_BUFFER);
873 }
874
875 #[test]
876 fn alloc_size_mid() {
877 assert_eq!(get_allocation_size(SMALL_SIZE + 1), LARGE_BUFFER);
879 assert_eq!(get_allocation_size(MIN_LARGE_ALLOC - 1), LARGE_BUFFER);
880 }
881
882 #[test]
883 fn alloc_size_large() {
884 assert_eq!(get_allocation_size(MIN_LARGE_ALLOC), MIN_LARGE_ALLOC);
886 assert_eq!(
887 get_allocation_size(MIN_LARGE_ALLOC + 1),
888 MIN_LARGE_ALLOC + ROUND_LARGE
889 );
890 }
891
892 fn make_stream() -> StreamId {
897 StreamId(42)
898 }
899
900 #[test]
901 fn block_pool_insert_find() {
902 let mut state = AllocatorState::new();
903 let stream = make_stream();
904
905 let block = Block::new(0, 4096, 0x1000, stream, true);
907 let idx = state.add_block(block);
908 state.small_pool.insert(idx, &state.blocks[idx]);
909
910 let found = state.small_pool.find_free_block(stream, 512);
912 assert_eq!(found, Some(idx));
913 }
914
915 #[test]
916 fn block_pool_respects_stream() {
917 let mut state = AllocatorState::new();
918 let stream_a = StreamId(1);
919 let stream_b = StreamId(2);
920
921 let block = Block::new(0, 4096, 0x1000, stream_a, true);
922 let idx = state.add_block(block);
923 state.small_pool.insert(idx, &state.blocks[idx]);
924
925 assert!(state.small_pool.find_free_block(stream_b, 512).is_none());
927
928 assert_eq!(state.small_pool.find_free_block(stream_a, 512), Some(idx));
930 }
931
932 #[test]
933 fn block_pool_finds_smallest_fit() {
934 let mut state = AllocatorState::new();
935 let stream = make_stream();
936
937 let b1 = Block::new(0, 4096, 0x1000, stream, true);
939 let i1 = state.add_block(b1);
940 state.small_pool.insert(i1, &state.blocks[i1]);
941
942 let b2 = Block::new(0, 1024, 0x2000, stream, true);
943 let i2 = state.add_block(b2);
944 state.small_pool.insert(i2, &state.blocks[i2]);
945
946 let found = state.small_pool.find_free_block(stream, 768);
948 assert_eq!(found, Some(i2));
949 }
950
951 #[test]
952 fn split_block_creates_remainder() {
953 let mut state = AllocatorState::new();
954 let stream = make_stream();
955
956 let block = Block::new(0, 8192, 0x1000, stream, true);
957 let idx = state.add_block(block);
958
959 state.split_block(idx, 1024);
961
962 assert_eq!(state.blocks[idx].size, 1024);
963 let rem_idx = state.blocks[idx].next.unwrap();
964 assert_eq!(state.blocks[rem_idx].size, 8192 - 1024);
965 assert_eq!(state.blocks[rem_idx].ptr, 0x1000 + 1024);
966 assert_eq!(state.blocks[rem_idx].prev, Some(idx));
967
968 let found = state.small_pool.find_free_block(stream, 1024);
970 assert_eq!(found, Some(rem_idx));
971 }
972
973 #[test]
974 fn coalesce_merges_adjacent_blocks() {
975 let mut state = AllocatorState::new();
976 let stream = make_stream();
977
978 let a = Block::new(0, 2048, 0x1000, stream, true);
980 let a_idx = state.add_block(a);
981
982 let b = Block::new(0, 2048, 0x1000 + 2048, stream, true);
983 let b_idx = state.add_block(b);
984
985 let c = Block::new(0, 4096, 0x1000 + 4096, stream, true);
986 let c_idx = state.add_block(c);
987
988 state.blocks[a_idx].next = Some(b_idx);
990 state.blocks[b_idx].prev = Some(a_idx);
991 state.blocks[b_idx].next = Some(c_idx);
992 state.blocks[c_idx].prev = Some(b_idx);
993
994 state.blocks[b_idx].allocated = true;
996 state.blocks[b_idx].size = 2048;
997 state.allocated_bytes = 2048;
998
999 state.small_pool.insert(a_idx, &state.blocks[a_idx]);
1000 state.small_pool.insert(c_idx, &state.blocks[c_idx]);
1001
1002 state.free_block(b_idx);
1004
1005 assert_eq!(state.blocks[b_idx].size, 2048 + 2048 + 4096);
1007 assert_eq!(state.blocks[b_idx].ptr, 0x1000);
1008 assert!(!state.blocks[b_idx].allocated);
1009 }
1010
1011 #[test]
1012 fn should_split_small_pool() {
1013 let mut state = AllocatorState::new();
1014 let stream = make_stream();
1015
1016 let block = Block::new(0, 2048, 0x1000, stream, true);
1017 let idx = state.add_block(block);
1018
1019 assert!(state.should_split(idx, 1024));
1021
1022 assert!(!state.should_split(idx, 1800));
1024 }
1025
1026 #[test]
1027 fn should_split_large_pool() {
1028 let mut state = AllocatorState::new();
1029 let stream = make_stream();
1030
1031 let block = Block::new(0, 4 * 1024 * 1024, 0x1000, stream, false);
1032 let idx = state.add_block(block);
1033
1034 assert!(state.should_split(idx, 2 * 1024 * 1024));
1036
1037 assert!(!state.should_split(idx, 3 * 1024 * 1024 + 512 * 1024));
1039 }
1040
1041 #[test]
1042 fn stream_uses_prevent_reuse() {
1043 let stream = make_stream();
1044 let mut block = Block::new(0, 4096, 0x1000, stream, true);
1045
1046 assert!(block.stream_uses.is_empty());
1047 block.stream_uses.insert(StreamId(99));
1048
1049 assert!(!block.stream_uses.is_empty());
1051 }
1052
1053 #[test]
1054 fn stream_uses_prevent_merge() {
1055 let mut state = AllocatorState::new();
1056 let stream = make_stream();
1057
1058 let a = Block::new(0, 2048, 0x1000, stream, true);
1060 let a_idx = state.add_block(a);
1061
1062 let mut b = Block::new(0, 2048, 0x1000 + 2048, stream, true);
1063 b.stream_uses.insert(StreamId(99)); let b_idx = state.add_block(b);
1065
1066 state.blocks[a_idx].next = Some(b_idx);
1068 state.blocks[b_idx].prev = Some(a_idx);
1069
1070 state.small_pool.insert(b_idx, &state.blocks[b_idx]);
1072
1073 let merged = state.try_merge(a_idx, Some(b_idx));
1075 assert_eq!(merged, 0);
1076 assert_eq!(state.blocks[a_idx].size, 2048); }
1078
1079 #[test]
1080 fn cache_find_and_insert_roundtrip() {
1081 let device = Arc::new(match GpuDevice::new(0) {
1082 Ok(d) => d,
1083 Err(_) => return, });
1085 let alloc = CudaAllocator::new(device);
1086 let stream = StreamId(1);
1087
1088 let (idx, actual) = alloc.cache_insert(2048, 4096, 0x1000, stream);
1090 assert!(actual <= 4096);
1092 assert_eq!(alloc.cache_stats().1, 1); alloc.cache_free(idx);
1096
1097 let found = alloc.cache_find(512, stream);
1099 assert!(found.is_some());
1100 assert_eq!(alloc.cache_stats().0, 1); }
1102
1103 #[test]
1104 fn empty_cache_clears_pools() {
1105 let device = Arc::new(match GpuDevice::new(0) {
1106 Ok(d) => d,
1107 Err(_) => return,
1108 });
1109 let alloc = CudaAllocator::new(device);
1110 let stream = StreamId(1);
1111
1112 alloc.cache_insert(1024, 4096, 0x1000, stream);
1113 {
1114 let state = alloc.state.lock().unwrap();
1115 assert!(!state.blocks.is_empty());
1117 }
1118
1119 alloc.cache_free(0);
1121 assert!(alloc.free_block_count() > 0);
1122
1123 alloc.empty_cache();
1124 assert_eq!(alloc.free_block_count(), 0);
1125 }
1126
1127 #[cfg(feature = "cuda")]
1132 mod cuda_tests {
1133 use super::*;
1134
1135 fn make_allocator() -> CudaAllocator {
1136 let device = GpuDevice::new(0).expect("CUDA device 0");
1137 CudaAllocator::new(Arc::new(device))
1138 }
1139
1140 #[test]
1141 fn new_allocator_starts_at_zero() {
1142 let alloc = make_allocator();
1143 assert_eq!(alloc.memory_allocated(), 0);
1144 assert_eq!(alloc.max_memory_allocated(), 0);
1145 }
1146
1147 #[test]
1148 fn empty_cache_is_harmless() {
1149 let alloc = make_allocator();
1150 alloc.empty_cache();
1151 }
1152
1153 #[test]
1154 fn debug_impl() {
1155 let alloc = make_allocator();
1156 let s = format!("{alloc:?}");
1157 assert!(s.contains("CudaAllocator"));
1158 assert!(s.contains("allocated_bytes"));
1159 }
1160
1161 #[test]
1162 fn alloc_increases_allocated_bytes() {
1163 let alloc = make_allocator();
1164 let buf = alloc.alloc_zeros::<f32>(256).expect("alloc_zeros");
1165 assert_eq!(alloc.memory_allocated(), 256 * std::mem::size_of::<f32>());
1166 assert_eq!(
1167 alloc.max_memory_allocated(),
1168 256 * std::mem::size_of::<f32>()
1169 );
1170 alloc.free(buf);
1171 }
1172
1173 #[test]
1174 fn free_decreases_allocated_bytes() {
1175 let alloc = make_allocator();
1176 let buf = alloc.alloc_zeros::<f32>(128).expect("alloc_zeros");
1177 let expected = 128 * std::mem::size_of::<f32>();
1178 assert_eq!(alloc.memory_allocated(), expected);
1179
1180 alloc.free(buf);
1181 assert_eq!(alloc.memory_allocated(), 0);
1182 }
1183
1184 #[test]
1185 fn peak_tracks_maximum() {
1186 let alloc = make_allocator();
1187
1188 let buf1 = alloc.alloc_zeros::<f32>(100).expect("alloc 1");
1189 let buf2 = alloc.alloc_zeros::<f32>(200).expect("alloc 2");
1190 let peak_after_two = alloc.max_memory_allocated();
1191
1192 alloc.free(buf1);
1193 assert_eq!(alloc.max_memory_allocated(), peak_after_two);
1194 assert!(alloc.memory_allocated() < peak_after_two);
1195
1196 alloc.free(buf2);
1197 assert_eq!(alloc.memory_allocated(), 0);
1198 assert_eq!(alloc.max_memory_allocated(), peak_after_two);
1199 }
1200
1201 #[test]
1202 fn reset_peak_stats_lowers_peak() {
1203 let alloc = make_allocator();
1204
1205 let buf = alloc.alloc_zeros::<f32>(512).expect("alloc");
1206 let high = alloc.max_memory_allocated();
1207 alloc.free(buf);
1208
1209 assert_eq!(alloc.max_memory_allocated(), high);
1210
1211 alloc.reset_peak_stats();
1212 assert_eq!(alloc.max_memory_allocated(), 0);
1213 }
1214
1215 #[test]
1216 fn alloc_copy_tracks_bytes() {
1217 let alloc = make_allocator();
1218 let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
1219 let buf = alloc.alloc_copy(&data).expect("alloc_copy");
1220 assert_eq!(alloc.memory_allocated(), 4 * std::mem::size_of::<f64>());
1221 alloc.free(buf);
1222 assert_eq!(alloc.memory_allocated(), 0);
1223 }
1224
1225 #[test]
1226 fn zero_element_alloc() {
1227 let alloc = make_allocator();
1228 let buf = alloc.alloc_zeros::<f32>(0).expect("alloc_zeros empty");
1229 assert_eq!(alloc.memory_allocated(), 0);
1230 assert_eq!(buf.len(), 0);
1231 assert!(buf.is_empty());
1232 alloc.free(buf);
1233 assert_eq!(alloc.memory_allocated(), 0);
1234 }
1235 }
1236}