1use crate::{
2 config::{TypeNameFormatLevel, type_name_format},
3 kernel::KernelMetadata,
4 logging::ProfileLevel,
5 memory_management::{MemoryAllocationMode, MemoryUsage},
6 runtime::Runtime,
7 server::{
8 Allocation, AllocationDescriptor, AllocationKind, Binding, Bindings, ComputeServer,
9 CopyDescriptor, CubeCount, ExecutionError, ExecutionMode, Handle, IoError, LaunchError,
10 ProfileError, ServerCommunication, ServerUtilities,
11 },
12 storage::{BindingResource, ComputeStorage},
13};
14use alloc::format;
15use alloc::sync::Arc;
16use alloc::vec;
17use alloc::vec::Vec;
18use core::ops::DerefMut;
19use cubecl_common::{
20 bytes::{AllocationProperty, Bytes},
21 device::{Device, DeviceContext},
22 future::DynFut,
23 profile::ProfileDuration,
24};
25use cubecl_ir::{DeviceProperties, LineSize, StorageType};
26
27#[allow(unused)]
28use cubecl_common::profile::TimingMethod;
29use cubecl_common::stream_id::StreamId;
30
31pub struct ComputeClient<R: Runtime> {
34 context: DeviceContext<R::Server>,
35 utilities: Arc<ServerUtilities<R::Server>>,
36 stream_id: Option<StreamId>,
37}
38
39impl<R: Runtime> Clone for ComputeClient<R> {
40 fn clone(&self) -> Self {
41 Self {
42 context: self.context.clone(),
43 utilities: self.utilities.clone(),
44 stream_id: self.stream_id,
45 }
46 }
47}
48
49impl<R: Runtime> ComputeClient<R> {
50 pub fn info(&self) -> &<R::Server as ComputeServer>::Info {
52 &self.utilities.info
53 }
54
55 pub fn init<D: Device>(device: &D, server: R::Server) -> Self {
57 let utilities = server.utilities();
58
59 let context = DeviceContext::<R::Server>::insert(device, server)
60 .expect("Can't create a new client on an already registered server");
61
62 Self {
63 context,
64 utilities,
65 stream_id: None,
66 }
67 }
68
69 pub fn load<D: Device>(device: &D) -> Self {
71 let context = DeviceContext::<R::Server>::locate(device);
72 let utilities = context.lock().utilities();
73
74 Self {
75 context,
76 utilities,
77 stream_id: None,
78 }
79 }
80
81 fn stream_id(&self) -> StreamId {
82 match self.stream_id {
83 Some(val) => val,
84 None => StreamId::current(),
85 }
86 }
87
88 pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
94 self.stream_id = Some(stream_id);
95 }
96
97 fn do_read(&self, descriptors: Vec<CopyDescriptor<'_>>) -> DynFut<Result<Vec<Bytes>, IoError>> {
98 let stream_id = self.stream_id();
99 let mut state = self.context.lock();
100 let fut = state.read(descriptors, stream_id);
101 core::mem::drop(state);
102 fut
103 }
104
105 pub fn read_async(
107 &self,
108 handles: Vec<Handle>,
109 ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
110 let strides = [1];
111 let shapes = handles
112 .iter()
113 .map(|it| [it.size() as usize])
114 .collect::<Vec<_>>();
115 let bindings = handles
116 .into_iter()
117 .map(|it| it.binding())
118 .collect::<Vec<_>>();
119 let descriptors = bindings
120 .into_iter()
121 .zip(shapes.iter())
122 .map(|(binding, shape)| CopyDescriptor::new(binding, shape, &strides, 1))
123 .collect();
124
125 self.do_read(descriptors)
126 }
127
128 pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
134 cubecl_common::reader::read_sync(self.read_async(handles)).expect("TODO")
135 }
136
137 pub fn read_one(&self, handle: Handle) -> Bytes {
142 cubecl_common::reader::read_sync(self.read_async(vec![handle]))
143 .expect("TODO")
144 .remove(0)
145 }
146
147 pub fn read_tensor_async(
149 &self,
150 descriptors: Vec<CopyDescriptor<'_>>,
151 ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
152 self.do_read(descriptors)
153 }
154
155 pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor<'_>>) -> Vec<Bytes> {
168 cubecl_common::reader::read_sync(self.read_tensor_async(descriptors)).expect("TODO")
169 }
170
171 pub fn read_one_tensor_async(
174 &self,
175 descriptor: CopyDescriptor<'_>,
176 ) -> impl Future<Output = Result<Bytes, IoError>> + Send {
177 let fut = self.read_tensor_async(vec![descriptor]);
178
179 async { Ok(fut.await?.remove(0)) }
180 }
181
182 pub fn read_one_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
188 self.read_tensor(vec![descriptor]).remove(0)
189 }
190
191 pub fn get_resource(
193 &self,
194 binding: Binding,
195 ) -> BindingResource<<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource> {
196 let stream_id = self.stream_id();
197 self.context.lock().get_resource(binding, stream_id)
198 }
199
200 fn do_create_from_slices(
201 &self,
202 descriptors: Vec<AllocationDescriptor<'_>>,
203 slices: Vec<&[u8]>,
204 ) -> Result<Vec<Allocation>, IoError> {
205 let mut state = self.context.lock();
206 let allocations = state.create(descriptors.clone(), self.stream_id())?;
207 let descriptors = descriptors
208 .into_iter()
209 .zip(allocations.iter())
210 .zip(slices)
211 .map(|((desc, alloc), data)| {
212 (
213 CopyDescriptor::new(
214 alloc.handle.clone().binding(),
215 desc.shape,
216 &alloc.strides,
217 desc.elem_size,
218 ),
219 Bytes::from_bytes_vec(data.to_vec()),
220 )
221 })
222 .collect();
223 let stream_id = self.stream_id();
224 state.write(descriptors, stream_id)?;
225 Ok(allocations)
226 }
227
228 fn do_create(
229 &self,
230 descriptors: Vec<AllocationDescriptor<'_>>,
231 mut data: Vec<Bytes>,
232 ) -> Result<Vec<Allocation>, IoError> {
233 self.staging(data.iter_mut(), true);
234
235 let mut state = self.context.lock();
236 let allocations = state.create(descriptors.clone(), self.stream_id())?;
237 let descriptors = descriptors
238 .into_iter()
239 .zip(allocations.iter())
240 .zip(data)
241 .map(|((desc, alloc), data)| {
242 (
243 CopyDescriptor::new(
244 alloc.handle.clone().binding(),
245 desc.shape,
246 &alloc.strides,
247 desc.elem_size,
248 ),
249 data,
250 )
251 })
252 .collect();
253 let stream_id = self.stream_id();
254 state.write(descriptors, stream_id)?;
255 Ok(allocations)
256 }
257
258 pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
264 let shape = [slice.len()];
265
266 self.do_create_from_slices(
267 vec![AllocationDescriptor::new(
268 AllocationKind::Contiguous,
269 &shape,
270 1,
271 )],
272 vec![slice],
273 )
274 .unwrap()
275 .remove(0)
276 .handle
277 }
278
279 pub fn create(&self, data: Bytes) -> Handle {
281 let shape = [data.len()];
282
283 self.do_create(
284 vec![AllocationDescriptor::new(
285 AllocationKind::Contiguous,
286 &shape,
287 1,
288 )],
289 vec![data],
290 )
291 .unwrap()
292 .remove(0)
293 .handle
294 }
295
296 pub fn create_tensor_from_slice(
314 &self,
315 slice: &[u8],
316 shape: &[usize],
317 elem_size: usize,
318 ) -> Allocation {
319 self.do_create_from_slices(
320 vec![AllocationDescriptor::new(
321 AllocationKind::Optimized,
322 shape,
323 elem_size,
324 )],
325 vec![slice],
326 )
327 .unwrap()
328 .remove(0)
329 }
330
331 pub fn create_tensor(&self, bytes: Bytes, shape: &[usize], elem_size: usize) -> Allocation {
345 self.do_create(
346 vec![AllocationDescriptor::new(
347 AllocationKind::Optimized,
348 shape,
349 elem_size,
350 )],
351 vec![bytes],
352 )
353 .unwrap()
354 .remove(0)
355 }
356
357 pub fn create_tensors_from_slices(
365 &self,
366 descriptors: Vec<(AllocationDescriptor<'_>, &[u8])>,
367 ) -> Vec<Allocation> {
368 let (descriptors, data) = descriptors.into_iter().unzip();
369
370 self.do_create_from_slices(descriptors, data).unwrap()
371 }
372
373 pub fn create_tensors(
377 &self,
378 descriptors: Vec<(AllocationDescriptor<'_>, Bytes)>,
379 ) -> Vec<Allocation> {
380 let (descriptors, data) = descriptors.into_iter().unzip();
381
382 self.do_create(descriptors, data).unwrap()
383 }
384
385 fn do_empty(
386 &self,
387 descriptors: Vec<AllocationDescriptor<'_>>,
388 ) -> Result<Vec<Allocation>, IoError> {
389 let mut state = self.context.lock();
390 state.create(descriptors, self.stream_id())
391 }
392
393 pub fn empty(&self, size: usize) -> Handle {
395 let shape = [size];
396 let descriptor = AllocationDescriptor::new(AllocationKind::Contiguous, &shape, 1);
397 self.do_empty(vec![descriptor]).unwrap().remove(0).handle
398 }
399
400 pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> Allocation {
403 let descriptor = AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size);
404 self.do_empty(vec![descriptor]).unwrap().remove(0)
405 }
406
407 pub fn empty_tensors(&self, descriptors: Vec<AllocationDescriptor<'_>>) -> Vec<Allocation> {
410 self.do_empty(descriptors).unwrap()
411 }
412
413 pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
416 where
417 I: Iterator<Item = &'a mut Bytes>,
418 {
419 let has_staging = |b: &Bytes| match b.property() {
420 AllocationProperty::Pinned => false,
421 AllocationProperty::File => true,
422 AllocationProperty::Native | AllocationProperty::Other => !file_only,
423 };
424
425 let mut to_be_updated = Vec::new();
426 let sizes = bytes
427 .filter_map(|b| match has_staging(b) {
428 true => {
429 let len = b.len();
430 to_be_updated.push(b);
431 Some(len)
432 }
433 false => None,
434 })
435 .collect::<Vec<usize>>();
436
437 if sizes.is_empty() {
438 return;
439 }
440
441 let stream_id = self.stream_id();
442 let mut context = self.context.lock();
443 let stagings = match context.staging(&sizes, stream_id) {
444 Ok(val) => val,
445 Err(_) => return,
446 };
447 core::mem::drop(context);
448
449 to_be_updated
450 .into_iter()
451 .zip(stagings)
452 .for_each(|(b, mut staging)| {
453 b.copy_into(&mut staging);
454 core::mem::swap(b, &mut staging);
455 });
456 }
457
458 #[cfg_attr(
460 feature = "tracing",
461 tracing::instrument(level = "trace", skip(self, src, dst_server))
462 )]
463 pub fn to_client(&self, src: Handle, dst_server: &Self) -> Allocation {
464 let shape = [src.size() as usize];
465 let src_descriptor = src.copy_descriptor(&shape, &[1], 1);
466
467 if R::Server::SERVER_COMM_ENABLED {
468 self.to_client_tensor(src_descriptor, dst_server)
469 } else {
470 let alloc_desc = AllocationDescriptor::new(
471 AllocationKind::Contiguous,
472 src_descriptor.shape,
473 src_descriptor.elem_size,
474 );
475 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
476 }
477 }
478
479 #[cfg_attr(
483 feature = "tracing",
484 tracing::instrument(level = "trace", skip(self, src_descriptor, dst_server))
485 )]
486 pub fn to_client_tensor(
487 &self,
488 src_descriptor: CopyDescriptor<'_>,
489 dst_server: &Self,
490 ) -> Allocation {
491 if R::Server::SERVER_COMM_ENABLED {
492 let guard = self.context.lock_device_kind();
493 let mut server_src = self.context.lock();
494 let mut server_dst = dst_server.context.lock();
495
496 let copied = R::Server::copy(
497 server_src.deref_mut(),
498 server_dst.deref_mut(),
499 src_descriptor,
500 self.stream_id(),
501 dst_server.stream_id(),
502 )
503 .unwrap();
504 core::mem::drop(server_src);
505 core::mem::drop(server_dst);
506 core::mem::drop(guard);
507 copied
508 } else {
509 let alloc_desc = AllocationDescriptor::new(
510 AllocationKind::Optimized,
511 src_descriptor.shape,
512 src_descriptor.elem_size,
513 );
514 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
515 }
516 }
517
518 #[track_caller]
519 #[cfg_attr(feature = "tracing", tracing::instrument(level="trace",
520 skip(self, kernel, bindings),
521 fields(
522 kernel.name = %kernel.name(),
523 kernel.id = %kernel.id(),
524 )
525 ))]
526 unsafe fn launch_inner(
527 &self,
528 kernel: <R::Server as ComputeServer>::Kernel,
529 count: CubeCount,
530 bindings: Bindings,
531 mode: ExecutionMode,
532 stream_id: StreamId,
533 ) -> Result<(), LaunchError> {
534 let level = self.utilities.logger.profile_level();
535
536 match level {
537 None | Some(ProfileLevel::ExecutionOnly) => {
538 let mut state = self.context.lock();
539 let name = kernel.name();
540
541 let result = unsafe { state.launch(kernel, count, bindings, mode, stream_id) };
542
543 if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
544 let info = type_name_format(name, TypeNameFormatLevel::Balanced);
545 self.utilities.logger.register_execution(info);
546 }
547 result
548 }
549 Some(level) => {
550 let name = kernel.name();
551 let kernel_id = kernel.id();
552 let (result, profile) = self
553 .profile(
554 || unsafe {
555 let mut state = self.context.lock();
556 state.launch(kernel, count.clone(), bindings, mode, stream_id)
557 },
558 name,
559 )
560 .unwrap();
561 let info = match level {
562 ProfileLevel::Full => {
563 format!("{name}: {kernel_id} CubeCount {count:?}")
564 }
565 _ => type_name_format(name, TypeNameFormatLevel::Balanced),
566 };
567 self.utilities.logger.register_profiled(info, profile);
568 result
569 }
570 }
571 }
572
573 #[track_caller]
575 pub fn launch(
576 &self,
577 kernel: <R::Server as ComputeServer>::Kernel,
578 count: CubeCount,
579 bindings: Bindings,
580 ) -> Result<(), LaunchError> {
581 unsafe {
583 self.launch_inner(
584 kernel,
585 count,
586 bindings,
587 ExecutionMode::Checked,
588 self.stream_id(),
589 )
590 }
591 }
592
593 #[track_caller]
601 pub unsafe fn launch_unchecked(
602 &self,
603 kernel: <R::Server as ComputeServer>::Kernel,
604 count: CubeCount,
605 bindings: Bindings,
606 ) -> Result<(), LaunchError> {
607 unsafe {
609 self.launch_inner(
610 kernel,
611 count,
612 bindings,
613 ExecutionMode::Unchecked,
614 self.stream_id(),
615 )
616 }
617 }
618
619 pub fn flush(&self) {
621 let stream_id = self.stream_id();
622 self.context.lock().flush(stream_id)
623 }
624
625 pub fn sync(&self) -> DynFut<Result<(), ExecutionError>> {
627 let stream_id = self.stream_id();
628 let mut state = self.context.lock();
629 let fut = state.sync(stream_id);
630 core::mem::drop(state);
631 self.utilities.logger.profile_summary();
632
633 fut
634 }
635
636 pub fn properties(&self) -> &DeviceProperties {
638 &self.utilities.properties
639 }
640
641 pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
645 Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
646 }
647
648 pub fn memory_usage(&self) -> MemoryUsage {
650 self.context.lock().memory_usage(self.stream_id())
651 }
652
653 pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
659 self.context.lock().allocation_mode(mode, self.stream_id())
660 }
661
662 pub fn memory_persistent_allocation<Input, Output, Func: Fn(Input) -> Output>(
669 &self,
670 input: Input,
671 func: Func,
672 ) -> Output {
673 let device_guard = self.context.lock_device();
674
675 self.context
676 .lock()
677 .allocation_mode(MemoryAllocationMode::Persistent, self.stream_id());
678
679 let output = func(input);
680
681 self.context
682 .lock()
683 .allocation_mode(MemoryAllocationMode::Auto, self.stream_id());
684
685 core::mem::drop(device_guard);
686
687 output
688 }
689
690 pub fn memory_cleanup(&self) {
695 self.context.lock().memory_cleanup(self.stream_id())
696 }
697
698 #[track_caller]
700 pub fn profile<O>(
701 &self,
702 func: impl FnOnce() -> O,
703 #[allow(unused)] func_name: &str,
704 ) -> Result<(O, ProfileDuration), ProfileError> {
705 #[cfg(feature = "profile-tracy")]
708 let location = std::panic::Location::caller();
709
710 #[cfg(feature = "profile-tracy")]
712 let _span = tracy_client::Client::running().unwrap().span_alloc(
713 None,
714 func_name,
715 location.file(),
716 location.line(),
717 0,
718 );
719
720 let device_guard = self.context.lock_device();
721
722 #[cfg(feature = "profile-tracy")]
723 let gpu_span = if self.utilities.properties.timing_method == TimingMethod::Device {
724 let gpu_span = self
725 .utilities
726 .gpu_client
727 .span_alloc(func_name, "profile", location.file(), location.line())
728 .unwrap();
729 Some(gpu_span)
730 } else {
731 None
732 };
733
734 let token = self.context.lock().start_profile(self.stream_id());
735
736 let out = func();
737
738 #[allow(unused_mut, reason = "Used in profile-tracy")]
739 let mut result = self.context.lock().end_profile(self.stream_id(), token);
740
741 #[cfg(feature = "profile-tracy")]
742 if let Some(mut gpu_span) = gpu_span {
743 gpu_span.end_zone();
744 let epoch = self.utilities.epoch_time;
745 result = result.map(|result| {
747 ProfileDuration::new(
748 Box::pin(async move {
749 let ticks = result.resolve().await;
750 let start_duration = ticks.start_duration_since(epoch).as_nanos() as i64;
751 let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
752 gpu_span.upload_timestamp_start(start_duration);
753 gpu_span.upload_timestamp_end(end_duration);
754 ticks
755 }),
756 TimingMethod::Device,
757 )
758 });
759 }
760 core::mem::drop(device_guard);
761
762 match result {
763 Ok(result) => Ok((out, result)),
764 Err(err) => Err(err),
765 }
766 }
767
768 #[cfg_attr(
770 feature = "tracing",
771 tracing::instrument(
772 level = "trace",
773 skip(self, src_descriptor, alloc_descriptor, dst_server)
774 )
775 )]
776 fn change_client_sync(
777 &self,
778 src_descriptor: CopyDescriptor<'_>,
779 alloc_descriptor: AllocationDescriptor<'_>,
780 dst_server: &Self,
781 ) -> Allocation {
782 let shape = src_descriptor.shape;
783 let elem_size = src_descriptor.elem_size;
784 let stream_id = self.stream_id();
785
786 let alloc = dst_server
788 .context
789 .lock()
790 .create(vec![alloc_descriptor], self.stream_id())
791 .unwrap()
792 .remove(0);
793
794 let read = self.context.lock().read(vec![src_descriptor], stream_id);
795 let mut data = cubecl_common::future::block_on(read).unwrap();
796
797 let desc_descriptor = CopyDescriptor {
798 binding: alloc.handle.clone().binding(),
799 shape,
800 strides: &alloc.strides,
801 elem_size,
802 };
803
804 dst_server
805 .context
806 .lock()
807 .write(vec![(desc_descriptor, data.remove(0))], stream_id)
808 .unwrap();
809
810 alloc
811 }
812
813 pub fn io_optimized_line_sizes(
815 &self,
816 elem: &StorageType,
817 ) -> impl Iterator<Item = LineSize> + Clone {
818 let load_width = self.properties().hardware.load_width as usize;
819 let max = load_width / elem.size_bits();
820 let supported = R::supported_line_sizes();
821 supported.iter().filter(move |v| **v <= max).cloned()
822 }
823
824 pub fn io_optimized_line_sizes_unchecked(
828 &self,
829 size: usize,
830 ) -> impl Iterator<Item = LineSize> + Clone {
831 let load_width = self.properties().hardware.load_width as usize;
832 let size_bits = size * 8;
833 let max = load_width / size_bits;
834 let max = usize::min(R::max_global_line_size(), max);
837
838 let num_candidates = max.trailing_zeros() + 1;
840
841 (0..num_candidates).map(|i| 2usize.pow(i)).rev()
842 }
843}