cubecl-runtime 0.10.0-pre.3

Crate that helps creating high performance async runtimes for CubeCL.
Documentation
use super::{ManagedMemoryHandle, MemoryPool, Slice, calculate_padding};
use crate::memory_management::{BytesFormat, MemoryLocation};
use crate::storage::StorageUtilization;
use crate::{memory_management::MemoryUsage, server::IoError};
use alloc::vec;
use alloc::vec::Vec;
use cubecl_common::backtrace::BackTrace;
use hashbrown::HashMap;

pub struct PersistentPool {
    slices: Vec<Slice>,
    sizes: HashMap<u64, Vec<usize>>,
    alignment: u64,
    max_alloc_size: u64,
    location_base: MemoryLocation,
}

impl core::fmt::Display for PersistentPool {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        for (size, positions) in self.sizes.iter() {
            let mut num_free = 0;
            let mut num_full = 0;
            let total = positions.len();

            for pos in positions {
                let slice = &self.slices[*pos];
                let is_free = slice.is_free();
                if is_free {
                    num_free += 1;
                } else {
                    num_full += 1;
                }
            }

            f.write_fmt(format_args!(
                "  - Slices {} =>  {num_free} free - {num_full} full - {total} total\n",
                BytesFormat::new(*size)
            ))?;
        }

        if !self.sizes.is_empty() {
            f.write_fmt(format_args!("\n{}\n", self.get_memory_usage()))?;
        }

        Ok(())
    }
}

impl PersistentPool {
    pub fn new(max_alloc_size: u64, alignment: u64, pool_pos: u8) -> Self {
        Self {
            slices: Vec::new(),
            sizes: HashMap::new(),
            max_alloc_size,
            alignment,
            location_base: MemoryLocation::new(pool_pos, 0, 0),
        }
    }

    pub fn has_size(&mut self, size: u64) -> bool {
        let padding = calculate_padding(size, self.alignment);
        let effective_size = size + padding;
        self.sizes.contains_key(&effective_size)
    }
}

impl MemoryPool for PersistentPool {
    fn accept(&self, size: u64) -> bool {
        self.max_alloc_size >= size
    }

    fn find(&self, binding: &super::ManagedMemoryBinding) -> Result<&Slice, IoError> {
        let slice_index = binding.descriptor().slice();

        self.slices
            .get(slice_index)
            .ok_or_else(|| IoError::NotFound {
                backtrace: BackTrace::capture(),
                reason: alloc::format!("Memory slice {} doesn't exist", slice_index).into(),
            })
    }

    fn try_reserve(&mut self, size: u64) -> Option<ManagedMemoryHandle> {
        let padding = calculate_padding(size, self.alignment);
        let effective_size = size + padding;

        if let Some(positions) = self.sizes.get_mut(&effective_size) {
            for pos in positions {
                let slice = &self.slices[*pos];

                if slice.is_free() {
                    return Some(slice.handle.clone());
                }
            }
        }

        None
    }

    fn alloc<Storage: crate::storage::ComputeStorage>(
        &mut self,
        storage: &mut Storage,
        size: u64,
    ) -> Result<ManagedMemoryHandle, IoError> {
        let padding = calculate_padding(size, self.alignment);
        let effective_size = size + padding;

        let storage_handle = storage.alloc(effective_size)?;
        let mut slice = Slice::new(storage_handle, padding);
        slice.storage.utilization = StorageUtilization { offset: 0, size };
        let slice_id = slice.descriptor();
        let slice_pos = self.slices.len();
        let mut location = self.location_base;
        location.slice = slice_pos as u32;
        slice_id.update_location(location);

        match self.sizes.get_mut(&effective_size) {
            Some(vals) => {
                vals.push(slice_pos);
            }
            None => {
                self.sizes.insert(effective_size, vec![slice_pos]);
            }
        }

        let handle = slice.handle.clone();
        self.slices.push(slice);

        Ok(handle)
    }

    fn get_memory_usage(&self) -> MemoryUsage {
        let used_slices: Vec<_> = self
            .slices
            .iter()
            .filter(|slice| !slice.is_free())
            .collect();

        MemoryUsage {
            number_allocs: used_slices.len() as u64,
            bytes_in_use: used_slices.iter().map(|slice| slice.storage.size()).sum(),
            bytes_padding: used_slices.iter().map(|slice| slice.padding).sum(),
            bytes_reserved: self.slices.iter().map(|slice| slice.effective_size()).sum(),
        }
    }

    fn cleanup<Storage: crate::storage::ComputeStorage>(
        &mut self,
        storage: &mut Storage,
        _alloc_nr: u64,
        explicit: bool,
    ) {
        if explicit {
            // We have to recompute all locations, so it's just safer to rebuild everything.
            let mut slices = Vec::new();
            let mut sizes = HashMap::<u64, Vec<usize>>::new();

            for slice in self.slices.drain(..) {
                if slice.is_free() {
                    storage.dealloc(slice.storage.id);
                } else {
                    let slice_pos = slices.len();
                    let effective_size = slice.effective_size();
                    slice.descriptor().update_slice(slice_pos as u32);
                    slices.push(slice);

                    match sizes.get_mut(&effective_size) {
                        Some(vals) => {
                            vals.push(slice_pos);
                        }
                        None => {
                            sizes.insert(effective_size, vec![slice_pos]);
                        }
                    }
                }
            }

            self.sizes = sizes;
            self.slices = slices;
            storage.flush();
        }
    }

    fn bind(
        &mut self,
        old: ManagedMemoryHandle,
        new: ManagedMemoryHandle,
        cursor: u64,
    ) -> Result<(), IoError> {
        let slice = &mut self.slices[old.descriptor().slice()];
        new.descriptor()
            .update_location(old.descriptor().location());
        slice.cursor = cursor;
        slice.handle = new;

        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use crate::memory_management::memory_pool::calculate_padding;
    use crate::storage::BytesStorage;

    use super::*;

    /// `alloc` and `try_reserve` must use the same `sizes` key (`size + padding`). See #1224.
    #[test_log::test]
    fn persistent_pool_try_reserve_reuses_slice_with_padding() {
        let mut storage = BytesStorage::default();
        let alignment = 4u64;
        let mut pool = PersistentPool::new(1024 * 1024, alignment, 0);

        let size = 1025u64;
        assert_ne!(
            calculate_padding(size, alignment),
            0,
            "test needs non-zero padding so alloc vs try_reserve keys differed pre-fix"
        );

        let handle = pool.alloc(&mut storage, size).expect("alloc");
        assert!(
            pool.try_reserve(size).is_none(),
            "slice must stay reserved while the handle is alive"
        );

        core::mem::drop(handle);

        assert!(
            pool.try_reserve(size).is_some(),
            "freed slice should be reusable"
        );
    }

    #[test_log::test]
    fn persistent_pool() {
        let mut storage = BytesStorage::default();
        let mut pool = PersistentPool::new(1024 * 1024, 4, 0);

        let result = pool.try_reserve(1024);
        assert!(result.is_none(), "No alloc yet");

        let alloc1 = pool.alloc(&mut storage, 1024);
        let result = pool.try_reserve(1024);
        assert!(result.is_none(), "No free slice yet, handle1 is alive");

        core::mem::drop(alloc1);
        let result = pool.try_reserve(1024);
        assert!(result.is_some(), "Handle1 is free to be reused.");
        core::mem::drop(result);

        let result = pool.try_reserve(1025);
        assert!(result.is_none(), "Not the same size.");

        let alloc2 = pool.alloc(&mut storage, 1024);
        let usage = pool.get_memory_usage();
        assert_eq!(usage.bytes_in_use, 1024);
        assert_eq!(usage.bytes_reserved, 2048);

        let result = pool.try_reserve(1024);
        let usage = pool.get_memory_usage();
        assert!(result.is_some(), "Handle1 is free to be reused.");
        assert_eq!(usage.bytes_in_use, 2048);
        assert_eq!(usage.bytes_reserved, 2048);

        core::mem::drop(alloc2);
        core::mem::drop(result);

        let usage = pool.get_memory_usage();
        assert_eq!(usage.bytes_in_use, 0);
        assert_eq!(usage.bytes_reserved, 2048);
    }
}