1use super::storage::gpu::{GpuResource, GpuStorage};
2use crate::{
3 compute::{
4 command::{Command, write_to_cpu},
5 context::HipContext,
6 fence::Fence,
7 stream::HipStreamBackend,
8 },
9 runtime::HipCompiler,
10};
11use cubecl_common::{bytes::Bytes, future::DynFut, profile::ProfileDuration, stream_id::StreamId};
12use cubecl_core::{
13 MemoryConfiguration, future,
14 ir::MemoryDeviceProperties,
15 prelude::*,
16 server::{
17 Allocation, AllocationKind, Binding, Bindings, CopyDescriptor, ExecutionError, IoError,
18 LaunchError, ProfileError, ProfilingToken, ServerCommunication, ServerUtilities,
19 },
20};
21use cubecl_runtime::{
22 compiler::CubeTask,
23 config::GlobalConfig,
24 logging::ServerLogger,
25 memory_management::{MemoryAllocationMode, MemoryUsage, offset_handles},
26 server::{self, ComputeServer},
27 storage::BindingResource,
28 stream::MultiStream,
29};
30use std::sync::Arc;
31
32#[derive(Debug)]
33pub struct HipServer {
34 ctx: HipContext,
35 streams: MultiStream<HipStreamBackend>,
36 mem_alignment: usize,
37 utilities: Arc<ServerUtilities<Self>>,
38}
39
40unsafe impl Send for HipServer {}
41
42impl ComputeServer for HipServer {
43 type Kernel = Box<dyn CubeTask<HipCompiler>>;
44 type Storage = GpuStorage;
45 type Info = ();
46
47 fn logger(&self) -> Arc<ServerLogger> {
48 self.streams.logger.clone()
49 }
50
51 fn utilities(&self) -> Arc<ServerUtilities<Self>> {
52 self.utilities.clone()
53 }
54
55 fn staging(&mut self, sizes: &[usize], stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
56 let mut command = self.command_no_inputs(stream_id);
57
58 Ok(sizes
59 .iter()
60 .map(|size| command.reserve_cpu(*size, true, None))
61 .collect())
62 }
63
64 fn create(
65 &mut self,
66 descriptors: Vec<server::AllocationDescriptor<'_>>,
67 stream_id: StreamId,
68 ) -> Result<Vec<server::Allocation>, IoError> {
69 let mut total_size = 0;
70 let mut strides = Vec::new();
71 let mut sizes = Vec::new();
72
73 for descriptor in descriptors {
74 let pitch_align = match descriptor.kind {
75 AllocationKind::Contiguous => 1,
76 AllocationKind::Optimized => self.mem_alignment,
77 };
78
79 let rank = descriptor.shape.len();
80 let width = *descriptor.shape.last().unwrap_or(&1);
81 let height: usize = descriptor.shape.iter().rev().skip(1).product();
82 let height = Ord::max(height, 1);
83 let width_bytes = width * descriptor.elem_size;
84 let pitch = width_bytes.next_multiple_of(pitch_align);
85 let size = height * pitch;
86 total_size += size.next_multiple_of(self.mem_alignment);
87
88 let mut stride = vec![1; rank];
89 if rank > 1 {
90 stride[rank - 2] = pitch / descriptor.elem_size;
91 }
92 if rank > 2 {
93 for i in (0..rank - 2).rev() {
94 stride[i] = stride[i + 1] * descriptor.shape[i + 1];
95 }
96 }
97
98 strides.push(stride);
99 sizes.push(size);
100 }
101
102 let mem_alignment = self.mem_alignment;
103 let mut command = self.command_no_inputs(stream_id);
104
105 let handle = command.reserve(total_size as u64)?;
106 let handles = offset_handles(handle, &sizes, mem_alignment);
107
108 Ok(handles
109 .into_iter()
110 .zip(strides)
111 .map(|(handle, strides)| Allocation::new(handle, strides))
112 .collect())
113 }
114
115 fn read(
116 &mut self,
117 descriptors: Vec<server::CopyDescriptor>,
118 stream_id: StreamId,
119 ) -> DynFut<Result<Vec<Bytes>, IoError>> {
120 let mut command = self.command(stream_id, descriptors.iter().map(|d| &d.binding));
121
122 Box::pin(command.read_async(descriptors))
123 }
124
125 fn write(
126 &mut self,
127 descriptors: Vec<(server::CopyDescriptor<'_>, Bytes)>,
128 stream_id: StreamId,
129 ) -> Result<(), IoError> {
130 let mut command = self.command(stream_id, descriptors.iter().map(|desc| &desc.0.binding));
131
132 let mut to_drop = Vec::with_capacity(descriptors.len());
133
134 for (descriptor, data) in descriptors {
135 command.write_to_gpu(descriptor, &data)?;
136 to_drop.push(data);
137 }
138
139 command.gc(to_drop);
140
141 Ok(())
142 }
143
144 fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage {
145 let mut command = self.command_no_inputs(stream_id);
146 command.memory_usage()
147 }
148
149 fn memory_cleanup(&mut self, stream_id: StreamId) {
150 let mut command = self.command_no_inputs(stream_id);
151 command.memory_cleanup()
152 }
153
154 unsafe fn launch(
155 &mut self,
156 kernel: Self::Kernel,
157 count: CubeCount,
158 bindings: Bindings,
159 mode: ExecutionMode,
160 stream_id: StreamId,
161 ) -> Result<(), LaunchError> {
162 let mut kernel_id = kernel.id();
163 let logger = self.streams.logger.clone();
164 kernel_id.mode(mode);
165 let mut command = self.command(stream_id, bindings.buffers.iter());
166
167 let count = match count {
168 CubeCount::Static(x, y, z) => (x, y, z),
169 CubeCount::Dynamic(binding) => {
173 let data = future::block_on(command.read_async(vec![CopyDescriptor::new(
174 binding,
175 &[3],
176 &[1],
177 4,
178 )]))
179 .unwrap();
180 let data = bytemuck::cast_slice(&data[0]);
181 assert!(
182 data.len() == 3,
183 "Dynamic cube count should contain 3 values"
184 );
185 (data[0], data[1], data[2])
186 }
187 };
188
189 let Bindings {
190 buffers,
191 metadata,
192 scalars,
193 tensor_maps,
194 } = bindings;
195
196 debug_assert!(tensor_maps.is_empty(), "Can't use tensor maps on HIP");
197
198 let info = command
199 .create_with_data(bytemuck::cast_slice(&metadata.data))
200 .unwrap();
201 let scalars: Vec<_> = scalars
202 .values()
203 .map(|s| command.create_with_data(s.data()).unwrap())
204 .collect();
205
206 let mut resources: Vec<_> = buffers
207 .into_iter()
208 .map(|b| command.resource(b).expect("Resource to exist."))
209 .collect();
210 resources.push(
211 command
212 .resource(info.clone().binding())
213 .expect("Resource to exist."),
214 );
215 resources.extend(
216 scalars
217 .into_iter()
218 .map(|s| command.resource(s.binding()).expect("Resource to exist.")),
219 );
220
221 command.kernel(kernel_id, kernel, mode, count, &resources, logger)?;
222
223 Ok(())
224 }
225
226 fn flush(&mut self, _stream_id: StreamId) {}
227
228 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>> {
229 let mut command = self.command_no_inputs(stream_id);
230 command.sync()
231 }
232
233 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
234 if let Err(err) = cubecl_common::future::block_on(self.sync(stream_id)) {
235 self.ctx.timestamps.error(err.into())
236 }
237
238 self.ctx.timestamps.start()
239 }
240
241 fn end_profile(
242 &mut self,
243 stream_id: StreamId,
244 token: ProfilingToken,
245 ) -> Result<ProfileDuration, ProfileError> {
246 if let Err(err) = cubecl_common::future::block_on(self.sync(stream_id)) {
247 self.ctx.timestamps.error(err.into())
248 }
249 self.ctx.timestamps.stop(token)
250 }
251
252 fn get_resource(
253 &mut self,
254 binding: server::Binding,
255 stream_id: StreamId,
256 ) -> BindingResource<GpuResource> {
257 let mut command = self.command(stream_id, [&binding].into_iter());
258
259 BindingResource::new(
260 binding.clone(),
261 command.resource(binding).expect("Failed to find resource"),
262 )
263 }
264
265 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
266 let mut command = self.command_no_inputs(stream_id);
267 command.allocation_mode(mode)
268 }
269}
270
271impl ServerCommunication for HipServer {
272 const SERVER_COMM_ENABLED: bool = true;
273
274 #[cfg_attr(
275 feature = "tracing",
276 tracing::instrument(level = "trace", skip(server_src, server_dst, src))
277 )]
278 fn copy(
279 server_src: &mut Self,
280 server_dst: &mut Self,
281 src: CopyDescriptor<'_>,
282 stream_id_src: StreamId,
283 stream_id_dst: StreamId,
284 ) -> Result<Allocation, IoError> {
285 Self::change_server_serialized(server_src, server_dst, src, stream_id_src, stream_id_dst)
286 }
287}
288
289impl HipServer {
290 pub(crate) fn new(
292 ctx: HipContext,
293 mem_props: MemoryDeviceProperties,
294 mem_config: MemoryConfiguration,
295 mem_alignment: usize,
296 utilities: ServerUtilities<Self>,
297 ) -> Self {
298 let config = GlobalConfig::get();
299 let max_streams = config.streaming.max_streams;
300
301 Self {
302 ctx,
303 mem_alignment,
304 streams: MultiStream::new(
305 utilities.logger.clone(),
306 HipStreamBackend::new(
307 mem_props,
308 mem_config,
309 mem_alignment,
310 utilities.logger.clone(),
311 ),
312 max_streams,
313 ),
314 utilities: Arc::new(utilities),
315 }
316 }
317
318 fn command_no_inputs(&mut self, stream_id: StreamId) -> Command<'_> {
319 self.command(stream_id, [].into_iter())
320 }
321
322 fn command<'a>(
323 &mut self,
324 stream_id: StreamId,
325 bindings: impl Iterator<Item = &'a Binding>,
326 ) -> Command<'_> {
327 let streams = self.streams.resolve(stream_id, bindings);
328
329 Command::new(&mut self.ctx, streams)
330 }
331
332 #[cfg_attr(
333 feature = "tracing",
334 tracing::instrument(level = "trace", skip(server_src, server_dst, src))
335 )]
336 fn change_server_serialized(
337 server_src: &mut Self,
338 server_dst: &mut Self,
339 src: CopyDescriptor<'_>,
340 stream_id_src: StreamId,
341 stream_id_dst: StreamId,
342 ) -> Result<Allocation, IoError> {
343 let shape = src.shape.to_vec();
344 let strides = src.strides.to_vec();
345 let elem_size = src.elem_size;
346 let binding = src.binding.clone();
347 let num_bytes = shape.iter().product::<usize>() * elem_size;
348
349 let mut command_dst = server_dst.command_no_inputs(stream_id_dst);
357 let handle = command_dst.reserve(binding.size())?;
358 let mut bytes = command_dst.reserve_cpu(num_bytes, true, None);
359 let copy_desc = handle.copy_descriptor(&shape, &strides, elem_size);
360
361 core::mem::drop(command_dst);
363
364 let mut command_src = server_src.command(stream_id_src, [&src.binding].into_iter());
373 let resource_src = command_src.resource(binding.clone())?;
374 let stream_src = command_src.streams.current().sys;
375
376 unsafe {
377 write_to_cpu(
378 &shape,
379 &strides,
380 elem_size,
381 &mut bytes,
382 resource_src.ptr,
383 stream_src,
384 )?;
385 }
386 let fence_src = Fence::new(stream_src);
387
388 core::mem::drop(command_src);
390
391 let mut command_dst = server_dst.command_no_inputs(stream_id_dst);
396 let stream_dst = command_dst.streams.current().sys;
397
398 fence_src.wait_async(stream_dst);
399 command_dst.write_to_gpu(copy_desc, &bytes)?;
400 command_dst.gc(bytes);
401
402 core::mem::drop(command_dst);
404
405 Ok(Allocation { handle, strides })
406 }
407}