1use cubecl_cpp::formatter::format_cpp;
2use cubecl_cpp::shared::CompilationOptions;
3
4use crate::runtime::HipCompiler;
5
6use super::fence::{Fence, SyncStream};
7use super::storage::HipStorage;
8use super::{uninit_vec, HipResource};
9use cubecl_core::compute::DebugInformation;
10use cubecl_core::ir::CubeDim;
11use cubecl_core::Feature;
12use cubecl_core::{prelude::*, KernelId};
13use cubecl_hip_sys::{hiprtcResult_HIPRTC_SUCCESS, HIP_SUCCESS};
14use cubecl_runtime::debug::{DebugLogger, ProfileLevel};
15use cubecl_runtime::memory_management::MemoryUsage;
16use cubecl_runtime::storage::BindingResource;
17use cubecl_runtime::{
18 memory_management::MemoryManagement,
19 server::{self, ComputeServer},
20};
21use cubecl_runtime::{ExecutionMode, TimestampsError, TimestampsResult};
22use std::collections::HashMap;
23use std::ffi::CStr;
24use std::ffi::CString;
25use std::future::Future;
26use std::path::PathBuf;
27use std::time::Instant;
28
29#[derive(Debug)]
30pub struct HipServer {
31 ctx: HipContext,
32 logger: DebugLogger,
33}
34
35#[derive(Debug)]
36pub(crate) struct HipContext {
37 stream: cubecl_hip_sys::hipStream_t,
38 memory_management: MemoryManagement<HipStorage>,
39 module_names: HashMap<KernelId, HipCompiledKernel>,
40 timestamps: KernelTimestamps,
41 compilation_options: CompilationOptions,
42}
43
44#[derive(Debug)]
45struct HipCompiledKernel {
46 _module: cubecl_hip_sys::hipModule_t,
47 func: cubecl_hip_sys::hipFunction_t,
48 cube_dim: CubeDim,
49 shared_mem_bytes: usize,
50}
51
52#[derive(Debug)]
53enum KernelTimestamps {
54 Inferred { start_time: Instant },
55 Disabled,
56}
57
58impl KernelTimestamps {
59 fn enable(&mut self) {
60 if !matches!(self, Self::Disabled) {
61 return;
62 }
63
64 *self = Self::Inferred {
65 start_time: Instant::now(),
66 };
67 }
68
69 fn disable(&mut self) {
70 *self = Self::Disabled;
71 }
72}
73
74unsafe impl Send for HipServer {}
75
76impl HipServer {
77 fn read_sync(&mut self, binding: server::Binding) -> Vec<u8> {
78 let ctx = self.get_context();
79 let resource = ctx.memory_management.get_resource(
80 binding.memory,
81 binding.offset_start,
82 binding.offset_end,
83 );
84
85 let mut data = uninit_vec(resource.size as usize);
86 unsafe {
87 let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
88 data.as_mut_ptr() as *mut _,
89 resource.ptr,
90 resource.size as usize,
91 ctx.stream,
92 );
93 assert_eq!(status, HIP_SUCCESS, "Should copy data from device to host");
94 };
95 ctx.sync();
96 data
97 }
98
99 fn read_async(
100 &mut self,
101 bindings: Vec<server::Binding>,
102 ) -> impl Future<Output = Vec<Vec<u8>>> + 'static + Send {
103 let ctx = self.get_context();
104 let mut result = Vec::with_capacity(bindings.len());
105
106 for binding in bindings {
107 let resource = ctx.memory_management.get_resource(
108 binding.memory,
109 binding.offset_start,
110 binding.offset_end,
111 );
112
113 let mut data = uninit_vec(resource.size as usize);
114 unsafe {
115 let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
116 data.as_mut_ptr() as *mut _,
117 resource.ptr,
118 resource.size as usize,
119 ctx.stream,
120 );
121 assert_eq!(status, HIP_SUCCESS, "Should copy data from device to host");
122 };
123 result.push(data);
124 }
125
126 let fence = ctx.fence();
127
128 async move {
129 fence.wait();
130 result
131 }
132 }
133
134 fn sync_stream_async(&mut self) -> impl Future<Output = ()> + 'static + Send {
135 let ctx = self.get_context();
136 let sync = ctx.lazy_sync_stream();
141
142 async move {
143 sync.wait();
144 }
145 }
146}
147
148impl ComputeServer for HipServer {
149 type Kernel = Box<dyn CubeTask<HipCompiler>>;
150 type Storage = HipStorage;
151 type Feature = Feature;
152
153 fn read(
154 &mut self,
155 bindings: Vec<server::Binding>,
156 ) -> impl Future<Output = Vec<Vec<u8>>> + 'static {
157 self.read_async(bindings)
158 }
159
160 fn memory_usage(&self) -> MemoryUsage {
161 self.ctx.memory_usage()
162 }
163
164 fn create(&mut self, data: &[u8]) -> server::Handle {
165 let handle = self.empty(data.len());
166 let ctx = self.get_context();
167
168 let binding = handle.clone().binding();
169 let resource = ctx.memory_management.get_resource(
170 binding.memory,
171 binding.offset_start,
172 binding.offset_end,
173 );
174
175 unsafe {
176 let status = cubecl_hip_sys::hipMemcpyHtoDAsync(
177 resource.ptr,
178 data as *const _ as *mut _,
179 data.len(),
180 ctx.stream,
181 );
182 assert_eq!(status, HIP_SUCCESS, "Should send data to device");
183 }
184 handle
185 }
186
187 fn empty(&mut self, size: usize) -> server::Handle {
188 let ctx = self.get_context();
189 let handle = ctx.memory_management.reserve(size as u64, None);
190 server::Handle::new(handle, None, None, size as u64)
191 }
192
193 unsafe fn execute(
194 &mut self,
195 kernel: Self::Kernel,
196 count: CubeCount,
197 bindings: Vec<server::Binding>,
198 mode: ExecutionMode,
199 ) {
200 let mut kernel_id = kernel.id();
201 kernel_id.mode(mode);
202
203 let profile_level = self.logger.profile_level();
204 let profile_info = if profile_level.is_some() {
205 Some((kernel.name(), kernel_id.clone()))
206 } else {
207 None
208 };
209
210 let count = match count {
211 CubeCount::Static(x, y, z) => (x, y, z),
212 CubeCount::Dynamic(binding) => {
216 let data = self.read_sync(binding);
217 let data = bytemuck::cast_slice(&data);
218 assert!(
219 data.len() == 3,
220 "Dynamic cube count should contain 3 values"
221 );
222 (data[0], data[1], data[2])
223 }
224 };
225
226 let (ctx, logger) = self.get_context_with_logger();
227
228 if !ctx.module_names.contains_key(&kernel_id) {
229 ctx.compile_kernel(&kernel_id, kernel, logger, mode);
230 }
231
232 let resources = bindings
233 .into_iter()
234 .map(|binding| {
235 ctx.memory_management.get_resource(
236 binding.memory,
237 binding.offset_start,
238 binding.offset_end,
239 )
240 })
241 .collect::<Vec<_>>();
242
243 if let Some(level) = profile_level {
244 ctx.sync();
245 let start = std::time::SystemTime::now();
246 ctx.execute_task(kernel_id, count, resources);
247 ctx.sync();
248
249 let (name, kernel_id) = profile_info.unwrap();
250 let info = match level {
251 ProfileLevel::Basic | ProfileLevel::Medium => {
252 if let Some(val) = name.split("<").next() {
253 val.split("::").last().unwrap_or(name).to_string()
254 } else {
255 name.to_string()
256 }
257 }
258 cubecl_runtime::debug::ProfileLevel::Full => {
259 format!("{name}: {kernel_id} CubeCount {count:?}")
260 }
261 };
262
263 self.logger
264 .register_profiled(info, start.elapsed().unwrap());
265 } else {
266 ctx.execute_task(kernel_id, count, resources);
267 }
268 }
269
270 fn flush(&mut self) {}
271
272 fn sync(&mut self) -> impl Future<Output = ()> + 'static {
273 self.logger.profile_summary();
274 self.sync_stream_async()
275 }
276
277 fn sync_elapsed(&mut self) -> impl Future<Output = TimestampsResult> + 'static {
278 self.logger.profile_summary();
279
280 let ctx = self.get_context();
281 ctx.sync();
282
283 let duration = match &mut ctx.timestamps {
284 KernelTimestamps::Inferred { start_time } => {
285 let duration = start_time.elapsed();
286 *start_time = Instant::now();
287 Ok(duration)
288 }
289 KernelTimestamps::Disabled => Err(TimestampsError::Disabled),
290 };
291
292 async move { duration }
293 }
294
295 fn get_resource(&mut self, binding: server::Binding) -> BindingResource<Self> {
296 let ctx = self.get_context();
297 BindingResource::new(
298 binding.clone(),
299 ctx.memory_management.get_resource(
300 binding.memory,
301 binding.offset_start,
302 binding.offset_end,
303 ),
304 )
305 }
306
307 fn enable_timestamps(&mut self) {
308 self.ctx.timestamps.enable();
309 }
310
311 fn disable_timestamps(&mut self) {
312 if self.logger.profile_level().is_none() {
313 self.ctx.timestamps.disable();
314 }
315 }
316}
317
318impl HipContext {
319 pub fn new(
320 memory_management: MemoryManagement<HipStorage>,
321 compilation_options: CompilationOptions,
322 stream: cubecl_hip_sys::hipStream_t,
323 ) -> Self {
324 Self {
325 memory_management,
326 module_names: HashMap::new(),
327 stream,
328 timestamps: KernelTimestamps::Disabled,
329 compilation_options,
330 }
331 }
332
333 fn fence(&mut self) -> Fence {
334 Fence::new(self.stream)
335 }
336
337 fn lazy_sync_stream(&mut self) -> SyncStream {
338 SyncStream::new(self.stream)
339 }
340
341 fn sync(&mut self) {
342 unsafe {
343 let status = cubecl_hip_sys::hipStreamSynchronize(self.stream);
344 assert_eq!(
345 status, HIP_SUCCESS,
346 "Should successfully synchronize stream"
347 );
348 };
349 self.memory_management.storage().flush();
350 }
351
352 fn memory_usage(&self) -> MemoryUsage {
353 self.memory_management.memory_usage()
354 }
355
356 fn compile_kernel(
357 &mut self,
358 kernel_id: &KernelId,
359 cube_kernel: Box<dyn CubeTask<HipCompiler>>,
360 logger: &mut DebugLogger,
361 mode: ExecutionMode,
362 ) {
363 let mut jitc_kernel = cube_kernel.compile(&self.compilation_options, mode);
366 let func_name = CString::new(jitc_kernel.entrypoint_name.clone()).unwrap();
367
368 if logger.is_activated() {
369 jitc_kernel.debug_info = Some(DebugInformation::new("cpp", kernel_id.clone()));
370
371 if let Ok(formatted) = format_cpp(&jitc_kernel.source) {
372 jitc_kernel.source = formatted;
373 }
374 }
375 let jitc_kernel = logger.debug(jitc_kernel);
376
377 let program = unsafe {
379 let source = CString::new(jitc_kernel.source.clone()).unwrap();
380 let mut program: cubecl_hip_sys::hiprtcProgram = std::ptr::null_mut();
381 let status = cubecl_hip_sys::hiprtcCreateProgram(
382 &mut program,
383 source.as_ptr(),
384 std::ptr::null(), 0,
386 std::ptr::null_mut(),
387 std::ptr::null_mut(),
388 );
389 assert_eq!(
390 status, hiprtcResult_HIPRTC_SUCCESS,
391 "Should create the program"
392 );
393 program
394 };
395 let include_path = include_path();
398 let include_option = format!("-I{}", include_path.display());
399 let include_option_cstr = CString::new(include_option).unwrap();
400 let cpp_std_option_cstr = CString::new("--std=c++17").unwrap();
402 let mut options: Vec<*const i8> =
403 vec![cpp_std_option_cstr.as_ptr(), include_option_cstr.as_ptr()];
404 unsafe {
405 let options_ptr = options.as_mut_ptr();
406 let status =
407 cubecl_hip_sys::hiprtcCompileProgram(program, options.len() as i32, options_ptr);
408 if status != hiprtcResult_HIPRTC_SUCCESS {
409 let mut log_size: usize = 0;
410 let status =
411 cubecl_hip_sys::hiprtcGetProgramLogSize(program, &mut log_size as *mut usize);
412 assert_eq!(
413 status, hiprtcResult_HIPRTC_SUCCESS,
414 "Should retrieve the compilation log size"
415 );
416 let mut log_buffer = vec![0i8; log_size];
417 let status = cubecl_hip_sys::hiprtcGetProgramLog(program, log_buffer.as_mut_ptr());
418 assert_eq!(
419 status, hiprtcResult_HIPRTC_SUCCESS,
420 "Should retrieve the compilation log contents"
421 );
422 let log = CStr::from_ptr(log_buffer.as_ptr());
423 let mut message = "[Compilation Error] ".to_string();
424 if log_size > 0 {
425 for line in log.to_string_lossy().split('\n') {
426 if !line.is_empty() {
427 message += format!("\n {line}").as_str();
428 }
429 }
430 } else {
431 message += "\n No compilation logs found!";
432 }
433 panic!("{message}\n[Source] \n{}", jitc_kernel.source);
434 }
435 assert_eq!(
436 status, hiprtcResult_HIPRTC_SUCCESS,
437 "Should compile the program"
438 );
439 };
440 let mut code_size: usize = 0;
442 unsafe {
443 let status = cubecl_hip_sys::hiprtcGetCodeSize(program, &mut code_size);
444 assert_eq!(
445 status, hiprtcResult_HIPRTC_SUCCESS,
446 "Should get size of compiled code"
447 );
448 }
449 let mut code = vec![0i8; code_size];
450 unsafe {
451 let status = cubecl_hip_sys::hiprtcGetCode(program, code.as_mut_ptr());
452 assert_eq!(
453 status, hiprtcResult_HIPRTC_SUCCESS,
454 "Should load compiled code"
455 );
456 }
457 let mut module: cubecl_hip_sys::hipModule_t = std::ptr::null_mut();
459 unsafe {
460 let codeptr = code.as_ptr();
461 let status = cubecl_hip_sys::hipModuleLoadData(&mut module, codeptr as *const _);
462 assert_eq!(status, HIP_SUCCESS, "Should load compiled code into module");
463 }
464 let mut func: cubecl_hip_sys::hipFunction_t = std::ptr::null_mut();
466 unsafe {
467 let status =
468 cubecl_hip_sys::hipModuleGetFunction(&mut func, module, func_name.as_ptr());
469 assert_eq!(status, HIP_SUCCESS, "Should return module function");
470 }
471
472 self.module_names.insert(
474 kernel_id.clone(),
475 HipCompiledKernel {
476 _module: module,
477 func,
478 cube_dim: jitc_kernel.cube_dim,
479 shared_mem_bytes: jitc_kernel.shared_mem_bytes,
480 },
481 );
482 }
483
484 fn execute_task(
485 &mut self,
486 kernel_id: KernelId,
487 dispatch_count: (u32, u32, u32),
488 resources: Vec<HipResource>,
489 ) {
490 let mut bindings = resources
491 .iter()
492 .map(|memory| memory.binding)
493 .collect::<Vec<_>>();
494
495 let kernel = self.module_names.get(&kernel_id).unwrap();
496 let cube_dim = kernel.cube_dim;
497
498 unsafe {
499 let status = cubecl_hip_sys::hipModuleLaunchKernel(
500 kernel.func,
501 dispatch_count.0,
502 dispatch_count.1,
503 dispatch_count.2,
504 cube_dim.x,
505 cube_dim.y,
506 cube_dim.z,
507 kernel.shared_mem_bytes as u32,
508 self.stream,
509 bindings.as_mut_ptr(),
510 std::ptr::null_mut(),
511 );
512 if status == cubecl_hip_sys::hipError_t_hipErrorOutOfMemory {
513 panic!("Error: Cannot launch kernel (Out of memory)\n{}", kernel_id)
514 }
515 assert_eq!(status, HIP_SUCCESS, "Should launch the kernel");
516 };
517 }
518}
519
520impl HipServer {
521 pub(crate) fn new(mut ctx: HipContext) -> Self {
523 let logger = DebugLogger::default();
524 if logger.profile_level().is_some() {
525 ctx.timestamps.enable();
526 }
527
528 Self { ctx, logger }
529 }
530
531 fn get_context(&mut self) -> &mut HipContext {
532 self.get_context_with_logger().0
533 }
534
535 fn get_context_with_logger(&mut self) -> (&mut HipContext, &mut DebugLogger) {
536 (&mut self.ctx, &mut self.logger)
537 }
538}
539
540fn include_path() -> PathBuf {
541 let error_msg = "
542 ROCm HIP installation not found.
543 Please ensure that ROCm is installed with HIP runtimes and development libraries and the ROCM_PATH or HIP_PATH environment variable is set correctly.
544 Note: Default path is /opt/rocm which may not be correct.
545 ";
546 let path = hip_path().expect(error_msg);
547 let result = path.join("include");
548 let hip_include = result.join("hip");
549 if !hip_include.exists() || !hip_include.is_dir() {
550 panic!("{error_msg}");
551 }
552 result
553}
554
555fn hip_path() -> Option<PathBuf> {
556 if let Ok(path) = std::env::var("CUBECL_ROCM_PATH") {
557 return Some(PathBuf::from(path));
558 }
559 if let Ok(path) = std::env::var("ROCM_PATH") {
560 return Some(PathBuf::from(path));
561 }
562 if let Ok(path) = std::env::var("HIP_PATH") {
563 return Some(PathBuf::from(path));
564 }
565 Some(PathBuf::from("/opt/rocm"))
567}