1use super::storage::gpu::{GpuResource, GpuStorage};
2use crate::{
3 compute::{command::Command, context::HipContext, fence::Fence, stream::HipStreamBackend},
4 runtime::HipCompiler,
5};
6use cubecl_common::{bytes::Bytes, future::DynFut, profile::ProfileDuration, stream_id::StreamId};
7use cubecl_core::{
8 MemoryConfiguration,
9 backtrace::BackTrace,
10 future,
11 ir::MemoryDeviceProperties,
12 prelude::*,
13 server::{
14 Binding, CopyDescriptor, KernelArguments, ProfileError, ProfilingToken,
15 ServerCommunication, ServerError, ServerUtilities, StreamErrorMode,
16 },
17};
18use cubecl_runtime::{
19 allocator::PitchedMemoryLayoutPolicy,
20 compiler::CubeTask,
21 config::{CubeClRuntimeConfig, RuntimeConfig},
22 logging::ServerLogger,
23 memory_management::{ManagedMemoryHandle, MemoryAllocationMode, MemoryUsage},
24 server::ComputeServer,
25 storage::{ComputeStorage, ManagedResource},
26 stream::MultiStream,
27};
28use std::sync::Arc;
29
30#[derive(Debug)]
31pub struct HipServer {
32 ctx: HipContext,
33 streams: MultiStream<HipStreamBackend>,
34 utilities: Arc<ServerUtilities<Self>>,
35}
36
37unsafe impl Send for HipServer {}
42
43impl ComputeServer for HipServer {
44 type Kernel = Box<dyn CubeTask<HipCompiler>>;
45 type Storage = GpuStorage;
46 type MemoryLayoutPolicy = PitchedMemoryLayoutPolicy;
47 type Info = ();
48
49 fn logger(&self) -> Arc<ServerLogger> {
50 self.streams.logger.clone()
51 }
52
53 fn utilities(&self) -> Arc<ServerUtilities<Self>> {
54 self.utilities.clone()
55 }
56
57 fn staging(&mut self, sizes: &[usize], stream_id: StreamId) -> Result<Vec<Bytes>, ServerError> {
58 let mut command = self.command_no_inputs(
59 stream_id,
60 StreamErrorMode {
61 ignore: true,
62 flush: false,
63 },
64 )?;
65
66 Ok(sizes
67 .iter()
68 .map(|size| command.reserve_cpu(*size, true, None))
69 .collect())
70 }
71
72 fn initialize_memory(&mut self, memory: ManagedMemoryHandle, size: u64, stream_id: StreamId) {
73 let mut command = match self.command_no_inputs(
74 stream_id,
75 StreamErrorMode {
76 ignore: true,
77 flush: false,
78 },
79 ) {
80 Ok(val) => val,
81 Err(err) => unreachable!("{err:?}"),
82 };
83
84 let reserved = command.reserve(size).unwrap();
85 command.bind(reserved, memory);
86 }
87
88 fn read(
89 &mut self,
90 descriptors: Vec<CopyDescriptor>,
91 stream_id: StreamId,
92 ) -> DynFut<Result<Vec<Bytes>, ServerError>> {
93 match self.command(
94 stream_id,
95 descriptors.iter().map(|d| &d.handle),
96 StreamErrorMode {
97 ignore: false,
98 flush: true,
99 },
100 ) {
101 Ok(mut command) => Box::pin(command.read_async(descriptors)),
102 Err(err) => Box::pin(async move { Err(err) }),
103 }
104 }
105
106 fn write(&mut self, descriptors: Vec<(CopyDescriptor, Bytes)>, stream_id: StreamId) {
107 let mut command = match self.command(
108 stream_id,
109 descriptors.iter().map(|desc| &desc.0.handle),
110 StreamErrorMode {
111 ignore: true,
112 flush: false,
113 },
114 ) {
115 Ok(val) => val,
116 Err(err) => unreachable!("{err:?}"),
117 };
118
119 for (descriptor, data) in descriptors {
120 if let Err(err) = command.write_to_gpu(descriptor, data) {
121 command.error(err.into());
122 return;
123 }
124 }
125 }
126
127 unsafe fn launch(
128 &mut self,
129 kernel: Self::Kernel,
130 count: CubeCount,
131 bindings: KernelArguments,
132 mode: ExecutionMode,
133 stream_id: StreamId,
134 ) {
135 if let Err(err) = self.launch_checked(kernel, count, bindings, mode, stream_id) {
136 let mut stream = match self.streams.resolve(stream_id, [].into_iter(), false) {
137 Ok(stream) => stream,
138 Err(err) => unreachable!("{err:?}"),
139 };
140 stream.current().errors.push(err);
141 }
142 }
143
144 fn flush(&mut self, stream_id: StreamId) -> Result<(), ServerError> {
145 let mut command = self.command_no_inputs(
146 stream_id,
147 StreamErrorMode {
148 ignore: false,
149 flush: true,
150 },
151 )?;
152
153 let current = command.streams.current();
154 current.drop_queue.flush(|| Fence::new(current.sys));
155 current.memory_management_gpu.storage().flush();
156
157 Ok(())
158 }
159
160 fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ServerError>> {
161 let command = self.command_no_inputs(
162 stream_id,
163 StreamErrorMode {
164 ignore: false,
165 flush: true,
166 },
167 );
168
169 match command {
170 Ok(mut command) => command.sync(),
171 Err(err) => Box::pin(async { Err(err) }),
172 }
173 }
174
175 fn start_profile(&mut self, stream_id: StreamId) -> Result<ProfilingToken, ServerError> {
176 cubecl_common::future::block_on(self.sync(stream_id))?;
177 Ok(self.ctx.timestamps.start())
178 }
179
180 fn end_profile(
181 &mut self,
182 stream_id: StreamId,
183 token: ProfilingToken,
184 ) -> Result<ProfileDuration, ProfileError> {
185 if let Err(err) = cubecl_common::future::block_on(self.sync(stream_id)) {
186 self.ctx
187 .timestamps
188 .error(ProfileError::Server(Box::new(err)));
189 }
190 self.ctx.timestamps.stop(token)
191 }
192
193 fn get_resource(
194 &mut self,
195 binding: Binding,
196 stream_id: StreamId,
197 ) -> Result<ManagedResource<GpuResource>, ServerError> {
198 let mut command = self.command(
199 stream_id,
200 [&binding].into_iter(),
201 StreamErrorMode {
202 ignore: true,
203 flush: false,
204 },
205 )?;
206 let memory = binding.memory.clone();
207 let resource = command.resource(binding)?;
208
209 Ok(ManagedResource::new(memory, resource))
210 }
211
212 fn memory_usage(&mut self, stream_id: StreamId) -> Result<MemoryUsage, ServerError> {
213 let mut command = self.command_no_inputs(
214 stream_id,
215 StreamErrorMode {
216 ignore: false,
217 flush: false,
218 },
219 )?;
220 Ok(command.memory_usage())
221 }
222
223 fn memory_cleanup(&mut self, stream_id: StreamId) {
224 let mut command = match self.command_no_inputs(
225 stream_id,
226 StreamErrorMode {
227 ignore: true,
228 flush: false,
229 },
230 ) {
231 Ok(val) => val,
232 Err(_) => return,
234 };
235 command.memory_cleanup()
236 }
237
238 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId) {
239 let mut command = match self.command_no_inputs(
240 stream_id,
241 StreamErrorMode {
242 ignore: true,
243 flush: false,
244 },
245 ) {
246 Ok(val) => val,
247 Err(err) => unreachable!("{err:?}"),
248 };
249 command.allocation_mode(mode)
250 }
251}
252
253impl ServerCommunication for HipServer {
254 const SERVER_COMM_ENABLED: bool = false;
255}
256
257impl HipServer {
258 pub(crate) fn new(
260 ctx: HipContext,
261 mem_props: MemoryDeviceProperties,
262 mem_config: MemoryConfiguration,
263 mem_alignment: usize,
264 is_integrated: bool,
265 utilities: ServerUtilities<Self>,
266 ) -> Self {
267 let config = CubeClRuntimeConfig::get();
268 let max_streams = config.streaming.max_streams;
269
270 Self {
271 ctx,
272 streams: MultiStream::new(
273 utilities.logger.clone(),
274 HipStreamBackend::new(
275 mem_props,
276 mem_config,
277 mem_alignment,
278 is_integrated,
279 utilities.logger.clone(),
280 ),
281 max_streams,
282 ),
283 utilities: Arc::new(utilities),
284 }
285 }
286
287 fn command_no_inputs(
288 &mut self,
289 stream_id: StreamId,
290 mode: StreamErrorMode,
291 ) -> Result<Command<'_>, ServerError> {
292 self.command(stream_id, [].into_iter(), mode)
293 }
294
295 fn command<'a>(
296 &mut self,
297 stream_id: StreamId,
298 handles: impl Iterator<Item = &'a Binding>,
299 mode: StreamErrorMode,
300 ) -> Result<Command<'_>, ServerError> {
301 if mode.flush {
302 let errors = self.flush_errors(stream_id);
303
304 if !mode.ignore && !errors.is_empty() {
305 return Err(ServerError::ServerUnhealthy {
306 errors,
307 backtrace: BackTrace::capture(),
308 });
309 }
310 }
311 let streams = self.streams.resolve(stream_id, handles, !mode.ignore)?;
312
313 Ok(Command::new(&mut self.ctx, streams))
314 }
315
316 fn flush_errors(&mut self, stream_id: StreamId) -> Vec<ServerError> {
317 let mut stream = match self.streams.resolve(stream_id, [].into_iter(), false) {
318 Ok(stream) => stream,
319 Err(_) => return Vec::new(),
320 };
321 let errors = core::mem::take(&mut stream.current().errors);
322
323 if !errors.is_empty() {
325 self.ctx.timestamps.error(ProfileError::Unknown {
326 reason: alloc::format!("{errors:?}"),
327 backtrace: BackTrace::capture(),
328 });
329 stream.current().memory_management_gpu.cleanup(false);
330 }
331
332 core::mem::drop(stream);
333 errors
334 }
335
336 fn launch_checked(
337 &mut self,
338 kernel: Box<dyn CubeTask<HipCompiler>>,
339 count: CubeCount,
340 bindings: KernelArguments,
341 mode: ExecutionMode,
342 stream_id: StreamId,
343 ) -> Result<(), ServerError> {
344 let mut kernel_id = kernel.id();
345 let logger = self.streams.logger.clone();
346 kernel_id.mode(mode);
347 let mut command = self.command(
348 stream_id,
349 bindings.buffers.iter(),
350 StreamErrorMode {
351 ignore: true,
352 flush: false,
353 },
354 )?;
355
356 let count = match count {
357 CubeCount::Static(x, y, z) => (x, y, z),
358 CubeCount::Dynamic(binding) => {
362 let data = future::block_on(command.read_async(vec![CopyDescriptor::new(
363 binding,
364 [3].into(),
365 [1].into(),
366 4,
367 )]))
368 .unwrap();
369 let data = bytemuck::cast_slice(&data[0]);
370 assert!(
371 data.len() == 3,
372 "Dynamic cube count should contain 3 values"
373 );
374 (data[0], data[1], data[2])
375 }
376 };
377
378 let KernelArguments {
379 buffers,
380 info,
381 tensor_maps,
382 } = bindings;
383
384 debug_assert!(tensor_maps.is_empty(), "Can't use tensor maps on HIP");
385
386 let info = command
387 .create_with_data(bytemuck::cast_slice(&info.data))
388 .unwrap();
389
390 let mut resources: Vec<_> = buffers
391 .into_iter()
392 .map(|b| command.resource(b).expect("Resource to exist."))
393 .collect();
394
395 resources.push(
396 command
397 .resource(info.binding())
398 .expect("Resource to exist."),
399 );
400
401 command.kernel(kernel_id, kernel, mode, count, &resources, logger)?;
402
403 Ok(())
404 }
405
406 pub(crate) fn utilities(&self) -> Arc<ServerUtilities<Self>> {
407 self.utilities.clone()
408 }
409}