1use crate::{
2 DeviceProperties,
3 config::{TypeNameFormatLevel, type_name_format},
4 kernel::KernelMetadata,
5 logging::ProfileLevel,
6 memory_management::{MemoryAllocationMode, MemoryUsage},
7 runtime::Runtime,
8 server::{
9 Allocation, AllocationDescriptor, AllocationKind, Binding, Bindings, ComputeServer,
10 CopyDescriptor, CubeCount, ExecutionError, ExecutionMode, Handle, IoError, LaunchError,
11 ProfileError, ServerCommunication, ServerUtilities,
12 },
13 storage::{BindingResource, ComputeStorage},
14};
15use alloc::format;
16use alloc::sync::Arc;
17use alloc::vec;
18use alloc::vec::Vec;
19use core::ops::DerefMut;
20use cubecl_common::{
21 bytes::{AllocationProperty, Bytes},
22 device::{Device, DeviceContext},
23 future::DynFut,
24 profile::ProfileDuration,
25};
26use cubecl_ir::StorageType;
27
28#[allow(unused)]
29use cubecl_common::profile::TimingMethod;
30use cubecl_common::stream_id::StreamId;
31
32pub struct ComputeClient<R: Runtime> {
35 context: DeviceContext<R::Server>,
36 utilities: Arc<ServerUtilities<R::Server>>,
37 stream_id: Option<StreamId>,
38}
39
40impl<R: Runtime> Clone for ComputeClient<R> {
41 fn clone(&self) -> Self {
42 Self {
43 context: self.context.clone(),
44 utilities: self.utilities.clone(),
45 stream_id: self.stream_id,
46 }
47 }
48}
49
50impl<R: Runtime> ComputeClient<R> {
51 pub fn info(&self) -> &<R::Server as ComputeServer>::Info {
53 &self.utilities.info
54 }
55
56 pub fn init<D: Device>(device: &D, server: R::Server) -> Self {
58 let utilities = server.utilities();
59
60 let context = DeviceContext::<R::Server>::insert(device, server)
61 .expect("Can't create a new client on an already registered server");
62
63 Self {
64 context,
65 utilities,
66 stream_id: None,
67 }
68 }
69
70 pub fn load<D: Device>(device: &D) -> Self {
72 let context = DeviceContext::<R::Server>::locate(device);
73 let utilities = context.lock().utilities();
74
75 Self {
76 context,
77 utilities,
78 stream_id: None,
79 }
80 }
81
82 fn stream_id(&self) -> StreamId {
83 match self.stream_id {
84 Some(val) => val,
85 None => StreamId::current(),
86 }
87 }
88
89 pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
95 self.stream_id = Some(stream_id);
96 }
97
98 fn do_read(&self, descriptors: Vec<CopyDescriptor<'_>>) -> DynFut<Result<Vec<Bytes>, IoError>> {
99 let stream_id = self.stream_id();
100 let mut state = self.context.lock();
101 let fut = state.read(descriptors, stream_id);
102 core::mem::drop(state);
103 fut
104 }
105
106 pub fn read_async(
108 &self,
109 handles: Vec<Handle>,
110 ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
111 let strides = [1];
112 let shapes = handles
113 .iter()
114 .map(|it| [it.size() as usize])
115 .collect::<Vec<_>>();
116 let bindings = handles
117 .into_iter()
118 .map(|it| it.binding())
119 .collect::<Vec<_>>();
120 let descriptors = bindings
121 .into_iter()
122 .zip(shapes.iter())
123 .map(|(binding, shape)| CopyDescriptor::new(binding, shape, &strides, 1))
124 .collect();
125
126 self.do_read(descriptors)
127 }
128
129 pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
135 cubecl_common::reader::read_sync(self.read_async(handles)).expect("TODO")
136 }
137
138 pub fn read_one(&self, handle: Handle) -> Bytes {
143 cubecl_common::reader::read_sync(self.read_async(vec![handle]))
144 .expect("TODO")
145 .remove(0)
146 }
147
148 pub fn read_tensor_async(
150 &self,
151 descriptors: Vec<CopyDescriptor<'_>>,
152 ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
153 self.do_read(descriptors)
154 }
155
156 pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor<'_>>) -> Vec<Bytes> {
169 cubecl_common::reader::read_sync(self.read_tensor_async(descriptors)).expect("TODO")
170 }
171
172 pub fn read_one_tensor_async(
175 &self,
176 descriptor: CopyDescriptor<'_>,
177 ) -> impl Future<Output = Result<Bytes, IoError>> + Send {
178 let fut = self.read_tensor_async(vec![descriptor]);
179
180 async { Ok(fut.await?.remove(0)) }
181 }
182
183 pub fn read_one_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
189 self.read_tensor(vec![descriptor]).remove(0)
190 }
191
192 pub fn get_resource(
194 &self,
195 binding: Binding,
196 ) -> BindingResource<<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource> {
197 let stream_id = self.stream_id();
198 self.context.lock().get_resource(binding, stream_id)
199 }
200
201 fn do_create_from_slices(
202 &self,
203 descriptors: Vec<AllocationDescriptor<'_>>,
204 slices: Vec<&[u8]>,
205 ) -> Result<Vec<Allocation>, IoError> {
206 let mut state = self.context.lock();
207 let allocations = state.create(descriptors.clone(), self.stream_id())?;
208 let descriptors = descriptors
209 .into_iter()
210 .zip(allocations.iter())
211 .zip(slices)
212 .map(|((desc, alloc), data)| {
213 (
214 CopyDescriptor::new(
215 alloc.handle.clone().binding(),
216 desc.shape,
217 &alloc.strides,
218 desc.elem_size,
219 ),
220 Bytes::from_bytes_vec(data.to_vec()),
221 )
222 })
223 .collect();
224 let stream_id = self.stream_id();
225 state.write(descriptors, stream_id)?;
226 Ok(allocations)
227 }
228
229 fn do_create(
230 &self,
231 descriptors: Vec<AllocationDescriptor<'_>>,
232 mut data: Vec<Bytes>,
233 ) -> Result<Vec<Allocation>, IoError> {
234 self.staging(data.iter_mut(), true);
235
236 let mut state = self.context.lock();
237 let allocations = state.create(descriptors.clone(), self.stream_id())?;
238 let descriptors = descriptors
239 .into_iter()
240 .zip(allocations.iter())
241 .zip(data)
242 .map(|((desc, alloc), data)| {
243 (
244 CopyDescriptor::new(
245 alloc.handle.clone().binding(),
246 desc.shape,
247 &alloc.strides,
248 desc.elem_size,
249 ),
250 data,
251 )
252 })
253 .collect();
254 let stream_id = self.stream_id();
255 state.write(descriptors, stream_id)?;
256 Ok(allocations)
257 }
258
259 pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
265 let shape = [slice.len()];
266
267 self.do_create_from_slices(
268 vec![AllocationDescriptor::new(
269 AllocationKind::Contiguous,
270 &shape,
271 1,
272 )],
273 vec![slice],
274 )
275 .unwrap()
276 .remove(0)
277 .handle
278 }
279
280 pub fn create(&self, data: Bytes) -> Handle {
282 let shape = [data.len()];
283
284 self.do_create(
285 vec![AllocationDescriptor::new(
286 AllocationKind::Contiguous,
287 &shape,
288 1,
289 )],
290 vec![data],
291 )
292 .unwrap()
293 .remove(0)
294 .handle
295 }
296
297 pub fn create_tensor_from_slice(
315 &self,
316 slice: &[u8],
317 shape: &[usize],
318 elem_size: usize,
319 ) -> Allocation {
320 self.do_create_from_slices(
321 vec![AllocationDescriptor::new(
322 AllocationKind::Optimized,
323 shape,
324 elem_size,
325 )],
326 vec![slice],
327 )
328 .unwrap()
329 .remove(0)
330 }
331
332 pub fn create_tensor(&self, bytes: Bytes, shape: &[usize], elem_size: usize) -> Allocation {
346 self.do_create(
347 vec![AllocationDescriptor::new(
348 AllocationKind::Optimized,
349 shape,
350 elem_size,
351 )],
352 vec![bytes],
353 )
354 .unwrap()
355 .remove(0)
356 }
357
358 pub fn create_tensors_from_slices(
366 &self,
367 descriptors: Vec<(AllocationDescriptor<'_>, &[u8])>,
368 ) -> Vec<Allocation> {
369 let (descriptors, data) = descriptors.into_iter().unzip();
370
371 self.do_create_from_slices(descriptors, data).unwrap()
372 }
373
374 pub fn create_tensors(
378 &self,
379 descriptors: Vec<(AllocationDescriptor<'_>, Bytes)>,
380 ) -> Vec<Allocation> {
381 let (descriptors, data) = descriptors.into_iter().unzip();
382
383 self.do_create(descriptors, data).unwrap()
384 }
385
386 fn do_empty(
387 &self,
388 descriptors: Vec<AllocationDescriptor<'_>>,
389 ) -> Result<Vec<Allocation>, IoError> {
390 let mut state = self.context.lock();
391 state.create(descriptors, self.stream_id())
392 }
393
394 pub fn empty(&self, size: usize) -> Handle {
396 let shape = [size];
397 let descriptor = AllocationDescriptor::new(AllocationKind::Contiguous, &shape, 1);
398 self.do_empty(vec![descriptor]).unwrap().remove(0).handle
399 }
400
401 pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> Allocation {
404 let descriptor = AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size);
405 self.do_empty(vec![descriptor]).unwrap().remove(0)
406 }
407
408 pub fn empty_tensors(&self, descriptors: Vec<AllocationDescriptor<'_>>) -> Vec<Allocation> {
411 self.do_empty(descriptors).unwrap()
412 }
413
414 pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
417 where
418 I: Iterator<Item = &'a mut Bytes>,
419 {
420 let has_staging = |b: &Bytes| match b.property() {
421 AllocationProperty::Pinned => false,
422 AllocationProperty::File => true,
423 AllocationProperty::Native | AllocationProperty::Other => !file_only,
424 };
425
426 let mut to_be_updated = Vec::new();
427 let sizes = bytes
428 .filter_map(|b| match has_staging(b) {
429 true => {
430 let len = b.len();
431 to_be_updated.push(b);
432 Some(len)
433 }
434 false => None,
435 })
436 .collect::<Vec<usize>>();
437
438 if sizes.is_empty() {
439 return;
440 }
441
442 let stream_id = self.stream_id();
443 let mut context = self.context.lock();
444 let stagings = match context.staging(&sizes, stream_id) {
445 Ok(val) => val,
446 Err(_) => return,
447 };
448 core::mem::drop(context);
449
450 to_be_updated
451 .into_iter()
452 .zip(stagings)
453 .for_each(|(b, mut staging)| {
454 b.copy_into(&mut staging);
455 core::mem::swap(b, &mut staging);
456 });
457 }
458
459 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, src, dst_server)))]
461 pub fn to_client(&self, src: Handle, dst_server: &Self) -> Allocation {
462 let shape = [src.size() as usize];
463 let src_descriptor = src.copy_descriptor(&shape, &[1], 1);
464
465 if R::Server::SERVER_COMM_ENABLED {
466 self.to_client_tensor(src_descriptor, dst_server)
467 } else {
468 let alloc_desc = AllocationDescriptor::new(
469 AllocationKind::Contiguous,
470 src_descriptor.shape,
471 src_descriptor.elem_size,
472 );
473 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
474 }
475 }
476
477 #[cfg_attr(
481 feature = "tracing",
482 tracing::instrument(skip(self, src_descriptor, dst_server))
483 )]
484 pub fn to_client_tensor(
485 &self,
486 src_descriptor: CopyDescriptor<'_>,
487 dst_server: &Self,
488 ) -> Allocation {
489 if R::Server::SERVER_COMM_ENABLED {
490 let guard = self.context.lock_device_kind();
491 let mut server_src = self.context.lock();
492 let mut server_dst = dst_server.context.lock();
493
494 let copied = R::Server::copy(
495 server_src.deref_mut(),
496 server_dst.deref_mut(),
497 src_descriptor,
498 self.stream_id(),
499 dst_server.stream_id(),
500 )
501 .unwrap();
502 core::mem::drop(server_src);
503 core::mem::drop(server_dst);
504 core::mem::drop(guard);
505 copied
506 } else {
507 let alloc_desc = AllocationDescriptor::new(
508 AllocationKind::Optimized,
509 src_descriptor.shape,
510 src_descriptor.elem_size,
511 );
512 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
513 }
514 }
515
516 #[track_caller]
517 #[cfg_attr(feature="tracing", tracing::instrument(
518 skip(self, kernel, bindings),
519 fields(
520 kernel.name = %kernel.name(),
521 kernel.id = %kernel.id(),
522 )
523 ))]
524 unsafe fn launch_inner(
525 &self,
526 kernel: <R::Server as ComputeServer>::Kernel,
527 count: CubeCount,
528 bindings: Bindings,
529 mode: ExecutionMode,
530 stream_id: StreamId,
531 ) -> Result<(), LaunchError> {
532 let level = self.utilities.logger.profile_level();
533
534 match level {
535 None | Some(ProfileLevel::ExecutionOnly) => {
536 let mut state = self.context.lock();
537 let name = kernel.name();
538
539 let result = unsafe { state.launch(kernel, count, bindings, mode, stream_id) };
540
541 if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
542 let info = type_name_format(name, TypeNameFormatLevel::Balanced);
543 self.utilities.logger.register_execution(info);
544 }
545 result
546 }
547 Some(level) => {
548 let name = kernel.name();
549 let kernel_id = kernel.id();
550 let (result, profile) = self
551 .profile(
552 || unsafe {
553 let mut state = self.context.lock();
554 state.launch(kernel, count.clone(), bindings, mode, stream_id)
555 },
556 name,
557 )
558 .unwrap();
559 let info = match level {
560 ProfileLevel::Full => {
561 format!("{name}: {kernel_id} CubeCount {count:?}")
562 }
563 _ => type_name_format(name, TypeNameFormatLevel::Balanced),
564 };
565 self.utilities.logger.register_profiled(info, profile);
566 result
567 }
568 }
569 }
570
571 #[track_caller]
573 pub fn launch(
574 &self,
575 kernel: <R::Server as ComputeServer>::Kernel,
576 count: CubeCount,
577 bindings: Bindings,
578 ) -> Result<(), LaunchError> {
579 unsafe {
581 self.launch_inner(
582 kernel,
583 count,
584 bindings,
585 ExecutionMode::Checked,
586 self.stream_id(),
587 )
588 }
589 }
590
591 #[track_caller]
599 pub unsafe fn launch_unchecked(
600 &self,
601 kernel: <R::Server as ComputeServer>::Kernel,
602 count: CubeCount,
603 bindings: Bindings,
604 ) -> Result<(), LaunchError> {
605 unsafe {
607 self.launch_inner(
608 kernel,
609 count,
610 bindings,
611 ExecutionMode::Unchecked,
612 self.stream_id(),
613 )
614 }
615 }
616
617 pub fn flush(&self) {
619 let stream_id = self.stream_id();
620 self.context.lock().flush(stream_id)
621 }
622
623 pub fn sync(&self) -> DynFut<Result<(), ExecutionError>> {
625 let stream_id = self.stream_id();
626 let mut state = self.context.lock();
627 let fut = state.sync(stream_id);
628 core::mem::drop(state);
629 self.utilities.logger.profile_summary();
630
631 fut
632 }
633
634 pub fn properties(&self) -> &DeviceProperties {
636 &self.utilities.properties
637 }
638
639 pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
643 Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
644 }
645
646 pub fn memory_usage(&self) -> MemoryUsage {
648 self.context.lock().memory_usage(self.stream_id())
649 }
650
651 pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
657 self.context.lock().allocation_mode(mode, self.stream_id())
658 }
659
660 pub fn memory_persistent_allocation<Input, Output, Func: Fn(Input) -> Output>(
667 &self,
668 input: Input,
669 func: Func,
670 ) -> Output {
671 let device_guard = self.context.lock_device();
672
673 self.context
674 .lock()
675 .allocation_mode(MemoryAllocationMode::Persistent, self.stream_id());
676
677 let output = func(input);
678
679 self.context
680 .lock()
681 .allocation_mode(MemoryAllocationMode::Auto, self.stream_id());
682
683 core::mem::drop(device_guard);
684
685 output
686 }
687
688 pub fn memory_cleanup(&self) {
693 self.context.lock().memory_cleanup(self.stream_id())
694 }
695
696 #[track_caller]
698 pub fn profile<O>(
699 &self,
700 func: impl FnOnce() -> O,
701 #[allow(unused)] func_name: &str,
702 ) -> Result<(O, ProfileDuration), ProfileError> {
703 #[cfg(feature = "profile-tracy")]
706 let location = std::panic::Location::caller();
707
708 #[cfg(feature = "profile-tracy")]
710 let _span = tracy_client::Client::running().unwrap().span_alloc(
711 None,
712 func_name,
713 location.file(),
714 location.line(),
715 0,
716 );
717
718 let device_guard = self.context.lock_device();
719
720 #[cfg(feature = "profile-tracy")]
721 let gpu_span = if self.utilities.properties.timing_method == TimingMethod::Device {
722 let gpu_span = self
723 .utilities
724 .gpu_client
725 .span_alloc(func_name, "profile", location.file(), location.line())
726 .unwrap();
727 Some(gpu_span)
728 } else {
729 None
730 };
731
732 let token = self.context.lock().start_profile(self.stream_id());
733
734 let out = func();
735
736 #[allow(unused_mut, reason = "Used in profile-tracy")]
737 let mut result = self.context.lock().end_profile(self.stream_id(), token);
738
739 #[cfg(feature = "profile-tracy")]
740 if let Some(mut gpu_span) = gpu_span {
741 gpu_span.end_zone();
742 let epoch = self.utilities.epoch_time;
743 result = result.map(|result| {
745 ProfileDuration::new(
746 Box::pin(async move {
747 let ticks = result.resolve().await;
748 let start_duration = ticks.start_duration_since(epoch).as_nanos() as i64;
749 let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
750 gpu_span.upload_timestamp_start(start_duration);
751 gpu_span.upload_timestamp_end(end_duration);
752 ticks
753 }),
754 TimingMethod::Device,
755 )
756 });
757 }
758 core::mem::drop(device_guard);
759
760 match result {
761 Ok(result) => Ok((out, result)),
762 Err(err) => Err(err),
763 }
764 }
765
766 #[cfg_attr(
768 feature = "tracing",
769 tracing::instrument(skip(self, src_descriptor, alloc_descriptor, dst_server))
770 )]
771 fn change_client_sync(
772 &self,
773 src_descriptor: CopyDescriptor<'_>,
774 alloc_descriptor: AllocationDescriptor<'_>,
775 dst_server: &Self,
776 ) -> Allocation {
777 let shape = src_descriptor.shape;
778 let elem_size = src_descriptor.elem_size;
779 let stream_id = self.stream_id();
780
781 let alloc = dst_server
783 .context
784 .lock()
785 .create(vec![alloc_descriptor], self.stream_id())
786 .unwrap()
787 .remove(0);
788
789 let read = self.context.lock().read(vec![src_descriptor], stream_id);
790 let mut data = cubecl_common::future::block_on(read).unwrap();
791
792 let desc_descriptor = CopyDescriptor {
793 binding: alloc.handle.clone().binding(),
794 shape,
795 strides: &alloc.strides,
796 elem_size,
797 };
798
799 dst_server
800 .context
801 .lock()
802 .write(vec![(desc_descriptor, data.remove(0))], stream_id)
803 .unwrap();
804
805 alloc
806 }
807
808 pub fn io_optimized_line_sizes(&self, elem: &StorageType) -> impl Iterator<Item = u8> + Clone {
810 let load_width = self.properties().hardware.load_width as usize;
811 let max = (load_width / elem.size_bits()) as u8;
812 let supported = R::supported_line_sizes();
813 supported.iter().filter(move |v| **v <= max).cloned()
814 }
815
816 pub fn io_optimized_line_sizes_unchecked(
820 &self,
821 size: usize,
822 ) -> impl Iterator<Item = u8> + Clone {
823 let load_width = self.properties().hardware.load_width as usize;
824 let size_bits = size * 8;
825 let max = load_width / size_bits;
826 let max = usize::min(R::max_global_line_size() as usize, max);
829
830 let num_candidates = max.trailing_zeros() + 1;
832
833 (0..num_candidates).map(|i| 2u8.pow(i)).rev()
834 }
835}