Skip to main content

cubecl_hip/compute/
server.rs

1use super::storage::gpu::{GpuResource, GpuStorage};
2use crate::{
3    compute::{command::Command, context::HipContext, fence::Fence, stream::HipStreamBackend},
4    runtime::HipCompiler,
5};
6use cubecl_common::{bytes::Bytes, future::DynFut, profile::ProfileDuration, stream_id::StreamId};
7use cubecl_core::{
8    MemoryConfiguration,
9    backtrace::BackTrace,
10    future,
11    ir::MemoryDeviceProperties,
12    prelude::*,
13    server::{
14        Binding, CopyDescriptor, KernelArguments, ProfileError, ProfilingToken,
15        ServerCommunication, ServerError, ServerUtilities, StreamErrorMode,
16    },
17};
18use cubecl_runtime::{
19    allocator::PitchedMemoryLayoutPolicy,
20    compiler::CubeTask,
21    config::{CubeClRuntimeConfig, RuntimeConfig},
22    logging::ServerLogger,
23    memory_management::{ManagedMemoryHandle, MemoryAllocationMode, MemoryUsage},
24    server::ComputeServer,
25    storage::{ComputeStorage, ManagedResource},
26    stream::MultiStream,
27};
28use std::sync::Arc;
29
30#[derive(Debug)]
31pub struct HipServer {
32    ctx: HipContext,
33    streams: MultiStream<HipStreamBackend>,
34    utilities: Arc<ServerUtilities<Self>>,
35}
36
37// SAFETY: `HipServer` is only accessed from one thread at a time via the `DeviceHandle`
38// (which serializes access through either a mutex or a dedicated runner thread depending
39// on the selected channel feature). The HIP context and streams it manages are never
40// shared across threads without synchronization.
41unsafe impl Send for HipServer {}
42
43impl ComputeServer for HipServer {
44    type Kernel = Box<dyn CubeTask<HipCompiler>>;
45    type Storage = GpuStorage;
46    type MemoryLayoutPolicy = PitchedMemoryLayoutPolicy;
47    type Info = ();
48
49    fn logger(&self) -> Arc<ServerLogger> {
50        self.streams.logger.clone()
51    }
52
53    fn utilities(&self) -> Arc<ServerUtilities<Self>> {
54        self.utilities.clone()
55    }
56
57    fn staging(&mut self, sizes: &[usize], stream_id: StreamId) -> Result<Vec<Bytes>, ServerError> {
58        let mut command = self.command_no_inputs(
59            stream_id,
60            StreamErrorMode {
61                ignore: true,
62                flush: false,
63            },
64        )?;
65
66        Ok(sizes
67            .iter()
68            .map(|size| command.reserve_cpu(*size, true, None))
69            .collect())
70    }
71
72    fn initialize_memory(&mut self, memory: ManagedMemoryHandle, size: u64, stream_id: StreamId) {
73        let mut command = match self.command_no_inputs(
74            stream_id,
75            StreamErrorMode {
76                ignore: true,
77                flush: false,
78            },
79        ) {
80            Ok(val) => val,
81            Err(err) => unreachable!("{err:?}"),
82        };
83
84        let reserved = command.reserve(size).unwrap();
85        command.bind(reserved, memory);
86    }
87
88    fn read(
89        &mut self,
90        descriptors: Vec<CopyDescriptor>,
91        stream_id: StreamId,
92    ) -> DynFut<Result<Vec<Bytes>, ServerError>> {
93        match self.command(
94            stream_id,
95            descriptors.iter().map(|d| &d.handle),
96            StreamErrorMode {
97                ignore: false,
98                flush: true,
99            },
100        ) {
101            Ok(mut command) => Box::pin(command.read_async(descriptors)),
102            Err(err) => Box::pin(async move { Err(err) }),
103        }
104    }
105
106    fn write(&mut self, descriptors: Vec<(CopyDescriptor, Bytes)>, stream_id: StreamId) {
107        let mut command = match self.command(
108            stream_id,
109            descriptors.iter().map(|desc| &desc.0.handle),
110            StreamErrorMode {
111                ignore: true,
112                flush: false,
113            },
114        ) {
115            Ok(val) => val,
116            Err(err) => unreachable!("{err:?}"),
117        };
118
119        for (descriptor, data) in descriptors {
120            if let Err(err) = command.write_to_gpu(descriptor, data) {
121                command.error(err.into());
122                return;
123            }
124        }
125    }
126
127    unsafe fn launch(
128        &mut self,
129        kernel: Self::Kernel,
130        count: CubeCount,
131        bindings: KernelArguments,
132        mode: ExecutionMode,
133        stream_id: StreamId,
134    ) {
135        if let Err(err) = self.launch_checked(kernel, count, bindings, mode, stream_id) {
136            let mut stream = match self.streams.resolve(stream_id, [].into_iter(), false) {
137                Ok(stream) => stream,
138                Err(err) => unreachable!("{err:?}"),
139            };
140            stream.current().errors.push(err);
141        }
142    }
143
144    fn flush(&mut self, stream_id: StreamId) -> Result<(), ServerError> {
145        let mut command = self.command_no_inputs(
146            stream_id,
147            StreamErrorMode {
148                ignore: false,
149                flush: true,
150            },
151        )?;
152
153        let current = command.streams.current();
154        current.drop_queue.flush(|| Fence::new(current.sys));
155        current.memory_management_gpu.storage().flush();
156
157        Ok(())
158    }
159
160    fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ServerError>> {
161        let command = self.command_no_inputs(
162            stream_id,
163            StreamErrorMode {
164                ignore: false,
165                flush: true,
166            },
167        );
168
169        match command {
170            Ok(mut command) => command.sync(),
171            Err(err) => Box::pin(async { Err(err) }),
172        }
173    }
174
175    fn start_profile(&mut self, stream_id: StreamId) -> Result<ProfilingToken, ServerError> {
176        cubecl_common::future::block_on(self.sync(stream_id))?;
177        Ok(self.ctx.timestamps.start())
178    }
179
180    fn end_profile(
181        &mut self,
182        stream_id: StreamId,
183        token: ProfilingToken,
184    ) -> Result<ProfileDuration, ProfileError> {
185        if let Err(err) = cubecl_common::future::block_on(self.sync(stream_id)) {
186            self.ctx
187                .timestamps
188                .error(ProfileError::Server(Box::new(err)));
189        }
190        self.ctx.timestamps.stop(token)
191    }
192
193    fn get_resource(
194        &mut self,
195        binding: Binding,
196        stream_id: StreamId,
197    ) -> Result<ManagedResource<GpuResource>, ServerError> {
198        let mut command = self.command(
199            stream_id,
200            [&binding].into_iter(),
201            StreamErrorMode {
202                ignore: true,
203                flush: false,
204            },
205        )?;
206        let memory = binding.memory.clone();
207        let resource = command.resource(binding)?;
208
209        Ok(ManagedResource::new(memory, resource))
210    }
211
212    fn memory_usage(&mut self, stream_id: StreamId) -> Result<MemoryUsage, ServerError> {
213        let mut command = self.command_no_inputs(
214            stream_id,
215            StreamErrorMode {
216                ignore: false,
217                flush: false,
218            },
219        )?;
220        Ok(command.memory_usage())
221    }
222
223    fn memory_cleanup(&mut self, stream_id: StreamId) {
224        let mut command = match self.command_no_inputs(
225            stream_id,
226            StreamErrorMode {
227                ignore: true,
228                flush: false,
229            },
230        ) {
231            Ok(val) => val,
232            // Server is in error.
233            Err(_) => return,
234        };
235        command.memory_cleanup()
236    }
237
238    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
239        let mut command = match self.command_no_inputs(
240            stream_id,
241            StreamErrorMode {
242                ignore: true,
243                flush: false,
244            },
245        ) {
246            Ok(val) => val,
247            Err(err) => unreachable!("{err:?}"),
248        };
249        command.allocation_mode(mode)
250    }
251}
252
253impl ServerCommunication for HipServer {
254    const SERVER_COMM_ENABLED: bool = false;
255}
256
257impl HipServer {
258    /// Create a new hip server.
259    pub(crate) fn new(
260        ctx: HipContext,
261        mem_props: MemoryDeviceProperties,
262        mem_config: MemoryConfiguration,
263        mem_alignment: usize,
264        is_integrated: bool,
265        utilities: ServerUtilities<Self>,
266    ) -> Self {
267        let config = CubeClRuntimeConfig::get();
268        let max_streams = config.streaming.max_streams;
269
270        Self {
271            ctx,
272            streams: MultiStream::new(
273                utilities.logger.clone(),
274                HipStreamBackend::new(
275                    mem_props,
276                    mem_config,
277                    mem_alignment,
278                    is_integrated,
279                    utilities.logger.clone(),
280                ),
281                max_streams,
282            ),
283            utilities: Arc::new(utilities),
284        }
285    }
286
287    fn command_no_inputs(
288        &mut self,
289        stream_id: StreamId,
290        mode: StreamErrorMode,
291    ) -> Result<Command<'_>, ServerError> {
292        self.command(stream_id, [].into_iter(), mode)
293    }
294
295    fn command<'a>(
296        &mut self,
297        stream_id: StreamId,
298        handles: impl Iterator<Item = &'a Binding>,
299        mode: StreamErrorMode,
300    ) -> Result<Command<'_>, ServerError> {
301        if mode.flush {
302            let errors = self.flush_errors(stream_id);
303
304            if !mode.ignore && !errors.is_empty() {
305                return Err(ServerError::ServerUnhealthy {
306                    errors,
307                    backtrace: BackTrace::capture(),
308                });
309            }
310        }
311        let streams = self.streams.resolve(stream_id, handles, !mode.ignore)?;
312
313        Ok(Command::new(&mut self.ctx, streams))
314    }
315
316    fn flush_errors(&mut self, stream_id: StreamId) -> Vec<ServerError> {
317        let mut stream = match self.streams.resolve(stream_id, [].into_iter(), false) {
318            Ok(stream) => stream,
319            Err(_) => return Vec::new(),
320        };
321        let errors = core::mem::take(&mut stream.current().errors);
322
323        // It is very important to tag current profiles as being wrong.
324        if !errors.is_empty() {
325            self.ctx.timestamps.error(ProfileError::Unknown {
326                reason: alloc::format!("{errors:?}"),
327                backtrace: BackTrace::capture(),
328            });
329            stream.current().memory_management_gpu.cleanup(false);
330        }
331
332        core::mem::drop(stream);
333        errors
334    }
335
336    fn launch_checked(
337        &mut self,
338        kernel: Box<dyn CubeTask<HipCompiler>>,
339        count: CubeCount,
340        bindings: KernelArguments,
341        mode: ExecutionMode,
342        stream_id: StreamId,
343    ) -> Result<(), ServerError> {
344        let mut kernel_id = kernel.id();
345        let logger = self.streams.logger.clone();
346        kernel_id.mode(mode);
347        let mut command = self.command(
348            stream_id,
349            bindings.buffers.iter(),
350            StreamErrorMode {
351                ignore: true,
352                flush: false,
353            },
354        )?;
355
356        let count = match count {
357            CubeCount::Static(x, y, z) => (x, y, z),
358            // TODO: HIP doesn't have an exact equivalent of dynamic dispatch. Instead, kernels are free to launch other kernels.
359            // One option is to create a dummy kernel with 1 thread that launches the real kernel with the dynamic dispatch settings.
360            // For now, just read the dispatch settings from the buffer.
361            CubeCount::Dynamic(binding) => {
362                let data = future::block_on(command.read_async(vec![CopyDescriptor::new(
363                    binding,
364                    [3].into(),
365                    [1].into(),
366                    4,
367                )]))
368                .unwrap();
369                let data = bytemuck::cast_slice(&data[0]);
370                assert!(
371                    data.len() == 3,
372                    "Dynamic cube count should contain 3 values"
373                );
374                (data[0], data[1], data[2])
375            }
376        };
377
378        let KernelArguments {
379            buffers,
380            info,
381            tensor_maps,
382        } = bindings;
383
384        debug_assert!(tensor_maps.is_empty(), "Can't use tensor maps on HIP");
385
386        let info = command
387            .create_with_data(bytemuck::cast_slice(&info.data))
388            .unwrap();
389
390        let mut resources: Vec<_> = buffers
391            .into_iter()
392            .map(|b| command.resource(b).expect("Resource to exist."))
393            .collect();
394
395        resources.push(
396            command
397                .resource(info.binding())
398                .expect("Resource to exist."),
399        );
400
401        command.kernel(kernel_id, kernel, mode, count, &resources, logger)?;
402
403        Ok(())
404    }
405
406    pub(crate) fn utilities(&self) -> Arc<ServerUtilities<Self>> {
407        self.utilities.clone()
408    }
409}