1use super::storage::gpu::GpuResource;
2use super::storage::gpu::GpuStorage;
3use crate::compute::command::Command;
4use crate::compute::command::write_to_cpu;
5use crate::compute::context::HipContext;
6use crate::compute::fence::Fence;
7use crate::compute::stream::HipStreamBackend;
8use crate::runtime::HipCompiler;
9use cubecl_common::bytes::Bytes;
10use cubecl_common::future::DynFut;
11use cubecl_common::profile::ProfileDuration;
12use cubecl_common::stream_id::StreamId;
13use cubecl_core::compute::CubeTask;
14use cubecl_core::server::ServerCommunication;
15use cubecl_core::server::ServerUtilities;
16use cubecl_core::server::{
17 Allocation, AllocationKind, CopyDescriptor, IoError, ProfileError, ProfilingToken,
18};
19use cubecl_core::server::{Binding, Bindings};
20use cubecl_core::{MemoryConfiguration, future, prelude::*};
21use cubecl_runtime::config::GlobalConfig;
22use cubecl_runtime::logging::ServerLogger;
23use cubecl_runtime::memory_management::{MemoryAllocationMode, MemoryUsage};
24use cubecl_runtime::memory_management::{MemoryDeviceProperties, offset_handles};
25use cubecl_runtime::server::{self, ComputeServer};
26use cubecl_runtime::storage::BindingResource;
27use cubecl_runtime::stream::MultiStream;
28use std::sync::Arc;
29
30#[derive(Debug)]
31pub struct HipServer {
32 ctx: HipContext,
33 streams: MultiStream<HipStreamBackend>,
34 mem_alignment: usize,
35 utilities: Arc<ServerUtilities<Self>>,
36}
37
38unsafe impl Send for HipServer {}
39
40impl ComputeServer for HipServer {
41 type Kernel = Box<dyn CubeTask<HipCompiler>>;
42 type Storage = GpuStorage;
43 type Info = ();
44
45 fn logger(&self) -> Arc<ServerLogger> {
46 self.streams.logger.clone()
47 }
48
49 fn utilities(&self) -> Arc<ServerUtilities<Self>> {
50 self.utilities.clone()
51 }
52
53 fn create(
54 &mut self,
55 descriptors: Vec<server::AllocationDescriptor<'_>>,
56 stream_id: StreamId,
57 ) -> Result<Vec<server::Allocation>, IoError> {
58 let mut total_size = 0;
59 let mut strides = Vec::new();
60 let mut sizes = Vec::new();
61
62 for descriptor in descriptors {
63 let pitch_align = match descriptor.kind {
64 AllocationKind::Contiguous => 1,
65 AllocationKind::Optimized => self.mem_alignment,
66 };
67
68 let rank = descriptor.shape.len();
69 let width = *descriptor.shape.last().unwrap_or(&1);
70 let height: usize = descriptor.shape.iter().rev().skip(1).product();
71 let height = height.max(1);
72 let width_bytes = width * descriptor.elem_size;
73 let pitch = width_bytes.next_multiple_of(pitch_align);
74 let size = height * pitch;
75 total_size += size.next_multiple_of(self.mem_alignment);
76
77 let mut stride = vec![1; rank];
78 if rank > 1 {
79 stride[rank - 2] = pitch / descriptor.elem_size;
80 }
81 if rank > 2 {
82 for i in (0..rank - 2).rev() {
83 stride[i] = stride[i + 1] * descriptor.shape[i + 1];
84 }
85 }
86
87 strides.push(stride);
88 sizes.push(size);
89 }
90
91 let mem_alignment = self.mem_alignment;
92 let mut command = self.command_no_inputs(stream_id);
93
94 let handle = command.reserve(total_size as u64)?;
95 let handles = offset_handles(handle, &sizes, mem_alignment);
96
97 Ok(handles
98 .into_iter()
99 .zip(strides)
100 .map(|(handle, strides)| Allocation::new(handle, strides))
101 .collect())
102 }
103
104 fn read(
105 &mut self,
106 descriptors: Vec<server::CopyDescriptor>,
107 stream_id: StreamId,
108 ) -> DynFut<Result<Vec<Bytes>, IoError>> {
109 let mut command = self.command(stream_id, descriptors.iter().map(|d| &d.binding));
110
111 Box::pin(command.read_async(descriptors))
112 }
113
114 fn write(
115 &mut self,
116 descriptors: Vec<(server::CopyDescriptor<'_>, &[u8])>,
117 stream_id: StreamId,
118 ) -> Result<(), IoError> {
119 let mut command = self.command(stream_id, descriptors.iter().map(|desc| &desc.0.binding));
120
121 for (descriptor, data) in descriptors {
122 command.write_to_gpu(descriptor, data)?;
123 }
124
125 Ok(())
126 }
127
128 fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage {
129 let mut command = self.command_no_inputs(stream_id);
130 command.memory_usage()
131 }
132
133 fn memory_cleanup(&mut self, stream_id: StreamId) {
134 let mut command = self.command_no_inputs(stream_id);
135 command.memory_cleanup()
136 }
137
138 unsafe fn execute(
139 &mut self,
140 kernel: Self::Kernel,
141 count: CubeCount,
142 bindings: Bindings,
143 mode: ExecutionMode,
144 stream_id: StreamId,
145 ) {
146 let mut kernel_id = kernel.id();
147 let logger = self.streams.logger.clone();
148 kernel_id.mode(mode);
149 let mut command = self.command(stream_id, bindings.buffers.iter());
150
151 let count = match count {
152 CubeCount::Static(x, y, z) => (x, y, z),
153 CubeCount::Dynamic(binding) => {
157 let data = future::block_on(command.read_async(vec![CopyDescriptor::new(
158 binding,
159 &[3],
160 &[1],
161 4,
162 )]))
163 .unwrap();
164 let data = bytemuck::cast_slice(&data[0]);
165 assert!(
166 data.len() == 3,
167 "Dynamic cube count should contain 3 values"
168 );
169 (data[0], data[1], data[2])
170 }
171 };
172
173 let Bindings {
174 buffers,
175 metadata,
176 scalars,
177 tensor_maps,
178 } = bindings;
179
180 debug_assert!(tensor_maps.is_empty(), "Can't use tensor maps on HIP");
181
182 let info = command
183 .create_with_data(bytemuck::cast_slice(&metadata.data))
184 .unwrap();
185 let scalars: Vec<_> = scalars
186 .values()
187 .map(|s| command.create_with_data(s.data()).unwrap())
188 .collect();
189
190 let mut resources: Vec<_> = buffers
191 .into_iter()
192 .map(|b| command.resource(b).expect("Resource to exist."))
193 .collect();
194 resources.push(
195 command
196 .resource(info.clone().binding())
197 .expect("Resource to exist."),
198 );
199 resources.extend(
200 scalars
201 .into_iter()
202 .map(|s| command.resource(s.binding()).expect("Resource to exist.")),
203 );
204
205 command.kernel(kernel_id, kernel, mode, count, &resources, logger)
206 }
207
208 fn flush(&mut self, _stream_id: StreamId) {}
209
210 fn sync(&mut self, stream_id: StreamId) -> DynFut<()> {
211 let mut command = self.command_no_inputs(stream_id);
212 command.sync()
213 }
214
215 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
216 cubecl_common::future::block_on(self.sync(stream_id));
217 self.ctx.timestamps.start()
218 }
219
220 fn end_profile(
221 &mut self,
222 stream_id: StreamId,
223 token: ProfilingToken,
224 ) -> Result<ProfileDuration, ProfileError> {
225 cubecl_common::future::block_on(self.sync(stream_id));
226 self.ctx.timestamps.stop(token)
227 }
228
229 fn get_resource(
230 &mut self,
231 binding: server::Binding,
232 stream_id: StreamId,
233 ) -> BindingResource<GpuResource> {
234 let mut command = self.command(stream_id, [&binding].into_iter());
235
236 BindingResource::new(
237 binding.clone(),
238 command.resource(binding).expect("Failed to find resource"),
239 )
240 }
241
242 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
243 let mut command = self.command_no_inputs(stream_id);
244 command.allocation_mode(mode)
245 }
246}
247
248impl ServerCommunication for HipServer {
249 const SERVER_COMM_ENABLED: bool = true;
250
251 fn copy(
252 server_src: &mut Self,
253 server_dst: &mut Self,
254 src: CopyDescriptor<'_>,
255 stream_id_src: StreamId,
256 stream_id_dst: StreamId,
257 ) -> Result<Allocation, IoError> {
258 Self::change_server_serialized(server_src, server_dst, src, stream_id_src, stream_id_dst)
259 }
260}
261
262impl HipServer {
263 pub(crate) fn new(
265 ctx: HipContext,
266 mem_props: MemoryDeviceProperties,
267 mem_config: MemoryConfiguration,
268 mem_alignment: usize,
269 utilities: ServerUtilities<Self>,
270 ) -> Self {
271 let config = GlobalConfig::get();
272 let max_streams = config.streaming.max_streams;
273
274 Self {
275 ctx,
276 mem_alignment,
277 streams: MultiStream::new(
278 utilities.logger.clone(),
279 HipStreamBackend::new(
280 mem_props,
281 mem_config,
282 mem_alignment,
283 utilities.logger.clone(),
284 ),
285 max_streams,
286 ),
287 utilities: Arc::new(utilities),
288 }
289 }
290
291 fn command_no_inputs(&mut self, stream_id: StreamId) -> Command<'_> {
292 self.command(stream_id, [].into_iter())
293 }
294
295 fn command<'a>(
296 &mut self,
297 stream_id: StreamId,
298 bindings: impl Iterator<Item = &'a Binding>,
299 ) -> Command<'_> {
300 let streams = self.streams.resolve(stream_id, bindings);
301
302 Command::new(&mut self.ctx, streams)
303 }
304
305 fn change_server_serialized(
306 server_src: &mut Self,
307 server_dst: &mut Self,
308 src: CopyDescriptor<'_>,
309 stream_id_src: StreamId,
310 stream_id_dst: StreamId,
311 ) -> Result<Allocation, IoError> {
312 let shape = src.shape.to_vec();
313 let strides = src.strides.to_vec();
314 let elem_size = src.elem_size;
315 let binding = src.binding.clone();
316 let num_bytes = shape.iter().product::<usize>() * elem_size;
317
318 let mut command_dst = server_dst.command_no_inputs(stream_id_dst);
326 let handle = command_dst.reserve(binding.size())?;
327 let mut bytes = command_dst.reserve_cpu(num_bytes, true, None);
328 let copy_desc = handle.copy_descriptor(&shape, &strides, elem_size);
329
330 core::mem::drop(command_dst);
332
333 let mut command_src = server_src.command(stream_id_src, [&src.binding].into_iter());
342 let resource_src = command_src.resource(binding.clone())?;
343 let stream_src = command_src.streams.current().sys;
344
345 unsafe {
346 write_to_cpu(
347 &shape,
348 &strides,
349 elem_size,
350 &mut bytes,
351 resource_src.ptr,
352 stream_src,
353 )?;
354 }
355 let fence_src = Fence::new(stream_src);
356
357 core::mem::drop(command_src);
359
360 let mut command_dst = server_dst.command_no_inputs(stream_id_dst);
365 let stream_dst = command_dst.streams.current().sys;
366
367 fence_src.wait_async(stream_dst);
368 command_dst.write_to_gpu(copy_desc, &bytes)?;
369
370 core::mem::drop(command_dst);
372
373 Ok(Allocation { handle, strides })
374 }
375}
376
377pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
378 let rank = shape.len();
379 let mut strides = vec![1; rank];
380 for i in (0..rank - 1).rev() {
381 strides[i] = strides[i + 1] * shape[i + 1];
382 }
383 strides
384}
385
386#[derive(Debug)]
387pub(crate) enum LaunchError {
388 OutOfMemory,
389 Unknown(String),
390}
391
392impl From<LaunchError> for ProfileError {
393 fn from(val: LaunchError) -> Self {
394 match val {
395 LaunchError::OutOfMemory => ProfileError::Unknown("Out of memory".into()),
396 LaunchError::Unknown(msg) => ProfileError::Unknown(msg),
397 }
398 }
399}
400
401pub fn valid_strides(shape: &[usize], strides: &[usize]) -> bool {
402 let rank = shape.len();
403 if strides[rank - 1] != 1 {
404 return false;
405 }
406 if rank <= 1 {
407 return true;
408 }
409
410 let mut sorted = strides.to_vec();
411 sorted.sort();
412 sorted.reverse();
413
414 if sorted != strides {
415 return false;
416 }
417
418 for i in 0..rank - 2 {
419 if strides[i] != shape[i + 1] * strides[i + 1] {
420 return false;
421 }
422 }
423 true
424}