cubecl_hip/compute/
server.rs

1use super::storage::gpu::GpuResource;
2use super::storage::gpu::GpuStorage;
3use crate::compute::command::Command;
4use crate::compute::command::write_to_cpu;
5use crate::compute::context::HipContext;
6use crate::compute::fence::Fence;
7use crate::compute::stream::HipStreamBackend;
8use crate::runtime::HipCompiler;
9use cubecl_common::bytes::Bytes;
10use cubecl_common::future::DynFut;
11use cubecl_common::profile::ProfileDuration;
12use cubecl_common::stream_id::StreamId;
13use cubecl_core::compute::CubeTask;
14use cubecl_core::server::ServerCommunication;
15use cubecl_core::server::ServerUtilities;
16use cubecl_core::server::{
17    Allocation, AllocationKind, CopyDescriptor, IoError, ProfileError, ProfilingToken,
18};
19use cubecl_core::server::{Binding, Bindings};
20use cubecl_core::{MemoryConfiguration, future, prelude::*};
21use cubecl_runtime::config::GlobalConfig;
22use cubecl_runtime::logging::ServerLogger;
23use cubecl_runtime::memory_management::{MemoryAllocationMode, MemoryUsage};
24use cubecl_runtime::memory_management::{MemoryDeviceProperties, offset_handles};
25use cubecl_runtime::server::{self, ComputeServer};
26use cubecl_runtime::storage::BindingResource;
27use cubecl_runtime::stream::MultiStream;
28use std::sync::Arc;
29
30#[derive(Debug)]
31pub struct HipServer {
32    ctx: HipContext,
33    streams: MultiStream<HipStreamBackend>,
34    mem_alignment: usize,
35    utilities: Arc<ServerUtilities<Self>>,
36}
37
38unsafe impl Send for HipServer {}
39
40impl ComputeServer for HipServer {
41    type Kernel = Box<dyn CubeTask<HipCompiler>>;
42    type Storage = GpuStorage;
43    type Info = ();
44
45    fn logger(&self) -> Arc<ServerLogger> {
46        self.streams.logger.clone()
47    }
48
49    fn utilities(&self) -> Arc<ServerUtilities<Self>> {
50        self.utilities.clone()
51    }
52
53    fn create(
54        &mut self,
55        descriptors: Vec<server::AllocationDescriptor<'_>>,
56        stream_id: StreamId,
57    ) -> Result<Vec<server::Allocation>, IoError> {
58        let mut total_size = 0;
59        let mut strides = Vec::new();
60        let mut sizes = Vec::new();
61
62        for descriptor in descriptors {
63            let pitch_align = match descriptor.kind {
64                AllocationKind::Contiguous => 1,
65                AllocationKind::Optimized => self.mem_alignment,
66            };
67
68            let rank = descriptor.shape.len();
69            let width = *descriptor.shape.last().unwrap_or(&1);
70            let height: usize = descriptor.shape.iter().rev().skip(1).product();
71            let height = height.max(1);
72            let width_bytes = width * descriptor.elem_size;
73            let pitch = width_bytes.next_multiple_of(pitch_align);
74            let size = height * pitch;
75            total_size += size.next_multiple_of(self.mem_alignment);
76
77            let mut stride = vec![1; rank];
78            if rank > 1 {
79                stride[rank - 2] = pitch / descriptor.elem_size;
80            }
81            if rank > 2 {
82                for i in (0..rank - 2).rev() {
83                    stride[i] = stride[i + 1] * descriptor.shape[i + 1];
84                }
85            }
86
87            strides.push(stride);
88            sizes.push(size);
89        }
90
91        let mem_alignment = self.mem_alignment;
92        let mut command = self.command_no_inputs(stream_id);
93
94        let handle = command.reserve(total_size as u64)?;
95        let handles = offset_handles(handle, &sizes, mem_alignment);
96
97        Ok(handles
98            .into_iter()
99            .zip(strides)
100            .map(|(handle, strides)| Allocation::new(handle, strides))
101            .collect())
102    }
103
104    fn read(
105        &mut self,
106        descriptors: Vec<server::CopyDescriptor>,
107        stream_id: StreamId,
108    ) -> DynFut<Result<Vec<Bytes>, IoError>> {
109        let mut command = self.command(stream_id, descriptors.iter().map(|d| &d.binding));
110
111        Box::pin(command.read_async(descriptors))
112    }
113
114    fn write(
115        &mut self,
116        descriptors: Vec<(server::CopyDescriptor<'_>, &[u8])>,
117        stream_id: StreamId,
118    ) -> Result<(), IoError> {
119        let mut command = self.command(stream_id, descriptors.iter().map(|desc| &desc.0.binding));
120
121        for (descriptor, data) in descriptors {
122            command.write_to_gpu(descriptor, data)?;
123        }
124
125        Ok(())
126    }
127
128    fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage {
129        let mut command = self.command_no_inputs(stream_id);
130        command.memory_usage()
131    }
132
133    fn memory_cleanup(&mut self, stream_id: StreamId) {
134        let mut command = self.command_no_inputs(stream_id);
135        command.memory_cleanup()
136    }
137
138    unsafe fn execute(
139        &mut self,
140        kernel: Self::Kernel,
141        count: CubeCount,
142        bindings: Bindings,
143        mode: ExecutionMode,
144        stream_id: StreamId,
145    ) {
146        let mut kernel_id = kernel.id();
147        let logger = self.streams.logger.clone();
148        kernel_id.mode(mode);
149        let mut command = self.command(stream_id, bindings.buffers.iter());
150
151        let count = match count {
152            CubeCount::Static(x, y, z) => (x, y, z),
153            // TODO: HIP doesn't have an exact equivalen of dynamic dispatch. Instead, kernels are free to launch other kernels.
154            // One option is to create a dummy kernel with 1 thread that launches the real kernel with the dynamic dispatch settings.
155            // For now, just read the dispatch settings from the buffer.
156            CubeCount::Dynamic(binding) => {
157                let data = future::block_on(command.read_async(vec![CopyDescriptor::new(
158                    binding,
159                    &[3],
160                    &[1],
161                    4,
162                )]))
163                .unwrap();
164                let data = bytemuck::cast_slice(&data[0]);
165                assert!(
166                    data.len() == 3,
167                    "Dynamic cube count should contain 3 values"
168                );
169                (data[0], data[1], data[2])
170            }
171        };
172
173        let Bindings {
174            buffers,
175            metadata,
176            scalars,
177            tensor_maps,
178        } = bindings;
179
180        debug_assert!(tensor_maps.is_empty(), "Can't use tensor maps on HIP");
181
182        let info = command
183            .create_with_data(bytemuck::cast_slice(&metadata.data))
184            .unwrap();
185        let scalars: Vec<_> = scalars
186            .values()
187            .map(|s| command.create_with_data(s.data()).unwrap())
188            .collect();
189
190        let mut resources: Vec<_> = buffers
191            .into_iter()
192            .map(|b| command.resource(b).expect("Resource to exist."))
193            .collect();
194        resources.push(
195            command
196                .resource(info.clone().binding())
197                .expect("Resource to exist."),
198        );
199        resources.extend(
200            scalars
201                .into_iter()
202                .map(|s| command.resource(s.binding()).expect("Resource to exist.")),
203        );
204
205        command.kernel(kernel_id, kernel, mode, count, &resources, logger)
206    }
207
208    fn flush(&mut self, _stream_id: StreamId) {}
209
210    fn sync(&mut self, stream_id: StreamId) -> DynFut<()> {
211        let mut command = self.command_no_inputs(stream_id);
212        command.sync()
213    }
214
215    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
216        cubecl_common::future::block_on(self.sync(stream_id));
217        self.ctx.timestamps.start()
218    }
219
220    fn end_profile(
221        &mut self,
222        stream_id: StreamId,
223        token: ProfilingToken,
224    ) -> Result<ProfileDuration, ProfileError> {
225        cubecl_common::future::block_on(self.sync(stream_id));
226        self.ctx.timestamps.stop(token)
227    }
228
229    fn get_resource(
230        &mut self,
231        binding: server::Binding,
232        stream_id: StreamId,
233    ) -> BindingResource<GpuResource> {
234        let mut command = self.command(stream_id, [&binding].into_iter());
235
236        BindingResource::new(
237            binding.clone(),
238            command.resource(binding).expect("Failed to find resource"),
239        )
240    }
241
242    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
243        let mut command = self.command_no_inputs(stream_id);
244        command.allocation_mode(mode)
245    }
246}
247
248impl ServerCommunication for HipServer {
249    const SERVER_COMM_ENABLED: bool = true;
250
251    fn copy(
252        server_src: &mut Self,
253        server_dst: &mut Self,
254        src: CopyDescriptor<'_>,
255        stream_id_src: StreamId,
256        stream_id_dst: StreamId,
257    ) -> Result<Allocation, IoError> {
258        Self::change_server_serialized(server_src, server_dst, src, stream_id_src, stream_id_dst)
259    }
260}
261
262impl HipServer {
263    /// Create a new hip server.
264    pub(crate) fn new(
265        ctx: HipContext,
266        mem_props: MemoryDeviceProperties,
267        mem_config: MemoryConfiguration,
268        mem_alignment: usize,
269        utilities: ServerUtilities<Self>,
270    ) -> Self {
271        let config = GlobalConfig::get();
272        let max_streams = config.streaming.max_streams;
273
274        Self {
275            ctx,
276            mem_alignment,
277            streams: MultiStream::new(
278                utilities.logger.clone(),
279                HipStreamBackend::new(
280                    mem_props,
281                    mem_config,
282                    mem_alignment,
283                    utilities.logger.clone(),
284                ),
285                max_streams,
286            ),
287            utilities: Arc::new(utilities),
288        }
289    }
290
291    fn command_no_inputs(&mut self, stream_id: StreamId) -> Command<'_> {
292        self.command(stream_id, [].into_iter())
293    }
294
295    fn command<'a>(
296        &mut self,
297        stream_id: StreamId,
298        bindings: impl Iterator<Item = &'a Binding>,
299    ) -> Command<'_> {
300        let streams = self.streams.resolve(stream_id, bindings);
301
302        Command::new(&mut self.ctx, streams)
303    }
304
305    fn change_server_serialized(
306        server_src: &mut Self,
307        server_dst: &mut Self,
308        src: CopyDescriptor<'_>,
309        stream_id_src: StreamId,
310        stream_id_dst: StreamId,
311    ) -> Result<Allocation, IoError> {
312        let shape = src.shape.to_vec();
313        let strides = src.strides.to_vec();
314        let elem_size = src.elem_size;
315        let binding = src.binding.clone();
316        let num_bytes = shape.iter().product::<usize>() * elem_size;
317
318        // We start by creating a command on the destination server.
319        //
320        // Here we allocate the necessary bytes using pinned memory managed by the destination
321        // server along a new GPU handle. This way, the bytes could be reused later by that server,
322        // and the lifetime of that handle is aligned with the execution order of the destination server,
323        // removing the need to keep the bytes handle alive using synchronization, which would be the
324        // case if we allocated the bytes using the source server.
325        let mut command_dst = server_dst.command_no_inputs(stream_id_dst);
326        let handle = command_dst.reserve(binding.size())?;
327        let mut bytes = command_dst.reserve_cpu(num_bytes, true, None);
328        let copy_desc = handle.copy_descriptor(&shape, &strides, elem_size);
329
330        // We need to free the command before creating another one.
331        core::mem::drop(command_dst);
332
333        // We create a command on the source server to retrieve the correct resource from the
334        // source memory pools. We also make sure the current stream is aligned with the stream of
335        // the binding, where the data was first allocated.
336        //
337        // We use the source stream to copy the data from the source server into the allocated
338        // bytes. This ensures that the source binding follows the correct execution order, meaning
339        // that we don't have to keep the source handle alive using synchronization, which would be
340        // the case if we performed the copy on the destination server.
341        let mut command_src = server_src.command(stream_id_src, [&src.binding].into_iter());
342        let resource_src = command_src.resource(binding.clone())?;
343        let stream_src = command_src.streams.current().sys;
344
345        unsafe {
346            write_to_cpu(
347                &shape,
348                &strides,
349                elem_size,
350                &mut bytes,
351                resource_src.ptr,
352                stream_src,
353            )?;
354        }
355        let fence_src = Fence::new(stream_src);
356
357        // We need to free the command before creating another one.
358        core::mem::drop(command_src);
359
360        // Finally, we recreate a new command on the destination server to write the data stored in
361        // pinned memory into the destination server. Here we need to wait for the initial copy
362        // made by the source server using an event. The synchronization is done lazily on the
363        // destination stream, which is very efficient.
364        let mut command_dst = server_dst.command_no_inputs(stream_id_dst);
365        let stream_dst = command_dst.streams.current().sys;
366
367        fence_src.wait_async(stream_dst);
368        command_dst.write_to_gpu(copy_desc, &bytes)?;
369
370        // We drop the last command.
371        core::mem::drop(command_dst);
372
373        Ok(Allocation { handle, strides })
374    }
375}
376
377pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
378    let rank = shape.len();
379    let mut strides = vec![1; rank];
380    for i in (0..rank - 1).rev() {
381        strides[i] = strides[i + 1] * shape[i + 1];
382    }
383    strides
384}
385
386#[derive(Debug)]
387pub(crate) enum LaunchError {
388    OutOfMemory,
389    Unknown(String),
390}
391
392impl From<LaunchError> for ProfileError {
393    fn from(val: LaunchError) -> Self {
394        match val {
395            LaunchError::OutOfMemory => ProfileError::Unknown("Out of memory".into()),
396            LaunchError::Unknown(msg) => ProfileError::Unknown(msg),
397        }
398    }
399}
400
401pub fn valid_strides(shape: &[usize], strides: &[usize]) -> bool {
402    let rank = shape.len();
403    if strides[rank - 1] != 1 {
404        return false;
405    }
406    if rank <= 1 {
407        return true;
408    }
409
410    let mut sorted = strides.to_vec();
411    sorted.sort();
412    sorted.reverse();
413
414    if sorted != strides {
415        return false;
416    }
417
418    for i in 0..rank - 2 {
419        if strides[i] != shape[i + 1] * strides[i + 1] {
420            return false;
421        }
422    }
423    true
424}