cubecl_runtime/
client.rs

1use crate::{
2    DeviceProperties,
3    channel::ComputeChannel,
4    config::{TypeNameFormatLevel, type_name_format},
5    kernel::KernelMetadata,
6    logging::{ProfileLevel, ServerLogger},
7    memory_management::MemoryUsage,
8    server::{Binding, BindingWithMeta, Bindings, ComputeServer, CubeCount, Handle, ProfileError},
9    storage::{BindingResource, ComputeStorage},
10};
11use alloc::format;
12use alloc::sync::Arc;
13use alloc::vec;
14use alloc::vec::Vec;
15use cubecl_common::{ExecutionMode, profile::ProfileDuration};
16
17#[allow(unused)]
18use cubecl_common::profile::TimingMethod;
19
20#[cfg(multi_threading)]
21use cubecl_common::stream_id::StreamId;
22
23/// The ComputeClient is the entry point to require tasks from the ComputeServer.
24/// It should be obtained for a specific device via the Compute struct.
25pub struct ComputeClient<Server: ComputeServer, Channel> {
26    channel: Channel,
27    state: Arc<ComputeClientState<Server>>,
28}
29
30#[derive(new)]
31struct ComputeClientState<Server: ComputeServer> {
32    #[cfg(feature = "profile-tracy")]
33    epoch_time: web_time::Instant,
34
35    #[cfg(feature = "profile-tracy")]
36    gpu_client: tracy_client::GpuContext,
37
38    properties: DeviceProperties<Server::Feature>,
39    info: Server::Info,
40    logger: Arc<ServerLogger>,
41
42    #[cfg(multi_threading)]
43    current_profiling: spin::RwLock<Option<StreamId>>,
44}
45
46impl<S, C> Clone for ComputeClient<S, C>
47where
48    S: ComputeServer,
49    C: ComputeChannel<S>,
50{
51    fn clone(&self) -> Self {
52        Self {
53            channel: self.channel.clone(),
54            state: self.state.clone(),
55        }
56    }
57}
58
59impl<Server, Channel> ComputeClient<Server, Channel>
60where
61    Server: ComputeServer,
62    Channel: ComputeChannel<Server>,
63{
64    /// Get the info of the current backend.
65    pub fn info(&self) -> &Server::Info {
66        &self.state.info
67    }
68
69    /// Create a new client.
70    pub fn new(
71        channel: Channel,
72        properties: DeviceProperties<Server::Feature>,
73        info: Server::Info,
74    ) -> Self {
75        let logger = ServerLogger::default();
76
77        // Start a tracy client if needed.
78        #[cfg(feature = "profile-tracy")]
79        let client = tracy_client::Client::start();
80
81        let state = ComputeClientState {
82            properties,
83            logger: Arc::new(logger),
84            #[cfg(multi_threading)]
85            current_profiling: spin::RwLock::new(None),
86            // Create the GPU client if needed.
87            #[cfg(feature = "profile-tracy")]
88            gpu_client: client
89                .clone()
90                .new_gpu_context(
91                    Some(&format!("{info:?}")),
92                    // In the future should ask the server what makes sense here. 'Invalid' atm is a generic stand-in (Tracy doesn't have CUDA/RocM atm anyway).
93                    tracy_client::GpuContextType::Invalid,
94                    0,   // Timestamps are manually aligned to this epoch so start at 0.
95                    1.0, // Timestamps are manually converted to be nanoseconds so period is 1.
96                )
97                .unwrap(),
98            #[cfg(feature = "profile-tracy")]
99            epoch_time: web_time::Instant::now(),
100            info,
101        };
102
103        Self {
104            channel,
105            state: Arc::new(state),
106        }
107    }
108
109    /// Given bindings, returns owned resources as bytes.
110    pub fn read_async(
111        &self,
112        bindings: Vec<Binding>,
113    ) -> impl Future<Output = Vec<Vec<u8>>> + Send + use<Server, Channel> {
114        self.profile_guard();
115
116        self.channel.read(bindings)
117    }
118
119    /// Given bindings, returns owned resources as bytes.
120    ///
121    /// # Remarks
122    ///
123    /// Panics if the read operation fails.
124    pub fn read(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
125        self.profile_guard();
126
127        cubecl_common::reader::read_sync(self.channel.read(bindings))
128    }
129
130    /// Given a binding, returns owned resource as bytes.
131    ///
132    /// # Remarks
133    /// Panics if the read operation fails.
134    pub fn read_one(&self, binding: Binding) -> Vec<u8> {
135        self.profile_guard();
136
137        cubecl_common::reader::read_sync(self.channel.read([binding].into())).remove(0)
138    }
139
140    /// Given bindings, returns owned resources as bytes.
141    pub async fn read_tensor_async(&self, bindings: Vec<BindingWithMeta>) -> Vec<Vec<u8>> {
142        self.profile_guard();
143
144        self.channel.read_tensor(bindings).await
145    }
146
147    /// Given bindings, returns owned resources as bytes.
148    ///
149    /// # Remarks
150    ///
151    /// Panics if the read operation fails.
152    ///
153    /// The tensor must be in the same layout as created by the runtime, or more strict.
154    /// Contiguous tensors are always fine, strided tensors are only ok if the stride is similar to
155    /// the one created by the runtime (i.e. padded on only the last dimension). A way to check
156    /// stride compatibility on the runtime will be added in the future.
157    ///
158    /// Also see [ComputeClient::create_tensor].
159    pub fn read_tensor(&self, bindings: Vec<BindingWithMeta>) -> Vec<Vec<u8>> {
160        self.profile_guard();
161
162        cubecl_common::reader::read_sync(self.channel.read_tensor(bindings))
163    }
164
165    /// Given a binding, returns owned resource as bytes.
166    /// See [ComputeClient::read_tensor]
167    pub async fn read_one_tensor_async(&self, binding: BindingWithMeta) -> Vec<u8> {
168        self.profile_guard();
169
170        self.channel.read_tensor([binding].into()).await.remove(0)
171    }
172
173    /// Given a binding, returns owned resource as bytes.
174    ///
175    /// # Remarks
176    /// Panics if the read operation fails.
177    /// See [ComputeClient::read_tensor]
178    pub fn read_one_tensor(&self, binding: BindingWithMeta) -> Vec<u8> {
179        self.read_tensor(vec![binding]).remove(0)
180    }
181
182    /// Given a resource handle, returns the storage resource.
183    pub fn get_resource(
184        &self,
185        binding: Binding,
186    ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
187        self.profile_guard();
188
189        self.channel.get_resource(binding)
190    }
191
192    /// Given a resource, stores it and returns the resource handle.
193    pub fn create(&self, data: &[u8]) -> Handle {
194        self.profile_guard();
195
196        self.channel.create(data)
197    }
198
199    /// Given a resource and shape, stores it and returns the tensor handle and strides.
200    /// This may or may not return contiguous strides. The layout is up to the runtime, and care
201    /// should be taken when indexing.
202    ///
203    /// Currently the tensor may either be contiguous (most runtimes), or "pitched", to use the CUDA
204    /// terminology. This means the last (contiguous) dimension is padded to fit a certain alignment,
205    /// and the strides are adjusted accordingly. This can make memory accesses significantly faster
206    /// since all rows are aligned to at least 16 bytes (the maximum load width), meaning the GPU
207    /// can load as much data as possible in a single instruction. It may be aligned even more to
208    /// also take cache lines into account.
209    ///
210    /// However, the stride must be taken into account when indexing and reading the tensor
211    /// (also see [ComputeClient::read_tensor]).
212    pub fn create_tensor(
213        &self,
214        data: &[u8],
215        shape: &[usize],
216        elem_size: usize,
217    ) -> (Handle, Vec<usize>) {
218        self.channel
219            .create_tensors(vec![data], vec![shape], vec![elem_size])
220            .pop()
221            .unwrap()
222    }
223
224    /// Reserves all `shapes` in a single storage buffer, copies the corresponding `data` into each
225    /// handle, and returns the handles for them.
226    /// See [ComputeClient::create_tensor]
227    pub fn create_tensors(
228        &self,
229        data: Vec<&[u8]>,
230        shapes: Vec<&[usize]>,
231        elem_size: Vec<usize>,
232    ) -> Vec<(Handle, Vec<usize>)> {
233        self.profile_guard();
234
235        self.channel.create_tensors(data, shapes, elem_size)
236    }
237
238    /// Reserves `size` bytes in the storage, and returns a handle over them.
239    pub fn empty(&self, size: usize) -> Handle {
240        self.profile_guard();
241
242        self.channel.empty(size)
243    }
244
245    /// Reserves `shape` in the storage, and returns a tensor handle for it.
246    /// See [ComputeClient::create_tensor]
247    pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> (Handle, Vec<usize>) {
248        self.channel
249            .empty_tensors(vec![shape], vec![elem_size])
250            .pop()
251            .unwrap()
252    }
253
254    /// Reserves all `shapes` in a single storage buffer, and returns the handles for them.
255    /// See [ComputeClient::create_tensor]
256    pub fn empty_tensors(
257        &self,
258        shapes: Vec<&[usize]>,
259        elem_size: Vec<usize>,
260    ) -> Vec<(Handle, Vec<usize>)> {
261        self.profile_guard();
262
263        self.channel.empty_tensors(shapes, elem_size)
264    }
265
266    #[track_caller]
267    unsafe fn execute_inner(
268        &self,
269        kernel: Server::Kernel,
270        count: CubeCount,
271        bindings: Bindings,
272        mode: ExecutionMode,
273    ) {
274        let level = self.state.logger.profile_level();
275
276        match level {
277            None | Some(ProfileLevel::ExecutionOnly) => {
278                self.profile_guard();
279
280                let name = kernel.name();
281
282                unsafe {
283                    self.channel
284                        .execute(kernel, count, bindings, mode, self.state.logger.clone())
285                };
286
287                if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
288                    let info = type_name_format(name, TypeNameFormatLevel::Balanced);
289                    self.state.logger.register_execution(info);
290                }
291            }
292            Some(level) => {
293                let name = kernel.name();
294                let kernel_id = kernel.id();
295                let profile = self
296                    .profile(
297                        || unsafe {
298                            self.channel.execute(
299                                kernel,
300                                count.clone(),
301                                bindings,
302                                mode,
303                                self.state.logger.clone(),
304                            )
305                        },
306                        name,
307                    )
308                    .unwrap();
309                let info = match level {
310                    ProfileLevel::Full => {
311                        format!("{name}: {kernel_id} CubeCount {count:?}")
312                    }
313                    _ => type_name_format(name, TypeNameFormatLevel::Balanced),
314                };
315                self.state.logger.register_profiled(info, profile);
316            }
317        }
318    }
319
320    /// Executes the `kernel` over the given `bindings`.
321    #[track_caller]
322    pub fn execute(&self, kernel: Server::Kernel, count: CubeCount, bindings: Bindings) {
323        // SAFETY: Using checked execution mode.
324        unsafe {
325            self.execute_inner(kernel, count, bindings, ExecutionMode::Checked);
326        }
327    }
328
329    /// Executes the `kernel` over the given `bindings` without performing any bound checks.
330    ///
331    /// # Safety
332    ///
333    /// To ensure this is safe, you must verify your kernel:
334    /// - Has no out-of-bound reads and writes that can happen.
335    /// - Has no infinite loops that might never terminate.
336    #[track_caller]
337    pub unsafe fn execute_unchecked(
338        &self,
339        kernel: Server::Kernel,
340        count: CubeCount,
341        bindings: Bindings,
342    ) {
343        // SAFETY: Caller has to uphold kernel being safe.
344        unsafe {
345            self.execute_inner(kernel, count, bindings, ExecutionMode::Unchecked);
346        }
347    }
348
349    /// Flush all outstanding commands.
350    pub fn flush(&self) {
351        self.profile_guard();
352
353        self.channel.flush();
354    }
355
356    /// Wait for the completion of every task in the server.
357    pub async fn sync(&self) {
358        self.profile_guard();
359
360        self.channel.sync().await;
361        self.state.logger.profile_summary();
362    }
363
364    /// Get the features supported by the compute server.
365    pub fn properties(&self) -> &DeviceProperties<Server::Feature> {
366        &self.state.properties
367    }
368
369    /// # Warning
370    ///
371    /// For private use only.
372    pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties<Server::Feature>> {
373        Arc::get_mut(&mut self.state).map(|state| &mut state.properties)
374    }
375
376    /// Get the current memory usage of this client.
377    pub fn memory_usage(&self) -> MemoryUsage {
378        self.profile_guard();
379
380        self.channel.memory_usage()
381    }
382
383    /// Ask the client to release memory that it can release.
384    ///
385    /// Nb: Results will vary on what the memory allocator deems beneficial,
386    /// so it's not guaranteed any memory is freed.
387    pub fn memory_cleanup(&self) {
388        self.profile_guard();
389
390        self.channel.memory_cleanup()
391    }
392
393    /// Measure the execution time of some inner operations.
394    #[track_caller]
395    pub fn profile<O>(
396        &self,
397        func: impl FnOnce() -> O,
398        #[allow(unused)] func_name: &str,
399    ) -> Result<ProfileDuration, ProfileError> {
400        // Get the outer caller. For execute() this points straight to the
401        // cube kernel. For general profiling it points to whoever calls profile.
402        #[cfg(feature = "profile-tracy")]
403        let location = std::panic::Location::caller();
404
405        // Make a CPU span. If the server has system profiling this is all you need.
406        #[cfg(feature = "profile-tracy")]
407        let _span = tracy_client::Client::running().unwrap().span_alloc(
408            None,
409            func_name,
410            location.file(),
411            location.line(),
412            0,
413        );
414
415        #[cfg(multi_threading)]
416        let stream_id = self.profile_acquire();
417
418        #[cfg(feature = "profile-tracy")]
419        let gpu_span = if self.state.properties.timing_method == TimingMethod::Device {
420            let gpu_span = self
421                .state
422                .gpu_client
423                .span_alloc(func_name, "profile", location.file(), location.line())
424                .unwrap();
425            Some(gpu_span)
426        } else {
427            None
428        };
429
430        let token = self.channel.start_profile();
431
432        let out = func();
433
434        #[allow(unused_mut)]
435        let mut result = self.channel.end_profile(token);
436
437        core::mem::drop(out);
438
439        #[cfg(feature = "profile-tracy")]
440        if let Some(mut gpu_span) = gpu_span {
441            gpu_span.end_zone();
442            let epoch = self.state.epoch_time;
443            // Add in the work to upload the timestamp data.
444            result = result.map(|result| {
445                ProfileDuration::new(
446                    Box::pin(async move {
447                        let ticks = result.resolve().await;
448                        let start_duration = ticks.start_duration_since(epoch).as_nanos() as i64;
449                        let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
450                        gpu_span.upload_timestamp_start(start_duration);
451                        gpu_span.upload_timestamp_end(end_duration);
452                        ticks
453                    }),
454                    TimingMethod::Device,
455                )
456            });
457        }
458
459        #[cfg(multi_threading)]
460        self.profile_release(stream_id);
461
462        result
463    }
464
465    #[cfg(not(multi_threading))]
466    fn profile_guard(&self) {}
467
468    #[cfg(multi_threading)]
469    fn profile_guard(&self) {
470        let current = self.state.current_profiling.read();
471
472        if let Some(current_stream_id) = current.as_ref() {
473            let stream_id = StreamId::current();
474
475            if current_stream_id == &stream_id {
476                return;
477            }
478
479            core::mem::drop(current);
480
481            loop {
482                std::thread::sleep(core::time::Duration::from_millis(10));
483
484                let current = self.state.current_profiling.read();
485                match current.as_ref() {
486                    Some(current_stream_id) => {
487                        if current_stream_id == &stream_id {
488                            return;
489                        }
490                    }
491                    None => {
492                        return;
493                    }
494                }
495            }
496        }
497    }
498
499    #[cfg(multi_threading)]
500    fn profile_acquire(&self) -> Option<StreamId> {
501        let stream_id = StreamId::current();
502        let mut current = self.state.current_profiling.write();
503
504        match current.as_mut() {
505            Some(current_stream_id) => {
506                if current_stream_id == &stream_id {
507                    return None;
508                }
509
510                core::mem::drop(current);
511
512                loop {
513                    std::thread::sleep(core::time::Duration::from_millis(10));
514
515                    let mut current = self.state.current_profiling.write();
516
517                    match current.as_mut() {
518                        Some(current_stream_id) => {
519                            if current_stream_id == &stream_id {
520                                return None;
521                            }
522                        }
523                        None => {
524                            *current = Some(stream_id);
525                            return Some(stream_id);
526                        }
527                    }
528                }
529            }
530            None => {
531                *current = Some(stream_id);
532                Some(stream_id)
533            }
534        }
535    }
536
537    #[cfg(multi_threading)]
538    fn profile_release(&self, stream_id: Option<StreamId>) {
539        let stream_id = match stream_id {
540            Some(val) => val,
541            None => return, // No releasing
542        };
543        let mut current = self.state.current_profiling.write();
544
545        match current.as_mut() {
546            Some(current_stream_id) => {
547                if current_stream_id != &stream_id {
548                    panic!("Can't release a different profiling guard.");
549                } else {
550                    *current = None;
551                }
552            }
553            None => panic!("Can't release an empty profiling guard"),
554        }
555    }
556}