Skip to main content

cubecl_runtime/
client.rs

1use crate::{
2    config::{TypeNameFormatLevel, type_name_format},
3    kernel::KernelMetadata,
4    logging::ProfileLevel,
5    memory_management::{MemoryAllocationMode, MemoryUsage},
6    runtime::Runtime,
7    server::{
8        ComputeServer, CopyDescriptor, CubeCount, ExecutionMode, Handle, IoError, KernelArguments,
9        MemoryLayout, MemoryLayoutDescriptor, MemoryLayoutPolicy, MemoryLayoutStrategy,
10        ProfileError, ReduceOperation, ServerCommunication, ServerError, ServerUtilities,
11    },
12    storage::{ComputeStorage, ManagedResource},
13};
14use alloc::{format, sync::Arc, vec, vec::Vec};
15use cubecl_common::{
16    backtrace::BackTrace,
17    bytes::{AllocationProperty, Bytes},
18    device::{Device, DeviceId},
19    device_handle::DeviceHandle,
20    future::DynFut,
21    profile::ProfileDuration,
22};
23use cubecl_ir::{DeviceProperties, ElemType, VectorSize, features::Features};
24use cubecl_zspace::Shape;
25
26#[allow(unused)]
27use cubecl_common::profile::TimingMethod;
28use cubecl_common::stream_id::StreamId;
29
30/// The `ComputeClient` is the entry point to require tasks from the `ComputeServer`.
31/// It should be obtained for a specific device via the Compute struct.
32pub struct ComputeClient<R: Runtime> {
33    device: DeviceHandle<R::Server>,
34    utilities: Arc<ServerUtilities<R::Server>>,
35    stream_id: Option<StreamId>,
36}
37
38impl<R: Runtime> Clone for ComputeClient<R> {
39    fn clone(&self) -> Self {
40        Self {
41            device: self.device.clone(),
42            utilities: self.utilities.clone(),
43            stream_id: self.stream_id,
44        }
45    }
46}
47
48impl<R: Runtime> ComputeClient<R> {
49    /// Get the info of the current backend.
50    pub fn info(&self) -> &<R::Server as ComputeServer>::Info {
51        &self.utilities.info
52    }
53
54    /// Create a new client with a new server.
55    pub fn init<D: Device>(device: &D, server: R::Server) -> Self {
56        let utilities = server.utilities();
57        let context = DeviceHandle::<R::Server>::insert(device.to_id(), server)
58            .expect("Can't create a new client on an already registered server");
59
60        Self {
61            device: context,
62            utilities,
63            stream_id: None,
64        }
65    }
66
67    /// Load the client for the given device.
68    pub fn load<D: Device>(device: &D) -> Self {
69        let context = DeviceHandle::<R::Server>::new(device.to_id());
70
71        // This is safe because we now know the return type of [`DeviceHandle::utilities()`].
72        let utilities = context
73            .utilities()
74            .downcast::<ServerUtilities<R::Server>>()
75            .expect("Can downcast to `ServerUtilities`");
76
77        Self {
78            device: context,
79            utilities,
80            stream_id: None,
81        }
82    }
83
84    fn stream_id(&self) -> StreamId {
85        match self.stream_id {
86            Some(val) => val,
87            None => StreamId::current(),
88        }
89    }
90
91    /// Set the stream in which the current client is operating on.
92    ///
93    /// # Safety
94    ///
95    /// This is highly unsafe and should probably only be used by the CubeCL/Burn projects for now.
96    pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
97        self.stream_id = Some(stream_id);
98    }
99
100    fn do_read(&self, descriptors: Vec<CopyDescriptor>) -> DynFut<Result<Vec<Bytes>, ServerError>> {
101        let stream_id = self.stream_id();
102        self.device
103            .submit_blocking(move |server| server.read(descriptors, stream_id))
104            .unwrap()
105    }
106
107    /// Given bindings, returns owned resources as bytes.
108    pub fn read_async(
109        &self,
110        handles: Vec<Handle>,
111    ) -> impl Future<Output = Result<Vec<Bytes>, ServerError>> + Send {
112        let shapes = handles
113            .iter()
114            .map(|it| [it.size_in_used() as usize].into())
115            .collect::<Vec<Shape>>();
116        let descriptors = handles
117            .into_iter()
118            .zip(shapes)
119            .map(|(handle, shape)| CopyDescriptor::new(handle.binding(), shape, [1].into(), 1))
120            .collect();
121
122        self.do_read(descriptors)
123    }
124
125    /// Given bindings, returns owned resources as bytes.
126    ///
127    /// # Remarks
128    ///
129    /// Panics if the read operation fails.
130    pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
131        cubecl_common::reader::read_sync(self.read_async(handles)).expect("TODO")
132    }
133
134    /// Given a binding, returns owned resource as bytes.
135    pub fn read_one(&self, handle: Handle) -> Result<Bytes, ServerError> {
136        Ok(cubecl_common::reader::read_sync(self.read_async(vec![handle]))?.remove(0))
137    }
138
139    /// Given a binding, returns owned resource as bytes.
140    ///
141    /// # Remarks
142    ///
143    /// Panics if the read operation fails. Useful for tests.
144    pub fn read_one_unchecked(&self, handle: Handle) -> Bytes {
145        cubecl_common::reader::read_sync(self.read_async(vec![handle]))
146            .unwrap()
147            .remove(0)
148    }
149
150    /// Given bindings, returns owned resources as bytes.
151    pub fn read_tensor_async(
152        &self,
153        descriptors: Vec<CopyDescriptor>,
154    ) -> impl Future<Output = Result<Vec<Bytes>, ServerError>> + Send {
155        self.do_read(descriptors)
156    }
157
158    /// Given bindings, returns owned resources as bytes.
159    ///
160    /// # Remarks
161    ///
162    /// Panics if the read operation fails.
163    ///
164    /// The tensor must be in the same layout as created by the runtime, or more strict.
165    /// Contiguous tensors are always fine, strided tensors are only ok if the stride is similar to
166    /// the one created by the runtime (i.e. padded on only the last dimension). A way to check
167    /// stride compatibility on the runtime will be added in the future.
168    ///
169    /// Also see [`ComputeClient::create_tensor`].
170    pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor>) -> Vec<Bytes> {
171        cubecl_common::reader::read_sync(self.read_tensor_async(descriptors)).expect("TODO")
172    }
173
174    /// Given a binding, returns owned resource as bytes.
175    /// See [`ComputeClient::read_tensor`]
176    pub fn read_one_tensor_async(
177        &self,
178        descriptor: CopyDescriptor,
179    ) -> impl Future<Output = Result<Bytes, ServerError>> + Send {
180        let fut = self.read_tensor_async(vec![descriptor]);
181
182        async { Ok(fut.await?.remove(0)) }
183    }
184
185    /// Given a binding, returns owned resource as bytes.
186    ///
187    /// # Remarks
188    ///
189    /// Panics if the read operation fails.
190    /// See [`ComputeClient::read_tensor`]
191    pub fn read_one_unchecked_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
192        self.read_tensor(vec![descriptor]).remove(0)
193    }
194
195    /// Given a resource handle, returns the storage resource.
196    pub fn get_resource(
197        &self,
198        handle: Handle,
199    ) -> Result<
200        ManagedResource<<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource>,
201        ServerError,
202    > {
203        let stream_id = self.stream_id();
204        let binding = handle.binding();
205
206        self.device
207            .submit_blocking(move |state| state.get_resource(binding, stream_id))
208            .unwrap()
209    }
210
211    fn do_create_from_slices(
212        &self,
213        descriptors: Vec<MemoryLayoutDescriptor>,
214        slices: Vec<Vec<u8>>,
215    ) -> Result<Vec<MemoryLayout>, IoError> {
216        let stream_id = self.stream_id();
217        let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
218
219        let descriptors = descriptors
220            .into_iter()
221            .zip(layouts.iter())
222            .zip(slices)
223            .map(|((desc, alloc), data)| {
224                (
225                    CopyDescriptor::new(
226                        alloc.memory.clone().binding(),
227                        desc.shape,
228                        alloc.strides.clone(),
229                        desc.elem_size,
230                    ),
231                    Bytes::from_bytes_vec(data.to_vec()),
232                )
233            })
234            .collect::<Vec<_>>();
235
236        let (size, memory) = (handle_base.size(), handle_base.memory);
237        self.device.submit(move |server| {
238            server.initialize_memory(memory, size, stream_id);
239            server.write(descriptors, stream_id);
240        });
241
242        Ok(layouts)
243    }
244
245    fn do_create(
246        &self,
247        descriptors: Vec<MemoryLayoutDescriptor>,
248        mut data: Vec<Bytes>,
249    ) -> Result<Vec<MemoryLayout>, IoError> {
250        self.staging(data.iter_mut(), true);
251
252        let stream_id = self.stream_id();
253        let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
254
255        let descriptors = descriptors
256            .into_iter()
257            .zip(layouts.iter())
258            .zip(data)
259            .map(|((desc, layout), data)| {
260                (
261                    CopyDescriptor::new(
262                        layout.memory.clone().binding(),
263                        desc.shape,
264                        layout.strides.clone(),
265                        desc.elem_size,
266                    ),
267                    Bytes::from_bytes_vec(data.to_vec()),
268                )
269            })
270            .collect::<Vec<_>>();
271
272        let (size, memory) = (handle_base.size(), handle_base.memory);
273        self.device.submit(move |server| {
274            server.initialize_memory(memory, size, stream_id);
275            server.write(descriptors, stream_id);
276        });
277
278        Ok(layouts)
279    }
280
281    /// Returns a resource handle containing the given data.
282    ///
283    /// # Notes
284    ///
285    /// Prefer using the more efficient [`Self::create`] function.
286    pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
287        let shape: Shape = [slice.len()].into();
288
289        self.do_create_from_slices(
290            vec![MemoryLayoutDescriptor::new(
291                MemoryLayoutStrategy::Contiguous,
292                shape,
293                1,
294            )],
295            vec![slice.to_vec()],
296        )
297        .unwrap()
298        .remove(0)
299        .memory
300    }
301
302    /// Executes a task that has exclusive access to the current device.
303    pub fn exclusive<Re: Send + 'static, F: FnOnce() -> Re + Send + 'static>(
304        &self,
305        task: F,
306    ) -> Result<Re, ServerError> {
307        // We first flush current tasks enqueued on the device.
308        self.flush()?;
309
310        // We then launch the task.
311        self.device
312            .exclusive(task)
313            .map_err(|err| ServerError::Generic {
314                reason: format!("Communication channel with the server is down: {err:?}"),
315                backtrace: BackTrace::capture(),
316            })
317    }
318
319    /// todo: docs
320    pub fn scoped<'a, Re: Send, F: FnOnce() -> Re + Send + 'a>(
321        &'a self,
322        task: F,
323    ) -> Result<Re, ServerError> {
324        // We then launch the task.
325        self.device
326            .exclusive_scoped(task)
327            .map_err(|err| ServerError::Generic {
328                reason: format!("Communication channel with the server is down: {err:?}"),
329                backtrace: BackTrace::capture(),
330            })
331    }
332
333    /// dodo: Docs
334    pub fn memory_persistent_allocation<
335        'a,
336        Re: Send,
337        Input: Send,
338        F: FnOnce(Input) -> Re + Send + 'a,
339    >(
340        &'a self,
341        input: Input,
342        task: F,
343    ) -> Result<Re, ServerError> {
344        // We then launch the task.
345        self.device
346            .exclusive_scoped(move || task(input))
347            .map_err(|err| ServerError::Generic {
348                reason: format!("Communication channel with the server is down: {err:?}"),
349                backtrace: BackTrace::capture(),
350            })
351    }
352
353    /// Returns a resource handle containing the given [Bytes].
354    pub fn create(&self, data: Bytes) -> Handle {
355        let shape = [data.len()].into();
356
357        self.do_create(
358            vec![MemoryLayoutDescriptor::new(
359                MemoryLayoutStrategy::Contiguous,
360                shape,
361                1,
362            )],
363            vec![data],
364        )
365        .unwrap()
366        .remove(0)
367        .memory
368    }
369
370    /// Given a resource and shape, stores it and returns the tensor handle and strides.
371    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
372    /// should be taken when indexing.
373    ///
374    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
375    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
376    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
377    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
378    /// can load as much data as possible in a single instruction. It may be aligned even more to
379    /// also take cache lines into account.
380    ///
381    /// However, the stride must be taken into account when indexing and reading the tensor
382    /// (also see [`ComputeClient::read_tensor`]).
383    ///
384    /// # Notes
385    ///
386    /// Prefer using [`Self::create_tensor`] for better performance.
387    pub fn create_tensor_from_slice(
388        &self,
389        slice: &[u8],
390        shape: Shape,
391        elem_size: usize,
392    ) -> MemoryLayout {
393        self.do_create_from_slices(
394            vec![MemoryLayoutDescriptor::new(
395                MemoryLayoutStrategy::Optimized,
396                shape,
397                elem_size,
398            )],
399            vec![slice.to_vec()],
400        )
401        .unwrap()
402        .remove(0)
403    }
404
405    /// Given a resource and shape, stores it and returns the tensor handle and strides.
406    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
407    /// should be taken when indexing.
408    ///
409    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
410    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
411    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
412    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
413    /// can load as much data as possible in a single instruction. It may be aligned even more to
414    /// also take cache lines into account.
415    ///
416    /// However, the stride must be taken into account when indexing and reading the tensor
417    /// (also see [`ComputeClient::read_tensor`]).
418    pub fn create_tensor(&self, bytes: Bytes, shape: Shape, elem_size: usize) -> MemoryLayout {
419        self.do_create(
420            vec![MemoryLayoutDescriptor::new(
421                MemoryLayoutStrategy::Optimized,
422                shape,
423                elem_size,
424            )],
425            vec![bytes],
426        )
427        .unwrap()
428        .remove(0)
429    }
430
431    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
432    /// handle, and returns the handles for them.
433    /// See [`ComputeClient::create_tensor`]
434    ///
435    /// # Notes
436    ///
437    /// Prefer using [`Self::create_tensors`] for better performance.
438    pub fn create_tensors_from_slices(
439        &self,
440        descriptors: Vec<(MemoryLayoutDescriptor, &[u8])>,
441    ) -> Vec<MemoryLayout> {
442        let mut data = Vec::with_capacity(descriptors.len());
443        let mut descriptors_ = Vec::with_capacity(descriptors.len());
444        for (a, b) in descriptors {
445            data.push(b.to_vec());
446            descriptors_.push(a);
447        }
448
449        self.do_create_from_slices(descriptors_, data).unwrap()
450    }
451
452    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
453    /// handle, and returns the handles for them.
454    /// See [`ComputeClient::create_tensor`]
455    pub fn create_tensors(
456        &self,
457        descriptors: Vec<(MemoryLayoutDescriptor, Bytes)>,
458    ) -> Vec<MemoryLayout> {
459        let (descriptors, data) = descriptors.into_iter().unzip();
460
461        self.do_create(descriptors, data).unwrap()
462    }
463
464    fn do_empty(
465        &self,
466        descriptors: Vec<MemoryLayoutDescriptor>,
467    ) -> Result<Vec<MemoryLayout>, IoError> {
468        let stream_id = self.stream_id();
469        let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
470
471        let (size, memory) = (handle_base.size(), handle_base.memory);
472        self.device.submit(move |server| {
473            server.initialize_memory(memory, size, stream_id);
474        });
475
476        Ok(layouts)
477    }
478
479    /// Reserves `size` bytes in the storage, and returns a handle over them.
480    pub fn empty(&self, size: usize) -> Handle {
481        let shape: Shape = [size].into();
482        let descriptor = MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Contiguous, shape, 1);
483        self.do_empty(vec![descriptor]).unwrap().remove(0).memory
484    }
485
486    /// Reserves `shape` in the storage, and returns a tensor handle for it.
487    /// See [`ComputeClient::create_tensor`]
488    pub fn empty_tensor(&self, shape: Shape, elem_size: usize) -> MemoryLayout {
489        let descriptor =
490            MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Optimized, shape, elem_size);
491        self.do_empty(vec![descriptor]).unwrap().remove(0)
492    }
493
494    /// Reserves all `shapes` in a single storage buffer, and returns the handles for them.
495    /// See [`ComputeClient::create_tensor`]
496    pub fn empty_tensors(&self, descriptors: Vec<MemoryLayoutDescriptor>) -> Vec<MemoryLayout> {
497        self.do_empty(descriptors).unwrap()
498    }
499
500    /// Marks the given [Bytes] as being a staging buffer, maybe transferring it to pinned memory
501    /// for faster data transfer with compute device.
502    ///
503    /// TODO: This blocks the compute queue, so it will drop the compute utilization.
504    pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
505    where
506        I: Iterator<Item = &'a mut Bytes>,
507    {
508        let has_staging = |b: &Bytes| match b.property() {
509            AllocationProperty::Pinned => false,
510            AllocationProperty::File => true,
511            AllocationProperty::Native | AllocationProperty::Other => !file_only,
512        };
513
514        let mut to_be_updated = Vec::new();
515        let sizes = bytes
516            .filter_map(|b| match has_staging(b) {
517                true => {
518                    let len = b.len();
519                    to_be_updated.push(b);
520                    Some(len)
521                }
522                false => None,
523            })
524            .collect::<Vec<usize>>();
525
526        if sizes.is_empty() {
527            return;
528        }
529
530        let stream_id = self.stream_id();
531        let sizes = sizes.to_vec();
532        let stagings = self
533            .device
534            .submit_blocking(move |server| server.staging(&sizes, stream_id))
535            .unwrap();
536
537        let stagings = match stagings {
538            Ok(val) => val,
539            Err(_) => return,
540        };
541
542        to_be_updated
543            .into_iter()
544            .zip(stagings)
545            .for_each(|(b, mut staging)| {
546                b.copy_into(&mut staging);
547                core::mem::swap(b, &mut staging);
548            });
549    }
550
551    /// Transfer data from one client to another
552    #[cfg_attr(
553        feature = "tracing",
554        tracing::instrument(level = "trace", skip(self, src, dst_server))
555    )]
556    pub fn to_client(&self, src: Handle, dst_server: &Self) -> Handle {
557        let shape = [src.size_in_used() as usize];
558        let src_descriptor = src.copy_descriptor(shape.into(), [1].into(), 1);
559
560        if R::Server::SERVER_COMM_ENABLED {
561            self.to_client_tensor(src_descriptor, dst_server)
562        } else {
563            let alloc_desc = MemoryLayoutDescriptor::new(
564                MemoryLayoutStrategy::Contiguous,
565                src_descriptor.shape.clone(),
566                src_descriptor.elem_size,
567            );
568            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
569                .memory
570        }
571    }
572
573    /// Wait on the communication stream.
574    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
575    pub fn sync_collective(&self) {
576        if DeviceHandle::<R::Server>::is_blocking() {
577            panic!("Can't use `sync_collective` with a blocking device handle");
578        }
579        let stream_id = self.stream_id();
580
581        self.device.submit(move |server| {
582            server.sync_collective(stream_id).unwrap();
583        });
584        // We don't actually need or want to sync the server here, but we need to make sure any
585        // task enqueued on the communication channel is done.
586        self.device.flush_queue();
587    }
588
589    /// Perform an `all_reduce` operation on the given devices.
590    #[cfg_attr(
591        feature = "tracing",
592        tracing::instrument(level = "trace", skip(self, src, dst, dtype, device_ids, op))
593    )]
594    pub fn all_reduce(
595        &self,
596        src: Handle,
597        dst: Handle,
598        dtype: ElemType,
599        device_ids: Vec<DeviceId>,
600        op: ReduceOperation,
601    ) {
602        if DeviceHandle::<R::Server>::is_blocking() {
603            panic!("Can't use `all_reduce` with a blocking device handle");
604        }
605
606        let stream_id = self.stream_id();
607        let src = src.binding();
608        let dst = dst.binding();
609
610        self.device.submit(move |server| {
611            server
612                .all_reduce(src, dst, dtype, stream_id, op, device_ids)
613                .unwrap();
614        });
615    }
616
617    /// Transfer data from one client to another
618    ///
619    /// Make sure the source description can be read in a contiguous manner.
620    #[cfg_attr(
621        feature = "tracing",
622        tracing::instrument(level = "trace", skip(self, src_descriptor, dst_server))
623    )]
624    pub fn to_client_tensor(&self, src_descriptor: CopyDescriptor, dst_server: &Self) -> Handle {
625        if R::Server::SERVER_COMM_ENABLED {
626            let stream_id_src = self.stream_id();
627            let stream_id_dst = dst_server.stream_id();
628
629            let dst_server = dst_server.clone();
630            let handle = Handle::new(stream_id_dst, src_descriptor.handle.size_in_used());
631            let handle_cloned = handle.clone();
632
633            // TODO: This should be made in a non-blocking API.
634            self.device
635                .submit_blocking_scoped(move |server_src| {
636                    dst_server.device.submit_blocking_scoped(|server_dst| {
637                        R::Server::copy(
638                            handle_cloned,
639                            server_src,
640                            server_dst,
641                            src_descriptor,
642                            stream_id_src,
643                            stream_id_dst,
644                        )
645                    })
646                })
647                .unwrap();
648
649            handle
650        } else {
651            let alloc_desc = MemoryLayoutDescriptor::new(
652                MemoryLayoutStrategy::Optimized,
653                src_descriptor.shape.clone(),
654                src_descriptor.elem_size,
655            );
656            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
657                .memory
658        }
659    }
660
661    #[track_caller]
662    #[cfg_attr(feature = "tracing", tracing::instrument(level="trace",
663        skip(self, kernel, bindings),
664        fields(
665            kernel.name = %kernel.name(),
666            kernel.id = %kernel.id(),
667        )
668    ))]
669    unsafe fn launch_inner(
670        &self,
671        kernel: <R::Server as ComputeServer>::Kernel,
672        count: CubeCount,
673        bindings: KernelArguments,
674        mode: ExecutionMode,
675        stream_id: StreamId,
676    ) {
677        let level = self.utilities.logger.profile_level();
678
679        match level {
680            None | Some(ProfileLevel::ExecutionOnly) => {
681                let utilities = self.utilities.clone();
682                self.device.submit(move |state| {
683                    let name = kernel.name();
684                    unsafe { state.launch(kernel, count, bindings, mode, stream_id) };
685
686                    if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
687                        let info = type_name_format(name, TypeNameFormatLevel::Balanced);
688                        utilities.logger.register_execution(info);
689                    }
690                });
691            }
692            Some(level) => {
693                let name = kernel.name();
694                let kernel_id = kernel.id();
695                let context = self.device.clone();
696                let count_moved = count.clone();
697                let (result, profile) = self
698                    .profile(
699                        move || {
700                            context
701                                .submit_blocking(move |state| unsafe {
702                                    state.launch(kernel, count_moved, bindings, mode, stream_id)
703                                })
704                                .unwrap()
705                        },
706                        name,
707                    )
708                    .unwrap();
709                let info = match level {
710                    ProfileLevel::Full => {
711                        format!("{name}: {kernel_id} CubeCount {count:?}")
712                    }
713                    _ => type_name_format(name, TypeNameFormatLevel::Balanced),
714                };
715                self.utilities.logger.register_profiled(info, profile);
716                result
717            }
718        }
719    }
720
721    /// Launches the `kernel` with the given `bindings`.
722    #[track_caller]
723    pub fn launch(
724        &self,
725        kernel: <R::Server as ComputeServer>::Kernel,
726        count: CubeCount,
727        bindings: KernelArguments,
728    ) {
729        // SAFETY: Using checked execution mode.
730        unsafe {
731            self.launch_inner(
732                kernel,
733                count,
734                bindings,
735                ExecutionMode::Checked,
736                self.stream_id(),
737            )
738        }
739    }
740
741    /// Launches the `kernel` with the given `bindings` without performing any bound checks.
742    ///
743    /// # Safety
744    ///
745    /// To ensure this is safe, you must verify your kernel:
746    /// - Has no out-of-bound reads and writes that can happen.
747    /// - Has no infinite loops that might never terminate.
748    #[track_caller]
749    pub unsafe fn launch_unchecked(
750        &self,
751        kernel: <R::Server as ComputeServer>::Kernel,
752        count: CubeCount,
753        bindings: KernelArguments,
754    ) {
755        // SAFETY: Caller has to uphold kernel being safe.
756        unsafe {
757            self.launch_inner(
758                kernel,
759                count,
760                bindings,
761                match self.utilities.check_mode {
762                    crate::config::compilation::BoundsCheckMode::Enforce => ExecutionMode::Checked,
763                    crate::config::compilation::BoundsCheckMode::Validate => {
764                        ExecutionMode::Validate
765                    }
766                    crate::config::compilation::BoundsCheckMode::Auto => ExecutionMode::Unchecked,
767                },
768                self.stream_id(),
769            )
770        }
771    }
772
773    /// Flush all outstanding commands.
774    pub fn flush(&self) -> Result<(), ServerError> {
775        let stream_id = self.stream_id();
776
777        self.device
778            .submit_blocking(move |server| server.flush(stream_id))
779            .unwrap()
780    }
781
782    /// Wait for the completion of every task in the server.
783    pub fn sync(&self) -> DynFut<Result<(), ServerError>> {
784        let stream_id = self.stream_id();
785
786        let fut = self
787            .device
788            .submit_blocking(move |server| server.sync(stream_id))
789            .unwrap();
790
791        self.utilities.logger.profile_summary();
792
793        fut
794    }
795
796    /// Get the features supported by the compute server.
797    pub fn properties(&self) -> &DeviceProperties {
798        &self.utilities.properties
799    }
800
801    /// Get the features supported by the compute server.
802    pub fn features(&self) -> &Features {
803        &self.utilities.properties.features
804    }
805
806    /// # Warning
807    ///
808    /// For private use only.
809    pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
810        Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
811    }
812
813    /// Get the current memory usage of this client.
814    pub fn memory_usage(&self) -> Result<MemoryUsage, ServerError> {
815        let stream_id = self.stream_id();
816        self.device
817            .submit_blocking(move |server| server.memory_usage(stream_id))
818            .unwrap()
819    }
820
821    /// Get all devices of a specific type available to this runtime
822    pub fn enumerate_devices(&self, type_id: u16) -> Vec<DeviceId> {
823        R::enumerate_devices(type_id, self.info())
824    }
825
826    /// Get all devices available to this runtime
827    pub fn enumerate_all_devices(&self) -> Vec<DeviceId> {
828        R::enumerate_all_devices(self.info())
829    }
830
831    /// Get the number of devices of a specific type available to this runtime
832    pub fn device_count(&self, type_id: u16) -> usize {
833        self.enumerate_devices(type_id).len()
834    }
835
836    /// Get the number of devices of a specific type available to this runtime
837    pub fn device_count_total(&self) -> usize {
838        self.enumerate_all_devices().len()
839    }
840
841    /// Change the memory allocation mode.
842    ///
843    /// # Safety
844    ///
845    /// This function isn't thread safe and might create memory leaks.
846    pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
847        let stream_id = self.stream_id();
848        self.device
849            .submit(move |server| server.allocation_mode(mode, stream_id));
850    }
851
852    // TODO: Remove that or rework for performance.
853    //
854    // /// Use a persistent memory strategy to execute the provided function.
855    // ///
856    // /// # Notes
857    // ///
858    // /// - Using that memory strategy is beneficial for stating model parameters and similar workflows.
859    // /// - You can call [`Self::memory_cleanup()`] if you want to free persistent memory.
860    // pub fn memory_persistent_allocation<
861    //     Input: Send + 'static,
862    //     Output: Send + 'static,
863    //     Func: Fn(Input) -> Output + Send + 'static,
864    // >(
865    //     &self,
866    //     input: Input,
867    //     func: Func,
868    // ) -> Output {
869    //     let stream_id = self.stream_id();
870    //     self.device
871    //         .submit_blocking(move |server| {
872    //             server.allocation_mode(MemoryAllocationMode::Persistent, stream_id);
873    //             let output = func(input);
874    //             server.allocation_mode(MemoryAllocationMode::Auto, stream_id);
875    //             output
876    //         })
877    //         .unwrap()
878    // }
879
880    /// Ask the client to release memory that it can release.
881    ///
882    /// Nb: Results will vary on what the memory allocator deems beneficial,
883    /// so it's not guaranteed any memory is freed.
884    pub fn memory_cleanup(&self) {
885        let stream_id = self.stream_id();
886        self.device
887            .submit(move |server| server.memory_cleanup(stream_id));
888    }
889
890    /// Measure the execution time of some inner operations.
891    #[track_caller]
892    pub fn profile<O: Send + 'static>(
893        &self,
894        func: impl FnOnce() -> O + Send,
895        #[allow(unused)] func_name: &str,
896    ) -> Result<(O, ProfileDuration), ProfileError> {
897        // Get the outer caller. For execute() this points straight to the
898        // cube kernel. For general profiling it points to whoever calls profile.
899        #[cfg(feature = "profile-tracy")]
900        let location = std::panic::Location::caller();
901
902        // Make a CPU span. If the server has system profiling this is all you need.
903        #[cfg(feature = "profile-tracy")]
904        let _span = tracy_client::Client::running().unwrap().span_alloc(
905            None,
906            func_name,
907            location.file(),
908            location.line(),
909            0,
910        );
911
912        let stream_id = self.stream_id();
913
914        #[cfg(feature = "profile-tracy")]
915        let gpu_span = if self.utilities.properties.timing_method == TimingMethod::Device {
916            let gpu_span = self
917                .utilities
918                .gpu_client
919                .span_alloc(func_name, "profile", location.file(), location.line())
920                .unwrap();
921            Some(gpu_span)
922        } else {
923            None
924        };
925
926        let device = self.device.clone();
927        #[allow(unused_mut, reason = "Used in profile-tracy")]
928        let mut result = self
929            .device
930            .exclusive_scoped(move || {
931                // We first get mut access to the server to create a token.
932                // Then we free to server, since it's going to be accessed in `func()`.
933                let token =
934                    match device.submit_blocking(move |server| server.start_profile(stream_id)) {
935                        Ok(token) => match token {
936                            Ok(token) => token,
937                            Err(err) => return Err(err),
938                        },
939                        Err(err) => {
940                            return Err(ServerError::Generic {
941                                reason: alloc::format!(
942                                    "Can't start profiling because of a call error: {err:?}"
943                                ),
944                                backtrace: BackTrace::capture(),
945                            });
946                        }
947                    };
948
949                // We execute `func()` which will recursibly access the server.
950                let out = func();
951
952                // Finally we get the result from the token.
953                let result = device
954                    .submit_blocking(move |server| {
955                        let mut result = server.end_profile(stream_id, token);
956
957                        match result {
958                            Ok(result) => Ok((out, result)),
959                            Err(err) => Err(err),
960                        }
961                    })
962                    .unwrap();
963
964                Ok(result)
965            })
966            .unwrap()
967            .map_err(|err| ProfileError::Unknown {
968                reason: alloc::format!("{err:?}"),
969                backtrace: BackTrace::capture(),
970            })?;
971
972        #[cfg(feature = "profile-tracy")]
973        if let Some(mut gpu_span) = gpu_span {
974            gpu_span.end_zone();
975            let epoch = self.utilities.epoch_time;
976            // Add in the work to upload the timestamp data.
977            result = result.map(|(o, result)| {
978                (
979                    o,
980                    ProfileDuration::new(
981                        alloc::boxed::Box::pin(async move {
982                            let ticks = result.resolve().await;
983                            let start_duration =
984                                ticks.start_duration_since(epoch).as_nanos() as i64;
985                            let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
986                            gpu_span.upload_timestamp_start(start_duration);
987                            gpu_span.upload_timestamp_end(end_duration);
988                            ticks
989                        }),
990                        TimingMethod::Device,
991                    ),
992                )
993            });
994        }
995
996        result
997    }
998
999    /// Transfer data from one client to another
1000    #[cfg_attr(
1001        feature = "tracing",
1002        tracing::instrument(
1003            level = "trace",
1004            skip(self, src_descriptor, alloc_descriptor, dst_server)
1005        )
1006    )]
1007    fn change_client_sync(
1008        &self,
1009        src_descriptor: CopyDescriptor,
1010        alloc_descriptor: MemoryLayoutDescriptor,
1011        dst_server: &Self,
1012    ) -> MemoryLayout {
1013        let shape = src_descriptor.shape.clone();
1014        let elem_size = src_descriptor.elem_size;
1015        let stream_id = self.stream_id();
1016
1017        let read = self
1018            .device
1019            .submit_blocking(move |server| server.read(vec![src_descriptor], stream_id))
1020            .unwrap();
1021
1022        let mut data = cubecl_common::future::block_on(read).unwrap();
1023
1024        let (handle_base, mut layouts) = self
1025            .utilities
1026            .layout_policy
1027            .apply(stream_id, &[alloc_descriptor]);
1028        let alloc = layouts.remove(0);
1029
1030        let desc_descriptor = CopyDescriptor {
1031            handle: handle_base.clone().binding(),
1032            shape,
1033            strides: alloc.strides.clone(),
1034            elem_size,
1035        };
1036
1037        let (size, memory) = (handle_base.size(), handle_base.memory);
1038        dst_server.device.submit(move |server| {
1039            server.initialize_memory(memory, size, stream_id);
1040            server.write(vec![(desc_descriptor, data.remove(0))], stream_id)
1041        });
1042
1043        alloc
1044    }
1045
1046    /// Returns all vector sizes that are useful to perform optimal IO operation on the given element.
1047    pub fn io_optimized_vector_sizes(
1048        &self,
1049        size: usize,
1050    ) -> impl Iterator<Item = VectorSize> + Clone {
1051        let load_width = self.properties().hardware.load_width as usize;
1052        let size_bits = size * 8;
1053        let max = load_width / size_bits;
1054        let max = usize::min(self.properties().hardware.max_vector_size, max);
1055
1056        // If the max is 8, we want to test 1, 2, 4, 8 which is log2(8) + 1.
1057        let num_candidates = max.trailing_zeros() + 1;
1058
1059        (0..num_candidates).map(|i| 2usize.pow(i)).rev()
1060    }
1061}