1use crate::{
2    DeviceProperties,
3    config::{TypeNameFormatLevel, type_name_format},
4    kernel::KernelMetadata,
5    logging::ProfileLevel,
6    memory_management::{MemoryAllocationMode, MemoryUsage},
7    server::{
8        Allocation, AllocationDescriptor, AllocationKind, Binding, Bindings, ComputeServer,
9        CopyDescriptor, CubeCount, Handle, IoError, ProfileError, ServerUtilities,
10    },
11    storage::{BindingResource, ComputeStorage},
12};
13use alloc::format;
14use alloc::sync::Arc;
15use alloc::vec;
16use alloc::vec::Vec;
17use core::ops::DerefMut;
18use cubecl_common::{
19    ExecutionMode,
20    bytes::Bytes,
21    device::{Device, DeviceContext},
22    future::DynFut,
23    profile::ProfileDuration,
24};
25
26#[allow(unused)]
27use cubecl_common::profile::TimingMethod;
28use cubecl_common::stream_id::StreamId;
29
30pub struct ComputeClient<Server: ComputeServer> {
33    context: DeviceContext<Server>,
34    utilities: Arc<ServerUtilities<Server>>,
35    stream_id: Option<StreamId>,
36}
37
38impl<S> Clone for ComputeClient<S>
39where
40    S: ComputeServer,
41{
42    fn clone(&self) -> Self {
43        Self {
44            context: self.context.clone(),
45            utilities: self.utilities.clone(),
46            stream_id: self.stream_id,
47        }
48    }
49}
50
51impl<Server> ComputeClient<Server>
52where
53    Server: ComputeServer,
54{
55    pub fn info(&self) -> &Server::Info {
57        &self.utilities.info
58    }
59
60    pub fn init<D: Device>(device: &D, server: Server) -> Self {
62        let utilities = server.utilities();
63
64        let context = DeviceContext::<Server>::insert(device, server)
65            .expect("Can't create a new client on an already registered server");
66
67        Self {
68            context,
69            utilities,
70            stream_id: None,
71        }
72    }
73
74    pub fn load<D: Device>(device: &D) -> Self {
76        let context = DeviceContext::<Server>::locate(device);
77        let utilities = context.lock().utilities();
78
79        Self {
80            context,
81            utilities,
82            stream_id: None,
83        }
84    }
85
86    fn stream_id(&self) -> StreamId {
87        match self.stream_id {
88            Some(val) => val,
89            None => StreamId::current(),
90        }
91    }
92
93    pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
99        self.stream_id = Some(stream_id);
100    }
101
102    fn do_read(&self, descriptors: Vec<CopyDescriptor<'_>>) -> DynFut<Result<Vec<Bytes>, IoError>> {
103        let stream_id = self.stream_id();
104        let mut state = self.context.lock();
105        let fut = state.read(descriptors, stream_id);
106        core::mem::drop(state);
107        fut
108    }
109
110    pub fn read_async(&self, handles: Vec<Handle>) -> impl Future<Output = Vec<Bytes>> + Send {
112        let strides = [1];
113        let shapes = handles
114            .iter()
115            .map(|it| [it.size() as usize])
116            .collect::<Vec<_>>();
117        let bindings = handles
118            .into_iter()
119            .map(|it| it.binding())
120            .collect::<Vec<_>>();
121        let descriptors = bindings
122            .into_iter()
123            .zip(shapes.iter())
124            .map(|(binding, shape)| CopyDescriptor::new(binding, shape, &strides, 1))
125            .collect();
126
127        let fut = self.do_read(descriptors);
128
129        async move { fut.await.unwrap() }
130    }
131
132    pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
138        cubecl_common::reader::read_sync(self.read_async(handles))
139    }
140
141    pub fn read_one(&self, handle: Handle) -> Bytes {
146        cubecl_common::reader::read_sync(self.read_async(vec![handle])).remove(0)
147    }
148
149    pub fn read_tensor_async(
151        &self,
152        descriptors: Vec<CopyDescriptor<'_>>,
153    ) -> impl Future<Output = Vec<Bytes>> + Send {
154        let fut = self.do_read(descriptors);
155
156        async move { fut.await.unwrap() }
157    }
158
159    pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor<'_>>) -> Vec<Bytes> {
172        cubecl_common::reader::read_sync(self.read_tensor_async(descriptors))
173    }
174
175    pub fn read_one_tensor_async(
178        &self,
179        descriptor: CopyDescriptor<'_>,
180    ) -> impl Future<Output = Bytes> + Send {
181        let fut = self.read_tensor_async(vec![descriptor]);
182
183        async { fut.await.remove(0) }
184    }
185
186    pub fn read_one_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
192        self.read_tensor(vec![descriptor]).remove(0)
193    }
194
195    pub fn get_resource(
197        &self,
198        binding: Binding,
199    ) -> BindingResource<<Server::Storage as ComputeStorage>::Resource> {
200        let stream_id = self.stream_id();
201        self.context.lock().get_resource(binding, stream_id)
202    }
203
204    fn do_create(
205        &self,
206        descriptors: Vec<AllocationDescriptor<'_>>,
207        data: Vec<&[u8]>,
208    ) -> Result<Vec<Allocation>, IoError> {
209        let mut state = self.context.lock();
210        let allocations = state.create(descriptors.clone(), self.stream_id())?;
211        let descriptors = descriptors
212            .into_iter()
213            .zip(allocations.iter())
214            .zip(data)
215            .map(|((desc, alloc), data)| {
216                (
217                    CopyDescriptor::new(
218                        alloc.handle.clone().binding(),
219                        desc.shape,
220                        &alloc.strides,
221                        desc.elem_size,
222                    ),
223                    data,
224                )
225            })
226            .collect();
227        let stream_id = self.stream_id();
228        state.write(descriptors, stream_id)?;
229        Ok(allocations)
230    }
231
232    pub fn create(&self, data: &[u8]) -> Handle {
234        let shape = [data.len()];
235
236        self.do_create(
237            vec![AllocationDescriptor::new(
238                AllocationKind::Contiguous,
239                &shape,
240                1,
241            )],
242            vec![data],
243        )
244        .unwrap()
245        .remove(0)
246        .handle
247    }
248
249    pub fn create_tensor(&self, data: &[u8], shape: &[usize], elem_size: usize) -> Allocation {
263        self.do_create(
264            vec![AllocationDescriptor::new(
265                AllocationKind::Optimized,
266                shape,
267                elem_size,
268            )],
269            vec![data],
270        )
271        .unwrap()
272        .remove(0)
273    }
274
275    pub fn create_tensors(
279        &self,
280        descriptors: Vec<(AllocationDescriptor<'_>, &[u8])>,
281    ) -> Vec<Allocation> {
282        let (descriptors, data) = descriptors.into_iter().unzip();
283
284        self.do_create(descriptors, data).unwrap()
285    }
286
287    fn do_empty(
288        &self,
289        descriptors: Vec<AllocationDescriptor<'_>>,
290    ) -> Result<Vec<Allocation>, IoError> {
291        let mut state = self.context.lock();
292        state.create(descriptors, self.stream_id())
293    }
294
295    pub fn empty(&self, size: usize) -> Handle {
297        let shape = [size];
298        let descriptor = AllocationDescriptor::new(AllocationKind::Contiguous, &shape, 1);
299        self.do_empty(vec![descriptor]).unwrap().remove(0).handle
300    }
301
302    pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> Allocation {
305        let descriptor = AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size);
306        self.do_empty(vec![descriptor]).unwrap().remove(0)
307    }
308
309    pub fn empty_tensors(&self, descriptors: Vec<AllocationDescriptor<'_>>) -> Vec<Allocation> {
312        self.do_empty(descriptors).unwrap()
313    }
314
315    pub fn to_client(&self, src: Handle, dst_server: &Self) -> Allocation {
317        let shape = [src.size() as usize];
318        let src_descriptor = src.copy_descriptor(&shape, &[1], 1);
319
320        if Server::SERVER_COMM_ENABLED {
321            self.to_client_tensor(src_descriptor, dst_server)
322        } else {
323            let alloc_desc = AllocationDescriptor::new(
324                AllocationKind::Contiguous,
325                src_descriptor.shape,
326                src_descriptor.elem_size,
327            );
328            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
329        }
330    }
331
332    pub fn to_client_tensor(
336        &self,
337        src_descriptor: CopyDescriptor<'_>,
338        dst_server: &Self,
339    ) -> Allocation {
340        if Server::SERVER_COMM_ENABLED {
341            let mut server_src = self.context.lock();
342            let mut server_dst = dst_server.context.lock();
343
344            Server::copy(
345                server_src.deref_mut(),
346                server_dst.deref_mut(),
347                src_descriptor,
348                self.stream_id(),
349                dst_server.stream_id(),
350            )
351            .unwrap()
352        } else {
353            let alloc_desc = AllocationDescriptor::new(
354                AllocationKind::Optimized,
355                src_descriptor.shape,
356                src_descriptor.elem_size,
357            );
358            self.change_client_sync(src_descriptor, alloc_desc, dst_server)
359        }
360    }
361
362    #[track_caller]
363    unsafe fn execute_inner(
364        &self,
365        kernel: Server::Kernel,
366        count: CubeCount,
367        bindings: Bindings,
368        mode: ExecutionMode,
369        stream_id: StreamId,
370    ) {
371        let level = self.utilities.logger.profile_level();
372
373        match level {
374            None | Some(ProfileLevel::ExecutionOnly) => {
375                let mut state = self.context.lock();
376                let name = kernel.name();
377
378                unsafe { state.execute(kernel, count, bindings, mode, stream_id) };
379
380                if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
381                    let info = type_name_format(name, TypeNameFormatLevel::Balanced);
382                    self.utilities.logger.register_execution(info);
383                }
384            }
385            Some(level) => {
386                let name = kernel.name();
387                let kernel_id = kernel.id();
388                let profile = self
389                    .profile(
390                        || unsafe {
391                            let mut state = self.context.lock();
392                            state.execute(kernel, count.clone(), bindings, mode, stream_id)
393                        },
394                        name,
395                    )
396                    .unwrap();
397                let info = match level {
398                    ProfileLevel::Full => {
399                        format!("{name}: {kernel_id} CubeCount {count:?}")
400                    }
401                    _ => type_name_format(name, TypeNameFormatLevel::Balanced),
402                };
403                self.utilities.logger.register_profiled(info, profile);
404            }
405        }
406    }
407
408    #[track_caller]
410    pub fn execute(&self, kernel: Server::Kernel, count: CubeCount, bindings: Bindings) {
411        unsafe {
413            self.execute_inner(
414                kernel,
415                count,
416                bindings,
417                ExecutionMode::Checked,
418                self.stream_id(),
419            );
420        }
421    }
422
423    #[track_caller]
431    pub unsafe fn execute_unchecked(
432        &self,
433        kernel: Server::Kernel,
434        count: CubeCount,
435        bindings: Bindings,
436    ) {
437        unsafe {
439            self.execute_inner(
440                kernel,
441                count,
442                bindings,
443                ExecutionMode::Unchecked,
444                self.stream_id(),
445            );
446        }
447    }
448
449    pub fn flush(&self) {
451        let stream_id = self.stream_id();
452        self.context.lock().flush(stream_id);
453    }
454
455    pub fn sync(&self) -> DynFut<()> {
457        let stream_id = self.stream_id();
458        let mut state = self.context.lock();
459        let fut = state.sync(stream_id);
460        core::mem::drop(state);
461        self.utilities.logger.profile_summary();
462
463        fut
464    }
465
466    pub fn properties(&self) -> &DeviceProperties {
468        &self.utilities.properties
469    }
470
471    pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
475        Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
476    }
477
478    pub fn memory_usage(&self) -> MemoryUsage {
480        self.context.lock().memory_usage(self.stream_id())
481    }
482
483    pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
489        self.context.lock().allocation_mode(mode, self.stream_id())
490    }
491
492    pub fn memory_persistent_allocation<Input, Output, Func: Fn(Input) -> Output>(
499        &self,
500        input: Input,
501        func: Func,
502    ) -> Output {
503        let device_guard = self.context.lock_device();
504
505        self.context
506            .lock()
507            .allocation_mode(MemoryAllocationMode::Persistent, self.stream_id());
508
509        let output = func(input);
510
511        self.context
512            .lock()
513            .allocation_mode(MemoryAllocationMode::Auto, self.stream_id());
514
515        core::mem::drop(device_guard);
516
517        output
518    }
519
520    pub fn memory_cleanup(&self) {
525        self.context.lock().memory_cleanup(self.stream_id())
526    }
527
528    #[track_caller]
530    pub fn profile<O>(
531        &self,
532        func: impl FnOnce() -> O,
533        #[allow(unused)] func_name: &str,
534    ) -> Result<ProfileDuration, ProfileError> {
535        #[cfg(feature = "profile-tracy")]
538        let location = std::panic::Location::caller();
539
540        #[cfg(feature = "profile-tracy")]
542        let _span = tracy_client::Client::running().unwrap().span_alloc(
543            None,
544            func_name,
545            location.file(),
546            location.line(),
547            0,
548        );
549
550        let device_guard = self.context.lock_device();
551
552        #[cfg(feature = "profile-tracy")]
553        let gpu_span = if self.state.properties.timing_method == TimingMethod::Device {
554            let gpu_span = self
555                .state
556                .gpu_client
557                .span_alloc(func_name, "profile", location.file(), location.line())
558                .unwrap();
559            Some(gpu_span)
560        } else {
561            None
562        };
563
564        let token = self.context.lock().start_profile(self.stream_id());
565
566        let out = func();
567
568        let result = self.context.lock().end_profile(self.stream_id(), token);
569
570        core::mem::drop(out);
571
572        #[cfg(feature = "profile-tracy")]
573        if let Some(mut gpu_span) = gpu_span {
574            gpu_span.end_zone();
575            let epoch = self.state.epoch_time;
576            result = result.map(|result| {
578                ProfileDuration::new(
579                    Box::pin(async move {
580                        let ticks = result.resolve().await;
581                        let start_duration = ticks.start_duration_since(epoch).as_nanos() as i64;
582                        let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
583                        gpu_span.upload_timestamp_start(start_duration);
584                        gpu_span.upload_timestamp_end(end_duration);
585                        ticks
586                    }),
587                    TimingMethod::Device,
588                )
589            });
590        }
591        core::mem::drop(device_guard);
592
593        result
594    }
595
596    fn change_client_sync(
598        &self,
599        src_descriptor: CopyDescriptor<'_>,
600        alloc_descriptor: AllocationDescriptor<'_>,
601        dst_server: &Self,
602    ) -> Allocation {
603        let shape = src_descriptor.shape;
604        let elem_size = src_descriptor.elem_size;
605        let stream_id = self.stream_id();
606
607        let alloc = dst_server
609            .context
610            .lock()
611            .create(vec![alloc_descriptor], self.stream_id())
612            .unwrap()
613            .remove(0);
614
615        let read = self.context.lock().read(vec![src_descriptor], stream_id);
616        let data = cubecl_common::future::block_on(read).unwrap();
617
618        let desc_descriptor = CopyDescriptor {
619            binding: alloc.handle.clone().binding(),
620            shape,
621            strides: &alloc.strides,
622            elem_size,
623        };
624
625        dst_server
626            .context
627            .lock()
628            .write(vec![(desc_descriptor, &data[0])], stream_id)
629            .unwrap();
630
631        alloc
632    }
633}