cubecl_runtime/
client.rs

1use crate::{
2    DeviceProperties,
3    config::{TypeNameFormatLevel, type_name_format},
4    kernel::KernelMetadata,
5    logging::ProfileLevel,
6    memory_management::{MemoryAllocationMode, MemoryUsage},
7    runtime::Runtime,
8    server::{
9        Allocation, AllocationDescriptor, AllocationKind, Binding, Bindings, ComputeServer,
10        CopyDescriptor, CubeCount, ExecutionError, ExecutionMode, Handle, IoError, LaunchError,
11        ProfileError, ServerCommunication, ServerUtilities,
12    },
13    storage::{BindingResource, ComputeStorage},
14};
15use alloc::format;
16use alloc::sync::Arc;
17use alloc::vec;
18use alloc::vec::Vec;
19use core::ops::DerefMut;
20use cubecl_common::{
21    bytes::{AllocationProperty, Bytes},
22    device::{Device, DeviceContext},
23    future::DynFut,
24    profile::ProfileDuration,
25};
26use cubecl_ir::StorageType;
27
28#[allow(unused)]
29use cubecl_common::profile::TimingMethod;
30use cubecl_common::stream_id::StreamId;
31
32/// The ComputeClient is the entry point to require tasks from the ComputeServer.
33/// It should be obtained for a specific device via the Compute struct.
34pub struct ComputeClient<R: Runtime> {
35    context: DeviceContext<R::Server>,
36    utilities: Arc<ServerUtilities<R::Server>>,
37    stream_id: Option<StreamId>,
38}
39
40impl<R: Runtime> Clone for ComputeClient<R> {
41    fn clone(&self) -> Self {
42        Self {
43            context: self.context.clone(),
44            utilities: self.utilities.clone(),
45            stream_id: self.stream_id,
46        }
47    }
48}
49
50impl<R: Runtime> ComputeClient<R> {
51    /// Get the info of the current backend.
52    pub fn info(&self) -> &<R::Server as ComputeServer>::Info {
53        &self.utilities.info
54    }
55
56    /// Create a new client with a new server.
57    pub fn init<D: Device>(device: &D, server: R::Server) -> Self {
58        let utilities = server.utilities();
59
60        let context = DeviceContext::<R::Server>::insert(device, server)
61            .expect("Can't create a new client on an already registered server");
62
63        Self {
64            context,
65            utilities,
66            stream_id: None,
67        }
68    }
69
70    /// Load the client for the given device.
71    pub fn load<D: Device>(device: &D) -> Self {
72        let context = DeviceContext::<R::Server>::locate(device);
73        let utilities = context.lock().utilities();
74
75        Self {
76            context,
77            utilities,
78            stream_id: None,
79        }
80    }
81
82    fn stream_id(&self) -> StreamId {
83        match self.stream_id {
84            Some(val) => val,
85            None => StreamId::current(),
86        }
87    }
88
89    /// Set the stream in which the current client is operating on.
90    ///
91    /// # Safety
92    ///
93    /// This is highly unsafe and should probably only be used by the CubeCL/Burn projects for now.
94    pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
95        self.stream_id = Some(stream_id);
96    }
97
98    fn do_read(&self, descriptors: Vec<CopyDescriptor<'_>>) -> DynFut<Result<Vec<Bytes>, IoError>> {
99        let stream_id = self.stream_id();
100        let mut state = self.context.lock();
101        let fut = state.read(descriptors, stream_id);
102        core::mem::drop(state);
103        fut
104    }
105
106    /// Given bindings, returns owned resources as bytes.
107    pub fn read_async(
108        &self,
109        handles: Vec<Handle>,
110    ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
111        let strides = [1];
112        let shapes = handles
113            .iter()
114            .map(|it| [it.size() as usize])
115            .collect::<Vec<_>>();
116        let bindings = handles
117            .into_iter()
118            .map(|it| it.binding())
119            .collect::<Vec<_>>();
120        let descriptors = bindings
121            .into_iter()
122            .zip(shapes.iter())
123            .map(|(binding, shape)| CopyDescriptor::new(binding, shape, &strides, 1))
124            .collect();
125
126        self.do_read(descriptors)
127    }
128
129    /// Given bindings, returns owned resources as bytes.
130    ///
131    /// # Remarks
132    ///
133    /// Panics if the read operation fails.
134    pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
135        cubecl_common::reader::read_sync(self.read_async(handles)).expect("TODO")
136    }
137
138    /// Given a binding, returns owned resource as bytes.
139    ///
140    /// # Remarks
141    /// Panics if the read operation fails.
142    pub fn read_one(&self, handle: Handle) -> Bytes {
143        cubecl_common::reader::read_sync(self.read_async(vec![handle]))
144            .expect("TODO")
145            .remove(0)
146    }
147
148    /// Given bindings, returns owned resources as bytes.
149    pub fn read_tensor_async(
150        &self,
151        descriptors: Vec<CopyDescriptor<'_>>,
152    ) -> impl Future<Output = Result<Vec<Bytes>, IoError>> + Send {
153        self.do_read(descriptors)
154    }
155
156    /// Given bindings, returns owned resources as bytes.
157    ///
158    /// # Remarks
159    ///
160    /// Panics if the read operation fails.
161    ///
162    /// The tensor must be in the same layout as created by the runtime, or more strict.
163    /// Contiguous tensors are always fine, strided tensors are only ok if the stride is similar to
164    /// the one created by the runtime (i.e. padded on only the last dimension). A way to check
165    /// stride compatibility on the runtime will be added in the future.
166    ///
167    /// Also see [ComputeClient::create_tensor].
168    pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor<'_>>) -> Vec<Bytes> {
169        cubecl_common::reader::read_sync(self.read_tensor_async(descriptors)).expect("TODO")
170    }
171
172    /// Given a binding, returns owned resource as bytes.
173    /// See [ComputeClient::read_tensor]
174    pub fn read_one_tensor_async(
175        &self,
176        descriptor: CopyDescriptor<'_>,
177    ) -> impl Future<Output = Result<Bytes, IoError>> + Send {
178        let fut = self.read_tensor_async(vec![descriptor]);
179
180        async { Ok(fut.await?.remove(0)) }
181    }
182
183    /// Given a binding, returns owned resource as bytes.
184    ///
185    /// # Remarks
186    /// Panics if the read operation fails.
187    /// See [ComputeClient::read_tensor]
188    pub fn read_one_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
189        self.read_tensor(vec![descriptor]).remove(0)
190    }
191
192    /// Given a resource handle, returns the storage resource.
193    pub fn get_resource(
194        &self,
195        binding: Binding,
196    ) -> BindingResource<<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource> {
197        let stream_id = self.stream_id();
198        self.context.lock().get_resource(binding, stream_id)
199    }
200
201    fn do_create_from_slices(
202        &self,
203        descriptors: Vec<AllocationDescriptor<'_>>,
204        slices: Vec<&[u8]>,
205    ) -> Result<Vec<Allocation>, IoError> {
206        let mut state = self.context.lock();
207        let allocations = state.create(descriptors.clone(), self.stream_id())?;
208        let descriptors = descriptors
209            .into_iter()
210            .zip(allocations.iter())
211            .zip(slices)
212            .map(|((desc, alloc), data)| {
213                (
214                    CopyDescriptor::new(
215                        alloc.handle.clone().binding(),
216                        desc.shape,
217                        &alloc.strides,
218                        desc.elem_size,
219                    ),
220                    Bytes::from_bytes_vec(data.to_vec()),
221                )
222            })
223            .collect();
224        let stream_id = self.stream_id();
225        state.write(descriptors, stream_id)?;
226        Ok(allocations)
227    }
228
229    fn do_create(
230        &self,
231        descriptors: Vec<AllocationDescriptor<'_>>,
232        mut data: Vec<Bytes>,
233    ) -> Result<Vec<Allocation>, IoError> {
234        self.staging(data.iter_mut(), true);
235
236        let mut state = self.context.lock();
237        let allocations = state.create(descriptors.clone(), self.stream_id())?;
238        let descriptors = descriptors
239            .into_iter()
240            .zip(allocations.iter())
241            .zip(data)
242            .map(|((desc, alloc), data)| {
243                (
244                    CopyDescriptor::new(
245                        alloc.handle.clone().binding(),
246                        desc.shape,
247                        &alloc.strides,
248                        desc.elem_size,
249                    ),
250                    data,
251                )
252            })
253            .collect();
254        let stream_id = self.stream_id();
255        state.write(descriptors, stream_id)?;
256        Ok(allocations)
257    }
258
259    /// Returns a resource handle containing the given data.
260    ///
261    /// # Notes
262    ///
263    /// Prefer using the more efficient [Self::create] function.
264    pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
265        let shape = [slice.len()];
266
267        self.do_create_from_slices(
268            vec![AllocationDescriptor::new(
269                AllocationKind::Contiguous,
270                &shape,
271                1,
272            )],
273            vec![slice],
274        )
275        .unwrap()
276        .remove(0)
277        .handle
278    }
279
280    /// Returns a resource handle containing the given [Bytes].
281    pub fn create(&self, data: Bytes) -> Handle {
282        let shape = [data.len()];
283
284        self.do_create(
285            vec![AllocationDescriptor::new(
286                AllocationKind::Contiguous,
287                &shape,
288                1,
289            )],
290            vec![data],
291        )
292        .unwrap()
293        .remove(0)
294        .handle
295    }
296
297    /// Given a resource and shape, stores it and returns the tensor handle and strides.
298    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
299    /// should be taken when indexing.
300    ///
301    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
302    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
303    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
304    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
305    /// can load as much data as possible in a single instruction. It may be aligned even more to
306    /// also take cache lines into account.
307    ///
308    /// However, the stride must be taken into account when indexing and reading the tensor
309    /// (also see [ComputeClient::read_tensor]).
310    ///
311    /// # Notes
312    ///
313    /// Prefer using [Self::create_tensor] for better performance.
314    pub fn create_tensor_from_slice(
315        &self,
316        slice: &[u8],
317        shape: &[usize],
318        elem_size: usize,
319    ) -> Allocation {
320        self.do_create_from_slices(
321            vec![AllocationDescriptor::new(
322                AllocationKind::Optimized,
323                shape,
324                elem_size,
325            )],
326            vec![slice],
327        )
328        .unwrap()
329        .remove(0)
330    }
331
332    /// Given a resource and shape, stores it and returns the tensor handle and strides.
333    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
334    /// should be taken when indexing.
335    ///
336    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
337    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
338    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
339    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
340    /// can load as much data as possible in a single instruction. It may be aligned even more to
341    /// also take cache lines into account.
342    ///
343    /// However, the stride must be taken into account when indexing and reading the tensor
344    /// (also see [ComputeClient::read_tensor]).
345    pub fn create_tensor(&self, bytes: Bytes, shape: &[usize], elem_size: usize) -> Allocation {
346        self.do_create(
347            vec![AllocationDescriptor::new(
348                AllocationKind::Optimized,
349                shape,
350                elem_size,
351            )],
352            vec![bytes],
353        )
354        .unwrap()
355        .remove(0)
356    }
357
358    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
359    /// handle, and returns the handles for them.
360    /// See [ComputeClient::create_tensor]
361    ///
362    /// # Notes
363    ///
364    /// Prefer using [Self::create_tensors] for better performance.
365    pub fn create_tensors_from_slices(
366        &self,
367        descriptors: Vec<(AllocationDescriptor<'_>, &[u8])>,
368    ) -> Vec<Allocation> {
369        let (descriptors, data) = descriptors.into_iter().unzip();
370
371        self.do_create_from_slices(descriptors, data).unwrap()
372    }
373
374    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
375    /// handle, and returns the handles for them.
376    /// See [ComputeClient::create_tensor]
377    pub fn create_tensors(
378        &self,
379        descriptors: Vec<(AllocationDescriptor<'_>, Bytes)>,
380    ) -> Vec<Allocation> {
381        let (descriptors, data) = descriptors.into_iter().unzip();
382
383        self.do_create(descriptors, data).unwrap()
384    }
385
386    fn do_empty(
387        &self,
388        descriptors: Vec<AllocationDescriptor<'_>>,
389    ) -> Result<Vec<Allocation>, IoError> {
390        let mut state = self.context.lock();
391        state.create(descriptors, self.stream_id())
392    }
393
394    /// Reserves `size` bytes in the storage, and returns a handle over them.
395    pub fn empty(&self, size: usize) -> Handle {
396        let shape = [size];
397        let descriptor = AllocationDescriptor::new(AllocationKind::Contiguous, &shape, 1);
398        self.do_empty(vec![descriptor]).unwrap().remove(0).handle
399    }
400
401    /// Reserves `shape` in the storage, and returns a tensor handle for it.
402    /// See [ComputeClient::create_tensor]
403    pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> Allocation {
404        let descriptor = AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size);
405        self.do_empty(vec![descriptor]).unwrap().remove(0)
406    }
407
408    /// Reserves all `shapes` in a single storage buffer, and returns the handles for them.
409    /// See [ComputeClient::create_tensor]
410    pub fn empty_tensors(&self, descriptors: Vec<AllocationDescriptor<'_>>) -> Vec<Allocation> {
411        self.do_empty(descriptors).unwrap()
412    }
413
414    /// Marks the given [Bytes] as being a staging buffer, maybe transferring it to pinned memory
415    /// for faster data transfer with compute device.
416    pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
417    where
418        I: Iterator<Item = &'a mut Bytes>,
419    {
420        let has_staging = |b: &Bytes| match b.property() {
421            AllocationProperty::Pinned => false,
422            AllocationProperty::File => true,
423            AllocationProperty::Native | AllocationProperty::Other => !file_only,
424        };
425
426        let mut to_be_updated = Vec::new();
427        let sizes = bytes
428            .filter_map(|b| match has_staging(b) {
429                true => {
430                    let len = b.len();
431                    to_be_updated.push(b);
432                    Some(len)
433                }
434                false => None,
435            })
436            .collect::<Vec<usize>>();
437
438        if sizes.is_empty() {
439            return;
440        }
441
442        let stream_id = self.stream_id();
443        let mut context = self.context.lock();
444        let stagings = match context.staging(&sizes, stream_id) {
445            Ok(val) => val,
446            Err(_) => return,
447        };
448        core::mem::drop(context);
449
450        to_be_updated
451            .into_iter()
452            .zip(stagings)
453            .for_each(|(b, mut staging)| {
454                b.copy_into(&mut staging);
455                core::mem::swap(b, &mut staging);
456            });
457    }
458
459    /// Transfer data from one client to another
460    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, src, dst_server)))]
461    pub fn to_client(&self, src: Handle, dst_server: &Self) -> Allocation {
462        let shape = [src.size() as usize];
463        let src_descriptor = src.copy_descriptor(&shape, &[1], 1);
464
465        if R::Server::SERVER_COMM_ENABLED {
466            self.to_client_tensor(src_descriptor, dst_server)
467        } else {
468            let alloc_desc = AllocationDescriptor::new(
469                AllocationKind::Contiguous,
470                src_descriptor.shape,
471                src_descriptor.elem_size,
472            );
473            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
474        }
475    }
476
477    /// Transfer data from one client to another
478    ///
479    /// Make sure the source description can be read in a contiguous manner.
480    #[cfg_attr(
481        feature = "tracing",
482        tracing::instrument(skip(self, src_descriptor, dst_server))
483    )]
484    pub fn to_client_tensor(
485        &self,
486        src_descriptor: CopyDescriptor<'_>,
487        dst_server: &Self,
488    ) -> Allocation {
489        if R::Server::SERVER_COMM_ENABLED {
490            let guard = self.context.lock_device_kind();
491            let mut server_src = self.context.lock();
492            let mut server_dst = dst_server.context.lock();
493
494            let copied = R::Server::copy(
495                server_src.deref_mut(),
496                server_dst.deref_mut(),
497                src_descriptor,
498                self.stream_id(),
499                dst_server.stream_id(),
500            )
501            .unwrap();
502            core::mem::drop(server_src);
503            core::mem::drop(server_dst);
504            core::mem::drop(guard);
505            copied
506        } else {
507            let alloc_desc = AllocationDescriptor::new(
508                AllocationKind::Optimized,
509                src_descriptor.shape,
510                src_descriptor.elem_size,
511            );
512            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
513        }
514    }
515
516    #[track_caller]
517    #[cfg_attr(feature="tracing", tracing::instrument(
518        skip(self, kernel, bindings),
519        fields(
520            kernel.name = %kernel.name(),
521            kernel.id = %kernel.id(),
522        )
523    ))]
524    unsafe fn launch_inner(
525        &self,
526        kernel: <R::Server as ComputeServer>::Kernel,
527        count: CubeCount,
528        bindings: Bindings,
529        mode: ExecutionMode,
530        stream_id: StreamId,
531    ) -> Result<(), LaunchError> {
532        let level = self.utilities.logger.profile_level();
533
534        match level {
535            None | Some(ProfileLevel::ExecutionOnly) => {
536                let mut state = self.context.lock();
537                let name = kernel.name();
538
539                let result = unsafe { state.launch(kernel, count, bindings, mode, stream_id) };
540
541                if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
542                    let info = type_name_format(name, TypeNameFormatLevel::Balanced);
543                    self.utilities.logger.register_execution(info);
544                }
545                result
546            }
547            Some(level) => {
548                let name = kernel.name();
549                let kernel_id = kernel.id();
550                let (result, profile) = self
551                    .profile(
552                        || unsafe {
553                            let mut state = self.context.lock();
554                            state.launch(kernel, count.clone(), bindings, mode, stream_id)
555                        },
556                        name,
557                    )
558                    .unwrap();
559                let info = match level {
560                    ProfileLevel::Full => {
561                        format!("{name}: {kernel_id} CubeCount {count:?}")
562                    }
563                    _ => type_name_format(name, TypeNameFormatLevel::Balanced),
564                };
565                self.utilities.logger.register_profiled(info, profile);
566                result
567            }
568        }
569    }
570
571    /// Launches the `kernel` with the given `bindings`.
572    #[track_caller]
573    pub fn launch(
574        &self,
575        kernel: <R::Server as ComputeServer>::Kernel,
576        count: CubeCount,
577        bindings: Bindings,
578    ) -> Result<(), LaunchError> {
579        // SAFETY: Using checked execution mode.
580        unsafe {
581            self.launch_inner(
582                kernel,
583                count,
584                bindings,
585                ExecutionMode::Checked,
586                self.stream_id(),
587            )
588        }
589    }
590
591    /// Launches the `kernel` with the given `bindings` without performing any bound checks.
592    ///
593    /// # Safety
594    ///
595    /// To ensure this is safe, you must verify your kernel:
596    /// - Has no out-of-bound reads and writes that can happen.
597    /// - Has no infinite loops that might never terminate.
598    #[track_caller]
599    pub unsafe fn launch_unchecked(
600        &self,
601        kernel: <R::Server as ComputeServer>::Kernel,
602        count: CubeCount,
603        bindings: Bindings,
604    ) -> Result<(), LaunchError> {
605        // SAFETY: Caller has to uphold kernel being safe.
606        unsafe {
607            self.launch_inner(
608                kernel,
609                count,
610                bindings,
611                ExecutionMode::Unchecked,
612                self.stream_id(),
613            )
614        }
615    }
616
617    /// Flush all outstanding commands.
618    pub fn flush(&self) {
619        let stream_id = self.stream_id();
620        self.context.lock().flush(stream_id)
621    }
622
623    /// Wait for the completion of every task in the server.
624    pub fn sync(&self) -> DynFut<Result<(), ExecutionError>> {
625        let stream_id = self.stream_id();
626        let mut state = self.context.lock();
627        let fut = state.sync(stream_id);
628        core::mem::drop(state);
629        self.utilities.logger.profile_summary();
630
631        fut
632    }
633
634    /// Get the features supported by the compute server.
635    pub fn properties(&self) -> &DeviceProperties {
636        &self.utilities.properties
637    }
638
639    /// # Warning
640    ///
641    /// For private use only.
642    pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
643        Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
644    }
645
646    /// Get the current memory usage of this client.
647    pub fn memory_usage(&self) -> MemoryUsage {
648        self.context.lock().memory_usage(self.stream_id())
649    }
650
651    /// Change the memory allocation mode.
652    ///
653    /// # Safety
654    ///
655    /// This function isn't thread safe and might create memory leaks.
656    pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
657        self.context.lock().allocation_mode(mode, self.stream_id())
658    }
659
660    /// Use a persistent memory strategy to execute the provided function.
661    ///
662    /// # Notes
663    ///
664    /// - Using that memory strategy is beneficial for stating model parameters and similar workflows.
665    /// - You can call [Self::memory_cleanup()] if you want to free persistent memory.
666    pub fn memory_persistent_allocation<Input, Output, Func: Fn(Input) -> Output>(
667        &self,
668        input: Input,
669        func: Func,
670    ) -> Output {
671        let device_guard = self.context.lock_device();
672
673        self.context
674            .lock()
675            .allocation_mode(MemoryAllocationMode::Persistent, self.stream_id());
676
677        let output = func(input);
678
679        self.context
680            .lock()
681            .allocation_mode(MemoryAllocationMode::Auto, self.stream_id());
682
683        core::mem::drop(device_guard);
684
685        output
686    }
687
688    /// Ask the client to release memory that it can release.
689    ///
690    /// Nb: Results will vary on what the memory allocator deems beneficial,
691    /// so it's not guaranteed any memory is freed.
692    pub fn memory_cleanup(&self) {
693        self.context.lock().memory_cleanup(self.stream_id())
694    }
695
696    /// Measure the execution time of some inner operations.
697    #[track_caller]
698    pub fn profile<O>(
699        &self,
700        func: impl FnOnce() -> O,
701        #[allow(unused)] func_name: &str,
702    ) -> Result<(O, ProfileDuration), ProfileError> {
703        // Get the outer caller. For execute() this points straight to the
704        // cube kernel. For general profiling it points to whoever calls profile.
705        #[cfg(feature = "profile-tracy")]
706        let location = std::panic::Location::caller();
707
708        // Make a CPU span. If the server has system profiling this is all you need.
709        #[cfg(feature = "profile-tracy")]
710        let _span = tracy_client::Client::running().unwrap().span_alloc(
711            None,
712            func_name,
713            location.file(),
714            location.line(),
715            0,
716        );
717
718        let device_guard = self.context.lock_device();
719
720        #[cfg(feature = "profile-tracy")]
721        let gpu_span = if self.utilities.properties.timing_method == TimingMethod::Device {
722            let gpu_span = self
723                .utilities
724                .gpu_client
725                .span_alloc(func_name, "profile", location.file(), location.line())
726                .unwrap();
727            Some(gpu_span)
728        } else {
729            None
730        };
731
732        let token = self.context.lock().start_profile(self.stream_id());
733
734        let out = func();
735
736        #[allow(unused_mut, reason = "Used in profile-tracy")]
737        let mut result = self.context.lock().end_profile(self.stream_id(), token);
738
739        #[cfg(feature = "profile-tracy")]
740        if let Some(mut gpu_span) = gpu_span {
741            gpu_span.end_zone();
742            let epoch = self.utilities.epoch_time;
743            // Add in the work to upload the timestamp data.
744            result = result.map(|result| {
745                ProfileDuration::new(
746                    Box::pin(async move {
747                        let ticks = result.resolve().await;
748                        let start_duration = ticks.start_duration_since(epoch).as_nanos() as i64;
749                        let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
750                        gpu_span.upload_timestamp_start(start_duration);
751                        gpu_span.upload_timestamp_end(end_duration);
752                        ticks
753                    }),
754                    TimingMethod::Device,
755                )
756            });
757        }
758        core::mem::drop(device_guard);
759
760        match result {
761            Ok(result) => Ok((out, result)),
762            Err(err) => Err(err),
763        }
764    }
765
766    /// Transfer data from one client to another
767    #[cfg_attr(
768        feature = "tracing",
769        tracing::instrument(skip(self, src_descriptor, alloc_descriptor, dst_server))
770    )]
771    fn change_client_sync(
772        &self,
773        src_descriptor: CopyDescriptor<'_>,
774        alloc_descriptor: AllocationDescriptor<'_>,
775        dst_server: &Self,
776    ) -> Allocation {
777        let shape = src_descriptor.shape;
778        let elem_size = src_descriptor.elem_size;
779        let stream_id = self.stream_id();
780
781        // Allocate destination
782        let alloc = dst_server
783            .context
784            .lock()
785            .create(vec![alloc_descriptor], self.stream_id())
786            .unwrap()
787            .remove(0);
788
789        let read = self.context.lock().read(vec![src_descriptor], stream_id);
790        let mut data = cubecl_common::future::block_on(read).unwrap();
791
792        let desc_descriptor = CopyDescriptor {
793            binding: alloc.handle.clone().binding(),
794            shape,
795            strides: &alloc.strides,
796            elem_size,
797        };
798
799        dst_server
800            .context
801            .lock()
802            .write(vec![(desc_descriptor, data.remove(0))], stream_id)
803            .unwrap();
804
805        alloc
806    }
807
808    /// Returns all line sizes that are useful to perform optimal IO operation on the given element.
809    pub fn io_optimized_line_sizes(&self, elem: &StorageType) -> impl Iterator<Item = u8> + Clone {
810        let load_width = self.properties().hardware.load_width as usize;
811        let max = (load_width / elem.size_bits()) as u8;
812        let supported = R::supported_line_sizes();
813        supported.iter().filter(move |v| **v <= max).cloned()
814    }
815
816    /// Returns all line sizes that are useful to perform optimal IO operation on the given element.
817    /// Ignores native support, and allows all line sizes. This means the returned size may be
818    /// unrolled, and may not support dynamic indexing.
819    pub fn io_optimized_line_sizes_unchecked(
820        &self,
821        size: usize,
822    ) -> impl Iterator<Item = u8> + Clone {
823        let load_width = self.properties().hardware.load_width as usize;
824        let size_bits = size * 8;
825        let max = load_width / size_bits;
826        // This makes this effectively the same as checked, if it doesn't work it's a problem with
827        // unroll that should be investigated instead. But separate PR.
828        let max = usize::min(R::max_global_line_size() as usize, max);
829
830        // If the max is 8, we want to test 1, 2, 4, 8 which is log2(8) + 1.
831        let num_candidates = max.trailing_zeros() + 1;
832
833        (0..num_candidates).map(|i| 2u8.pow(i)).rev()
834    }
835}