Skip to main content

cubecl_runtime/
client.rs

1use crate::{
2    config::{TypeNameFormatLevel, type_name_format},
3    kernel::KernelMetadata,
4    logging::ProfileLevel,
5    memory_management::{MemoryAllocationMode, MemoryUsage},
6    runtime::Runtime,
7    server::{
8        CommunicationId, ComputeServer, CopyDescriptor, CubeCount, ExecutionMode, Handle, IoError,
9        KernelArguments, MemoryLayout, MemoryLayoutDescriptor, MemoryLayoutPolicy,
10        MemoryLayoutStrategy, ProfileError, ReduceOperation, ServerCommunication, ServerError,
11        ServerUtilities,
12    },
13    storage::{ComputeStorage, ManagedResource},
14};
15use alloc::{format, sync::Arc, vec, vec::Vec};
16use cubecl_common::{
17    backtrace::BackTrace,
18    bytes::{AllocationProperty, Bytes},
19    device::{Device, DeviceId},
20    device_handle::DeviceHandle,
21    future::DynFut,
22    profile::ProfileDuration,
23};
24use cubecl_ir::{DeviceProperties, ElemType, VectorSize, features::Features};
25use cubecl_zspace::Shape;
26
27#[allow(unused)]
28use cubecl_common::profile::TimingMethod;
29use cubecl_common::stream_id::StreamId;
30
31/// The `ComputeClient` is the entry point to require tasks from the `ComputeServer`.
32/// It should be obtained for a specific device via the Compute struct.
33pub struct ComputeClient<R: Runtime> {
34    device: DeviceHandle<R::Server>,
35    utilities: Arc<ServerUtilities<R::Server>>,
36    stream_id: Option<StreamId>,
37}
38
39impl<R: Runtime> Clone for ComputeClient<R> {
40    fn clone(&self) -> Self {
41        Self {
42            device: self.device.clone(),
43            utilities: self.utilities.clone(),
44            stream_id: self.stream_id,
45        }
46    }
47}
48
49impl<R: Runtime> ComputeClient<R> {
50    /// Get the info of the current backend.
51    pub fn info(&self) -> &<R::Server as ComputeServer>::Info {
52        &self.utilities.info
53    }
54
55    /// Create a new client with a new server.
56    pub fn init<D: Device>(device: &D, server: R::Server) -> Self {
57        let utilities = server.utilities();
58        let context = DeviceHandle::<R::Server>::insert(device.to_id(), server)
59            .expect("Can't create a new client on an already registered server");
60
61        Self {
62            device: context,
63            utilities,
64            stream_id: None,
65        }
66    }
67
68    /// Load the client for the given device.
69    pub fn load<D: Device>(device: &D) -> Self {
70        let context = DeviceHandle::<R::Server>::new(device.to_id());
71
72        // This is safe because we now know the return type of [`DeviceHandle::utilities()`].
73        let utilities = context
74            .utilities()
75            .downcast::<ServerUtilities<R::Server>>()
76            .expect("Can downcast to `ServerUtilities`");
77
78        Self {
79            device: context,
80            utilities,
81            stream_id: None,
82        }
83    }
84
85    fn stream_id(&self) -> StreamId {
86        match self.stream_id {
87            Some(val) => val,
88            None => StreamId::current(),
89        }
90    }
91
92    /// Set the stream in which the current client is operating on.
93    ///
94    /// # Safety
95    ///
96    /// This is highly unsafe and should probably only be used by the CubeCL/Burn projects for now.
97    pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
98        self.stream_id = Some(stream_id);
99    }
100
101    fn do_read(&self, descriptors: Vec<CopyDescriptor>) -> DynFut<Result<Vec<Bytes>, ServerError>> {
102        let stream_id = self.stream_id();
103        self.device
104            .submit_blocking(move |server| server.read(descriptors, stream_id))
105            .unwrap()
106    }
107
108    /// Given bindings, returns owned resources as bytes.
109    pub fn read_async(
110        &self,
111        handles: Vec<Handle>,
112    ) -> impl Future<Output = Result<Vec<Bytes>, ServerError>> + Send {
113        let shapes = handles
114            .iter()
115            .map(|it| [it.size_in_used() as usize].into())
116            .collect::<Vec<Shape>>();
117        let descriptors = handles
118            .into_iter()
119            .zip(shapes)
120            .map(|(handle, shape)| CopyDescriptor::new(handle.binding(), shape, [1].into(), 1))
121            .collect();
122
123        self.do_read(descriptors)
124    }
125
126    /// Given bindings, returns owned resources as bytes.
127    ///
128    /// # Remarks
129    ///
130    /// Panics if the read operation fails.
131    pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
132        cubecl_common::reader::read_sync(self.read_async(handles)).expect("TODO")
133    }
134
135    /// Given a binding, returns owned resource as bytes.
136    pub fn read_one(&self, handle: Handle) -> Result<Bytes, ServerError> {
137        Ok(cubecl_common::reader::read_sync(self.read_async(vec![handle]))?.remove(0))
138    }
139
140    /// Given a binding, returns owned resource as bytes.
141    ///
142    /// # Remarks
143    ///
144    /// Panics if the read operation fails. Useful for tests.
145    pub fn read_one_unchecked(&self, handle: Handle) -> Bytes {
146        cubecl_common::reader::read_sync(self.read_async(vec![handle]))
147            .unwrap()
148            .remove(0)
149    }
150
151    /// Given bindings, returns owned resources as bytes.
152    pub fn read_tensor_async(
153        &self,
154        descriptors: Vec<CopyDescriptor>,
155    ) -> impl Future<Output = Result<Vec<Bytes>, ServerError>> + Send {
156        self.do_read(descriptors)
157    }
158
159    /// Given bindings, returns owned resources as bytes.
160    ///
161    /// # Remarks
162    ///
163    /// Panics if the read operation fails.
164    ///
165    /// The tensor must be in the same layout as created by the runtime, or more strict.
166    /// Contiguous tensors are always fine, strided tensors are only ok if the stride is similar to
167    /// the one created by the runtime (i.e. padded on only the last dimension). A way to check
168    /// stride compatibility on the runtime will be added in the future.
169    ///
170    /// Also see [`ComputeClient::create_tensor`].
171    pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor>) -> Vec<Bytes> {
172        cubecl_common::reader::read_sync(self.read_tensor_async(descriptors)).expect("TODO")
173    }
174
175    /// Given a binding, returns owned resource as bytes.
176    /// See [`ComputeClient::read_tensor`]
177    pub fn read_one_tensor_async(
178        &self,
179        descriptor: CopyDescriptor,
180    ) -> impl Future<Output = Result<Bytes, ServerError>> + Send {
181        let fut = self.read_tensor_async(vec![descriptor]);
182
183        async { Ok(fut.await?.remove(0)) }
184    }
185
186    /// Given a binding, returns owned resource as bytes.
187    ///
188    /// # Remarks
189    ///
190    /// Panics if the read operation fails.
191    /// See [`ComputeClient::read_tensor`]
192    pub fn read_one_unchecked_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
193        self.read_tensor(vec![descriptor]).remove(0)
194    }
195
196    /// Given a resource handle, returns the storage resource.
197    pub fn get_resource(
198        &self,
199        handle: Handle,
200    ) -> Result<
201        ManagedResource<<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource>,
202        ServerError,
203    > {
204        let stream_id = self.stream_id();
205        let binding = handle.binding();
206
207        self.device
208            .submit_blocking(move |state| state.get_resource(binding, stream_id))
209            .unwrap()
210    }
211
212    fn do_create_from_slices(
213        &self,
214        descriptors: Vec<MemoryLayoutDescriptor>,
215        slices: Vec<Vec<u8>>,
216    ) -> Result<Vec<MemoryLayout>, IoError> {
217        let stream_id = self.stream_id();
218        let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
219
220        // Wrap each input slice as a (pageable) `Bytes`, then run `staging` so the
221        // server can swap each one to pinned host memory on backends that benefit
222        // (e.g. CUDA: `cuMemcpyHtoDAsync` from pinned memory hits the DMA fast path
223        // at 12-25 GB/s on PCIe 4.0 vs 5-6 GB/s from pageable memory).
224        //
225        // This keeps the public API unchanged while routing single-shot uploads
226        // through the same fast path that `create` / `create_tensor` already use.
227        let mut data: Vec<Bytes> = slices.into_iter().map(Bytes::from_bytes_vec).collect();
228        self.staging(data.iter_mut(), true);
229
230        let descriptors = descriptors
231            .into_iter()
232            .zip(layouts.iter())
233            .zip(data)
234            .map(|((desc, alloc), data)| {
235                (
236                    CopyDescriptor::new(
237                        alloc.memory.clone().binding(),
238                        desc.shape,
239                        alloc.strides.clone(),
240                        desc.elem_size,
241                    ),
242                    data,
243                )
244            })
245            .collect::<Vec<_>>();
246
247        let (size, memory) = (handle_base.size(), handle_base.memory);
248        self.device.submit(move |server| {
249            server.initialize_memory(memory, size, stream_id);
250            server.write(descriptors, stream_id);
251        });
252
253        Ok(layouts)
254    }
255
256    fn do_create(
257        &self,
258        descriptors: Vec<MemoryLayoutDescriptor>,
259        mut data: Vec<Bytes>,
260    ) -> Result<Vec<MemoryLayout>, IoError> {
261        // After `staging`, each `Bytes` may have been swapped in-place to a pinned
262        // host buffer. Forward those `Bytes` to the server *as-is* — re-wrapping via
263        // `Bytes::from_bytes_vec(data.to_vec())` would allocate a fresh pageable
264        // `Vec<u8>`, demote the buffer back to `AllocationProperty::Native`, and
265        // re-trigger the slow CUDA pageable bounce on the subsequent HtoD copy.
266        self.staging(data.iter_mut(), true);
267
268        let stream_id = self.stream_id();
269        let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
270
271        let descriptors = descriptors
272            .into_iter()
273            .zip(layouts.iter())
274            .zip(data)
275            .map(|((desc, layout), data)| {
276                (
277                    CopyDescriptor::new(
278                        layout.memory.clone().binding(),
279                        desc.shape,
280                        layout.strides.clone(),
281                        desc.elem_size,
282                    ),
283                    data,
284                )
285            })
286            .collect::<Vec<_>>();
287
288        let (size, memory) = (handle_base.size(), handle_base.memory);
289        self.device.submit(move |server| {
290            server.initialize_memory(memory, size, stream_id);
291            server.write(descriptors, stream_id);
292        });
293
294        Ok(layouts)
295    }
296
297    /// Returns a resource handle containing the given data.
298    ///
299    /// # Notes
300    ///
301    /// Prefer using the more efficient [`Self::create`] function.
302    pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
303        let shape: Shape = [slice.len()].into();
304
305        self.do_create_from_slices(
306            vec![MemoryLayoutDescriptor::new(
307                MemoryLayoutStrategy::Contiguous,
308                shape,
309                1,
310            )],
311            vec![slice.to_vec()],
312        )
313        .unwrap()
314        .remove(0)
315        .memory
316    }
317
318    /// Reserves pinned (page-locked, on backends that support it) host buffers of
319    /// the requested sizes. The caller fills the returned [`Bytes`] (e.g. via
320    /// [`Bytes::copy_from_slice`] or by writing through `DerefMut`) and then hands
321    /// the buffers to [`Self::create`], [`Self::create_tensor`], or
322    /// [`Self::create_tensors`] to upload them to the device.
323    ///
324    /// On CUDA, pinned host memory enables direct DMA in `cuMemcpyHtoDAsync`,
325    /// reaching ~12-25 GB/s on PCIe 4.0 compared to ~5-6 GB/s from pageable
326    /// memory. On backends without an explicit pinned-memory concept this falls
327    /// back to a regular host allocation, so callers can use this API
328    /// unconditionally without regressing other backends.
329    ///
330    /// Note that pinned host memory is a limited system resource — allocate it
331    /// only for buffers that will actually be uploaded to the device, and drop
332    /// the [`Bytes`] handle as soon as the upload completes.
333    pub fn reserve_staging(&self, sizes: &[usize]) -> Vec<Bytes> {
334        if sizes.is_empty() {
335            return Vec::new();
336        }
337
338        let stream_id = self.stream_id();
339        let sizes_owned = sizes.to_vec();
340        let result = self
341            .device
342            .submit_blocking(move |server| server.staging(&sizes_owned, stream_id))
343            .unwrap();
344
345        match result {
346            Ok(stagings) => stagings,
347            // Backends may return errors if pinned memory is exhausted. Fall back
348            // to plain heap allocations so the caller always gets buffers of the
349            // requested sizes.
350            Err(_) => sizes
351                .iter()
352                .map(|&size| Bytes::from_bytes_vec(vec![0u8; size]))
353                .collect(),
354        }
355    }
356
357    /// Like [`Self::create_from_slice`], but copies the input directly into a
358    /// pinned host buffer (on backends that support it) before issuing the
359    /// device upload.
360    ///
361    /// The default [`Self::create_from_slice`] path performs two host-side
362    /// memcpys (caller `&[u8]` → pageable `Vec<u8>` → pinned [`Bytes`]) before
363    /// the device transfer. This variant skips the intermediate `Vec<u8>` and
364    /// copies straight into the pinned staging buffer, halving host-side
365    /// memory traffic for large uploads. The on-device handle is identical.
366    ///
367    /// On backends without a pinned-memory fast path this behaves the same as
368    /// [`Self::create_from_slice`].
369    pub fn create_from_slice_pinned(&self, slice: &[u8]) -> Handle {
370        let mut staging = self.reserve_staging(&[slice.len()]);
371        let mut bytes = staging.pop().expect("reserve_staging returned no buffers");
372        bytes.copy_from_slice(slice);
373        self.create(bytes)
374    }
375
376    /// Like [`Self::create_tensors_from_slices`], but copies inputs directly
377    /// into pinned host buffers before issuing the device upload. See
378    /// [`Self::create_from_slice_pinned`] for the host-side savings rationale.
379    pub fn create_tensors_from_slices_pinned(
380        &self,
381        descriptors: Vec<(MemoryLayoutDescriptor, &[u8])>,
382    ) -> Vec<MemoryLayout> {
383        let sizes: Vec<usize> = descriptors.iter().map(|(_, s)| s.len()).collect();
384        let stagings = self.reserve_staging(&sizes);
385
386        let mut bytes_vec = Vec::with_capacity(descriptors.len());
387        let mut descs = Vec::with_capacity(descriptors.len());
388        for ((desc, slice), mut staging) in descriptors.into_iter().zip(stagings) {
389            staging.copy_from_slice(slice);
390            bytes_vec.push(staging);
391            descs.push(desc);
392        }
393
394        self.do_create(descs, bytes_vec).unwrap()
395    }
396
397    /// todo: docs
398    pub fn exclusive<'a, Re: Send + 'static, F: FnOnce() -> Re + Send + 'a>(
399        &'a self,
400        task: F,
401    ) -> Result<Re, ServerError> {
402        // We then launch the task.
403        self.device
404            .exclusive(task)
405            .map_err(|err| ServerError::Generic {
406                reason: format!("Communication channel with the server is down: {err:?}"),
407                backtrace: BackTrace::capture(),
408            })
409    }
410
411    /// dodo: Docs
412    pub fn memory_persistent_allocation<
413        'a,
414        Re: Send,
415        Input: Send,
416        F: FnOnce(Input) -> Re + Send + 'a,
417    >(
418        &'a self,
419        input: Input,
420        task: F,
421    ) -> Result<Re, ServerError> {
422        let stream_id = StreamId::current();
423
424        self.device.submit(move |server| {
425            server.allocation_mode(MemoryAllocationMode::Persistent, stream_id);
426        });
427
428        // All tasks created on the same stream will have persistent memory.
429        let output = task(input);
430
431        self.device.submit(move |server| {
432            server.allocation_mode(MemoryAllocationMode::Auto, stream_id);
433        });
434
435        Ok(output)
436    }
437
438    /// Returns a resource handle containing the given [Bytes].
439    pub fn create(&self, data: Bytes) -> Handle {
440        let shape = [data.len()].into();
441
442        self.do_create(
443            vec![MemoryLayoutDescriptor::new(
444                MemoryLayoutStrategy::Contiguous,
445                shape,
446                1,
447            )],
448            vec![data],
449        )
450        .unwrap()
451        .remove(0)
452        .memory
453    }
454
455    /// Given a resource and shape, stores it and returns the tensor handle and strides.
456    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
457    /// should be taken when indexing.
458    ///
459    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
460    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
461    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
462    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
463    /// can load as much data as possible in a single instruction. It may be aligned even more to
464    /// also take cache lines into account.
465    ///
466    /// However, the stride must be taken into account when indexing and reading the tensor
467    /// (also see [`ComputeClient::read_tensor`]).
468    ///
469    /// # Notes
470    ///
471    /// Prefer using [`Self::create_tensor`] for better performance.
472    pub fn create_tensor_from_slice(
473        &self,
474        slice: &[u8],
475        shape: Shape,
476        elem_size: usize,
477    ) -> MemoryLayout {
478        self.do_create_from_slices(
479            vec![MemoryLayoutDescriptor::new(
480                MemoryLayoutStrategy::Optimized,
481                shape,
482                elem_size,
483            )],
484            vec![slice.to_vec()],
485        )
486        .unwrap()
487        .remove(0)
488    }
489
490    /// Given a resource and shape, stores it and returns the tensor handle and strides.
491    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
492    /// should be taken when indexing.
493    ///
494    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
495    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
496    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
497    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
498    /// can load as much data as possible in a single instruction. It may be aligned even more to
499    /// also take cache lines into account.
500    ///
501    /// However, the stride must be taken into account when indexing and reading the tensor
502    /// (also see [`ComputeClient::read_tensor`]).
503    pub fn create_tensor(&self, bytes: Bytes, shape: Shape, elem_size: usize) -> MemoryLayout {
504        self.do_create(
505            vec![MemoryLayoutDescriptor::new(
506                MemoryLayoutStrategy::Optimized,
507                shape,
508                elem_size,
509            )],
510            vec![bytes],
511        )
512        .unwrap()
513        .remove(0)
514    }
515
516    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
517    /// handle, and returns the handles for them.
518    /// See [`ComputeClient::create_tensor`]
519    ///
520    /// # Notes
521    ///
522    /// Prefer using [`Self::create_tensors`] for better performance.
523    pub fn create_tensors_from_slices(
524        &self,
525        descriptors: Vec<(MemoryLayoutDescriptor, &[u8])>,
526    ) -> Vec<MemoryLayout> {
527        let mut data = Vec::with_capacity(descriptors.len());
528        let mut descriptors_ = Vec::with_capacity(descriptors.len());
529        for (a, b) in descriptors {
530            data.push(b.to_vec());
531            descriptors_.push(a);
532        }
533
534        self.do_create_from_slices(descriptors_, data).unwrap()
535    }
536
537    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
538    /// handle, and returns the handles for them.
539    /// See [`ComputeClient::create_tensor`]
540    pub fn create_tensors(
541        &self,
542        descriptors: Vec<(MemoryLayoutDescriptor, Bytes)>,
543    ) -> Vec<MemoryLayout> {
544        let (descriptors, data) = descriptors.into_iter().unzip();
545
546        self.do_create(descriptors, data).unwrap()
547    }
548
549    fn do_empty(
550        &self,
551        descriptors: Vec<MemoryLayoutDescriptor>,
552    ) -> Result<Vec<MemoryLayout>, IoError> {
553        let stream_id = self.stream_id();
554        let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
555
556        let (size, memory) = (handle_base.size(), handle_base.memory);
557        self.device.submit(move |server| {
558            server.initialize_memory(memory, size, stream_id);
559        });
560
561        Ok(layouts)
562    }
563
564    /// Reserves `size` bytes in the storage, and returns a handle over them.
565    pub fn empty(&self, size: usize) -> Handle {
566        let shape: Shape = [size].into();
567        let descriptor = MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Contiguous, shape, 1);
568        self.do_empty(vec![descriptor]).unwrap().remove(0).memory
569    }
570
571    /// Reserves `shape` in the storage, and returns a tensor handle for it.
572    /// See [`ComputeClient::create_tensor`]
573    pub fn empty_tensor(&self, shape: Shape, elem_size: usize) -> MemoryLayout {
574        let descriptor =
575            MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Optimized, shape, elem_size);
576        self.do_empty(vec![descriptor]).unwrap().remove(0)
577    }
578
579    /// Reserves all `shapes` in a single storage buffer, and returns the handles for them.
580    /// See [`ComputeClient::create_tensor`]
581    pub fn empty_tensors(&self, descriptors: Vec<MemoryLayoutDescriptor>) -> Vec<MemoryLayout> {
582        self.do_empty(descriptors).unwrap()
583    }
584
585    /// Marks the given [Bytes] as being a staging buffer, maybe transferring it to pinned memory
586    /// for faster data transfer with compute device.
587    ///
588    /// TODO: This blocks the compute queue, so it will drop the compute utilization.
589    pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
590    where
591        I: Iterator<Item = &'a mut Bytes>,
592    {
593        let has_staging = |b: &Bytes| match b.property() {
594            AllocationProperty::Pinned => false,
595            AllocationProperty::File => true,
596            AllocationProperty::Native | AllocationProperty::Other => !file_only,
597        };
598
599        let mut to_be_updated = Vec::new();
600        let sizes = bytes
601            .filter_map(|b| match has_staging(b) {
602                true => {
603                    let len = b.len();
604                    to_be_updated.push(b);
605                    Some(len)
606                }
607                false => None,
608            })
609            .collect::<Vec<usize>>();
610
611        if sizes.is_empty() {
612            return;
613        }
614
615        let stream_id = self.stream_id();
616        let sizes = sizes.to_vec();
617        let stagings = self
618            .device
619            .submit_blocking(move |server| server.staging(&sizes, stream_id))
620            .unwrap();
621
622        let stagings = match stagings {
623            Ok(val) => val,
624            Err(_) => return,
625        };
626
627        to_be_updated
628            .into_iter()
629            .zip(stagings)
630            .for_each(|(b, mut staging)| {
631                b.copy_into(&mut staging);
632                core::mem::swap(b, &mut staging);
633            });
634    }
635
636    /// Transfer data from one client to another
637    #[cfg_attr(
638        feature = "tracing",
639        tracing::instrument(level = "trace", skip(self, src, dst_server))
640    )]
641    pub fn to_client(&mut self, src: Handle, dst_server: &Self, dtype: ElemType) -> Handle {
642        let shape = [src.size_in_used() as usize];
643        let src_descriptor = src.copy_descriptor(shape.into(), [1].into(), 1);
644
645        if R::Server::SERVER_COMM_ENABLED {
646            self.to_client_tensor(src_descriptor, dst_server, dtype)
647        } else {
648            let alloc_desc = MemoryLayoutDescriptor::new(
649                MemoryLayoutStrategy::Contiguous,
650                src_descriptor.shape.clone(),
651                src_descriptor.elem_size,
652            );
653            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
654                .memory
655        }
656    }
657
658    /// Perform an `all_reduce` operation on the given devices.
659    #[cfg_attr(
660        feature = "tracing",
661        tracing::instrument(level = "trace", skip(self, device_ids))
662    )]
663    pub fn ensure_init_collective(&mut self, device_ids: Vec<DeviceId>) {
664        let comm_id = CommunicationId::from(device_ids.clone());
665        let is_comms_init = self
666            .utilities
667            .initialized_comms
668            .read()
669            .unwrap()
670            .contains(&comm_id);
671        if !is_comms_init {
672            self.device
673                .submit(move |server| server.comm_init(device_ids).unwrap());
674            let mut initialized_comms = self.utilities.initialized_comms.write().unwrap();
675            initialized_comms.insert(comm_id);
676            // Flush immediately so other devices aren't blocked waiting on this initialization.
677            self.device.flush_queue();
678        }
679    }
680
681    /// Wait on the communication stream.
682    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
683    pub fn sync_collective(&self) {
684        if DeviceHandle::<R::Server>::is_blocking() {
685            panic!("Can't use `sync_collective` with a blocking device handle");
686        }
687        let stream_id = self.stream_id();
688
689        self.device.submit(move |server| {
690            server.sync_collective(stream_id).unwrap();
691        });
692
693        // We don't actually need or want to sync the server here, but we need to make sure any
694        // task enqueued on the communication channel is done.
695        self.device.flush_queue();
696    }
697
698    /// Perform an `all_reduce` operation on the given devices.
699    #[cfg_attr(
700        feature = "tracing",
701        tracing::instrument(level = "trace", skip(self, src, dst, dtype, device_ids, op))
702    )]
703    pub fn all_reduce(
704        &mut self,
705        src: Handle,
706        dst: Handle,
707        dtype: ElemType,
708        device_ids: Vec<DeviceId>,
709        op: ReduceOperation,
710    ) {
711        if DeviceHandle::<R::Server>::is_blocking() {
712            panic!("Can't use `all_reduce` with a blocking device handle");
713        }
714
715        let stream_id = self.stream_id();
716        let src = src.binding();
717        let dst = dst.binding();
718
719        self.ensure_init_collective(device_ids.clone());
720
721        self.device.submit(move |server| {
722            server
723                .all_reduce(src, dst, dtype, stream_id, op, device_ids)
724                .unwrap();
725        });
726    }
727
728    /// Transfer data from one client to another
729    ///
730    /// Make sure the source description can be read in a contiguous manner.
731    #[cfg_attr(
732        feature = "tracing",
733        tracing::instrument(level = "trace", skip(self, src_descriptor, dst_server))
734    )]
735    pub fn to_client_tensor(
736        &mut self,
737        src_descriptor: CopyDescriptor,
738        dst_server: &Self,
739        dtype: ElemType,
740    ) -> Handle {
741        let stream_id_src = self.stream_id();
742        let stream_id_dst = dst_server.stream_id();
743
744        let device_id_src = self.device.device_id();
745        let device_id_dst = dst_server.device.device_id();
746
747        let mut dst_server = dst_server.clone();
748        let handle = Handle::new(stream_id_dst, src_descriptor.handle.size_in_used());
749        let handle_cloned = handle.clone();
750
751        let device_ids = vec![device_id_src, device_id_dst];
752        self.ensure_init_collective(device_ids.clone());
753        dst_server.ensure_init_collective(device_ids);
754
755        self.device.submit(move |server_src| {
756            server_src
757                .send(src_descriptor, dtype, stream_id_src, device_id_dst)
758                .unwrap()
759        });
760
761        dst_server.device.submit(move |server_dst| {
762            server_dst
763                .recv(handle_cloned, dtype, stream_id_dst, device_id_src)
764                .unwrap();
765            server_dst.sync_collective(stream_id_dst).unwrap();
766        });
767
768        // `ServerCommunication::send` and`ServerCommunication::recv` are blocking: they each wait for the corresponding recv/send
769        // call to be made. We flush the operations right away so that the neither server ends up in a deadlock.
770        // The actual data transfer is still executed asynchronously on the communication stream.
771        self.device.flush_queue();
772        dst_server.device.flush_queue();
773
774        handle
775    }
776
777    #[track_caller]
778    #[cfg_attr(feature = "tracing", tracing::instrument(level="trace",
779        skip(self, kernel, bindings),
780        fields(
781            kernel.name = %kernel.name(),
782            kernel.id = %kernel.id(),
783        )
784    ))]
785    unsafe fn launch_inner(
786        &self,
787        kernel: <R::Server as ComputeServer>::Kernel,
788        count: CubeCount,
789        bindings: KernelArguments,
790        mode: ExecutionMode,
791        stream_id: StreamId,
792    ) {
793        let level = self.utilities.logger.profile_level();
794
795        match level {
796            None | Some(ProfileLevel::ExecutionOnly) => {
797                let utilities = self.utilities.clone();
798                self.device.submit(move |state| {
799                    let name = kernel.name();
800                    unsafe { state.launch(kernel, count, bindings, mode, stream_id) };
801
802                    if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
803                        let info = type_name_format(name, TypeNameFormatLevel::Balanced);
804                        utilities.logger.register_execution(info);
805                    }
806                });
807            }
808            Some(level) => {
809                let name = kernel.name();
810                let kernel_id = kernel.id();
811                let context = self.device.clone();
812                let count_moved = count.clone();
813                let (result, profile) = self
814                    .profile(
815                        move || {
816                            context
817                                .submit_blocking(move |state| unsafe {
818                                    state.launch(kernel, count_moved, bindings, mode, stream_id)
819                                })
820                                .unwrap()
821                        },
822                        name,
823                    )
824                    .unwrap();
825                let info = match level {
826                    ProfileLevel::Full => {
827                        format!("{name}: {kernel_id} CubeCount {count:?}")
828                    }
829                    _ => type_name_format(name, TypeNameFormatLevel::Balanced),
830                };
831                self.utilities.logger.register_profiled(info, profile);
832                result
833            }
834        }
835    }
836
837    /// Launches the `kernel` with the given `bindings`.
838    #[track_caller]
839    pub fn launch(
840        &self,
841        kernel: <R::Server as ComputeServer>::Kernel,
842        count: CubeCount,
843        bindings: KernelArguments,
844    ) {
845        // SAFETY: Using checked execution mode.
846        unsafe {
847            self.launch_inner(
848                kernel,
849                count,
850                bindings,
851                ExecutionMode::Checked,
852                self.stream_id(),
853            )
854        }
855    }
856
857    /// Launches the `kernel` with the given `bindings` without performing any bound checks.
858    ///
859    /// # Safety
860    ///
861    /// To ensure this is safe, you must verify your kernel:
862    /// - Has no out-of-bound reads and writes that can happen.
863    /// - Has no infinite loops that might never terminate.
864    #[track_caller]
865    pub unsafe fn launch_unchecked(
866        &self,
867        kernel: <R::Server as ComputeServer>::Kernel,
868        count: CubeCount,
869        bindings: KernelArguments,
870    ) {
871        // SAFETY: Caller has to uphold kernel being safe.
872        unsafe {
873            self.launch_inner(
874                kernel,
875                count,
876                bindings,
877                match self.utilities.check_mode {
878                    crate::config::compilation::BoundsCheckMode::Enforce => ExecutionMode::Checked,
879                    crate::config::compilation::BoundsCheckMode::Validate => {
880                        ExecutionMode::Validate
881                    }
882                    crate::config::compilation::BoundsCheckMode::Auto => ExecutionMode::Unchecked,
883                },
884                self.stream_id(),
885            )
886        }
887    }
888
889    /// Flush all outstanding commands.
890    pub fn flush(&self) -> Result<(), ServerError> {
891        let stream_id = self.stream_id();
892
893        self.device
894            .submit_blocking(move |server| server.flush(stream_id))
895            .unwrap()
896    }
897
898    /// Wait for the completion of every task in the server.
899    pub fn sync(&self) -> DynFut<Result<(), ServerError>> {
900        let stream_id = self.stream_id();
901
902        let fut = self
903            .device
904            .submit_blocking(move |server| server.sync(stream_id))
905            .unwrap();
906
907        self.utilities.logger.profile_summary();
908
909        fut
910    }
911
912    /// Get the features supported by the compute server.
913    pub fn properties(&self) -> &DeviceProperties {
914        &self.utilities.properties
915    }
916
917    /// Get the features supported by the compute server.
918    pub fn features(&self) -> &Features {
919        &self.utilities.properties.features
920    }
921
922    /// # Warning
923    ///
924    /// For private use only.
925    pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
926        Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
927    }
928
929    /// Get the current memory usage of this client.
930    pub fn memory_usage(&self) -> Result<MemoryUsage, ServerError> {
931        let stream_id = self.stream_id();
932        self.device
933            .submit_blocking(move |server| server.memory_usage(stream_id))
934            .unwrap()
935    }
936
937    /// Get all devices of a specific type available to this runtime
938    pub fn enumerate_devices(&self, type_id: u16) -> Vec<DeviceId> {
939        R::enumerate_devices(type_id, self.info())
940    }
941
942    /// Get all devices available to this runtime
943    pub fn enumerate_all_devices(&self) -> Vec<DeviceId> {
944        R::enumerate_all_devices(self.info())
945    }
946
947    /// Get the number of devices of a specific type available to this runtime
948    pub fn device_count(&self, type_id: u16) -> usize {
949        self.enumerate_devices(type_id).len()
950    }
951
952    /// Get the number of devices of a specific type available to this runtime
953    pub fn device_count_total(&self) -> usize {
954        self.enumerate_all_devices().len()
955    }
956
957    /// Change the memory allocation mode.
958    ///
959    /// # Safety
960    ///
961    /// This function isn't thread safe and might create memory leaks.
962    pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
963        let stream_id = self.stream_id();
964        self.device
965            .submit(move |server| server.allocation_mode(mode, stream_id));
966    }
967
968    /// Ask the client to release memory that it can release.
969    ///
970    /// Nb: Results will vary on what the memory allocator deems beneficial,
971    /// so it's not guaranteed any memory is freed.
972    pub fn memory_cleanup(&self) {
973        let stream_id = self.stream_id();
974        self.device
975            .submit(move |server| server.memory_cleanup(stream_id));
976    }
977
978    /// Measure the execution time of some inner operations.
979    #[track_caller]
980    pub fn profile<O: Send + 'static>(
981        &self,
982        func: impl FnOnce() -> O + Send,
983        #[allow(unused)] func_name: &str,
984    ) -> Result<(O, ProfileDuration), ProfileError> {
985        // Get the outer caller. For execute() this points straight to the
986        // cube kernel. For general profiling it points to whoever calls profile.
987        #[cfg(feature = "profile-tracy")]
988        let location = std::panic::Location::caller();
989
990        // Make a CPU span. If the server has system profiling this is all you need.
991        #[cfg(feature = "profile-tracy")]
992        let _span = tracy_client::Client::running().unwrap().span_alloc(
993            None,
994            func_name,
995            location.file(),
996            location.line(),
997            0,
998        );
999
1000        let stream_id = self.stream_id();
1001
1002        #[cfg(feature = "profile-tracy")]
1003        let gpu_span = if self.utilities.properties.timing_method == TimingMethod::Device {
1004            let gpu_span = self
1005                .utilities
1006                .gpu_client
1007                .span_alloc(func_name, "profile", location.file(), location.line())
1008                .unwrap();
1009            Some(gpu_span)
1010        } else {
1011            None
1012        };
1013
1014        let device = self.device.clone();
1015        #[allow(unused_mut, reason = "Used in profile-tracy")]
1016        let mut result = self
1017            .device
1018            .exclusive(move || {
1019                // We first get mut access to the server to create a token.
1020                // Then we free to server, since it's going to be accessed in `func()`.
1021                let token =
1022                    match device.submit_blocking(move |server| server.start_profile(stream_id)) {
1023                        Ok(token) => match token {
1024                            Ok(token) => token,
1025                            Err(err) => return Err(err),
1026                        },
1027                        Err(err) => {
1028                            return Err(ServerError::Generic {
1029                                reason: alloc::format!(
1030                                    "Can't start profiling because of a call error: {err:?}"
1031                                ),
1032                                backtrace: BackTrace::capture(),
1033                            });
1034                        }
1035                    };
1036
1037                // We execute `func()` which will recursibly access the server.
1038                let out = func();
1039
1040                // Finally we get the result from the token.
1041                let result = device
1042                    .submit_blocking(move |server| {
1043                        let mut result = server.end_profile(stream_id, token);
1044
1045                        match result {
1046                            Ok(result) => Ok((out, result)),
1047                            Err(err) => Err(err),
1048                        }
1049                    })
1050                    .unwrap();
1051
1052                Ok(result)
1053            })
1054            .unwrap()
1055            .map_err(|err| ProfileError::Unknown {
1056                reason: alloc::format!("{err:?}"),
1057                backtrace: BackTrace::capture(),
1058            })?;
1059
1060        #[cfg(feature = "profile-tracy")]
1061        if let Some(mut gpu_span) = gpu_span {
1062            gpu_span.end_zone();
1063            let epoch = self.utilities.epoch_time;
1064            // Add in the work to upload the timestamp data.
1065            result = result.map(|(o, result)| {
1066                (
1067                    o,
1068                    ProfileDuration::new(
1069                        alloc::boxed::Box::pin(async move {
1070                            let ticks = result.resolve().await;
1071                            let start_duration =
1072                                ticks.start_duration_since(epoch).as_nanos() as i64;
1073                            let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
1074                            gpu_span.upload_timestamp_start(start_duration);
1075                            gpu_span.upload_timestamp_end(end_duration);
1076                            ticks
1077                        }),
1078                        TimingMethod::Device,
1079                    ),
1080                )
1081            });
1082        }
1083
1084        result
1085    }
1086
1087    /// Transfer data from one client to another
1088    #[cfg_attr(
1089        feature = "tracing",
1090        tracing::instrument(
1091            level = "trace",
1092            skip(self, src_descriptor, alloc_descriptor, dst_server)
1093        )
1094    )]
1095    fn change_client_sync(
1096        &self,
1097        src_descriptor: CopyDescriptor,
1098        alloc_descriptor: MemoryLayoutDescriptor,
1099        dst_server: &Self,
1100    ) -> MemoryLayout {
1101        let shape = src_descriptor.shape.clone();
1102        let elem_size = src_descriptor.elem_size;
1103        let stream_id = self.stream_id();
1104
1105        let read = self
1106            .device
1107            .submit_blocking(move |server| server.read(vec![src_descriptor], stream_id))
1108            .unwrap();
1109
1110        let mut data = cubecl_common::future::block_on(read).unwrap();
1111
1112        let (handle_base, mut layouts) = self
1113            .utilities
1114            .layout_policy
1115            .apply(stream_id, &[alloc_descriptor]);
1116        let alloc = layouts.remove(0);
1117
1118        let desc_descriptor = CopyDescriptor {
1119            handle: handle_base.clone().binding(),
1120            shape,
1121            strides: alloc.strides.clone(),
1122            elem_size,
1123        };
1124
1125        let (size, memory) = (handle_base.size(), handle_base.memory);
1126        dst_server.device.submit(move |server| {
1127            server.initialize_memory(memory, size, stream_id);
1128            server.write(vec![(desc_descriptor, data.remove(0))], stream_id)
1129        });
1130
1131        alloc
1132    }
1133
1134    /// Returns all vector sizes that are useful to perform optimal IO operation on the given element.
1135    pub fn io_optimized_vector_sizes(
1136        &self,
1137        size: usize,
1138    ) -> impl Iterator<Item = VectorSize> + Clone {
1139        let load_width = self.properties().hardware.load_width as usize;
1140        let size_bits = size * 8;
1141        let max = load_width / size_bits;
1142        let max = usize::min(self.properties().hardware.max_vector_size, max);
1143
1144        // If the max is 8, we want to test 1, 2, 4, 8 which is log2(8) + 1.
1145        let num_candidates = max.trailing_zeros() + 1;
1146
1147        (0..num_candidates).map(|i| 2usize.pow(i)).rev()
1148    }
1149}