Skip to main content

cubecl_runtime/
client.rs

1use crate::{
2    config::{TypeNameFormatLevel, type_name_format},
3    kernel::KernelMetadata,
4    logging::ProfileLevel,
5    memory_management::{MemoryAllocationMode, MemoryUsage},
6    runtime::Runtime,
7    server::{
8        Allocation, AllocationDescriptor, AllocationKind, Binding, Bindings, ComputeServer,
9        CopyDescriptor, CubeCount, ExecutionError, ExecutionMode, Handle, IoError, LaunchError,
10        ProfileError, ServerCommunication, ServerUtilities,
11    },
12    storage::{BindingResource, ComputeStorage},
13};
14use alloc::format;
15use alloc::sync::Arc;
16use alloc::vec;
17use alloc::vec::Vec;
18use core::ops::DerefMut;
19use cubecl_common::{
20    bytes::{AllocationProperty, Bytes},
21    device::{Device, DeviceContext},
22    future::DynFut,
23    profile::ProfileDuration,
24};
25use cubecl_ir::{DeviceProperties, LineSize};
26
27#[cfg(feature = "profile-tracy")]
28use alloc::boxed::Box;
29
30#[allow(unused)]
31use cubecl_common::profile::TimingMethod;
32use cubecl_common::stream_id::StreamId;
33
34/// The `ComputeClient` is the entry point to require tasks from the `ComputeServer`.
35/// It should be obtained for a specific device via the Compute struct.
36pub struct ComputeClient<R: Runtime> {
37    context: DeviceContext<R::Server>,
38    utilities: Arc<ServerUtilities<R::Server>>,
39    stream_id: Option<StreamId>,
40}
41
42impl<R: Runtime> Clone for ComputeClient<R> {
43    fn clone(&self) -> Self {
44        Self {
45            context: self.context.clone(),
46            utilities: self.utilities.clone(),
47            stream_id: self.stream_id,
48        }
49    }
50}
51
52impl<R: Runtime> ComputeClient<R> {
53    /// Get the info of the current backend.
54    pub fn info(&self) -> &<R::Server as ComputeServer>::Info {
55        &self.utilities.info
56    }
57
58    /// Create a new client with a new server.
59    pub fn init<D: Device>(device: &D, server: R::Server) -> Self {
60        let utilities = server.utilities();
61
62        let context = DeviceContext::<R::Server>::insert(device, server)
63            .expect("Can't create a new client on an already registered server");
64
65        Self {
66            context,
67            utilities,
68            stream_id: None,
69        }
70    }
71
72    /// Load the client for the given device.
73    pub fn load<D: Device>(device: &D) -> Self {
74        let context = DeviceContext::<R::Server>::locate(device);
75        let utilities = context.lock().utilities();
76
77        Self {
78            context,
79            utilities,
80            stream_id: None,
81        }
82    }
83
84    fn stream_id(&self) -> StreamId {
85        match self.stream_id {
86            Some(val) => val,
87            None => StreamId::current(),
88        }
89    }
90
91    /// Set the stream in which the current client is operating on.
92    ///
93    /// # Safety
94    ///
95    /// This is highly unsafe and should probably only be used by the CubeCL/Burn projects for now.
96    pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
97        self.stream_id = Some(stream_id);
98    }
99
100    fn do_read(&self, descriptors: Vec<CopyDescriptor<'_>>) -> DynFut<Result<Vec<Bytes>, IoError>> {
101        let stream_id = self.stream_id();
102        let mut state = self.context.lock();
103        let fut = state.read(descriptors, stream_id);
104        core::mem::drop(state);
105        fut
106    }
107
108    /// Given bindings, returns owned resources as bytes.
109    pub fn read_async(
110        &self,
111        handles: Vec<Handle>,
112    ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
113        let strides = [1];
114        let shapes = handles
115            .iter()
116            .map(|it| [it.size() as usize])
117            .collect::<Vec<_>>();
118        let bindings = handles
119            .into_iter()
120            .map(|it| it.binding())
121            .collect::<Vec<_>>();
122        let descriptors = bindings
123            .into_iter()
124            .zip(shapes.iter())
125            .map(|(binding, shape)| CopyDescriptor::new(binding, shape, &strides, 1))
126            .collect();
127
128        self.do_read(descriptors)
129    }
130
131    /// Given bindings, returns owned resources as bytes.
132    ///
133    /// # Remarks
134    ///
135    /// Panics if the read operation fails.
136    pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
137        cubecl_common::reader::read_sync(self.read_async(handles)).expect("TODO")
138    }
139
140    /// Given a binding, returns owned resource as bytes.
141    ///
142    /// # Remarks
143    /// Panics if the read operation fails.
144    pub fn read_one(&self, handle: Handle) -> Bytes {
145        cubecl_common::reader::read_sync(self.read_async(vec![handle]))
146            .expect("TODO")
147            .remove(0)
148    }
149
150    /// Given bindings, returns owned resources as bytes.
151    pub fn read_tensor_async(
152        &self,
153        descriptors: Vec<CopyDescriptor<'_>>,
154    ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
155        self.do_read(descriptors)
156    }
157
158    /// Given bindings, returns owned resources as bytes.
159    ///
160    /// # Remarks
161    ///
162    /// Panics if the read operation fails.
163    ///
164    /// The tensor must be in the same layout as created by the runtime, or more strict.
165    /// Contiguous tensors are always fine, strided tensors are only ok if the stride is similar to
166    /// the one created by the runtime (i.e. padded on only the last dimension). A way to check
167    /// stride compatibility on the runtime will be added in the future.
168    ///
169    /// Also see [`ComputeClient::create_tensor`].
170    pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor<'_>>) -> Vec<Bytes> {
171        cubecl_common::reader::read_sync(self.read_tensor_async(descriptors)).expect("TODO")
172    }
173
174    /// Given a binding, returns owned resource as bytes.
175    /// See [`ComputeClient::read_tensor`]
176    pub fn read_one_tensor_async(
177        &self,
178        descriptor: CopyDescriptor<'_>,
179    ) -> impl Future<Output = Result<Bytes, IoError>> + Send {
180        let fut = self.read_tensor_async(vec![descriptor]);
181
182        async { Ok(fut.await?.remove(0)) }
183    }
184
185    /// Given a binding, returns owned resource as bytes.
186    ///
187    /// # Remarks
188    /// Panics if the read operation fails.
189    /// See [`ComputeClient::read_tensor`]
190    pub fn read_one_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
191        self.read_tensor(vec![descriptor]).remove(0)
192    }
193
194    /// Given a resource handle, returns the storage resource.
195    pub fn get_resource(
196        &self,
197        binding: Binding,
198    ) -> BindingResource<<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource> {
199        let stream_id = self.stream_id();
200        self.context.lock().get_resource(binding, stream_id)
201    }
202
203    fn do_create_from_slices(
204        &self,
205        descriptors: Vec<AllocationDescriptor<'_>>,
206        slices: Vec<&[u8]>,
207    ) -> Result<Vec<Allocation>, IoError> {
208        let mut state = self.context.lock();
209        let allocations = state.create(descriptors.clone(), self.stream_id())?;
210        let descriptors = descriptors
211            .into_iter()
212            .zip(allocations.iter())
213            .zip(slices)
214            .map(|((desc, alloc), data)| {
215                (
216                    CopyDescriptor::new(
217                        alloc.handle.clone().binding(),
218                        desc.shape,
219                        &alloc.strides,
220                        desc.elem_size,
221                    ),
222                    Bytes::from_bytes_vec(data.to_vec()),
223                )
224            })
225            .collect();
226        let stream_id = self.stream_id();
227        state.write(descriptors, stream_id)?;
228        Ok(allocations)
229    }
230
231    fn do_create(
232        &self,
233        descriptors: Vec<AllocationDescriptor<'_>>,
234        mut data: Vec<Bytes>,
235    ) -> Result<Vec<Allocation>, IoError> {
236        self.staging(data.iter_mut(), true);
237
238        let mut state = self.context.lock();
239        let allocations = state.create(descriptors.clone(), self.stream_id())?;
240        let descriptors = descriptors
241            .into_iter()
242            .zip(allocations.iter())
243            .zip(data)
244            .map(|((desc, alloc), data)| {
245                (
246                    CopyDescriptor::new(
247                        alloc.handle.clone().binding(),
248                        desc.shape,
249                        &alloc.strides,
250                        desc.elem_size,
251                    ),
252                    data,
253                )
254            })
255            .collect();
256        let stream_id = self.stream_id();
257        state.write(descriptors, stream_id)?;
258        Ok(allocations)
259    }
260
261    /// Returns a resource handle containing the given data.
262    ///
263    /// # Notes
264    ///
265    /// Prefer using the more efficient [`Self::create`] function.
266    pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
267        let shape = [slice.len()];
268
269        self.do_create_from_slices(
270            vec![AllocationDescriptor::new(
271                AllocationKind::Contiguous,
272                &shape,
273                1,
274            )],
275            vec![slice],
276        )
277        .unwrap()
278        .remove(0)
279        .handle
280    }
281
282    /// Returns a resource handle containing the given [Bytes].
283    pub fn create(&self, data: Bytes) -> Handle {
284        let shape = [data.len()];
285
286        self.do_create(
287            vec![AllocationDescriptor::new(
288                AllocationKind::Contiguous,
289                &shape,
290                1,
291            )],
292            vec![data],
293        )
294        .unwrap()
295        .remove(0)
296        .handle
297    }
298
299    /// Given a resource and shape, stores it and returns the tensor handle and strides.
300    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
301    /// should be taken when indexing.
302    ///
303    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
304    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
305    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
306    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
307    /// can load as much data as possible in a single instruction. It may be aligned even more to
308    /// also take cache lines into account.
309    ///
310    /// However, the stride must be taken into account when indexing and reading the tensor
311    /// (also see [`ComputeClient::read_tensor`]).
312    ///
313    /// # Notes
314    ///
315    /// Prefer using [`Self::create_tensor`] for better performance.
316    pub fn create_tensor_from_slice(
317        &self,
318        slice: &[u8],
319        shape: &[usize],
320        elem_size: usize,
321    ) -> Allocation {
322        self.do_create_from_slices(
323            vec![AllocationDescriptor::new(
324                AllocationKind::Optimized,
325                shape,
326                elem_size,
327            )],
328            vec![slice],
329        )
330        .unwrap()
331        .remove(0)
332    }
333
334    /// Given a resource and shape, stores it and returns the tensor handle and strides.
335    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
336    /// should be taken when indexing.
337    ///
338    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
339    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
340    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
341    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
342    /// can load as much data as possible in a single instruction. It may be aligned even more to
343    /// also take cache lines into account.
344    ///
345    /// However, the stride must be taken into account when indexing and reading the tensor
346    /// (also see [`ComputeClient::read_tensor`]).
347    pub fn create_tensor(&self, bytes: Bytes, shape: &[usize], elem_size: usize) -> Allocation {
348        self.do_create(
349            vec![AllocationDescriptor::new(
350                AllocationKind::Optimized,
351                shape,
352                elem_size,
353            )],
354            vec![bytes],
355        )
356        .unwrap()
357        .remove(0)
358    }
359
360    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
361    /// handle, and returns the handles for them.
362    /// See [`ComputeClient::create_tensor`]
363    ///
364    /// # Notes
365    ///
366    /// Prefer using [`Self::create_tensors`] for better performance.
367    pub fn create_tensors_from_slices(
368        &self,
369        descriptors: Vec<(AllocationDescriptor<'_>, &[u8])>,
370    ) -> Vec<Allocation> {
371        let (descriptors, data) = descriptors.into_iter().unzip();
372
373        self.do_create_from_slices(descriptors, data).unwrap()
374    }
375
376    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
377    /// handle, and returns the handles for them.
378    /// See [`ComputeClient::create_tensor`]
379    pub fn create_tensors(
380        &self,
381        descriptors: Vec<(AllocationDescriptor<'_>, Bytes)>,
382    ) -> Vec<Allocation> {
383        let (descriptors, data) = descriptors.into_iter().unzip();
384
385        self.do_create(descriptors, data).unwrap()
386    }
387
388    fn do_empty(
389        &self,
390        descriptors: Vec<AllocationDescriptor<'_>>,
391    ) -> Result<Vec<Allocation>, IoError> {
392        let mut state = self.context.lock();
393        state.create(descriptors, self.stream_id())
394    }
395
396    /// Reserves `size` bytes in the storage, and returns a handle over them.
397    pub fn empty(&self, size: usize) -> Handle {
398        let shape = [size];
399        let descriptor = AllocationDescriptor::new(AllocationKind::Contiguous, &shape, 1);
400        self.do_empty(vec![descriptor]).unwrap().remove(0).handle
401    }
402
403    /// Reserves `shape` in the storage, and returns a tensor handle for it.
404    /// See [`ComputeClient::create_tensor`]
405    pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> Allocation {
406        let descriptor = AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size);
407        self.do_empty(vec![descriptor]).unwrap().remove(0)
408    }
409
410    /// Reserves all `shapes` in a single storage buffer, and returns the handles for them.
411    /// See [`ComputeClient::create_tensor`]
412    pub fn empty_tensors(&self, descriptors: Vec<AllocationDescriptor<'_>>) -> Vec<Allocation> {
413        self.do_empty(descriptors).unwrap()
414    }
415
416    /// Marks the given [Bytes] as being a staging buffer, maybe transferring it to pinned memory
417    /// for faster data transfer with compute device.
418    pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
419    where
420        I: Iterator<Item = &'a mut Bytes>,
421    {
422        let has_staging = |b: &Bytes| match b.property() {
423            AllocationProperty::Pinned => false,
424            AllocationProperty::File => true,
425            AllocationProperty::Native | AllocationProperty::Other => !file_only,
426        };
427
428        let mut to_be_updated = Vec::new();
429        let sizes = bytes
430            .filter_map(|b| match has_staging(b) {
431                true => {
432                    let len = b.len();
433                    to_be_updated.push(b);
434                    Some(len)
435                }
436                false => None,
437            })
438            .collect::<Vec<usize>>();
439
440        if sizes.is_empty() {
441            return;
442        }
443
444        let stream_id = self.stream_id();
445        let mut context = self.context.lock();
446        let stagings = match context.staging(&sizes, stream_id) {
447            Ok(val) => val,
448            Err(_) => return,
449        };
450        core::mem::drop(context);
451
452        to_be_updated
453            .into_iter()
454            .zip(stagings)
455            .for_each(|(b, mut staging)| {
456                b.copy_into(&mut staging);
457                core::mem::swap(b, &mut staging);
458            });
459    }
460
461    /// Transfer data from one client to another
462    #[cfg_attr(
463        feature = "tracing",
464        tracing::instrument(level = "trace", skip(self, src, dst_server))
465    )]
466    pub fn to_client(&self, src: Handle, dst_server: &Self) -> Allocation {
467        let shape = [src.size() as usize];
468        let src_descriptor = src.copy_descriptor(&shape, &[1], 1);
469
470        if R::Server::SERVER_COMM_ENABLED {
471            self.to_client_tensor(src_descriptor, dst_server)
472        } else {
473            let alloc_desc = AllocationDescriptor::new(
474                AllocationKind::Contiguous,
475                src_descriptor.shape,
476                src_descriptor.elem_size,
477            );
478            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
479        }
480    }
481
482    /// Transfer data from one client to another
483    ///
484    /// Make sure the source description can be read in a contiguous manner.
485    #[cfg_attr(
486        feature = "tracing",
487        tracing::instrument(level = "trace", skip(self, src_descriptor, dst_server))
488    )]
489    pub fn to_client_tensor(
490        &self,
491        src_descriptor: CopyDescriptor<'_>,
492        dst_server: &Self,
493    ) -> Allocation {
494        if R::Server::SERVER_COMM_ENABLED {
495            let guard = self.context.lock_device_kind();
496            let mut server_src = self.context.lock();
497            let mut server_dst = dst_server.context.lock();
498
499            let copied = R::Server::copy(
500                server_src.deref_mut(),
501                server_dst.deref_mut(),
502                src_descriptor,
503                self.stream_id(),
504                dst_server.stream_id(),
505            )
506            .unwrap();
507            core::mem::drop(server_src);
508            core::mem::drop(server_dst);
509            core::mem::drop(guard);
510            copied
511        } else {
512            let alloc_desc = AllocationDescriptor::new(
513                AllocationKind::Optimized,
514                src_descriptor.shape,
515                src_descriptor.elem_size,
516            );
517            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
518        }
519    }
520
521    #[track_caller]
522    #[cfg_attr(feature = "tracing", tracing::instrument(level="trace",
523        skip(self, kernel, bindings),
524        fields(
525            kernel.name = %kernel.name(),
526            kernel.id = %kernel.id(),
527        )
528    ))]
529    unsafe fn launch_inner(
530        &self,
531        kernel: <R::Server as ComputeServer>::Kernel,
532        count: CubeCount,
533        bindings: Bindings,
534        mode: ExecutionMode,
535        stream_id: StreamId,
536    ) -> Result<(), LaunchError> {
537        let level = self.utilities.logger.profile_level();
538
539        match level {
540            None | Some(ProfileLevel::ExecutionOnly) => {
541                let mut state = self.context.lock();
542                let name = kernel.name();
543
544                let result = unsafe { state.launch(kernel, count, bindings, mode, stream_id) };
545
546                if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
547                    let info = type_name_format(name, TypeNameFormatLevel::Balanced);
548                    self.utilities.logger.register_execution(info);
549                }
550                result
551            }
552            Some(level) => {
553                let name = kernel.name();
554                let kernel_id = kernel.id();
555                let (result, profile) = self
556                    .profile(
557                        || unsafe {
558                            let mut state = self.context.lock();
559                            state.launch(kernel, count.clone(), bindings, mode, stream_id)
560                        },
561                        name,
562                    )
563                    .unwrap();
564                let info = match level {
565                    ProfileLevel::Full => {
566                        format!("{name}: {kernel_id} CubeCount {count:?}")
567                    }
568                    _ => type_name_format(name, TypeNameFormatLevel::Balanced),
569                };
570                self.utilities.logger.register_profiled(info, profile);
571                result
572            }
573        }
574    }
575
576    /// Launches the `kernel` with the given `bindings`.
577    #[track_caller]
578    pub fn launch(
579        &self,
580        kernel: <R::Server as ComputeServer>::Kernel,
581        count: CubeCount,
582        bindings: Bindings,
583    ) -> Result<(), LaunchError> {
584        // SAFETY: Using checked execution mode.
585        unsafe {
586            self.launch_inner(
587                kernel,
588                count,
589                bindings,
590                ExecutionMode::Checked,
591                self.stream_id(),
592            )
593        }
594    }
595
596    /// Launches the `kernel` with the given `bindings` without performing any bound checks.
597    ///
598    /// # Safety
599    ///
600    /// To ensure this is safe, you must verify your kernel:
601    /// - Has no out-of-bound reads and writes that can happen.
602    /// - Has no infinite loops that might never terminate.
603    #[track_caller]
604    pub unsafe fn launch_unchecked(
605        &self,
606        kernel: <R::Server as ComputeServer>::Kernel,
607        count: CubeCount,
608        bindings: Bindings,
609    ) -> Result<(), LaunchError> {
610        // SAFETY: Caller has to uphold kernel being safe.
611        unsafe {
612            self.launch_inner(
613                kernel,
614                count,
615                bindings,
616                ExecutionMode::Unchecked,
617                self.stream_id(),
618            )
619        }
620    }
621
622    /// Flush all outstanding commands.
623    pub fn flush(&self) {
624        let stream_id = self.stream_id();
625        self.context.lock().flush(stream_id)
626    }
627
628    /// Wait for the completion of every task in the server.
629    pub fn sync(&self) -> DynFut<Result<(), ExecutionError>> {
630        let stream_id = self.stream_id();
631        let mut state = self.context.lock();
632        let fut = state.sync(stream_id);
633        core::mem::drop(state);
634        self.utilities.logger.profile_summary();
635
636        fut
637    }
638
639    /// Get the features supported by the compute server.
640    pub fn properties(&self) -> &DeviceProperties {
641        &self.utilities.properties
642    }
643
644    /// # Warning
645    ///
646    /// For private use only.
647    pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
648        Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
649    }
650
651    /// Get the current memory usage of this client.
652    pub fn memory_usage(&self) -> MemoryUsage {
653        self.context.lock().memory_usage(self.stream_id())
654    }
655
656    /// Change the memory allocation mode.
657    ///
658    /// # Safety
659    ///
660    /// This function isn't thread safe and might create memory leaks.
661    pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
662        self.context.lock().allocation_mode(mode, self.stream_id())
663    }
664
665    /// Use a persistent memory strategy to execute the provided function.
666    ///
667    /// # Notes
668    ///
669    /// - Using that memory strategy is beneficial for stating model parameters and similar workflows.
670    /// - You can call [`Self::memory_cleanup()`] if you want to free persistent memory.
671    pub fn memory_persistent_allocation<Input, Output, Func: Fn(Input) -> Output>(
672        &self,
673        input: Input,
674        func: Func,
675    ) -> Output {
676        let device_guard = self.context.lock_device();
677
678        self.context
679            .lock()
680            .allocation_mode(MemoryAllocationMode::Persistent, self.stream_id());
681
682        let output = func(input);
683
684        self.context
685            .lock()
686            .allocation_mode(MemoryAllocationMode::Auto, self.stream_id());
687
688        core::mem::drop(device_guard);
689
690        output
691    }
692
693    /// Ask the client to release memory that it can release.
694    ///
695    /// Nb: Results will vary on what the memory allocator deems beneficial,
696    /// so it's not guaranteed any memory is freed.
697    pub fn memory_cleanup(&self) {
698        self.context.lock().memory_cleanup(self.stream_id())
699    }
700
701    /// Measure the execution time of some inner operations.
702    #[track_caller]
703    pub fn profile<O>(
704        &self,
705        func: impl FnOnce() -> O,
706        #[allow(unused)] func_name: &str,
707    ) -> Result<(O, ProfileDuration), ProfileError> {
708        // Get the outer caller. For execute() this points straight to the
709        // cube kernel. For general profiling it points to whoever calls profile.
710        #[cfg(feature = "profile-tracy")]
711        let location = std::panic::Location::caller();
712
713        // Make a CPU span. If the server has system profiling this is all you need.
714        #[cfg(feature = "profile-tracy")]
715        let _span = tracy_client::Client::running().unwrap().span_alloc(
716            None,
717            func_name,
718            location.file(),
719            location.line(),
720            0,
721        );
722
723        let device_guard = self.context.lock_device();
724
725        #[cfg(feature = "profile-tracy")]
726        let gpu_span = if self.utilities.properties.timing_method == TimingMethod::Device {
727            let gpu_span = self
728                .utilities
729                .gpu_client
730                .span_alloc(func_name, "profile", location.file(), location.line())
731                .unwrap();
732            Some(gpu_span)
733        } else {
734            None
735        };
736
737        let token = self.context.lock().start_profile(self.stream_id());
738
739        let out = func();
740
741        #[allow(unused_mut, reason = "Used in profile-tracy")]
742        let mut result = self.context.lock().end_profile(self.stream_id(), token);
743
744        #[cfg(feature = "profile-tracy")]
745        if let Some(mut gpu_span) = gpu_span {
746            gpu_span.end_zone();
747            let epoch = self.utilities.epoch_time;
748            // Add in the work to upload the timestamp data.
749            result = result.map(|result| {
750                ProfileDuration::new(
751                    Box::pin(async move {
752                        let ticks = result.resolve().await;
753                        let start_duration = ticks.start_duration_since(epoch).as_nanos() as i64;
754                        let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
755                        gpu_span.upload_timestamp_start(start_duration);
756                        gpu_span.upload_timestamp_end(end_duration);
757                        ticks
758                    }),
759                    TimingMethod::Device,
760                )
761            });
762        }
763        core::mem::drop(device_guard);
764
765        match result {
766            Ok(result) => Ok((out, result)),
767            Err(err) => Err(err),
768        }
769    }
770
771    /// Transfer data from one client to another
772    #[cfg_attr(
773        feature = "tracing",
774        tracing::instrument(
775            level = "trace",
776            skip(self, src_descriptor, alloc_descriptor, dst_server)
777        )
778    )]
779    fn change_client_sync(
780        &self,
781        src_descriptor: CopyDescriptor<'_>,
782        alloc_descriptor: AllocationDescriptor<'_>,
783        dst_server: &Self,
784    ) -> Allocation {
785        let shape = src_descriptor.shape;
786        let elem_size = src_descriptor.elem_size;
787        let stream_id = self.stream_id();
788
789        // Allocate destination
790        let alloc = dst_server
791            .context
792            .lock()
793            .create(vec![alloc_descriptor], self.stream_id())
794            .unwrap()
795            .remove(0);
796
797        let read = self.context.lock().read(vec![src_descriptor], stream_id);
798        let mut data = cubecl_common::future::block_on(read).unwrap();
799
800        let desc_descriptor = CopyDescriptor {
801            binding: alloc.handle.clone().binding(),
802            shape,
803            strides: &alloc.strides,
804            elem_size,
805        };
806
807        dst_server
808            .context
809            .lock()
810            .write(vec![(desc_descriptor, data.remove(0))], stream_id)
811            .unwrap();
812
813        alloc
814    }
815
816    /// Returns all line sizes that are useful to perform optimal IO operation on the given element.
817    pub fn io_optimized_line_sizes(&self, size: usize) -> impl Iterator<Item = LineSize> + Clone {
818        let load_width = self.properties().hardware.load_width as usize;
819        let size_bits = size * 8;
820        let max = load_width / size_bits;
821        let max = usize::min(self.properties().hardware.max_line_size, max);
822
823        // If the max is 8, we want to test 1, 2, 4, 8 which is log2(8) + 1.
824        let num_candidates = max.trailing_zeros() + 1;
825
826        (0..num_candidates).map(|i| 2usize.pow(i)).rev()
827    }
828}