cubecl_runtime/
client.rs

1use crate::{
2    DeviceProperties,
3    config::{TypeNameFormatLevel, type_name_format},
4    kernel::KernelMetadata,
5    logging::ProfileLevel,
6    memory_management::{MemoryAllocationMode, MemoryUsage},
7    runtime::Runtime,
8    server::{
9        Allocation, AllocationDescriptor, AllocationKind, Binding, Bindings, ComputeServer,
10        CopyDescriptor, CubeCount, ExecutionError, Handle, IoError, LaunchError, ProfileError,
11        ServerCommunication, ServerUtilities,
12    },
13    storage::{BindingResource, ComputeStorage},
14};
15use alloc::format;
16use alloc::sync::Arc;
17use alloc::vec;
18use alloc::vec::Vec;
19use core::ops::DerefMut;
20use cubecl_common::{
21    ExecutionMode,
22    bytes::{AllocationProperty, Bytes},
23    device::{Device, DeviceContext},
24    future::DynFut,
25    profile::ProfileDuration,
26};
27use cubecl_ir::StorageType;
28
29#[allow(unused)]
30use cubecl_common::profile::TimingMethod;
31use cubecl_common::stream_id::StreamId;
32
33/// The ComputeClient is the entry point to require tasks from the ComputeServer.
34/// It should be obtained for a specific device via the Compute struct.
35pub struct ComputeClient<R: Runtime> {
36    context: DeviceContext<R::Server>,
37    utilities: Arc<ServerUtilities<R::Server>>,
38    stream_id: Option<StreamId>,
39}
40
41impl<R: Runtime> Clone for ComputeClient<R> {
42    fn clone(&self) -> Self {
43        Self {
44            context: self.context.clone(),
45            utilities: self.utilities.clone(),
46            stream_id: self.stream_id,
47        }
48    }
49}
50
51impl<R: Runtime> ComputeClient<R> {
52    /// Get the info of the current backend.
53    pub fn info(&self) -> &<R::Server as ComputeServer>::Info {
54        &self.utilities.info
55    }
56
57    /// Create a new client with a new server.
58    pub fn init<D: Device>(device: &D, server: R::Server) -> Self {
59        let utilities = server.utilities();
60
61        let context = DeviceContext::<R::Server>::insert(device, server)
62            .expect("Can't create a new client on an already registered server");
63
64        Self {
65            context,
66            utilities,
67            stream_id: None,
68        }
69    }
70
71    /// Load the client for the given device.
72    pub fn load<D: Device>(device: &D) -> Self {
73        let context = DeviceContext::<R::Server>::locate(device);
74        let utilities = context.lock().utilities();
75
76        Self {
77            context,
78            utilities,
79            stream_id: None,
80        }
81    }
82
83    fn stream_id(&self) -> StreamId {
84        match self.stream_id {
85            Some(val) => val,
86            None => StreamId::current(),
87        }
88    }
89
90    /// Set the stream in which the current client is operating on.
91    ///
92    /// # Safety
93    ///
94    /// This is highly unsafe and should probably only be used by the CubeCL/Burn projects for now.
95    pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
96        self.stream_id = Some(stream_id);
97    }
98
99    fn do_read(&self, descriptors: Vec<CopyDescriptor<'_>>) -> DynFut<Result<Vec<Bytes>, IoError>> {
100        let stream_id = self.stream_id();
101        let mut state = self.context.lock();
102        let fut = state.read(descriptors, stream_id);
103        core::mem::drop(state);
104        fut
105    }
106
107    /// Given bindings, returns owned resources as bytes.
108    pub fn read_async(
109        &self,
110        handles: Vec<Handle>,
111    ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
112        let strides = [1];
113        let shapes = handles
114            .iter()
115            .map(|it| [it.size() as usize])
116            .collect::<Vec<_>>();
117        let bindings = handles
118            .into_iter()
119            .map(|it| it.binding())
120            .collect::<Vec<_>>();
121        let descriptors = bindings
122            .into_iter()
123            .zip(shapes.iter())
124            .map(|(binding, shape)| CopyDescriptor::new(binding, shape, &strides, 1))
125            .collect();
126
127        self.do_read(descriptors)
128    }
129
130    /// Given bindings, returns owned resources as bytes.
131    ///
132    /// # Remarks
133    ///
134    /// Panics if the read operation fails.
135    pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
136        cubecl_common::reader::read_sync(self.read_async(handles)).expect("TODO")
137    }
138
139    /// Given a binding, returns owned resource as bytes.
140    ///
141    /// # Remarks
142    /// Panics if the read operation fails.
143    pub fn read_one(&self, handle: Handle) -> Bytes {
144        cubecl_common::reader::read_sync(self.read_async(vec![handle]))
145            .expect("TODO")
146            .remove(0)
147    }
148
149    /// Given bindings, returns owned resources as bytes.
150    pub fn read_tensor_async(
151        &self,
152        descriptors: Vec<CopyDescriptor<'_>>,
153    ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
154        self.do_read(descriptors)
155    }
156
157    /// Given bindings, returns owned resources as bytes.
158    ///
159    /// # Remarks
160    ///
161    /// Panics if the read operation fails.
162    ///
163    /// The tensor must be in the same layout as created by the runtime, or more strict.
164    /// Contiguous tensors are always fine, strided tensors are only ok if the stride is similar to
165    /// the one created by the runtime (i.e. padded on only the last dimension). A way to check
166    /// stride compatibility on the runtime will be added in the future.
167    ///
168    /// Also see [ComputeClient::create_tensor].
169    pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor<'_>>) -> Vec<Bytes> {
170        cubecl_common::reader::read_sync(self.read_tensor_async(descriptors)).expect("TODO")
171    }
172
173    /// Given a binding, returns owned resource as bytes.
174    /// See [ComputeClient::read_tensor]
175    pub fn read_one_tensor_async(
176        &self,
177        descriptor: CopyDescriptor<'_>,
178    ) -> impl Future<Output = Result<Bytes, IoError>> + Send {
179        let fut = self.read_tensor_async(vec![descriptor]);
180
181        async { Ok(fut.await?.remove(0)) }
182    }
183
184    /// Given a binding, returns owned resource as bytes.
185    ///
186    /// # Remarks
187    /// Panics if the read operation fails.
188    /// See [ComputeClient::read_tensor]
189    pub fn read_one_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
190        self.read_tensor(vec![descriptor]).remove(0)
191    }
192
193    /// Given a resource handle, returns the storage resource.
194    pub fn get_resource(
195        &self,
196        binding: Binding,
197    ) -> BindingResource<<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource> {
198        let stream_id = self.stream_id();
199        self.context.lock().get_resource(binding, stream_id)
200    }
201
202    fn do_create_from_slices(
203        &self,
204        descriptors: Vec<AllocationDescriptor<'_>>,
205        slices: Vec<&[u8]>,
206    ) -> Result<Vec<Allocation>, IoError> {
207        let mut state = self.context.lock();
208        let allocations = state.create(descriptors.clone(), self.stream_id())?;
209        let descriptors = descriptors
210            .into_iter()
211            .zip(allocations.iter())
212            .zip(slices)
213            .map(|((desc, alloc), data)| {
214                (
215                    CopyDescriptor::new(
216                        alloc.handle.clone().binding(),
217                        desc.shape,
218                        &alloc.strides,
219                        desc.elem_size,
220                    ),
221                    Bytes::from_bytes_vec(data.to_vec()),
222                )
223            })
224            .collect();
225        let stream_id = self.stream_id();
226        state.write(descriptors, stream_id)?;
227        Ok(allocations)
228    }
229
230    fn do_create(
231        &self,
232        descriptors: Vec<AllocationDescriptor<'_>>,
233        mut data: Vec<Bytes>,
234    ) -> Result<Vec<Allocation>, IoError> {
235        self.staging(data.iter_mut(), true);
236
237        let mut state = self.context.lock();
238        let allocations = state.create(descriptors.clone(), self.stream_id())?;
239        let descriptors = descriptors
240            .into_iter()
241            .zip(allocations.iter())
242            .zip(data)
243            .map(|((desc, alloc), data)| {
244                (
245                    CopyDescriptor::new(
246                        alloc.handle.clone().binding(),
247                        desc.shape,
248                        &alloc.strides,
249                        desc.elem_size,
250                    ),
251                    data,
252                )
253            })
254            .collect();
255        let stream_id = self.stream_id();
256        state.write(descriptors, stream_id)?;
257        Ok(allocations)
258    }
259
260    /// Returns a resource handle containing the given data.
261    ///
262    /// # Notes
263    ///
264    /// Prefer using the more efficient [Self::create] function.
265    pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
266        let shape = [slice.len()];
267
268        self.do_create_from_slices(
269            vec![AllocationDescriptor::new(
270                AllocationKind::Contiguous,
271                &shape,
272                1,
273            )],
274            vec![slice],
275        )
276        .unwrap()
277        .remove(0)
278        .handle
279    }
280
281    /// Returns a resource handle containing the given [Bytes].
282    pub fn create(&self, data: Bytes) -> Handle {
283        let shape = [data.len()];
284
285        self.do_create(
286            vec![AllocationDescriptor::new(
287                AllocationKind::Contiguous,
288                &shape,
289                1,
290            )],
291            vec![data],
292        )
293        .unwrap()
294        .remove(0)
295        .handle
296    }
297
298    /// Given a resource and shape, stores it and returns the tensor handle and strides.
299    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
300    /// should be taken when indexing.
301    ///
302    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
303    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
304    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
305    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
306    /// can load as much data as possible in a single instruction. It may be aligned even more to
307    /// also take cache lines into account.
308    ///
309    /// However, the stride must be taken into account when indexing and reading the tensor
310    /// (also see [ComputeClient::read_tensor]).
311    ///
312    /// # Notes
313    ///
314    /// Prefer using [Self::create_tensor] for better performance.
315    pub fn create_tensor_from_slice(
316        &self,
317        slice: &[u8],
318        shape: &[usize],
319        elem_size: usize,
320    ) -> Allocation {
321        self.do_create_from_slices(
322            vec![AllocationDescriptor::new(
323                AllocationKind::Optimized,
324                shape,
325                elem_size,
326            )],
327            vec![slice],
328        )
329        .unwrap()
330        .remove(0)
331    }
332
333    /// Given a resource and shape, stores it and returns the tensor handle and strides.
334    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
335    /// should be taken when indexing.
336    ///
337    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
338    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
339    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
340    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
341    /// can load as much data as possible in a single instruction. It may be aligned even more to
342    /// also take cache lines into account.
343    ///
344    /// However, the stride must be taken into account when indexing and reading the tensor
345    /// (also see [ComputeClient::read_tensor]).
346    pub fn create_tensor(&self, bytes: Bytes, shape: &[usize], elem_size: usize) -> Allocation {
347        self.do_create(
348            vec![AllocationDescriptor::new(
349                AllocationKind::Optimized,
350                shape,
351                elem_size,
352            )],
353            vec![bytes],
354        )
355        .unwrap()
356        .remove(0)
357    }
358
359    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
360    /// handle, and returns the handles for them.
361    /// See [ComputeClient::create_tensor]
362    ///
363    /// # Notes
364    ///
365    /// Prefer using [Self::create_tensors] for better performance.
366    pub fn create_tensors_from_slices(
367        &self,
368        descriptors: Vec<(AllocationDescriptor<'_>, &[u8])>,
369    ) -> Vec<Allocation> {
370        let (descriptors, data) = descriptors.into_iter().unzip();
371
372        self.do_create_from_slices(descriptors, data).unwrap()
373    }
374
375    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
376    /// handle, and returns the handles for them.
377    /// See [ComputeClient::create_tensor]
378    pub fn create_tensors(
379        &self,
380        descriptors: Vec<(AllocationDescriptor<'_>, Bytes)>,
381    ) -> Vec<Allocation> {
382        let (descriptors, data) = descriptors.into_iter().unzip();
383
384        self.do_create(descriptors, data).unwrap()
385    }
386
387    fn do_empty(
388        &self,
389        descriptors: Vec<AllocationDescriptor<'_>>,
390    ) -> Result<Vec<Allocation>, IoError> {
391        let mut state = self.context.lock();
392        state.create(descriptors, self.stream_id())
393    }
394
395    /// Reserves `size` bytes in the storage, and returns a handle over them.
396    pub fn empty(&self, size: usize) -> Handle {
397        let shape = [size];
398        let descriptor = AllocationDescriptor::new(AllocationKind::Contiguous, &shape, 1);
399        self.do_empty(vec![descriptor]).unwrap().remove(0).handle
400    }
401
402    /// Reserves `shape` in the storage, and returns a tensor handle for it.
403    /// See [ComputeClient::create_tensor]
404    pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> Allocation {
405        let descriptor = AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size);
406        self.do_empty(vec![descriptor]).unwrap().remove(0)
407    }
408
409    /// Reserves all `shapes` in a single storage buffer, and returns the handles for them.
410    /// See [ComputeClient::create_tensor]
411    pub fn empty_tensors(&self, descriptors: Vec<AllocationDescriptor<'_>>) -> Vec<Allocation> {
412        self.do_empty(descriptors).unwrap()
413    }
414
415    /// Marks the given [Bytes] as being a staging buffer, maybe transferring it to pinned memory
416    /// for faster data transfer with compute device.
417    pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
418    where
419        I: Iterator<Item = &'a mut Bytes>,
420    {
421        let has_staging = |b: &Bytes| match b.property() {
422            AllocationProperty::Pinned => false,
423            AllocationProperty::File => true,
424            AllocationProperty::Native | AllocationProperty::Other => !file_only,
425        };
426
427        let mut to_be_updated = Vec::new();
428        let sizes = bytes
429            .filter_map(|b| match has_staging(b) {
430                true => {
431                    let len = b.len();
432                    to_be_updated.push(b);
433                    Some(len)
434                }
435                false => None,
436            })
437            .collect::<Vec<usize>>();
438
439        if sizes.is_empty() {
440            return;
441        }
442
443        let stream_id = self.stream_id();
444        let mut context = self.context.lock();
445        let stagings = match context.staging(&sizes, stream_id) {
446            Ok(val) => val,
447            Err(_) => return,
448        };
449        core::mem::drop(context);
450
451        to_be_updated
452            .into_iter()
453            .zip(stagings)
454            .for_each(|(b, mut staging)| {
455                b.copy_into(&mut staging);
456                core::mem::swap(b, &mut staging);
457            });
458    }
459
460    /// Transfer data from one client to another
461    pub fn to_client(&self, src: Handle, dst_server: &Self) -> Allocation {
462        let shape = [src.size() as usize];
463        let src_descriptor = src.copy_descriptor(&shape, &[1], 1);
464
465        if R::Server::SERVER_COMM_ENABLED {
466            self.to_client_tensor(src_descriptor, dst_server)
467        } else {
468            let alloc_desc = AllocationDescriptor::new(
469                AllocationKind::Contiguous,
470                src_descriptor.shape,
471                src_descriptor.elem_size,
472            );
473            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
474        }
475    }
476
477    /// Transfer data from one client to another
478    ///
479    /// Make sure the source description can be read in a contiguous manner.
480    pub fn to_client_tensor(
481        &self,
482        src_descriptor: CopyDescriptor<'_>,
483        dst_server: &Self,
484    ) -> Allocation {
485        if R::Server::SERVER_COMM_ENABLED {
486            let guard = self.context.lock_device_kind();
487            let mut server_src = self.context.lock();
488            let mut server_dst = dst_server.context.lock();
489
490            let copied = R::Server::copy(
491                server_src.deref_mut(),
492                server_dst.deref_mut(),
493                src_descriptor,
494                self.stream_id(),
495                dst_server.stream_id(),
496            )
497            .unwrap();
498            core::mem::drop(server_src);
499            core::mem::drop(server_dst);
500            core::mem::drop(guard);
501            copied
502        } else {
503            let alloc_desc = AllocationDescriptor::new(
504                AllocationKind::Optimized,
505                src_descriptor.shape,
506                src_descriptor.elem_size,
507            );
508            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
509        }
510    }
511
512    #[track_caller]
513    unsafe fn launch_inner(
514        &self,
515        kernel: <R::Server as ComputeServer>::Kernel,
516        count: CubeCount,
517        bindings: Bindings,
518        mode: ExecutionMode,
519        stream_id: StreamId,
520    ) -> Result<(), LaunchError> {
521        let level = self.utilities.logger.profile_level();
522
523        match level {
524            None | Some(ProfileLevel::ExecutionOnly) => {
525                let mut state = self.context.lock();
526                let name = kernel.name();
527
528                let result = unsafe { state.launch(kernel, count, bindings, mode, stream_id) };
529
530                if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
531                    let info = type_name_format(name, TypeNameFormatLevel::Balanced);
532                    self.utilities.logger.register_execution(info);
533                }
534                result
535            }
536            Some(level) => {
537                let name = kernel.name();
538                let kernel_id = kernel.id();
539                let (result, profile) = self
540                    .profile(
541                        || unsafe {
542                            let mut state = self.context.lock();
543                            state.launch(kernel, count.clone(), bindings, mode, stream_id)
544                        },
545                        name,
546                    )
547                    .unwrap();
548                let info = match level {
549                    ProfileLevel::Full => {
550                        format!("{name}: {kernel_id} CubeCount {count:?}")
551                    }
552                    _ => type_name_format(name, TypeNameFormatLevel::Balanced),
553                };
554                self.utilities.logger.register_profiled(info, profile);
555                result
556            }
557        }
558    }
559
560    /// Launches the `kernel` with the given `bindings`.
561    #[track_caller]
562    pub fn launch(
563        &self,
564        kernel: <R::Server as ComputeServer>::Kernel,
565        count: CubeCount,
566        bindings: Bindings,
567    ) -> Result<(), LaunchError> {
568        // SAFETY: Using checked execution mode.
569        unsafe {
570            self.launch_inner(
571                kernel,
572                count,
573                bindings,
574                ExecutionMode::Checked,
575                self.stream_id(),
576            )
577        }
578    }
579
580    /// Launches the `kernel` with the given `bindings` without performing any bound checks.
581    ///
582    /// # Safety
583    ///
584    /// To ensure this is safe, you must verify your kernel:
585    /// - Has no out-of-bound reads and writes that can happen.
586    /// - Has no infinite loops that might never terminate.
587    #[track_caller]
588    pub unsafe fn launch_unchecked(
589        &self,
590        kernel: <R::Server as ComputeServer>::Kernel,
591        count: CubeCount,
592        bindings: Bindings,
593    ) -> Result<(), LaunchError> {
594        // SAFETY: Caller has to uphold kernel being safe.
595        unsafe {
596            self.launch_inner(
597                kernel,
598                count,
599                bindings,
600                ExecutionMode::Unchecked,
601                self.stream_id(),
602            )
603        }
604    }
605
606    /// Flush all outstanding commands.
607    pub fn flush(&self) {
608        let stream_id = self.stream_id();
609        self.context.lock().flush(stream_id)
610    }
611
612    /// Wait for the completion of every task in the server.
613    pub fn sync(&self) -> DynFut<Result<(), ExecutionError>> {
614        let stream_id = self.stream_id();
615        let mut state = self.context.lock();
616        let fut = state.sync(stream_id);
617        core::mem::drop(state);
618        self.utilities.logger.profile_summary();
619
620        fut
621    }
622
623    /// Get the features supported by the compute server.
624    pub fn properties(&self) -> &DeviceProperties {
625        &self.utilities.properties
626    }
627
628    /// # Warning
629    ///
630    /// For private use only.
631    pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
632        Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
633    }
634
635    /// Get the current memory usage of this client.
636    pub fn memory_usage(&self) -> MemoryUsage {
637        self.context.lock().memory_usage(self.stream_id())
638    }
639
640    /// Change the memory allocation mode.
641    ///
642    /// # Safety
643    ///
644    /// This function isn't thread safe and might create memory leaks.
645    pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
646        self.context.lock().allocation_mode(mode, self.stream_id())
647    }
648
649    /// Use a persistent memory strategy to execute the provided function.
650    ///
651    /// # Notes
652    ///
653    /// - Using that memory strategy is beneficial for stating model parameters and similar workflows.
654    /// - You can call [Self::memory_cleanup()] if you want to free persistent memory.
655    pub fn memory_persistent_allocation<Input, Output, Func: Fn(Input) -> Output>(
656        &self,
657        input: Input,
658        func: Func,
659    ) -> Output {
660        let device_guard = self.context.lock_device();
661
662        self.context
663            .lock()
664            .allocation_mode(MemoryAllocationMode::Persistent, self.stream_id());
665
666        let output = func(input);
667
668        self.context
669            .lock()
670            .allocation_mode(MemoryAllocationMode::Auto, self.stream_id());
671
672        core::mem::drop(device_guard);
673
674        output
675    }
676
677    /// Ask the client to release memory that it can release.
678    ///
679    /// Nb: Results will vary on what the memory allocator deems beneficial,
680    /// so it's not guaranteed any memory is freed.
681    pub fn memory_cleanup(&self) {
682        self.context.lock().memory_cleanup(self.stream_id())
683    }
684
685    /// Measure the execution time of some inner operations.
686    #[track_caller]
687    pub fn profile<O>(
688        &self,
689        func: impl FnOnce() -> O,
690        #[allow(unused)] func_name: &str,
691    ) -> Result<(O, ProfileDuration), ProfileError> {
692        // Get the outer caller. For execute() this points straight to the
693        // cube kernel. For general profiling it points to whoever calls profile.
694        #[cfg(feature = "profile-tracy")]
695        let location = std::panic::Location::caller();
696
697        // Make a CPU span. If the server has system profiling this is all you need.
698        #[cfg(feature = "profile-tracy")]
699        let _span = tracy_client::Client::running().unwrap().span_alloc(
700            None,
701            func_name,
702            location.file(),
703            location.line(),
704            0,
705        );
706
707        let device_guard = self.context.lock_device();
708
709        #[cfg(feature = "profile-tracy")]
710        let gpu_span = if self.state.properties.timing_method == TimingMethod::Device {
711            let gpu_span = self
712                .state
713                .gpu_client
714                .span_alloc(func_name, "profile", location.file(), location.line())
715                .unwrap();
716            Some(gpu_span)
717        } else {
718            None
719        };
720
721        let token = self.context.lock().start_profile(self.stream_id());
722
723        let out = func();
724
725        let result = self.context.lock().end_profile(self.stream_id(), token);
726
727        #[cfg(feature = "profile-tracy")]
728        if let Some(mut gpu_span) = gpu_span {
729            gpu_span.end_zone();
730            let epoch = self.state.epoch_time;
731            // Add in the work to upload the timestamp data.
732            result = result.map(|result| {
733                ProfileDuration::new(
734                    Box::pin(async move {
735                        let ticks = result.resolve().await;
736                        let start_duration = ticks.start_duration_since(epoch).as_nanos() as i64;
737                        let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
738                        gpu_span.upload_timestamp_start(start_duration);
739                        gpu_span.upload_timestamp_end(end_duration);
740                        ticks
741                    }),
742                    TimingMethod::Device,
743                )
744            });
745        }
746        core::mem::drop(device_guard);
747
748        match result {
749            Ok(result) => Ok((out, result)),
750            Err(err) => Err(err),
751        }
752    }
753
754    /// Transfer data from one client to another
755    fn change_client_sync(
756        &self,
757        src_descriptor: CopyDescriptor<'_>,
758        alloc_descriptor: AllocationDescriptor<'_>,
759        dst_server: &Self,
760    ) -> Allocation {
761        let shape = src_descriptor.shape;
762        let elem_size = src_descriptor.elem_size;
763        let stream_id = self.stream_id();
764
765        // Allocate destination
766        let alloc = dst_server
767            .context
768            .lock()
769            .create(vec![alloc_descriptor], self.stream_id())
770            .unwrap()
771            .remove(0);
772
773        let read = self.context.lock().read(vec![src_descriptor], stream_id);
774        let mut data = cubecl_common::future::block_on(read).unwrap();
775
776        let desc_descriptor = CopyDescriptor {
777            binding: alloc.handle.clone().binding(),
778            shape,
779            strides: &alloc.strides,
780            elem_size,
781        };
782
783        dst_server
784            .context
785            .lock()
786            .write(vec![(desc_descriptor, data.remove(0))], stream_id)
787            .unwrap();
788
789        alloc
790    }
791
792    /// Returns all line sizes that are useful to perform optimal IO operation on the given element.
793    pub fn io_optimized_line_sizes(&self, elem: &StorageType) -> impl Iterator<Item = u8> + Clone {
794        let load_width = self.properties().hardware.load_width as usize;
795        let max = (load_width / elem.size_bits()) as u8;
796        let supported = R::supported_line_sizes();
797        supported.iter().filter(move |v| **v <= max).cloned()
798    }
799
800    /// Returns all line sizes that are useful to perform optimal IO operation on the given element.
801    /// Ignores native support, and allows all line sizes. This means the returned size may be
802    /// unrolled, and may not support dynamic indexing.
803    pub fn io_optimized_line_sizes_unchecked(
804        &self,
805        size: usize,
806    ) -> impl Iterator<Item = u8> + Clone {
807        let load_width = self.properties().hardware.load_width as usize;
808        let size_bits = size * 8;
809        let max = load_width / size_bits;
810        // This makes this effectively the same as checked, if it doesn't work it's a problem with
811        // unroll that should be investigated instead. But separate PR.
812        let max = usize::min(R::max_global_line_size() as usize, max);
813
814        // If the max is 8, we want to test 1, 2, 4, 8 which is log2(8) + 1.
815        let num_candidates = max.trailing_zeros() + 1;
816
817        (0..num_candidates).map(|i| 2u8.pow(i)).rev()
818    }
819}