cubecl_hip/compute/
server.rs

1use super::storage::gpu::{GpuResource, GpuStorage};
2use crate::{
3    compute::{
4        command::{Command, write_to_cpu},
5        context::HipContext,
6        fence::Fence,
7        stream::HipStreamBackend,
8    },
9    runtime::HipCompiler,
10};
11use cubecl_common::{bytes::Bytes, future::DynFut, profile::ProfileDuration, stream_id::StreamId};
12use cubecl_core::{
13    MemoryConfiguration, future,
14    ir::MemoryDeviceProperties,
15    prelude::*,
16    server::{
17        Allocation, AllocationKind, Binding, Bindings, CopyDescriptor, ExecutionError, IoError,
18        LaunchError, ProfileError, ProfilingToken, ServerCommunication, ServerUtilities,
19    },
20};
21use cubecl_runtime::{
22    compiler::CubeTask,
23    config::GlobalConfig,
24    logging::ServerLogger,
25    memory_management::{MemoryAllocationMode, MemoryUsage, offset_handles},
26    server::{self, ComputeServer},
27    storage::BindingResource,
28    stream::MultiStream,
29};
30use std::sync::Arc;
31
32#[derive(Debug)]
33pub struct HipServer {
34    ctx: HipContext,
35    streams: MultiStream<HipStreamBackend>,
36    mem_alignment: usize,
37    utilities: Arc<ServerUtilities<Self>>,
38}
39
40unsafe impl Send for HipServer {}
41
42impl ComputeServer for HipServer {
43    type Kernel = Box<dyn CubeTask<HipCompiler>>;
44    type Storage = GpuStorage;
45    type Info = ();
46
47    fn logger(&self) -> Arc<ServerLogger> {
48        self.streams.logger.clone()
49    }
50
51    fn utilities(&self) -> Arc<ServerUtilities<Self>> {
52        self.utilities.clone()
53    }
54
55    fn staging(&mut self, sizes: &[usize], stream_id: StreamId) -> Result<Vec<Bytes>, IoError> {
56        let mut command = self.command_no_inputs(stream_id);
57
58        Ok(sizes
59            .iter()
60            .map(|size| command.reserve_cpu(*size, true, None))
61            .collect())
62    }
63
64    fn create(
65        &mut self,
66        descriptors: Vec<server::AllocationDescriptor<'_>>,
67        stream_id: StreamId,
68    ) -> Result<Vec<server::Allocation>, IoError> {
69        let mut total_size = 0;
70        let mut strides = Vec::new();
71        let mut sizes = Vec::new();
72
73        for descriptor in descriptors {
74            let pitch_align = match descriptor.kind {
75                AllocationKind::Contiguous => 1,
76                AllocationKind::Optimized => self.mem_alignment,
77            };
78
79            let rank = descriptor.shape.len();
80            let width = *descriptor.shape.last().unwrap_or(&1);
81            let height: usize = descriptor.shape.iter().rev().skip(1).product();
82            let height = Ord::max(height, 1);
83            let width_bytes = width * descriptor.elem_size;
84            let pitch = width_bytes.next_multiple_of(pitch_align);
85            let size = height * pitch;
86            total_size += size.next_multiple_of(self.mem_alignment);
87
88            let mut stride = vec![1; rank];
89            if rank > 1 {
90                stride[rank - 2] = pitch / descriptor.elem_size;
91            }
92            if rank > 2 {
93                for i in (0..rank - 2).rev() {
94                    stride[i] = stride[i + 1] * descriptor.shape[i + 1];
95                }
96            }
97
98            strides.push(stride);
99            sizes.push(size);
100        }
101
102        let mem_alignment = self.mem_alignment;
103        let mut command = self.command_no_inputs(stream_id);
104
105        let handle = command.reserve(total_size as u64)?;
106        let handles = offset_handles(handle, &sizes, mem_alignment);
107
108        Ok(handles
109            .into_iter()
110            .zip(strides)
111            .map(|(handle, strides)| Allocation::new(handle, strides))
112            .collect())
113    }
114
115    fn read(
116        &mut self,
117        descriptors: Vec<server::CopyDescriptor>,
118        stream_id: StreamId,
119    ) -> DynFut<Result<Vec<Bytes>, IoError>> {
120        let mut command = self.command(stream_id, descriptors.iter().map(|d| &d.binding));
121
122        Box::pin(command.read_async(descriptors))
123    }
124
125    fn write(
126        &mut self,
127        descriptors: Vec<(server::CopyDescriptor<'_>, Bytes)>,
128        stream_id: StreamId,
129    ) -> Result<(), IoError> {
130        let mut command = self.command(stream_id, descriptors.iter().map(|desc| &desc.0.binding));
131
132        let mut to_drop = Vec::with_capacity(descriptors.len());
133
134        for (descriptor, data) in descriptors {
135            command.write_to_gpu(descriptor, &data)?;
136            to_drop.push(data);
137        }
138
139        command.gc(to_drop);
140
141        Ok(())
142    }
143
144    fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage {
145        let mut command = self.command_no_inputs(stream_id);
146        command.memory_usage()
147    }
148
149    fn memory_cleanup(&mut self, stream_id: StreamId) {
150        let mut command = self.command_no_inputs(stream_id);
151        command.memory_cleanup()
152    }
153
154    unsafe fn launch(
155        &mut self,
156        kernel: Self::Kernel,
157        count: CubeCount,
158        bindings: Bindings,
159        mode: ExecutionMode,
160        stream_id: StreamId,
161    ) -> Result<(), LaunchError> {
162        let mut kernel_id = kernel.id();
163        let logger = self.streams.logger.clone();
164        kernel_id.mode(mode);
165        let mut command = self.command(stream_id, bindings.buffers.iter());
166
167        let count = match count {
168            CubeCount::Static(x, y, z) => (x, y, z),
169            // TODO: HIP doesn't have an exact equivalen of dynamic dispatch. Instead, kernels are free to launch other kernels.
170            // One option is to create a dummy kernel with 1 thread that launches the real kernel with the dynamic dispatch settings.
171            // For now, just read the dispatch settings from the buffer.
172            CubeCount::Dynamic(binding) => {
173                let data = future::block_on(command.read_async(vec![CopyDescriptor::new(
174                    binding,
175                    &[3],
176                    &[1],
177                    4,
178                )]))
179                .unwrap();
180                let data = bytemuck::cast_slice(&data[0]);
181                assert!(
182                    data.len() == 3,
183                    "Dynamic cube count should contain 3 values"
184                );
185                (data[0], data[1], data[2])
186            }
187        };
188
189        let Bindings {
190            buffers,
191            metadata,
192            scalars,
193            tensor_maps,
194        } = bindings;
195
196        debug_assert!(tensor_maps.is_empty(), "Can't use tensor maps on HIP");
197
198        let info = command
199            .create_with_data(bytemuck::cast_slice(&metadata.data))
200            .unwrap();
201        let scalars: Vec<_> = scalars
202            .values()
203            .map(|s| command.create_with_data(s.data()).unwrap())
204            .collect();
205
206        let mut resources: Vec<_> = buffers
207            .into_iter()
208            .map(|b| command.resource(b).expect("Resource to exist."))
209            .collect();
210        resources.push(
211            command
212                .resource(info.clone().binding())
213                .expect("Resource to exist."),
214        );
215        resources.extend(
216            scalars
217                .into_iter()
218                .map(|s| command.resource(s.binding()).expect("Resource to exist.")),
219        );
220
221        command.kernel(kernel_id, kernel, mode, count, &resources, logger)?;
222
223        Ok(())
224    }
225
226    fn flush(&mut self, _stream_id: StreamId) {}
227
228    fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>> {
229        let mut command = self.command_no_inputs(stream_id);
230        command.sync()
231    }
232
233    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
234        if let Err(err) = cubecl_common::future::block_on(self.sync(stream_id)) {
235            self.ctx.timestamps.error(err.into())
236        }
237
238        self.ctx.timestamps.start()
239    }
240
241    fn end_profile(
242        &mut self,
243        stream_id: StreamId,
244        token: ProfilingToken,
245    ) -> Result<ProfileDuration, ProfileError> {
246        if let Err(err) = cubecl_common::future::block_on(self.sync(stream_id)) {
247            self.ctx.timestamps.error(err.into())
248        }
249        self.ctx.timestamps.stop(token)
250    }
251
252    fn get_resource(
253        &mut self,
254        binding: server::Binding,
255        stream_id: StreamId,
256    ) -> BindingResource<GpuResource> {
257        let mut command = self.command(stream_id, [&binding].into_iter());
258
259        BindingResource::new(
260            binding.clone(),
261            command.resource(binding).expect("Failed to find resource"),
262        )
263    }
264
265    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
266        let mut command = self.command_no_inputs(stream_id);
267        command.allocation_mode(mode)
268    }
269}
270
271impl ServerCommunication for HipServer {
272    const SERVER_COMM_ENABLED: bool = true;
273
274    #[cfg_attr(
275        feature = "tracing",
276        tracing::instrument(level = "trace", skip(server_src, server_dst, src))
277    )]
278    fn copy(
279        server_src: &mut Self,
280        server_dst: &mut Self,
281        src: CopyDescriptor<'_>,
282        stream_id_src: StreamId,
283        stream_id_dst: StreamId,
284    ) -> Result<Allocation, IoError> {
285        Self::change_server_serialized(server_src, server_dst, src, stream_id_src, stream_id_dst)
286    }
287}
288
289impl HipServer {
290    /// Create a new hip server.
291    pub(crate) fn new(
292        ctx: HipContext,
293        mem_props: MemoryDeviceProperties,
294        mem_config: MemoryConfiguration,
295        mem_alignment: usize,
296        utilities: ServerUtilities<Self>,
297    ) -> Self {
298        let config = GlobalConfig::get();
299        let max_streams = config.streaming.max_streams;
300
301        Self {
302            ctx,
303            mem_alignment,
304            streams: MultiStream::new(
305                utilities.logger.clone(),
306                HipStreamBackend::new(
307                    mem_props,
308                    mem_config,
309                    mem_alignment,
310                    utilities.logger.clone(),
311                ),
312                max_streams,
313            ),
314            utilities: Arc::new(utilities),
315        }
316    }
317
318    fn command_no_inputs(&mut self, stream_id: StreamId) -> Command<'_> {
319        self.command(stream_id, [].into_iter())
320    }
321
322    fn command<'a>(
323        &mut self,
324        stream_id: StreamId,
325        bindings: impl Iterator<Item = &'a Binding>,
326    ) -> Command<'_> {
327        let streams = self.streams.resolve(stream_id, bindings);
328
329        Command::new(&mut self.ctx, streams)
330    }
331
332    #[cfg_attr(
333        feature = "tracing",
334        tracing::instrument(level = "trace", skip(server_src, server_dst, src))
335    )]
336    fn change_server_serialized(
337        server_src: &mut Self,
338        server_dst: &mut Self,
339        src: CopyDescriptor<'_>,
340        stream_id_src: StreamId,
341        stream_id_dst: StreamId,
342    ) -> Result<Allocation, IoError> {
343        let shape = src.shape.to_vec();
344        let strides = src.strides.to_vec();
345        let elem_size = src.elem_size;
346        let binding = src.binding.clone();
347        let num_bytes = shape.iter().product::<usize>() * elem_size;
348
349        // We start by creating a command on the destination server.
350        //
351        // Here we allocate the necessary bytes using pinned memory managed by the destination
352        // server along a new GPU handle. This way, the bytes could be reused later by that server,
353        // and the lifetime of that handle is aligned with the execution order of the destination server,
354        // removing the need to keep the bytes handle alive using synchronization, which would be the
355        // case if we allocated the bytes using the source server.
356        let mut command_dst = server_dst.command_no_inputs(stream_id_dst);
357        let handle = command_dst.reserve(binding.size())?;
358        let mut bytes = command_dst.reserve_cpu(num_bytes, true, None);
359        let copy_desc = handle.copy_descriptor(&shape, &strides, elem_size);
360
361        // We need to free the command before creating another one.
362        core::mem::drop(command_dst);
363
364        // We create a command on the source server to retrieve the correct resource from the
365        // source memory pools. We also make sure the current stream is aligned with the stream of
366        // the binding, where the data was first allocated.
367        //
368        // We use the source stream to copy the data from the source server into the allocated
369        // bytes. This ensures that the source binding follows the correct execution order, meaning
370        // that we don't have to keep the source handle alive using synchronization, which would be
371        // the case if we performed the copy on the destination server.
372        let mut command_src = server_src.command(stream_id_src, [&src.binding].into_iter());
373        let resource_src = command_src.resource(binding.clone())?;
374        let stream_src = command_src.streams.current().sys;
375
376        unsafe {
377            write_to_cpu(
378                &shape,
379                &strides,
380                elem_size,
381                &mut bytes,
382                resource_src.ptr,
383                stream_src,
384            )?;
385        }
386        let fence_src = Fence::new(stream_src);
387
388        // We need to free the command before creating another one.
389        core::mem::drop(command_src);
390
391        // Finally, we recreate a new command on the destination server to write the data stored in
392        // pinned memory into the destination server. Here we need to wait for the initial copy
393        // made by the source server using an event. The synchronization is done lazily on the
394        // destination stream, which is very efficient.
395        let mut command_dst = server_dst.command_no_inputs(stream_id_dst);
396        let stream_dst = command_dst.streams.current().sys;
397
398        fence_src.wait_async(stream_dst);
399        command_dst.write_to_gpu(copy_desc, &bytes)?;
400        command_dst.gc(bytes);
401
402        // We drop the last command.
403        core::mem::drop(command_dst);
404
405        Ok(Allocation { handle, strides })
406    }
407}