1use crate::{
2 DeviceProperties,
3 config::{TypeNameFormatLevel, type_name_format},
4 kernel::KernelMetadata,
5 logging::ProfileLevel,
6 memory_management::{MemoryAllocationMode, MemoryUsage},
7 server::{
8 Allocation, AllocationDescriptor, AllocationKind, Binding, Bindings, ComputeServer,
9 CopyDescriptor, CubeCount, Handle, IoError, ProfileError, ServerUtilities,
10 },
11 storage::{BindingResource, ComputeStorage},
12};
13use alloc::format;
14use alloc::sync::Arc;
15use alloc::vec;
16use alloc::vec::Vec;
17use core::ops::DerefMut;
18use cubecl_common::{
19 ExecutionMode,
20 bytes::{AllocationProperty, Bytes},
21 device::{Device, DeviceContext},
22 future::DynFut,
23 profile::ProfileDuration,
24};
25
26#[allow(unused)]
27use cubecl_common::profile::TimingMethod;
28use cubecl_common::stream_id::StreamId;
29
30pub struct ComputeClient<Server: ComputeServer> {
33 context: DeviceContext<Server>,
34 utilities: Arc<ServerUtilities<Server>>,
35 stream_id: Option<StreamId>,
36}
37
38impl<S> Clone for ComputeClient<S>
39where
40 S: ComputeServer,
41{
42 fn clone(&self) -> Self {
43 Self {
44 context: self.context.clone(),
45 utilities: self.utilities.clone(),
46 stream_id: self.stream_id,
47 }
48 }
49}
50
51impl<Server> ComputeClient<Server>
52where
53 Server: ComputeServer,
54{
55 pub fn info(&self) -> &Server::Info {
57 &self.utilities.info
58 }
59
60 pub fn init<D: Device>(device: &D, server: Server) -> Self {
62 let utilities = server.utilities();
63
64 let context = DeviceContext::<Server>::insert(device, server)
65 .expect("Can't create a new client on an already registered server");
66
67 Self {
68 context,
69 utilities,
70 stream_id: None,
71 }
72 }
73
74 pub fn load<D: Device>(device: &D) -> Self {
76 let context = DeviceContext::<Server>::locate(device);
77 let utilities = context.lock().utilities();
78
79 Self {
80 context,
81 utilities,
82 stream_id: None,
83 }
84 }
85
86 fn stream_id(&self) -> StreamId {
87 match self.stream_id {
88 Some(val) => val,
89 None => StreamId::current(),
90 }
91 }
92
93 pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
99 self.stream_id = Some(stream_id);
100 }
101
102 fn do_read(&self, descriptors: Vec<CopyDescriptor<'_>>) -> DynFut<Result<Vec<Bytes>, IoError>> {
103 let stream_id = self.stream_id();
104 let mut state = self.context.lock();
105 let fut = state.read(descriptors, stream_id);
106 core::mem::drop(state);
107 fut
108 }
109
110 pub fn read_async(&self, handles: Vec<Handle>) -> impl Future<Output = Vec<Bytes>> + Send {
112 let strides = [1];
113 let shapes = handles
114 .iter()
115 .map(|it| [it.size() as usize])
116 .collect::<Vec<_>>();
117 let bindings = handles
118 .into_iter()
119 .map(|it| it.binding())
120 .collect::<Vec<_>>();
121 let descriptors = bindings
122 .into_iter()
123 .zip(shapes.iter())
124 .map(|(binding, shape)| CopyDescriptor::new(binding, shape, &strides, 1))
125 .collect();
126
127 let fut = self.do_read(descriptors);
128
129 async move { fut.await.unwrap() }
130 }
131
132 pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
138 cubecl_common::reader::read_sync(self.read_async(handles))
139 }
140
141 pub fn read_one(&self, handle: Handle) -> Bytes {
146 cubecl_common::reader::read_sync(self.read_async(vec![handle])).remove(0)
147 }
148
149 pub fn read_tensor_async(
151 &self,
152 descriptors: Vec<CopyDescriptor<'_>>,
153 ) -> impl Future<Output = Vec<Bytes>> + Send {
154 let fut = self.do_read(descriptors);
155
156 async move { fut.await.unwrap() }
157 }
158
159 pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor<'_>>) -> Vec<Bytes> {
172 cubecl_common::reader::read_sync(self.read_tensor_async(descriptors))
173 }
174
175 pub fn read_one_tensor_async(
178 &self,
179 descriptor: CopyDescriptor<'_>,
180 ) -> impl Future<Output = Bytes> + Send {
181 let fut = self.read_tensor_async(vec![descriptor]);
182
183 async { fut.await.remove(0) }
184 }
185
186 pub fn read_one_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
192 self.read_tensor(vec![descriptor]).remove(0)
193 }
194
195 pub fn get_resource(
197 &self,
198 binding: Binding,
199 ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
200 let stream_id = self.stream_id();
201 self.context.lock().get_resource(binding, stream_id)
202 }
203
204 fn do_create_from_slices(
205 &self,
206 descriptors: Vec<AllocationDescriptor<'_>>,
207 slices: Vec<&[u8]>,
208 ) -> Result<Vec<Allocation>, IoError> {
209 let mut state = self.context.lock();
210 let allocations = state.create(descriptors.clone(), self.stream_id())?;
211 let descriptors = descriptors
212 .into_iter()
213 .zip(allocations.iter())
214 .zip(slices)
215 .map(|((desc, alloc), data)| {
216 (
217 CopyDescriptor::new(
218 alloc.handle.clone().binding(),
219 desc.shape,
220 &alloc.strides,
221 desc.elem_size,
222 ),
223 Bytes::from_bytes_vec(data.to_vec()),
224 )
225 })
226 .collect();
227 let stream_id = self.stream_id();
228 state.write(descriptors, stream_id)?;
229 Ok(allocations)
230 }
231
232 fn do_create(
233 &self,
234 descriptors: Vec<AllocationDescriptor<'_>>,
235 mut data: Vec<Bytes>,
236 ) -> Result<Vec<Allocation>, IoError> {
237 self.staging(data.iter_mut(), true);
238
239 let mut state = self.context.lock();
240 let allocations = state.create(descriptors.clone(), self.stream_id())?;
241 let descriptors = descriptors
242 .into_iter()
243 .zip(allocations.iter())
244 .zip(data)
245 .map(|((desc, alloc), data)| {
246 (
247 CopyDescriptor::new(
248 alloc.handle.clone().binding(),
249 desc.shape,
250 &alloc.strides,
251 desc.elem_size,
252 ),
253 data,
254 )
255 })
256 .collect();
257 let stream_id = self.stream_id();
258 state.write(descriptors, stream_id)?;
259 Ok(allocations)
260 }
261
262 pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
268 let shape = [slice.len()];
269
270 self.do_create_from_slices(
271 vec![AllocationDescriptor::new(
272 AllocationKind::Contiguous,
273 &shape,
274 1,
275 )],
276 vec![slice],
277 )
278 .unwrap()
279 .remove(0)
280 .handle
281 }
282
283 pub fn create(&self, data: Bytes) -> Handle {
285 let shape = [data.len()];
286
287 self.do_create(
288 vec![AllocationDescriptor::new(
289 AllocationKind::Contiguous,
290 &shape,
291 1,
292 )],
293 vec![data],
294 )
295 .unwrap()
296 .remove(0)
297 .handle
298 }
299
300 pub fn create_tensor_from_slice(
318 &self,
319 slice: &[u8],
320 shape: &[usize],
321 elem_size: usize,
322 ) -> Allocation {
323 self.do_create_from_slices(
324 vec![AllocationDescriptor::new(
325 AllocationKind::Optimized,
326 shape,
327 elem_size,
328 )],
329 vec![slice],
330 )
331 .unwrap()
332 .remove(0)
333 }
334
335 pub fn create_tensor(&self, bytes: Bytes, shape: &[usize], elem_size: usize) -> Allocation {
349 self.do_create(
350 vec![AllocationDescriptor::new(
351 AllocationKind::Optimized,
352 shape,
353 elem_size,
354 )],
355 vec![bytes],
356 )
357 .unwrap()
358 .remove(0)
359 }
360
361 pub fn create_tensors_from_slices(
369 &self,
370 descriptors: Vec<(AllocationDescriptor<'_>, &[u8])>,
371 ) -> Vec<Allocation> {
372 let (descriptors, data) = descriptors.into_iter().unzip();
373
374 self.do_create_from_slices(descriptors, data).unwrap()
375 }
376
377 pub fn create_tensors(
381 &self,
382 descriptors: Vec<(AllocationDescriptor<'_>, Bytes)>,
383 ) -> Vec<Allocation> {
384 let (descriptors, data) = descriptors.into_iter().unzip();
385
386 self.do_create(descriptors, data).unwrap()
387 }
388
389 fn do_empty(
390 &self,
391 descriptors: Vec<AllocationDescriptor<'_>>,
392 ) -> Result<Vec<Allocation>, IoError> {
393 let mut state = self.context.lock();
394 state.create(descriptors, self.stream_id())
395 }
396
397 pub fn empty(&self, size: usize) -> Handle {
399 let shape = [size];
400 let descriptor = AllocationDescriptor::new(AllocationKind::Contiguous, &shape, 1);
401 self.do_empty(vec![descriptor]).unwrap().remove(0).handle
402 }
403
404 pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> Allocation {
407 let descriptor = AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size);
408 self.do_empty(vec![descriptor]).unwrap().remove(0)
409 }
410
411 pub fn empty_tensors(&self, descriptors: Vec<AllocationDescriptor<'_>>) -> Vec<Allocation> {
414 self.do_empty(descriptors).unwrap()
415 }
416
417 pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
420 where
421 I: Iterator<Item = &'a mut Bytes>,
422 {
423 let has_staging = |b: &Bytes| match b.property() {
424 AllocationProperty::Pinned => false,
425 AllocationProperty::File => true,
426 AllocationProperty::Native | AllocationProperty::Other => !file_only,
427 };
428
429 let mut to_be_updated = Vec::new();
430 let sizes = bytes
431 .filter_map(|b| match has_staging(b) {
432 true => {
433 let len = b.len();
434 to_be_updated.push(b);
435 Some(len)
436 }
437 false => None,
438 })
439 .collect::<Vec<usize>>();
440
441 if sizes.is_empty() {
442 return;
443 }
444
445 let stream_id = self.stream_id();
446 let mut context = self.context.lock();
447 let stagings = match context.staging(&sizes, stream_id) {
448 Ok(val) => val,
449 Err(_) => return,
450 };
451 core::mem::drop(context);
452
453 to_be_updated
454 .into_iter()
455 .zip(stagings)
456 .for_each(|(b, mut staging)| {
457 b.copy_into(&mut staging);
458 core::mem::swap(b, &mut staging);
459 });
460 }
461
462 pub fn to_client(&self, src: Handle, dst_server: &Self) -> Allocation {
464 let shape = [src.size() as usize];
465 let src_descriptor = src.copy_descriptor(&shape, &[1], 1);
466
467 if Server::SERVER_COMM_ENABLED {
468 self.to_client_tensor(src_descriptor, dst_server)
469 } else {
470 let alloc_desc = AllocationDescriptor::new(
471 AllocationKind::Contiguous,
472 src_descriptor.shape,
473 src_descriptor.elem_size,
474 );
475 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
476 }
477 }
478
479 pub fn to_client_tensor(
483 &self,
484 src_descriptor: CopyDescriptor<'_>,
485 dst_server: &Self,
486 ) -> Allocation {
487 if Server::SERVER_COMM_ENABLED {
488 let guard = self.context.lock_device_kind();
489 let mut server_src = self.context.lock();
490 let mut server_dst = dst_server.context.lock();
491
492 let copied = Server::copy(
493 server_src.deref_mut(),
494 server_dst.deref_mut(),
495 src_descriptor,
496 self.stream_id(),
497 dst_server.stream_id(),
498 )
499 .unwrap();
500 core::mem::drop(server_src);
501 core::mem::drop(server_dst);
502 core::mem::drop(guard);
503 copied
504 } else {
505 let alloc_desc = AllocationDescriptor::new(
506 AllocationKind::Optimized,
507 src_descriptor.shape,
508 src_descriptor.elem_size,
509 );
510 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
511 }
512 }
513
514 #[track_caller]
515 unsafe fn execute_inner(
516 &self,
517 kernel: Server::Kernel,
518 count: CubeCount,
519 bindings: Bindings,
520 mode: ExecutionMode,
521 stream_id: StreamId,
522 ) {
523 let level = self.utilities.logger.profile_level();
524
525 match level {
526 None | Some(ProfileLevel::ExecutionOnly) => {
527 let mut state = self.context.lock();
528 let name = kernel.name();
529
530 unsafe { state.execute(kernel, count, bindings, mode, stream_id) };
531
532 if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
533 let info = type_name_format(name, TypeNameFormatLevel::Balanced);
534 self.utilities.logger.register_execution(info);
535 }
536 }
537 Some(level) => {
538 let name = kernel.name();
539 let kernel_id = kernel.id();
540 let profile = self
541 .profile(
542 || unsafe {
543 let mut state = self.context.lock();
544 state.execute(kernel, count.clone(), bindings, mode, stream_id)
545 },
546 name,
547 )
548 .unwrap();
549 let info = match level {
550 ProfileLevel::Full => {
551 format!("{name}: {kernel_id} CubeCount {count:?}")
552 }
553 _ => type_name_format(name, TypeNameFormatLevel::Balanced),
554 };
555 self.utilities.logger.register_profiled(info, profile);
556 }
557 }
558 }
559
560 #[track_caller]
562 pub fn execute(&self, kernel: Server::Kernel, count: CubeCount, bindings: Bindings) {
563 unsafe {
565 self.execute_inner(
566 kernel,
567 count,
568 bindings,
569 ExecutionMode::Checked,
570 self.stream_id(),
571 );
572 }
573 }
574
575 #[track_caller]
583 pub unsafe fn execute_unchecked(
584 &self,
585 kernel: Server::Kernel,
586 count: CubeCount,
587 bindings: Bindings,
588 ) {
589 unsafe {
591 self.execute_inner(
592 kernel,
593 count,
594 bindings,
595 ExecutionMode::Unchecked,
596 self.stream_id(),
597 );
598 }
599 }
600
601 pub fn flush(&self) {
603 let stream_id = self.stream_id();
604 self.context.lock().flush(stream_id);
605 }
606
607 pub fn sync(&self) -> DynFut<()> {
609 let stream_id = self.stream_id();
610 let mut state = self.context.lock();
611 let fut = state.sync(stream_id);
612 core::mem::drop(state);
613 self.utilities.logger.profile_summary();
614
615 fut
616 }
617
618 pub fn properties(&self) -> &DeviceProperties {
620 &self.utilities.properties
621 }
622
623 pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
627 Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
628 }
629
630 pub fn memory_usage(&self) -> MemoryUsage {
632 self.context.lock().memory_usage(self.stream_id())
633 }
634
635 pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
641 self.context.lock().allocation_mode(mode, self.stream_id())
642 }
643
644 pub fn memory_persistent_allocation<Input, Output, Func: Fn(Input) -> Output>(
651 &self,
652 input: Input,
653 func: Func,
654 ) -> Output {
655 let device_guard = self.context.lock_device();
656
657 self.context
658 .lock()
659 .allocation_mode(MemoryAllocationMode::Persistent, self.stream_id());
660
661 let output = func(input);
662
663 self.context
664 .lock()
665 .allocation_mode(MemoryAllocationMode::Auto, self.stream_id());
666
667 core::mem::drop(device_guard);
668
669 output
670 }
671
672 pub fn memory_cleanup(&self) {
677 self.context.lock().memory_cleanup(self.stream_id())
678 }
679
680 #[track_caller]
682 pub fn profile<O>(
683 &self,
684 func: impl FnOnce() -> O,
685 #[allow(unused)] func_name: &str,
686 ) -> Result<ProfileDuration, ProfileError> {
687 #[cfg(feature = "profile-tracy")]
690 let location = std::panic::Location::caller();
691
692 #[cfg(feature = "profile-tracy")]
694 let _span = tracy_client::Client::running().unwrap().span_alloc(
695 None,
696 func_name,
697 location.file(),
698 location.line(),
699 0,
700 );
701
702 let device_guard = self.context.lock_device();
703
704 #[cfg(feature = "profile-tracy")]
705 let gpu_span = if self.state.properties.timing_method == TimingMethod::Device {
706 let gpu_span = self
707 .state
708 .gpu_client
709 .span_alloc(func_name, "profile", location.file(), location.line())
710 .unwrap();
711 Some(gpu_span)
712 } else {
713 None
714 };
715
716 let token = self.context.lock().start_profile(self.stream_id());
717
718 let out = func();
719
720 let result = self.context.lock().end_profile(self.stream_id(), token);
721
722 core::mem::drop(out);
723
724 #[cfg(feature = "profile-tracy")]
725 if let Some(mut gpu_span) = gpu_span {
726 gpu_span.end_zone();
727 let epoch = self.state.epoch_time;
728 result = result.map(|result| {
730 ProfileDuration::new(
731 Box::pin(async move {
732 let ticks = result.resolve().await;
733 let start_duration = ticks.start_duration_since(epoch).as_nanos() as i64;
734 let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
735 gpu_span.upload_timestamp_start(start_duration);
736 gpu_span.upload_timestamp_end(end_duration);
737 ticks
738 }),
739 TimingMethod::Device,
740 )
741 });
742 }
743 core::mem::drop(device_guard);
744
745 result
746 }
747
748 fn change_client_sync(
750 &self,
751 src_descriptor: CopyDescriptor<'_>,
752 alloc_descriptor: AllocationDescriptor<'_>,
753 dst_server: &Self,
754 ) -> Allocation {
755 let shape = src_descriptor.shape;
756 let elem_size = src_descriptor.elem_size;
757 let stream_id = self.stream_id();
758
759 let alloc = dst_server
761 .context
762 .lock()
763 .create(vec![alloc_descriptor], self.stream_id())
764 .unwrap()
765 .remove(0);
766
767 let read = self.context.lock().read(vec![src_descriptor], stream_id);
768 let mut data = cubecl_common::future::block_on(read).unwrap();
769
770 let desc_descriptor = CopyDescriptor {
771 binding: alloc.handle.clone().binding(),
772 shape,
773 strides: &alloc.strides,
774 elem_size,
775 };
776
777 dst_server
778 .context
779 .lock()
780 .write(vec![(desc_descriptor, data.remove(0))], stream_id)
781 .unwrap();
782
783 alloc
784 }
785}