cubecl_hip/compute/
server.rs

1use cubecl_cpp::formatter::format_cpp;
2use cubecl_cpp::shared::CompilationOptions;
3
4use crate::runtime::HipCompiler;
5
6use super::fence::{Fence, SyncStream};
7use super::storage::HipStorage;
8use super::{HipResource, uninit_vec};
9use cubecl_common::benchmark::ProfileDuration;
10use cubecl_core::compute::DebugInformation;
11use cubecl_core::{Feature, server::Bindings};
12use cubecl_core::{KernelId, prelude::*};
13use cubecl_hip_sys::{HIP_SUCCESS, hiprtcResult_HIPRTC_SUCCESS};
14use cubecl_runtime::debug::{DebugLogger, ProfileLevel};
15use cubecl_runtime::kernel_timestamps::KernelTimestamps;
16use cubecl_runtime::memory_management::MemoryUsage;
17use cubecl_runtime::storage::BindingResource;
18use cubecl_runtime::{
19    memory_management::MemoryManagement,
20    server::{self, ComputeServer},
21};
22use std::collections::HashMap;
23use std::ffi::CStr;
24use std::ffi::CString;
25use std::future::Future;
26use std::path::PathBuf;
27
28#[cfg(feature = "compilation-cache")]
29use cubecl_common::cache::{Cache, CacheOption};
30
31#[derive(Debug)]
32pub struct HipServer {
33    ctx: HipContext,
34    logger: DebugLogger,
35}
36
37#[derive(Debug)]
38pub(crate) struct HipContext {
39    stream: cubecl_hip_sys::hipStream_t,
40    memory_management: MemoryManagement<HipStorage>,
41    module_names: HashMap<KernelId, HipCompiledKernel>,
42    timestamps: KernelTimestamps,
43    compilation_options: CompilationOptions,
44    #[cfg(feature = "compilation-cache")]
45    compilation_cache: Cache<String, CompilationCacheEntry>,
46}
47
48#[cfg(feature = "compilation-cache")]
49#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Eq, Clone)]
50pub struct CompilationCacheEntry {
51    entrypoint_name: String,
52    cube_dim: (u32, u32, u32),
53    binary: Vec<i8>,
54}
55
56#[derive(Debug)]
57struct HipCompiledKernel {
58    _module: cubecl_hip_sys::hipModule_t,
59    func: cubecl_hip_sys::hipFunction_t,
60    cube_dim: CubeDim,
61}
62
63unsafe impl Send for HipServer {}
64
65impl HipServer {
66    fn read_sync(&mut self, binding: server::Binding) -> Vec<u8> {
67        let ctx = self.get_context();
68        let resource = ctx
69            .memory_management
70            .get_resource(binding.memory, binding.offset_start, binding.offset_end)
71            .expect("Failed to find resource");
72
73        let mut data = uninit_vec(resource.size as usize);
74        unsafe {
75            let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
76                data.as_mut_ptr() as *mut _,
77                resource.ptr,
78                resource.size as usize,
79                ctx.stream,
80            );
81            assert_eq!(status, HIP_SUCCESS, "Should copy data from device to host");
82        };
83        ctx.sync();
84        data
85    }
86
87    fn read_async(
88        &mut self,
89        bindings: Vec<server::Binding>,
90    ) -> impl Future<Output = Vec<Vec<u8>>> + 'static + Send {
91        let ctx = self.get_context();
92        let mut result = Vec::with_capacity(bindings.len());
93
94        for binding in bindings {
95            let resource = ctx
96                .memory_management
97                .get_resource(binding.memory, binding.offset_start, binding.offset_end)
98                .expect("Failed to find resource");
99
100            let mut data = uninit_vec(resource.size as usize);
101            unsafe {
102                let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
103                    data.as_mut_ptr() as *mut _,
104                    resource.ptr,
105                    resource.size as usize,
106                    ctx.stream,
107                );
108                assert_eq!(status, HIP_SUCCESS, "Should copy data from device to host");
109            };
110            result.push(data);
111        }
112
113        let fence = ctx.fence();
114
115        async move {
116            fence.wait();
117            result
118        }
119    }
120
121    fn sync_stream_async(&mut self) -> impl Future<Output = ()> + 'static + Send {
122        let ctx = self.get_context();
123        // We can't use a fence here because no action has been recorded on the context.
124        // We need at least one action to be recorded after the context is initialized
125        // with `cudarc::driver::result::ctx::set_current(self.ctx.context)` for the fence
126        // to have any effect. Otherwise, it seems to be ignored.
127        let sync = ctx.lazy_sync_stream();
128
129        async move {
130            sync.wait();
131        }
132    }
133}
134
135impl ComputeServer for HipServer {
136    type Kernel = Box<dyn CubeTask<HipCompiler>>;
137    type Storage = HipStorage;
138    type Feature = Feature;
139    type Info = ();
140
141    fn read(
142        &mut self,
143        bindings: Vec<server::Binding>,
144    ) -> impl Future<Output = Vec<Vec<u8>>> + 'static {
145        self.read_async(bindings)
146    }
147
148    fn read_tensor(
149        &mut self,
150        bindings: Vec<server::BindingWithMeta>,
151    ) -> impl Future<Output = Vec<Vec<u8>>> + 'static {
152        let bindings = bindings.into_iter().map(|it| it.binding).collect();
153        self.read_async(bindings)
154    }
155
156    fn memory_usage(&self) -> MemoryUsage {
157        self.ctx.memory_usage()
158    }
159
160    fn memory_cleanup(&mut self) {
161        let ctx = self.get_context();
162        ctx.memory_management.cleanup(true);
163    }
164
165    fn create(&mut self, data: &[u8]) -> server::Handle {
166        let handle = self.empty(data.len());
167        let ctx = self.get_context();
168
169        let binding = handle.clone().binding();
170        let resource = ctx
171            .memory_management
172            .get_resource(binding.memory, binding.offset_start, binding.offset_end)
173            .unwrap();
174
175        unsafe {
176            let status = cubecl_hip_sys::hipMemcpyHtoDAsync(
177                resource.ptr,
178                data as *const _ as *mut _,
179                data.len(),
180                ctx.stream,
181            );
182            assert_eq!(status, HIP_SUCCESS, "Should send data to device");
183        }
184        handle
185    }
186
187    fn create_tensor(
188        &mut self,
189        data: &[u8],
190        shape: &[usize],
191        _elem_size: usize,
192    ) -> (server::Handle, Vec<usize>) {
193        let strides = contiguous_strides(shape);
194        let handle = self.create(data);
195        (handle, strides)
196    }
197
198    fn empty(&mut self, size: usize) -> server::Handle {
199        let ctx = self.get_context();
200        let handle = ctx.memory_management.reserve(size as u64, None);
201        server::Handle::new(handle, None, None, size as u64)
202    }
203
204    fn empty_tensor(&mut self, shape: &[usize], elem_size: usize) -> (server::Handle, Vec<usize>) {
205        let strides = contiguous_strides(shape);
206        let size = shape.iter().product::<usize>() * elem_size;
207        let handle = self.empty(size);
208        (handle, strides)
209    }
210
211    unsafe fn execute(
212        &mut self,
213        kernel: Self::Kernel,
214        count: CubeCount,
215        bindings: Bindings,
216        mode: ExecutionMode,
217    ) {
218        let mut kernel_id = kernel.id();
219        kernel_id.mode(mode);
220
221        let profile_level = self.logger.profile_level();
222        let profile_info = if profile_level.is_some() {
223            Some((kernel.name(), kernel_id.clone()))
224        } else {
225            None
226        };
227
228        let count = match count {
229            CubeCount::Static(x, y, z) => (x, y, z),
230            // TODO: CUDA doesn't have an exact equivalen of dynamic dispatch. Instead, kernels are free to launch other kernels.
231            // One option is to create a dummy kernel with 1 thread that launches the real kernel with the dynamic dispatch settings.
232            // For now, just read the dispatch settings from the buffer.
233            CubeCount::Dynamic(binding) => {
234                let data = self.read_sync(binding);
235                let data = bytemuck::cast_slice(&data);
236                assert!(
237                    data.len() == 3,
238                    "Dynamic cube count should contain 3 values"
239                );
240                (data[0], data[1], data[2])
241            }
242        };
243
244        let Bindings {
245            buffers,
246            metadata,
247            scalars,
248            tensor_maps,
249        } = bindings;
250
251        debug_assert!(tensor_maps.is_empty(), "Can't use tensor maps on HIP");
252        let info = self.create(bytemuck::cast_slice(&metadata.data));
253        let scalars: Vec<_> = scalars.values().map(|s| self.create(s.data())).collect();
254
255        let (ctx, logger) = self.get_context_with_logger();
256
257        if !ctx.module_names.contains_key(&kernel_id) {
258            ctx.compile_kernel(&kernel_id, kernel, logger, mode);
259        }
260
261        let mut resources: Vec<_> = buffers.into_iter().map(|b| find_resource(ctx, b)).collect();
262        resources.push(find_resource(ctx, info.clone().binding()));
263        resources.extend(scalars.into_iter().map(|s| find_resource(ctx, s.binding())));
264
265        if let Some(level) = profile_level {
266            ctx.sync();
267            let start = std::time::SystemTime::now();
268            ctx.execute_task(kernel_id, count, resources);
269            ctx.sync();
270
271            let (name, kernel_id) = profile_info.unwrap();
272            let info = match level {
273                ProfileLevel::Basic | ProfileLevel::Medium => {
274                    if let Some(val) = name.split("<").next() {
275                        val.split("::").last().unwrap_or(name).to_string()
276                    } else {
277                        name.to_string()
278                    }
279                }
280                cubecl_runtime::debug::ProfileLevel::Full => {
281                    format!("{name}: {kernel_id} CubeCount {count:?}")
282                }
283            };
284
285            self.logger
286                .register_profiled(info, start.elapsed().unwrap());
287        } else {
288            ctx.execute_task(kernel_id, count, resources);
289        }
290    }
291
292    fn flush(&mut self) {}
293
294    fn sync(&mut self) -> impl Future<Output = ()> + 'static {
295        self.logger.profile_summary();
296        self.sync_stream_async()
297    }
298
299    fn start_profile(&mut self) {
300        cubecl_common::future::block_on(self.sync());
301        self.ctx.timestamps.start();
302    }
303
304    fn end_profile(&mut self) -> ProfileDuration {
305        self.logger.profile_summary();
306        cubecl_common::future::block_on(self.sync());
307        self.ctx.timestamps.stop()
308    }
309
310    fn get_resource(&mut self, binding: server::Binding) -> BindingResource<HipResource> {
311        let ctx = self.get_context();
312        BindingResource::new(
313            binding.clone(),
314            ctx.memory_management
315                .get_resource(binding.memory, binding.offset_start, binding.offset_end)
316                .expect("Can't find resource"),
317        )
318    }
319}
320
321fn find_resource(ctx: &mut HipContext, binding: server::Binding) -> HipResource {
322    ctx.memory_management
323        .get_resource(binding.memory, binding.offset_start, binding.offset_end)
324        .expect("Failed to find resource")
325}
326
327impl HipContext {
328    pub fn new(
329        memory_management: MemoryManagement<HipStorage>,
330        compilation_options: CompilationOptions,
331        stream: cubecl_hip_sys::hipStream_t,
332    ) -> Self {
333        Self {
334            memory_management,
335            module_names: HashMap::new(),
336            stream,
337            timestamps: KernelTimestamps::default(),
338            compilation_options,
339            #[cfg(feature = "compilation-cache")]
340            compilation_cache: Cache::new("hip/compilation", CacheOption::default()),
341        }
342    }
343
344    fn fence(&mut self) -> Fence {
345        Fence::new(self.stream)
346    }
347
348    fn lazy_sync_stream(&mut self) -> SyncStream {
349        SyncStream::new(self.stream)
350    }
351
352    fn sync(&mut self) {
353        unsafe {
354            let status = cubecl_hip_sys::hipStreamSynchronize(self.stream);
355            assert_eq!(
356                status, HIP_SUCCESS,
357                "Should successfully synchronize stream"
358            );
359        };
360        self.memory_management.storage().flush();
361    }
362
363    fn memory_usage(&self) -> MemoryUsage {
364        self.memory_management.memory_usage()
365    }
366
367    fn compile_kernel(
368        &mut self,
369        kernel_id: &KernelId,
370        cube_kernel: Box<dyn CubeTask<HipCompiler>>,
371        logger: &mut DebugLogger,
372        mode: ExecutionMode,
373    ) {
374        #[cfg(feature = "compilation-cache")]
375        let name = kernel_id.stable_format();
376        #[cfg(feature = "compilation-cache")]
377        if let Some(entry) = self.compilation_cache.get(&name) {
378            log::trace!("Using compilation cache");
379            self.load_compiled_binary(
380                entry.binary.clone(),
381                kernel_id.clone(),
382                entry.entrypoint_name.clone(),
383                CubeDim {
384                    x: entry.cube_dim.0,
385                    y: entry.cube_dim.1,
386                    z: entry.cube_dim.2,
387                },
388            );
389            return;
390        }
391
392        // CubeCL compilation
393        // jitc = just-in-time compiled
394        let mut jitc_kernel =
395            cube_kernel.compile(&mut Default::default(), &self.compilation_options, mode);
396
397        if logger.is_activated() {
398            jitc_kernel.debug_info = Some(DebugInformation::new("cpp", kernel_id.clone()));
399
400            if let Ok(formatted) = format_cpp(&jitc_kernel.source) {
401                jitc_kernel.source = formatted;
402            }
403        }
404        let jitc_kernel = logger.debug(jitc_kernel);
405
406        // Create HIP Program
407        let program = unsafe {
408            let source = CString::new(jitc_kernel.source.clone()).unwrap();
409            let mut program: cubecl_hip_sys::hiprtcProgram = std::ptr::null_mut();
410            let status = cubecl_hip_sys::hiprtcCreateProgram(
411                &mut program,
412                source.as_ptr(),
413                std::ptr::null(), // program name seems unnecessary
414                0,
415                std::ptr::null_mut(),
416                std::ptr::null_mut(),
417            );
418            assert_eq!(
419                status, hiprtcResult_HIPRTC_SUCCESS,
420                "Should create the program"
421            );
422            program
423        };
424        // Compile HIP program
425        // options
426        let include_path = include_path();
427        let include_option = format!("-I{}", include_path.display());
428        let include_option_cstr = CString::new(include_option).unwrap();
429        // needed for rocWMMA extension to compile
430        let cpp_std_option_cstr = CString::new("--std=c++17").unwrap();
431        let mut options = vec![cpp_std_option_cstr.as_ptr(), include_option_cstr.as_ptr()];
432        unsafe {
433            let options_ptr = options.as_mut_ptr();
434            let status =
435                cubecl_hip_sys::hiprtcCompileProgram(program, options.len() as i32, options_ptr);
436            if status != hiprtcResult_HIPRTC_SUCCESS {
437                let mut log_size: usize = 0;
438                let status =
439                    cubecl_hip_sys::hiprtcGetProgramLogSize(program, &mut log_size as *mut usize);
440                assert_eq!(
441                    status, hiprtcResult_HIPRTC_SUCCESS,
442                    "Should retrieve the compilation log size"
443                );
444                let mut log_buffer = vec![0; log_size];
445                let status = cubecl_hip_sys::hiprtcGetProgramLog(program, log_buffer.as_mut_ptr());
446                assert_eq!(
447                    status, hiprtcResult_HIPRTC_SUCCESS,
448                    "Should retrieve the compilation log contents"
449                );
450                let log = CStr::from_ptr(log_buffer.as_ptr());
451                let mut message = "[Compilation Error] ".to_string();
452                if log_size > 0 {
453                    for line in log.to_string_lossy().split('\n') {
454                        if !line.is_empty() {
455                            message += format!("\n    {line}").as_str();
456                        }
457                    }
458                } else {
459                    message += "\n No compilation logs found!";
460                }
461                panic!("{message}\n[Source]  \n{}", jitc_kernel.source);
462            }
463            assert_eq!(
464                status, hiprtcResult_HIPRTC_SUCCESS,
465                "Should compile the program"
466            );
467        };
468        // Get HIP compiled code from program
469        let mut code_size: usize = 0;
470        unsafe {
471            let status = cubecl_hip_sys::hiprtcGetCodeSize(program, &mut code_size);
472            assert_eq!(
473                status, hiprtcResult_HIPRTC_SUCCESS,
474                "Should get size of compiled code"
475            );
476        }
477        let mut code = vec![0; code_size];
478        unsafe {
479            let status = cubecl_hip_sys::hiprtcGetCode(program, code.as_mut_ptr());
480            assert_eq!(
481                status, hiprtcResult_HIPRTC_SUCCESS,
482                "Should load compiled code"
483            );
484        }
485
486        #[cfg(feature = "compilation-cache")]
487        self.compilation_cache
488            .insert(
489                name,
490                CompilationCacheEntry {
491                    entrypoint_name: jitc_kernel.entrypoint_name.clone(),
492                    cube_dim: (
493                        jitc_kernel.cube_dim.x,
494                        jitc_kernel.cube_dim.y,
495                        jitc_kernel.cube_dim.z,
496                    ),
497                    binary: code.clone(),
498                },
499            )
500            .unwrap();
501
502        self.load_compiled_binary(
503            code,
504            kernel_id.clone(),
505            jitc_kernel.entrypoint_name,
506            jitc_kernel.cube_dim,
507        );
508    }
509
510    fn load_compiled_binary(
511        &mut self,
512        code: Vec<i8>,
513        kernel_id: KernelId,
514        entrypoint_name: String,
515        cube_dim: CubeDim,
516    ) {
517        let func_name = CString::new(entrypoint_name.clone()).unwrap();
518
519        // Create the HIP module
520        let mut module: cubecl_hip_sys::hipModule_t = std::ptr::null_mut();
521        unsafe {
522            let codeptr = code.as_ptr();
523            let status = cubecl_hip_sys::hipModuleLoadData(&mut module, codeptr as *const _);
524            assert_eq!(status, HIP_SUCCESS, "Should load compiled code into module");
525        }
526        // Retrieve the HIP module function
527        let mut func: cubecl_hip_sys::hipFunction_t = std::ptr::null_mut();
528        unsafe {
529            let status =
530                cubecl_hip_sys::hipModuleGetFunction(&mut func, module, func_name.as_ptr());
531            assert_eq!(status, HIP_SUCCESS, "Should return module function");
532        }
533
534        // register module
535        self.module_names.insert(
536            kernel_id.clone(),
537            HipCompiledKernel {
538                _module: module,
539                func,
540                cube_dim,
541            },
542        );
543    }
544
545    fn execute_task(
546        &mut self,
547        kernel_id: KernelId,
548        dispatch_count: (u32, u32, u32),
549        resources: Vec<HipResource>,
550    ) {
551        let mut bindings = resources
552            .iter()
553            .map(|memory| memory.binding)
554            .collect::<Vec<_>>();
555
556        let kernel = self.module_names.get(&kernel_id).unwrap();
557        let cube_dim = kernel.cube_dim;
558
559        unsafe {
560            let status = cubecl_hip_sys::hipModuleLaunchKernel(
561                kernel.func,
562                dispatch_count.0,
563                dispatch_count.1,
564                dispatch_count.2,
565                cube_dim.x,
566                cube_dim.y,
567                cube_dim.z,
568                // Shared memory is specified statically in the kernel, and no dynamic shared
569                // memory is supported yet in the kernel, which would be that value for the
570                // current kernel launch.
571                0,
572                self.stream,
573                bindings.as_mut_ptr(),
574                std::ptr::null_mut(),
575            );
576            if status == cubecl_hip_sys::hipError_t_hipErrorOutOfMemory {
577                panic!("Error: Cannot launch kernel (Out of memory)\n{}", kernel_id)
578            }
579            assert_eq!(status, HIP_SUCCESS, "Should launch the kernel");
580        };
581    }
582}
583
584impl HipServer {
585    /// Create a new hip server.
586    pub(crate) fn new(ctx: HipContext) -> Self {
587        let logger = DebugLogger::default();
588        Self { ctx, logger }
589    }
590
591    fn get_context(&mut self) -> &mut HipContext {
592        self.get_context_with_logger().0
593    }
594
595    fn get_context_with_logger(&mut self) -> (&mut HipContext, &mut DebugLogger) {
596        (&mut self.ctx, &mut self.logger)
597    }
598}
599
600fn include_path() -> PathBuf {
601    let error_msg = "
602        ROCm HIP installation not found.
603        Please ensure that ROCm is installed with HIP runtimes and development libraries and the ROCM_PATH or HIP_PATH environment variable is set correctly.
604        Note: Default path is /opt/rocm which may not be correct.
605    ";
606    let path = hip_path().expect(error_msg);
607    let result = path.join("include");
608    let hip_include = result.join("hip");
609    if !hip_include.exists() || !hip_include.is_dir() {
610        panic!("{error_msg}");
611    }
612    result
613}
614
615fn hip_path() -> Option<PathBuf> {
616    if let Ok(path) = std::env::var("CUBECL_ROCM_PATH") {
617        return Some(PathBuf::from(path));
618    }
619    if let Ok(path) = std::env::var("ROCM_PATH") {
620        return Some(PathBuf::from(path));
621    }
622    if let Ok(path) = std::env::var("HIP_PATH") {
623        return Some(PathBuf::from(path));
624    }
625    // Default path (only Linux is supported for now)
626    Some(PathBuf::from("/opt/rocm"))
627}
628
629fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
630    let rank = shape.len();
631    let mut strides = vec![1; rank];
632    for i in (0..rank - 1).rev() {
633        strides[i] = strides[i + 1] * shape[i + 1];
634    }
635    strides
636}