cubecl_hip/compute/
server.rs

1use cubecl_core::server::{ProfileError, ProfilingToken};
2use cubecl_cpp::formatter::format_cpp;
3use cubecl_cpp::shared::CompilationOptions;
4
5use super::fence::{Fence, SyncStream};
6use super::storage::HipStorage;
7use super::{HipResource, uninit_vec};
8use crate::runtime::HipCompiler;
9use cubecl_common::future::DynFut;
10use cubecl_common::profile::ProfileDuration;
11use cubecl_core::compute::CubeTask;
12use cubecl_core::compute::DebugInformation;
13use cubecl_core::prelude::*;
14use cubecl_core::{Feature, server::Bindings};
15use cubecl_hip_sys::{HIP_SUCCESS, get_hip_include_path, hiprtcResult_HIPRTC_SUCCESS};
16use cubecl_runtime::logging::ServerLogger;
17use cubecl_runtime::memory_management::MemoryUsage;
18use cubecl_runtime::memory_management::offset_handles;
19use cubecl_runtime::storage::BindingResource;
20use cubecl_runtime::timestamp_profiler::TimestampProfiler;
21use cubecl_runtime::{
22    memory_management::MemoryManagement,
23    server::{self, ComputeServer},
24};
25use std::collections::HashMap;
26use std::ffi::CStr;
27use std::ffi::CString;
28use std::future::Future;
29use std::sync::Arc;
30
31#[cfg(feature = "compilation-cache")]
32use cubecl_common::cache::{Cache, CacheOption};
33
34#[derive(Debug)]
35pub struct HipServer {
36    ctx: HipContext,
37    mem_alignment: usize,
38}
39
40#[derive(Debug)]
41pub(crate) struct HipContext {
42    stream: cubecl_hip_sys::hipStream_t,
43    memory_management: MemoryManagement<HipStorage>,
44    module_names: HashMap<KernelId, HipCompiledKernel>,
45    timestamps: TimestampProfiler,
46    compilation_options: CompilationOptions,
47    #[cfg(feature = "compilation-cache")]
48    compilation_cache: Cache<String, CompilationCacheEntry>,
49}
50
51#[cfg(feature = "compilation-cache")]
52#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Eq, Clone)]
53pub struct CompilationCacheEntry {
54    entrypoint_name: String,
55    cube_dim: (u32, u32, u32),
56    binary: Vec<i8>,
57}
58
59#[derive(Debug)]
60struct HipCompiledKernel {
61    _module: cubecl_hip_sys::hipModule_t,
62    func: cubecl_hip_sys::hipFunction_t,
63    cube_dim: CubeDim,
64}
65
66unsafe impl Send for HipServer {}
67
68impl HipServer {
69    fn read_sync(&mut self, binding: server::Binding) -> Vec<u8> {
70        let ctx = self.get_context();
71        let resource = ctx
72            .memory_management
73            .get_resource(binding.memory, binding.offset_start, binding.offset_end)
74            .expect("Failed to find resource");
75
76        let mut data = uninit_vec(resource.size as usize);
77        unsafe {
78            let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
79                data.as_mut_ptr() as *mut _,
80                resource.ptr,
81                resource.size as usize,
82                ctx.stream,
83            );
84            assert_eq!(status, HIP_SUCCESS, "Should copy data from device to host");
85        };
86        ctx.sync();
87        data
88    }
89
90    fn read_async(
91        &mut self,
92        bindings: Vec<server::Binding>,
93    ) -> impl Future<Output = Vec<Vec<u8>>> + Send + use<> {
94        let ctx = self.get_context();
95        let mut result = Vec::with_capacity(bindings.len());
96
97        for binding in bindings {
98            let resource = ctx
99                .memory_management
100                .get_resource(binding.memory, binding.offset_start, binding.offset_end)
101                .expect("Failed to find resource");
102
103            let mut data = uninit_vec(resource.size as usize);
104            unsafe {
105                let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
106                    data.as_mut_ptr() as *mut _,
107                    resource.ptr,
108                    resource.size as usize,
109                    ctx.stream,
110                );
111                assert_eq!(status, HIP_SUCCESS, "Should copy data from device to host");
112            };
113            result.push(data);
114        }
115
116        let fence = ctx.fence();
117
118        async move {
119            fence.wait();
120            result
121        }
122    }
123
124    fn sync_stream_async(&mut self) -> impl Future<Output = ()> + Send + use<> {
125        let ctx = self.get_context();
126        // We can't use a fence here because no action has been recorded on the context.
127        // We need at least one action to be recorded after the context is initialized
128        // with `cudarc::driver::result::ctx::set_current(self.ctx.context)` for the fence
129        // to have any effect. Otherwise, it seems to be ignored.
130        let sync = ctx.lazy_sync_stream();
131
132        async move {
133            sync.wait();
134        }
135    }
136}
137
138impl ComputeServer for HipServer {
139    type Kernel = Box<dyn CubeTask<HipCompiler>>;
140    type Storage = HipStorage;
141    type Feature = Feature;
142    type Info = ();
143
144    fn read(&mut self, bindings: Vec<server::Binding>) -> DynFut<Vec<Vec<u8>>> {
145        Box::pin(self.read_async(bindings))
146    }
147
148    fn read_tensor(&mut self, bindings: Vec<server::BindingWithMeta>) -> DynFut<Vec<Vec<u8>>> {
149        let bindings = bindings.into_iter().map(|it| it.binding).collect();
150        Box::pin(self.read_async(bindings))
151    }
152
153    fn memory_usage(&self) -> MemoryUsage {
154        self.ctx.memory_usage()
155    }
156
157    fn memory_cleanup(&mut self) {
158        let ctx = self.get_context();
159        ctx.memory_management.cleanup(true);
160    }
161
162    fn create(&mut self, data: &[u8]) -> server::Handle {
163        let handle = self.empty(data.len());
164
165        let binding = handle.clone().binding();
166        self.copy_to_binding(binding, data);
167        handle
168    }
169
170    fn create_tensors(
171        &mut self,
172        data: Vec<&[u8]>,
173        shapes: Vec<&[usize]>,
174        elem_sizes: Vec<usize>,
175    ) -> Vec<(server::Handle, Vec<usize>)> {
176        let handles_strides = self.empty_tensors(shapes.clone(), elem_sizes);
177        for i in 0..data.len() {
178            let data = data[i];
179            let (handle, _) = &handles_strides[i];
180            let binding = handle.clone().binding();
181            self.copy_to_binding(binding, data);
182        }
183        handles_strides
184    }
185
186    fn empty(&mut self, size: usize) -> server::Handle {
187        let ctx = self.get_context();
188        let handle = ctx.memory_management.reserve(size as u64, None);
189        server::Handle::new(handle, None, None, size as u64)
190    }
191
192    fn empty_tensors(
193        &mut self,
194        shapes: Vec<&[usize]>,
195        elem_sizes: Vec<usize>,
196    ) -> Vec<(server::Handle, Vec<usize>)> {
197        let mut total_size = 0;
198        let mut strides = Vec::new();
199        let mut sizes = Vec::new();
200
201        for (shape, elem_size) in shapes.into_iter().zip(elem_sizes) {
202            let size =
203                (shape.iter().product::<usize>() * elem_size).next_multiple_of(self.mem_alignment);
204            strides.push(contiguous_strides(shape));
205            sizes.push(size);
206            total_size += size;
207        }
208
209        let mem_handle = self.empty(total_size);
210        let handles = offset_handles(mem_handle, &sizes);
211
212        handles.into_iter().zip(strides).collect()
213    }
214
215    unsafe fn execute(
216        &mut self,
217        kernel: Self::Kernel,
218        count: CubeCount,
219        bindings: Bindings,
220        mode: ExecutionMode,
221        logger: Arc<ServerLogger>,
222    ) {
223        let mut kernel_id = kernel.id();
224        kernel_id.mode(mode);
225
226        let count = match count {
227            CubeCount::Static(x, y, z) => (x, y, z),
228            // TODO: HIP doesn't have an exact equivalen of dynamic dispatch. Instead, kernels are free to launch other kernels.
229            // One option is to create a dummy kernel with 1 thread that launches the real kernel with the dynamic dispatch settings.
230            // For now, just read the dispatch settings from the buffer.
231            CubeCount::Dynamic(binding) => {
232                let data = self.read_sync(binding);
233                let data = bytemuck::cast_slice(&data);
234                assert!(
235                    data.len() == 3,
236                    "Dynamic cube count should contain 3 values"
237                );
238                (data[0], data[1], data[2])
239            }
240        };
241
242        let Bindings {
243            buffers,
244            metadata,
245            scalars,
246            tensor_maps,
247        } = bindings;
248
249        debug_assert!(tensor_maps.is_empty(), "Can't use tensor maps on HIP");
250        let info = self.create(bytemuck::cast_slice(&metadata.data));
251        let scalars: Vec<_> = scalars.values().map(|s| self.create(s.data())).collect();
252
253        let ctx = self.get_context();
254
255        if !ctx.module_names.contains_key(&kernel_id) {
256            ctx.compile_kernel(&kernel_id, kernel, mode, logger);
257        }
258
259        let mut resources: Vec<_> = buffers.into_iter().map(|b| find_resource(ctx, b)).collect();
260        resources.push(find_resource(ctx, info.clone().binding()));
261        resources.extend(scalars.into_iter().map(|s| find_resource(ctx, s.binding())));
262
263        ctx.execute_task(kernel_id, count, resources);
264    }
265
266    fn flush(&mut self) {}
267
268    fn sync(&mut self) -> DynFut<()> {
269        Box::pin(self.sync_stream_async())
270    }
271
272    fn start_profile(&mut self) -> ProfilingToken {
273        cubecl_common::future::block_on(self.sync());
274        self.ctx.timestamps.start()
275    }
276
277    fn end_profile(&mut self, token: ProfilingToken) -> Result<ProfileDuration, ProfileError> {
278        cubecl_common::future::block_on(self.sync());
279        self.ctx.timestamps.stop(token)
280    }
281
282    fn get_resource(&mut self, binding: server::Binding) -> BindingResource<HipResource> {
283        let ctx = self.get_context();
284        BindingResource::new(
285            binding.clone(),
286            ctx.memory_management
287                .get_resource(binding.memory, binding.offset_start, binding.offset_end)
288                .expect("Can't find resource"),
289        )
290    }
291}
292
293fn find_resource(ctx: &mut HipContext, binding: server::Binding) -> HipResource {
294    ctx.memory_management
295        .get_resource(binding.memory, binding.offset_start, binding.offset_end)
296        .expect("Failed to find resource")
297}
298
299impl HipContext {
300    pub fn new(
301        memory_management: MemoryManagement<HipStorage>,
302        compilation_options: CompilationOptions,
303        stream: cubecl_hip_sys::hipStream_t,
304    ) -> Self {
305        Self {
306            memory_management,
307            module_names: HashMap::new(),
308            stream,
309            timestamps: TimestampProfiler::default(),
310            compilation_options,
311            #[cfg(feature = "compilation-cache")]
312            compilation_cache: Cache::new("hip/compilation", CacheOption::default()),
313        }
314    }
315
316    fn fence(&mut self) -> Fence {
317        Fence::new(self.stream)
318    }
319
320    fn lazy_sync_stream(&mut self) -> SyncStream {
321        SyncStream::new(self.stream)
322    }
323
324    fn sync(&mut self) {
325        unsafe {
326            let status = cubecl_hip_sys::hipStreamSynchronize(self.stream);
327            assert_eq!(
328                status, HIP_SUCCESS,
329                "Should successfully synchronize stream"
330            );
331        };
332        self.memory_management.storage().flush();
333    }
334
335    fn memory_usage(&self) -> MemoryUsage {
336        self.memory_management.memory_usage()
337    }
338
339    fn compile_kernel(
340        &mut self,
341        kernel_id: &KernelId,
342        cube_kernel: Box<dyn CubeTask<HipCompiler>>,
343        mode: ExecutionMode,
344        logger: Arc<ServerLogger>,
345    ) {
346        #[cfg(feature = "compilation-cache")]
347        let name = kernel_id.stable_format();
348        #[cfg(feature = "compilation-cache")]
349        if let Some(entry) = self.compilation_cache.get(&name) {
350            log::trace!("Using compilation cache");
351            self.load_compiled_binary(
352                entry.binary.clone(),
353                kernel_id.clone(),
354                entry.entrypoint_name.clone(),
355                CubeDim {
356                    x: entry.cube_dim.0,
357                    y: entry.cube_dim.1,
358                    z: entry.cube_dim.2,
359                },
360            );
361            return;
362        }
363
364        // CubeCL compilation
365        // jitc = just-in-time compiled
366        let mut jitc_kernel =
367            cube_kernel.compile(&mut Default::default(), &self.compilation_options, mode);
368
369        if logger.compilation_activated() {
370            jitc_kernel.debug_info = Some(DebugInformation::new("cpp", kernel_id.clone()));
371
372            if let Ok(formatted) = format_cpp(&jitc_kernel.source) {
373                jitc_kernel.source = formatted;
374            }
375        }
376        logger.log_compilation(&jitc_kernel);
377
378        // Create HIP Program
379        let program = unsafe {
380            let source = CString::new(jitc_kernel.source.clone()).unwrap();
381            let mut program: cubecl_hip_sys::hiprtcProgram = std::ptr::null_mut();
382            let status = cubecl_hip_sys::hiprtcCreateProgram(
383                &mut program,
384                source.as_ptr(),
385                std::ptr::null(), // program name seems unnecessary
386                0,
387                std::ptr::null_mut(),
388                std::ptr::null_mut(),
389            );
390            assert_eq!(
391                status, hiprtcResult_HIPRTC_SUCCESS,
392                "Should create the program"
393            );
394            program
395        };
396        // Compile HIP program
397        // options
398        let include_path = get_hip_include_path().unwrap();
399        let include_option = format!("-I{include_path}");
400        let include_option_cstr = CString::new(include_option).unwrap();
401        // needed for rocWMMA extension to compile
402        let cpp_std_option_cstr = CString::new("--std=c++17").unwrap();
403        let optimization_level = CString::new("-O3").unwrap();
404        let mut options = vec![
405            cpp_std_option_cstr.as_ptr(),
406            include_option_cstr.as_ptr(),
407            optimization_level.as_ptr(),
408        ];
409        unsafe {
410            let options_ptr = options.as_mut_ptr();
411            let status =
412                cubecl_hip_sys::hiprtcCompileProgram(program, options.len() as i32, options_ptr);
413            if status != hiprtcResult_HIPRTC_SUCCESS {
414                let mut log_size: usize = 0;
415                let status =
416                    cubecl_hip_sys::hiprtcGetProgramLogSize(program, &mut log_size as *mut usize);
417                assert_eq!(
418                    status, hiprtcResult_HIPRTC_SUCCESS,
419                    "Should retrieve the compilation log size"
420                );
421                let mut log_buffer = vec![0; log_size];
422                let status = cubecl_hip_sys::hiprtcGetProgramLog(program, log_buffer.as_mut_ptr());
423                assert_eq!(
424                    status, hiprtcResult_HIPRTC_SUCCESS,
425                    "Should retrieve the compilation log contents"
426                );
427                let log = CStr::from_ptr(log_buffer.as_ptr());
428                let mut message = "[Compilation Error] ".to_string();
429                if log_size > 0 {
430                    for line in log.to_string_lossy().split('\n') {
431                        if !line.is_empty() {
432                            message += format!("\n    {line}").as_str();
433                        }
434                    }
435                } else {
436                    message += "\n No compilation logs found!";
437                }
438                panic!("{message}\n[Source]  \n{}", jitc_kernel.source);
439            }
440            assert_eq!(
441                status, hiprtcResult_HIPRTC_SUCCESS,
442                "Should compile the program"
443            );
444        };
445        // Get HIP compiled code from program
446        let mut code_size: usize = 0;
447        unsafe {
448            let status = cubecl_hip_sys::hiprtcGetCodeSize(program, &mut code_size);
449            assert_eq!(
450                status, hiprtcResult_HIPRTC_SUCCESS,
451                "Should get size of compiled code"
452            );
453        }
454        let mut code = vec![0; code_size];
455        unsafe {
456            let status = cubecl_hip_sys::hiprtcGetCode(program, code.as_mut_ptr());
457            assert_eq!(
458                status, hiprtcResult_HIPRTC_SUCCESS,
459                "Should load compiled code"
460            );
461        }
462
463        #[cfg(feature = "compilation-cache")]
464        self.compilation_cache
465            .insert(
466                name,
467                CompilationCacheEntry {
468                    entrypoint_name: jitc_kernel.entrypoint_name.clone(),
469                    cube_dim: (
470                        jitc_kernel.cube_dim.x,
471                        jitc_kernel.cube_dim.y,
472                        jitc_kernel.cube_dim.z,
473                    ),
474                    binary: code.clone(),
475                },
476            )
477            .unwrap();
478
479        self.load_compiled_binary(
480            code,
481            kernel_id.clone(),
482            jitc_kernel.entrypoint_name,
483            jitc_kernel.cube_dim,
484        );
485    }
486
487    fn load_compiled_binary(
488        &mut self,
489        code: Vec<i8>,
490        kernel_id: KernelId,
491        entrypoint_name: String,
492        cube_dim: CubeDim,
493    ) {
494        let func_name = CString::new(entrypoint_name.clone()).unwrap();
495
496        // Create the HIP module
497        let mut module: cubecl_hip_sys::hipModule_t = std::ptr::null_mut();
498        unsafe {
499            let codeptr = code.as_ptr();
500            let status = cubecl_hip_sys::hipModuleLoadData(&mut module, codeptr as *const _);
501            assert_eq!(status, HIP_SUCCESS, "Should load compiled code into module");
502        }
503        // Retrieve the HIP module function
504        let mut func: cubecl_hip_sys::hipFunction_t = std::ptr::null_mut();
505        unsafe {
506            let status =
507                cubecl_hip_sys::hipModuleGetFunction(&mut func, module, func_name.as_ptr());
508            assert_eq!(status, HIP_SUCCESS, "Should return module function");
509        }
510
511        // register module
512        self.module_names.insert(
513            kernel_id.clone(),
514            HipCompiledKernel {
515                _module: module,
516                func,
517                cube_dim,
518            },
519        );
520    }
521
522    fn execute_task(
523        &mut self,
524        kernel_id: KernelId,
525        dispatch_count: (u32, u32, u32),
526        resources: Vec<HipResource>,
527    ) {
528        let mut bindings = resources
529            .iter()
530            .map(|memory| memory.binding)
531            .collect::<Vec<_>>();
532
533        let kernel = self.module_names.get(&kernel_id).unwrap();
534        let cube_dim = kernel.cube_dim;
535
536        let result = unsafe {
537            let status = cubecl_hip_sys::hipModuleLaunchKernel(
538                kernel.func,
539                dispatch_count.0,
540                dispatch_count.1,
541                dispatch_count.2,
542                cube_dim.x,
543                cube_dim.y,
544                cube_dim.z,
545                // Shared memory is specified statically in the kernel, and no dynamic shared
546                // memory is supported yet in the kernel, which would be that value for the
547                // current kernel launch.
548                0,
549                self.stream,
550                bindings.as_mut_ptr(),
551                std::ptr::null_mut(),
552            );
553            if status == cubecl_hip_sys::hipError_t_hipErrorOutOfMemory {
554                Err(LaunchError::OutOfMemory)
555            } else if status != HIP_SUCCESS {
556                Err(LaunchError::Unknown(format!(
557                    "Unable to launch kernel {kernel_id:?} with status {status:?}"
558                )))
559            } else {
560                Ok(())
561            }
562        };
563
564        match result {
565            Ok(_) => {}
566            Err(err) => match self.timestamps.is_empty() {
567                true => panic!("{err:?}"),
568                false => self.timestamps.error(err.into()),
569            },
570        }
571    }
572}
573
574impl HipServer {
575    /// Create a new hip server.
576    pub(crate) fn new(mem_alignment: usize, ctx: HipContext) -> Self {
577        Self { ctx, mem_alignment }
578    }
579
580    fn get_context(&mut self) -> &mut HipContext {
581        &mut self.ctx
582    }
583
584    fn copy_to_binding(&mut self, binding: server::Binding, data: &[u8]) {
585        let ctx = self.get_context();
586        let resource = ctx
587            .memory_management
588            .get_resource(binding.memory, binding.offset_start, binding.offset_end)
589            .unwrap();
590
591        unsafe {
592            let status = cubecl_hip_sys::hipMemcpyHtoDAsync(
593                resource.ptr,
594                data as *const _ as *mut _,
595                data.len(),
596                ctx.stream,
597            );
598            assert_eq!(status, HIP_SUCCESS, "Should send data to device");
599        }
600    }
601}
602
603pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
604    let rank = shape.len();
605    let mut strides = vec![1; rank];
606    for i in (0..rank - 1).rev() {
607        strides[i] = strides[i + 1] * shape[i + 1];
608    }
609    strides
610}
611
612#[derive(Debug)]
613pub(crate) enum LaunchError {
614    OutOfMemory,
615    Unknown(String),
616}
617
618impl From<LaunchError> for ProfileError {
619    fn from(val: LaunchError) -> Self {
620        match val {
621            LaunchError::OutOfMemory => ProfileError::Unknown("Out of memory".into()),
622            LaunchError::Unknown(msg) => ProfileError::Unknown(msg),
623        }
624    }
625}