cubecl_runtime/
client.rs

1use crate::{
2    config::{TypeNameFormatLevel, type_name_format},
3    kernel::KernelMetadata,
4    logging::ProfileLevel,
5    memory_management::{MemoryAllocationMode, MemoryUsage},
6    runtime::Runtime,
7    server::{
8        Allocation, AllocationDescriptor, AllocationKind, Binding, Bindings, ComputeServer,
9        CopyDescriptor, CubeCount, ExecutionError, ExecutionMode, Handle, IoError, LaunchError,
10        ProfileError, ServerCommunication, ServerUtilities,
11    },
12    storage::{BindingResource, ComputeStorage},
13};
14use alloc::format;
15use alloc::sync::Arc;
16use alloc::vec;
17use alloc::vec::Vec;
18use core::ops::DerefMut;
19use cubecl_common::{
20    bytes::{AllocationProperty, Bytes},
21    device::{Device, DeviceContext},
22    future::DynFut,
23    profile::ProfileDuration,
24};
25use cubecl_ir::{DeviceProperties, LineSize, StorageType};
26
27#[allow(unused)]
28use cubecl_common::profile::TimingMethod;
29use cubecl_common::stream_id::StreamId;
30
31/// The ComputeClient is the entry point to require tasks from the ComputeServer.
32/// It should be obtained for a specific device via the Compute struct.
33pub struct ComputeClient<R: Runtime> {
34    context: DeviceContext<R::Server>,
35    utilities: Arc<ServerUtilities<R::Server>>,
36    stream_id: Option<StreamId>,
37}
38
39impl<R: Runtime> Clone for ComputeClient<R> {
40    fn clone(&self) -> Self {
41        Self {
42            context: self.context.clone(),
43            utilities: self.utilities.clone(),
44            stream_id: self.stream_id,
45        }
46    }
47}
48
49impl<R: Runtime> ComputeClient<R> {
50    /// Get the info of the current backend.
51    pub fn info(&self) -> &<R::Server as ComputeServer>::Info {
52        &self.utilities.info
53    }
54
55    /// Create a new client with a new server.
56    pub fn init<D: Device>(device: &D, server: R::Server) -> Self {
57        let utilities = server.utilities();
58
59        let context = DeviceContext::<R::Server>::insert(device, server)
60            .expect("Can't create a new client on an already registered server");
61
62        Self {
63            context,
64            utilities,
65            stream_id: None,
66        }
67    }
68
69    /// Load the client for the given device.
70    pub fn load<D: Device>(device: &D) -> Self {
71        let context = DeviceContext::<R::Server>::locate(device);
72        let utilities = context.lock().utilities();
73
74        Self {
75            context,
76            utilities,
77            stream_id: None,
78        }
79    }
80
81    fn stream_id(&self) -> StreamId {
82        match self.stream_id {
83            Some(val) => val,
84            None => StreamId::current(),
85        }
86    }
87
88    /// Set the stream in which the current client is operating on.
89    ///
90    /// # Safety
91    ///
92    /// This is highly unsafe and should probably only be used by the CubeCL/Burn projects for now.
93    pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
94        self.stream_id = Some(stream_id);
95    }
96
97    fn do_read(&self, descriptors: Vec<CopyDescriptor<'_>>) -> DynFut<Result<Vec<Bytes>, IoError>> {
98        let stream_id = self.stream_id();
99        let mut state = self.context.lock();
100        let fut = state.read(descriptors, stream_id);
101        core::mem::drop(state);
102        fut
103    }
104
105    /// Given bindings, returns owned resources as bytes.
106    pub fn read_async(
107        &self,
108        handles: Vec<Handle>,
109    ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
110        let strides = [1];
111        let shapes = handles
112            .iter()
113            .map(|it| [it.size() as usize])
114            .collect::<Vec<_>>();
115        let bindings = handles
116            .into_iter()
117            .map(|it| it.binding())
118            .collect::<Vec<_>>();
119        let descriptors = bindings
120            .into_iter()
121            .zip(shapes.iter())
122            .map(|(binding, shape)| CopyDescriptor::new(binding, shape, &strides, 1))
123            .collect();
124
125        self.do_read(descriptors)
126    }
127
128    /// Given bindings, returns owned resources as bytes.
129    ///
130    /// # Remarks
131    ///
132    /// Panics if the read operation fails.
133    pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
134        cubecl_common::reader::read_sync(self.read_async(handles)).expect("TODO")
135    }
136
137    /// Given a binding, returns owned resource as bytes.
138    ///
139    /// # Remarks
140    /// Panics if the read operation fails.
141    pub fn read_one(&self, handle: Handle) -> Bytes {
142        cubecl_common::reader::read_sync(self.read_async(vec![handle]))
143            .expect("TODO")
144            .remove(0)
145    }
146
147    /// Given bindings, returns owned resources as bytes.
148    pub fn read_tensor_async(
149        &self,
150        descriptors: Vec<CopyDescriptor<'_>>,
151    ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
152        self.do_read(descriptors)
153    }
154
155    /// Given bindings, returns owned resources as bytes.
156    ///
157    /// # Remarks
158    ///
159    /// Panics if the read operation fails.
160    ///
161    /// The tensor must be in the same layout as created by the runtime, or more strict.
162    /// Contiguous tensors are always fine, strided tensors are only ok if the stride is similar to
163    /// the one created by the runtime (i.e. padded on only the last dimension). A way to check
164    /// stride compatibility on the runtime will be added in the future.
165    ///
166    /// Also see [ComputeClient::create_tensor].
167    pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor<'_>>) -> Vec<Bytes> {
168        cubecl_common::reader::read_sync(self.read_tensor_async(descriptors)).expect("TODO")
169    }
170
171    /// Given a binding, returns owned resource as bytes.
172    /// See [ComputeClient::read_tensor]
173    pub fn read_one_tensor_async(
174        &self,
175        descriptor: CopyDescriptor<'_>,
176    ) -> impl Future<Output = Result<Bytes, IoError>> + Send {
177        let fut = self.read_tensor_async(vec![descriptor]);
178
179        async { Ok(fut.await?.remove(0)) }
180    }
181
182    /// Given a binding, returns owned resource as bytes.
183    ///
184    /// # Remarks
185    /// Panics if the read operation fails.
186    /// See [ComputeClient::read_tensor]
187    pub fn read_one_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
188        self.read_tensor(vec![descriptor]).remove(0)
189    }
190
191    /// Given a resource handle, returns the storage resource.
192    pub fn get_resource(
193        &self,
194        binding: Binding,
195    ) -> BindingResource<<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource> {
196        let stream_id = self.stream_id();
197        self.context.lock().get_resource(binding, stream_id)
198    }
199
200    fn do_create_from_slices(
201        &self,
202        descriptors: Vec<AllocationDescriptor<'_>>,
203        slices: Vec<&[u8]>,
204    ) -> Result<Vec<Allocation>, IoError> {
205        let mut state = self.context.lock();
206        let allocations = state.create(descriptors.clone(), self.stream_id())?;
207        let descriptors = descriptors
208            .into_iter()
209            .zip(allocations.iter())
210            .zip(slices)
211            .map(|((desc, alloc), data)| {
212                (
213                    CopyDescriptor::new(
214                        alloc.handle.clone().binding(),
215                        desc.shape,
216                        &alloc.strides,
217                        desc.elem_size,
218                    ),
219                    Bytes::from_bytes_vec(data.to_vec()),
220                )
221            })
222            .collect();
223        let stream_id = self.stream_id();
224        state.write(descriptors, stream_id)?;
225        Ok(allocations)
226    }
227
228    fn do_create(
229        &self,
230        descriptors: Vec<AllocationDescriptor<'_>>,
231        mut data: Vec<Bytes>,
232    ) -> Result<Vec<Allocation>, IoError> {
233        self.staging(data.iter_mut(), true);
234
235        let mut state = self.context.lock();
236        let allocations = state.create(descriptors.clone(), self.stream_id())?;
237        let descriptors = descriptors
238            .into_iter()
239            .zip(allocations.iter())
240            .zip(data)
241            .map(|((desc, alloc), data)| {
242                (
243                    CopyDescriptor::new(
244                        alloc.handle.clone().binding(),
245                        desc.shape,
246                        &alloc.strides,
247                        desc.elem_size,
248                    ),
249                    data,
250                )
251            })
252            .collect();
253        let stream_id = self.stream_id();
254        state.write(descriptors, stream_id)?;
255        Ok(allocations)
256    }
257
258    /// Returns a resource handle containing the given data.
259    ///
260    /// # Notes
261    ///
262    /// Prefer using the more efficient [Self::create] function.
263    pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
264        let shape = [slice.len()];
265
266        self.do_create_from_slices(
267            vec![AllocationDescriptor::new(
268                AllocationKind::Contiguous,
269                &shape,
270                1,
271            )],
272            vec![slice],
273        )
274        .unwrap()
275        .remove(0)
276        .handle
277    }
278
279    /// Returns a resource handle containing the given [Bytes].
280    pub fn create(&self, data: Bytes) -> Handle {
281        let shape = [data.len()];
282
283        self.do_create(
284            vec![AllocationDescriptor::new(
285                AllocationKind::Contiguous,
286                &shape,
287                1,
288            )],
289            vec![data],
290        )
291        .unwrap()
292        .remove(0)
293        .handle
294    }
295
296    /// Given a resource and shape, stores it and returns the tensor handle and strides.
297    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
298    /// should be taken when indexing.
299    ///
300    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
301    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
302    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
303    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
304    /// can load as much data as possible in a single instruction. It may be aligned even more to
305    /// also take cache lines into account.
306    ///
307    /// However, the stride must be taken into account when indexing and reading the tensor
308    /// (also see [ComputeClient::read_tensor]).
309    ///
310    /// # Notes
311    ///
312    /// Prefer using [Self::create_tensor] for better performance.
313    pub fn create_tensor_from_slice(
314        &self,
315        slice: &[u8],
316        shape: &[usize],
317        elem_size: usize,
318    ) -> Allocation {
319        self.do_create_from_slices(
320            vec![AllocationDescriptor::new(
321                AllocationKind::Optimized,
322                shape,
323                elem_size,
324            )],
325            vec![slice],
326        )
327        .unwrap()
328        .remove(0)
329    }
330
331    /// Given a resource and shape, stores it and returns the tensor handle and strides.
332    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
333    /// should be taken when indexing.
334    ///
335    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
336    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
337    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
338    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
339    /// can load as much data as possible in a single instruction. It may be aligned even more to
340    /// also take cache lines into account.
341    ///
342    /// However, the stride must be taken into account when indexing and reading the tensor
343    /// (also see [ComputeClient::read_tensor]).
344    pub fn create_tensor(&self, bytes: Bytes, shape: &[usize], elem_size: usize) -> Allocation {
345        self.do_create(
346            vec![AllocationDescriptor::new(
347                AllocationKind::Optimized,
348                shape,
349                elem_size,
350            )],
351            vec![bytes],
352        )
353        .unwrap()
354        .remove(0)
355    }
356
357    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
358    /// handle, and returns the handles for them.
359    /// See [ComputeClient::create_tensor]
360    ///
361    /// # Notes
362    ///
363    /// Prefer using [Self::create_tensors] for better performance.
364    pub fn create_tensors_from_slices(
365        &self,
366        descriptors: Vec<(AllocationDescriptor<'_>, &[u8])>,
367    ) -> Vec<Allocation> {
368        let (descriptors, data) = descriptors.into_iter().unzip();
369
370        self.do_create_from_slices(descriptors, data).unwrap()
371    }
372
373    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
374    /// handle, and returns the handles for them.
375    /// See [ComputeClient::create_tensor]
376    pub fn create_tensors(
377        &self,
378        descriptors: Vec<(AllocationDescriptor<'_>, Bytes)>,
379    ) -> Vec<Allocation> {
380        let (descriptors, data) = descriptors.into_iter().unzip();
381
382        self.do_create(descriptors, data).unwrap()
383    }
384
385    fn do_empty(
386        &self,
387        descriptors: Vec<AllocationDescriptor<'_>>,
388    ) -> Result<Vec<Allocation>, IoError> {
389        let mut state = self.context.lock();
390        state.create(descriptors, self.stream_id())
391    }
392
393    /// Reserves `size` bytes in the storage, and returns a handle over them.
394    pub fn empty(&self, size: usize) -> Handle {
395        let shape = [size];
396        let descriptor = AllocationDescriptor::new(AllocationKind::Contiguous, &shape, 1);
397        self.do_empty(vec![descriptor]).unwrap().remove(0).handle
398    }
399
400    /// Reserves `shape` in the storage, and returns a tensor handle for it.
401    /// See [ComputeClient::create_tensor]
402    pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> Allocation {
403        let descriptor = AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size);
404        self.do_empty(vec![descriptor]).unwrap().remove(0)
405    }
406
407    /// Reserves all `shapes` in a single storage buffer, and returns the handles for them.
408    /// See [ComputeClient::create_tensor]
409    pub fn empty_tensors(&self, descriptors: Vec<AllocationDescriptor<'_>>) -> Vec<Allocation> {
410        self.do_empty(descriptors).unwrap()
411    }
412
413    /// Marks the given [Bytes] as being a staging buffer, maybe transferring it to pinned memory
414    /// for faster data transfer with compute device.
415    pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
416    where
417        I: Iterator<Item = &'a mut Bytes>,
418    {
419        let has_staging = |b: &Bytes| match b.property() {
420            AllocationProperty::Pinned => false,
421            AllocationProperty::File => true,
422            AllocationProperty::Native | AllocationProperty::Other => !file_only,
423        };
424
425        let mut to_be_updated = Vec::new();
426        let sizes = bytes
427            .filter_map(|b| match has_staging(b) {
428                true => {
429                    let len = b.len();
430                    to_be_updated.push(b);
431                    Some(len)
432                }
433                false => None,
434            })
435            .collect::<Vec<usize>>();
436
437        if sizes.is_empty() {
438            return;
439        }
440
441        let stream_id = self.stream_id();
442        let mut context = self.context.lock();
443        let stagings = match context.staging(&sizes, stream_id) {
444            Ok(val) => val,
445            Err(_) => return,
446        };
447        core::mem::drop(context);
448
449        to_be_updated
450            .into_iter()
451            .zip(stagings)
452            .for_each(|(b, mut staging)| {
453                b.copy_into(&mut staging);
454                core::mem::swap(b, &mut staging);
455            });
456    }
457
458    /// Transfer data from one client to another
459    #[cfg_attr(
460        feature = "tracing",
461        tracing::instrument(level = "trace", skip(self, src, dst_server))
462    )]
463    pub fn to_client(&self, src: Handle, dst_server: &Self) -> Allocation {
464        let shape = [src.size() as usize];
465        let src_descriptor = src.copy_descriptor(&shape, &[1], 1);
466
467        if R::Server::SERVER_COMM_ENABLED {
468            self.to_client_tensor(src_descriptor, dst_server)
469        } else {
470            let alloc_desc = AllocationDescriptor::new(
471                AllocationKind::Contiguous,
472                src_descriptor.shape,
473                src_descriptor.elem_size,
474            );
475            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
476        }
477    }
478
479    /// Transfer data from one client to another
480    ///
481    /// Make sure the source description can be read in a contiguous manner.
482    #[cfg_attr(
483        feature = "tracing",
484        tracing::instrument(level = "trace", skip(self, src_descriptor, dst_server))
485    )]
486    pub fn to_client_tensor(
487        &self,
488        src_descriptor: CopyDescriptor<'_>,
489        dst_server: &Self,
490    ) -> Allocation {
491        if R::Server::SERVER_COMM_ENABLED {
492            let guard = self.context.lock_device_kind();
493            let mut server_src = self.context.lock();
494            let mut server_dst = dst_server.context.lock();
495
496            let copied = R::Server::copy(
497                server_src.deref_mut(),
498                server_dst.deref_mut(),
499                src_descriptor,
500                self.stream_id(),
501                dst_server.stream_id(),
502            )
503            .unwrap();
504            core::mem::drop(server_src);
505            core::mem::drop(server_dst);
506            core::mem::drop(guard);
507            copied
508        } else {
509            let alloc_desc = AllocationDescriptor::new(
510                AllocationKind::Optimized,
511                src_descriptor.shape,
512                src_descriptor.elem_size,
513            );
514            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
515        }
516    }
517
518    #[track_caller]
519    #[cfg_attr(feature = "tracing", tracing::instrument(level="trace",
520        skip(self, kernel, bindings),
521        fields(
522            kernel.name = %kernel.name(),
523            kernel.id = %kernel.id(),
524        )
525    ))]
526    unsafe fn launch_inner(
527        &self,
528        kernel: <R::Server as ComputeServer>::Kernel,
529        count: CubeCount,
530        bindings: Bindings,
531        mode: ExecutionMode,
532        stream_id: StreamId,
533    ) -> Result<(), LaunchError> {
534        let level = self.utilities.logger.profile_level();
535
536        match level {
537            None | Some(ProfileLevel::ExecutionOnly) => {
538                let mut state = self.context.lock();
539                let name = kernel.name();
540
541                let result = unsafe { state.launch(kernel, count, bindings, mode, stream_id) };
542
543                if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
544                    let info = type_name_format(name, TypeNameFormatLevel::Balanced);
545                    self.utilities.logger.register_execution(info);
546                }
547                result
548            }
549            Some(level) => {
550                let name = kernel.name();
551                let kernel_id = kernel.id();
552                let (result, profile) = self
553                    .profile(
554                        || unsafe {
555                            let mut state = self.context.lock();
556                            state.launch(kernel, count.clone(), bindings, mode, stream_id)
557                        },
558                        name,
559                    )
560                    .unwrap();
561                let info = match level {
562                    ProfileLevel::Full => {
563                        format!("{name}: {kernel_id} CubeCount {count:?}")
564                    }
565                    _ => type_name_format(name, TypeNameFormatLevel::Balanced),
566                };
567                self.utilities.logger.register_profiled(info, profile);
568                result
569            }
570        }
571    }
572
573    /// Launches the `kernel` with the given `bindings`.
574    #[track_caller]
575    pub fn launch(
576        &self,
577        kernel: <R::Server as ComputeServer>::Kernel,
578        count: CubeCount,
579        bindings: Bindings,
580    ) -> Result<(), LaunchError> {
581        // SAFETY: Using checked execution mode.
582        unsafe {
583            self.launch_inner(
584                kernel,
585                count,
586                bindings,
587                ExecutionMode::Checked,
588                self.stream_id(),
589            )
590        }
591    }
592
593    /// Launches the `kernel` with the given `bindings` without performing any bound checks.
594    ///
595    /// # Safety
596    ///
597    /// To ensure this is safe, you must verify your kernel:
598    /// - Has no out-of-bound reads and writes that can happen.
599    /// - Has no infinite loops that might never terminate.
600    #[track_caller]
601    pub unsafe fn launch_unchecked(
602        &self,
603        kernel: <R::Server as ComputeServer>::Kernel,
604        count: CubeCount,
605        bindings: Bindings,
606    ) -> Result<(), LaunchError> {
607        // SAFETY: Caller has to uphold kernel being safe.
608        unsafe {
609            self.launch_inner(
610                kernel,
611                count,
612                bindings,
613                ExecutionMode::Unchecked,
614                self.stream_id(),
615            )
616        }
617    }
618
619    /// Flush all outstanding commands.
620    pub fn flush(&self) {
621        let stream_id = self.stream_id();
622        self.context.lock().flush(stream_id)
623    }
624
625    /// Wait for the completion of every task in the server.
626    pub fn sync(&self) -> DynFut<Result<(), ExecutionError>> {
627        let stream_id = self.stream_id();
628        let mut state = self.context.lock();
629        let fut = state.sync(stream_id);
630        core::mem::drop(state);
631        self.utilities.logger.profile_summary();
632
633        fut
634    }
635
636    /// Get the features supported by the compute server.
637    pub fn properties(&self) -> &DeviceProperties {
638        &self.utilities.properties
639    }
640
641    /// # Warning
642    ///
643    /// For private use only.
644    pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
645        Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
646    }
647
648    /// Get the current memory usage of this client.
649    pub fn memory_usage(&self) -> MemoryUsage {
650        self.context.lock().memory_usage(self.stream_id())
651    }
652
653    /// Change the memory allocation mode.
654    ///
655    /// # Safety
656    ///
657    /// This function isn't thread safe and might create memory leaks.
658    pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
659        self.context.lock().allocation_mode(mode, self.stream_id())
660    }
661
662    /// Use a persistent memory strategy to execute the provided function.
663    ///
664    /// # Notes
665    ///
666    /// - Using that memory strategy is beneficial for stating model parameters and similar workflows.
667    /// - You can call [Self::memory_cleanup()] if you want to free persistent memory.
668    pub fn memory_persistent_allocation<Input, Output, Func: Fn(Input) -> Output>(
669        &self,
670        input: Input,
671        func: Func,
672    ) -> Output {
673        let device_guard = self.context.lock_device();
674
675        self.context
676            .lock()
677            .allocation_mode(MemoryAllocationMode::Persistent, self.stream_id());
678
679        let output = func(input);
680
681        self.context
682            .lock()
683            .allocation_mode(MemoryAllocationMode::Auto, self.stream_id());
684
685        core::mem::drop(device_guard);
686
687        output
688    }
689
690    /// Ask the client to release memory that it can release.
691    ///
692    /// Nb: Results will vary on what the memory allocator deems beneficial,
693    /// so it's not guaranteed any memory is freed.
694    pub fn memory_cleanup(&self) {
695        self.context.lock().memory_cleanup(self.stream_id())
696    }
697
698    /// Measure the execution time of some inner operations.
699    #[track_caller]
700    pub fn profile<O>(
701        &self,
702        func: impl FnOnce() -> O,
703        #[allow(unused)] func_name: &str,
704    ) -> Result<(O, ProfileDuration), ProfileError> {
705        // Get the outer caller. For execute() this points straight to the
706        // cube kernel. For general profiling it points to whoever calls profile.
707        #[cfg(feature = "profile-tracy")]
708        let location = std::panic::Location::caller();
709
710        // Make a CPU span. If the server has system profiling this is all you need.
711        #[cfg(feature = "profile-tracy")]
712        let _span = tracy_client::Client::running().unwrap().span_alloc(
713            None,
714            func_name,
715            location.file(),
716            location.line(),
717            0,
718        );
719
720        let device_guard = self.context.lock_device();
721
722        #[cfg(feature = "profile-tracy")]
723        let gpu_span = if self.utilities.properties.timing_method == TimingMethod::Device {
724            let gpu_span = self
725                .utilities
726                .gpu_client
727                .span_alloc(func_name, "profile", location.file(), location.line())
728                .unwrap();
729            Some(gpu_span)
730        } else {
731            None
732        };
733
734        let token = self.context.lock().start_profile(self.stream_id());
735
736        let out = func();
737
738        #[allow(unused_mut, reason = "Used in profile-tracy")]
739        let mut result = self.context.lock().end_profile(self.stream_id(), token);
740
741        #[cfg(feature = "profile-tracy")]
742        if let Some(mut gpu_span) = gpu_span {
743            gpu_span.end_zone();
744            let epoch = self.utilities.epoch_time;
745            // Add in the work to upload the timestamp data.
746            result = result.map(|result| {
747                ProfileDuration::new(
748                    Box::pin(async move {
749                        let ticks = result.resolve().await;
750                        let start_duration = ticks.start_duration_since(epoch).as_nanos() as i64;
751                        let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
752                        gpu_span.upload_timestamp_start(start_duration);
753                        gpu_span.upload_timestamp_end(end_duration);
754                        ticks
755                    }),
756                    TimingMethod::Device,
757                )
758            });
759        }
760        core::mem::drop(device_guard);
761
762        match result {
763            Ok(result) => Ok((out, result)),
764            Err(err) => Err(err),
765        }
766    }
767
768    /// Transfer data from one client to another
769    #[cfg_attr(
770        feature = "tracing",
771        tracing::instrument(
772            level = "trace",
773            skip(self, src_descriptor, alloc_descriptor, dst_server)
774        )
775    )]
776    fn change_client_sync(
777        &self,
778        src_descriptor: CopyDescriptor<'_>,
779        alloc_descriptor: AllocationDescriptor<'_>,
780        dst_server: &Self,
781    ) -> Allocation {
782        let shape = src_descriptor.shape;
783        let elem_size = src_descriptor.elem_size;
784        let stream_id = self.stream_id();
785
786        // Allocate destination
787        let alloc = dst_server
788            .context
789            .lock()
790            .create(vec![alloc_descriptor], self.stream_id())
791            .unwrap()
792            .remove(0);
793
794        let read = self.context.lock().read(vec![src_descriptor], stream_id);
795        let mut data = cubecl_common::future::block_on(read).unwrap();
796
797        let desc_descriptor = CopyDescriptor {
798            binding: alloc.handle.clone().binding(),
799            shape,
800            strides: &alloc.strides,
801            elem_size,
802        };
803
804        dst_server
805            .context
806            .lock()
807            .write(vec![(desc_descriptor, data.remove(0))], stream_id)
808            .unwrap();
809
810        alloc
811    }
812
813    /// Returns all line sizes that are useful to perform optimal IO operation on the given element.
814    pub fn io_optimized_line_sizes(
815        &self,
816        elem: &StorageType,
817    ) -> impl Iterator<Item = LineSize> + Clone {
818        let load_width = self.properties().hardware.load_width as usize;
819        let max = load_width / elem.size_bits();
820        let supported = R::supported_line_sizes();
821        supported.iter().filter(move |v| **v <= max).cloned()
822    }
823
824    /// Returns all line sizes that are useful to perform optimal IO operation on the given element.
825    /// Ignores native support, and allows all line sizes. This means the returned size may be
826    /// unrolled, and may not support dynamic indexing.
827    pub fn io_optimized_line_sizes_unchecked(
828        &self,
829        size: usize,
830    ) -> impl Iterator<Item = LineSize> + Clone {
831        let load_width = self.properties().hardware.load_width as usize;
832        let size_bits = size * 8;
833        let max = load_width / size_bits;
834        // This makes this effectively the same as checked, if it doesn't work it's a problem with
835        // unroll that should be investigated instead. But separate PR.
836        let max = usize::min(R::max_global_line_size(), max);
837
838        // If the max is 8, we want to test 1, 2, 4, 8 which is log2(8) + 1.
839        let num_candidates = max.trailing_zeros() + 1;
840
841        (0..num_candidates).map(|i| 2usize.pow(i)).rev()
842    }
843}