Skip to main content

cubecl_runtime/
client.rs

1use crate::{
2    config::{TypeNameFormatLevel, type_name_format},
3    kernel::KernelMetadata,
4    logging::ProfileLevel,
5    memory_management::{MemoryAllocationMode, MemoryUsage},
6    runtime::Runtime,
7    server::{
8        CommunicationId, ComputeServer, CopyDescriptor, CubeCount, ExecutionMode, Handle, IoError,
9        KernelArguments, MemoryLayout, MemoryLayoutDescriptor, MemoryLayoutPolicy,
10        MemoryLayoutStrategy, ProfileError, ReduceOperation, ServerCommunication, ServerError,
11        ServerUtilities,
12    },
13    storage::{ComputeStorage, ManagedResource},
14};
15use alloc::{format, sync::Arc, vec, vec::Vec};
16use cubecl_common::{
17    backtrace::BackTrace,
18    bytes::{AllocationProperty, Bytes},
19    device::{Device, DeviceId},
20    device_handle::DeviceHandle,
21    future::DynFut,
22    profile::ProfileDuration,
23};
24use cubecl_ir::{DeviceProperties, ElemType, VectorSize, features::Features};
25use cubecl_zspace::Shape;
26
27#[allow(unused)]
28use cubecl_common::profile::TimingMethod;
29use cubecl_common::stream_id::StreamId;
30
31/// The `ComputeClient` is the entry point to require tasks from the `ComputeServer`.
32/// It should be obtained for a specific device via the Compute struct.
33pub struct ComputeClient<R: Runtime> {
34    device: DeviceHandle<R::Server>,
35    utilities: Arc<ServerUtilities<R::Server>>,
36    stream_id: Option<StreamId>,
37}
38
39impl<R: Runtime> Clone for ComputeClient<R> {
40    fn clone(&self) -> Self {
41        Self {
42            device: self.device.clone(),
43            utilities: self.utilities.clone(),
44            stream_id: self.stream_id,
45        }
46    }
47}
48
49impl<R: Runtime> ComputeClient<R> {
50    /// Get the info of the current backend.
51    pub fn info(&self) -> &<R::Server as ComputeServer>::Info {
52        &self.utilities.info
53    }
54
55    /// Create a new client with a new server.
56    pub fn init<D: Device>(device: &D, server: R::Server) -> Self {
57        let utilities = server.utilities();
58        let context = DeviceHandle::<R::Server>::insert(device.to_id(), server)
59            .expect("Can't create a new client on an already registered server");
60
61        Self {
62            device: context,
63            utilities,
64            stream_id: None,
65        }
66    }
67
68    /// Load the client for the given device.
69    pub fn load<D: Device>(device: &D) -> Self {
70        let context = DeviceHandle::<R::Server>::new(device.to_id());
71
72        // This is safe because we now know the return type of [`DeviceHandle::utilities()`].
73        let utilities = context
74            .utilities()
75            .downcast::<ServerUtilities<R::Server>>()
76            .expect("Can downcast to `ServerUtilities`");
77
78        Self {
79            device: context,
80            utilities,
81            stream_id: None,
82        }
83    }
84
85    fn stream_id(&self) -> StreamId {
86        match self.stream_id {
87            Some(val) => val,
88            None => StreamId::current(),
89        }
90    }
91
92    /// Set the stream in which the current client is operating on.
93    ///
94    /// # Safety
95    ///
96    /// This is highly unsafe and should probably only be used by the CubeCL/Burn projects for now.
97    pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
98        self.stream_id = Some(stream_id);
99    }
100
101    fn do_read(&self, descriptors: Vec<CopyDescriptor>) -> DynFut<Result<Vec<Bytes>, ServerError>> {
102        let stream_id = self.stream_id();
103        self.device
104            .submit_blocking(move |server| server.read(descriptors, stream_id))
105            .unwrap()
106    }
107
108    /// Given bindings, returns owned resources as bytes.
109    pub fn read_async(
110        &self,
111        handles: Vec<Handle>,
112    ) -> impl Future<Output = Result<Vec<Bytes>, ServerError>> + Send {
113        let shapes = handles
114            .iter()
115            .map(|it| [it.size_in_used() as usize].into())
116            .collect::<Vec<Shape>>();
117        let descriptors = handles
118            .into_iter()
119            .zip(shapes)
120            .map(|(handle, shape)| CopyDescriptor::new(handle.binding(), shape, [1].into(), 1))
121            .collect();
122
123        self.do_read(descriptors)
124    }
125
126    /// Given bindings, returns owned resources as bytes.
127    ///
128    /// # Remarks
129    ///
130    /// Panics if the read operation fails.
131    pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
132        cubecl_common::reader::read_sync(self.read_async(handles)).expect("TODO")
133    }
134
135    /// Given a binding, returns owned resource as bytes.
136    pub fn read_one(&self, handle: Handle) -> Result<Bytes, ServerError> {
137        Ok(cubecl_common::reader::read_sync(self.read_async(vec![handle]))?.remove(0))
138    }
139
140    /// Given a binding, returns owned resource as bytes.
141    ///
142    /// # Remarks
143    ///
144    /// Panics if the read operation fails. Useful for tests.
145    pub fn read_one_unchecked(&self, handle: Handle) -> Bytes {
146        cubecl_common::reader::read_sync(self.read_async(vec![handle]))
147            .unwrap()
148            .remove(0)
149    }
150
151    /// Given bindings, returns owned resources as bytes.
152    pub fn read_tensor_async(
153        &self,
154        descriptors: Vec<CopyDescriptor>,
155    ) -> impl Future<Output = Result<Vec<Bytes>, ServerError>> + Send {
156        self.do_read(descriptors)
157    }
158
159    /// Given bindings, returns owned resources as bytes.
160    ///
161    /// # Remarks
162    ///
163    /// Panics if the read operation fails.
164    ///
165    /// The tensor must be in the same layout as created by the runtime, or more strict.
166    /// Contiguous tensors are always fine, strided tensors are only ok if the stride is similar to
167    /// the one created by the runtime (i.e. padded on only the last dimension). A way to check
168    /// stride compatibility on the runtime will be added in the future.
169    ///
170    /// Also see [`ComputeClient::create_tensor`].
171    pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor>) -> Vec<Bytes> {
172        cubecl_common::reader::read_sync(self.read_tensor_async(descriptors)).expect("TODO")
173    }
174
175    /// Given a binding, returns owned resource as bytes.
176    /// See [`ComputeClient::read_tensor`]
177    pub fn read_one_tensor_async(
178        &self,
179        descriptor: CopyDescriptor,
180    ) -> impl Future<Output = Result<Bytes, ServerError>> + Send {
181        let fut = self.read_tensor_async(vec![descriptor]);
182
183        async { Ok(fut.await?.remove(0)) }
184    }
185
186    /// Given a binding, returns owned resource as bytes.
187    ///
188    /// # Remarks
189    ///
190    /// Panics if the read operation fails.
191    /// See [`ComputeClient::read_tensor`]
192    pub fn read_one_unchecked_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
193        self.read_tensor(vec![descriptor]).remove(0)
194    }
195
196    /// Given a resource handle, returns the storage resource.
197    pub fn get_resource(
198        &self,
199        handle: Handle,
200    ) -> Result<
201        ManagedResource<<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource>,
202        ServerError,
203    > {
204        let stream_id = self.stream_id();
205        let binding = handle.binding();
206
207        self.device
208            .submit_blocking(move |state| state.get_resource(binding, stream_id))
209            .unwrap()
210    }
211
212    fn do_create_from_slices(
213        &self,
214        descriptors: Vec<MemoryLayoutDescriptor>,
215        slices: Vec<Vec<u8>>,
216    ) -> Result<Vec<MemoryLayout>, IoError> {
217        let stream_id = self.stream_id();
218        let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
219
220        let descriptors = descriptors
221            .into_iter()
222            .zip(layouts.iter())
223            .zip(slices)
224            .map(|((desc, alloc), data)| {
225                (
226                    CopyDescriptor::new(
227                        alloc.memory.clone().binding(),
228                        desc.shape,
229                        alloc.strides.clone(),
230                        desc.elem_size,
231                    ),
232                    Bytes::from_bytes_vec(data.to_vec()),
233                )
234            })
235            .collect::<Vec<_>>();
236
237        let (size, memory) = (handle_base.size(), handle_base.memory);
238        self.device.submit(move |server| {
239            server.initialize_memory(memory, size, stream_id);
240            server.write(descriptors, stream_id);
241        });
242
243        Ok(layouts)
244    }
245
246    fn do_create(
247        &self,
248        descriptors: Vec<MemoryLayoutDescriptor>,
249        mut data: Vec<Bytes>,
250    ) -> Result<Vec<MemoryLayout>, IoError> {
251        self.staging(data.iter_mut(), true);
252
253        let stream_id = self.stream_id();
254        let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
255
256        let descriptors = descriptors
257            .into_iter()
258            .zip(layouts.iter())
259            .zip(data)
260            .map(|((desc, layout), data)| {
261                (
262                    CopyDescriptor::new(
263                        layout.memory.clone().binding(),
264                        desc.shape,
265                        layout.strides.clone(),
266                        desc.elem_size,
267                    ),
268                    Bytes::from_bytes_vec(data.to_vec()),
269                )
270            })
271            .collect::<Vec<_>>();
272
273        let (size, memory) = (handle_base.size(), handle_base.memory);
274        self.device.submit(move |server| {
275            server.initialize_memory(memory, size, stream_id);
276            server.write(descriptors, stream_id);
277        });
278
279        Ok(layouts)
280    }
281
282    /// Returns a resource handle containing the given data.
283    ///
284    /// # Notes
285    ///
286    /// Prefer using the more efficient [`Self::create`] function.
287    pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
288        let shape: Shape = [slice.len()].into();
289
290        self.do_create_from_slices(
291            vec![MemoryLayoutDescriptor::new(
292                MemoryLayoutStrategy::Contiguous,
293                shape,
294                1,
295            )],
296            vec![slice.to_vec()],
297        )
298        .unwrap()
299        .remove(0)
300        .memory
301    }
302
303    /// todo: docs
304    pub fn exclusive<'a, Re: Send + 'static, F: FnOnce() -> Re + Send + 'a>(
305        &'a self,
306        task: F,
307    ) -> Result<Re, ServerError> {
308        // We then launch the task.
309        self.device
310            .exclusive(task)
311            .map_err(|err| ServerError::Generic {
312                reason: format!("Communication channel with the server is down: {err:?}"),
313                backtrace: BackTrace::capture(),
314            })
315    }
316
317    /// dodo: Docs
318    pub fn memory_persistent_allocation<
319        'a,
320        Re: Send,
321        Input: Send,
322        F: FnOnce(Input) -> Re + Send + 'a,
323    >(
324        &'a self,
325        input: Input,
326        task: F,
327    ) -> Result<Re, ServerError> {
328        let stream_id = StreamId::current();
329
330        self.device.submit(move |server| {
331            server.allocation_mode(MemoryAllocationMode::Persistent, stream_id);
332        });
333
334        // All tasks created on the same stream will have persistent memory.
335        let output = task(input);
336
337        self.device.submit(move |server| {
338            server.allocation_mode(MemoryAllocationMode::Auto, stream_id);
339        });
340
341        Ok(output)
342    }
343
344    /// Returns a resource handle containing the given [Bytes].
345    pub fn create(&self, data: Bytes) -> Handle {
346        let shape = [data.len()].into();
347
348        self.do_create(
349            vec![MemoryLayoutDescriptor::new(
350                MemoryLayoutStrategy::Contiguous,
351                shape,
352                1,
353            )],
354            vec![data],
355        )
356        .unwrap()
357        .remove(0)
358        .memory
359    }
360
361    /// Given a resource and shape, stores it and returns the tensor handle and strides.
362    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
363    /// should be taken when indexing.
364    ///
365    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
366    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
367    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
368    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
369    /// can load as much data as possible in a single instruction. It may be aligned even more to
370    /// also take cache lines into account.
371    ///
372    /// However, the stride must be taken into account when indexing and reading the tensor
373    /// (also see [`ComputeClient::read_tensor`]).
374    ///
375    /// # Notes
376    ///
377    /// Prefer using [`Self::create_tensor`] for better performance.
378    pub fn create_tensor_from_slice(
379        &self,
380        slice: &[u8],
381        shape: Shape,
382        elem_size: usize,
383    ) -> MemoryLayout {
384        self.do_create_from_slices(
385            vec![MemoryLayoutDescriptor::new(
386                MemoryLayoutStrategy::Optimized,
387                shape,
388                elem_size,
389            )],
390            vec![slice.to_vec()],
391        )
392        .unwrap()
393        .remove(0)
394    }
395
396    /// Given a resource and shape, stores it and returns the tensor handle and strides.
397    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
398    /// should be taken when indexing.
399    ///
400    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
401    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
402    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
403    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
404    /// can load as much data as possible in a single instruction. It may be aligned even more to
405    /// also take cache lines into account.
406    ///
407    /// However, the stride must be taken into account when indexing and reading the tensor
408    /// (also see [`ComputeClient::read_tensor`]).
409    pub fn create_tensor(&self, bytes: Bytes, shape: Shape, elem_size: usize) -> MemoryLayout {
410        self.do_create(
411            vec![MemoryLayoutDescriptor::new(
412                MemoryLayoutStrategy::Optimized,
413                shape,
414                elem_size,
415            )],
416            vec![bytes],
417        )
418        .unwrap()
419        .remove(0)
420    }
421
422    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
423    /// handle, and returns the handles for them.
424    /// See [`ComputeClient::create_tensor`]
425    ///
426    /// # Notes
427    ///
428    /// Prefer using [`Self::create_tensors`] for better performance.
429    pub fn create_tensors_from_slices(
430        &self,
431        descriptors: Vec<(MemoryLayoutDescriptor, &[u8])>,
432    ) -> Vec<MemoryLayout> {
433        let mut data = Vec::with_capacity(descriptors.len());
434        let mut descriptors_ = Vec::with_capacity(descriptors.len());
435        for (a, b) in descriptors {
436            data.push(b.to_vec());
437            descriptors_.push(a);
438        }
439
440        self.do_create_from_slices(descriptors_, data).unwrap()
441    }
442
443    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
444    /// handle, and returns the handles for them.
445    /// See [`ComputeClient::create_tensor`]
446    pub fn create_tensors(
447        &self,
448        descriptors: Vec<(MemoryLayoutDescriptor, Bytes)>,
449    ) -> Vec<MemoryLayout> {
450        let (descriptors, data) = descriptors.into_iter().unzip();
451
452        self.do_create(descriptors, data).unwrap()
453    }
454
455    fn do_empty(
456        &self,
457        descriptors: Vec<MemoryLayoutDescriptor>,
458    ) -> Result<Vec<MemoryLayout>, IoError> {
459        let stream_id = self.stream_id();
460        let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
461
462        let (size, memory) = (handle_base.size(), handle_base.memory);
463        self.device.submit(move |server| {
464            server.initialize_memory(memory, size, stream_id);
465        });
466
467        Ok(layouts)
468    }
469
470    /// Reserves `size` bytes in the storage, and returns a handle over them.
471    pub fn empty(&self, size: usize) -> Handle {
472        let shape: Shape = [size].into();
473        let descriptor = MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Contiguous, shape, 1);
474        self.do_empty(vec![descriptor]).unwrap().remove(0).memory
475    }
476
477    /// Reserves `shape` in the storage, and returns a tensor handle for it.
478    /// See [`ComputeClient::create_tensor`]
479    pub fn empty_tensor(&self, shape: Shape, elem_size: usize) -> MemoryLayout {
480        let descriptor =
481            MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Optimized, shape, elem_size);
482        self.do_empty(vec![descriptor]).unwrap().remove(0)
483    }
484
485    /// Reserves all `shapes` in a single storage buffer, and returns the handles for them.
486    /// See [`ComputeClient::create_tensor`]
487    pub fn empty_tensors(&self, descriptors: Vec<MemoryLayoutDescriptor>) -> Vec<MemoryLayout> {
488        self.do_empty(descriptors).unwrap()
489    }
490
491    /// Marks the given [Bytes] as being a staging buffer, maybe transferring it to pinned memory
492    /// for faster data transfer with compute device.
493    ///
494    /// TODO: This blocks the compute queue, so it will drop the compute utilization.
495    pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
496    where
497        I: Iterator<Item = &'a mut Bytes>,
498    {
499        let has_staging = |b: &Bytes| match b.property() {
500            AllocationProperty::Pinned => false,
501            AllocationProperty::File => true,
502            AllocationProperty::Native | AllocationProperty::Other => !file_only,
503        };
504
505        let mut to_be_updated = Vec::new();
506        let sizes = bytes
507            .filter_map(|b| match has_staging(b) {
508                true => {
509                    let len = b.len();
510                    to_be_updated.push(b);
511                    Some(len)
512                }
513                false => None,
514            })
515            .collect::<Vec<usize>>();
516
517        if sizes.is_empty() {
518            return;
519        }
520
521        let stream_id = self.stream_id();
522        let sizes = sizes.to_vec();
523        let stagings = self
524            .device
525            .submit_blocking(move |server| server.staging(&sizes, stream_id))
526            .unwrap();
527
528        let stagings = match stagings {
529            Ok(val) => val,
530            Err(_) => return,
531        };
532
533        to_be_updated
534            .into_iter()
535            .zip(stagings)
536            .for_each(|(b, mut staging)| {
537                b.copy_into(&mut staging);
538                core::mem::swap(b, &mut staging);
539            });
540    }
541
542    /// Transfer data from one client to another
543    #[cfg_attr(
544        feature = "tracing",
545        tracing::instrument(level = "trace", skip(self, src, dst_server))
546    )]
547    pub fn to_client(&mut self, src: Handle, dst_server: &Self, dtype: ElemType) -> Handle {
548        let shape = [src.size_in_used() as usize];
549        let src_descriptor = src.copy_descriptor(shape.into(), [1].into(), 1);
550
551        if R::Server::SERVER_COMM_ENABLED {
552            self.to_client_tensor(src_descriptor, dst_server, dtype)
553        } else {
554            let alloc_desc = MemoryLayoutDescriptor::new(
555                MemoryLayoutStrategy::Contiguous,
556                src_descriptor.shape.clone(),
557                src_descriptor.elem_size,
558            );
559            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
560                .memory
561        }
562    }
563
564    /// Perform an `all_reduce` operation on the given devices.
565    #[cfg_attr(
566        feature = "tracing",
567        tracing::instrument(level = "trace", skip(self, device_ids))
568    )]
569    pub fn ensure_init_collective(&mut self, device_ids: Vec<DeviceId>) {
570        let comm_id = CommunicationId::from(device_ids.clone());
571        let is_comms_init = self
572            .utilities
573            .initialized_comms
574            .read()
575            .unwrap()
576            .contains(&comm_id);
577        if !is_comms_init {
578            self.device
579                .submit(move |server| server.comm_init(device_ids).unwrap());
580            let mut initialized_comms = self.utilities.initialized_comms.write().unwrap();
581            initialized_comms.insert(comm_id);
582            // Flush immediately so other devices aren't blocked waiting on this initialization.
583            self.device.flush_queue();
584        }
585    }
586
587    /// Wait on the communication stream.
588    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
589    pub fn sync_collective(&self) {
590        if DeviceHandle::<R::Server>::is_blocking() {
591            panic!("Can't use `sync_collective` with a blocking device handle");
592        }
593        let stream_id = self.stream_id();
594
595        self.device.submit(move |server| {
596            server.sync_collective(stream_id).unwrap();
597        });
598
599        // We don't actually need or want to sync the server here, but we need to make sure any
600        // task enqueued on the communication channel is done.
601        self.device.flush_queue();
602    }
603
604    /// Perform an `all_reduce` operation on the given devices.
605    #[cfg_attr(
606        feature = "tracing",
607        tracing::instrument(level = "trace", skip(self, src, dst, dtype, device_ids, op))
608    )]
609    pub fn all_reduce(
610        &mut self,
611        src: Handle,
612        dst: Handle,
613        dtype: ElemType,
614        device_ids: Vec<DeviceId>,
615        op: ReduceOperation,
616    ) {
617        if DeviceHandle::<R::Server>::is_blocking() {
618            panic!("Can't use `all_reduce` with a blocking device handle");
619        }
620
621        let stream_id = self.stream_id();
622        let src = src.binding();
623        let dst = dst.binding();
624
625        self.ensure_init_collective(device_ids.clone());
626
627        self.device.submit(move |server| {
628            server
629                .all_reduce(src, dst, dtype, stream_id, op, device_ids)
630                .unwrap();
631        });
632    }
633
634    /// Transfer data from one client to another
635    ///
636    /// Make sure the source description can be read in a contiguous manner.
637    #[cfg_attr(
638        feature = "tracing",
639        tracing::instrument(level = "trace", skip(self, src_descriptor, dst_server))
640    )]
641    pub fn to_client_tensor(
642        &mut self,
643        src_descriptor: CopyDescriptor,
644        dst_server: &Self,
645        dtype: ElemType,
646    ) -> Handle {
647        let stream_id_src = self.stream_id();
648        let stream_id_dst = dst_server.stream_id();
649
650        let device_id_src = self.device.device_id();
651        let device_id_dst = dst_server.device.device_id();
652
653        let mut dst_server = dst_server.clone();
654        let handle = Handle::new(stream_id_dst, src_descriptor.handle.size_in_used());
655        let handle_cloned = handle.clone();
656
657        let device_ids = vec![device_id_src, device_id_dst];
658        self.ensure_init_collective(device_ids.clone());
659        dst_server.ensure_init_collective(device_ids);
660
661        self.device.submit(move |server_src| {
662            server_src
663                .send(src_descriptor, dtype, stream_id_src, device_id_dst)
664                .unwrap()
665        });
666
667        dst_server.device.submit(move |server_dst| {
668            server_dst
669                .recv(handle_cloned, dtype, stream_id_dst, device_id_src)
670                .unwrap();
671            server_dst.sync_collective(stream_id_dst).unwrap();
672        });
673
674        // `ServerCommunication::send` and`ServerCommunication::recv` are blocking: they each wait for the corresponding recv/send
675        // call to be made. We flush the operations right away so that the neither server ends up in a deadlock.
676        // The actual data transfer is still executed asynchronously on the communication stream.
677        self.device.flush_queue();
678        dst_server.device.flush_queue();
679
680        handle
681    }
682
683    #[track_caller]
684    #[cfg_attr(feature = "tracing", tracing::instrument(level="trace",
685        skip(self, kernel, bindings),
686        fields(
687            kernel.name = %kernel.name(),
688            kernel.id = %kernel.id(),
689        )
690    ))]
691    unsafe fn launch_inner(
692        &self,
693        kernel: <R::Server as ComputeServer>::Kernel,
694        count: CubeCount,
695        bindings: KernelArguments,
696        mode: ExecutionMode,
697        stream_id: StreamId,
698    ) {
699        let level = self.utilities.logger.profile_level();
700
701        match level {
702            None | Some(ProfileLevel::ExecutionOnly) => {
703                let utilities = self.utilities.clone();
704                self.device.submit(move |state| {
705                    let name = kernel.name();
706                    unsafe { state.launch(kernel, count, bindings, mode, stream_id) };
707
708                    if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
709                        let info = type_name_format(name, TypeNameFormatLevel::Balanced);
710                        utilities.logger.register_execution(info);
711                    }
712                });
713            }
714            Some(level) => {
715                let name = kernel.name();
716                let kernel_id = kernel.id();
717                let context = self.device.clone();
718                let count_moved = count.clone();
719                let (result, profile) = self
720                    .profile(
721                        move || {
722                            context
723                                .submit_blocking(move |state| unsafe {
724                                    state.launch(kernel, count_moved, bindings, mode, stream_id)
725                                })
726                                .unwrap()
727                        },
728                        name,
729                    )
730                    .unwrap();
731                let info = match level {
732                    ProfileLevel::Full => {
733                        format!("{name}: {kernel_id} CubeCount {count:?}")
734                    }
735                    _ => type_name_format(name, TypeNameFormatLevel::Balanced),
736                };
737                self.utilities.logger.register_profiled(info, profile);
738                result
739            }
740        }
741    }
742
743    /// Launches the `kernel` with the given `bindings`.
744    #[track_caller]
745    pub fn launch(
746        &self,
747        kernel: <R::Server as ComputeServer>::Kernel,
748        count: CubeCount,
749        bindings: KernelArguments,
750    ) {
751        // SAFETY: Using checked execution mode.
752        unsafe {
753            self.launch_inner(
754                kernel,
755                count,
756                bindings,
757                ExecutionMode::Checked,
758                self.stream_id(),
759            )
760        }
761    }
762
763    /// Launches the `kernel` with the given `bindings` without performing any bound checks.
764    ///
765    /// # Safety
766    ///
767    /// To ensure this is safe, you must verify your kernel:
768    /// - Has no out-of-bound reads and writes that can happen.
769    /// - Has no infinite loops that might never terminate.
770    #[track_caller]
771    pub unsafe fn launch_unchecked(
772        &self,
773        kernel: <R::Server as ComputeServer>::Kernel,
774        count: CubeCount,
775        bindings: KernelArguments,
776    ) {
777        // SAFETY: Caller has to uphold kernel being safe.
778        unsafe {
779            self.launch_inner(
780                kernel,
781                count,
782                bindings,
783                match self.utilities.check_mode {
784                    crate::config::compilation::BoundsCheckMode::Enforce => ExecutionMode::Checked,
785                    crate::config::compilation::BoundsCheckMode::Validate => {
786                        ExecutionMode::Validate
787                    }
788                    crate::config::compilation::BoundsCheckMode::Auto => ExecutionMode::Unchecked,
789                },
790                self.stream_id(),
791            )
792        }
793    }
794
795    /// Flush all outstanding commands.
796    pub fn flush(&self) -> Result<(), ServerError> {
797        let stream_id = self.stream_id();
798
799        self.device
800            .submit_blocking(move |server| server.flush(stream_id))
801            .unwrap()
802    }
803
804    /// Wait for the completion of every task in the server.
805    pub fn sync(&self) -> DynFut<Result<(), ServerError>> {
806        let stream_id = self.stream_id();
807
808        let fut = self
809            .device
810            .submit_blocking(move |server| server.sync(stream_id))
811            .unwrap();
812
813        self.utilities.logger.profile_summary();
814
815        fut
816    }
817
818    /// Get the features supported by the compute server.
819    pub fn properties(&self) -> &DeviceProperties {
820        &self.utilities.properties
821    }
822
823    /// Get the features supported by the compute server.
824    pub fn features(&self) -> &Features {
825        &self.utilities.properties.features
826    }
827
828    /// # Warning
829    ///
830    /// For private use only.
831    pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
832        Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
833    }
834
835    /// Get the current memory usage of this client.
836    pub fn memory_usage(&self) -> Result<MemoryUsage, ServerError> {
837        let stream_id = self.stream_id();
838        self.device
839            .submit_blocking(move |server| server.memory_usage(stream_id))
840            .unwrap()
841    }
842
843    /// Get all devices of a specific type available to this runtime
844    pub fn enumerate_devices(&self, type_id: u16) -> Vec<DeviceId> {
845        R::enumerate_devices(type_id, self.info())
846    }
847
848    /// Get all devices available to this runtime
849    pub fn enumerate_all_devices(&self) -> Vec<DeviceId> {
850        R::enumerate_all_devices(self.info())
851    }
852
853    /// Get the number of devices of a specific type available to this runtime
854    pub fn device_count(&self, type_id: u16) -> usize {
855        self.enumerate_devices(type_id).len()
856    }
857
858    /// Get the number of devices of a specific type available to this runtime
859    pub fn device_count_total(&self) -> usize {
860        self.enumerate_all_devices().len()
861    }
862
863    /// Change the memory allocation mode.
864    ///
865    /// # Safety
866    ///
867    /// This function isn't thread safe and might create memory leaks.
868    pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
869        let stream_id = self.stream_id();
870        self.device
871            .submit(move |server| server.allocation_mode(mode, stream_id));
872    }
873
874    /// Ask the client to release memory that it can release.
875    ///
876    /// Nb: Results will vary on what the memory allocator deems beneficial,
877    /// so it's not guaranteed any memory is freed.
878    pub fn memory_cleanup(&self) {
879        let stream_id = self.stream_id();
880        self.device
881            .submit(move |server| server.memory_cleanup(stream_id));
882    }
883
884    /// Measure the execution time of some inner operations.
885    #[track_caller]
886    pub fn profile<O: Send + 'static>(
887        &self,
888        func: impl FnOnce() -> O + Send,
889        #[allow(unused)] func_name: &str,
890    ) -> Result<(O, ProfileDuration), ProfileError> {
891        // Get the outer caller. For execute() this points straight to the
892        // cube kernel. For general profiling it points to whoever calls profile.
893        #[cfg(feature = "profile-tracy")]
894        let location = std::panic::Location::caller();
895
896        // Make a CPU span. If the server has system profiling this is all you need.
897        #[cfg(feature = "profile-tracy")]
898        let _span = tracy_client::Client::running().unwrap().span_alloc(
899            None,
900            func_name,
901            location.file(),
902            location.line(),
903            0,
904        );
905
906        let stream_id = self.stream_id();
907
908        #[cfg(feature = "profile-tracy")]
909        let gpu_span = if self.utilities.properties.timing_method == TimingMethod::Device {
910            let gpu_span = self
911                .utilities
912                .gpu_client
913                .span_alloc(func_name, "profile", location.file(), location.line())
914                .unwrap();
915            Some(gpu_span)
916        } else {
917            None
918        };
919
920        let device = self.device.clone();
921        #[allow(unused_mut, reason = "Used in profile-tracy")]
922        let mut result = self
923            .device
924            .exclusive(move || {
925                // We first get mut access to the server to create a token.
926                // Then we free to server, since it's going to be accessed in `func()`.
927                let token =
928                    match device.submit_blocking(move |server| server.start_profile(stream_id)) {
929                        Ok(token) => match token {
930                            Ok(token) => token,
931                            Err(err) => return Err(err),
932                        },
933                        Err(err) => {
934                            return Err(ServerError::Generic {
935                                reason: alloc::format!(
936                                    "Can't start profiling because of a call error: {err:?}"
937                                ),
938                                backtrace: BackTrace::capture(),
939                            });
940                        }
941                    };
942
943                // We execute `func()` which will recursibly access the server.
944                let out = func();
945
946                // Finally we get the result from the token.
947                let result = device
948                    .submit_blocking(move |server| {
949                        let mut result = server.end_profile(stream_id, token);
950
951                        match result {
952                            Ok(result) => Ok((out, result)),
953                            Err(err) => Err(err),
954                        }
955                    })
956                    .unwrap();
957
958                Ok(result)
959            })
960            .unwrap()
961            .map_err(|err| ProfileError::Unknown {
962                reason: alloc::format!("{err:?}"),
963                backtrace: BackTrace::capture(),
964            })?;
965
966        #[cfg(feature = "profile-tracy")]
967        if let Some(mut gpu_span) = gpu_span {
968            gpu_span.end_zone();
969            let epoch = self.utilities.epoch_time;
970            // Add in the work to upload the timestamp data.
971            result = result.map(|(o, result)| {
972                (
973                    o,
974                    ProfileDuration::new(
975                        alloc::boxed::Box::pin(async move {
976                            let ticks = result.resolve().await;
977                            let start_duration =
978                                ticks.start_duration_since(epoch).as_nanos() as i64;
979                            let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
980                            gpu_span.upload_timestamp_start(start_duration);
981                            gpu_span.upload_timestamp_end(end_duration);
982                            ticks
983                        }),
984                        TimingMethod::Device,
985                    ),
986                )
987            });
988        }
989
990        result
991    }
992
993    /// Transfer data from one client to another
994    #[cfg_attr(
995        feature = "tracing",
996        tracing::instrument(
997            level = "trace",
998            skip(self, src_descriptor, alloc_descriptor, dst_server)
999        )
1000    )]
1001    fn change_client_sync(
1002        &self,
1003        src_descriptor: CopyDescriptor,
1004        alloc_descriptor: MemoryLayoutDescriptor,
1005        dst_server: &Self,
1006    ) -> MemoryLayout {
1007        let shape = src_descriptor.shape.clone();
1008        let elem_size = src_descriptor.elem_size;
1009        let stream_id = self.stream_id();
1010
1011        let read = self
1012            .device
1013            .submit_blocking(move |server| server.read(vec![src_descriptor], stream_id))
1014            .unwrap();
1015
1016        let mut data = cubecl_common::future::block_on(read).unwrap();
1017
1018        let (handle_base, mut layouts) = self
1019            .utilities
1020            .layout_policy
1021            .apply(stream_id, &[alloc_descriptor]);
1022        let alloc = layouts.remove(0);
1023
1024        let desc_descriptor = CopyDescriptor {
1025            handle: handle_base.clone().binding(),
1026            shape,
1027            strides: alloc.strides.clone(),
1028            elem_size,
1029        };
1030
1031        let (size, memory) = (handle_base.size(), handle_base.memory);
1032        dst_server.device.submit(move |server| {
1033            server.initialize_memory(memory, size, stream_id);
1034            server.write(vec![(desc_descriptor, data.remove(0))], stream_id)
1035        });
1036
1037        alloc
1038    }
1039
1040    /// Returns all vector sizes that are useful to perform optimal IO operation on the given element.
1041    pub fn io_optimized_vector_sizes(
1042        &self,
1043        size: usize,
1044    ) -> impl Iterator<Item = VectorSize> + Clone {
1045        let load_width = self.properties().hardware.load_width as usize;
1046        let size_bits = size * 8;
1047        let max = load_width / size_bits;
1048        let max = usize::min(self.properties().hardware.max_vector_size, max);
1049
1050        // If the max is 8, we want to test 1, 2, 4, 8 which is log2(8) + 1.
1051        let num_candidates = max.trailing_zeros() + 1;
1052
1053        (0..num_candidates).map(|i| 2usize.pow(i)).rev()
1054    }
1055}