cubecl_runtime/
client.rs

1use crate::{
2    DeviceProperties,
3    config::{TypeNameFormatLevel, type_name_format},
4    kernel::KernelMetadata,
5    logging::ProfileLevel,
6    memory_management::{MemoryAllocationMode, MemoryUsage},
7    server::{
8        Allocation, AllocationDescriptor, AllocationKind, Binding, Bindings, ComputeServer,
9        CopyDescriptor, CubeCount, Handle, IoError, ProfileError, ServerUtilities,
10    },
11    storage::{BindingResource, ComputeStorage},
12};
13use alloc::format;
14use alloc::sync::Arc;
15use alloc::vec;
16use alloc::vec::Vec;
17use core::ops::DerefMut;
18use cubecl_common::{
19    ExecutionMode,
20    bytes::Bytes,
21    device::{Device, DeviceContext},
22    future::DynFut,
23    profile::ProfileDuration,
24};
25
26#[allow(unused)]
27use cubecl_common::profile::TimingMethod;
28use cubecl_common::stream_id::StreamId;
29
30/// The ComputeClient is the entry point to require tasks from the ComputeServer.
31/// It should be obtained for a specific device via the Compute struct.
32pub struct ComputeClient<Server: ComputeServer> {
33    context: DeviceContext<Server>,
34    utilities: Arc<ServerUtilities<Server>>,
35    stream_id: Option<StreamId>,
36}
37
38impl<S> Clone for ComputeClient<S>
39where
40    S: ComputeServer,
41{
42    fn clone(&self) -> Self {
43        Self {
44            context: self.context.clone(),
45            utilities: self.utilities.clone(),
46            stream_id: self.stream_id,
47        }
48    }
49}
50
51impl<Server> ComputeClient<Server>
52where
53    Server: ComputeServer,
54{
55    /// Get the info of the current backend.
56    pub fn info(&self) -> &Server::Info {
57        &self.utilities.info
58    }
59
60    /// Create a new client with a new server.
61    pub fn init<D: Device>(device: &D, server: Server) -> Self {
62        let utilities = server.utilities();
63
64        let context = DeviceContext::<Server>::insert(device, server)
65            .expect("Can't create a new client on an already registered server");
66
67        Self {
68            context,
69            utilities,
70            stream_id: None,
71        }
72    }
73
74    /// Load the client for the given device.
75    pub fn load<D: Device>(device: &D) -> Self {
76        let context = DeviceContext::<Server>::locate(device);
77        let utilities = context.lock().utilities();
78
79        Self {
80            context,
81            utilities,
82            stream_id: None,
83        }
84    }
85
86    fn stream_id(&self) -> StreamId {
87        match self.stream_id {
88            Some(val) => val,
89            None => StreamId::current(),
90        }
91    }
92
93    /// Set the stream in which the current client is operating on.
94    ///
95    /// # Safety
96    ///
97    /// This is highly unsafe and should probably only be used by the CubeCL/Burn projects for now.
98    pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
99        self.stream_id = Some(stream_id);
100    }
101
102    fn do_read(&self, descriptors: Vec<CopyDescriptor<'_>>) -> DynFut<Result<Vec<Bytes>, IoError>> {
103        let stream_id = self.stream_id();
104        let mut state = self.context.lock();
105        let fut = state.read(descriptors, stream_id);
106        core::mem::drop(state);
107        fut
108    }
109
110    /// Given bindings, returns owned resources as bytes.
111    pub fn read_async(&self, handles: Vec<Handle>) -> impl Future<Output = Vec<Bytes>> + Send {
112        let strides = [1];
113        let shapes = handles
114            .iter()
115            .map(|it| [it.size() as usize])
116            .collect::<Vec<_>>();
117        let bindings = handles
118            .into_iter()
119            .map(|it| it.binding())
120            .collect::<Vec<_>>();
121        let descriptors = bindings
122            .into_iter()
123            .zip(shapes.iter())
124            .map(|(binding, shape)| CopyDescriptor::new(binding, shape, &strides, 1))
125            .collect();
126
127        let fut = self.do_read(descriptors);
128
129        async move { fut.await.unwrap() }
130    }
131
132    /// Given bindings, returns owned resources as bytes.
133    ///
134    /// # Remarks
135    ///
136    /// Panics if the read operation fails.
137    pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
138        cubecl_common::reader::read_sync(self.read_async(handles))
139    }
140
141    /// Given a binding, returns owned resource as bytes.
142    ///
143    /// # Remarks
144    /// Panics if the read operation fails.
145    pub fn read_one(&self, handle: Handle) -> Bytes {
146        cubecl_common::reader::read_sync(self.read_async(vec![handle])).remove(0)
147    }
148
149    /// Given bindings, returns owned resources as bytes.
150    pub fn read_tensor_async(
151        &self,
152        descriptors: Vec<CopyDescriptor<'_>>,
153    ) -> impl Future<Output = Vec<Bytes>> + Send {
154        let fut = self.do_read(descriptors);
155
156        async move { fut.await.unwrap() }
157    }
158
159    /// Given bindings, returns owned resources as bytes.
160    ///
161    /// # Remarks
162    ///
163    /// Panics if the read operation fails.
164    ///
165    /// The tensor must be in the same layout as created by the runtime, or more strict.
166    /// Contiguous tensors are always fine, strided tensors are only ok if the stride is similar to
167    /// the one created by the runtime (i.e. padded on only the last dimension). A way to check
168    /// stride compatibility on the runtime will be added in the future.
169    ///
170    /// Also see [ComputeClient::create_tensor].
171    pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor<'_>>) -> Vec<Bytes> {
172        cubecl_common::reader::read_sync(self.read_tensor_async(descriptors))
173    }
174
175    /// Given a binding, returns owned resource as bytes.
176    /// See [ComputeClient::read_tensor]
177    pub fn read_one_tensor_async(
178        &self,
179        descriptor: CopyDescriptor<'_>,
180    ) -> impl Future<Output = Bytes> + Send {
181        let fut = self.read_tensor_async(vec![descriptor]);
182
183        async { fut.await.remove(0) }
184    }
185
186    /// Given a binding, returns owned resource as bytes.
187    ///
188    /// # Remarks
189    /// Panics if the read operation fails.
190    /// See [ComputeClient::read_tensor]
191    pub fn read_one_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
192        self.read_tensor(vec![descriptor]).remove(0)
193    }
194
195    /// Given a resource handle, returns the storage resource.
196    pub fn get_resource(
197        &self,
198        binding: Binding,
199    ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
200        let stream_id = self.stream_id();
201        self.context.lock().get_resource(binding, stream_id)
202    }
203
204    fn do_create(
205        &self,
206        descriptors: Vec<AllocationDescriptor<'_>>,
207        data: Vec<&[u8]>,
208    ) -> Result<Vec<Allocation>, IoError> {
209        let mut state = self.context.lock();
210        let allocations = state.create(descriptors.clone(), self.stream_id())?;
211        let descriptors = descriptors
212            .into_iter()
213            .zip(allocations.iter())
214            .zip(data)
215            .map(|((desc, alloc), data)| {
216                (
217                    CopyDescriptor::new(
218                        alloc.handle.clone().binding(),
219                        desc.shape,
220                        &alloc.strides,
221                        desc.elem_size,
222                    ),
223                    data,
224                )
225            })
226            .collect();
227        let stream_id = self.stream_id();
228        state.write(descriptors, stream_id)?;
229        Ok(allocations)
230    }
231
232    /// Given a resource, stores it and returns the resource handle.
233    pub fn create(&self, data: &[u8]) -> Handle {
234        let shape = [data.len()];
235
236        self.do_create(
237            vec![AllocationDescriptor::new(
238                AllocationKind::Contiguous,
239                &shape,
240                1,
241            )],
242            vec![data],
243        )
244        .unwrap()
245        .remove(0)
246        .handle
247    }
248
249    /// Given a resource and shape, stores it and returns the tensor handle and strides.
250    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
251    /// should be taken when indexing.
252    ///
253    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
254    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
255    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
256    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
257    /// can load as much data as possible in a single instruction. It may be aligned even more to
258    /// also take cache lines into account.
259    ///
260    /// However, the stride must be taken into account when indexing and reading the tensor
261    /// (also see [ComputeClient::read_tensor]).
262    pub fn create_tensor(&self, data: &[u8], shape: &[usize], elem_size: usize) -> Allocation {
263        self.do_create(
264            vec![AllocationDescriptor::new(
265                AllocationKind::Optimized,
266                shape,
267                elem_size,
268            )],
269            vec![data],
270        )
271        .unwrap()
272        .remove(0)
273    }
274
275    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
276    /// handle, and returns the handles for them.
277    /// See [ComputeClient::create_tensor]
278    pub fn create_tensors(
279        &self,
280        descriptors: Vec<(AllocationDescriptor<'_>, &[u8])>,
281    ) -> Vec<Allocation> {
282        let (descriptors, data) = descriptors.into_iter().unzip();
283
284        self.do_create(descriptors, data).unwrap()
285    }
286
287    fn do_empty(
288        &self,
289        descriptors: Vec<AllocationDescriptor<'_>>,
290    ) -> Result<Vec<Allocation>, IoError> {
291        let mut state = self.context.lock();
292        state.create(descriptors, self.stream_id())
293    }
294
295    /// Reserves `size` bytes in the storage, and returns a handle over them.
296    pub fn empty(&self, size: usize) -> Handle {
297        let shape = [size];
298        let descriptor = AllocationDescriptor::new(AllocationKind::Contiguous, &shape, 1);
299        self.do_empty(vec![descriptor]).unwrap().remove(0).handle
300    }
301
302    /// Reserves `shape` in the storage, and returns a tensor handle for it.
303    /// See [ComputeClient::create_tensor]
304    pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> Allocation {
305        let descriptor = AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size);
306        self.do_empty(vec![descriptor]).unwrap().remove(0)
307    }
308
309    /// Reserves all `shapes` in a single storage buffer, and returns the handles for them.
310    /// See [ComputeClient::create_tensor]
311    pub fn empty_tensors(&self, descriptors: Vec<AllocationDescriptor<'_>>) -> Vec<Allocation> {
312        self.do_empty(descriptors).unwrap()
313    }
314
315    /// Transfer data from one client to another
316    pub fn to_client(&self, src: Handle, dst_server: &Self) -> Allocation {
317        let shape = [src.size() as usize];
318        let src_descriptor = src.copy_descriptor(&shape, &[1], 1);
319
320        if Server::SERVER_COMM_ENABLED {
321            self.to_client_tensor(src_descriptor, dst_server)
322        } else {
323            let alloc_desc = AllocationDescriptor::new(
324                AllocationKind::Contiguous,
325                src_descriptor.shape,
326                src_descriptor.elem_size,
327            );
328            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
329        }
330    }
331
332    /// Transfer data from one client to another
333    ///
334    /// Make sure the source description can be read in a contiguous manner.
335    pub fn to_client_tensor(
336        &self,
337        src_descriptor: CopyDescriptor<'_>,
338        dst_server: &Self,
339    ) -> Allocation {
340        if Server::SERVER_COMM_ENABLED {
341            let mut server_src = self.context.lock();
342            let mut server_dst = dst_server.context.lock();
343
344            Server::copy(
345                server_src.deref_mut(),
346                server_dst.deref_mut(),
347                src_descriptor,
348                self.stream_id(),
349                dst_server.stream_id(),
350            )
351            .unwrap()
352        } else {
353            let alloc_desc = AllocationDescriptor::new(
354                AllocationKind::Optimized,
355                src_descriptor.shape,
356                src_descriptor.elem_size,
357            );
358            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
359        }
360    }
361
362    #[track_caller]
363    unsafe fn execute_inner(
364        &self,
365        kernel: Server::Kernel,
366        count: CubeCount,
367        bindings: Bindings,
368        mode: ExecutionMode,
369        stream_id: StreamId,
370    ) {
371        let level = self.utilities.logger.profile_level();
372
373        match level {
374            None | Some(ProfileLevel::ExecutionOnly) => {
375                let mut state = self.context.lock();
376                let name = kernel.name();
377
378                unsafe { state.execute(kernel, count, bindings, mode, stream_id) };
379
380                if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
381                    let info = type_name_format(name, TypeNameFormatLevel::Balanced);
382                    self.utilities.logger.register_execution(info);
383                }
384            }
385            Some(level) => {
386                let name = kernel.name();
387                let kernel_id = kernel.id();
388                let profile = self
389                    .profile(
390                        || unsafe {
391                            let mut state = self.context.lock();
392                            state.execute(kernel, count.clone(), bindings, mode, stream_id)
393                        },
394                        name,
395                    )
396                    .unwrap();
397                let info = match level {
398                    ProfileLevel::Full => {
399                        format!("{name}: {kernel_id} CubeCount {count:?}")
400                    }
401                    _ => type_name_format(name, TypeNameFormatLevel::Balanced),
402                };
403                self.utilities.logger.register_profiled(info, profile);
404            }
405        }
406    }
407
408    /// Executes the `kernel` over the given `bindings`.
409    #[track_caller]
410    pub fn execute(&self, kernel: Server::Kernel, count: CubeCount, bindings: Bindings) {
411        // SAFETY: Using checked execution mode.
412        unsafe {
413            self.execute_inner(
414                kernel,
415                count,
416                bindings,
417                ExecutionMode::Checked,
418                self.stream_id(),
419            );
420        }
421    }
422
423    /// Executes the `kernel` over the given `bindings` without performing any bound checks.
424    ///
425    /// # Safety
426    ///
427    /// To ensure this is safe, you must verify your kernel:
428    /// - Has no out-of-bound reads and writes that can happen.
429    /// - Has no infinite loops that might never terminate.
430    #[track_caller]
431    pub unsafe fn execute_unchecked(
432        &self,
433        kernel: Server::Kernel,
434        count: CubeCount,
435        bindings: Bindings,
436    ) {
437        // SAFETY: Caller has to uphold kernel being safe.
438        unsafe {
439            self.execute_inner(
440                kernel,
441                count,
442                bindings,
443                ExecutionMode::Unchecked,
444                self.stream_id(),
445            );
446        }
447    }
448
449    /// Flush all outstanding commands.
450    pub fn flush(&self) {
451        let stream_id = self.stream_id();
452        self.context.lock().flush(stream_id);
453    }
454
455    /// Wait for the completion of every task in the server.
456    pub fn sync(&self) -> DynFut<()> {
457        let stream_id = self.stream_id();
458        let mut state = self.context.lock();
459        let fut = state.sync(stream_id);
460        core::mem::drop(state);
461        self.utilities.logger.profile_summary();
462
463        fut
464    }
465
466    /// Get the features supported by the compute server.
467    pub fn properties(&self) -> &DeviceProperties {
468        &self.utilities.properties
469    }
470
471    /// # Warning
472    ///
473    /// For private use only.
474    pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
475        Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
476    }
477
478    /// Get the current memory usage of this client.
479    pub fn memory_usage(&self) -> MemoryUsage {
480        self.context.lock().memory_usage(self.stream_id())
481    }
482
483    /// Change the memory allocation mode.
484    ///
485    /// # Safety
486    ///
487    /// This function isn't thread safe and might create memory leaks.
488    pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
489        self.context.lock().allocation_mode(mode, self.stream_id())
490    }
491
492    /// Use a persistent memory strategy to execute the provided function.
493    ///
494    /// # Notes
495    ///
496    /// - Using that memory strategy is beneficial for stating model parameters and similar workflows.
497    /// - You can call [Self::memory_cleanup()] if you want to free persistent memory.
498    pub fn memory_persistent_allocation<Input, Output, Func: Fn(Input) -> Output>(
499        &self,
500        input: Input,
501        func: Func,
502    ) -> Output {
503        let device_guard = self.context.lock_device();
504
505        self.context
506            .lock()
507            .allocation_mode(MemoryAllocationMode::Persistent, self.stream_id());
508
509        let output = func(input);
510
511        self.context
512            .lock()
513            .allocation_mode(MemoryAllocationMode::Auto, self.stream_id());
514
515        core::mem::drop(device_guard);
516
517        output
518    }
519
520    /// Ask the client to release memory that it can release.
521    ///
522    /// Nb: Results will vary on what the memory allocator deems beneficial,
523    /// so it's not guaranteed any memory is freed.
524    pub fn memory_cleanup(&self) {
525        self.context.lock().memory_cleanup(self.stream_id())
526    }
527
528    /// Measure the execution time of some inner operations.
529    #[track_caller]
530    pub fn profile<O>(
531        &self,
532        func: impl FnOnce() -> O,
533        #[allow(unused)] func_name: &str,
534    ) -> Result<ProfileDuration, ProfileError> {
535        // Get the outer caller. For execute() this points straight to the
536        // cube kernel. For general profiling it points to whoever calls profile.
537        #[cfg(feature = "profile-tracy")]
538        let location = std::panic::Location::caller();
539
540        // Make a CPU span. If the server has system profiling this is all you need.
541        #[cfg(feature = "profile-tracy")]
542        let _span = tracy_client::Client::running().unwrap().span_alloc(
543            None,
544            func_name,
545            location.file(),
546            location.line(),
547            0,
548        );
549
550        let device_guard = self.context.lock_device();
551
552        #[cfg(feature = "profile-tracy")]
553        let gpu_span = if self.state.properties.timing_method == TimingMethod::Device {
554            let gpu_span = self
555                .state
556                .gpu_client
557                .span_alloc(func_name, "profile", location.file(), location.line())
558                .unwrap();
559            Some(gpu_span)
560        } else {
561            None
562        };
563
564        let token = self.context.lock().start_profile(self.stream_id());
565
566        let out = func();
567
568        let result = self.context.lock().end_profile(self.stream_id(), token);
569
570        core::mem::drop(out);
571
572        #[cfg(feature = "profile-tracy")]
573        if let Some(mut gpu_span) = gpu_span {
574            gpu_span.end_zone();
575            let epoch = self.state.epoch_time;
576            // Add in the work to upload the timestamp data.
577            result = result.map(|result| {
578                ProfileDuration::new(
579                    Box::pin(async move {
580                        let ticks = result.resolve().await;
581                        let start_duration = ticks.start_duration_since(epoch).as_nanos() as i64;
582                        let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
583                        gpu_span.upload_timestamp_start(start_duration);
584                        gpu_span.upload_timestamp_end(end_duration);
585                        ticks
586                    }),
587                    TimingMethod::Device,
588                )
589            });
590        }
591        core::mem::drop(device_guard);
592
593        result
594    }
595
596    /// Transfer data from one client to another
597    fn change_client_sync(
598        &self,
599        src_descriptor: CopyDescriptor<'_>,
600        alloc_descriptor: AllocationDescriptor<'_>,
601        dst_server: &Self,
602    ) -> Allocation {
603        let shape = src_descriptor.shape;
604        let elem_size = src_descriptor.elem_size;
605        let stream_id = self.stream_id();
606
607        // Allocate destination
608        let alloc = dst_server
609            .context
610            .lock()
611            .create(vec![alloc_descriptor], self.stream_id())
612            .unwrap()
613            .remove(0);
614
615        let read = self.context.lock().read(vec![src_descriptor], stream_id);
616        let data = cubecl_common::future::block_on(read).unwrap();
617
618        let desc_descriptor = CopyDescriptor {
619            binding: alloc.handle.clone().binding(),
620            shape,
621            strides: &alloc.strides,
622            elem_size,
623        };
624
625        dst_server
626            .context
627            .lock()
628            .write(vec![(desc_descriptor, &data[0])], stream_id)
629            .unwrap();
630
631        alloc
632    }
633}