1use crate::{
2 config::{TypeNameFormatLevel, type_name_format},
3 kernel::KernelMetadata,
4 logging::ProfileLevel,
5 memory_management::{MemoryAllocationMode, MemoryUsage},
6 runtime::Runtime,
7 server::{
8 Allocation, AllocationDescriptor, AllocationKind, Binding, Bindings, ComputeServer,
9 CopyDescriptor, CubeCount, ExecutionError, ExecutionMode, Handle, IoError, LaunchError,
10 ProfileError, ServerCommunication, ServerUtilities,
11 },
12 storage::{BindingResource, ComputeStorage},
13};
14use alloc::format;
15use alloc::sync::Arc;
16use alloc::vec;
17use alloc::vec::Vec;
18use core::ops::DerefMut;
19use cubecl_common::{
20 bytes::{AllocationProperty, Bytes},
21 device::{Device, DeviceContext},
22 future::DynFut,
23 profile::ProfileDuration,
24};
25use cubecl_ir::{DeviceProperties, LineSize};
26
27#[cfg(feature = "profile-tracy")]
28use alloc::boxed::Box;
29
30#[allow(unused)]
31use cubecl_common::profile::TimingMethod;
32use cubecl_common::stream_id::StreamId;
33
34pub struct ComputeClient<R: Runtime> {
37 context: DeviceContext<R::Server>,
38 utilities: Arc<ServerUtilities<R::Server>>,
39 stream_id: Option<StreamId>,
40}
41
42impl<R: Runtime> Clone for ComputeClient<R> {
43 fn clone(&self) -> Self {
44 Self {
45 context: self.context.clone(),
46 utilities: self.utilities.clone(),
47 stream_id: self.stream_id,
48 }
49 }
50}
51
52impl<R: Runtime> ComputeClient<R> {
53 pub fn info(&self) -> &<R::Server as ComputeServer>::Info {
55 &self.utilities.info
56 }
57
58 pub fn init<D: Device>(device: &D, server: R::Server) -> Self {
60 let utilities = server.utilities();
61
62 let context = DeviceContext::<R::Server>::insert(device, server)
63 .expect("Can't create a new client on an already registered server");
64
65 Self {
66 context,
67 utilities,
68 stream_id: None,
69 }
70 }
71
72 pub fn load<D: Device>(device: &D) -> Self {
74 let context = DeviceContext::<R::Server>::locate(device);
75 let utilities = context.lock().utilities();
76
77 Self {
78 context,
79 utilities,
80 stream_id: None,
81 }
82 }
83
84 fn stream_id(&self) -> StreamId {
85 match self.stream_id {
86 Some(val) => val,
87 None => StreamId::current(),
88 }
89 }
90
91 pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
97 self.stream_id = Some(stream_id);
98 }
99
100 fn do_read(&self, descriptors: Vec<CopyDescriptor<'_>>) -> DynFut<Result<Vec<Bytes>, IoError>> {
101 let stream_id = self.stream_id();
102 let mut state = self.context.lock();
103 let fut = state.read(descriptors, stream_id);
104 core::mem::drop(state);
105 fut
106 }
107
108 pub fn read_async(
110 &self,
111 handles: Vec<Handle>,
112 ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
113 let strides = [1];
114 let shapes = handles
115 .iter()
116 .map(|it| [it.size() as usize])
117 .collect::<Vec<_>>();
118 let bindings = handles
119 .into_iter()
120 .map(|it| it.binding())
121 .collect::<Vec<_>>();
122 let descriptors = bindings
123 .into_iter()
124 .zip(shapes.iter())
125 .map(|(binding, shape)| CopyDescriptor::new(binding, shape, &strides, 1))
126 .collect();
127
128 self.do_read(descriptors)
129 }
130
131 pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
137 cubecl_common::reader::read_sync(self.read_async(handles)).expect("TODO")
138 }
139
140 pub fn read_one(&self, handle: Handle) -> Bytes {
145 cubecl_common::reader::read_sync(self.read_async(vec![handle]))
146 .expect("TODO")
147 .remove(0)
148 }
149
150 pub fn read_tensor_async(
152 &self,
153 descriptors: Vec<CopyDescriptor<'_>>,
154 ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
155 self.do_read(descriptors)
156 }
157
158 pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor<'_>>) -> Vec<Bytes> {
171 cubecl_common::reader::read_sync(self.read_tensor_async(descriptors)).expect("TODO")
172 }
173
174 pub fn read_one_tensor_async(
177 &self,
178 descriptor: CopyDescriptor<'_>,
179 ) -> impl Future<Output = Result<Bytes, IoError>> + Send {
180 let fut = self.read_tensor_async(vec![descriptor]);
181
182 async { Ok(fut.await?.remove(0)) }
183 }
184
185 pub fn read_one_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
191 self.read_tensor(vec![descriptor]).remove(0)
192 }
193
194 pub fn get_resource(
196 &self,
197 binding: Binding,
198 ) -> BindingResource<<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource> {
199 let stream_id = self.stream_id();
200 self.context.lock().get_resource(binding, stream_id)
201 }
202
203 fn do_create_from_slices(
204 &self,
205 descriptors: Vec<AllocationDescriptor<'_>>,
206 slices: Vec<&[u8]>,
207 ) -> Result<Vec<Allocation>, IoError> {
208 let mut state = self.context.lock();
209 let allocations = state.create(descriptors.clone(), self.stream_id())?;
210 let descriptors = descriptors
211 .into_iter()
212 .zip(allocations.iter())
213 .zip(slices)
214 .map(|((desc, alloc), data)| {
215 (
216 CopyDescriptor::new(
217 alloc.handle.clone().binding(),
218 desc.shape,
219 &alloc.strides,
220 desc.elem_size,
221 ),
222 Bytes::from_bytes_vec(data.to_vec()),
223 )
224 })
225 .collect();
226 let stream_id = self.stream_id();
227 state.write(descriptors, stream_id)?;
228 Ok(allocations)
229 }
230
231 fn do_create(
232 &self,
233 descriptors: Vec<AllocationDescriptor<'_>>,
234 mut data: Vec<Bytes>,
235 ) -> Result<Vec<Allocation>, IoError> {
236 self.staging(data.iter_mut(), true);
237
238 let mut state = self.context.lock();
239 let allocations = state.create(descriptors.clone(), self.stream_id())?;
240 let descriptors = descriptors
241 .into_iter()
242 .zip(allocations.iter())
243 .zip(data)
244 .map(|((desc, alloc), data)| {
245 (
246 CopyDescriptor::new(
247 alloc.handle.clone().binding(),
248 desc.shape,
249 &alloc.strides,
250 desc.elem_size,
251 ),
252 data,
253 )
254 })
255 .collect();
256 let stream_id = self.stream_id();
257 state.write(descriptors, stream_id)?;
258 Ok(allocations)
259 }
260
261 pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
267 let shape = [slice.len()];
268
269 self.do_create_from_slices(
270 vec![AllocationDescriptor::new(
271 AllocationKind::Contiguous,
272 &shape,
273 1,
274 )],
275 vec![slice],
276 )
277 .unwrap()
278 .remove(0)
279 .handle
280 }
281
282 pub fn create(&self, data: Bytes) -> Handle {
284 let shape = [data.len()];
285
286 self.do_create(
287 vec![AllocationDescriptor::new(
288 AllocationKind::Contiguous,
289 &shape,
290 1,
291 )],
292 vec![data],
293 )
294 .unwrap()
295 .remove(0)
296 .handle
297 }
298
299 pub fn create_tensor_from_slice(
317 &self,
318 slice: &[u8],
319 shape: &[usize],
320 elem_size: usize,
321 ) -> Allocation {
322 self.do_create_from_slices(
323 vec![AllocationDescriptor::new(
324 AllocationKind::Optimized,
325 shape,
326 elem_size,
327 )],
328 vec![slice],
329 )
330 .unwrap()
331 .remove(0)
332 }
333
334 pub fn create_tensor(&self, bytes: Bytes, shape: &[usize], elem_size: usize) -> Allocation {
348 self.do_create(
349 vec![AllocationDescriptor::new(
350 AllocationKind::Optimized,
351 shape,
352 elem_size,
353 )],
354 vec![bytes],
355 )
356 .unwrap()
357 .remove(0)
358 }
359
360 pub fn create_tensors_from_slices(
368 &self,
369 descriptors: Vec<(AllocationDescriptor<'_>, &[u8])>,
370 ) -> Vec<Allocation> {
371 let (descriptors, data) = descriptors.into_iter().unzip();
372
373 self.do_create_from_slices(descriptors, data).unwrap()
374 }
375
376 pub fn create_tensors(
380 &self,
381 descriptors: Vec<(AllocationDescriptor<'_>, Bytes)>,
382 ) -> Vec<Allocation> {
383 let (descriptors, data) = descriptors.into_iter().unzip();
384
385 self.do_create(descriptors, data).unwrap()
386 }
387
388 fn do_empty(
389 &self,
390 descriptors: Vec<AllocationDescriptor<'_>>,
391 ) -> Result<Vec<Allocation>, IoError> {
392 let mut state = self.context.lock();
393 state.create(descriptors, self.stream_id())
394 }
395
396 pub fn empty(&self, size: usize) -> Handle {
398 let shape = [size];
399 let descriptor = AllocationDescriptor::new(AllocationKind::Contiguous, &shape, 1);
400 self.do_empty(vec![descriptor]).unwrap().remove(0).handle
401 }
402
403 pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> Allocation {
406 let descriptor = AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size);
407 self.do_empty(vec![descriptor]).unwrap().remove(0)
408 }
409
410 pub fn empty_tensors(&self, descriptors: Vec<AllocationDescriptor<'_>>) -> Vec<Allocation> {
413 self.do_empty(descriptors).unwrap()
414 }
415
416 pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
419 where
420 I: Iterator<Item = &'a mut Bytes>,
421 {
422 let has_staging = |b: &Bytes| match b.property() {
423 AllocationProperty::Pinned => false,
424 AllocationProperty::File => true,
425 AllocationProperty::Native | AllocationProperty::Other => !file_only,
426 };
427
428 let mut to_be_updated = Vec::new();
429 let sizes = bytes
430 .filter_map(|b| match has_staging(b) {
431 true => {
432 let len = b.len();
433 to_be_updated.push(b);
434 Some(len)
435 }
436 false => None,
437 })
438 .collect::<Vec<usize>>();
439
440 if sizes.is_empty() {
441 return;
442 }
443
444 let stream_id = self.stream_id();
445 let mut context = self.context.lock();
446 let stagings = match context.staging(&sizes, stream_id) {
447 Ok(val) => val,
448 Err(_) => return,
449 };
450 core::mem::drop(context);
451
452 to_be_updated
453 .into_iter()
454 .zip(stagings)
455 .for_each(|(b, mut staging)| {
456 b.copy_into(&mut staging);
457 core::mem::swap(b, &mut staging);
458 });
459 }
460
461 #[cfg_attr(
463 feature = "tracing",
464 tracing::instrument(level = "trace", skip(self, src, dst_server))
465 )]
466 pub fn to_client(&self, src: Handle, dst_server: &Self) -> Allocation {
467 let shape = [src.size() as usize];
468 let src_descriptor = src.copy_descriptor(&shape, &[1], 1);
469
470 if R::Server::SERVER_COMM_ENABLED {
471 self.to_client_tensor(src_descriptor, dst_server)
472 } else {
473 let alloc_desc = AllocationDescriptor::new(
474 AllocationKind::Contiguous,
475 src_descriptor.shape,
476 src_descriptor.elem_size,
477 );
478 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
479 }
480 }
481
482 #[cfg_attr(
486 feature = "tracing",
487 tracing::instrument(level = "trace", skip(self, src_descriptor, dst_server))
488 )]
489 pub fn to_client_tensor(
490 &self,
491 src_descriptor: CopyDescriptor<'_>,
492 dst_server: &Self,
493 ) -> Allocation {
494 if R::Server::SERVER_COMM_ENABLED {
495 let guard = self.context.lock_device_kind();
496 let mut server_src = self.context.lock();
497 let mut server_dst = dst_server.context.lock();
498
499 let copied = R::Server::copy(
500 server_src.deref_mut(),
501 server_dst.deref_mut(),
502 src_descriptor,
503 self.stream_id(),
504 dst_server.stream_id(),
505 )
506 .unwrap();
507 core::mem::drop(server_src);
508 core::mem::drop(server_dst);
509 core::mem::drop(guard);
510 copied
511 } else {
512 let alloc_desc = AllocationDescriptor::new(
513 AllocationKind::Optimized,
514 src_descriptor.shape,
515 src_descriptor.elem_size,
516 );
517 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
518 }
519 }
520
521 #[track_caller]
522 #[cfg_attr(feature = "tracing", tracing::instrument(level="trace",
523 skip(self, kernel, bindings),
524 fields(
525 kernel.name = %kernel.name(),
526 kernel.id = %kernel.id(),
527 )
528 ))]
529 unsafe fn launch_inner(
530 &self,
531 kernel: <R::Server as ComputeServer>::Kernel,
532 count: CubeCount,
533 bindings: Bindings,
534 mode: ExecutionMode,
535 stream_id: StreamId,
536 ) -> Result<(), LaunchError> {
537 let level = self.utilities.logger.profile_level();
538
539 match level {
540 None | Some(ProfileLevel::ExecutionOnly) => {
541 let mut state = self.context.lock();
542 let name = kernel.name();
543
544 let result = unsafe { state.launch(kernel, count, bindings, mode, stream_id) };
545
546 if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
547 let info = type_name_format(name, TypeNameFormatLevel::Balanced);
548 self.utilities.logger.register_execution(info);
549 }
550 result
551 }
552 Some(level) => {
553 let name = kernel.name();
554 let kernel_id = kernel.id();
555 let (result, profile) = self
556 .profile(
557 || unsafe {
558 let mut state = self.context.lock();
559 state.launch(kernel, count.clone(), bindings, mode, stream_id)
560 },
561 name,
562 )
563 .unwrap();
564 let info = match level {
565 ProfileLevel::Full => {
566 format!("{name}: {kernel_id} CubeCount {count:?}")
567 }
568 _ => type_name_format(name, TypeNameFormatLevel::Balanced),
569 };
570 self.utilities.logger.register_profiled(info, profile);
571 result
572 }
573 }
574 }
575
576 #[track_caller]
578 pub fn launch(
579 &self,
580 kernel: <R::Server as ComputeServer>::Kernel,
581 count: CubeCount,
582 bindings: Bindings,
583 ) -> Result<(), LaunchError> {
584 unsafe {
586 self.launch_inner(
587 kernel,
588 count,
589 bindings,
590 ExecutionMode::Checked,
591 self.stream_id(),
592 )
593 }
594 }
595
596 #[track_caller]
604 pub unsafe fn launch_unchecked(
605 &self,
606 kernel: <R::Server as ComputeServer>::Kernel,
607 count: CubeCount,
608 bindings: Bindings,
609 ) -> Result<(), LaunchError> {
610 unsafe {
612 self.launch_inner(
613 kernel,
614 count,
615 bindings,
616 ExecutionMode::Unchecked,
617 self.stream_id(),
618 )
619 }
620 }
621
622 pub fn flush(&self) {
624 let stream_id = self.stream_id();
625 self.context.lock().flush(stream_id)
626 }
627
628 pub fn sync(&self) -> DynFut<Result<(), ExecutionError>> {
630 let stream_id = self.stream_id();
631 let mut state = self.context.lock();
632 let fut = state.sync(stream_id);
633 core::mem::drop(state);
634 self.utilities.logger.profile_summary();
635
636 fut
637 }
638
639 pub fn properties(&self) -> &DeviceProperties {
641 &self.utilities.properties
642 }
643
644 pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
648 Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
649 }
650
651 pub fn memory_usage(&self) -> MemoryUsage {
653 self.context.lock().memory_usage(self.stream_id())
654 }
655
656 pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
662 self.context.lock().allocation_mode(mode, self.stream_id())
663 }
664
665 pub fn memory_persistent_allocation<Input, Output, Func: Fn(Input) -> Output>(
672 &self,
673 input: Input,
674 func: Func,
675 ) -> Output {
676 let device_guard = self.context.lock_device();
677
678 self.context
679 .lock()
680 .allocation_mode(MemoryAllocationMode::Persistent, self.stream_id());
681
682 let output = func(input);
683
684 self.context
685 .lock()
686 .allocation_mode(MemoryAllocationMode::Auto, self.stream_id());
687
688 core::mem::drop(device_guard);
689
690 output
691 }
692
693 pub fn memory_cleanup(&self) {
698 self.context.lock().memory_cleanup(self.stream_id())
699 }
700
701 #[track_caller]
703 pub fn profile<O>(
704 &self,
705 func: impl FnOnce() -> O,
706 #[allow(unused)] func_name: &str,
707 ) -> Result<(O, ProfileDuration), ProfileError> {
708 #[cfg(feature = "profile-tracy")]
711 let location = std::panic::Location::caller();
712
713 #[cfg(feature = "profile-tracy")]
715 let _span = tracy_client::Client::running().unwrap().span_alloc(
716 None,
717 func_name,
718 location.file(),
719 location.line(),
720 0,
721 );
722
723 let device_guard = self.context.lock_device();
724
725 #[cfg(feature = "profile-tracy")]
726 let gpu_span = if self.utilities.properties.timing_method == TimingMethod::Device {
727 let gpu_span = self
728 .utilities
729 .gpu_client
730 .span_alloc(func_name, "profile", location.file(), location.line())
731 .unwrap();
732 Some(gpu_span)
733 } else {
734 None
735 };
736
737 let token = self.context.lock().start_profile(self.stream_id());
738
739 let out = func();
740
741 #[allow(unused_mut, reason = "Used in profile-tracy")]
742 let mut result = self.context.lock().end_profile(self.stream_id(), token);
743
744 #[cfg(feature = "profile-tracy")]
745 if let Some(mut gpu_span) = gpu_span {
746 gpu_span.end_zone();
747 let epoch = self.utilities.epoch_time;
748 result = result.map(|result| {
750 ProfileDuration::new(
751 Box::pin(async move {
752 let ticks = result.resolve().await;
753 let start_duration = ticks.start_duration_since(epoch).as_nanos() as i64;
754 let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
755 gpu_span.upload_timestamp_start(start_duration);
756 gpu_span.upload_timestamp_end(end_duration);
757 ticks
758 }),
759 TimingMethod::Device,
760 )
761 });
762 }
763 core::mem::drop(device_guard);
764
765 match result {
766 Ok(result) => Ok((out, result)),
767 Err(err) => Err(err),
768 }
769 }
770
771 #[cfg_attr(
773 feature = "tracing",
774 tracing::instrument(
775 level = "trace",
776 skip(self, src_descriptor, alloc_descriptor, dst_server)
777 )
778 )]
779 fn change_client_sync(
780 &self,
781 src_descriptor: CopyDescriptor<'_>,
782 alloc_descriptor: AllocationDescriptor<'_>,
783 dst_server: &Self,
784 ) -> Allocation {
785 let shape = src_descriptor.shape;
786 let elem_size = src_descriptor.elem_size;
787 let stream_id = self.stream_id();
788
789 let alloc = dst_server
791 .context
792 .lock()
793 .create(vec![alloc_descriptor], self.stream_id())
794 .unwrap()
795 .remove(0);
796
797 let read = self.context.lock().read(vec![src_descriptor], stream_id);
798 let mut data = cubecl_common::future::block_on(read).unwrap();
799
800 let desc_descriptor = CopyDescriptor {
801 binding: alloc.handle.clone().binding(),
802 shape,
803 strides: &alloc.strides,
804 elem_size,
805 };
806
807 dst_server
808 .context
809 .lock()
810 .write(vec![(desc_descriptor, data.remove(0))], stream_id)
811 .unwrap();
812
813 alloc
814 }
815
816 pub fn io_optimized_line_sizes(&self, size: usize) -> impl Iterator<Item = LineSize> + Clone {
818 let load_width = self.properties().hardware.load_width as usize;
819 let size_bits = size * 8;
820 let max = load_width / size_bits;
821 let max = usize::min(self.properties().hardware.max_line_size, max);
822
823 let num_candidates = max.trailing_zeros() + 1;
825
826 (0..num_candidates).map(|i| 2usize.pow(i)).rev()
827 }
828}