1use crate::{
2 config::{TypeNameFormatLevel, type_name_format},
3 kernel::KernelMetadata,
4 logging::ProfileLevel,
5 memory_management::{MemoryAllocationMode, MemoryUsage},
6 runtime::Runtime,
7 server::{
8 CommunicationId, ComputeServer, CopyDescriptor, CubeCount, ExecutionMode, Handle, IoError,
9 KernelArguments, MemoryLayout, MemoryLayoutDescriptor, MemoryLayoutPolicy,
10 MemoryLayoutStrategy, ProfileError, ReduceOperation, ServerCommunication, ServerError,
11 ServerUtilities,
12 },
13 storage::{ComputeStorage, ManagedResource},
14};
15use alloc::{format, sync::Arc, vec, vec::Vec};
16use cubecl_common::{
17 backtrace::BackTrace,
18 bytes::{AllocationProperty, Bytes},
19 device::{Device, DeviceId},
20 device_handle::DeviceHandle,
21 future::DynFut,
22 profile::ProfileDuration,
23};
24use cubecl_ir::{DeviceProperties, ElemType, VectorSize, features::Features};
25use cubecl_zspace::Shape;
26
27#[allow(unused)]
28use cubecl_common::profile::TimingMethod;
29use cubecl_common::stream_id::StreamId;
30
31pub struct ComputeClient<R: Runtime> {
34 device: DeviceHandle<R::Server>,
35 utilities: Arc<ServerUtilities<R::Server>>,
36 stream_id: Option<StreamId>,
37}
38
39impl<R: Runtime> Clone for ComputeClient<R> {
40 fn clone(&self) -> Self {
41 Self {
42 device: self.device.clone(),
43 utilities: self.utilities.clone(),
44 stream_id: self.stream_id,
45 }
46 }
47}
48
49impl<R: Runtime> ComputeClient<R> {
50 pub fn info(&self) -> &<R::Server as ComputeServer>::Info {
52 &self.utilities.info
53 }
54
55 pub fn init<D: Device>(device: &D, server: R::Server) -> Self {
57 let utilities = server.utilities();
58 let context = DeviceHandle::<R::Server>::insert(device.to_id(), server)
59 .expect("Can't create a new client on an already registered server");
60
61 Self {
62 device: context,
63 utilities,
64 stream_id: None,
65 }
66 }
67
68 pub fn load<D: Device>(device: &D) -> Self {
70 let context = DeviceHandle::<R::Server>::new(device.to_id());
71
72 let utilities = context
74 .utilities()
75 .downcast::<ServerUtilities<R::Server>>()
76 .expect("Can downcast to `ServerUtilities`");
77
78 Self {
79 device: context,
80 utilities,
81 stream_id: None,
82 }
83 }
84
85 fn stream_id(&self) -> StreamId {
86 match self.stream_id {
87 Some(val) => val,
88 None => StreamId::current(),
89 }
90 }
91
92 pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
98 self.stream_id = Some(stream_id);
99 }
100
101 fn do_read(&self, descriptors: Vec<CopyDescriptor>) -> DynFut<Result<Vec<Bytes>, ServerError>> {
102 let stream_id = self.stream_id();
103 self.device
104 .submit_blocking(move |server| server.read(descriptors, stream_id))
105 .unwrap()
106 }
107
108 pub fn read_async(
110 &self,
111 handles: Vec<Handle>,
112 ) -> impl Future<Output = Result<Vec<Bytes>, ServerError>> + Send {
113 let shapes = handles
114 .iter()
115 .map(|it| [it.size_in_used() as usize].into())
116 .collect::<Vec<Shape>>();
117 let descriptors = handles
118 .into_iter()
119 .zip(shapes)
120 .map(|(handle, shape)| CopyDescriptor::new(handle.binding(), shape, [1].into(), 1))
121 .collect();
122
123 self.do_read(descriptors)
124 }
125
126 pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
132 cubecl_common::reader::read_sync(self.read_async(handles)).expect("TODO")
133 }
134
135 pub fn read_one(&self, handle: Handle) -> Result<Bytes, ServerError> {
137 Ok(cubecl_common::reader::read_sync(self.read_async(vec![handle]))?.remove(0))
138 }
139
140 pub fn read_one_unchecked(&self, handle: Handle) -> Bytes {
146 cubecl_common::reader::read_sync(self.read_async(vec![handle]))
147 .unwrap()
148 .remove(0)
149 }
150
151 pub fn read_tensor_async(
153 &self,
154 descriptors: Vec<CopyDescriptor>,
155 ) -> impl Future<Output = Result<Vec<Bytes>, ServerError>> + Send {
156 self.do_read(descriptors)
157 }
158
159 pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor>) -> Vec<Bytes> {
172 cubecl_common::reader::read_sync(self.read_tensor_async(descriptors)).expect("TODO")
173 }
174
175 pub fn read_one_tensor_async(
178 &self,
179 descriptor: CopyDescriptor,
180 ) -> impl Future<Output = Result<Bytes, ServerError>> + Send {
181 let fut = self.read_tensor_async(vec![descriptor]);
182
183 async { Ok(fut.await?.remove(0)) }
184 }
185
186 pub fn read_one_unchecked_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
193 self.read_tensor(vec![descriptor]).remove(0)
194 }
195
196 pub fn get_resource(
198 &self,
199 handle: Handle,
200 ) -> Result<
201 ManagedResource<<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource>,
202 ServerError,
203 > {
204 let stream_id = self.stream_id();
205 let binding = handle.binding();
206
207 self.device
208 .submit_blocking(move |state| state.get_resource(binding, stream_id))
209 .unwrap()
210 }
211
212 fn do_create_from_slices(
213 &self,
214 descriptors: Vec<MemoryLayoutDescriptor>,
215 slices: Vec<Vec<u8>>,
216 ) -> Result<Vec<MemoryLayout>, IoError> {
217 let stream_id = self.stream_id();
218 let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
219
220 let descriptors = descriptors
221 .into_iter()
222 .zip(layouts.iter())
223 .zip(slices)
224 .map(|((desc, alloc), data)| {
225 (
226 CopyDescriptor::new(
227 alloc.memory.clone().binding(),
228 desc.shape,
229 alloc.strides.clone(),
230 desc.elem_size,
231 ),
232 Bytes::from_bytes_vec(data.to_vec()),
233 )
234 })
235 .collect::<Vec<_>>();
236
237 let (size, memory) = (handle_base.size(), handle_base.memory);
238 self.device.submit(move |server| {
239 server.initialize_memory(memory, size, stream_id);
240 server.write(descriptors, stream_id);
241 });
242
243 Ok(layouts)
244 }
245
246 fn do_create(
247 &self,
248 descriptors: Vec<MemoryLayoutDescriptor>,
249 mut data: Vec<Bytes>,
250 ) -> Result<Vec<MemoryLayout>, IoError> {
251 self.staging(data.iter_mut(), true);
252
253 let stream_id = self.stream_id();
254 let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
255
256 let descriptors = descriptors
257 .into_iter()
258 .zip(layouts.iter())
259 .zip(data)
260 .map(|((desc, layout), data)| {
261 (
262 CopyDescriptor::new(
263 layout.memory.clone().binding(),
264 desc.shape,
265 layout.strides.clone(),
266 desc.elem_size,
267 ),
268 Bytes::from_bytes_vec(data.to_vec()),
269 )
270 })
271 .collect::<Vec<_>>();
272
273 let (size, memory) = (handle_base.size(), handle_base.memory);
274 self.device.submit(move |server| {
275 server.initialize_memory(memory, size, stream_id);
276 server.write(descriptors, stream_id);
277 });
278
279 Ok(layouts)
280 }
281
282 pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
288 let shape: Shape = [slice.len()].into();
289
290 self.do_create_from_slices(
291 vec![MemoryLayoutDescriptor::new(
292 MemoryLayoutStrategy::Contiguous,
293 shape,
294 1,
295 )],
296 vec![slice.to_vec()],
297 )
298 .unwrap()
299 .remove(0)
300 .memory
301 }
302
303 pub fn exclusive<'a, Re: Send + 'static, F: FnOnce() -> Re + Send + 'a>(
305 &'a self,
306 task: F,
307 ) -> Result<Re, ServerError> {
308 self.device
310 .exclusive(task)
311 .map_err(|err| ServerError::Generic {
312 reason: format!("Communication channel with the server is down: {err:?}"),
313 backtrace: BackTrace::capture(),
314 })
315 }
316
317 pub fn memory_persistent_allocation<
319 'a,
320 Re: Send,
321 Input: Send,
322 F: FnOnce(Input) -> Re + Send + 'a,
323 >(
324 &'a self,
325 input: Input,
326 task: F,
327 ) -> Result<Re, ServerError> {
328 let stream_id = StreamId::current();
329
330 self.device.submit(move |server| {
331 server.allocation_mode(MemoryAllocationMode::Persistent, stream_id);
332 });
333
334 let output = task(input);
336
337 self.device.submit(move |server| {
338 server.allocation_mode(MemoryAllocationMode::Auto, stream_id);
339 });
340
341 Ok(output)
342 }
343
344 pub fn create(&self, data: Bytes) -> Handle {
346 let shape = [data.len()].into();
347
348 self.do_create(
349 vec![MemoryLayoutDescriptor::new(
350 MemoryLayoutStrategy::Contiguous,
351 shape,
352 1,
353 )],
354 vec![data],
355 )
356 .unwrap()
357 .remove(0)
358 .memory
359 }
360
361 pub fn create_tensor_from_slice(
379 &self,
380 slice: &[u8],
381 shape: Shape,
382 elem_size: usize,
383 ) -> MemoryLayout {
384 self.do_create_from_slices(
385 vec![MemoryLayoutDescriptor::new(
386 MemoryLayoutStrategy::Optimized,
387 shape,
388 elem_size,
389 )],
390 vec![slice.to_vec()],
391 )
392 .unwrap()
393 .remove(0)
394 }
395
396 pub fn create_tensor(&self, bytes: Bytes, shape: Shape, elem_size: usize) -> MemoryLayout {
410 self.do_create(
411 vec![MemoryLayoutDescriptor::new(
412 MemoryLayoutStrategy::Optimized,
413 shape,
414 elem_size,
415 )],
416 vec![bytes],
417 )
418 .unwrap()
419 .remove(0)
420 }
421
422 pub fn create_tensors_from_slices(
430 &self,
431 descriptors: Vec<(MemoryLayoutDescriptor, &[u8])>,
432 ) -> Vec<MemoryLayout> {
433 let mut data = Vec::with_capacity(descriptors.len());
434 let mut descriptors_ = Vec::with_capacity(descriptors.len());
435 for (a, b) in descriptors {
436 data.push(b.to_vec());
437 descriptors_.push(a);
438 }
439
440 self.do_create_from_slices(descriptors_, data).unwrap()
441 }
442
443 pub fn create_tensors(
447 &self,
448 descriptors: Vec<(MemoryLayoutDescriptor, Bytes)>,
449 ) -> Vec<MemoryLayout> {
450 let (descriptors, data) = descriptors.into_iter().unzip();
451
452 self.do_create(descriptors, data).unwrap()
453 }
454
455 fn do_empty(
456 &self,
457 descriptors: Vec<MemoryLayoutDescriptor>,
458 ) -> Result<Vec<MemoryLayout>, IoError> {
459 let stream_id = self.stream_id();
460 let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
461
462 let (size, memory) = (handle_base.size(), handle_base.memory);
463 self.device.submit(move |server| {
464 server.initialize_memory(memory, size, stream_id);
465 });
466
467 Ok(layouts)
468 }
469
470 pub fn empty(&self, size: usize) -> Handle {
472 let shape: Shape = [size].into();
473 let descriptor = MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Contiguous, shape, 1);
474 self.do_empty(vec![descriptor]).unwrap().remove(0).memory
475 }
476
477 pub fn empty_tensor(&self, shape: Shape, elem_size: usize) -> MemoryLayout {
480 let descriptor =
481 MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Optimized, shape, elem_size);
482 self.do_empty(vec![descriptor]).unwrap().remove(0)
483 }
484
485 pub fn empty_tensors(&self, descriptors: Vec<MemoryLayoutDescriptor>) -> Vec<MemoryLayout> {
488 self.do_empty(descriptors).unwrap()
489 }
490
491 pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
496 where
497 I: Iterator<Item = &'a mut Bytes>,
498 {
499 let has_staging = |b: &Bytes| match b.property() {
500 AllocationProperty::Pinned => false,
501 AllocationProperty::File => true,
502 AllocationProperty::Native | AllocationProperty::Other => !file_only,
503 };
504
505 let mut to_be_updated = Vec::new();
506 let sizes = bytes
507 .filter_map(|b| match has_staging(b) {
508 true => {
509 let len = b.len();
510 to_be_updated.push(b);
511 Some(len)
512 }
513 false => None,
514 })
515 .collect::<Vec<usize>>();
516
517 if sizes.is_empty() {
518 return;
519 }
520
521 let stream_id = self.stream_id();
522 let sizes = sizes.to_vec();
523 let stagings = self
524 .device
525 .submit_blocking(move |server| server.staging(&sizes, stream_id))
526 .unwrap();
527
528 let stagings = match stagings {
529 Ok(val) => val,
530 Err(_) => return,
531 };
532
533 to_be_updated
534 .into_iter()
535 .zip(stagings)
536 .for_each(|(b, mut staging)| {
537 b.copy_into(&mut staging);
538 core::mem::swap(b, &mut staging);
539 });
540 }
541
542 #[cfg_attr(
544 feature = "tracing",
545 tracing::instrument(level = "trace", skip(self, src, dst_server))
546 )]
547 pub fn to_client(&mut self, src: Handle, dst_server: &Self, dtype: ElemType) -> Handle {
548 let shape = [src.size_in_used() as usize];
549 let src_descriptor = src.copy_descriptor(shape.into(), [1].into(), 1);
550
551 if R::Server::SERVER_COMM_ENABLED {
552 self.to_client_tensor(src_descriptor, dst_server, dtype)
553 } else {
554 let alloc_desc = MemoryLayoutDescriptor::new(
555 MemoryLayoutStrategy::Contiguous,
556 src_descriptor.shape.clone(),
557 src_descriptor.elem_size,
558 );
559 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
560 .memory
561 }
562 }
563
564 #[cfg_attr(
566 feature = "tracing",
567 tracing::instrument(level = "trace", skip(self, device_ids))
568 )]
569 pub fn ensure_init_collective(&mut self, device_ids: Vec<DeviceId>) {
570 let comm_id = CommunicationId::from(device_ids.clone());
571 let is_comms_init = self
572 .utilities
573 .initialized_comms
574 .read()
575 .unwrap()
576 .contains(&comm_id);
577 if !is_comms_init {
578 self.device
579 .submit(move |server| server.comm_init(device_ids).unwrap());
580 let mut initialized_comms = self.utilities.initialized_comms.write().unwrap();
581 initialized_comms.insert(comm_id);
582 self.device.flush_queue();
584 }
585 }
586
587 #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
589 pub fn sync_collective(&self) {
590 if DeviceHandle::<R::Server>::is_blocking() {
591 panic!("Can't use `sync_collective` with a blocking device handle");
592 }
593 let stream_id = self.stream_id();
594
595 self.device.submit(move |server| {
596 server.sync_collective(stream_id).unwrap();
597 });
598
599 self.device.flush_queue();
602 }
603
604 #[cfg_attr(
606 feature = "tracing",
607 tracing::instrument(level = "trace", skip(self, src, dst, dtype, device_ids, op))
608 )]
609 pub fn all_reduce(
610 &mut self,
611 src: Handle,
612 dst: Handle,
613 dtype: ElemType,
614 device_ids: Vec<DeviceId>,
615 op: ReduceOperation,
616 ) {
617 if DeviceHandle::<R::Server>::is_blocking() {
618 panic!("Can't use `all_reduce` with a blocking device handle");
619 }
620
621 let stream_id = self.stream_id();
622 let src = src.binding();
623 let dst = dst.binding();
624
625 self.ensure_init_collective(device_ids.clone());
626
627 self.device.submit(move |server| {
628 server
629 .all_reduce(src, dst, dtype, stream_id, op, device_ids)
630 .unwrap();
631 });
632 }
633
634 #[cfg_attr(
638 feature = "tracing",
639 tracing::instrument(level = "trace", skip(self, src_descriptor, dst_server))
640 )]
641 pub fn to_client_tensor(
642 &mut self,
643 src_descriptor: CopyDescriptor,
644 dst_server: &Self,
645 dtype: ElemType,
646 ) -> Handle {
647 let stream_id_src = self.stream_id();
648 let stream_id_dst = dst_server.stream_id();
649
650 let device_id_src = self.device.device_id();
651 let device_id_dst = dst_server.device.device_id();
652
653 let mut dst_server = dst_server.clone();
654 let handle = Handle::new(stream_id_dst, src_descriptor.handle.size_in_used());
655 let handle_cloned = handle.clone();
656
657 let device_ids = vec![device_id_src, device_id_dst];
658 self.ensure_init_collective(device_ids.clone());
659 dst_server.ensure_init_collective(device_ids);
660
661 self.device.submit(move |server_src| {
662 server_src
663 .send(src_descriptor, dtype, stream_id_src, device_id_dst)
664 .unwrap()
665 });
666
667 dst_server.device.submit(move |server_dst| {
668 server_dst
669 .recv(handle_cloned, dtype, stream_id_dst, device_id_src)
670 .unwrap();
671 server_dst.sync_collective(stream_id_dst).unwrap();
672 });
673
674 self.device.flush_queue();
678 dst_server.device.flush_queue();
679
680 handle
681 }
682
683 #[track_caller]
684 #[cfg_attr(feature = "tracing", tracing::instrument(level="trace",
685 skip(self, kernel, bindings),
686 fields(
687 kernel.name = %kernel.name(),
688 kernel.id = %kernel.id(),
689 )
690 ))]
691 unsafe fn launch_inner(
692 &self,
693 kernel: <R::Server as ComputeServer>::Kernel,
694 count: CubeCount,
695 bindings: KernelArguments,
696 mode: ExecutionMode,
697 stream_id: StreamId,
698 ) {
699 let level = self.utilities.logger.profile_level();
700
701 match level {
702 None | Some(ProfileLevel::ExecutionOnly) => {
703 let utilities = self.utilities.clone();
704 self.device.submit(move |state| {
705 let name = kernel.name();
706 unsafe { state.launch(kernel, count, bindings, mode, stream_id) };
707
708 if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
709 let info = type_name_format(name, TypeNameFormatLevel::Balanced);
710 utilities.logger.register_execution(info);
711 }
712 });
713 }
714 Some(level) => {
715 let name = kernel.name();
716 let kernel_id = kernel.id();
717 let context = self.device.clone();
718 let count_moved = count.clone();
719 let (result, profile) = self
720 .profile(
721 move || {
722 context
723 .submit_blocking(move |state| unsafe {
724 state.launch(kernel, count_moved, bindings, mode, stream_id)
725 })
726 .unwrap()
727 },
728 name,
729 )
730 .unwrap();
731 let info = match level {
732 ProfileLevel::Full => {
733 format!("{name}: {kernel_id} CubeCount {count:?}")
734 }
735 _ => type_name_format(name, TypeNameFormatLevel::Balanced),
736 };
737 self.utilities.logger.register_profiled(info, profile);
738 result
739 }
740 }
741 }
742
743 #[track_caller]
745 pub fn launch(
746 &self,
747 kernel: <R::Server as ComputeServer>::Kernel,
748 count: CubeCount,
749 bindings: KernelArguments,
750 ) {
751 unsafe {
753 self.launch_inner(
754 kernel,
755 count,
756 bindings,
757 ExecutionMode::Checked,
758 self.stream_id(),
759 )
760 }
761 }
762
763 #[track_caller]
771 pub unsafe fn launch_unchecked(
772 &self,
773 kernel: <R::Server as ComputeServer>::Kernel,
774 count: CubeCount,
775 bindings: KernelArguments,
776 ) {
777 unsafe {
779 self.launch_inner(
780 kernel,
781 count,
782 bindings,
783 match self.utilities.check_mode {
784 crate::config::compilation::BoundsCheckMode::Enforce => ExecutionMode::Checked,
785 crate::config::compilation::BoundsCheckMode::Validate => {
786 ExecutionMode::Validate
787 }
788 crate::config::compilation::BoundsCheckMode::Auto => ExecutionMode::Unchecked,
789 },
790 self.stream_id(),
791 )
792 }
793 }
794
795 pub fn flush(&self) -> Result<(), ServerError> {
797 let stream_id = self.stream_id();
798
799 self.device
800 .submit_blocking(move |server| server.flush(stream_id))
801 .unwrap()
802 }
803
804 pub fn sync(&self) -> DynFut<Result<(), ServerError>> {
806 let stream_id = self.stream_id();
807
808 let fut = self
809 .device
810 .submit_blocking(move |server| server.sync(stream_id))
811 .unwrap();
812
813 self.utilities.logger.profile_summary();
814
815 fut
816 }
817
818 pub fn properties(&self) -> &DeviceProperties {
820 &self.utilities.properties
821 }
822
823 pub fn features(&self) -> &Features {
825 &self.utilities.properties.features
826 }
827
828 pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
832 Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
833 }
834
835 pub fn memory_usage(&self) -> Result<MemoryUsage, ServerError> {
837 let stream_id = self.stream_id();
838 self.device
839 .submit_blocking(move |server| server.memory_usage(stream_id))
840 .unwrap()
841 }
842
843 pub fn enumerate_devices(&self, type_id: u16) -> Vec<DeviceId> {
845 R::enumerate_devices(type_id, self.info())
846 }
847
848 pub fn enumerate_all_devices(&self) -> Vec<DeviceId> {
850 R::enumerate_all_devices(self.info())
851 }
852
853 pub fn device_count(&self, type_id: u16) -> usize {
855 self.enumerate_devices(type_id).len()
856 }
857
858 pub fn device_count_total(&self) -> usize {
860 self.enumerate_all_devices().len()
861 }
862
863 pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
869 let stream_id = self.stream_id();
870 self.device
871 .submit(move |server| server.allocation_mode(mode, stream_id));
872 }
873
874 pub fn memory_cleanup(&self) {
879 let stream_id = self.stream_id();
880 self.device
881 .submit(move |server| server.memory_cleanup(stream_id));
882 }
883
884 #[track_caller]
886 pub fn profile<O: Send + 'static>(
887 &self,
888 func: impl FnOnce() -> O + Send,
889 #[allow(unused)] func_name: &str,
890 ) -> Result<(O, ProfileDuration), ProfileError> {
891 #[cfg(feature = "profile-tracy")]
894 let location = std::panic::Location::caller();
895
896 #[cfg(feature = "profile-tracy")]
898 let _span = tracy_client::Client::running().unwrap().span_alloc(
899 None,
900 func_name,
901 location.file(),
902 location.line(),
903 0,
904 );
905
906 let stream_id = self.stream_id();
907
908 #[cfg(feature = "profile-tracy")]
909 let gpu_span = if self.utilities.properties.timing_method == TimingMethod::Device {
910 let gpu_span = self
911 .utilities
912 .gpu_client
913 .span_alloc(func_name, "profile", location.file(), location.line())
914 .unwrap();
915 Some(gpu_span)
916 } else {
917 None
918 };
919
920 let device = self.device.clone();
921 #[allow(unused_mut, reason = "Used in profile-tracy")]
922 let mut result = self
923 .device
924 .exclusive(move || {
925 let token =
928 match device.submit_blocking(move |server| server.start_profile(stream_id)) {
929 Ok(token) => match token {
930 Ok(token) => token,
931 Err(err) => return Err(err),
932 },
933 Err(err) => {
934 return Err(ServerError::Generic {
935 reason: alloc::format!(
936 "Can't start profiling because of a call error: {err:?}"
937 ),
938 backtrace: BackTrace::capture(),
939 });
940 }
941 };
942
943 let out = func();
945
946 let result = device
948 .submit_blocking(move |server| {
949 let mut result = server.end_profile(stream_id, token);
950
951 match result {
952 Ok(result) => Ok((out, result)),
953 Err(err) => Err(err),
954 }
955 })
956 .unwrap();
957
958 Ok(result)
959 })
960 .unwrap()
961 .map_err(|err| ProfileError::Unknown {
962 reason: alloc::format!("{err:?}"),
963 backtrace: BackTrace::capture(),
964 })?;
965
966 #[cfg(feature = "profile-tracy")]
967 if let Some(mut gpu_span) = gpu_span {
968 gpu_span.end_zone();
969 let epoch = self.utilities.epoch_time;
970 result = result.map(|(o, result)| {
972 (
973 o,
974 ProfileDuration::new(
975 alloc::boxed::Box::pin(async move {
976 let ticks = result.resolve().await;
977 let start_duration =
978 ticks.start_duration_since(epoch).as_nanos() as i64;
979 let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
980 gpu_span.upload_timestamp_start(start_duration);
981 gpu_span.upload_timestamp_end(end_duration);
982 ticks
983 }),
984 TimingMethod::Device,
985 ),
986 )
987 });
988 }
989
990 result
991 }
992
993 #[cfg_attr(
995 feature = "tracing",
996 tracing::instrument(
997 level = "trace",
998 skip(self, src_descriptor, alloc_descriptor, dst_server)
999 )
1000 )]
1001 fn change_client_sync(
1002 &self,
1003 src_descriptor: CopyDescriptor,
1004 alloc_descriptor: MemoryLayoutDescriptor,
1005 dst_server: &Self,
1006 ) -> MemoryLayout {
1007 let shape = src_descriptor.shape.clone();
1008 let elem_size = src_descriptor.elem_size;
1009 let stream_id = self.stream_id();
1010
1011 let read = self
1012 .device
1013 .submit_blocking(move |server| server.read(vec![src_descriptor], stream_id))
1014 .unwrap();
1015
1016 let mut data = cubecl_common::future::block_on(read).unwrap();
1017
1018 let (handle_base, mut layouts) = self
1019 .utilities
1020 .layout_policy
1021 .apply(stream_id, &[alloc_descriptor]);
1022 let alloc = layouts.remove(0);
1023
1024 let desc_descriptor = CopyDescriptor {
1025 handle: handle_base.clone().binding(),
1026 shape,
1027 strides: alloc.strides.clone(),
1028 elem_size,
1029 };
1030
1031 let (size, memory) = (handle_base.size(), handle_base.memory);
1032 dst_server.device.submit(move |server| {
1033 server.initialize_memory(memory, size, stream_id);
1034 server.write(vec![(desc_descriptor, data.remove(0))], stream_id)
1035 });
1036
1037 alloc
1038 }
1039
1040 pub fn io_optimized_vector_sizes(
1042 &self,
1043 size: usize,
1044 ) -> impl Iterator<Item = VectorSize> + Clone {
1045 let load_width = self.properties().hardware.load_width as usize;
1046 let size_bits = size * 8;
1047 let max = load_width / size_bits;
1048 let max = usize::min(self.properties().hardware.max_vector_size, max);
1049
1050 let num_candidates = max.trailing_zeros() + 1;
1052
1053 (0..num_candidates).map(|i| 2usize.pow(i)).rev()
1054 }
1055}