cubecl_hip/compute/
server.rs

1use cubecl_cpp::formatter::format_cpp;
2use cubecl_cpp::shared::CompilationOptions;
3
4use crate::runtime::HipCompiler;
5
6use super::fence::{Fence, SyncStream};
7use super::storage::HipStorage;
8use super::{uninit_vec, HipResource};
9use cubecl_core::compute::DebugInformation;
10use cubecl_core::ir::CubeDim;
11use cubecl_core::Feature;
12use cubecl_core::{prelude::*, KernelId};
13use cubecl_hip_sys::{hiprtcResult_HIPRTC_SUCCESS, HIP_SUCCESS};
14use cubecl_runtime::debug::{DebugLogger, ProfileLevel};
15use cubecl_runtime::memory_management::MemoryUsage;
16use cubecl_runtime::storage::BindingResource;
17use cubecl_runtime::{
18    memory_management::MemoryManagement,
19    server::{self, ComputeServer},
20};
21use cubecl_runtime::{ExecutionMode, TimestampsError, TimestampsResult};
22use std::collections::HashMap;
23use std::ffi::CStr;
24use std::ffi::CString;
25use std::future::Future;
26use std::path::PathBuf;
27use std::time::Instant;
28
29#[derive(Debug)]
30pub struct HipServer {
31    ctx: HipContext,
32    logger: DebugLogger,
33}
34
35#[derive(Debug)]
36pub(crate) struct HipContext {
37    stream: cubecl_hip_sys::hipStream_t,
38    memory_management: MemoryManagement<HipStorage>,
39    module_names: HashMap<KernelId, HipCompiledKernel>,
40    timestamps: KernelTimestamps,
41    compilation_options: CompilationOptions,
42}
43
44#[derive(Debug)]
45struct HipCompiledKernel {
46    _module: cubecl_hip_sys::hipModule_t,
47    func: cubecl_hip_sys::hipFunction_t,
48    cube_dim: CubeDim,
49    shared_mem_bytes: usize,
50}
51
52#[derive(Debug)]
53enum KernelTimestamps {
54    Inferred { start_time: Instant },
55    Disabled,
56}
57
58impl KernelTimestamps {
59    fn enable(&mut self) {
60        if !matches!(self, Self::Disabled) {
61            return;
62        }
63
64        *self = Self::Inferred {
65            start_time: Instant::now(),
66        };
67    }
68
69    fn disable(&mut self) {
70        *self = Self::Disabled;
71    }
72}
73
74unsafe impl Send for HipServer {}
75
76impl HipServer {
77    fn read_sync(&mut self, binding: server::Binding) -> Vec<u8> {
78        let ctx = self.get_context();
79        let resource = ctx.memory_management.get_resource(
80            binding.memory,
81            binding.offset_start,
82            binding.offset_end,
83        );
84
85        let mut data = uninit_vec(resource.size as usize);
86        unsafe {
87            let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
88                data.as_mut_ptr() as *mut _,
89                resource.ptr,
90                resource.size as usize,
91                ctx.stream,
92            );
93            assert_eq!(status, HIP_SUCCESS, "Should copy data from device to host");
94        };
95        ctx.sync();
96        data
97    }
98
99    fn read_async(
100        &mut self,
101        bindings: Vec<server::Binding>,
102    ) -> impl Future<Output = Vec<Vec<u8>>> + 'static + Send {
103        let ctx = self.get_context();
104        let mut result = Vec::with_capacity(bindings.len());
105
106        for binding in bindings {
107            let resource = ctx.memory_management.get_resource(
108                binding.memory,
109                binding.offset_start,
110                binding.offset_end,
111            );
112
113            let mut data = uninit_vec(resource.size as usize);
114            unsafe {
115                let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
116                    data.as_mut_ptr() as *mut _,
117                    resource.ptr,
118                    resource.size as usize,
119                    ctx.stream,
120                );
121                assert_eq!(status, HIP_SUCCESS, "Should copy data from device to host");
122            };
123            result.push(data);
124        }
125
126        let fence = ctx.fence();
127
128        async move {
129            fence.wait();
130            result
131        }
132    }
133
134    fn sync_stream_async(&mut self) -> impl Future<Output = ()> + 'static + Send {
135        let ctx = self.get_context();
136        // We can't use a fence here because no action has been recorded on the context.
137        // We need at least one action to be recorded after the context is initialized
138        // with `cudarc::driver::result::ctx::set_current(self.ctx.context)` for the fence
139        // to have any effect. Otherwise, it seems to be ignored.
140        let sync = ctx.lazy_sync_stream();
141
142        async move {
143            sync.wait();
144        }
145    }
146}
147
148impl ComputeServer for HipServer {
149    type Kernel = Box<dyn CubeTask<HipCompiler>>;
150    type Storage = HipStorage;
151    type Feature = Feature;
152
153    fn read(
154        &mut self,
155        bindings: Vec<server::Binding>,
156    ) -> impl Future<Output = Vec<Vec<u8>>> + 'static {
157        self.read_async(bindings)
158    }
159
160    fn memory_usage(&self) -> MemoryUsage {
161        self.ctx.memory_usage()
162    }
163
164    fn create(&mut self, data: &[u8]) -> server::Handle {
165        let handle = self.empty(data.len());
166        let ctx = self.get_context();
167
168        let binding = handle.clone().binding();
169        let resource = ctx.memory_management.get_resource(
170            binding.memory,
171            binding.offset_start,
172            binding.offset_end,
173        );
174
175        unsafe {
176            let status = cubecl_hip_sys::hipMemcpyHtoDAsync(
177                resource.ptr,
178                data as *const _ as *mut _,
179                data.len(),
180                ctx.stream,
181            );
182            assert_eq!(status, HIP_SUCCESS, "Should send data to device");
183        }
184        handle
185    }
186
187    fn empty(&mut self, size: usize) -> server::Handle {
188        let ctx = self.get_context();
189        let handle = ctx.memory_management.reserve(size as u64, None);
190        server::Handle::new(handle, None, None, size as u64)
191    }
192
193    unsafe fn execute(
194        &mut self,
195        kernel: Self::Kernel,
196        count: CubeCount,
197        bindings: Vec<server::Binding>,
198        mode: ExecutionMode,
199    ) {
200        let mut kernel_id = kernel.id();
201        kernel_id.mode(mode);
202
203        let profile_level = self.logger.profile_level();
204        let profile_info = if profile_level.is_some() {
205            Some((kernel.name(), kernel_id.clone()))
206        } else {
207            None
208        };
209
210        let count = match count {
211            CubeCount::Static(x, y, z) => (x, y, z),
212            // TODO: CUDA doesn't have an exact equivalen of dynamic dispatch. Instead, kernels are free to launch other kernels.
213            // One option is to create a dummy kernel with 1 thread that launches the real kernel with the dynamic dispatch settings.
214            // For now, just read the dispatch settings from the buffer.
215            CubeCount::Dynamic(binding) => {
216                let data = self.read_sync(binding);
217                let data = bytemuck::cast_slice(&data);
218                assert!(
219                    data.len() == 3,
220                    "Dynamic cube count should contain 3 values"
221                );
222                (data[0], data[1], data[2])
223            }
224        };
225
226        let (ctx, logger) = self.get_context_with_logger();
227
228        if !ctx.module_names.contains_key(&kernel_id) {
229            ctx.compile_kernel(&kernel_id, kernel, logger, mode);
230        }
231
232        let resources = bindings
233            .into_iter()
234            .map(|binding| {
235                ctx.memory_management.get_resource(
236                    binding.memory,
237                    binding.offset_start,
238                    binding.offset_end,
239                )
240            })
241            .collect::<Vec<_>>();
242
243        if let Some(level) = profile_level {
244            ctx.sync();
245            let start = std::time::SystemTime::now();
246            ctx.execute_task(kernel_id, count, resources);
247            ctx.sync();
248
249            let (name, kernel_id) = profile_info.unwrap();
250            let info = match level {
251                ProfileLevel::Basic | ProfileLevel::Medium => {
252                    if let Some(val) = name.split("<").next() {
253                        val.split("::").last().unwrap_or(name).to_string()
254                    } else {
255                        name.to_string()
256                    }
257                }
258                cubecl_runtime::debug::ProfileLevel::Full => {
259                    format!("{name}: {kernel_id} CubeCount {count:?}")
260                }
261            };
262
263            self.logger
264                .register_profiled(info, start.elapsed().unwrap());
265        } else {
266            ctx.execute_task(kernel_id, count, resources);
267        }
268    }
269
270    fn flush(&mut self) {}
271
272    fn sync(&mut self) -> impl Future<Output = ()> + 'static {
273        self.logger.profile_summary();
274        self.sync_stream_async()
275    }
276
277    fn sync_elapsed(&mut self) -> impl Future<Output = TimestampsResult> + 'static {
278        self.logger.profile_summary();
279
280        let ctx = self.get_context();
281        ctx.sync();
282
283        let duration = match &mut ctx.timestamps {
284            KernelTimestamps::Inferred { start_time } => {
285                let duration = start_time.elapsed();
286                *start_time = Instant::now();
287                Ok(duration)
288            }
289            KernelTimestamps::Disabled => Err(TimestampsError::Disabled),
290        };
291
292        async move { duration }
293    }
294
295    fn get_resource(&mut self, binding: server::Binding) -> BindingResource<Self> {
296        let ctx = self.get_context();
297        BindingResource::new(
298            binding.clone(),
299            ctx.memory_management.get_resource(
300                binding.memory,
301                binding.offset_start,
302                binding.offset_end,
303            ),
304        )
305    }
306
307    fn enable_timestamps(&mut self) {
308        self.ctx.timestamps.enable();
309    }
310
311    fn disable_timestamps(&mut self) {
312        if self.logger.profile_level().is_none() {
313            self.ctx.timestamps.disable();
314        }
315    }
316}
317
318impl HipContext {
319    pub fn new(
320        memory_management: MemoryManagement<HipStorage>,
321        compilation_options: CompilationOptions,
322        stream: cubecl_hip_sys::hipStream_t,
323    ) -> Self {
324        Self {
325            memory_management,
326            module_names: HashMap::new(),
327            stream,
328            timestamps: KernelTimestamps::Disabled,
329            compilation_options,
330        }
331    }
332
333    fn fence(&mut self) -> Fence {
334        Fence::new(self.stream)
335    }
336
337    fn lazy_sync_stream(&mut self) -> SyncStream {
338        SyncStream::new(self.stream)
339    }
340
341    fn sync(&mut self) {
342        unsafe {
343            let status = cubecl_hip_sys::hipStreamSynchronize(self.stream);
344            assert_eq!(
345                status, HIP_SUCCESS,
346                "Should successfully synchronize stream"
347            );
348        };
349        self.memory_management.storage().flush();
350    }
351
352    fn memory_usage(&self) -> MemoryUsage {
353        self.memory_management.memory_usage()
354    }
355
356    fn compile_kernel(
357        &mut self,
358        kernel_id: &KernelId,
359        cube_kernel: Box<dyn CubeTask<HipCompiler>>,
360        logger: &mut DebugLogger,
361        mode: ExecutionMode,
362    ) {
363        // CubeCL compilation
364        // jitc = just-in-time compiled
365        let mut jitc_kernel = cube_kernel.compile(&self.compilation_options, mode);
366        let func_name = CString::new(jitc_kernel.entrypoint_name.clone()).unwrap();
367
368        if logger.is_activated() {
369            jitc_kernel.debug_info = Some(DebugInformation::new("cpp", kernel_id.clone()));
370
371            if let Ok(formatted) = format_cpp(&jitc_kernel.source) {
372                jitc_kernel.source = formatted;
373            }
374        }
375        let jitc_kernel = logger.debug(jitc_kernel);
376
377        // Create HIP Program
378        let program = unsafe {
379            let source = CString::new(jitc_kernel.source.clone()).unwrap();
380            let mut program: cubecl_hip_sys::hiprtcProgram = std::ptr::null_mut();
381            let status = cubecl_hip_sys::hiprtcCreateProgram(
382                &mut program,
383                source.as_ptr(),
384                std::ptr::null(), // program name seems unnecessary
385                0,
386                std::ptr::null_mut(),
387                std::ptr::null_mut(),
388            );
389            assert_eq!(
390                status, hiprtcResult_HIPRTC_SUCCESS,
391                "Should create the program"
392            );
393            program
394        };
395        // Compile HIP program
396        // options
397        let include_path = include_path();
398        let include_option = format!("-I{}", include_path.display());
399        let include_option_cstr = CString::new(include_option).unwrap();
400        // needed for rocWMMA extension to compile
401        let cpp_std_option_cstr = CString::new("--std=c++17").unwrap();
402        let mut options: Vec<*const i8> =
403            vec![cpp_std_option_cstr.as_ptr(), include_option_cstr.as_ptr()];
404        unsafe {
405            let options_ptr = options.as_mut_ptr();
406            let status =
407                cubecl_hip_sys::hiprtcCompileProgram(program, options.len() as i32, options_ptr);
408            if status != hiprtcResult_HIPRTC_SUCCESS {
409                let mut log_size: usize = 0;
410                let status =
411                    cubecl_hip_sys::hiprtcGetProgramLogSize(program, &mut log_size as *mut usize);
412                assert_eq!(
413                    status, hiprtcResult_HIPRTC_SUCCESS,
414                    "Should retrieve the compilation log size"
415                );
416                let mut log_buffer = vec![0i8; log_size];
417                let status = cubecl_hip_sys::hiprtcGetProgramLog(program, log_buffer.as_mut_ptr());
418                assert_eq!(
419                    status, hiprtcResult_HIPRTC_SUCCESS,
420                    "Should retrieve the compilation log contents"
421                );
422                let log = CStr::from_ptr(log_buffer.as_ptr());
423                let mut message = "[Compilation Error] ".to_string();
424                if log_size > 0 {
425                    for line in log.to_string_lossy().split('\n') {
426                        if !line.is_empty() {
427                            message += format!("\n    {line}").as_str();
428                        }
429                    }
430                } else {
431                    message += "\n No compilation logs found!";
432                }
433                panic!("{message}\n[Source]  \n{}", jitc_kernel.source);
434            }
435            assert_eq!(
436                status, hiprtcResult_HIPRTC_SUCCESS,
437                "Should compile the program"
438            );
439        };
440        // Get HIP compiled code from program
441        let mut code_size: usize = 0;
442        unsafe {
443            let status = cubecl_hip_sys::hiprtcGetCodeSize(program, &mut code_size);
444            assert_eq!(
445                status, hiprtcResult_HIPRTC_SUCCESS,
446                "Should get size of compiled code"
447            );
448        }
449        let mut code = vec![0i8; code_size];
450        unsafe {
451            let status = cubecl_hip_sys::hiprtcGetCode(program, code.as_mut_ptr());
452            assert_eq!(
453                status, hiprtcResult_HIPRTC_SUCCESS,
454                "Should load compiled code"
455            );
456        }
457        // Create the HIP module
458        let mut module: cubecl_hip_sys::hipModule_t = std::ptr::null_mut();
459        unsafe {
460            let codeptr = code.as_ptr();
461            let status = cubecl_hip_sys::hipModuleLoadData(&mut module, codeptr as *const _);
462            assert_eq!(status, HIP_SUCCESS, "Should load compiled code into module");
463        }
464        // Retrieve the HIP module function
465        let mut func: cubecl_hip_sys::hipFunction_t = std::ptr::null_mut();
466        unsafe {
467            let status =
468                cubecl_hip_sys::hipModuleGetFunction(&mut func, module, func_name.as_ptr());
469            assert_eq!(status, HIP_SUCCESS, "Should return module function");
470        }
471
472        // register module
473        self.module_names.insert(
474            kernel_id.clone(),
475            HipCompiledKernel {
476                _module: module,
477                func,
478                cube_dim: jitc_kernel.cube_dim,
479                shared_mem_bytes: jitc_kernel.shared_mem_bytes,
480            },
481        );
482    }
483
484    fn execute_task(
485        &mut self,
486        kernel_id: KernelId,
487        dispatch_count: (u32, u32, u32),
488        resources: Vec<HipResource>,
489    ) {
490        let mut bindings = resources
491            .iter()
492            .map(|memory| memory.binding)
493            .collect::<Vec<_>>();
494
495        let kernel = self.module_names.get(&kernel_id).unwrap();
496        let cube_dim = kernel.cube_dim;
497
498        unsafe {
499            let status = cubecl_hip_sys::hipModuleLaunchKernel(
500                kernel.func,
501                dispatch_count.0,
502                dispatch_count.1,
503                dispatch_count.2,
504                cube_dim.x,
505                cube_dim.y,
506                cube_dim.z,
507                kernel.shared_mem_bytes as u32,
508                self.stream,
509                bindings.as_mut_ptr(),
510                std::ptr::null_mut(),
511            );
512            if status == cubecl_hip_sys::hipError_t_hipErrorOutOfMemory {
513                panic!("Error: Cannot launch kernel (Out of memory)\n{}", kernel_id)
514            }
515            assert_eq!(status, HIP_SUCCESS, "Should launch the kernel");
516        };
517    }
518}
519
520impl HipServer {
521    /// Create a new hip server.
522    pub(crate) fn new(mut ctx: HipContext) -> Self {
523        let logger = DebugLogger::default();
524        if logger.profile_level().is_some() {
525            ctx.timestamps.enable();
526        }
527
528        Self { ctx, logger }
529    }
530
531    fn get_context(&mut self) -> &mut HipContext {
532        self.get_context_with_logger().0
533    }
534
535    fn get_context_with_logger(&mut self) -> (&mut HipContext, &mut DebugLogger) {
536        (&mut self.ctx, &mut self.logger)
537    }
538}
539
540fn include_path() -> PathBuf {
541    let error_msg = "
542        ROCm HIP installation not found.
543        Please ensure that ROCm is installed with HIP runtimes and development libraries and the ROCM_PATH or HIP_PATH environment variable is set correctly.
544        Note: Default path is /opt/rocm which may not be correct.
545    ";
546    let path = hip_path().expect(error_msg);
547    let result = path.join("include");
548    let hip_include = result.join("hip");
549    if !hip_include.exists() || !hip_include.is_dir() {
550        panic!("{error_msg}");
551    }
552    result
553}
554
555fn hip_path() -> Option<PathBuf> {
556    if let Ok(path) = std::env::var("CUBECL_ROCM_PATH") {
557        return Some(PathBuf::from(path));
558    }
559    if let Ok(path) = std::env::var("ROCM_PATH") {
560        return Some(PathBuf::from(path));
561    }
562    if let Ok(path) = std::env::var("HIP_PATH") {
563        return Some(PathBuf::from(path));
564    }
565    // Default path (only Linux is supported for now)
566    Some(PathBuf::from("/opt/rocm"))
567}