1use crate::{
2 DeviceProperties,
3 config::{TypeNameFormatLevel, type_name_format},
4 kernel::KernelMetadata,
5 logging::ProfileLevel,
6 memory_management::{MemoryAllocationMode, MemoryUsage},
7 server::{
8 Allocation, AllocationDescriptor, AllocationKind, Binding, Bindings, ComputeServer,
9 CopyDescriptor, CubeCount, Handle, IoError, ProfileError, ServerUtilities,
10 },
11 storage::{BindingResource, ComputeStorage},
12};
13use alloc::format;
14use alloc::sync::Arc;
15use alloc::vec;
16use alloc::vec::Vec;
17use core::ops::DerefMut;
18use cubecl_common::{
19 ExecutionMode,
20 bytes::Bytes,
21 device::{Device, DeviceContext},
22 future::DynFut,
23 profile::ProfileDuration,
24};
25
26#[allow(unused)]
27use cubecl_common::profile::TimingMethod;
28use cubecl_common::stream_id::StreamId;
29
30pub struct ComputeClient<Server: ComputeServer> {
33 context: DeviceContext<Server>,
34 utilities: Arc<ServerUtilities<Server>>,
35 stream_id: Option<StreamId>,
36}
37
38impl<S> Clone for ComputeClient<S>
39where
40 S: ComputeServer,
41{
42 fn clone(&self) -> Self {
43 Self {
44 context: self.context.clone(),
45 utilities: self.utilities.clone(),
46 stream_id: self.stream_id,
47 }
48 }
49}
50
51impl<Server> ComputeClient<Server>
52where
53 Server: ComputeServer,
54{
55 pub fn info(&self) -> &Server::Info {
57 &self.utilities.info
58 }
59
60 pub fn init<D: Device>(device: &D, server: Server) -> Self {
62 let utilities = server.utilities();
63
64 let context = DeviceContext::<Server>::insert(device, server)
65 .expect("Can't create a new client on an already registered server");
66
67 Self {
68 context,
69 utilities,
70 stream_id: None,
71 }
72 }
73
74 pub fn load<D: Device>(device: &D) -> Self {
76 let context = DeviceContext::<Server>::locate(device);
77 let utilities = context.lock().utilities();
78
79 Self {
80 context,
81 utilities,
82 stream_id: None,
83 }
84 }
85
86 fn stream_id(&self) -> StreamId {
87 match self.stream_id {
88 Some(val) => val,
89 None => StreamId::current(),
90 }
91 }
92
93 pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
99 self.stream_id = Some(stream_id);
100 }
101
102 fn do_read(&self, descriptors: Vec<CopyDescriptor<'_>>) -> DynFut<Result<Vec<Bytes>, IoError>> {
103 let stream_id = self.stream_id();
104 let mut state = self.context.lock();
105 let fut = state.read(descriptors, stream_id);
106 core::mem::drop(state);
107 fut
108 }
109
110 pub fn read_async(&self, handles: Vec<Handle>) -> impl Future<Output = Vec<Bytes>> + Send {
112 let strides = [1];
113 let shapes = handles
114 .iter()
115 .map(|it| [it.size() as usize])
116 .collect::<Vec<_>>();
117 let bindings = handles
118 .into_iter()
119 .map(|it| it.binding())
120 .collect::<Vec<_>>();
121 let descriptors = bindings
122 .into_iter()
123 .zip(shapes.iter())
124 .map(|(binding, shape)| CopyDescriptor::new(binding, shape, &strides, 1))
125 .collect();
126
127 let fut = self.do_read(descriptors);
128
129 async move { fut.await.unwrap() }
130 }
131
132 pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
138 cubecl_common::reader::read_sync(self.read_async(handles))
139 }
140
141 pub fn read_one(&self, handle: Handle) -> Bytes {
146 cubecl_common::reader::read_sync(self.read_async(vec![handle])).remove(0)
147 }
148
149 pub fn read_tensor_async(
151 &self,
152 descriptors: Vec<CopyDescriptor<'_>>,
153 ) -> impl Future<Output = Vec<Bytes>> + Send {
154 let fut = self.do_read(descriptors);
155
156 async move { fut.await.unwrap() }
157 }
158
159 pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor<'_>>) -> Vec<Bytes> {
172 cubecl_common::reader::read_sync(self.read_tensor_async(descriptors))
173 }
174
175 pub fn read_one_tensor_async(
178 &self,
179 descriptor: CopyDescriptor<'_>,
180 ) -> impl Future<Output = Bytes> + Send {
181 let fut = self.read_tensor_async(vec![descriptor]);
182
183 async { fut.await.remove(0) }
184 }
185
186 pub fn read_one_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
192 self.read_tensor(vec![descriptor]).remove(0)
193 }
194
195 pub fn get_resource(
197 &self,
198 binding: Binding,
199 ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
200 let stream_id = self.stream_id();
201 self.context.lock().get_resource(binding, stream_id)
202 }
203
204 fn do_create(
205 &self,
206 descriptors: Vec<AllocationDescriptor<'_>>,
207 data: Vec<&[u8]>,
208 ) -> Result<Vec<Allocation>, IoError> {
209 let mut state = self.context.lock();
210 let allocations = state.create(descriptors.clone(), self.stream_id())?;
211 let descriptors = descriptors
212 .into_iter()
213 .zip(allocations.iter())
214 .zip(data)
215 .map(|((desc, alloc), data)| {
216 (
217 CopyDescriptor::new(
218 alloc.handle.clone().binding(),
219 desc.shape,
220 &alloc.strides,
221 desc.elem_size,
222 ),
223 data,
224 )
225 })
226 .collect();
227 let stream_id = self.stream_id();
228 state.write(descriptors, stream_id)?;
229 Ok(allocations)
230 }
231
232 pub fn create(&self, data: &[u8]) -> Handle {
234 let shape = [data.len()];
235
236 self.do_create(
237 vec![AllocationDescriptor::new(
238 AllocationKind::Contiguous,
239 &shape,
240 1,
241 )],
242 vec![data],
243 )
244 .unwrap()
245 .remove(0)
246 .handle
247 }
248
249 pub fn create_tensor(&self, data: &[u8], shape: &[usize], elem_size: usize) -> Allocation {
263 self.do_create(
264 vec![AllocationDescriptor::new(
265 AllocationKind::Optimized,
266 shape,
267 elem_size,
268 )],
269 vec![data],
270 )
271 .unwrap()
272 .remove(0)
273 }
274
275 pub fn create_tensors(
279 &self,
280 descriptors: Vec<(AllocationDescriptor<'_>, &[u8])>,
281 ) -> Vec<Allocation> {
282 let (descriptors, data) = descriptors.into_iter().unzip();
283
284 self.do_create(descriptors, data).unwrap()
285 }
286
287 fn do_empty(
288 &self,
289 descriptors: Vec<AllocationDescriptor<'_>>,
290 ) -> Result<Vec<Allocation>, IoError> {
291 let mut state = self.context.lock();
292 state.create(descriptors, self.stream_id())
293 }
294
295 pub fn empty(&self, size: usize) -> Handle {
297 let shape = [size];
298 let descriptor = AllocationDescriptor::new(AllocationKind::Contiguous, &shape, 1);
299 self.do_empty(vec![descriptor]).unwrap().remove(0).handle
300 }
301
302 pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> Allocation {
305 let descriptor = AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size);
306 self.do_empty(vec![descriptor]).unwrap().remove(0)
307 }
308
309 pub fn empty_tensors(&self, descriptors: Vec<AllocationDescriptor<'_>>) -> Vec<Allocation> {
312 self.do_empty(descriptors).unwrap()
313 }
314
315 pub fn to_client(&self, src: Handle, dst_server: &Self) -> Allocation {
317 let shape = [src.size() as usize];
318 let src_descriptor = src.copy_descriptor(&shape, &[1], 1);
319
320 if Server::SERVER_COMM_ENABLED {
321 self.to_client_tensor(src_descriptor, dst_server)
322 } else {
323 let alloc_desc = AllocationDescriptor::new(
324 AllocationKind::Contiguous,
325 src_descriptor.shape,
326 src_descriptor.elem_size,
327 );
328 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
329 }
330 }
331
332 pub fn to_client_tensor(
336 &self,
337 src_descriptor: CopyDescriptor<'_>,
338 dst_server: &Self,
339 ) -> Allocation {
340 if Server::SERVER_COMM_ENABLED {
341 let mut server_src = self.context.lock();
342 let mut server_dst = dst_server.context.lock();
343
344 Server::copy(
345 server_src.deref_mut(),
346 server_dst.deref_mut(),
347 src_descriptor,
348 self.stream_id(),
349 dst_server.stream_id(),
350 )
351 .unwrap()
352 } else {
353 let alloc_desc = AllocationDescriptor::new(
354 AllocationKind::Optimized,
355 src_descriptor.shape,
356 src_descriptor.elem_size,
357 );
358 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
359 }
360 }
361
362 #[track_caller]
363 unsafe fn execute_inner(
364 &self,
365 kernel: Server::Kernel,
366 count: CubeCount,
367 bindings: Bindings,
368 mode: ExecutionMode,
369 stream_id: StreamId,
370 ) {
371 let level = self.utilities.logger.profile_level();
372
373 match level {
374 None | Some(ProfileLevel::ExecutionOnly) => {
375 let mut state = self.context.lock();
376 let name = kernel.name();
377
378 unsafe { state.execute(kernel, count, bindings, mode, stream_id) };
379
380 if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
381 let info = type_name_format(name, TypeNameFormatLevel::Balanced);
382 self.utilities.logger.register_execution(info);
383 }
384 }
385 Some(level) => {
386 let name = kernel.name();
387 let kernel_id = kernel.id();
388 let profile = self
389 .profile(
390 || unsafe {
391 let mut state = self.context.lock();
392 state.execute(kernel, count.clone(), bindings, mode, stream_id)
393 },
394 name,
395 )
396 .unwrap();
397 let info = match level {
398 ProfileLevel::Full => {
399 format!("{name}: {kernel_id} CubeCount {count:?}")
400 }
401 _ => type_name_format(name, TypeNameFormatLevel::Balanced),
402 };
403 self.utilities.logger.register_profiled(info, profile);
404 }
405 }
406 }
407
408 #[track_caller]
410 pub fn execute(&self, kernel: Server::Kernel, count: CubeCount, bindings: Bindings) {
411 unsafe {
413 self.execute_inner(
414 kernel,
415 count,
416 bindings,
417 ExecutionMode::Checked,
418 self.stream_id(),
419 );
420 }
421 }
422
423 #[track_caller]
431 pub unsafe fn execute_unchecked(
432 &self,
433 kernel: Server::Kernel,
434 count: CubeCount,
435 bindings: Bindings,
436 ) {
437 unsafe {
439 self.execute_inner(
440 kernel,
441 count,
442 bindings,
443 ExecutionMode::Unchecked,
444 self.stream_id(),
445 );
446 }
447 }
448
449 pub fn flush(&self) {
451 let stream_id = self.stream_id();
452 self.context.lock().flush(stream_id);
453 }
454
455 pub fn sync(&self) -> DynFut<()> {
457 let stream_id = self.stream_id();
458 let mut state = self.context.lock();
459 let fut = state.sync(stream_id);
460 core::mem::drop(state);
461 self.utilities.logger.profile_summary();
462
463 fut
464 }
465
466 pub fn properties(&self) -> &DeviceProperties {
468 &self.utilities.properties
469 }
470
471 pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
475 Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
476 }
477
478 pub fn memory_usage(&self) -> MemoryUsage {
480 self.context.lock().memory_usage(self.stream_id())
481 }
482
483 pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
489 self.context.lock().allocation_mode(mode, self.stream_id())
490 }
491
492 pub fn memory_persistent_allocation<Input, Output, Func: Fn(Input) -> Output>(
499 &self,
500 input: Input,
501 func: Func,
502 ) -> Output {
503 let device_guard = self.context.lock_device();
504
505 self.context
506 .lock()
507 .allocation_mode(MemoryAllocationMode::Persistent, self.stream_id());
508
509 let output = func(input);
510
511 self.context
512 .lock()
513 .allocation_mode(MemoryAllocationMode::Auto, self.stream_id());
514
515 core::mem::drop(device_guard);
516
517 output
518 }
519
520 pub fn memory_cleanup(&self) {
525 self.context.lock().memory_cleanup(self.stream_id())
526 }
527
528 #[track_caller]
530 pub fn profile<O>(
531 &self,
532 func: impl FnOnce() -> O,
533 #[allow(unused)] func_name: &str,
534 ) -> Result<ProfileDuration, ProfileError> {
535 #[cfg(feature = "profile-tracy")]
538 let location = std::panic::Location::caller();
539
540 #[cfg(feature = "profile-tracy")]
542 let _span = tracy_client::Client::running().unwrap().span_alloc(
543 None,
544 func_name,
545 location.file(),
546 location.line(),
547 0,
548 );
549
550 let device_guard = self.context.lock_device();
551
552 #[cfg(feature = "profile-tracy")]
553 let gpu_span = if self.state.properties.timing_method == TimingMethod::Device {
554 let gpu_span = self
555 .state
556 .gpu_client
557 .span_alloc(func_name, "profile", location.file(), location.line())
558 .unwrap();
559 Some(gpu_span)
560 } else {
561 None
562 };
563
564 let token = self.context.lock().start_profile(self.stream_id());
565
566 let out = func();
567
568 let result = self.context.lock().end_profile(self.stream_id(), token);
569
570 core::mem::drop(out);
571
572 #[cfg(feature = "profile-tracy")]
573 if let Some(mut gpu_span) = gpu_span {
574 gpu_span.end_zone();
575 let epoch = self.state.epoch_time;
576 result = result.map(|result| {
578 ProfileDuration::new(
579 Box::pin(async move {
580 let ticks = result.resolve().await;
581 let start_duration = ticks.start_duration_since(epoch).as_nanos() as i64;
582 let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
583 gpu_span.upload_timestamp_start(start_duration);
584 gpu_span.upload_timestamp_end(end_duration);
585 ticks
586 }),
587 TimingMethod::Device,
588 )
589 });
590 }
591 core::mem::drop(device_guard);
592
593 result
594 }
595
596 fn change_client_sync(
598 &self,
599 src_descriptor: CopyDescriptor<'_>,
600 alloc_descriptor: AllocationDescriptor<'_>,
601 dst_server: &Self,
602 ) -> Allocation {
603 let shape = src_descriptor.shape;
604 let elem_size = src_descriptor.elem_size;
605 let stream_id = self.stream_id();
606
607 let alloc = dst_server
609 .context
610 .lock()
611 .create(vec![alloc_descriptor], self.stream_id())
612 .unwrap()
613 .remove(0);
614
615 let read = self.context.lock().read(vec![src_descriptor], stream_id);
616 let data = cubecl_common::future::block_on(read).unwrap();
617
618 let desc_descriptor = CopyDescriptor {
619 binding: alloc.handle.clone().binding(),
620 shape,
621 strides: &alloc.strides,
622 elem_size,
623 };
624
625 dst_server
626 .context
627 .lock()
628 .write(vec![(desc_descriptor, &data[0])], stream_id)
629 .unwrap();
630
631 alloc
632 }
633}