cubecl_hip/compute/
storage.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
use cubecl_hip_sys::HIP_SUCCESS;
use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
use std::collections::HashMap;

/// Buffer storage for HIP.
pub struct HipStorage {
    memory: HashMap<StorageId, cubecl_hip_sys::hipDeviceptr_t>,
    deallocations: Vec<StorageId>,
    stream: cubecl_hip_sys::hipStream_t,
    activate_slices: HashMap<ActiveResource, cubecl_hip_sys::hipDeviceptr_t>,
}

#[derive(new, Debug, Hash, PartialEq, Eq, Clone)]
struct ActiveResource {
    ptr: u64,
}

unsafe impl Send for HipStorage {}

impl core::fmt::Debug for HipStorage {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str(format!("HipStorage {{ device: {:?} }}", self.stream).as_str())
    }
}

/// Keeps actual HIP buffer references in a hashmap with ids as key.
impl HipStorage {
    /// Create a new storage on the given stream.
    pub fn new(stream: cubecl_hip_sys::hipStream_t) -> Self {
        Self {
            memory: HashMap::new(),
            deallocations: Vec::new(),
            stream,
            activate_slices: HashMap::new(),
        }
    }

    /// Actually deallocates buffers tagged to be deallocated.
    pub fn perform_deallocations(&mut self) {
        for id in self.deallocations.drain(..) {
            if let Some(ptr) = self.memory.remove(&id) {
                unsafe {
                    cubecl_hip_sys::hipFreeAsync(ptr, self.stream);
                }
            }
        }
    }

    pub fn flush(&mut self) {
        self.activate_slices.clear();
    }
}

pub type Binding = cubecl_hip_sys::hipDeviceptr_t;

/// The memory resource that can be allocated for the device.
#[derive(new, Debug)]
pub struct HipResource {
    /// The buffer.
    pub ptr: cubecl_hip_sys::hipDeviceptr_t,
    pub binding: Binding,
    pub offset: u64,
    pub size: u64,
}

unsafe impl Send for HipResource {}

impl ComputeStorage for HipStorage {
    const ALIGNMENT: u64 = 4;

    type Resource = HipResource;

    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
        let ptr = (*self.memory.get(&handle.id).unwrap()) as u64;

        let offset = handle.offset();
        let size = handle.size();

        let ptr = ptr + offset;
        let key = ActiveResource::new(ptr);

        self.activate_slices
            .insert(key.clone(), ptr as cubecl_hip_sys::hipDeviceptr_t);

        // The ptr needs to stay alive until we send the task to the server.
        let ptr = self.activate_slices.get(&key).unwrap();

        HipResource::new(
            *ptr,
            ptr as *const cubecl_hip_sys::hipDeviceptr_t as *mut std::ffi::c_void,
            offset,
            size,
        )
    }

    fn alloc(&mut self, size: u64) -> StorageHandle {
        let id = StorageId::new();
        unsafe {
            let mut dptr: *mut ::std::os::raw::c_void = std::ptr::null_mut();
            let status = cubecl_hip_sys::hipMallocAsync(&mut dptr, size as usize, self.stream);
            assert_eq!(status, HIP_SUCCESS, "Should allocate memory");
            self.memory.insert(id, dptr);
        };
        StorageHandle::new(id, StorageUtilization { offset: 0, size })
    }

    fn dealloc(&mut self, id: StorageId) {
        self.deallocations.push(id);
    }
}