1use crate::{
2 config::{TypeNameFormatLevel, type_name_format},
3 kernel::KernelMetadata,
4 logging::ProfileLevel,
5 memory_management::{MemoryAllocationMode, MemoryUsage},
6 runtime::Runtime,
7 server::{
8 ComputeServer, CopyDescriptor, CubeCount, ExecutionMode, Handle, IoError, KernelArguments,
9 MemoryLayout, MemoryLayoutDescriptor, MemoryLayoutPolicy, MemoryLayoutStrategy,
10 ProfileError, ReduceOperation, ServerCommunication, ServerError, ServerUtilities,
11 },
12 storage::{ComputeStorage, ManagedResource},
13};
14use alloc::{format, sync::Arc, vec, vec::Vec};
15use cubecl_common::{
16 backtrace::BackTrace,
17 bytes::{AllocationProperty, Bytes},
18 device::{Device, DeviceId},
19 device_handle::DeviceHandle,
20 future::DynFut,
21 profile::ProfileDuration,
22};
23use cubecl_ir::{DeviceProperties, ElemType, VectorSize, features::Features};
24use cubecl_zspace::Shape;
25
26#[allow(unused)]
27use cubecl_common::profile::TimingMethod;
28use cubecl_common::stream_id::StreamId;
29
30pub struct ComputeClient<R: Runtime> {
33 device: DeviceHandle<R::Server>,
34 utilities: Arc<ServerUtilities<R::Server>>,
35 stream_id: Option<StreamId>,
36}
37
38impl<R: Runtime> Clone for ComputeClient<R> {
39 fn clone(&self) -> Self {
40 Self {
41 device: self.device.clone(),
42 utilities: self.utilities.clone(),
43 stream_id: self.stream_id,
44 }
45 }
46}
47
48impl<R: Runtime> ComputeClient<R> {
49 pub fn info(&self) -> &<R::Server as ComputeServer>::Info {
51 &self.utilities.info
52 }
53
54 pub fn init<D: Device>(device: &D, server: R::Server) -> Self {
56 let utilities = server.utilities();
57 let context = DeviceHandle::<R::Server>::insert(device.to_id(), server)
58 .expect("Can't create a new client on an already registered server");
59
60 Self {
61 device: context,
62 utilities,
63 stream_id: None,
64 }
65 }
66
67 pub fn load<D: Device>(device: &D) -> Self {
69 let context = DeviceHandle::<R::Server>::new(device.to_id());
70
71 let utilities = context
73 .utilities()
74 .downcast::<ServerUtilities<R::Server>>()
75 .expect("Can downcast to `ServerUtilities`");
76
77 Self {
78 device: context,
79 utilities,
80 stream_id: None,
81 }
82 }
83
84 fn stream_id(&self) -> StreamId {
85 match self.stream_id {
86 Some(val) => val,
87 None => StreamId::current(),
88 }
89 }
90
91 pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
97 self.stream_id = Some(stream_id);
98 }
99
100 fn do_read(&self, descriptors: Vec<CopyDescriptor>) -> DynFut<Result<Vec<Bytes>, ServerError>> {
101 let stream_id = self.stream_id();
102 self.device
103 .submit_blocking(move |server| server.read(descriptors, stream_id))
104 .unwrap()
105 }
106
107 pub fn read_async(
109 &self,
110 handles: Vec<Handle>,
111 ) -> impl Future<Output = Result<Vec<Bytes>, ServerError>> + Send {
112 let shapes = handles
113 .iter()
114 .map(|it| [it.size_in_used() as usize].into())
115 .collect::<Vec<Shape>>();
116 let descriptors = handles
117 .into_iter()
118 .zip(shapes)
119 .map(|(handle, shape)| CopyDescriptor::new(handle.binding(), shape, [1].into(), 1))
120 .collect();
121
122 self.do_read(descriptors)
123 }
124
125 pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
131 cubecl_common::reader::read_sync(self.read_async(handles)).expect("TODO")
132 }
133
134 pub fn read_one(&self, handle: Handle) -> Result<Bytes, ServerError> {
136 Ok(cubecl_common::reader::read_sync(self.read_async(vec![handle]))?.remove(0))
137 }
138
139 pub fn read_one_unchecked(&self, handle: Handle) -> Bytes {
145 cubecl_common::reader::read_sync(self.read_async(vec![handle]))
146 .unwrap()
147 .remove(0)
148 }
149
150 pub fn read_tensor_async(
152 &self,
153 descriptors: Vec<CopyDescriptor>,
154 ) -> impl Future<Output = Result<Vec<Bytes>, ServerError>> + Send {
155 self.do_read(descriptors)
156 }
157
158 pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor>) -> Vec<Bytes> {
171 cubecl_common::reader::read_sync(self.read_tensor_async(descriptors)).expect("TODO")
172 }
173
174 pub fn read_one_tensor_async(
177 &self,
178 descriptor: CopyDescriptor,
179 ) -> impl Future<Output = Result<Bytes, ServerError>> + Send {
180 let fut = self.read_tensor_async(vec![descriptor]);
181
182 async { Ok(fut.await?.remove(0)) }
183 }
184
185 pub fn read_one_unchecked_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
192 self.read_tensor(vec![descriptor]).remove(0)
193 }
194
195 pub fn get_resource(
197 &self,
198 handle: Handle,
199 ) -> Result<
200 ManagedResource<<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource>,
201 ServerError,
202 > {
203 let stream_id = self.stream_id();
204 let binding = handle.binding();
205
206 self.device
207 .submit_blocking(move |state| state.get_resource(binding, stream_id))
208 .unwrap()
209 }
210
211 fn do_create_from_slices(
212 &self,
213 descriptors: Vec<MemoryLayoutDescriptor>,
214 slices: Vec<Vec<u8>>,
215 ) -> Result<Vec<MemoryLayout>, IoError> {
216 let stream_id = self.stream_id();
217 let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
218
219 let descriptors = descriptors
220 .into_iter()
221 .zip(layouts.iter())
222 .zip(slices)
223 .map(|((desc, alloc), data)| {
224 (
225 CopyDescriptor::new(
226 alloc.memory.clone().binding(),
227 desc.shape,
228 alloc.strides.clone(),
229 desc.elem_size,
230 ),
231 Bytes::from_bytes_vec(data.to_vec()),
232 )
233 })
234 .collect::<Vec<_>>();
235
236 let (size, memory) = (handle_base.size(), handle_base.memory);
237 self.device.submit(move |server| {
238 server.initialize_memory(memory, size, stream_id);
239 server.write(descriptors, stream_id);
240 });
241
242 Ok(layouts)
243 }
244
245 fn do_create(
246 &self,
247 descriptors: Vec<MemoryLayoutDescriptor>,
248 mut data: Vec<Bytes>,
249 ) -> Result<Vec<MemoryLayout>, IoError> {
250 self.staging(data.iter_mut(), true);
251
252 let stream_id = self.stream_id();
253 let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
254
255 let descriptors = descriptors
256 .into_iter()
257 .zip(layouts.iter())
258 .zip(data)
259 .map(|((desc, layout), data)| {
260 (
261 CopyDescriptor::new(
262 layout.memory.clone().binding(),
263 desc.shape,
264 layout.strides.clone(),
265 desc.elem_size,
266 ),
267 Bytes::from_bytes_vec(data.to_vec()),
268 )
269 })
270 .collect::<Vec<_>>();
271
272 let (size, memory) = (handle_base.size(), handle_base.memory);
273 self.device.submit(move |server| {
274 server.initialize_memory(memory, size, stream_id);
275 server.write(descriptors, stream_id);
276 });
277
278 Ok(layouts)
279 }
280
281 pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
287 let shape: Shape = [slice.len()].into();
288
289 self.do_create_from_slices(
290 vec![MemoryLayoutDescriptor::new(
291 MemoryLayoutStrategy::Contiguous,
292 shape,
293 1,
294 )],
295 vec![slice.to_vec()],
296 )
297 .unwrap()
298 .remove(0)
299 .memory
300 }
301
302 pub fn exclusive<Re: Send + 'static, F: FnOnce() -> Re + Send + 'static>(
304 &self,
305 task: F,
306 ) -> Result<Re, ServerError> {
307 self.flush()?;
309
310 self.device
312 .exclusive(task)
313 .map_err(|err| ServerError::Generic {
314 reason: format!("Communication channel with the server is down: {err:?}"),
315 backtrace: BackTrace::capture(),
316 })
317 }
318
319 pub fn scoped<'a, Re: Send, F: FnOnce() -> Re + Send + 'a>(
321 &'a self,
322 task: F,
323 ) -> Result<Re, ServerError> {
324 self.device
326 .exclusive_scoped(task)
327 .map_err(|err| ServerError::Generic {
328 reason: format!("Communication channel with the server is down: {err:?}"),
329 backtrace: BackTrace::capture(),
330 })
331 }
332
333 pub fn memory_persistent_allocation<
335 'a,
336 Re: Send,
337 Input: Send,
338 F: FnOnce(Input) -> Re + Send + 'a,
339 >(
340 &'a self,
341 input: Input,
342 task: F,
343 ) -> Result<Re, ServerError> {
344 self.device
346 .exclusive_scoped(move || task(input))
347 .map_err(|err| ServerError::Generic {
348 reason: format!("Communication channel with the server is down: {err:?}"),
349 backtrace: BackTrace::capture(),
350 })
351 }
352
353 pub fn create(&self, data: Bytes) -> Handle {
355 let shape = [data.len()].into();
356
357 self.do_create(
358 vec![MemoryLayoutDescriptor::new(
359 MemoryLayoutStrategy::Contiguous,
360 shape,
361 1,
362 )],
363 vec![data],
364 )
365 .unwrap()
366 .remove(0)
367 .memory
368 }
369
370 pub fn create_tensor_from_slice(
388 &self,
389 slice: &[u8],
390 shape: Shape,
391 elem_size: usize,
392 ) -> MemoryLayout {
393 self.do_create_from_slices(
394 vec![MemoryLayoutDescriptor::new(
395 MemoryLayoutStrategy::Optimized,
396 shape,
397 elem_size,
398 )],
399 vec![slice.to_vec()],
400 )
401 .unwrap()
402 .remove(0)
403 }
404
405 pub fn create_tensor(&self, bytes: Bytes, shape: Shape, elem_size: usize) -> MemoryLayout {
419 self.do_create(
420 vec![MemoryLayoutDescriptor::new(
421 MemoryLayoutStrategy::Optimized,
422 shape,
423 elem_size,
424 )],
425 vec![bytes],
426 )
427 .unwrap()
428 .remove(0)
429 }
430
431 pub fn create_tensors_from_slices(
439 &self,
440 descriptors: Vec<(MemoryLayoutDescriptor, &[u8])>,
441 ) -> Vec<MemoryLayout> {
442 let mut data = Vec::with_capacity(descriptors.len());
443 let mut descriptors_ = Vec::with_capacity(descriptors.len());
444 for (a, b) in descriptors {
445 data.push(b.to_vec());
446 descriptors_.push(a);
447 }
448
449 self.do_create_from_slices(descriptors_, data).unwrap()
450 }
451
452 pub fn create_tensors(
456 &self,
457 descriptors: Vec<(MemoryLayoutDescriptor, Bytes)>,
458 ) -> Vec<MemoryLayout> {
459 let (descriptors, data) = descriptors.into_iter().unzip();
460
461 self.do_create(descriptors, data).unwrap()
462 }
463
464 fn do_empty(
465 &self,
466 descriptors: Vec<MemoryLayoutDescriptor>,
467 ) -> Result<Vec<MemoryLayout>, IoError> {
468 let stream_id = self.stream_id();
469 let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
470
471 let (size, memory) = (handle_base.size(), handle_base.memory);
472 self.device.submit(move |server| {
473 server.initialize_memory(memory, size, stream_id);
474 });
475
476 Ok(layouts)
477 }
478
479 pub fn empty(&self, size: usize) -> Handle {
481 let shape: Shape = [size].into();
482 let descriptor = MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Contiguous, shape, 1);
483 self.do_empty(vec![descriptor]).unwrap().remove(0).memory
484 }
485
486 pub fn empty_tensor(&self, shape: Shape, elem_size: usize) -> MemoryLayout {
489 let descriptor =
490 MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Optimized, shape, elem_size);
491 self.do_empty(vec![descriptor]).unwrap().remove(0)
492 }
493
494 pub fn empty_tensors(&self, descriptors: Vec<MemoryLayoutDescriptor>) -> Vec<MemoryLayout> {
497 self.do_empty(descriptors).unwrap()
498 }
499
500 pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
505 where
506 I: Iterator<Item = &'a mut Bytes>,
507 {
508 let has_staging = |b: &Bytes| match b.property() {
509 AllocationProperty::Pinned => false,
510 AllocationProperty::File => true,
511 AllocationProperty::Native | AllocationProperty::Other => !file_only,
512 };
513
514 let mut to_be_updated = Vec::new();
515 let sizes = bytes
516 .filter_map(|b| match has_staging(b) {
517 true => {
518 let len = b.len();
519 to_be_updated.push(b);
520 Some(len)
521 }
522 false => None,
523 })
524 .collect::<Vec<usize>>();
525
526 if sizes.is_empty() {
527 return;
528 }
529
530 let stream_id = self.stream_id();
531 let sizes = sizes.to_vec();
532 let stagings = self
533 .device
534 .submit_blocking(move |server| server.staging(&sizes, stream_id))
535 .unwrap();
536
537 let stagings = match stagings {
538 Ok(val) => val,
539 Err(_) => return,
540 };
541
542 to_be_updated
543 .into_iter()
544 .zip(stagings)
545 .for_each(|(b, mut staging)| {
546 b.copy_into(&mut staging);
547 core::mem::swap(b, &mut staging);
548 });
549 }
550
551 #[cfg_attr(
553 feature = "tracing",
554 tracing::instrument(level = "trace", skip(self, src, dst_server))
555 )]
556 pub fn to_client(&self, src: Handle, dst_server: &Self) -> Handle {
557 let shape = [src.size_in_used() as usize];
558 let src_descriptor = src.copy_descriptor(shape.into(), [1].into(), 1);
559
560 if R::Server::SERVER_COMM_ENABLED {
561 self.to_client_tensor(src_descriptor, dst_server)
562 } else {
563 let alloc_desc = MemoryLayoutDescriptor::new(
564 MemoryLayoutStrategy::Contiguous,
565 src_descriptor.shape.clone(),
566 src_descriptor.elem_size,
567 );
568 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
569 .memory
570 }
571 }
572
573 #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
575 pub fn sync_collective(&self) {
576 if DeviceHandle::<R::Server>::is_blocking() {
577 panic!("Can't use `sync_collective` with a blocking device handle");
578 }
579 let stream_id = self.stream_id();
580
581 self.device.submit(move |server| {
582 server.sync_collective(stream_id).unwrap();
583 });
584 self.device.flush_queue();
587 }
588
589 #[cfg_attr(
591 feature = "tracing",
592 tracing::instrument(level = "trace", skip(self, src, dst, dtype, device_ids, op))
593 )]
594 pub fn all_reduce(
595 &self,
596 src: Handle,
597 dst: Handle,
598 dtype: ElemType,
599 device_ids: Vec<DeviceId>,
600 op: ReduceOperation,
601 ) {
602 if DeviceHandle::<R::Server>::is_blocking() {
603 panic!("Can't use `all_reduce` with a blocking device handle");
604 }
605
606 let stream_id = self.stream_id();
607 let src = src.binding();
608 let dst = dst.binding();
609
610 self.device.submit(move |server| {
611 server
612 .all_reduce(src, dst, dtype, stream_id, op, device_ids)
613 .unwrap();
614 });
615 }
616
617 #[cfg_attr(
621 feature = "tracing",
622 tracing::instrument(level = "trace", skip(self, src_descriptor, dst_server))
623 )]
624 pub fn to_client_tensor(&self, src_descriptor: CopyDescriptor, dst_server: &Self) -> Handle {
625 if R::Server::SERVER_COMM_ENABLED {
626 let stream_id_src = self.stream_id();
627 let stream_id_dst = dst_server.stream_id();
628
629 let dst_server = dst_server.clone();
630 let handle = Handle::new(stream_id_dst, src_descriptor.handle.size_in_used());
631 let handle_cloned = handle.clone();
632
633 self.device
635 .submit_blocking_scoped(move |server_src| {
636 dst_server.device.submit_blocking_scoped(|server_dst| {
637 R::Server::copy(
638 handle_cloned,
639 server_src,
640 server_dst,
641 src_descriptor,
642 stream_id_src,
643 stream_id_dst,
644 )
645 })
646 })
647 .unwrap();
648
649 handle
650 } else {
651 let alloc_desc = MemoryLayoutDescriptor::new(
652 MemoryLayoutStrategy::Optimized,
653 src_descriptor.shape.clone(),
654 src_descriptor.elem_size,
655 );
656 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
657 .memory
658 }
659 }
660
661 #[track_caller]
662 #[cfg_attr(feature = "tracing", tracing::instrument(level="trace",
663 skip(self, kernel, bindings),
664 fields(
665 kernel.name = %kernel.name(),
666 kernel.id = %kernel.id(),
667 )
668 ))]
669 unsafe fn launch_inner(
670 &self,
671 kernel: <R::Server as ComputeServer>::Kernel,
672 count: CubeCount,
673 bindings: KernelArguments,
674 mode: ExecutionMode,
675 stream_id: StreamId,
676 ) {
677 let level = self.utilities.logger.profile_level();
678
679 match level {
680 None | Some(ProfileLevel::ExecutionOnly) => {
681 let utilities = self.utilities.clone();
682 self.device.submit(move |state| {
683 let name = kernel.name();
684 unsafe { state.launch(kernel, count, bindings, mode, stream_id) };
685
686 if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
687 let info = type_name_format(name, TypeNameFormatLevel::Balanced);
688 utilities.logger.register_execution(info);
689 }
690 });
691 }
692 Some(level) => {
693 let name = kernel.name();
694 let kernel_id = kernel.id();
695 let context = self.device.clone();
696 let count_moved = count.clone();
697 let (result, profile) = self
698 .profile(
699 move || {
700 context
701 .submit_blocking(move |state| unsafe {
702 state.launch(kernel, count_moved, bindings, mode, stream_id)
703 })
704 .unwrap()
705 },
706 name,
707 )
708 .unwrap();
709 let info = match level {
710 ProfileLevel::Full => {
711 format!("{name}: {kernel_id} CubeCount {count:?}")
712 }
713 _ => type_name_format(name, TypeNameFormatLevel::Balanced),
714 };
715 self.utilities.logger.register_profiled(info, profile);
716 result
717 }
718 }
719 }
720
721 #[track_caller]
723 pub fn launch(
724 &self,
725 kernel: <R::Server as ComputeServer>::Kernel,
726 count: CubeCount,
727 bindings: KernelArguments,
728 ) {
729 unsafe {
731 self.launch_inner(
732 kernel,
733 count,
734 bindings,
735 ExecutionMode::Checked,
736 self.stream_id(),
737 )
738 }
739 }
740
741 #[track_caller]
749 pub unsafe fn launch_unchecked(
750 &self,
751 kernel: <R::Server as ComputeServer>::Kernel,
752 count: CubeCount,
753 bindings: KernelArguments,
754 ) {
755 unsafe {
757 self.launch_inner(
758 kernel,
759 count,
760 bindings,
761 match self.utilities.check_mode {
762 crate::config::compilation::BoundsCheckMode::Enforce => ExecutionMode::Checked,
763 crate::config::compilation::BoundsCheckMode::Validate => {
764 ExecutionMode::Validate
765 }
766 crate::config::compilation::BoundsCheckMode::Auto => ExecutionMode::Unchecked,
767 },
768 self.stream_id(),
769 )
770 }
771 }
772
773 pub fn flush(&self) -> Result<(), ServerError> {
775 let stream_id = self.stream_id();
776
777 self.device
778 .submit_blocking(move |server| server.flush(stream_id))
779 .unwrap()
780 }
781
782 pub fn sync(&self) -> DynFut<Result<(), ServerError>> {
784 let stream_id = self.stream_id();
785
786 let fut = self
787 .device
788 .submit_blocking(move |server| server.sync(stream_id))
789 .unwrap();
790
791 self.utilities.logger.profile_summary();
792
793 fut
794 }
795
796 pub fn properties(&self) -> &DeviceProperties {
798 &self.utilities.properties
799 }
800
801 pub fn features(&self) -> &Features {
803 &self.utilities.properties.features
804 }
805
806 pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
810 Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
811 }
812
813 pub fn memory_usage(&self) -> Result<MemoryUsage, ServerError> {
815 let stream_id = self.stream_id();
816 self.device
817 .submit_blocking(move |server| server.memory_usage(stream_id))
818 .unwrap()
819 }
820
821 pub fn enumerate_devices(&self, type_id: u16) -> Vec<DeviceId> {
823 R::enumerate_devices(type_id, self.info())
824 }
825
826 pub fn enumerate_all_devices(&self) -> Vec<DeviceId> {
828 R::enumerate_all_devices(self.info())
829 }
830
831 pub fn device_count(&self, type_id: u16) -> usize {
833 self.enumerate_devices(type_id).len()
834 }
835
836 pub fn device_count_total(&self) -> usize {
838 self.enumerate_all_devices().len()
839 }
840
841 pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
847 let stream_id = self.stream_id();
848 self.device
849 .submit(move |server| server.allocation_mode(mode, stream_id));
850 }
851
852 pub fn memory_cleanup(&self) {
885 let stream_id = self.stream_id();
886 self.device
887 .submit(move |server| server.memory_cleanup(stream_id));
888 }
889
890 #[track_caller]
892 pub fn profile<O: Send + 'static>(
893 &self,
894 func: impl FnOnce() -> O + Send,
895 #[allow(unused)] func_name: &str,
896 ) -> Result<(O, ProfileDuration), ProfileError> {
897 #[cfg(feature = "profile-tracy")]
900 let location = std::panic::Location::caller();
901
902 #[cfg(feature = "profile-tracy")]
904 let _span = tracy_client::Client::running().unwrap().span_alloc(
905 None,
906 func_name,
907 location.file(),
908 location.line(),
909 0,
910 );
911
912 let stream_id = self.stream_id();
913
914 #[cfg(feature = "profile-tracy")]
915 let gpu_span = if self.utilities.properties.timing_method == TimingMethod::Device {
916 let gpu_span = self
917 .utilities
918 .gpu_client
919 .span_alloc(func_name, "profile", location.file(), location.line())
920 .unwrap();
921 Some(gpu_span)
922 } else {
923 None
924 };
925
926 let device = self.device.clone();
927 #[allow(unused_mut, reason = "Used in profile-tracy")]
928 let mut result = self
929 .device
930 .exclusive_scoped(move || {
931 let token =
934 match device.submit_blocking(move |server| server.start_profile(stream_id)) {
935 Ok(token) => match token {
936 Ok(token) => token,
937 Err(err) => return Err(err),
938 },
939 Err(err) => {
940 return Err(ServerError::Generic {
941 reason: alloc::format!(
942 "Can't start profiling because of a call error: {err:?}"
943 ),
944 backtrace: BackTrace::capture(),
945 });
946 }
947 };
948
949 let out = func();
951
952 let result = device
954 .submit_blocking(move |server| {
955 let mut result = server.end_profile(stream_id, token);
956
957 match result {
958 Ok(result) => Ok((out, result)),
959 Err(err) => Err(err),
960 }
961 })
962 .unwrap();
963
964 Ok(result)
965 })
966 .unwrap()
967 .map_err(|err| ProfileError::Unknown {
968 reason: alloc::format!("{err:?}"),
969 backtrace: BackTrace::capture(),
970 })?;
971
972 #[cfg(feature = "profile-tracy")]
973 if let Some(mut gpu_span) = gpu_span {
974 gpu_span.end_zone();
975 let epoch = self.utilities.epoch_time;
976 result = result.map(|(o, result)| {
978 (
979 o,
980 ProfileDuration::new(
981 alloc::boxed::Box::pin(async move {
982 let ticks = result.resolve().await;
983 let start_duration =
984 ticks.start_duration_since(epoch).as_nanos() as i64;
985 let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
986 gpu_span.upload_timestamp_start(start_duration);
987 gpu_span.upload_timestamp_end(end_duration);
988 ticks
989 }),
990 TimingMethod::Device,
991 ),
992 )
993 });
994 }
995
996 result
997 }
998
999 #[cfg_attr(
1001 feature = "tracing",
1002 tracing::instrument(
1003 level = "trace",
1004 skip(self, src_descriptor, alloc_descriptor, dst_server)
1005 )
1006 )]
1007 fn change_client_sync(
1008 &self,
1009 src_descriptor: CopyDescriptor,
1010 alloc_descriptor: MemoryLayoutDescriptor,
1011 dst_server: &Self,
1012 ) -> MemoryLayout {
1013 let shape = src_descriptor.shape.clone();
1014 let elem_size = src_descriptor.elem_size;
1015 let stream_id = self.stream_id();
1016
1017 let read = self
1018 .device
1019 .submit_blocking(move |server| server.read(vec![src_descriptor], stream_id))
1020 .unwrap();
1021
1022 let mut data = cubecl_common::future::block_on(read).unwrap();
1023
1024 let (handle_base, mut layouts) = self
1025 .utilities
1026 .layout_policy
1027 .apply(stream_id, &[alloc_descriptor]);
1028 let alloc = layouts.remove(0);
1029
1030 let desc_descriptor = CopyDescriptor {
1031 handle: handle_base.clone().binding(),
1032 shape,
1033 strides: alloc.strides.clone(),
1034 elem_size,
1035 };
1036
1037 let (size, memory) = (handle_base.size(), handle_base.memory);
1038 dst_server.device.submit(move |server| {
1039 server.initialize_memory(memory, size, stream_id);
1040 server.write(vec![(desc_descriptor, data.remove(0))], stream_id)
1041 });
1042
1043 alloc
1044 }
1045
1046 pub fn io_optimized_vector_sizes(
1048 &self,
1049 size: usize,
1050 ) -> impl Iterator<Item = VectorSize> + Clone {
1051 let load_width = self.properties().hardware.load_width as usize;
1052 let size_bits = size * 8;
1053 let max = load_width / size_bits;
1054 let max = usize::min(self.properties().hardware.max_vector_size, max);
1055
1056 let num_candidates = max.trailing_zeros() + 1;
1058
1059 (0..num_candidates).map(|i| 2usize.pow(i)).rev()
1060 }
1061}