1use crate::{
2 DeviceProperties,
3 channel::ComputeChannel,
4 config::{TypeNameFormatLevel, type_name_format},
5 kernel::KernelMetadata,
6 logging::{ProfileLevel, ServerLogger},
7 memory_management::MemoryUsage,
8 server::{Binding, BindingWithMeta, Bindings, ComputeServer, CubeCount, Handle, ProfileError},
9 storage::{BindingResource, ComputeStorage},
10};
11use alloc::format;
12use alloc::sync::Arc;
13use alloc::vec;
14use alloc::vec::Vec;
15use cubecl_common::{ExecutionMode, profile::ProfileDuration};
16
17#[allow(unused)]
18use cubecl_common::profile::TimingMethod;
19
20#[cfg(multi_threading)]
21use cubecl_common::stream_id::StreamId;
22
23pub struct ComputeClient<Server: ComputeServer, Channel> {
26 channel: Channel,
27 state: Arc<ComputeClientState<Server>>,
28}
29
30#[derive(new)]
31struct ComputeClientState<Server: ComputeServer> {
32 #[cfg(feature = "profile-tracy")]
33 epoch_time: web_time::Instant,
34
35 #[cfg(feature = "profile-tracy")]
36 gpu_client: tracy_client::GpuContext,
37
38 properties: DeviceProperties<Server::Feature>,
39 info: Server::Info,
40 logger: Arc<ServerLogger>,
41
42 #[cfg(multi_threading)]
43 current_profiling: spin::RwLock<Option<StreamId>>,
44}
45
46impl<S, C> Clone for ComputeClient<S, C>
47where
48 S: ComputeServer,
49 C: ComputeChannel<S>,
50{
51 fn clone(&self) -> Self {
52 Self {
53 channel: self.channel.clone(),
54 state: self.state.clone(),
55 }
56 }
57}
58
59impl<Server, Channel> ComputeClient<Server, Channel>
60where
61 Server: ComputeServer,
62 Channel: ComputeChannel<Server>,
63{
64 pub fn info(&self) -> &Server::Info {
66 &self.state.info
67 }
68
69 pub fn new(
71 channel: Channel,
72 properties: DeviceProperties<Server::Feature>,
73 info: Server::Info,
74 ) -> Self {
75 let logger = ServerLogger::default();
76
77 #[cfg(feature = "profile-tracy")]
79 let client = tracy_client::Client::start();
80
81 let state = ComputeClientState {
82 properties,
83 logger: Arc::new(logger),
84 #[cfg(multi_threading)]
85 current_profiling: spin::RwLock::new(None),
86 #[cfg(feature = "profile-tracy")]
88 gpu_client: client
89 .clone()
90 .new_gpu_context(
91 Some(&format!("{info:?}")),
92 tracy_client::GpuContextType::Invalid,
94 0, 1.0, )
97 .unwrap(),
98 #[cfg(feature = "profile-tracy")]
99 epoch_time: web_time::Instant::now(),
100 info,
101 };
102
103 Self {
104 channel,
105 state: Arc::new(state),
106 }
107 }
108
109 pub fn read_async(
111 &self,
112 bindings: Vec<Binding>,
113 ) -> impl Future<Output = Vec<Vec<u8>>> + Send + use<Server, Channel> {
114 self.profile_guard();
115
116 self.channel.read(bindings)
117 }
118
119 pub fn read(&self, bindings: Vec<Binding>) -> Vec<Vec<u8>> {
125 self.profile_guard();
126
127 cubecl_common::reader::read_sync(self.channel.read(bindings))
128 }
129
130 pub fn read_one(&self, binding: Binding) -> Vec<u8> {
135 self.profile_guard();
136
137 cubecl_common::reader::read_sync(self.channel.read([binding].into())).remove(0)
138 }
139
140 pub async fn read_tensor_async(&self, bindings: Vec<BindingWithMeta>) -> Vec<Vec<u8>> {
142 self.profile_guard();
143
144 self.channel.read_tensor(bindings).await
145 }
146
147 pub fn read_tensor(&self, bindings: Vec<BindingWithMeta>) -> Vec<Vec<u8>> {
160 self.profile_guard();
161
162 cubecl_common::reader::read_sync(self.channel.read_tensor(bindings))
163 }
164
165 pub async fn read_one_tensor_async(&self, binding: BindingWithMeta) -> Vec<u8> {
168 self.profile_guard();
169
170 self.channel.read_tensor([binding].into()).await.remove(0)
171 }
172
173 pub fn read_one_tensor(&self, binding: BindingWithMeta) -> Vec<u8> {
179 self.read_tensor(vec![binding]).remove(0)
180 }
181
182 pub fn get_resource(
184 &self,
185 binding: Binding,
186 ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
187 self.profile_guard();
188
189 self.channel.get_resource(binding)
190 }
191
192 pub fn create(&self, data: &[u8]) -> Handle {
194 self.profile_guard();
195
196 self.channel.create(data)
197 }
198
199 pub fn create_tensor(
213 &self,
214 data: &[u8],
215 shape: &[usize],
216 elem_size: usize,
217 ) -> (Handle, Vec<usize>) {
218 self.channel
219 .create_tensors(vec![data], vec![shape], vec![elem_size])
220 .pop()
221 .unwrap()
222 }
223
224 pub fn create_tensors(
228 &self,
229 data: Vec<&[u8]>,
230 shapes: Vec<&[usize]>,
231 elem_size: Vec<usize>,
232 ) -> Vec<(Handle, Vec<usize>)> {
233 self.profile_guard();
234
235 self.channel.create_tensors(data, shapes, elem_size)
236 }
237
238 pub fn empty(&self, size: usize) -> Handle {
240 self.profile_guard();
241
242 self.channel.empty(size)
243 }
244
245 pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> (Handle, Vec<usize>) {
248 self.channel
249 .empty_tensors(vec![shape], vec![elem_size])
250 .pop()
251 .unwrap()
252 }
253
254 pub fn empty_tensors(
257 &self,
258 shapes: Vec<&[usize]>,
259 elem_size: Vec<usize>,
260 ) -> Vec<(Handle, Vec<usize>)> {
261 self.profile_guard();
262
263 self.channel.empty_tensors(shapes, elem_size)
264 }
265
266 #[track_caller]
267 unsafe fn execute_inner(
268 &self,
269 kernel: Server::Kernel,
270 count: CubeCount,
271 bindings: Bindings,
272 mode: ExecutionMode,
273 ) {
274 let level = self.state.logger.profile_level();
275
276 match level {
277 None | Some(ProfileLevel::ExecutionOnly) => {
278 self.profile_guard();
279
280 let name = kernel.name();
281
282 unsafe {
283 self.channel
284 .execute(kernel, count, bindings, mode, self.state.logger.clone())
285 };
286
287 if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
288 let info = type_name_format(name, TypeNameFormatLevel::Balanced);
289 self.state.logger.register_execution(info);
290 }
291 }
292 Some(level) => {
293 let name = kernel.name();
294 let kernel_id = kernel.id();
295 let profile = self
296 .profile(
297 || unsafe {
298 self.channel.execute(
299 kernel,
300 count.clone(),
301 bindings,
302 mode,
303 self.state.logger.clone(),
304 )
305 },
306 name,
307 )
308 .unwrap();
309 let info = match level {
310 ProfileLevel::Full => {
311 format!("{name}: {kernel_id} CubeCount {count:?}")
312 }
313 _ => type_name_format(name, TypeNameFormatLevel::Balanced),
314 };
315 self.state.logger.register_profiled(info, profile);
316 }
317 }
318 }
319
320 #[track_caller]
322 pub fn execute(&self, kernel: Server::Kernel, count: CubeCount, bindings: Bindings) {
323 unsafe {
325 self.execute_inner(kernel, count, bindings, ExecutionMode::Checked);
326 }
327 }
328
329 #[track_caller]
337 pub unsafe fn execute_unchecked(
338 &self,
339 kernel: Server::Kernel,
340 count: CubeCount,
341 bindings: Bindings,
342 ) {
343 unsafe {
345 self.execute_inner(kernel, count, bindings, ExecutionMode::Unchecked);
346 }
347 }
348
349 pub fn flush(&self) {
351 self.profile_guard();
352
353 self.channel.flush();
354 }
355
356 pub async fn sync(&self) {
358 self.profile_guard();
359
360 self.channel.sync().await;
361 self.state.logger.profile_summary();
362 }
363
364 pub fn properties(&self) -> &DeviceProperties<Server::Feature> {
366 &self.state.properties
367 }
368
369 pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties<Server::Feature>> {
373 Arc::get_mut(&mut self.state).map(|state| &mut state.properties)
374 }
375
376 pub fn memory_usage(&self) -> MemoryUsage {
378 self.profile_guard();
379
380 self.channel.memory_usage()
381 }
382
383 pub fn memory_cleanup(&self) {
388 self.profile_guard();
389
390 self.channel.memory_cleanup()
391 }
392
393 #[track_caller]
395 pub fn profile<O>(
396 &self,
397 func: impl FnOnce() -> O,
398 #[allow(unused)] func_name: &str,
399 ) -> Result<ProfileDuration, ProfileError> {
400 #[cfg(feature = "profile-tracy")]
403 let location = std::panic::Location::caller();
404
405 #[cfg(feature = "profile-tracy")]
407 let _span = tracy_client::Client::running().unwrap().span_alloc(
408 None,
409 func_name,
410 location.file(),
411 location.line(),
412 0,
413 );
414
415 #[cfg(multi_threading)]
416 let stream_id = self.profile_acquire();
417
418 #[cfg(feature = "profile-tracy")]
419 let gpu_span = if self.state.properties.timing_method == TimingMethod::Device {
420 let gpu_span = self
421 .state
422 .gpu_client
423 .span_alloc(func_name, "profile", location.file(), location.line())
424 .unwrap();
425 Some(gpu_span)
426 } else {
427 None
428 };
429
430 let token = self.channel.start_profile();
431
432 let out = func();
433
434 #[allow(unused_mut)]
435 let mut result = self.channel.end_profile(token);
436
437 core::mem::drop(out);
438
439 #[cfg(feature = "profile-tracy")]
440 if let Some(mut gpu_span) = gpu_span {
441 gpu_span.end_zone();
442 let epoch = self.state.epoch_time;
443 result = result.map(|result| {
445 ProfileDuration::new(
446 Box::pin(async move {
447 let ticks = result.resolve().await;
448 let start_duration = ticks.start_duration_since(epoch).as_nanos() as i64;
449 let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
450 gpu_span.upload_timestamp_start(start_duration);
451 gpu_span.upload_timestamp_end(end_duration);
452 ticks
453 }),
454 TimingMethod::Device,
455 )
456 });
457 }
458
459 #[cfg(multi_threading)]
460 self.profile_release(stream_id);
461
462 result
463 }
464
465 #[cfg(not(multi_threading))]
466 fn profile_guard(&self) {}
467
468 #[cfg(multi_threading)]
469 fn profile_guard(&self) {
470 let current = self.state.current_profiling.read();
471
472 if let Some(current_stream_id) = current.as_ref() {
473 let stream_id = StreamId::current();
474
475 if current_stream_id == &stream_id {
476 return;
477 }
478
479 core::mem::drop(current);
480
481 loop {
482 std::thread::sleep(core::time::Duration::from_millis(10));
483
484 let current = self.state.current_profiling.read();
485 match current.as_ref() {
486 Some(current_stream_id) => {
487 if current_stream_id == &stream_id {
488 return;
489 }
490 }
491 None => {
492 return;
493 }
494 }
495 }
496 }
497 }
498
499 #[cfg(multi_threading)]
500 fn profile_acquire(&self) -> Option<StreamId> {
501 let stream_id = StreamId::current();
502 let mut current = self.state.current_profiling.write();
503
504 match current.as_mut() {
505 Some(current_stream_id) => {
506 if current_stream_id == &stream_id {
507 return None;
508 }
509
510 core::mem::drop(current);
511
512 loop {
513 std::thread::sleep(core::time::Duration::from_millis(10));
514
515 let mut current = self.state.current_profiling.write();
516
517 match current.as_mut() {
518 Some(current_stream_id) => {
519 if current_stream_id == &stream_id {
520 return None;
521 }
522 }
523 None => {
524 *current = Some(stream_id);
525 return Some(stream_id);
526 }
527 }
528 }
529 }
530 None => {
531 *current = Some(stream_id);
532 Some(stream_id)
533 }
534 }
535 }
536
537 #[cfg(multi_threading)]
538 fn profile_release(&self, stream_id: Option<StreamId>) {
539 let stream_id = match stream_id {
540 Some(val) => val,
541 None => return, };
543 let mut current = self.state.current_profiling.write();
544
545 match current.as_mut() {
546 Some(current_stream_id) => {
547 if current_stream_id != &stream_id {
548 panic!("Can't release a different profiling guard.");
549 } else {
550 *current = None;
551 }
552 }
553 None => panic!("Can't release an empty profiling guard"),
554 }
555 }
556}