1use cubecl_core::server::{ProfileError, ProfilingToken};
2use cubecl_cpp::formatter::format_cpp;
3use cubecl_cpp::shared::CompilationOptions;
4
5use super::fence::{Fence, SyncStream};
6use super::storage::HipStorage;
7use super::{HipResource, uninit_vec};
8use crate::runtime::HipCompiler;
9use cubecl_common::future::DynFut;
10use cubecl_common::profile::ProfileDuration;
11use cubecl_core::compute::CubeTask;
12use cubecl_core::compute::DebugInformation;
13use cubecl_core::prelude::*;
14use cubecl_core::{Feature, server::Bindings};
15use cubecl_hip_sys::{HIP_SUCCESS, get_hip_include_path, hiprtcResult_HIPRTC_SUCCESS};
16use cubecl_runtime::logging::ServerLogger;
17use cubecl_runtime::memory_management::MemoryUsage;
18use cubecl_runtime::memory_management::offset_handles;
19use cubecl_runtime::storage::BindingResource;
20use cubecl_runtime::timestamp_profiler::TimestampProfiler;
21use cubecl_runtime::{
22 memory_management::MemoryManagement,
23 server::{self, ComputeServer},
24};
25use std::collections::HashMap;
26use std::ffi::CStr;
27use std::ffi::CString;
28use std::future::Future;
29use std::sync::Arc;
30
31#[cfg(feature = "compilation-cache")]
32use cubecl_common::cache::{Cache, CacheOption};
33
34#[derive(Debug)]
35pub struct HipServer {
36 ctx: HipContext,
37 mem_alignment: usize,
38}
39
40#[derive(Debug)]
41pub(crate) struct HipContext {
42 stream: cubecl_hip_sys::hipStream_t,
43 memory_management: MemoryManagement<HipStorage>,
44 module_names: HashMap<KernelId, HipCompiledKernel>,
45 timestamps: TimestampProfiler,
46 compilation_options: CompilationOptions,
47 #[cfg(feature = "compilation-cache")]
48 compilation_cache: Cache<String, CompilationCacheEntry>,
49}
50
51#[cfg(feature = "compilation-cache")]
52#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Eq, Clone)]
53pub struct CompilationCacheEntry {
54 entrypoint_name: String,
55 cube_dim: (u32, u32, u32),
56 binary: Vec<i8>,
57}
58
59#[derive(Debug)]
60struct HipCompiledKernel {
61 _module: cubecl_hip_sys::hipModule_t,
62 func: cubecl_hip_sys::hipFunction_t,
63 cube_dim: CubeDim,
64}
65
66unsafe impl Send for HipServer {}
67
68impl HipServer {
69 fn read_sync(&mut self, binding: server::Binding) -> Vec<u8> {
70 let ctx = self.get_context();
71 let resource = ctx
72 .memory_management
73 .get_resource(binding.memory, binding.offset_start, binding.offset_end)
74 .expect("Failed to find resource");
75
76 let mut data = uninit_vec(resource.size as usize);
77 unsafe {
78 let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
79 data.as_mut_ptr() as *mut _,
80 resource.ptr,
81 resource.size as usize,
82 ctx.stream,
83 );
84 assert_eq!(status, HIP_SUCCESS, "Should copy data from device to host");
85 };
86 ctx.sync();
87 data
88 }
89
90 fn read_async(
91 &mut self,
92 bindings: Vec<server::Binding>,
93 ) -> impl Future<Output = Vec<Vec<u8>>> + Send + use<> {
94 let ctx = self.get_context();
95 let mut result = Vec::with_capacity(bindings.len());
96
97 for binding in bindings {
98 let resource = ctx
99 .memory_management
100 .get_resource(binding.memory, binding.offset_start, binding.offset_end)
101 .expect("Failed to find resource");
102
103 let mut data = uninit_vec(resource.size as usize);
104 unsafe {
105 let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
106 data.as_mut_ptr() as *mut _,
107 resource.ptr,
108 resource.size as usize,
109 ctx.stream,
110 );
111 assert_eq!(status, HIP_SUCCESS, "Should copy data from device to host");
112 };
113 result.push(data);
114 }
115
116 let fence = ctx.fence();
117
118 async move {
119 fence.wait();
120 result
121 }
122 }
123
124 fn sync_stream_async(&mut self) -> impl Future<Output = ()> + Send + use<> {
125 let ctx = self.get_context();
126 let sync = ctx.lazy_sync_stream();
131
132 async move {
133 sync.wait();
134 }
135 }
136}
137
138impl ComputeServer for HipServer {
139 type Kernel = Box<dyn CubeTask<HipCompiler>>;
140 type Storage = HipStorage;
141 type Feature = Feature;
142 type Info = ();
143
144 fn read(&mut self, bindings: Vec<server::Binding>) -> DynFut<Vec<Vec<u8>>> {
145 Box::pin(self.read_async(bindings))
146 }
147
148 fn read_tensor(&mut self, bindings: Vec<server::BindingWithMeta>) -> DynFut<Vec<Vec<u8>>> {
149 let bindings = bindings.into_iter().map(|it| it.binding).collect();
150 Box::pin(self.read_async(bindings))
151 }
152
153 fn memory_usage(&self) -> MemoryUsage {
154 self.ctx.memory_usage()
155 }
156
157 fn memory_cleanup(&mut self) {
158 let ctx = self.get_context();
159 ctx.memory_management.cleanup(true);
160 }
161
162 fn create(&mut self, data: &[u8]) -> server::Handle {
163 let handle = self.empty(data.len());
164
165 let binding = handle.clone().binding();
166 self.copy_to_binding(binding, data);
167 handle
168 }
169
170 fn create_tensors(
171 &mut self,
172 data: Vec<&[u8]>,
173 shapes: Vec<&[usize]>,
174 elem_sizes: Vec<usize>,
175 ) -> Vec<(server::Handle, Vec<usize>)> {
176 let handles_strides = self.empty_tensors(shapes.clone(), elem_sizes);
177 for i in 0..data.len() {
178 let data = data[i];
179 let (handle, _) = &handles_strides[i];
180 let binding = handle.clone().binding();
181 self.copy_to_binding(binding, data);
182 }
183 handles_strides
184 }
185
186 fn empty(&mut self, size: usize) -> server::Handle {
187 let ctx = self.get_context();
188 let handle = ctx.memory_management.reserve(size as u64, None);
189 server::Handle::new(handle, None, None, size as u64)
190 }
191
192 fn empty_tensors(
193 &mut self,
194 shapes: Vec<&[usize]>,
195 elem_sizes: Vec<usize>,
196 ) -> Vec<(server::Handle, Vec<usize>)> {
197 let mut total_size = 0;
198 let mut strides = Vec::new();
199 let mut sizes = Vec::new();
200
201 for (shape, elem_size) in shapes.into_iter().zip(elem_sizes) {
202 let size =
203 (shape.iter().product::<usize>() * elem_size).next_multiple_of(self.mem_alignment);
204 strides.push(contiguous_strides(shape));
205 sizes.push(size);
206 total_size += size;
207 }
208
209 let mem_handle = self.empty(total_size);
210 let handles = offset_handles(mem_handle, &sizes);
211
212 handles.into_iter().zip(strides).collect()
213 }
214
215 unsafe fn execute(
216 &mut self,
217 kernel: Self::Kernel,
218 count: CubeCount,
219 bindings: Bindings,
220 mode: ExecutionMode,
221 logger: Arc<ServerLogger>,
222 ) {
223 let mut kernel_id = kernel.id();
224 kernel_id.mode(mode);
225
226 let count = match count {
227 CubeCount::Static(x, y, z) => (x, y, z),
228 CubeCount::Dynamic(binding) => {
232 let data = self.read_sync(binding);
233 let data = bytemuck::cast_slice(&data);
234 assert!(
235 data.len() == 3,
236 "Dynamic cube count should contain 3 values"
237 );
238 (data[0], data[1], data[2])
239 }
240 };
241
242 let Bindings {
243 buffers,
244 metadata,
245 scalars,
246 tensor_maps,
247 } = bindings;
248
249 debug_assert!(tensor_maps.is_empty(), "Can't use tensor maps on HIP");
250 let info = self.create(bytemuck::cast_slice(&metadata.data));
251 let scalars: Vec<_> = scalars.values().map(|s| self.create(s.data())).collect();
252
253 let ctx = self.get_context();
254
255 if !ctx.module_names.contains_key(&kernel_id) {
256 ctx.compile_kernel(&kernel_id, kernel, mode, logger);
257 }
258
259 let mut resources: Vec<_> = buffers.into_iter().map(|b| find_resource(ctx, b)).collect();
260 resources.push(find_resource(ctx, info.clone().binding()));
261 resources.extend(scalars.into_iter().map(|s| find_resource(ctx, s.binding())));
262
263 ctx.execute_task(kernel_id, count, resources);
264 }
265
266 fn flush(&mut self) {}
267
268 fn sync(&mut self) -> DynFut<()> {
269 Box::pin(self.sync_stream_async())
270 }
271
272 fn start_profile(&mut self) -> ProfilingToken {
273 cubecl_common::future::block_on(self.sync());
274 self.ctx.timestamps.start()
275 }
276
277 fn end_profile(&mut self, token: ProfilingToken) -> Result<ProfileDuration, ProfileError> {
278 cubecl_common::future::block_on(self.sync());
279 self.ctx.timestamps.stop(token)
280 }
281
282 fn get_resource(&mut self, binding: server::Binding) -> BindingResource<HipResource> {
283 let ctx = self.get_context();
284 BindingResource::new(
285 binding.clone(),
286 ctx.memory_management
287 .get_resource(binding.memory, binding.offset_start, binding.offset_end)
288 .expect("Can't find resource"),
289 )
290 }
291}
292
293fn find_resource(ctx: &mut HipContext, binding: server::Binding) -> HipResource {
294 ctx.memory_management
295 .get_resource(binding.memory, binding.offset_start, binding.offset_end)
296 .expect("Failed to find resource")
297}
298
299impl HipContext {
300 pub fn new(
301 memory_management: MemoryManagement<HipStorage>,
302 compilation_options: CompilationOptions,
303 stream: cubecl_hip_sys::hipStream_t,
304 ) -> Self {
305 Self {
306 memory_management,
307 module_names: HashMap::new(),
308 stream,
309 timestamps: TimestampProfiler::default(),
310 compilation_options,
311 #[cfg(feature = "compilation-cache")]
312 compilation_cache: Cache::new("hip/compilation", CacheOption::default()),
313 }
314 }
315
316 fn fence(&mut self) -> Fence {
317 Fence::new(self.stream)
318 }
319
320 fn lazy_sync_stream(&mut self) -> SyncStream {
321 SyncStream::new(self.stream)
322 }
323
324 fn sync(&mut self) {
325 unsafe {
326 let status = cubecl_hip_sys::hipStreamSynchronize(self.stream);
327 assert_eq!(
328 status, HIP_SUCCESS,
329 "Should successfully synchronize stream"
330 );
331 };
332 self.memory_management.storage().flush();
333 }
334
335 fn memory_usage(&self) -> MemoryUsage {
336 self.memory_management.memory_usage()
337 }
338
339 fn compile_kernel(
340 &mut self,
341 kernel_id: &KernelId,
342 cube_kernel: Box<dyn CubeTask<HipCompiler>>,
343 mode: ExecutionMode,
344 logger: Arc<ServerLogger>,
345 ) {
346 #[cfg(feature = "compilation-cache")]
347 let name = kernel_id.stable_format();
348 #[cfg(feature = "compilation-cache")]
349 if let Some(entry) = self.compilation_cache.get(&name) {
350 log::trace!("Using compilation cache");
351 self.load_compiled_binary(
352 entry.binary.clone(),
353 kernel_id.clone(),
354 entry.entrypoint_name.clone(),
355 CubeDim {
356 x: entry.cube_dim.0,
357 y: entry.cube_dim.1,
358 z: entry.cube_dim.2,
359 },
360 );
361 return;
362 }
363
364 let mut jitc_kernel =
367 cube_kernel.compile(&mut Default::default(), &self.compilation_options, mode);
368
369 if logger.compilation_activated() {
370 jitc_kernel.debug_info = Some(DebugInformation::new("cpp", kernel_id.clone()));
371
372 if let Ok(formatted) = format_cpp(&jitc_kernel.source) {
373 jitc_kernel.source = formatted;
374 }
375 }
376 logger.log_compilation(&jitc_kernel);
377
378 let program = unsafe {
380 let source = CString::new(jitc_kernel.source.clone()).unwrap();
381 let mut program: cubecl_hip_sys::hiprtcProgram = std::ptr::null_mut();
382 let status = cubecl_hip_sys::hiprtcCreateProgram(
383 &mut program,
384 source.as_ptr(),
385 std::ptr::null(), 0,
387 std::ptr::null_mut(),
388 std::ptr::null_mut(),
389 );
390 assert_eq!(
391 status, hiprtcResult_HIPRTC_SUCCESS,
392 "Should create the program"
393 );
394 program
395 };
396 let include_path = get_hip_include_path().unwrap();
399 let include_option = format!("-I{include_path}");
400 let include_option_cstr = CString::new(include_option).unwrap();
401 let cpp_std_option_cstr = CString::new("--std=c++17").unwrap();
403 let optimization_level = CString::new("-O3").unwrap();
404 let mut options = vec![
405 cpp_std_option_cstr.as_ptr(),
406 include_option_cstr.as_ptr(),
407 optimization_level.as_ptr(),
408 ];
409 unsafe {
410 let options_ptr = options.as_mut_ptr();
411 let status =
412 cubecl_hip_sys::hiprtcCompileProgram(program, options.len() as i32, options_ptr);
413 if status != hiprtcResult_HIPRTC_SUCCESS {
414 let mut log_size: usize = 0;
415 let status =
416 cubecl_hip_sys::hiprtcGetProgramLogSize(program, &mut log_size as *mut usize);
417 assert_eq!(
418 status, hiprtcResult_HIPRTC_SUCCESS,
419 "Should retrieve the compilation log size"
420 );
421 let mut log_buffer = vec![0; log_size];
422 let status = cubecl_hip_sys::hiprtcGetProgramLog(program, log_buffer.as_mut_ptr());
423 assert_eq!(
424 status, hiprtcResult_HIPRTC_SUCCESS,
425 "Should retrieve the compilation log contents"
426 );
427 let log = CStr::from_ptr(log_buffer.as_ptr());
428 let mut message = "[Compilation Error] ".to_string();
429 if log_size > 0 {
430 for line in log.to_string_lossy().split('\n') {
431 if !line.is_empty() {
432 message += format!("\n {line}").as_str();
433 }
434 }
435 } else {
436 message += "\n No compilation logs found!";
437 }
438 panic!("{message}\n[Source] \n{}", jitc_kernel.source);
439 }
440 assert_eq!(
441 status, hiprtcResult_HIPRTC_SUCCESS,
442 "Should compile the program"
443 );
444 };
445 let mut code_size: usize = 0;
447 unsafe {
448 let status = cubecl_hip_sys::hiprtcGetCodeSize(program, &mut code_size);
449 assert_eq!(
450 status, hiprtcResult_HIPRTC_SUCCESS,
451 "Should get size of compiled code"
452 );
453 }
454 let mut code = vec![0; code_size];
455 unsafe {
456 let status = cubecl_hip_sys::hiprtcGetCode(program, code.as_mut_ptr());
457 assert_eq!(
458 status, hiprtcResult_HIPRTC_SUCCESS,
459 "Should load compiled code"
460 );
461 }
462
463 #[cfg(feature = "compilation-cache")]
464 self.compilation_cache
465 .insert(
466 name,
467 CompilationCacheEntry {
468 entrypoint_name: jitc_kernel.entrypoint_name.clone(),
469 cube_dim: (
470 jitc_kernel.cube_dim.x,
471 jitc_kernel.cube_dim.y,
472 jitc_kernel.cube_dim.z,
473 ),
474 binary: code.clone(),
475 },
476 )
477 .unwrap();
478
479 self.load_compiled_binary(
480 code,
481 kernel_id.clone(),
482 jitc_kernel.entrypoint_name,
483 jitc_kernel.cube_dim,
484 );
485 }
486
487 fn load_compiled_binary(
488 &mut self,
489 code: Vec<i8>,
490 kernel_id: KernelId,
491 entrypoint_name: String,
492 cube_dim: CubeDim,
493 ) {
494 let func_name = CString::new(entrypoint_name.clone()).unwrap();
495
496 let mut module: cubecl_hip_sys::hipModule_t = std::ptr::null_mut();
498 unsafe {
499 let codeptr = code.as_ptr();
500 let status = cubecl_hip_sys::hipModuleLoadData(&mut module, codeptr as *const _);
501 assert_eq!(status, HIP_SUCCESS, "Should load compiled code into module");
502 }
503 let mut func: cubecl_hip_sys::hipFunction_t = std::ptr::null_mut();
505 unsafe {
506 let status =
507 cubecl_hip_sys::hipModuleGetFunction(&mut func, module, func_name.as_ptr());
508 assert_eq!(status, HIP_SUCCESS, "Should return module function");
509 }
510
511 self.module_names.insert(
513 kernel_id.clone(),
514 HipCompiledKernel {
515 _module: module,
516 func,
517 cube_dim,
518 },
519 );
520 }
521
522 fn execute_task(
523 &mut self,
524 kernel_id: KernelId,
525 dispatch_count: (u32, u32, u32),
526 resources: Vec<HipResource>,
527 ) {
528 let mut bindings = resources
529 .iter()
530 .map(|memory| memory.binding)
531 .collect::<Vec<_>>();
532
533 let kernel = self.module_names.get(&kernel_id).unwrap();
534 let cube_dim = kernel.cube_dim;
535
536 let result = unsafe {
537 let status = cubecl_hip_sys::hipModuleLaunchKernel(
538 kernel.func,
539 dispatch_count.0,
540 dispatch_count.1,
541 dispatch_count.2,
542 cube_dim.x,
543 cube_dim.y,
544 cube_dim.z,
545 0,
549 self.stream,
550 bindings.as_mut_ptr(),
551 std::ptr::null_mut(),
552 );
553 if status == cubecl_hip_sys::hipError_t_hipErrorOutOfMemory {
554 Err(LaunchError::OutOfMemory)
555 } else if status != HIP_SUCCESS {
556 Err(LaunchError::Unknown(format!(
557 "Unable to launch kernel {kernel_id:?} with status {status:?}"
558 )))
559 } else {
560 Ok(())
561 }
562 };
563
564 match result {
565 Ok(_) => {}
566 Err(err) => match self.timestamps.is_empty() {
567 true => panic!("{err:?}"),
568 false => self.timestamps.error(err.into()),
569 },
570 }
571 }
572}
573
574impl HipServer {
575 pub(crate) fn new(mem_alignment: usize, ctx: HipContext) -> Self {
577 Self { ctx, mem_alignment }
578 }
579
580 fn get_context(&mut self) -> &mut HipContext {
581 &mut self.ctx
582 }
583
584 fn copy_to_binding(&mut self, binding: server::Binding, data: &[u8]) {
585 let ctx = self.get_context();
586 let resource = ctx
587 .memory_management
588 .get_resource(binding.memory, binding.offset_start, binding.offset_end)
589 .unwrap();
590
591 unsafe {
592 let status = cubecl_hip_sys::hipMemcpyHtoDAsync(
593 resource.ptr,
594 data as *const _ as *mut _,
595 data.len(),
596 ctx.stream,
597 );
598 assert_eq!(status, HIP_SUCCESS, "Should send data to device");
599 }
600 }
601}
602
603pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
604 let rank = shape.len();
605 let mut strides = vec![1; rank];
606 for i in (0..rank - 1).rev() {
607 strides[i] = strides[i + 1] * shape[i + 1];
608 }
609 strides
610}
611
612#[derive(Debug)]
613pub(crate) enum LaunchError {
614 OutOfMemory,
615 Unknown(String),
616}
617
618impl From<LaunchError> for ProfileError {
619 fn from(val: LaunchError) -> Self {
620 match val {
621 LaunchError::OutOfMemory => ProfileError::Unknown("Out of memory".into()),
622 LaunchError::Unknown(msg) => ProfileError::Unknown(msg),
623 }
624 }
625}