1use crate::{
2 DeviceProperties,
3 config::{TypeNameFormatLevel, type_name_format},
4 kernel::KernelMetadata,
5 logging::ProfileLevel,
6 memory_management::{MemoryAllocationMode, MemoryUsage},
7 runtime::Runtime,
8 server::{
9 Allocation, AllocationDescriptor, AllocationKind, Binding, Bindings, ComputeServer,
10 CopyDescriptor, CubeCount, ExecutionError, Handle, IoError, LaunchError, ProfileError,
11 ServerCommunication, ServerUtilities,
12 },
13 storage::{BindingResource, ComputeStorage},
14};
15use alloc::format;
16use alloc::sync::Arc;
17use alloc::vec;
18use alloc::vec::Vec;
19use core::ops::DerefMut;
20use cubecl_common::{
21 ExecutionMode,
22 bytes::{AllocationProperty, Bytes},
23 device::{Device, DeviceContext},
24 future::DynFut,
25 profile::ProfileDuration,
26};
27use cubecl_ir::StorageType;
28
29#[allow(unused)]
30use cubecl_common::profile::TimingMethod;
31use cubecl_common::stream_id::StreamId;
32
33pub struct ComputeClient<R: Runtime> {
36 context: DeviceContext<R::Server>,
37 utilities: Arc<ServerUtilities<R::Server>>,
38 stream_id: Option<StreamId>,
39}
40
41impl<R: Runtime> Clone for ComputeClient<R> {
42 fn clone(&self) -> Self {
43 Self {
44 context: self.context.clone(),
45 utilities: self.utilities.clone(),
46 stream_id: self.stream_id,
47 }
48 }
49}
50
51impl<R: Runtime> ComputeClient<R> {
52 pub fn info(&self) -> &<R::Server as ComputeServer>::Info {
54 &self.utilities.info
55 }
56
57 pub fn init<D: Device>(device: &D, server: R::Server) -> Self {
59 let utilities = server.utilities();
60
61 let context = DeviceContext::<R::Server>::insert(device, server)
62 .expect("Can't create a new client on an already registered server");
63
64 Self {
65 context,
66 utilities,
67 stream_id: None,
68 }
69 }
70
71 pub fn load<D: Device>(device: &D) -> Self {
73 let context = DeviceContext::<R::Server>::locate(device);
74 let utilities = context.lock().utilities();
75
76 Self {
77 context,
78 utilities,
79 stream_id: None,
80 }
81 }
82
83 fn stream_id(&self) -> StreamId {
84 match self.stream_id {
85 Some(val) => val,
86 None => StreamId::current(),
87 }
88 }
89
90 pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
96 self.stream_id = Some(stream_id);
97 }
98
99 fn do_read(&self, descriptors: Vec<CopyDescriptor<'_>>) -> DynFut<Result<Vec<Bytes>, IoError>> {
100 let stream_id = self.stream_id();
101 let mut state = self.context.lock();
102 let fut = state.read(descriptors, stream_id);
103 core::mem::drop(state);
104 fut
105 }
106
107 pub fn read_async(
109 &self,
110 handles: Vec<Handle>,
111 ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
112 let strides = [1];
113 let shapes = handles
114 .iter()
115 .map(|it| [it.size() as usize])
116 .collect::<Vec<_>>();
117 let bindings = handles
118 .into_iter()
119 .map(|it| it.binding())
120 .collect::<Vec<_>>();
121 let descriptors = bindings
122 .into_iter()
123 .zip(shapes.iter())
124 .map(|(binding, shape)| CopyDescriptor::new(binding, shape, &strides, 1))
125 .collect();
126
127 self.do_read(descriptors)
128 }
129
130 pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
136 cubecl_common::reader::read_sync(self.read_async(handles)).expect("TODO")
137 }
138
139 pub fn read_one(&self, handle: Handle) -> Bytes {
144 cubecl_common::reader::read_sync(self.read_async(vec![handle]))
145 .expect("TODO")
146 .remove(0)
147 }
148
149 pub fn read_tensor_async(
151 &self,
152 descriptors: Vec<CopyDescriptor<'_>>,
153 ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
154 self.do_read(descriptors)
155 }
156
157 pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor<'_>>) -> Vec<Bytes> {
170 cubecl_common::reader::read_sync(self.read_tensor_async(descriptors)).expect("TODO")
171 }
172
173 pub fn read_one_tensor_async(
176 &self,
177 descriptor: CopyDescriptor<'_>,
178 ) -> impl Future<Output = Result<Bytes, IoError>> + Send {
179 let fut = self.read_tensor_async(vec![descriptor]);
180
181 async { Ok(fut.await?.remove(0)) }
182 }
183
184 pub fn read_one_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
190 self.read_tensor(vec![descriptor]).remove(0)
191 }
192
193 pub fn get_resource(
195 &self,
196 binding: Binding,
197 ) -> BindingResource<<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource> {
198 let stream_id = self.stream_id();
199 self.context.lock().get_resource(binding, stream_id)
200 }
201
202 fn do_create_from_slices(
203 &self,
204 descriptors: Vec<AllocationDescriptor<'_>>,
205 slices: Vec<&[u8]>,
206 ) -> Result<Vec<Allocation>, IoError> {
207 let mut state = self.context.lock();
208 let allocations = state.create(descriptors.clone(), self.stream_id())?;
209 let descriptors = descriptors
210 .into_iter()
211 .zip(allocations.iter())
212 .zip(slices)
213 .map(|((desc, alloc), data)| {
214 (
215 CopyDescriptor::new(
216 alloc.handle.clone().binding(),
217 desc.shape,
218 &alloc.strides,
219 desc.elem_size,
220 ),
221 Bytes::from_bytes_vec(data.to_vec()),
222 )
223 })
224 .collect();
225 let stream_id = self.stream_id();
226 state.write(descriptors, stream_id)?;
227 Ok(allocations)
228 }
229
230 fn do_create(
231 &self,
232 descriptors: Vec<AllocationDescriptor<'_>>,
233 mut data: Vec<Bytes>,
234 ) -> Result<Vec<Allocation>, IoError> {
235 self.staging(data.iter_mut(), true);
236
237 let mut state = self.context.lock();
238 let allocations = state.create(descriptors.clone(), self.stream_id())?;
239 let descriptors = descriptors
240 .into_iter()
241 .zip(allocations.iter())
242 .zip(data)
243 .map(|((desc, alloc), data)| {
244 (
245 CopyDescriptor::new(
246 alloc.handle.clone().binding(),
247 desc.shape,
248 &alloc.strides,
249 desc.elem_size,
250 ),
251 data,
252 )
253 })
254 .collect();
255 let stream_id = self.stream_id();
256 state.write(descriptors, stream_id)?;
257 Ok(allocations)
258 }
259
260 pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
266 let shape = [slice.len()];
267
268 self.do_create_from_slices(
269 vec![AllocationDescriptor::new(
270 AllocationKind::Contiguous,
271 &shape,
272 1,
273 )],
274 vec![slice],
275 )
276 .unwrap()
277 .remove(0)
278 .handle
279 }
280
281 pub fn create(&self, data: Bytes) -> Handle {
283 let shape = [data.len()];
284
285 self.do_create(
286 vec![AllocationDescriptor::new(
287 AllocationKind::Contiguous,
288 &shape,
289 1,
290 )],
291 vec![data],
292 )
293 .unwrap()
294 .remove(0)
295 .handle
296 }
297
298 pub fn create_tensor_from_slice(
316 &self,
317 slice: &[u8],
318 shape: &[usize],
319 elem_size: usize,
320 ) -> Allocation {
321 self.do_create_from_slices(
322 vec![AllocationDescriptor::new(
323 AllocationKind::Optimized,
324 shape,
325 elem_size,
326 )],
327 vec![slice],
328 )
329 .unwrap()
330 .remove(0)
331 }
332
333 pub fn create_tensor(&self, bytes: Bytes, shape: &[usize], elem_size: usize) -> Allocation {
347 self.do_create(
348 vec![AllocationDescriptor::new(
349 AllocationKind::Optimized,
350 shape,
351 elem_size,
352 )],
353 vec![bytes],
354 )
355 .unwrap()
356 .remove(0)
357 }
358
359 pub fn create_tensors_from_slices(
367 &self,
368 descriptors: Vec<(AllocationDescriptor<'_>, &[u8])>,
369 ) -> Vec<Allocation> {
370 let (descriptors, data) = descriptors.into_iter().unzip();
371
372 self.do_create_from_slices(descriptors, data).unwrap()
373 }
374
375 pub fn create_tensors(
379 &self,
380 descriptors: Vec<(AllocationDescriptor<'_>, Bytes)>,
381 ) -> Vec<Allocation> {
382 let (descriptors, data) = descriptors.into_iter().unzip();
383
384 self.do_create(descriptors, data).unwrap()
385 }
386
387 fn do_empty(
388 &self,
389 descriptors: Vec<AllocationDescriptor<'_>>,
390 ) -> Result<Vec<Allocation>, IoError> {
391 let mut state = self.context.lock();
392 state.create(descriptors, self.stream_id())
393 }
394
395 pub fn empty(&self, size: usize) -> Handle {
397 let shape = [size];
398 let descriptor = AllocationDescriptor::new(AllocationKind::Contiguous, &shape, 1);
399 self.do_empty(vec![descriptor]).unwrap().remove(0).handle
400 }
401
402 pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> Allocation {
405 let descriptor = AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size);
406 self.do_empty(vec![descriptor]).unwrap().remove(0)
407 }
408
409 pub fn empty_tensors(&self, descriptors: Vec<AllocationDescriptor<'_>>) -> Vec<Allocation> {
412 self.do_empty(descriptors).unwrap()
413 }
414
415 pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
418 where
419 I: Iterator<Item = &'a mut Bytes>,
420 {
421 let has_staging = |b: &Bytes| match b.property() {
422 AllocationProperty::Pinned => false,
423 AllocationProperty::File => true,
424 AllocationProperty::Native | AllocationProperty::Other => !file_only,
425 };
426
427 let mut to_be_updated = Vec::new();
428 let sizes = bytes
429 .filter_map(|b| match has_staging(b) {
430 true => {
431 let len = b.len();
432 to_be_updated.push(b);
433 Some(len)
434 }
435 false => None,
436 })
437 .collect::<Vec<usize>>();
438
439 if sizes.is_empty() {
440 return;
441 }
442
443 let stream_id = self.stream_id();
444 let mut context = self.context.lock();
445 let stagings = match context.staging(&sizes, stream_id) {
446 Ok(val) => val,
447 Err(_) => return,
448 };
449 core::mem::drop(context);
450
451 to_be_updated
452 .into_iter()
453 .zip(stagings)
454 .for_each(|(b, mut staging)| {
455 b.copy_into(&mut staging);
456 core::mem::swap(b, &mut staging);
457 });
458 }
459
460 pub fn to_client(&self, src: Handle, dst_server: &Self) -> Allocation {
462 let shape = [src.size() as usize];
463 let src_descriptor = src.copy_descriptor(&shape, &[1], 1);
464
465 if R::Server::SERVER_COMM_ENABLED {
466 self.to_client_tensor(src_descriptor, dst_server)
467 } else {
468 let alloc_desc = AllocationDescriptor::new(
469 AllocationKind::Contiguous,
470 src_descriptor.shape,
471 src_descriptor.elem_size,
472 );
473 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
474 }
475 }
476
477 pub fn to_client_tensor(
481 &self,
482 src_descriptor: CopyDescriptor<'_>,
483 dst_server: &Self,
484 ) -> Allocation {
485 if R::Server::SERVER_COMM_ENABLED {
486 let guard = self.context.lock_device_kind();
487 let mut server_src = self.context.lock();
488 let mut server_dst = dst_server.context.lock();
489
490 let copied = R::Server::copy(
491 server_src.deref_mut(),
492 server_dst.deref_mut(),
493 src_descriptor,
494 self.stream_id(),
495 dst_server.stream_id(),
496 )
497 .unwrap();
498 core::mem::drop(server_src);
499 core::mem::drop(server_dst);
500 core::mem::drop(guard);
501 copied
502 } else {
503 let alloc_desc = AllocationDescriptor::new(
504 AllocationKind::Optimized,
505 src_descriptor.shape,
506 src_descriptor.elem_size,
507 );
508 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
509 }
510 }
511
512 #[track_caller]
513 unsafe fn launch_inner(
514 &self,
515 kernel: <R::Server as ComputeServer>::Kernel,
516 count: CubeCount,
517 bindings: Bindings,
518 mode: ExecutionMode,
519 stream_id: StreamId,
520 ) -> Result<(), LaunchError> {
521 let level = self.utilities.logger.profile_level();
522
523 match level {
524 None | Some(ProfileLevel::ExecutionOnly) => {
525 let mut state = self.context.lock();
526 let name = kernel.name();
527
528 let result = unsafe { state.launch(kernel, count, bindings, mode, stream_id) };
529
530 if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
531 let info = type_name_format(name, TypeNameFormatLevel::Balanced);
532 self.utilities.logger.register_execution(info);
533 }
534 result
535 }
536 Some(level) => {
537 let name = kernel.name();
538 let kernel_id = kernel.id();
539 let (result, profile) = self
540 .profile(
541 || unsafe {
542 let mut state = self.context.lock();
543 state.launch(kernel, count.clone(), bindings, mode, stream_id)
544 },
545 name,
546 )
547 .unwrap();
548 let info = match level {
549 ProfileLevel::Full => {
550 format!("{name}: {kernel_id} CubeCount {count:?}")
551 }
552 _ => type_name_format(name, TypeNameFormatLevel::Balanced),
553 };
554 self.utilities.logger.register_profiled(info, profile);
555 result
556 }
557 }
558 }
559
560 #[track_caller]
562 pub fn launch(
563 &self,
564 kernel: <R::Server as ComputeServer>::Kernel,
565 count: CubeCount,
566 bindings: Bindings,
567 ) -> Result<(), LaunchError> {
568 unsafe {
570 self.launch_inner(
571 kernel,
572 count,
573 bindings,
574 ExecutionMode::Checked,
575 self.stream_id(),
576 )
577 }
578 }
579
580 #[track_caller]
588 pub unsafe fn launch_unchecked(
589 &self,
590 kernel: <R::Server as ComputeServer>::Kernel,
591 count: CubeCount,
592 bindings: Bindings,
593 ) -> Result<(), LaunchError> {
594 unsafe {
596 self.launch_inner(
597 kernel,
598 count,
599 bindings,
600 ExecutionMode::Unchecked,
601 self.stream_id(),
602 )
603 }
604 }
605
606 pub fn flush(&self) {
608 let stream_id = self.stream_id();
609 self.context.lock().flush(stream_id)
610 }
611
612 pub fn sync(&self) -> DynFut<Result<(), ExecutionError>> {
614 let stream_id = self.stream_id();
615 let mut state = self.context.lock();
616 let fut = state.sync(stream_id);
617 core::mem::drop(state);
618 self.utilities.logger.profile_summary();
619
620 fut
621 }
622
623 pub fn properties(&self) -> &DeviceProperties {
625 &self.utilities.properties
626 }
627
628 pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
632 Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
633 }
634
635 pub fn memory_usage(&self) -> MemoryUsage {
637 self.context.lock().memory_usage(self.stream_id())
638 }
639
640 pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
646 self.context.lock().allocation_mode(mode, self.stream_id())
647 }
648
649 pub fn memory_persistent_allocation<Input, Output, Func: Fn(Input) -> Output>(
656 &self,
657 input: Input,
658 func: Func,
659 ) -> Output {
660 let device_guard = self.context.lock_device();
661
662 self.context
663 .lock()
664 .allocation_mode(MemoryAllocationMode::Persistent, self.stream_id());
665
666 let output = func(input);
667
668 self.context
669 .lock()
670 .allocation_mode(MemoryAllocationMode::Auto, self.stream_id());
671
672 core::mem::drop(device_guard);
673
674 output
675 }
676
677 pub fn memory_cleanup(&self) {
682 self.context.lock().memory_cleanup(self.stream_id())
683 }
684
685 #[track_caller]
687 pub fn profile<O>(
688 &self,
689 func: impl FnOnce() -> O,
690 #[allow(unused)] func_name: &str,
691 ) -> Result<(O, ProfileDuration), ProfileError> {
692 #[cfg(feature = "profile-tracy")]
695 let location = std::panic::Location::caller();
696
697 #[cfg(feature = "profile-tracy")]
699 let _span = tracy_client::Client::running().unwrap().span_alloc(
700 None,
701 func_name,
702 location.file(),
703 location.line(),
704 0,
705 );
706
707 let device_guard = self.context.lock_device();
708
709 #[cfg(feature = "profile-tracy")]
710 let gpu_span = if self.state.properties.timing_method == TimingMethod::Device {
711 let gpu_span = self
712 .state
713 .gpu_client
714 .span_alloc(func_name, "profile", location.file(), location.line())
715 .unwrap();
716 Some(gpu_span)
717 } else {
718 None
719 };
720
721 let token = self.context.lock().start_profile(self.stream_id());
722
723 let out = func();
724
725 let result = self.context.lock().end_profile(self.stream_id(), token);
726
727 #[cfg(feature = "profile-tracy")]
728 if let Some(mut gpu_span) = gpu_span {
729 gpu_span.end_zone();
730 let epoch = self.state.epoch_time;
731 result = result.map(|result| {
733 ProfileDuration::new(
734 Box::pin(async move {
735 let ticks = result.resolve().await;
736 let start_duration = ticks.start_duration_since(epoch).as_nanos() as i64;
737 let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
738 gpu_span.upload_timestamp_start(start_duration);
739 gpu_span.upload_timestamp_end(end_duration);
740 ticks
741 }),
742 TimingMethod::Device,
743 )
744 });
745 }
746 core::mem::drop(device_guard);
747
748 match result {
749 Ok(result) => Ok((out, result)),
750 Err(err) => Err(err),
751 }
752 }
753
754 fn change_client_sync(
756 &self,
757 src_descriptor: CopyDescriptor<'_>,
758 alloc_descriptor: AllocationDescriptor<'_>,
759 dst_server: &Self,
760 ) -> Allocation {
761 let shape = src_descriptor.shape;
762 let elem_size = src_descriptor.elem_size;
763 let stream_id = self.stream_id();
764
765 let alloc = dst_server
767 .context
768 .lock()
769 .create(vec![alloc_descriptor], self.stream_id())
770 .unwrap()
771 .remove(0);
772
773 let read = self.context.lock().read(vec![src_descriptor], stream_id);
774 let mut data = cubecl_common::future::block_on(read).unwrap();
775
776 let desc_descriptor = CopyDescriptor {
777 binding: alloc.handle.clone().binding(),
778 shape,
779 strides: &alloc.strides,
780 elem_size,
781 };
782
783 dst_server
784 .context
785 .lock()
786 .write(vec![(desc_descriptor, data.remove(0))], stream_id)
787 .unwrap();
788
789 alloc
790 }
791
792 pub fn io_optimized_line_sizes(&self, elem: &StorageType) -> impl Iterator<Item = u8> + Clone {
794 let load_width = self.properties().hardware.load_width as usize;
795 let max = (load_width / elem.size_bits()) as u8;
796 let supported = R::supported_line_sizes();
797 supported.iter().filter(move |v| **v <= max).cloned()
798 }
799
800 pub fn io_optimized_line_sizes_unchecked(
804 &self,
805 size: usize,
806 ) -> impl Iterator<Item = u8> + Clone {
807 let load_width = self.properties().hardware.load_width as usize;
808 let size_bits = size * 8;
809 let max = load_width / size_bits;
810 let max = usize::min(R::max_global_line_size() as usize, max);
813
814 let num_candidates = max.trailing_zeros() + 1;
816
817 (0..num_candidates).map(|i| 2u8.pow(i)).rev()
818 }
819}