1use cubecl_cpp::formatter::format_cpp;
2use cubecl_cpp::shared::CompilationOptions;
3
4use crate::runtime::HipCompiler;
5
6use super::fence::{Fence, SyncStream};
7use super::storage::HipStorage;
8use super::{HipResource, uninit_vec};
9use cubecl_common::benchmark::ProfileDuration;
10use cubecl_core::compute::DebugInformation;
11use cubecl_core::{Feature, server::Bindings};
12use cubecl_core::{KernelId, prelude::*};
13use cubecl_hip_sys::{HIP_SUCCESS, hiprtcResult_HIPRTC_SUCCESS};
14use cubecl_runtime::debug::{DebugLogger, ProfileLevel};
15use cubecl_runtime::kernel_timestamps::KernelTimestamps;
16use cubecl_runtime::memory_management::MemoryUsage;
17use cubecl_runtime::storage::BindingResource;
18use cubecl_runtime::{
19 memory_management::MemoryManagement,
20 server::{self, ComputeServer},
21};
22use std::collections::HashMap;
23use std::ffi::CStr;
24use std::ffi::CString;
25use std::future::Future;
26use std::path::PathBuf;
27
28#[cfg(feature = "compilation-cache")]
29use cubecl_common::cache::{Cache, CacheOption};
30
31#[derive(Debug)]
32pub struct HipServer {
33 ctx: HipContext,
34 logger: DebugLogger,
35}
36
37#[derive(Debug)]
38pub(crate) struct HipContext {
39 stream: cubecl_hip_sys::hipStream_t,
40 memory_management: MemoryManagement<HipStorage>,
41 module_names: HashMap<KernelId, HipCompiledKernel>,
42 timestamps: KernelTimestamps,
43 compilation_options: CompilationOptions,
44 #[cfg(feature = "compilation-cache")]
45 compilation_cache: Cache<String, CompilationCacheEntry>,
46}
47
48#[cfg(feature = "compilation-cache")]
49#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Eq, Clone)]
50pub struct CompilationCacheEntry {
51 entrypoint_name: String,
52 cube_dim: (u32, u32, u32),
53 binary: Vec<i8>,
54}
55
56#[derive(Debug)]
57struct HipCompiledKernel {
58 _module: cubecl_hip_sys::hipModule_t,
59 func: cubecl_hip_sys::hipFunction_t,
60 cube_dim: CubeDim,
61}
62
63unsafe impl Send for HipServer {}
64
65impl HipServer {
66 fn read_sync(&mut self, binding: server::Binding) -> Vec<u8> {
67 let ctx = self.get_context();
68 let resource = ctx
69 .memory_management
70 .get_resource(binding.memory, binding.offset_start, binding.offset_end)
71 .expect("Failed to find resource");
72
73 let mut data = uninit_vec(resource.size as usize);
74 unsafe {
75 let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
76 data.as_mut_ptr() as *mut _,
77 resource.ptr,
78 resource.size as usize,
79 ctx.stream,
80 );
81 assert_eq!(status, HIP_SUCCESS, "Should copy data from device to host");
82 };
83 ctx.sync();
84 data
85 }
86
87 fn read_async(
88 &mut self,
89 bindings: Vec<server::Binding>,
90 ) -> impl Future<Output = Vec<Vec<u8>>> + 'static + Send {
91 let ctx = self.get_context();
92 let mut result = Vec::with_capacity(bindings.len());
93
94 for binding in bindings {
95 let resource = ctx
96 .memory_management
97 .get_resource(binding.memory, binding.offset_start, binding.offset_end)
98 .expect("Failed to find resource");
99
100 let mut data = uninit_vec(resource.size as usize);
101 unsafe {
102 let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
103 data.as_mut_ptr() as *mut _,
104 resource.ptr,
105 resource.size as usize,
106 ctx.stream,
107 );
108 assert_eq!(status, HIP_SUCCESS, "Should copy data from device to host");
109 };
110 result.push(data);
111 }
112
113 let fence = ctx.fence();
114
115 async move {
116 fence.wait();
117 result
118 }
119 }
120
121 fn sync_stream_async(&mut self) -> impl Future<Output = ()> + 'static + Send {
122 let ctx = self.get_context();
123 let sync = ctx.lazy_sync_stream();
128
129 async move {
130 sync.wait();
131 }
132 }
133}
134
135impl ComputeServer for HipServer {
136 type Kernel = Box<dyn CubeTask<HipCompiler>>;
137 type Storage = HipStorage;
138 type Feature = Feature;
139 type Info = ();
140
141 fn read(
142 &mut self,
143 bindings: Vec<server::Binding>,
144 ) -> impl Future<Output = Vec<Vec<u8>>> + 'static {
145 self.read_async(bindings)
146 }
147
148 fn read_tensor(
149 &mut self,
150 bindings: Vec<server::BindingWithMeta>,
151 ) -> impl Future<Output = Vec<Vec<u8>>> + 'static {
152 let bindings = bindings.into_iter().map(|it| it.binding).collect();
153 self.read_async(bindings)
154 }
155
156 fn memory_usage(&self) -> MemoryUsage {
157 self.ctx.memory_usage()
158 }
159
160 fn memory_cleanup(&mut self) {
161 let ctx = self.get_context();
162 ctx.memory_management.cleanup(true);
163 }
164
165 fn create(&mut self, data: &[u8]) -> server::Handle {
166 let handle = self.empty(data.len());
167 let ctx = self.get_context();
168
169 let binding = handle.clone().binding();
170 let resource = ctx
171 .memory_management
172 .get_resource(binding.memory, binding.offset_start, binding.offset_end)
173 .unwrap();
174
175 unsafe {
176 let status = cubecl_hip_sys::hipMemcpyHtoDAsync(
177 resource.ptr,
178 data as *const _ as *mut _,
179 data.len(),
180 ctx.stream,
181 );
182 assert_eq!(status, HIP_SUCCESS, "Should send data to device");
183 }
184 handle
185 }
186
187 fn create_tensor(
188 &mut self,
189 data: &[u8],
190 shape: &[usize],
191 _elem_size: usize,
192 ) -> (server::Handle, Vec<usize>) {
193 let strides = contiguous_strides(shape);
194 let handle = self.create(data);
195 (handle, strides)
196 }
197
198 fn empty(&mut self, size: usize) -> server::Handle {
199 let ctx = self.get_context();
200 let handle = ctx.memory_management.reserve(size as u64, None);
201 server::Handle::new(handle, None, None, size as u64)
202 }
203
204 fn empty_tensor(&mut self, shape: &[usize], elem_size: usize) -> (server::Handle, Vec<usize>) {
205 let strides = contiguous_strides(shape);
206 let size = shape.iter().product::<usize>() * elem_size;
207 let handle = self.empty(size);
208 (handle, strides)
209 }
210
211 unsafe fn execute(
212 &mut self,
213 kernel: Self::Kernel,
214 count: CubeCount,
215 bindings: Bindings,
216 mode: ExecutionMode,
217 ) {
218 let mut kernel_id = kernel.id();
219 kernel_id.mode(mode);
220
221 let profile_level = self.logger.profile_level();
222 let profile_info = if profile_level.is_some() {
223 Some((kernel.name(), kernel_id.clone()))
224 } else {
225 None
226 };
227
228 let count = match count {
229 CubeCount::Static(x, y, z) => (x, y, z),
230 CubeCount::Dynamic(binding) => {
234 let data = self.read_sync(binding);
235 let data = bytemuck::cast_slice(&data);
236 assert!(
237 data.len() == 3,
238 "Dynamic cube count should contain 3 values"
239 );
240 (data[0], data[1], data[2])
241 }
242 };
243
244 let Bindings {
245 buffers,
246 metadata,
247 scalars,
248 tensor_maps,
249 } = bindings;
250
251 debug_assert!(tensor_maps.is_empty(), "Can't use tensor maps on HIP");
252 let info = self.create(bytemuck::cast_slice(&metadata.data));
253 let scalars: Vec<_> = scalars.values().map(|s| self.create(s.data())).collect();
254
255 let (ctx, logger) = self.get_context_with_logger();
256
257 if !ctx.module_names.contains_key(&kernel_id) {
258 ctx.compile_kernel(&kernel_id, kernel, logger, mode);
259 }
260
261 let mut resources: Vec<_> = buffers.into_iter().map(|b| find_resource(ctx, b)).collect();
262 resources.push(find_resource(ctx, info.clone().binding()));
263 resources.extend(scalars.into_iter().map(|s| find_resource(ctx, s.binding())));
264
265 if let Some(level) = profile_level {
266 ctx.sync();
267 let start = std::time::SystemTime::now();
268 ctx.execute_task(kernel_id, count, resources);
269 ctx.sync();
270
271 let (name, kernel_id) = profile_info.unwrap();
272 let info = match level {
273 ProfileLevel::Basic | ProfileLevel::Medium => {
274 if let Some(val) = name.split("<").next() {
275 val.split("::").last().unwrap_or(name).to_string()
276 } else {
277 name.to_string()
278 }
279 }
280 cubecl_runtime::debug::ProfileLevel::Full => {
281 format!("{name}: {kernel_id} CubeCount {count:?}")
282 }
283 };
284
285 self.logger
286 .register_profiled(info, start.elapsed().unwrap());
287 } else {
288 ctx.execute_task(kernel_id, count, resources);
289 }
290 }
291
292 fn flush(&mut self) {}
293
294 fn sync(&mut self) -> impl Future<Output = ()> + 'static {
295 self.logger.profile_summary();
296 self.sync_stream_async()
297 }
298
299 fn start_profile(&mut self) {
300 cubecl_common::future::block_on(self.sync());
301 self.ctx.timestamps.start();
302 }
303
304 fn end_profile(&mut self) -> ProfileDuration {
305 self.logger.profile_summary();
306 cubecl_common::future::block_on(self.sync());
307 self.ctx.timestamps.stop()
308 }
309
310 fn get_resource(&mut self, binding: server::Binding) -> BindingResource<HipResource> {
311 let ctx = self.get_context();
312 BindingResource::new(
313 binding.clone(),
314 ctx.memory_management
315 .get_resource(binding.memory, binding.offset_start, binding.offset_end)
316 .expect("Can't find resource"),
317 )
318 }
319}
320
321fn find_resource(ctx: &mut HipContext, binding: server::Binding) -> HipResource {
322 ctx.memory_management
323 .get_resource(binding.memory, binding.offset_start, binding.offset_end)
324 .expect("Failed to find resource")
325}
326
327impl HipContext {
328 pub fn new(
329 memory_management: MemoryManagement<HipStorage>,
330 compilation_options: CompilationOptions,
331 stream: cubecl_hip_sys::hipStream_t,
332 ) -> Self {
333 Self {
334 memory_management,
335 module_names: HashMap::new(),
336 stream,
337 timestamps: KernelTimestamps::default(),
338 compilation_options,
339 #[cfg(feature = "compilation-cache")]
340 compilation_cache: Cache::new("hip/compilation", CacheOption::default()),
341 }
342 }
343
344 fn fence(&mut self) -> Fence {
345 Fence::new(self.stream)
346 }
347
348 fn lazy_sync_stream(&mut self) -> SyncStream {
349 SyncStream::new(self.stream)
350 }
351
352 fn sync(&mut self) {
353 unsafe {
354 let status = cubecl_hip_sys::hipStreamSynchronize(self.stream);
355 assert_eq!(
356 status, HIP_SUCCESS,
357 "Should successfully synchronize stream"
358 );
359 };
360 self.memory_management.storage().flush();
361 }
362
363 fn memory_usage(&self) -> MemoryUsage {
364 self.memory_management.memory_usage()
365 }
366
367 fn compile_kernel(
368 &mut self,
369 kernel_id: &KernelId,
370 cube_kernel: Box<dyn CubeTask<HipCompiler>>,
371 logger: &mut DebugLogger,
372 mode: ExecutionMode,
373 ) {
374 #[cfg(feature = "compilation-cache")]
375 let name = kernel_id.stable_format();
376 #[cfg(feature = "compilation-cache")]
377 if let Some(entry) = self.compilation_cache.get(&name) {
378 log::trace!("Using compilation cache");
379 self.load_compiled_binary(
380 entry.binary.clone(),
381 kernel_id.clone(),
382 entry.entrypoint_name.clone(),
383 CubeDim {
384 x: entry.cube_dim.0,
385 y: entry.cube_dim.1,
386 z: entry.cube_dim.2,
387 },
388 );
389 return;
390 }
391
392 let mut jitc_kernel =
395 cube_kernel.compile(&mut Default::default(), &self.compilation_options, mode);
396
397 if logger.is_activated() {
398 jitc_kernel.debug_info = Some(DebugInformation::new("cpp", kernel_id.clone()));
399
400 if let Ok(formatted) = format_cpp(&jitc_kernel.source) {
401 jitc_kernel.source = formatted;
402 }
403 }
404 let jitc_kernel = logger.debug(jitc_kernel);
405
406 let program = unsafe {
408 let source = CString::new(jitc_kernel.source.clone()).unwrap();
409 let mut program: cubecl_hip_sys::hiprtcProgram = std::ptr::null_mut();
410 let status = cubecl_hip_sys::hiprtcCreateProgram(
411 &mut program,
412 source.as_ptr(),
413 std::ptr::null(), 0,
415 std::ptr::null_mut(),
416 std::ptr::null_mut(),
417 );
418 assert_eq!(
419 status, hiprtcResult_HIPRTC_SUCCESS,
420 "Should create the program"
421 );
422 program
423 };
424 let include_path = include_path();
427 let include_option = format!("-I{}", include_path.display());
428 let include_option_cstr = CString::new(include_option).unwrap();
429 let cpp_std_option_cstr = CString::new("--std=c++17").unwrap();
431 let mut options = vec![cpp_std_option_cstr.as_ptr(), include_option_cstr.as_ptr()];
432 unsafe {
433 let options_ptr = options.as_mut_ptr();
434 let status =
435 cubecl_hip_sys::hiprtcCompileProgram(program, options.len() as i32, options_ptr);
436 if status != hiprtcResult_HIPRTC_SUCCESS {
437 let mut log_size: usize = 0;
438 let status =
439 cubecl_hip_sys::hiprtcGetProgramLogSize(program, &mut log_size as *mut usize);
440 assert_eq!(
441 status, hiprtcResult_HIPRTC_SUCCESS,
442 "Should retrieve the compilation log size"
443 );
444 let mut log_buffer = vec![0; log_size];
445 let status = cubecl_hip_sys::hiprtcGetProgramLog(program, log_buffer.as_mut_ptr());
446 assert_eq!(
447 status, hiprtcResult_HIPRTC_SUCCESS,
448 "Should retrieve the compilation log contents"
449 );
450 let log = CStr::from_ptr(log_buffer.as_ptr());
451 let mut message = "[Compilation Error] ".to_string();
452 if log_size > 0 {
453 for line in log.to_string_lossy().split('\n') {
454 if !line.is_empty() {
455 message += format!("\n {line}").as_str();
456 }
457 }
458 } else {
459 message += "\n No compilation logs found!";
460 }
461 panic!("{message}\n[Source] \n{}", jitc_kernel.source);
462 }
463 assert_eq!(
464 status, hiprtcResult_HIPRTC_SUCCESS,
465 "Should compile the program"
466 );
467 };
468 let mut code_size: usize = 0;
470 unsafe {
471 let status = cubecl_hip_sys::hiprtcGetCodeSize(program, &mut code_size);
472 assert_eq!(
473 status, hiprtcResult_HIPRTC_SUCCESS,
474 "Should get size of compiled code"
475 );
476 }
477 let mut code = vec![0; code_size];
478 unsafe {
479 let status = cubecl_hip_sys::hiprtcGetCode(program, code.as_mut_ptr());
480 assert_eq!(
481 status, hiprtcResult_HIPRTC_SUCCESS,
482 "Should load compiled code"
483 );
484 }
485
486 #[cfg(feature = "compilation-cache")]
487 self.compilation_cache
488 .insert(
489 name,
490 CompilationCacheEntry {
491 entrypoint_name: jitc_kernel.entrypoint_name.clone(),
492 cube_dim: (
493 jitc_kernel.cube_dim.x,
494 jitc_kernel.cube_dim.y,
495 jitc_kernel.cube_dim.z,
496 ),
497 binary: code.clone(),
498 },
499 )
500 .unwrap();
501
502 self.load_compiled_binary(
503 code,
504 kernel_id.clone(),
505 jitc_kernel.entrypoint_name,
506 jitc_kernel.cube_dim,
507 );
508 }
509
510 fn load_compiled_binary(
511 &mut self,
512 code: Vec<i8>,
513 kernel_id: KernelId,
514 entrypoint_name: String,
515 cube_dim: CubeDim,
516 ) {
517 let func_name = CString::new(entrypoint_name.clone()).unwrap();
518
519 let mut module: cubecl_hip_sys::hipModule_t = std::ptr::null_mut();
521 unsafe {
522 let codeptr = code.as_ptr();
523 let status = cubecl_hip_sys::hipModuleLoadData(&mut module, codeptr as *const _);
524 assert_eq!(status, HIP_SUCCESS, "Should load compiled code into module");
525 }
526 let mut func: cubecl_hip_sys::hipFunction_t = std::ptr::null_mut();
528 unsafe {
529 let status =
530 cubecl_hip_sys::hipModuleGetFunction(&mut func, module, func_name.as_ptr());
531 assert_eq!(status, HIP_SUCCESS, "Should return module function");
532 }
533
534 self.module_names.insert(
536 kernel_id.clone(),
537 HipCompiledKernel {
538 _module: module,
539 func,
540 cube_dim,
541 },
542 );
543 }
544
545 fn execute_task(
546 &mut self,
547 kernel_id: KernelId,
548 dispatch_count: (u32, u32, u32),
549 resources: Vec<HipResource>,
550 ) {
551 let mut bindings = resources
552 .iter()
553 .map(|memory| memory.binding)
554 .collect::<Vec<_>>();
555
556 let kernel = self.module_names.get(&kernel_id).unwrap();
557 let cube_dim = kernel.cube_dim;
558
559 unsafe {
560 let status = cubecl_hip_sys::hipModuleLaunchKernel(
561 kernel.func,
562 dispatch_count.0,
563 dispatch_count.1,
564 dispatch_count.2,
565 cube_dim.x,
566 cube_dim.y,
567 cube_dim.z,
568 0,
572 self.stream,
573 bindings.as_mut_ptr(),
574 std::ptr::null_mut(),
575 );
576 if status == cubecl_hip_sys::hipError_t_hipErrorOutOfMemory {
577 panic!("Error: Cannot launch kernel (Out of memory)\n{}", kernel_id)
578 }
579 assert_eq!(status, HIP_SUCCESS, "Should launch the kernel");
580 };
581 }
582}
583
584impl HipServer {
585 pub(crate) fn new(ctx: HipContext) -> Self {
587 let logger = DebugLogger::default();
588 Self { ctx, logger }
589 }
590
591 fn get_context(&mut self) -> &mut HipContext {
592 self.get_context_with_logger().0
593 }
594
595 fn get_context_with_logger(&mut self) -> (&mut HipContext, &mut DebugLogger) {
596 (&mut self.ctx, &mut self.logger)
597 }
598}
599
600fn include_path() -> PathBuf {
601 let error_msg = "
602 ROCm HIP installation not found.
603 Please ensure that ROCm is installed with HIP runtimes and development libraries and the ROCM_PATH or HIP_PATH environment variable is set correctly.
604 Note: Default path is /opt/rocm which may not be correct.
605 ";
606 let path = hip_path().expect(error_msg);
607 let result = path.join("include");
608 let hip_include = result.join("hip");
609 if !hip_include.exists() || !hip_include.is_dir() {
610 panic!("{error_msg}");
611 }
612 result
613}
614
615fn hip_path() -> Option<PathBuf> {
616 if let Ok(path) = std::env::var("CUBECL_ROCM_PATH") {
617 return Some(PathBuf::from(path));
618 }
619 if let Ok(path) = std::env::var("ROCM_PATH") {
620 return Some(PathBuf::from(path));
621 }
622 if let Ok(path) = std::env::var("HIP_PATH") {
623 return Some(PathBuf::from(path));
624 }
625 Some(PathBuf::from("/opt/rocm"))
627}
628
629fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
630 let rank = shape.len();
631 let mut strides = vec![1; rank];
632 for i in (0..rank - 1).rev() {
633 strides[i] = strides[i + 1] * shape[i + 1];
634 }
635 strides
636}