cubecl_runtime/
client.rs

1use crate::{
2    DeviceProperties,
3    config::{TypeNameFormatLevel, type_name_format},
4    kernel::KernelMetadata,
5    logging::ProfileLevel,
6    memory_management::{MemoryAllocationMode, MemoryUsage},
7    server::{
8        Allocation, AllocationDescriptor, AllocationKind, Binding, Bindings, ComputeServer,
9        CopyDescriptor, CubeCount, Handle, IoError, ProfileError, ServerUtilities,
10    },
11    storage::{BindingResource, ComputeStorage},
12};
13use alloc::format;
14use alloc::sync::Arc;
15use alloc::vec;
16use alloc::vec::Vec;
17use core::ops::DerefMut;
18use cubecl_common::{
19    ExecutionMode,
20    bytes::{AllocationProperty, Bytes},
21    device::{Device, DeviceContext},
22    future::DynFut,
23    profile::ProfileDuration,
24};
25
26#[allow(unused)]
27use cubecl_common::profile::TimingMethod;
28use cubecl_common::stream_id::StreamId;
29
30/// The ComputeClient is the entry point to require tasks from the ComputeServer.
31/// It should be obtained for a specific device via the Compute struct.
32pub struct ComputeClient<Server: ComputeServer> {
33    context: DeviceContext<Server>,
34    utilities: Arc<ServerUtilities<Server>>,
35    stream_id: Option<StreamId>,
36}
37
38impl<S> Clone for ComputeClient<S>
39where
40    S: ComputeServer,
41{
42    fn clone(&self) -> Self {
43        Self {
44            context: self.context.clone(),
45            utilities: self.utilities.clone(),
46            stream_id: self.stream_id,
47        }
48    }
49}
50
51impl<Server> ComputeClient<Server>
52where
53    Server: ComputeServer,
54{
55    /// Get the info of the current backend.
56    pub fn info(&self) -> &Server::Info {
57        &self.utilities.info
58    }
59
60    /// Create a new client with a new server.
61    pub fn init<D: Device>(device: &D, server: Server) -> Self {
62        let utilities = server.utilities();
63
64        let context = DeviceContext::<Server>::insert(device, server)
65            .expect("Can't create a new client on an already registered server");
66
67        Self {
68            context,
69            utilities,
70            stream_id: None,
71        }
72    }
73
74    /// Load the client for the given device.
75    pub fn load<D: Device>(device: &D) -> Self {
76        let context = DeviceContext::<Server>::locate(device);
77        let utilities = context.lock().utilities();
78
79        Self {
80            context,
81            utilities,
82            stream_id: None,
83        }
84    }
85
86    fn stream_id(&self) -> StreamId {
87        match self.stream_id {
88            Some(val) => val,
89            None => StreamId::current(),
90        }
91    }
92
93    /// Set the stream in which the current client is operating on.
94    ///
95    /// # Safety
96    ///
97    /// This is highly unsafe and should probably only be used by the CubeCL/Burn projects for now.
98    pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
99        self.stream_id = Some(stream_id);
100    }
101
102    fn do_read(&self, descriptors: Vec<CopyDescriptor<'_>>) -> DynFut<Result<Vec<Bytes>, IoError>> {
103        let stream_id = self.stream_id();
104        let mut state = self.context.lock();
105        let fut = state.read(descriptors, stream_id);
106        core::mem::drop(state);
107        fut
108    }
109
110    /// Given bindings, returns owned resources as bytes.
111    pub fn read_async(&self, handles: Vec<Handle>) -> impl Future<Output = Vec<Bytes>> + Send {
112        let strides = [1];
113        let shapes = handles
114            .iter()
115            .map(|it| [it.size() as usize])
116            .collect::<Vec<_>>();
117        let bindings = handles
118            .into_iter()
119            .map(|it| it.binding())
120            .collect::<Vec<_>>();
121        let descriptors = bindings
122            .into_iter()
123            .zip(shapes.iter())
124            .map(|(binding, shape)| CopyDescriptor::new(binding, shape, &strides, 1))
125            .collect();
126
127        let fut = self.do_read(descriptors);
128
129        async move { fut.await.unwrap() }
130    }
131
132    /// Given bindings, returns owned resources as bytes.
133    ///
134    /// # Remarks
135    ///
136    /// Panics if the read operation fails.
137    pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
138        cubecl_common::reader::read_sync(self.read_async(handles))
139    }
140
141    /// Given a binding, returns owned resource as bytes.
142    ///
143    /// # Remarks
144    /// Panics if the read operation fails.
145    pub fn read_one(&self, handle: Handle) -> Bytes {
146        cubecl_common::reader::read_sync(self.read_async(vec![handle])).remove(0)
147    }
148
149    /// Given bindings, returns owned resources as bytes.
150    pub fn read_tensor_async(
151        &self,
152        descriptors: Vec<CopyDescriptor<'_>>,
153    ) -> impl Future<Output = Vec<Bytes>> + Send {
154        let fut = self.do_read(descriptors);
155
156        async move { fut.await.unwrap() }
157    }
158
159    /// Given bindings, returns owned resources as bytes.
160    ///
161    /// # Remarks
162    ///
163    /// Panics if the read operation fails.
164    ///
165    /// The tensor must be in the same layout as created by the runtime, or more strict.
166    /// Contiguous tensors are always fine, strided tensors are only ok if the stride is similar to
167    /// the one created by the runtime (i.e. padded on only the last dimension). A way to check
168    /// stride compatibility on the runtime will be added in the future.
169    ///
170    /// Also see [ComputeClient::create_tensor].
171    pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor<'_>>) -> Vec<Bytes> {
172        cubecl_common::reader::read_sync(self.read_tensor_async(descriptors))
173    }
174
175    /// Given a binding, returns owned resource as bytes.
176    /// See [ComputeClient::read_tensor]
177    pub fn read_one_tensor_async(
178        &self,
179        descriptor: CopyDescriptor<'_>,
180    ) -> impl Future<Output = Bytes> + Send {
181        let fut = self.read_tensor_async(vec![descriptor]);
182
183        async { fut.await.remove(0) }
184    }
185
186    /// Given a binding, returns owned resource as bytes.
187    ///
188    /// # Remarks
189    /// Panics if the read operation fails.
190    /// See [ComputeClient::read_tensor]
191    pub fn read_one_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
192        self.read_tensor(vec![descriptor]).remove(0)
193    }
194
195    /// Given a resource handle, returns the storage resource.
196    pub fn get_resource(
197        &self,
198        binding: Binding,
199    ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
200        let stream_id = self.stream_id();
201        self.context.lock().get_resource(binding, stream_id)
202    }
203
204    fn do_create_from_slices(
205        &self,
206        descriptors: Vec<AllocationDescriptor<'_>>,
207        slices: Vec<&[u8]>,
208    ) -> Result<Vec<Allocation>, IoError> {
209        let mut state = self.context.lock();
210        let allocations = state.create(descriptors.clone(), self.stream_id())?;
211        let descriptors = descriptors
212            .into_iter()
213            .zip(allocations.iter())
214            .zip(slices)
215            .map(|((desc, alloc), data)| {
216                (
217                    CopyDescriptor::new(
218                        alloc.handle.clone().binding(),
219                        desc.shape,
220                        &alloc.strides,
221                        desc.elem_size,
222                    ),
223                    Bytes::from_bytes_vec(data.to_vec()),
224                )
225            })
226            .collect();
227        let stream_id = self.stream_id();
228        state.write(descriptors, stream_id)?;
229        Ok(allocations)
230    }
231
232    fn do_create(
233        &self,
234        descriptors: Vec<AllocationDescriptor<'_>>,
235        mut data: Vec<Bytes>,
236    ) -> Result<Vec<Allocation>, IoError> {
237        self.staging(data.iter_mut(), true);
238
239        let mut state = self.context.lock();
240        let allocations = state.create(descriptors.clone(), self.stream_id())?;
241        let descriptors = descriptors
242            .into_iter()
243            .zip(allocations.iter())
244            .zip(data)
245            .map(|((desc, alloc), data)| {
246                (
247                    CopyDescriptor::new(
248                        alloc.handle.clone().binding(),
249                        desc.shape,
250                        &alloc.strides,
251                        desc.elem_size,
252                    ),
253                    data,
254                )
255            })
256            .collect();
257        let stream_id = self.stream_id();
258        state.write(descriptors, stream_id)?;
259        Ok(allocations)
260    }
261
262    /// Returns a resource handle containing the given data.
263    ///
264    /// # Notes
265    ///
266    /// Prefer using the more efficient [Self::create] function.
267    pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
268        let shape = [slice.len()];
269
270        self.do_create_from_slices(
271            vec![AllocationDescriptor::new(
272                AllocationKind::Contiguous,
273                &shape,
274                1,
275            )],
276            vec![slice],
277        )
278        .unwrap()
279        .remove(0)
280        .handle
281    }
282
283    /// Returns a resource handle containing the given [Bytes].
284    pub fn create(&self, data: Bytes) -> Handle {
285        let shape = [data.len()];
286
287        self.do_create(
288            vec![AllocationDescriptor::new(
289                AllocationKind::Contiguous,
290                &shape,
291                1,
292            )],
293            vec![data],
294        )
295        .unwrap()
296        .remove(0)
297        .handle
298    }
299
300    /// Given a resource and shape, stores it and returns the tensor handle and strides.
301    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
302    /// should be taken when indexing.
303    ///
304    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
305    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
306    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
307    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
308    /// can load as much data as possible in a single instruction. It may be aligned even more to
309    /// also take cache lines into account.
310    ///
311    /// However, the stride must be taken into account when indexing and reading the tensor
312    /// (also see [ComputeClient::read_tensor]).
313    ///
314    /// # Notes
315    ///
316    /// Prefer using [Self::create_tensor] for better performance.
317    pub fn create_tensor_from_slice(
318        &self,
319        slice: &[u8],
320        shape: &[usize],
321        elem_size: usize,
322    ) -> Allocation {
323        self.do_create_from_slices(
324            vec![AllocationDescriptor::new(
325                AllocationKind::Optimized,
326                shape,
327                elem_size,
328            )],
329            vec![slice],
330        )
331        .unwrap()
332        .remove(0)
333    }
334
335    /// Given a resource and shape, stores it and returns the tensor handle and strides.
336    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
337    /// should be taken when indexing.
338    ///
339    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
340    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
341    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
342    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
343    /// can load as much data as possible in a single instruction. It may be aligned even more to
344    /// also take cache lines into account.
345    ///
346    /// However, the stride must be taken into account when indexing and reading the tensor
347    /// (also see [ComputeClient::read_tensor]).
348    pub fn create_tensor(&self, bytes: Bytes, shape: &[usize], elem_size: usize) -> Allocation {
349        self.do_create(
350            vec![AllocationDescriptor::new(
351                AllocationKind::Optimized,
352                shape,
353                elem_size,
354            )],
355            vec![bytes],
356        )
357        .unwrap()
358        .remove(0)
359    }
360
361    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
362    /// handle, and returns the handles for them.
363    /// See [ComputeClient::create_tensor]
364    ///
365    /// # Notes
366    ///
367    /// Prefer using [Self::create_tensors] for better performance.
368    pub fn create_tensors_from_slices(
369        &self,
370        descriptors: Vec<(AllocationDescriptor<'_>, &[u8])>,
371    ) -> Vec<Allocation> {
372        let (descriptors, data) = descriptors.into_iter().unzip();
373
374        self.do_create_from_slices(descriptors, data).unwrap()
375    }
376
377    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
378    /// handle, and returns the handles for them.
379    /// See [ComputeClient::create_tensor]
380    pub fn create_tensors(
381        &self,
382        descriptors: Vec<(AllocationDescriptor<'_>, Bytes)>,
383    ) -> Vec<Allocation> {
384        let (descriptors, data) = descriptors.into_iter().unzip();
385
386        self.do_create(descriptors, data).unwrap()
387    }
388
389    fn do_empty(
390        &self,
391        descriptors: Vec<AllocationDescriptor<'_>>,
392    ) -> Result<Vec<Allocation>, IoError> {
393        let mut state = self.context.lock();
394        state.create(descriptors, self.stream_id())
395    }
396
397    /// Reserves `size` bytes in the storage, and returns a handle over them.
398    pub fn empty(&self, size: usize) -> Handle {
399        let shape = [size];
400        let descriptor = AllocationDescriptor::new(AllocationKind::Contiguous, &shape, 1);
401        self.do_empty(vec![descriptor]).unwrap().remove(0).handle
402    }
403
404    /// Reserves `shape` in the storage, and returns a tensor handle for it.
405    /// See [ComputeClient::create_tensor]
406    pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> Allocation {
407        let descriptor = AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size);
408        self.do_empty(vec![descriptor]).unwrap().remove(0)
409    }
410
411    /// Reserves all `shapes` in a single storage buffer, and returns the handles for them.
412    /// See [ComputeClient::create_tensor]
413    pub fn empty_tensors(&self, descriptors: Vec<AllocationDescriptor<'_>>) -> Vec<Allocation> {
414        self.do_empty(descriptors).unwrap()
415    }
416
417    /// Marks the given [Bytes] as being a staging buffer, maybe transfering it to pinned memory
418    /// for faster data transfer with compute device.
419    pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
420    where
421        I: Iterator<Item = &'a mut Bytes>,
422    {
423        let has_staging = |b: &Bytes| match b.property() {
424            AllocationProperty::Pinned => false,
425            AllocationProperty::File => true,
426            AllocationProperty::Native | AllocationProperty::Other => !file_only,
427        };
428
429        let mut to_be_updated = Vec::new();
430        let sizes = bytes
431            .filter_map(|b| match has_staging(b) {
432                true => {
433                    let len = b.len();
434                    to_be_updated.push(b);
435                    Some(len)
436                }
437                false => None,
438            })
439            .collect::<Vec<usize>>();
440
441        if sizes.is_empty() {
442            return;
443        }
444
445        let stream_id = self.stream_id();
446        let mut context = self.context.lock();
447        let stagings = match context.staging(&sizes, stream_id) {
448            Ok(val) => val,
449            Err(_) => return,
450        };
451        core::mem::drop(context);
452
453        to_be_updated
454            .into_iter()
455            .zip(stagings)
456            .for_each(|(b, mut staging)| {
457                b.copy_into(&mut staging);
458                core::mem::swap(b, &mut staging);
459            });
460    }
461
462    /// Transfer data from one client to another
463    pub fn to_client(&self, src: Handle, dst_server: &Self) -> Allocation {
464        let shape = [src.size() as usize];
465        let src_descriptor = src.copy_descriptor(&shape, &[1], 1);
466
467        if Server::SERVER_COMM_ENABLED {
468            self.to_client_tensor(src_descriptor, dst_server)
469        } else {
470            let alloc_desc = AllocationDescriptor::new(
471                AllocationKind::Contiguous,
472                src_descriptor.shape,
473                src_descriptor.elem_size,
474            );
475            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
476        }
477    }
478
479    /// Transfer data from one client to another
480    ///
481    /// Make sure the source description can be read in a contiguous manner.
482    pub fn to_client_tensor(
483        &self,
484        src_descriptor: CopyDescriptor<'_>,
485        dst_server: &Self,
486    ) -> Allocation {
487        if Server::SERVER_COMM_ENABLED {
488            let guard = self.context.lock_device_kind();
489            let mut server_src = self.context.lock();
490            let mut server_dst = dst_server.context.lock();
491
492            let copied = Server::copy(
493                server_src.deref_mut(),
494                server_dst.deref_mut(),
495                src_descriptor,
496                self.stream_id(),
497                dst_server.stream_id(),
498            )
499            .unwrap();
500            core::mem::drop(server_src);
501            core::mem::drop(server_dst);
502            core::mem::drop(guard);
503            copied
504        } else {
505            let alloc_desc = AllocationDescriptor::new(
506                AllocationKind::Optimized,
507                src_descriptor.shape,
508                src_descriptor.elem_size,
509            );
510            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
511        }
512    }
513
514    #[track_caller]
515    unsafe fn execute_inner(
516        &self,
517        kernel: Server::Kernel,
518        count: CubeCount,
519        bindings: Bindings,
520        mode: ExecutionMode,
521        stream_id: StreamId,
522    ) {
523        let level = self.utilities.logger.profile_level();
524
525        match level {
526            None | Some(ProfileLevel::ExecutionOnly) => {
527                let mut state = self.context.lock();
528                let name = kernel.name();
529
530                unsafe { state.execute(kernel, count, bindings, mode, stream_id) };
531
532                if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
533                    let info = type_name_format(name, TypeNameFormatLevel::Balanced);
534                    self.utilities.logger.register_execution(info);
535                }
536            }
537            Some(level) => {
538                let name = kernel.name();
539                let kernel_id = kernel.id();
540                let profile = self
541                    .profile(
542                        || unsafe {
543                            let mut state = self.context.lock();
544                            state.execute(kernel, count.clone(), bindings, mode, stream_id)
545                        },
546                        name,
547                    )
548                    .unwrap();
549                let info = match level {
550                    ProfileLevel::Full => {
551                        format!("{name}: {kernel_id} CubeCount {count:?}")
552                    }
553                    _ => type_name_format(name, TypeNameFormatLevel::Balanced),
554                };
555                self.utilities.logger.register_profiled(info, profile);
556            }
557        }
558    }
559
560    /// Executes the `kernel` over the given `bindings`.
561    #[track_caller]
562    pub fn execute(&self, kernel: Server::Kernel, count: CubeCount, bindings: Bindings) {
563        // SAFETY: Using checked execution mode.
564        unsafe {
565            self.execute_inner(
566                kernel,
567                count,
568                bindings,
569                ExecutionMode::Checked,
570                self.stream_id(),
571            );
572        }
573    }
574
575    /// Executes the `kernel` over the given `bindings` without performing any bound checks.
576    ///
577    /// # Safety
578    ///
579    /// To ensure this is safe, you must verify your kernel:
580    /// - Has no out-of-bound reads and writes that can happen.
581    /// - Has no infinite loops that might never terminate.
582    #[track_caller]
583    pub unsafe fn execute_unchecked(
584        &self,
585        kernel: Server::Kernel,
586        count: CubeCount,
587        bindings: Bindings,
588    ) {
589        // SAFETY: Caller has to uphold kernel being safe.
590        unsafe {
591            self.execute_inner(
592                kernel,
593                count,
594                bindings,
595                ExecutionMode::Unchecked,
596                self.stream_id(),
597            );
598        }
599    }
600
601    /// Flush all outstanding commands.
602    pub fn flush(&self) {
603        let stream_id = self.stream_id();
604        self.context.lock().flush(stream_id);
605    }
606
607    /// Wait for the completion of every task in the server.
608    pub fn sync(&self) -> DynFut<()> {
609        let stream_id = self.stream_id();
610        let mut state = self.context.lock();
611        let fut = state.sync(stream_id);
612        core::mem::drop(state);
613        self.utilities.logger.profile_summary();
614
615        fut
616    }
617
618    /// Get the features supported by the compute server.
619    pub fn properties(&self) -> &DeviceProperties {
620        &self.utilities.properties
621    }
622
623    /// # Warning
624    ///
625    /// For private use only.
626    pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
627        Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
628    }
629
630    /// Get the current memory usage of this client.
631    pub fn memory_usage(&self) -> MemoryUsage {
632        self.context.lock().memory_usage(self.stream_id())
633    }
634
635    /// Change the memory allocation mode.
636    ///
637    /// # Safety
638    ///
639    /// This function isn't thread safe and might create memory leaks.
640    pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
641        self.context.lock().allocation_mode(mode, self.stream_id())
642    }
643
644    /// Use a persistent memory strategy to execute the provided function.
645    ///
646    /// # Notes
647    ///
648    /// - Using that memory strategy is beneficial for stating model parameters and similar workflows.
649    /// - You can call [Self::memory_cleanup()] if you want to free persistent memory.
650    pub fn memory_persistent_allocation<Input, Output, Func: Fn(Input) -> Output>(
651        &self,
652        input: Input,
653        func: Func,
654    ) -> Output {
655        let device_guard = self.context.lock_device();
656
657        self.context
658            .lock()
659            .allocation_mode(MemoryAllocationMode::Persistent, self.stream_id());
660
661        let output = func(input);
662
663        self.context
664            .lock()
665            .allocation_mode(MemoryAllocationMode::Auto, self.stream_id());
666
667        core::mem::drop(device_guard);
668
669        output
670    }
671
672    /// Ask the client to release memory that it can release.
673    ///
674    /// Nb: Results will vary on what the memory allocator deems beneficial,
675    /// so it's not guaranteed any memory is freed.
676    pub fn memory_cleanup(&self) {
677        self.context.lock().memory_cleanup(self.stream_id())
678    }
679
680    /// Measure the execution time of some inner operations.
681    #[track_caller]
682    pub fn profile<O>(
683        &self,
684        func: impl FnOnce() -> O,
685        #[allow(unused)] func_name: &str,
686    ) -> Result<ProfileDuration, ProfileError> {
687        // Get the outer caller. For execute() this points straight to the
688        // cube kernel. For general profiling it points to whoever calls profile.
689        #[cfg(feature = "profile-tracy")]
690        let location = std::panic::Location::caller();
691
692        // Make a CPU span. If the server has system profiling this is all you need.
693        #[cfg(feature = "profile-tracy")]
694        let _span = tracy_client::Client::running().unwrap().span_alloc(
695            None,
696            func_name,
697            location.file(),
698            location.line(),
699            0,
700        );
701
702        let device_guard = self.context.lock_device();
703
704        #[cfg(feature = "profile-tracy")]
705        let gpu_span = if self.state.properties.timing_method == TimingMethod::Device {
706            let gpu_span = self
707                .state
708                .gpu_client
709                .span_alloc(func_name, "profile", location.file(), location.line())
710                .unwrap();
711            Some(gpu_span)
712        } else {
713            None
714        };
715
716        let token = self.context.lock().start_profile(self.stream_id());
717
718        let out = func();
719
720        let result = self.context.lock().end_profile(self.stream_id(), token);
721
722        core::mem::drop(out);
723
724        #[cfg(feature = "profile-tracy")]
725        if let Some(mut gpu_span) = gpu_span {
726            gpu_span.end_zone();
727            let epoch = self.state.epoch_time;
728            // Add in the work to upload the timestamp data.
729            result = result.map(|result| {
730                ProfileDuration::new(
731                    Box::pin(async move {
732                        let ticks = result.resolve().await;
733                        let start_duration = ticks.start_duration_since(epoch).as_nanos() as i64;
734                        let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
735                        gpu_span.upload_timestamp_start(start_duration);
736                        gpu_span.upload_timestamp_end(end_duration);
737                        ticks
738                    }),
739                    TimingMethod::Device,
740                )
741            });
742        }
743        core::mem::drop(device_guard);
744
745        result
746    }
747
748    /// Transfer data from one client to another
749    fn change_client_sync(
750        &self,
751        src_descriptor: CopyDescriptor<'_>,
752        alloc_descriptor: AllocationDescriptor<'_>,
753        dst_server: &Self,
754    ) -> Allocation {
755        let shape = src_descriptor.shape;
756        let elem_size = src_descriptor.elem_size;
757        let stream_id = self.stream_id();
758
759        // Allocate destination
760        let alloc = dst_server
761            .context
762            .lock()
763            .create(vec![alloc_descriptor], self.stream_id())
764            .unwrap()
765            .remove(0);
766
767        let read = self.context.lock().read(vec![src_descriptor], stream_id);
768        let mut data = cubecl_common::future::block_on(read).unwrap();
769
770        let desc_descriptor = CopyDescriptor {
771            binding: alloc.handle.clone().binding(),
772            shape,
773            strides: &alloc.strides,
774            elem_size,
775        };
776
777        dst_server
778            .context
779            .lock()
780            .write(vec![(desc_descriptor, data.remove(0))], stream_id)
781            .unwrap();
782
783        alloc
784    }
785}